diff --git a/zluda_dump/Cargo.toml b/zluda_dump/Cargo.toml index 80c7ddc..c88dca7 100644 --- a/zluda_dump/Cargo.toml +++ b/zluda_dump/Cargo.toml @@ -12,6 +12,8 @@ crate-type = ["cdylib"] ptx = { path = "../ptx" } lz4-sys = "1.9" regex = "1.4" +dynasm = "1.1" +dynasmrt = "1.1" [target.'cfg(windows)'.dependencies] winapi = { version = "0.3", features = ["libloaderapi", "debugapi", "std"] } diff --git a/zluda_dump/src/cuda.rs b/zluda_dump/src/cuda.rs index 3f78b14..d715689 100644 --- a/zluda_dump/src/cuda.rs +++ b/zluda_dump/src/cuda.rs @@ -103,7 +103,7 @@ pub struct CUgraphExec_st { } pub type CUgraphExec = *mut CUgraphExec_st; #[repr(C)] -#[derive(Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] pub struct CUuuid_st { pub bytes: [::std::os::raw::c_uchar; 16usize], } diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs index 4ea449c..eecd573 100644 --- a/zluda_dump/src/lib.rs +++ b/zluda_dump/src/lib.rs @@ -191,7 +191,10 @@ unsafe fn record_module_image(module: CUmodule, image: &str) { unsafe fn try_dump_module_image(image: &str) -> Result<(), Box> { let mut dump_path = get_dump_dir()?; - dump_path.push(format!("module_{:04}.ptx", MODULES.as_ref().unwrap().len() - 1)); + dump_path.push(format!( + "module_{:04}.ptx", + MODULES.as_ref().unwrap().len() - 1 + )); let mut file = File::create(dump_path)?; file.write_all(image.as_bytes())?; Ok(()) @@ -665,9 +668,8 @@ const CUDART_INTERFACE_GUID: CUuuid = CUuuid { ], }; -static mut CUDART_INTERFACE_VTABLE: Vec<*const c_void> = Vec::new(); -const GET_MODULE_FROM_CUBIN_OFFSET: usize = 1; -const GET_MODULE_FROM_CUBIN_EXT_OFFSET: usize = 6; +static mut OVERRIDEN_INTERFACE_VTABLES: Option, Vec<*const c_void>>> = None; + static mut ORIGINAL_GET_MODULE_FROM_CUBIN: Option< unsafe extern "system" fn( result: *mut CUmodule, @@ -683,44 +685,104 @@ static mut ORIGINAL_GET_MODULE_FROM_CUBIN_EXT: Option< ) -> CUresult, > = None; +unsafe extern "stdcall" fn report_unknown_export_table_call( + export_table: *const CUuuid, + idx: usize, +) { + let guid = (*export_table).bytes; + os_log!("Call to an unsupported export table function: {{{:02X}{:02X}{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}::{}", guid[0], guid[1], guid[2], guid[3], guid[4], guid[5], guid[6], guid[7], guid[8], guid[9], guid[10], guid[11], guid[12], guid[13], guid[14], guid[15], idx); +} + #[allow(non_snake_case)] pub unsafe fn cuGetExportTable( ppExportTable: *mut *const ::std::os::raw::c_void, pExportTableId: *const CUuuid, cont: impl FnOnce(*mut *const ::std::os::raw::c_void, *const CUuuid) -> CUresult, ) -> CUresult { - if *pExportTableId == CUDART_INTERFACE_GUID { - if CUDART_INTERFACE_VTABLE.len() == 0 { - let mut base_table = ptr::null(); - let base_result = cont(&mut base_table, pExportTableId); - if base_result != CUresult::CUDA_SUCCESS { - return base_result; + let guid = (*pExportTableId).bytes; + os_log!("Requested export table id: {{{:02X}{:02X}{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}", guid[0], guid[1], guid[2], guid[3], guid[4], guid[5], guid[6], guid[7], guid[8], guid[9], guid[10], guid[11], guid[12], guid[13], guid[14], guid[15]); + let result = cont(ppExportTable, pExportTableId); + if result == CUresult::CUDA_SUCCESS { + override_export_table(ppExportTable, pExportTableId); + } + result +} + +unsafe fn override_export_table( + export_table_ptr: *mut *const ::std::os::raw::c_void, + export_table_id: *const CUuuid, +) { + let overrides_map = OVERRIDEN_INTERFACE_VTABLES.get_or_insert_with(|| HashMap::new()); + if overrides_map.contains_key(&*export_table_id) { + return; + } + let export_table = (*export_table_ptr) as *mut *const c_void; + let boxed_guid = Box::new(*export_table_id); + let byte_length: usize = *(export_table as *const usize); + let mut override_table = Vec::new(); + if byte_length < 0x10000 { + override_table.push(byte_length as *const _); + let length = byte_length / mem::size_of::(); + for i in 1..length { + let current_fn = export_table.add(i); + if (*current_fn as usize) == usize::max_value() { + override_table.push(usize::max_value() as *const _); + break; } - let len = *(base_table as *const usize); - CUDART_INTERFACE_VTABLE = vec![ptr::null(); len]; - ptr::copy_nonoverlapping( - base_table as *const _, - CUDART_INTERFACE_VTABLE.as_mut_ptr(), - len, - ); - if GET_MODULE_FROM_CUBIN_EXT_OFFSET >= len { - return CUresult::CUDA_ERROR_UNKNOWN; - } - ORIGINAL_GET_MODULE_FROM_CUBIN = - mem::transmute(CUDART_INTERFACE_VTABLE[GET_MODULE_FROM_CUBIN_OFFSET]); - CUDART_INTERFACE_VTABLE[GET_MODULE_FROM_CUBIN_OFFSET] = - get_module_from_cubin as *const _; - ORIGINAL_GET_MODULE_FROM_CUBIN_EXT = - mem::transmute(CUDART_INTERFACE_VTABLE[GET_MODULE_FROM_CUBIN_EXT_OFFSET]); - CUDART_INTERFACE_VTABLE[GET_MODULE_FROM_CUBIN_EXT_OFFSET] = - get_module_from_cubin_ext as *const _; + override_table.push(get_export_override_fn(*current_fn, &*boxed_guid, i)); } - *ppExportTable = CUDART_INTERFACE_VTABLE.as_ptr() as *const _; - return CUresult::CUDA_SUCCESS; } else { - let guid = (*pExportTableId).bytes; - os_log!("Unsupported export table id: {{{:02X}{:02X}{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}", guid[0], guid[1], guid[2], guid[3], guid[4], guid[5], guid[6], guid[7], guid[8], guid[9], guid[10], guid[11], guid[12], guid[13], guid[14], guid[15]); - cont(ppExportTable, pExportTableId) + let mut i = 0; + loop { + let current_fn = export_table.add(i); + if (*current_fn as usize) == usize::max_value() { + override_table.push(usize::max_value() as *const _); + break; + } + override_table.push(get_export_override_fn(*current_fn, &*boxed_guid, i)); + i += 1; + } + } + *export_table_ptr = override_table.as_ptr() as *const _; + overrides_map.insert(boxed_guid, override_table); +} + +const TOOLS_RUNTIME_CALLBACK_HOOKS_GUID: CUuuid = CUuuid { + bytes: [ + 0xa0, 0x94, 0x79, 0x8c, 0x2e, 0x74, 0x2e, 0x74, 0x93, 0xf2, 0x08, 0x00, 0x20, 0x0c, 0x0a, + 0x66, + ], +}; + +const CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID: CUuuid = CUuuid { + bytes: [ + 0xc6, 0x93, 0x33, 0x6e, 0x11, 0x21, 0xdf, 0x11, 0xa8, 0xc3, 0x68, 0xf3, 0x55, 0xd8, 0x95, + 0x93, + ], +}; + +unsafe fn get_export_override_fn( + original_fn: *const c_void, + guid: *const CUuuid, + idx: usize, +) -> *const c_void { + match (*guid, idx) { + (TOOLS_RUNTIME_CALLBACK_HOOKS_GUID, 2) + | (TOOLS_RUNTIME_CALLBACK_HOOKS_GUID, 6) + | (CUDART_INTERFACE_GUID, 2) + | (CUDART_INTERFACE_GUID, 7) + | (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 0) + | (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 1) + | (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 2) => original_fn, + (CUDART_INTERFACE_GUID, 1) => { + ORIGINAL_GET_MODULE_FROM_CUBIN = mem::transmute(original_fn); + get_module_from_cubin as *const _ + } + (CUDART_INTERFACE_GUID, 6) => { + ORIGINAL_GET_MODULE_FROM_CUBIN_EXT = mem::transmute(original_fn); + get_module_from_cubin_ext as *const _ + } + _ => os::get_thunk(original_fn, report_unknown_export_table_call, guid, idx), } } diff --git a/zluda_dump/src/os_unix.rs b/zluda_dump/src/os_unix.rs index 91a004a..2cf8dad 100644 --- a/zluda_dump/src/os_unix.rs +++ b/zluda_dump/src/os_unix.rs @@ -28,3 +28,34 @@ macro_rules! os_log { } }; } + +//RDI, RSI, RDX, RCX, R8, R9 +#[cfg(target_arch = "x86_64")] +pub fn get_thunk( + original_fn: *const c_void, + report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize), + guid: *const CUuuid, + idx: usize, +) -> *const c_void { + use dynasmrt::{dynasm, DynasmApi}; + let mut ops = dynasmrt::x86::Assembler::new().unwrap(); + let start = ops.offset(); + dynasm!(ops + ; .arch x64 + ; push rdi + ; push rsi + ; mov rdi, QWORD guid as i64 + ; mov rsi, QWORD idx as i64 + ; mov rax, QWORD report_fn as i64 + ; call rax + ; pop rsi + ; pop rdi + ; mov rax, QWORD original_fn as i64 + ; jmp rax + ; int 3 + ); + let exe_buf = ops.finalize().unwrap(); + let result_fn = exe_buf.ptr(start); + mem::forget(exe_buf); + result_fn as *const _ +} diff --git a/zluda_dump/src/os_win.rs b/zluda_dump/src/os_win.rs index 70a2b42..55b69da 100644 --- a/zluda_dump/src/os_win.rs +++ b/zluda_dump/src/os_win.rs @@ -13,6 +13,8 @@ use winapi::{ um::libloaderapi::{GetProcAddress, LoadLibraryW}, }; +use crate::cuda::CUuuid; + const NVCUDA_DEFAULT_PATH: &[u16] = wch_c!(r"C:\Windows\System32\nvcuda.dll"); const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0"; @@ -97,3 +99,60 @@ pub fn __log_impl(s: String) { unsafe { OutputDebugStringA(win_str.as_ptr() as *const _) }; } } + +#[cfg(target_arch = "x86")] +pub fn get_thunk( + original_fn: *const c_void, + report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize), + guid: *const CUuuid, + idx: usize, +) -> *const c_void { + use dynasmrt::{dynasm, DynasmApi}; + let mut ops = dynasmrt::x86::Assembler::new().unwrap(); + let start = ops.offset(); + dynasm!(ops + ; .arch x86 + ; push idx as i32 + ; push guid as i32 + ; mov eax, report_fn as i32 + ; call eax + ; mov eax, original_fn as i32 + ; jmp eax + ; int 3 + ); + let exe_buf = ops.finalize().unwrap(); + let result_fn = exe_buf.ptr(start); + mem::forget(exe_buf); + result_fn as *const _ +} + +//RCX, RDX, R8, R9 +#[cfg(target_arch = "x86_64")] +pub fn get_thunk( + original_fn: *const c_void, + report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize), + guid: *const CUuuid, + idx: usize, +) -> *const c_void { + use dynasmrt::{dynasm, DynasmApi}; + let mut ops = dynasmrt::x86::Assembler::new().unwrap(); + let start = ops.offset(); + dynasm!(ops + ; .arch x64 + ; push rcx + ; push rdx + ; mov rcx, QWORD guid as i64 + ; mov rdx, QWORD idx as i64 + ; mov rax, QWORD report_fn as i64 + ; call rax + ; pop rdx + ; pop rcx + ; mov rax, QWORD original_fn as i64 + ; jmp rax + ; int 3 + ); + let exe_buf = ops.finalize().unwrap(); + let result_fn = exe_buf.ptr(start); + mem::forget(exe_buf); + result_fn as *const _ +}