Report calls to unsupported exportad table functions

This commit is contained in:
Andrzej Janik 2021-05-16 01:08:59 +02:00
parent a005c92c61
commit dca4c5bd21
5 changed files with 188 additions and 34 deletions

View file

@ -12,6 +12,8 @@ crate-type = ["cdylib"]
ptx = { path = "../ptx" }
lz4-sys = "1.9"
regex = "1.4"
dynasm = "1.1"
dynasmrt = "1.1"
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["libloaderapi", "debugapi", "std"] }

View file

@ -103,7 +103,7 @@ pub struct CUgraphExec_st {
}
pub type CUgraphExec = *mut CUgraphExec_st;
#[repr(C)]
#[derive(Copy, Clone, PartialEq)]
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub struct CUuuid_st {
pub bytes: [::std::os::raw::c_uchar; 16usize],
}

View file

@ -191,7 +191,10 @@ unsafe fn record_module_image(module: CUmodule, image: &str) {
unsafe fn try_dump_module_image(image: &str) -> Result<(), Box<dyn Error>> {
let mut dump_path = get_dump_dir()?;
dump_path.push(format!("module_{:04}.ptx", MODULES.as_ref().unwrap().len() - 1));
dump_path.push(format!(
"module_{:04}.ptx",
MODULES.as_ref().unwrap().len() - 1
));
let mut file = File::create(dump_path)?;
file.write_all(image.as_bytes())?;
Ok(())
@ -665,9 +668,8 @@ const CUDART_INTERFACE_GUID: CUuuid = CUuuid {
],
};
static mut CUDART_INTERFACE_VTABLE: Vec<*const c_void> = Vec::new();
const GET_MODULE_FROM_CUBIN_OFFSET: usize = 1;
const GET_MODULE_FROM_CUBIN_EXT_OFFSET: usize = 6;
static mut OVERRIDEN_INTERFACE_VTABLES: Option<HashMap<Box<CUuuid>, Vec<*const c_void>>> = None;
static mut ORIGINAL_GET_MODULE_FROM_CUBIN: Option<
unsafe extern "system" fn(
result: *mut CUmodule,
@ -683,44 +685,104 @@ static mut ORIGINAL_GET_MODULE_FROM_CUBIN_EXT: Option<
) -> CUresult,
> = None;
unsafe extern "stdcall" fn report_unknown_export_table_call(
export_table: *const CUuuid,
idx: usize,
) {
let guid = (*export_table).bytes;
os_log!("Call to an unsupported export table function: {{{:02X}{:02X}{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}::{}", guid[0], guid[1], guid[2], guid[3], guid[4], guid[5], guid[6], guid[7], guid[8], guid[9], guid[10], guid[11], guid[12], guid[13], guid[14], guid[15], idx);
}
#[allow(non_snake_case)]
pub unsafe fn cuGetExportTable(
ppExportTable: *mut *const ::std::os::raw::c_void,
pExportTableId: *const CUuuid,
cont: impl FnOnce(*mut *const ::std::os::raw::c_void, *const CUuuid) -> CUresult,
) -> CUresult {
if *pExportTableId == CUDART_INTERFACE_GUID {
if CUDART_INTERFACE_VTABLE.len() == 0 {
let mut base_table = ptr::null();
let base_result = cont(&mut base_table, pExportTableId);
if base_result != CUresult::CUDA_SUCCESS {
return base_result;
let guid = (*pExportTableId).bytes;
os_log!("Requested export table id: {{{:02X}{:02X}{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}", guid[0], guid[1], guid[2], guid[3], guid[4], guid[5], guid[6], guid[7], guid[8], guid[9], guid[10], guid[11], guid[12], guid[13], guid[14], guid[15]);
let result = cont(ppExportTable, pExportTableId);
if result == CUresult::CUDA_SUCCESS {
override_export_table(ppExportTable, pExportTableId);
}
result
}
unsafe fn override_export_table(
export_table_ptr: *mut *const ::std::os::raw::c_void,
export_table_id: *const CUuuid,
) {
let overrides_map = OVERRIDEN_INTERFACE_VTABLES.get_or_insert_with(|| HashMap::new());
if overrides_map.contains_key(&*export_table_id) {
return;
}
let export_table = (*export_table_ptr) as *mut *const c_void;
let boxed_guid = Box::new(*export_table_id);
let byte_length: usize = *(export_table as *const usize);
let mut override_table = Vec::new();
if byte_length < 0x10000 {
override_table.push(byte_length as *const _);
let length = byte_length / mem::size_of::<usize>();
for i in 1..length {
let current_fn = export_table.add(i);
if (*current_fn as usize) == usize::max_value() {
override_table.push(usize::max_value() as *const _);
break;
}
let len = *(base_table as *const usize);
CUDART_INTERFACE_VTABLE = vec![ptr::null(); len];
ptr::copy_nonoverlapping(
base_table as *const _,
CUDART_INTERFACE_VTABLE.as_mut_ptr(),
len,
);
if GET_MODULE_FROM_CUBIN_EXT_OFFSET >= len {
return CUresult::CUDA_ERROR_UNKNOWN;
}
ORIGINAL_GET_MODULE_FROM_CUBIN =
mem::transmute(CUDART_INTERFACE_VTABLE[GET_MODULE_FROM_CUBIN_OFFSET]);
CUDART_INTERFACE_VTABLE[GET_MODULE_FROM_CUBIN_OFFSET] =
get_module_from_cubin as *const _;
ORIGINAL_GET_MODULE_FROM_CUBIN_EXT =
mem::transmute(CUDART_INTERFACE_VTABLE[GET_MODULE_FROM_CUBIN_EXT_OFFSET]);
CUDART_INTERFACE_VTABLE[GET_MODULE_FROM_CUBIN_EXT_OFFSET] =
get_module_from_cubin_ext as *const _;
override_table.push(get_export_override_fn(*current_fn, &*boxed_guid, i));
}
*ppExportTable = CUDART_INTERFACE_VTABLE.as_ptr() as *const _;
return CUresult::CUDA_SUCCESS;
} else {
let guid = (*pExportTableId).bytes;
os_log!("Unsupported export table id: {{{:02X}{:02X}{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}", guid[0], guid[1], guid[2], guid[3], guid[4], guid[5], guid[6], guid[7], guid[8], guid[9], guid[10], guid[11], guid[12], guid[13], guid[14], guid[15]);
cont(ppExportTable, pExportTableId)
let mut i = 0;
loop {
let current_fn = export_table.add(i);
if (*current_fn as usize) == usize::max_value() {
override_table.push(usize::max_value() as *const _);
break;
}
override_table.push(get_export_override_fn(*current_fn, &*boxed_guid, i));
i += 1;
}
}
*export_table_ptr = override_table.as_ptr() as *const _;
overrides_map.insert(boxed_guid, override_table);
}
const TOOLS_RUNTIME_CALLBACK_HOOKS_GUID: CUuuid = CUuuid {
bytes: [
0xa0, 0x94, 0x79, 0x8c, 0x2e, 0x74, 0x2e, 0x74, 0x93, 0xf2, 0x08, 0x00, 0x20, 0x0c, 0x0a,
0x66,
],
};
const CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID: CUuuid = CUuuid {
bytes: [
0xc6, 0x93, 0x33, 0x6e, 0x11, 0x21, 0xdf, 0x11, 0xa8, 0xc3, 0x68, 0xf3, 0x55, 0xd8, 0x95,
0x93,
],
};
unsafe fn get_export_override_fn(
original_fn: *const c_void,
guid: *const CUuuid,
idx: usize,
) -> *const c_void {
match (*guid, idx) {
(TOOLS_RUNTIME_CALLBACK_HOOKS_GUID, 2)
| (TOOLS_RUNTIME_CALLBACK_HOOKS_GUID, 6)
| (CUDART_INTERFACE_GUID, 2)
| (CUDART_INTERFACE_GUID, 7)
| (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 0)
| (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 1)
| (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 2) => original_fn,
(CUDART_INTERFACE_GUID, 1) => {
ORIGINAL_GET_MODULE_FROM_CUBIN = mem::transmute(original_fn);
get_module_from_cubin as *const _
}
(CUDART_INTERFACE_GUID, 6) => {
ORIGINAL_GET_MODULE_FROM_CUBIN_EXT = mem::transmute(original_fn);
get_module_from_cubin_ext as *const _
}
_ => os::get_thunk(original_fn, report_unknown_export_table_call, guid, idx),
}
}

View file

@ -28,3 +28,34 @@ macro_rules! os_log {
}
};
}
//RDI, RSI, RDX, RCX, R8, R9
#[cfg(target_arch = "x86_64")]
pub fn get_thunk(
original_fn: *const c_void,
report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize),
guid: *const CUuuid,
idx: usize,
) -> *const c_void {
use dynasmrt::{dynasm, DynasmApi};
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
let start = ops.offset();
dynasm!(ops
; .arch x64
; push rdi
; push rsi
; mov rdi, QWORD guid as i64
; mov rsi, QWORD idx as i64
; mov rax, QWORD report_fn as i64
; call rax
; pop rsi
; pop rdi
; mov rax, QWORD original_fn as i64
; jmp rax
; int 3
);
let exe_buf = ops.finalize().unwrap();
let result_fn = exe_buf.ptr(start);
mem::forget(exe_buf);
result_fn as *const _
}

View file

@ -13,6 +13,8 @@ use winapi::{
um::libloaderapi::{GetProcAddress, LoadLibraryW},
};
use crate::cuda::CUuuid;
const NVCUDA_DEFAULT_PATH: &[u16] = wch_c!(r"C:\Windows\System32\nvcuda.dll");
const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0";
@ -97,3 +99,60 @@ pub fn __log_impl(s: String) {
unsafe { OutputDebugStringA(win_str.as_ptr() as *const _) };
}
}
#[cfg(target_arch = "x86")]
pub fn get_thunk(
original_fn: *const c_void,
report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize),
guid: *const CUuuid,
idx: usize,
) -> *const c_void {
use dynasmrt::{dynasm, DynasmApi};
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
let start = ops.offset();
dynasm!(ops
; .arch x86
; push idx as i32
; push guid as i32
; mov eax, report_fn as i32
; call eax
; mov eax, original_fn as i32
; jmp eax
; int 3
);
let exe_buf = ops.finalize().unwrap();
let result_fn = exe_buf.ptr(start);
mem::forget(exe_buf);
result_fn as *const _
}
//RCX, RDX, R8, R9
#[cfg(target_arch = "x86_64")]
pub fn get_thunk(
original_fn: *const c_void,
report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize),
guid: *const CUuuid,
idx: usize,
) -> *const c_void {
use dynasmrt::{dynasm, DynasmApi};
let mut ops = dynasmrt::x86::Assembler::new().unwrap();
let start = ops.offset();
dynasm!(ops
; .arch x64
; push rcx
; push rdx
; mov rcx, QWORD guid as i64
; mov rdx, QWORD idx as i64
; mov rax, QWORD report_fn as i64
; call rax
; pop rdx
; pop rcx
; mov rax, QWORD original_fn as i64
; jmp rax
; int 3
);
let exe_buf = ops.finalize().unwrap();
let result_fn = exe_buf.ptr(start);
mem::forget(exe_buf);
result_fn as *const _
}