mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
Fix remaining issues with detouring nvcuda
This commit is contained in:
parent
26bf0eeaf2
commit
2c6d7ffb7a
4 changed files with 112 additions and 29 deletions
|
@ -66,7 +66,6 @@ impl PlatformLibrary {
|
|||
if module == ptr::null_mut() {
|
||||
break;
|
||||
}
|
||||
let mut size = 0;
|
||||
let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
|
||||
if payload != ptr::null_mut() {
|
||||
return Some(module as _);
|
||||
|
|
|
@ -13,7 +13,7 @@ winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchap
|
|||
detours-sys = { path = "../detours-sys" }
|
||||
|
||||
[dev-dependencies]
|
||||
# dependency for integration tests
|
||||
# all of those are used in integration tests
|
||||
zluda_redirect = { path = "../zluda_redirect" }
|
||||
# dependency for integration tests
|
||||
zluda_dump = { path = "../zluda_dump" }
|
||||
zluda_ml = { path = "../zluda_ml" }
|
||||
|
|
|
@ -10,4 +10,5 @@ crate-type = ["cdylib"]
|
|||
[target.'cfg(windows)'.dependencies]
|
||||
detours-sys = { path = "../detours-sys" }
|
||||
wchar = "0.6"
|
||||
winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "handleapi", "std"] }
|
||||
winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "handleapi", "std"] }
|
||||
tempfile = "3"
|
|
@ -6,33 +6,18 @@ extern crate winapi;
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
ffi::{c_void, CStr},
|
||||
mem,
|
||||
io, mem,
|
||||
os::raw::c_uint,
|
||||
ptr, slice, usize,
|
||||
};
|
||||
|
||||
use detours_sys::{
|
||||
DetourAttach, DetourEnumerateExports, DetourRestoreAfterWith, DetourTransactionAbort,
|
||||
DetourTransactionBegin, DetourTransactionCommit, DetourUpdateProcessWithDll,
|
||||
DetourUpdateThread,
|
||||
DetourAttach, DetourEnumerateExports, DetourGetEntryPoint, DetourRestoreAfterWith,
|
||||
DetourTransactionAbort, DetourTransactionBegin, DetourTransactionCommit,
|
||||
DetourUpdateProcessWithDll, DetourUpdateThread,
|
||||
};
|
||||
use tempfile::TempDir;
|
||||
use wchar::wch;
|
||||
use winapi::{
|
||||
shared::minwindef::{BOOL, LPVOID},
|
||||
um::{
|
||||
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
|
||||
minwinbase::LPSECURITY_ATTRIBUTES,
|
||||
processthreadsapi::{
|
||||
CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
|
||||
SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, LPSTARTUPINFOW,
|
||||
},
|
||||
tlhelp32::{
|
||||
CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32,
|
||||
},
|
||||
winbase::CREATE_SUSPENDED,
|
||||
winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME},
|
||||
},
|
||||
};
|
||||
use winapi::{
|
||||
shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
|
||||
um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR},
|
||||
|
@ -50,6 +35,26 @@ use winapi::{
|
|||
shared::winerror::NO_ERROR,
|
||||
um::libloaderapi::{LoadLibraryA, LoadLibraryExW, LoadLibraryW},
|
||||
};
|
||||
use winapi::{
|
||||
shared::{
|
||||
minwindef::{BOOL, LPVOID},
|
||||
winerror::E_UNEXPECTED,
|
||||
},
|
||||
um::{
|
||||
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
|
||||
libloaderapi::GetModuleHandleW,
|
||||
minwinbase::LPSECURITY_ATTRIBUTES,
|
||||
processthreadsapi::{
|
||||
CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
|
||||
SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, LPSTARTUPINFOW,
|
||||
},
|
||||
tlhelp32::{
|
||||
CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32,
|
||||
},
|
||||
winbase::{CopyFileW, CreateSymbolicLinkW, CREATE_SUSPENDED},
|
||||
winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME},
|
||||
},
|
||||
};
|
||||
|
||||
include!("payload_guid.rs");
|
||||
|
||||
|
@ -375,6 +380,59 @@ unsafe extern "system" fn ZludaCreateProcessWithTokenW(
|
|||
continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation)
|
||||
}
|
||||
|
||||
static mut MAIN: unsafe extern "system" fn() -> DWORD = ZludaMain;
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order#search-order-for-desktop-applications
|
||||
// "If a DLL with the same module name is already loaded in memory, the system
|
||||
// uses the loaded DLL, no matter which directory it is in. The system does not
|
||||
// search for the DLL."
|
||||
#[allow(non_snake_case)]
|
||||
unsafe extern "system" fn ZludaMain() -> DWORD {
|
||||
let temp_dir = match do_zluda_preload() {
|
||||
Ok(f) => f,
|
||||
Err(e) => return e.raw_os_error().unwrap_or(E_UNEXPECTED) as u32,
|
||||
};
|
||||
let result = MAIN();
|
||||
drop(temp_dir);
|
||||
result
|
||||
}
|
||||
|
||||
unsafe fn do_zluda_preload() -> std::io::Result<TempDir> {
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
do_single_zluda_preload(&temp_dir, ZLUDA_PATH_UTF16.unwrap().as_ptr(), NVCUDA_UTF8)?;
|
||||
do_single_zluda_preload(&temp_dir, ZLUDA_ML_PATH_UTF16.unwrap().as_ptr(), NVML_UTF8)?;
|
||||
Ok(temp_dir)
|
||||
}
|
||||
|
||||
unsafe fn do_single_zluda_preload(
|
||||
temp_dir: &TempDir,
|
||||
full_path: *const u16,
|
||||
file_name: &'static str,
|
||||
) -> io::Result<()> {
|
||||
let mut temp_file_path = temp_dir.path().to_path_buf();
|
||||
temp_file_path.push(file_name);
|
||||
let mut temp_file_path_utf16 = temp_file_path
|
||||
.into_os_string()
|
||||
.to_string_lossy()
|
||||
.encode_utf16()
|
||||
.collect::<Vec<_>>();
|
||||
temp_file_path_utf16.push(0);
|
||||
// Probably we are not in developer mode, do a copty then
|
||||
if 0 == CreateSymbolicLinkW(
|
||||
temp_file_path_utf16.as_ptr(),
|
||||
full_path,
|
||||
0x2, //SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE
|
||||
) {
|
||||
if 0 == CopyFileW(full_path, temp_file_path_utf16.as_ptr(), 1) {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
}
|
||||
if ptr::null_mut() == ZludaLoadLibraryW_NoRedirect(temp_file_path_utf16.as_ptr()) {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// This type encapsulates typical calling sequence of detours and cleanup.
|
||||
// We have two ways we do detours:
|
||||
// * If we are loaded before nvcuda.dll, we hook LoadLibrary*
|
||||
|
@ -668,8 +726,8 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
|
|||
// redirecting LoadLibrary* to load ZLUDA, we override already loaded
|
||||
// functions
|
||||
let detach_guard = match get_cuinit() {
|
||||
Some((nvcuda_mod, _)) => attach_cuinit(nvcuda_mod),
|
||||
None => attach_load_libary(),
|
||||
Some((nvcuda_mod, _)) => detour_already_loaded_nvcuda(nvcuda_mod),
|
||||
None => detour_main(),
|
||||
};
|
||||
match detach_guard {
|
||||
Some(g) => {
|
||||
|
@ -724,7 +782,7 @@ unsafe fn get_cuinit() -> Option<(HMODULE, FARPROC)> {
|
|||
}
|
||||
|
||||
#[must_use]
|
||||
unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
|
||||
unsafe fn detour_already_loaded_nvcuda(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
|
||||
let zluda_module = LoadLibraryW(ZLUDA_PATH_UTF16.unwrap().as_ptr());
|
||||
if zluda_module == ptr::null_mut() {
|
||||
return None;
|
||||
|
@ -747,7 +805,22 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
|
|||
(original_fn_address as _, override_fn_address),
|
||||
);
|
||||
}
|
||||
DetourDetachGuard::detour_functions(nvcuda_mod, Vec::new(), override_fn_pairs)
|
||||
let detour_functions = vec![
|
||||
(
|
||||
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
|
||||
ZludaLoadLibraryA as *mut c_void,
|
||||
),
|
||||
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
|
||||
(
|
||||
&mut LOAD_LIBRARY_EX_A as *mut _ as _,
|
||||
ZludaLoadLibraryExA as _,
|
||||
),
|
||||
(
|
||||
&mut LOAD_LIBRARY_EX_W as *mut _ as _,
|
||||
ZludaLoadLibraryExW as _,
|
||||
),
|
||||
];
|
||||
DetourDetachGuard::detour_functions(nvcuda_mod, detour_functions, override_fn_pairs)
|
||||
}
|
||||
|
||||
unsafe extern "system" fn cuda_unsupported() -> c_uint {
|
||||
|
@ -776,8 +849,18 @@ unsafe extern "stdcall" fn gather_imports_impl(
|
|||
}
|
||||
|
||||
#[must_use]
|
||||
unsafe fn attach_load_libary() -> Option<DetourDetachGuard> {
|
||||
unsafe fn detour_main() -> Option<DetourDetachGuard> {
|
||||
let exe_handle = GetModuleHandleW(ptr::null());
|
||||
let entry_point = DetourGetEntryPoint(exe_handle as _);
|
||||
if entry_point == ptr::null_mut() {
|
||||
return None;
|
||||
}
|
||||
MAIN = mem::transmute(entry_point);
|
||||
let detour_functions = vec![
|
||||
(
|
||||
&mut MAIN as *mut _ as *mut *mut c_void,
|
||||
ZludaMain as *mut c_void,
|
||||
),
|
||||
(
|
||||
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
|
||||
ZludaLoadLibraryA as *mut c_void,
|
||||
|
|
Loading…
Add table
Reference in a new issue