Fix remaining issues with detouring nvcuda

This commit is contained in:
Andrzej Janik 2021-12-05 23:01:46 +01:00
parent 26bf0eeaf2
commit 2c6d7ffb7a
4 changed files with 112 additions and 29 deletions

View file

@ -66,7 +66,6 @@ impl PlatformLibrary {
if module == ptr::null_mut() {
break;
}
let mut size = 0;
let payload = GetProcAddress(module as _, b"ZLUDA_REDIRECT\0".as_ptr() as _);
if payload != ptr::null_mut() {
return Some(module as _);

View file

@ -13,7 +13,7 @@ winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchap
detours-sys = { path = "../detours-sys" }
[dev-dependencies]
# dependency for integration tests
# all of those are used in integration tests
zluda_redirect = { path = "../zluda_redirect" }
# dependency for integration tests
zluda_dump = { path = "../zluda_dump" }
zluda_ml = { path = "../zluda_ml" }

View file

@ -10,4 +10,5 @@ crate-type = ["cdylib"]
[target.'cfg(windows)'.dependencies]
detours-sys = { path = "../detours-sys" }
wchar = "0.6"
winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "handleapi", "std"] }
winapi = { version = "0.3", features = ["processthreadsapi", "winbase", "winnt", "winerror", "libloaderapi", "tlhelp32", "handleapi", "std"] }
tempfile = "3"

View file

@ -6,33 +6,18 @@ extern crate winapi;
use std::{
collections::HashMap,
ffi::{c_void, CStr},
mem,
io, mem,
os::raw::c_uint,
ptr, slice, usize,
};
use detours_sys::{
DetourAttach, DetourEnumerateExports, DetourRestoreAfterWith, DetourTransactionAbort,
DetourTransactionBegin, DetourTransactionCommit, DetourUpdateProcessWithDll,
DetourUpdateThread,
DetourAttach, DetourEnumerateExports, DetourGetEntryPoint, DetourRestoreAfterWith,
DetourTransactionAbort, DetourTransactionBegin, DetourTransactionCommit,
DetourUpdateProcessWithDll, DetourUpdateThread,
};
use tempfile::TempDir;
use wchar::wch;
use winapi::{
shared::minwindef::{BOOL, LPVOID},
um::{
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
minwinbase::LPSECURITY_ATTRIBUTES,
processthreadsapi::{
CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, LPSTARTUPINFOW,
},
tlhelp32::{
CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32,
},
winbase::CREATE_SUSPENDED,
winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME},
},
};
use winapi::{
shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR},
@ -50,6 +35,26 @@ use winapi::{
shared::winerror::NO_ERROR,
um::libloaderapi::{LoadLibraryA, LoadLibraryExW, LoadLibraryW},
};
use winapi::{
shared::{
minwindef::{BOOL, LPVOID},
winerror::E_UNEXPECTED,
},
um::{
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
libloaderapi::GetModuleHandleW,
minwinbase::LPSECURITY_ATTRIBUTES,
processthreadsapi::{
CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
SuspendThread, TerminateProcess, LPPROCESS_INFORMATION, LPSTARTUPINFOA, LPSTARTUPINFOW,
},
tlhelp32::{
CreateToolhelp32Snapshot, Thread32First, Thread32Next, TH32CS_SNAPTHREAD, THREADENTRY32,
},
winbase::{CopyFileW, CreateSymbolicLinkW, CREATE_SUSPENDED},
winnt::{LPSTR, LPWSTR, THREAD_SUSPEND_RESUME},
},
};
include!("payload_guid.rs");
@ -375,6 +380,59 @@ unsafe extern "system" fn ZludaCreateProcessWithTokenW(
continue_create_process_hook(create_proc_result, dwCreationFlags, lpProcessInformation)
}
static mut MAIN: unsafe extern "system" fn() -> DWORD = ZludaMain;
// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order#search-order-for-desktop-applications
// "If a DLL with the same module name is already loaded in memory, the system
// uses the loaded DLL, no matter which directory it is in. The system does not
// search for the DLL."
#[allow(non_snake_case)]
unsafe extern "system" fn ZludaMain() -> DWORD {
let temp_dir = match do_zluda_preload() {
Ok(f) => f,
Err(e) => return e.raw_os_error().unwrap_or(E_UNEXPECTED) as u32,
};
let result = MAIN();
drop(temp_dir);
result
}
unsafe fn do_zluda_preload() -> std::io::Result<TempDir> {
let temp_dir = tempfile::tempdir()?;
do_single_zluda_preload(&temp_dir, ZLUDA_PATH_UTF16.unwrap().as_ptr(), NVCUDA_UTF8)?;
do_single_zluda_preload(&temp_dir, ZLUDA_ML_PATH_UTF16.unwrap().as_ptr(), NVML_UTF8)?;
Ok(temp_dir)
}
unsafe fn do_single_zluda_preload(
temp_dir: &TempDir,
full_path: *const u16,
file_name: &'static str,
) -> io::Result<()> {
let mut temp_file_path = temp_dir.path().to_path_buf();
temp_file_path.push(file_name);
let mut temp_file_path_utf16 = temp_file_path
.into_os_string()
.to_string_lossy()
.encode_utf16()
.collect::<Vec<_>>();
temp_file_path_utf16.push(0);
// Probably we are not in developer mode, do a copty then
if 0 == CreateSymbolicLinkW(
temp_file_path_utf16.as_ptr(),
full_path,
0x2, //SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE
) {
if 0 == CopyFileW(full_path, temp_file_path_utf16.as_ptr(), 1) {
return Err(io::Error::last_os_error());
}
}
if ptr::null_mut() == ZludaLoadLibraryW_NoRedirect(temp_file_path_utf16.as_ptr()) {
return Err(io::Error::last_os_error());
}
Ok(())
}
// This type encapsulates typical calling sequence of detours and cleanup.
// We have two ways we do detours:
// * If we are loaded before nvcuda.dll, we hook LoadLibrary*
@ -668,8 +726,8 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
// redirecting LoadLibrary* to load ZLUDA, we override already loaded
// functions
let detach_guard = match get_cuinit() {
Some((nvcuda_mod, _)) => attach_cuinit(nvcuda_mod),
None => attach_load_libary(),
Some((nvcuda_mod, _)) => detour_already_loaded_nvcuda(nvcuda_mod),
None => detour_main(),
};
match detach_guard {
Some(g) => {
@ -724,7 +782,7 @@ unsafe fn get_cuinit() -> Option<(HMODULE, FARPROC)> {
}
#[must_use]
unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
unsafe fn detour_already_loaded_nvcuda(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
let zluda_module = LoadLibraryW(ZLUDA_PATH_UTF16.unwrap().as_ptr());
if zluda_module == ptr::null_mut() {
return None;
@ -747,7 +805,22 @@ unsafe fn attach_cuinit(nvcuda_mod: HMODULE) -> Option<DetourDetachGuard> {
(original_fn_address as _, override_fn_address),
);
}
DetourDetachGuard::detour_functions(nvcuda_mod, Vec::new(), override_fn_pairs)
let detour_functions = vec![
(
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
ZludaLoadLibraryA as *mut c_void,
),
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
(
&mut LOAD_LIBRARY_EX_A as *mut _ as _,
ZludaLoadLibraryExA as _,
),
(
&mut LOAD_LIBRARY_EX_W as *mut _ as _,
ZludaLoadLibraryExW as _,
),
];
DetourDetachGuard::detour_functions(nvcuda_mod, detour_functions, override_fn_pairs)
}
unsafe extern "system" fn cuda_unsupported() -> c_uint {
@ -776,8 +849,18 @@ unsafe extern "stdcall" fn gather_imports_impl(
}
#[must_use]
unsafe fn attach_load_libary() -> Option<DetourDetachGuard> {
unsafe fn detour_main() -> Option<DetourDetachGuard> {
let exe_handle = GetModuleHandleW(ptr::null());
let entry_point = DetourGetEntryPoint(exe_handle as _);
if entry_point == ptr::null_mut() {
return None;
}
MAIN = mem::transmute(entry_point);
let detour_functions = vec![
(
&mut MAIN as *mut _ as *mut *mut c_void,
ZludaMain as *mut c_void,
),
(
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
ZludaLoadLibraryA as *mut c_void,