mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-06 00:00:13 +00:00
Handle even more export table functions
This commit is contained in:
parent
dca4c5bd21
commit
89e72e4e95
6 changed files with 163 additions and 14 deletions
|
@ -15,6 +15,9 @@ lazy_static = "1.4"
|
||||||
num_enum = "0.4"
|
num_enum = "0.4"
|
||||||
lz4-sys = "1.9"
|
lz4-sys = "1.9"
|
||||||
|
|
||||||
|
[target.'cfg(windows)'.dependencies]
|
||||||
|
winapi = { version = "0.3", features = ["heapapi", "std"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
cuda-driver-sys = "0.3.0"
|
cuda-driver-sys = "0.3.0"
|
||||||
paste = "1.0"
|
paste = "1.0"
|
|
@ -1,3 +1,5 @@
|
||||||
|
use winapi::um::heapapi::{HeapAlloc, HeapFree};
|
||||||
|
|
||||||
use crate::cuda::CUresult;
|
use crate::cuda::CUresult;
|
||||||
use crate::{
|
use crate::{
|
||||||
cuda::{CUcontext, CUdevice, CUmodule, CUuuid},
|
cuda::{CUcontext, CUdevice, CUmodule, CUuuid},
|
||||||
|
@ -34,6 +36,14 @@ pub fn get(table: *mut *const std::os::raw::c_void, id: *const CUuuid) -> CUresu
|
||||||
unsafe { *table = CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_VTABLE.as_ptr() as *const _ };
|
unsafe { *table = CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_VTABLE.as_ptr() as *const _ };
|
||||||
CUresult::CUDA_SUCCESS
|
CUresult::CUDA_SUCCESS
|
||||||
}
|
}
|
||||||
|
CTX_CREATE_BYPASS_GUID => {
|
||||||
|
unsafe { *table = CTX_CREATE_BYPASS_VTABLE.as_ptr() as *const _ };
|
||||||
|
CUresult::CUDA_SUCCESS
|
||||||
|
}
|
||||||
|
HEAP_ACCESS_GUID => {
|
||||||
|
unsafe { *table = HEAP_ACCESS_VTABLE.as_ptr() as *const _ };
|
||||||
|
CUresult::CUDA_SUCCESS
|
||||||
|
}
|
||||||
_ => CUresult::CUDA_ERROR_NOT_SUPPORTED,
|
_ => CUresult::CUDA_ERROR_NOT_SUPPORTED,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -412,3 +422,106 @@ fn lock_context<T>(
|
||||||
})?
|
})?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const CTX_CREATE_BYPASS_GUID: CUuuid = CUuuid {
|
||||||
|
bytes: [
|
||||||
|
0x0C, 0xA5, 0x0B, 0x8C, 0x10, 0x04, 0x92, 0x9A, 0x89, 0xA7, 0xD0, 0xDF, 0x10, 0xE7, 0x72,
|
||||||
|
0x86,
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
const CTX_CREATE_BYPASS_LENGTH: usize = 2;
|
||||||
|
static CTX_CREATE_BYPASS_VTABLE: [VTableEntry; CTX_CREATE_BYPASS_LENGTH] = [
|
||||||
|
VTableEntry {
|
||||||
|
length: mem::size_of::<[VTableEntry; CTX_CREATE_BYPASS_LENGTH]>(),
|
||||||
|
},
|
||||||
|
VTableEntry {
|
||||||
|
ptr: ctx_create_v2_bypass as *const (),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
// I have no idea what is the difference between this function and
|
||||||
|
// cuCtxCreate_v2, but PhysX uses both interchangeably
|
||||||
|
extern "system" fn ctx_create_v2_bypass(
|
||||||
|
pctx: *mut CUcontext,
|
||||||
|
flags: ::std::os::raw::c_uint,
|
||||||
|
dev: CUdevice,
|
||||||
|
) -> CUresult {
|
||||||
|
context::create_v2(pctx.decuda(), flags, dev.decuda()).encuda()
|
||||||
|
}
|
||||||
|
|
||||||
|
const HEAP_ACCESS_GUID: CUuuid = CUuuid {
|
||||||
|
bytes: [
|
||||||
|
0x19, 0x5B, 0xCB, 0xF4, 0xD6, 0x7D, 0x02, 0x4A, 0xAC, 0xC5, 0x1D, 0x29, 0xCE, 0xA6, 0x31,
|
||||||
|
0xAE,
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct HeapAllocRecord {
|
||||||
|
arg1: usize,
|
||||||
|
arg2: usize,
|
||||||
|
_unknown: usize,
|
||||||
|
global_heap: *mut c_void,
|
||||||
|
}
|
||||||
|
|
||||||
|
const HEAP_ACCESS_LENGTH: usize = 3;
|
||||||
|
static HEAP_ACCESS_VTABLE: [VTableEntry; HEAP_ACCESS_LENGTH] = [
|
||||||
|
VTableEntry {
|
||||||
|
length: mem::size_of::<[VTableEntry; HEAP_ACCESS_LENGTH]>(),
|
||||||
|
},
|
||||||
|
VTableEntry {
|
||||||
|
ptr: heap_alloc as *const (),
|
||||||
|
},
|
||||||
|
VTableEntry {
|
||||||
|
ptr: heap_free as *const (),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
// TODO: reverse and implement for Linux
|
||||||
|
unsafe extern "system" fn heap_alloc(
|
||||||
|
halloc_ptr: *mut *const HeapAllocRecord,
|
||||||
|
arg1: usize,
|
||||||
|
arg2: usize,
|
||||||
|
) -> CUresult {
|
||||||
|
if halloc_ptr == ptr::null_mut() {
|
||||||
|
return CUresult::CUDA_ERROR_INVALID_VALUE;
|
||||||
|
}
|
||||||
|
let halloc = GlobalState::lock(|global_state| {
|
||||||
|
let halloc = HeapAlloc(
|
||||||
|
global_state.global_heap,
|
||||||
|
0,
|
||||||
|
mem::size_of::<HeapAllocRecord>(),
|
||||||
|
) as *mut HeapAllocRecord;
|
||||||
|
if halloc == ptr::null_mut() {
|
||||||
|
return Err(CUresult::CUDA_ERROR_OUT_OF_MEMORY);
|
||||||
|
}
|
||||||
|
(*halloc).arg1 = arg1;
|
||||||
|
(*halloc).arg2 = arg2;
|
||||||
|
(*halloc)._unknown = 0;
|
||||||
|
(*halloc).global_heap = global_state.global_heap;
|
||||||
|
Ok(halloc)
|
||||||
|
});
|
||||||
|
match halloc {
|
||||||
|
Ok(Ok(halloc)) => {
|
||||||
|
*halloc_ptr = halloc;
|
||||||
|
CUresult::CUDA_SUCCESS
|
||||||
|
}
|
||||||
|
Err(err) | Ok(Err(err)) => err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: reverse and implement for Linux
|
||||||
|
unsafe extern "system" fn heap_free(halloc: *mut HeapAllocRecord, arg1: *mut usize) -> CUresult {
|
||||||
|
if halloc == ptr::null_mut() {
|
||||||
|
return CUresult::CUDA_ERROR_INVALID_VALUE;
|
||||||
|
}
|
||||||
|
if arg1 != ptr::null_mut() {
|
||||||
|
*arg1 = (*halloc).arg2;
|
||||||
|
}
|
||||||
|
GlobalState::lock(|global_state| {
|
||||||
|
HeapFree(global_state.global_heap, 0, halloc as *mut _);
|
||||||
|
()
|
||||||
|
})
|
||||||
|
.encuda()
|
||||||
|
}
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
use winapi::um::{heapapi::HeapCreate, winnt::HEAP_NO_SERIALIZE};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st},
|
cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st},
|
||||||
r#impl::device::Device,
|
r#impl::device::Device,
|
||||||
|
@ -203,6 +205,7 @@ lazy_static! {
|
||||||
|
|
||||||
struct GlobalState {
|
struct GlobalState {
|
||||||
devices: Vec<Device>,
|
devices: Vec<Device>,
|
||||||
|
global_heap: *mut c_void,
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl Send for GlobalState {}
|
unsafe impl Send for GlobalState {}
|
||||||
|
@ -301,7 +304,14 @@ pub fn init() -> Result<(), CUresult> {
|
||||||
None => return Err(CUresult::CUDA_ERROR_UNKNOWN),
|
None => return Err(CUresult::CUDA_ERROR_UNKNOWN),
|
||||||
Some(driver) => device::init(&driver)?,
|
Some(driver) => device::init(&driver)?,
|
||||||
};
|
};
|
||||||
*global_state = Some(GlobalState { devices });
|
let global_heap = unsafe { HeapCreate(HEAP_NO_SERIALIZE, 0, 0) };
|
||||||
|
if global_heap == ptr::null_mut() {
|
||||||
|
return Err(CUresult::CUDA_ERROR_OUT_OF_MEMORY);
|
||||||
|
}
|
||||||
|
*global_state = Some(GlobalState {
|
||||||
|
devices,
|
||||||
|
global_heap,
|
||||||
|
});
|
||||||
drop(global_state);
|
drop(global_state);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -685,7 +685,7 @@ static mut ORIGINAL_GET_MODULE_FROM_CUBIN_EXT: Option<
|
||||||
) -> CUresult,
|
) -> CUresult,
|
||||||
> = None;
|
> = None;
|
||||||
|
|
||||||
unsafe extern "stdcall" fn report_unknown_export_table_call(
|
unsafe extern "system" fn report_unknown_export_table_call(
|
||||||
export_table: *const CUuuid,
|
export_table: *const CUuuid,
|
||||||
idx: usize,
|
idx: usize,
|
||||||
) {
|
) {
|
||||||
|
@ -699,22 +699,27 @@ pub unsafe fn cuGetExportTable(
|
||||||
pExportTableId: *const CUuuid,
|
pExportTableId: *const CUuuid,
|
||||||
cont: impl FnOnce(*mut *const ::std::os::raw::c_void, *const CUuuid) -> CUresult,
|
cont: impl FnOnce(*mut *const ::std::os::raw::c_void, *const CUuuid) -> CUresult,
|
||||||
) -> CUresult {
|
) -> CUresult {
|
||||||
|
if ppExportTable == ptr::null_mut() || pExportTableId == ptr::null() {
|
||||||
|
return CUresult::CUDA_ERROR_INVALID_VALUE;
|
||||||
|
}
|
||||||
let guid = (*pExportTableId).bytes;
|
let guid = (*pExportTableId).bytes;
|
||||||
os_log!("Requested export table id: {{{:02X}{:02X}{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}", guid[0], guid[1], guid[2], guid[3], guid[4], guid[5], guid[6], guid[7], guid[8], guid[9], guid[10], guid[11], guid[12], guid[13], guid[14], guid[15]);
|
os_log!("Requested export table id: {{{:02X}{:02X}{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}", guid[0], guid[1], guid[2], guid[3], guid[4], guid[5], guid[6], guid[7], guid[8], guid[9], guid[10], guid[11], guid[12], guid[13], guid[14], guid[15]);
|
||||||
let result = cont(ppExportTable, pExportTableId);
|
override_export_table(ppExportTable, pExportTableId, cont)
|
||||||
if result == CUresult::CUDA_SUCCESS {
|
|
||||||
override_export_table(ppExportTable, pExportTableId);
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn override_export_table(
|
unsafe fn override_export_table(
|
||||||
export_table_ptr: *mut *const ::std::os::raw::c_void,
|
export_table_ptr: *mut *const ::std::os::raw::c_void,
|
||||||
export_table_id: *const CUuuid,
|
export_table_id: *const CUuuid,
|
||||||
) {
|
cont: impl FnOnce(*mut *const ::std::os::raw::c_void, *const CUuuid) -> CUresult,
|
||||||
|
) -> CUresult {
|
||||||
let overrides_map = OVERRIDEN_INTERFACE_VTABLES.get_or_insert_with(|| HashMap::new());
|
let overrides_map = OVERRIDEN_INTERFACE_VTABLES.get_or_insert_with(|| HashMap::new());
|
||||||
if overrides_map.contains_key(&*export_table_id) {
|
if let Some(override_table) = overrides_map.get(&*export_table_id) {
|
||||||
return;
|
*export_table_ptr = override_table.as_ptr() as *const _;
|
||||||
|
return CUresult::CUDA_SUCCESS;
|
||||||
|
}
|
||||||
|
let base_result = cont(export_table_ptr, export_table_id);
|
||||||
|
if base_result != CUresult::CUDA_SUCCESS {
|
||||||
|
return base_result;
|
||||||
}
|
}
|
||||||
let export_table = (*export_table_ptr) as *mut *const c_void;
|
let export_table = (*export_table_ptr) as *mut *const c_void;
|
||||||
let boxed_guid = Box::new(*export_table_id);
|
let boxed_guid = Box::new(*export_table_id);
|
||||||
|
@ -745,6 +750,7 @@ unsafe fn override_export_table(
|
||||||
}
|
}
|
||||||
*export_table_ptr = override_table.as_ptr() as *const _;
|
*export_table_ptr = override_table.as_ptr() as *const _;
|
||||||
overrides_map.insert(boxed_guid, override_table);
|
overrides_map.insert(boxed_guid, override_table);
|
||||||
|
CUresult::CUDA_SUCCESS
|
||||||
}
|
}
|
||||||
|
|
||||||
const TOOLS_RUNTIME_CALLBACK_HOOKS_GUID: CUuuid = CUuuid {
|
const TOOLS_RUNTIME_CALLBACK_HOOKS_GUID: CUuuid = CUuuid {
|
||||||
|
@ -761,6 +767,20 @@ const CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID: CUuuid = CUuuid {
|
||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const CTX_CREATE_BYPASS_GUID: CUuuid = CUuuid {
|
||||||
|
bytes: [
|
||||||
|
0x0C, 0xA5, 0x0B, 0x8C, 0x10, 0x04, 0x92, 0x9A, 0x89, 0xA7, 0xD0, 0xDF, 0x10, 0xE7, 0x72,
|
||||||
|
0x86,
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
const HEAP_ACCESS_GUID: CUuuid = CUuuid {
|
||||||
|
bytes: [
|
||||||
|
0x19, 0x5B, 0xCB, 0xF4, 0xD6, 0x7D, 0x02, 0x4A, 0xAC, 0xC5, 0x1D, 0x29, 0xCE, 0xA6, 0x31,
|
||||||
|
0xAE,
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
unsafe fn get_export_override_fn(
|
unsafe fn get_export_override_fn(
|
||||||
original_fn: *const c_void,
|
original_fn: *const c_void,
|
||||||
guid: *const CUuuid,
|
guid: *const CUuuid,
|
||||||
|
@ -773,7 +793,10 @@ unsafe fn get_export_override_fn(
|
||||||
| (CUDART_INTERFACE_GUID, 7)
|
| (CUDART_INTERFACE_GUID, 7)
|
||||||
| (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 0)
|
| (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 0)
|
||||||
| (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 1)
|
| (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 1)
|
||||||
| (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 2) => original_fn,
|
| (CONTEXT_LOCAL_STORAGE_INTERFACE_V0301_GUID, 2)
|
||||||
|
| (CTX_CREATE_BYPASS_GUID, 1)
|
||||||
|
| (HEAP_ACCESS_GUID, 1)
|
||||||
|
| (HEAP_ACCESS_GUID, 2) => original_fn,
|
||||||
(CUDART_INTERFACE_GUID, 1) => {
|
(CUDART_INTERFACE_GUID, 1) => {
|
||||||
ORIGINAL_GET_MODULE_FROM_CUBIN = mem::transmute(original_fn);
|
ORIGINAL_GET_MODULE_FROM_CUBIN = mem::transmute(original_fn);
|
||||||
get_module_from_cubin as *const _
|
get_module_from_cubin as *const _
|
||||||
|
|
|
@ -33,7 +33,7 @@ macro_rules! os_log {
|
||||||
#[cfg(target_arch = "x86_64")]
|
#[cfg(target_arch = "x86_64")]
|
||||||
pub fn get_thunk(
|
pub fn get_thunk(
|
||||||
original_fn: *const c_void,
|
original_fn: *const c_void,
|
||||||
report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize),
|
report_fn: unsafe extern "system" fn(*const CUuuid, usize),
|
||||||
guid: *const CUuuid,
|
guid: *const CUuuid,
|
||||||
idx: usize,
|
idx: usize,
|
||||||
) -> *const c_void {
|
) -> *const c_void {
|
||||||
|
|
|
@ -103,7 +103,7 @@ pub fn __log_impl(s: String) {
|
||||||
#[cfg(target_arch = "x86")]
|
#[cfg(target_arch = "x86")]
|
||||||
pub fn get_thunk(
|
pub fn get_thunk(
|
||||||
original_fn: *const c_void,
|
original_fn: *const c_void,
|
||||||
report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize),
|
report_fn: unsafe extern "system" fn(*const CUuuid, usize),
|
||||||
guid: *const CUuuid,
|
guid: *const CUuuid,
|
||||||
idx: usize,
|
idx: usize,
|
||||||
) -> *const c_void {
|
) -> *const c_void {
|
||||||
|
@ -130,7 +130,7 @@ pub fn get_thunk(
|
||||||
#[cfg(target_arch = "x86_64")]
|
#[cfg(target_arch = "x86_64")]
|
||||||
pub fn get_thunk(
|
pub fn get_thunk(
|
||||||
original_fn: *const c_void,
|
original_fn: *const c_void,
|
||||||
report_fn: unsafe extern "stdcall" fn(*const CUuuid, usize),
|
report_fn: unsafe extern "system" fn(*const CUuuid, usize),
|
||||||
guid: *const CUuuid,
|
guid: *const CUuuid,
|
||||||
idx: usize,
|
idx: usize,
|
||||||
) -> *const c_void {
|
) -> *const c_void {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue