mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-10-04 07:09:55 +00:00
Overhaul DLL injection
This commit is contained in:
parent
c869a0d611
commit
2753d956df
13 changed files with 295 additions and 430 deletions
|
@ -10,6 +10,8 @@ path = "src/main.rs"
|
|||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "synchapi", "winbase", "std"] }
|
||||
tempfile = "3"
|
||||
argh = "0.1"
|
||||
detours-sys = { path = "../detours-sys" }
|
||||
|
||||
[dev-dependencies]
|
||||
|
|
|
@ -43,6 +43,8 @@ fn main() -> Result<(), VarError> {
|
|||
.arg("-ldylib=nvcuda")
|
||||
.arg("-C")
|
||||
.arg(format!("opt-level={}", opt_level))
|
||||
.arg("-L")
|
||||
.arg(format!("{}", out_dir))
|
||||
.arg("--out-dir")
|
||||
.arg(format!("{}", out_dir))
|
||||
.arg("--target")
|
||||
|
@ -52,11 +54,11 @@ fn main() -> Result<(), VarError> {
|
|||
}
|
||||
std::fs::copy(
|
||||
format!(
|
||||
"{}{}do_cuinit_main_clr.exe",
|
||||
"{}{}do_cuinit_late_clr.exe",
|
||||
helpers_dir_as_string,
|
||||
path::MAIN_SEPARATOR
|
||||
),
|
||||
format!("{}{}do_cuinit_main_clr.exe", out_dir, path::MAIN_SEPARATOR),
|
||||
format!("{}{}do_cuinit_late_clr.exe", out_dir, path::MAIN_SEPARATOR),
|
||||
)
|
||||
.unwrap();
|
||||
println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir);
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
use std::env;
|
||||
use std::os::windows;
|
||||
use std::os::windows::ffi::OsStrExt;
|
||||
use std::path::Path;
|
||||
use std::ptr;
|
||||
use std::{env, ops::Deref};
|
||||
use std::{error::Error, process};
|
||||
use std::{fs, io, ptr};
|
||||
use std::{mem, path::PathBuf};
|
||||
|
||||
use argh::FromArgs;
|
||||
use mem::size_of_val;
|
||||
use tempfile::TempDir;
|
||||
use winapi::um::processenv::SearchPathW;
|
||||
use winapi::um::{
|
||||
jobapi2::{AssignProcessToJobObject, SetInformationJobObject},
|
||||
processthreadsapi::{GetExitCodeProcess, ResumeThread},
|
||||
|
@ -20,28 +23,46 @@ use winapi::um::{
|
|||
use winapi::um::winbase::{INFINITE, WAIT_FAILED};
|
||||
|
||||
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
|
||||
static ZLUDA_DLL: &'static str = "nvcuda.dll";
|
||||
static ZLUDA_ML_DLL: &'static str = "nvml.dll";
|
||||
static NVCUDA_DLL: &'static str = "nvcuda.dll";
|
||||
static NVML_DLL: &'static str = "nvml.dll";
|
||||
|
||||
include!("../../zluda_redirect/src/payload_guid.rs");
|
||||
|
||||
#[derive(FromArgs)]
|
||||
/// Launch application with custom CUDA libraries
|
||||
struct ProgramArguments {
|
||||
/// DLL to be injected instead of system nvcuda.dll. If not provided {0} will use nvcuda.dll from its directory
|
||||
#[argh(option)]
|
||||
nvcuda: Option<PathBuf>,
|
||||
|
||||
/// DLL to be injected instead of system nvml.dll. If not provided {0} will use nvml.dll from its directory
|
||||
#[argh(option)]
|
||||
nvml: Option<PathBuf>,
|
||||
|
||||
/// executable to be injected with custom CUDA libraries
|
||||
#[argh(positional)]
|
||||
exe: String,
|
||||
|
||||
/// arguments to the executable
|
||||
#[argh(positional)]
|
||||
args: Vec<String>,
|
||||
}
|
||||
|
||||
pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
||||
let args = env::args().collect::<Vec<_>>();
|
||||
if args.len() <= 1 {
|
||||
print_help_and_exit();
|
||||
}
|
||||
let injector_path = env::current_exe()?;
|
||||
let injector_dir = injector_path.parent().unwrap();
|
||||
let redirect_path = create_redirect_path(injector_dir);
|
||||
let (mut inject_nvcuda_path, mut inject_nvml_path, cmd) =
|
||||
create_inject_path(&args[1..], injector_dir)?;
|
||||
let mut cmd_line = construct_command_line(cmd);
|
||||
let raw_args = argh::from_env::<ProgramArguments>();
|
||||
let normalized_args = NormalizedArguments::new(raw_args)?;
|
||||
let mut environment = Environment::setup(normalized_args)?;
|
||||
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
|
||||
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
|
||||
let mut dlls_to_inject = [
|
||||
environment.nvml_path_zero_terminated.as_ptr() as *const i8,
|
||||
environment.nvcuda_path_zero_terminated.as_ptr() as _,
|
||||
environment.redirect_path_zero_terminated.as_ptr() as _,
|
||||
];
|
||||
os_call!(
|
||||
detours_sys::DetourCreateProcessWithDllExW(
|
||||
detours_sys::DetourCreateProcessWithDllsW(
|
||||
ptr::null(),
|
||||
cmd_line.as_mut_ptr(),
|
||||
environment.winapi_command_line_zero_terminated.as_mut_ptr(),
|
||||
ptr::null_mut(),
|
||||
ptr::null_mut(),
|
||||
0,
|
||||
|
@ -50,7 +71,8 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
|||
ptr::null(),
|
||||
&mut startup_info as *mut _,
|
||||
&mut proc_info as *mut _,
|
||||
redirect_path.as_ptr() as *const i8,
|
||||
dlls_to_inject.len() as u32,
|
||||
dlls_to_inject.as_mut_ptr(),
|
||||
Option::None
|
||||
),
|
||||
|x| x != 0
|
||||
|
@ -60,8 +82,8 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
|||
detours_sys::DetourCopyPayloadToProcess(
|
||||
proc_info.hProcess,
|
||||
&PAYLOAD_NVCUDA_GUID,
|
||||
inject_nvcuda_path.as_mut_ptr() as *mut _,
|
||||
(inject_nvcuda_path.len() * mem::size_of::<u16>()) as u32
|
||||
environment.nvcuda_path_zero_terminated.as_ptr() as *mut _,
|
||||
environment.nvcuda_path_zero_terminated.len() as u32
|
||||
),
|
||||
|x| x != 0
|
||||
);
|
||||
|
@ -69,8 +91,8 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
|||
detours_sys::DetourCopyPayloadToProcess(
|
||||
proc_info.hProcess,
|
||||
&PAYLOAD_NVML_GUID,
|
||||
inject_nvml_path.as_mut_ptr() as *mut _,
|
||||
(inject_nvml_path.len() * mem::size_of::<u16>()) as u32
|
||||
environment.nvml_path_zero_terminated.as_ptr() as *mut _,
|
||||
environment.nvml_path_zero_terminated.len() as u32
|
||||
),
|
||||
|x| x != 0
|
||||
);
|
||||
|
@ -85,6 +107,135 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
|||
process::exit(child_exit_code as i32)
|
||||
}
|
||||
|
||||
struct NormalizedArguments {
|
||||
nvml_path: PathBuf,
|
||||
nvcuda_path: PathBuf,
|
||||
redirect_path: PathBuf,
|
||||
winapi_command_line_zero_terminated: Vec<u16>,
|
||||
}
|
||||
|
||||
impl NormalizedArguments {
|
||||
fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
|
||||
let current_exe = env::current_exe()?;
|
||||
let nvml_path = Self::get_absolute_path(¤t_exe, prog_args.nvml, NVML_DLL)?;
|
||||
let nvcuda_path = Self::get_absolute_path(¤t_exe, prog_args.nvcuda, NVCUDA_DLL)?;
|
||||
let winapi_command_line_zero_terminated =
|
||||
construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args));
|
||||
let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
|
||||
redirect_path.push(REDIRECT_DLL);
|
||||
Ok(Self {
|
||||
nvml_path,
|
||||
nvcuda_path,
|
||||
redirect_path,
|
||||
winapi_command_line_zero_terminated,
|
||||
})
|
||||
}
|
||||
|
||||
const WIN_MAX_PATH: usize = 260;
|
||||
|
||||
fn get_absolute_path(
|
||||
current_exe: &PathBuf,
|
||||
dll: Option<PathBuf>,
|
||||
default: &str,
|
||||
) -> Result<PathBuf, Box<dyn Error>> {
|
||||
Ok(if let Some(dll) = dll {
|
||||
if dll.is_absolute() {
|
||||
dll
|
||||
} else {
|
||||
let mut full_dll_path = vec![0; Self::WIN_MAX_PATH];
|
||||
let mut dll_utf16 = dll.as_os_str().encode_wide().collect::<Vec<_>>();
|
||||
dll_utf16.push(0);
|
||||
loop {
|
||||
let copied_len = os_call!(
|
||||
SearchPathW(
|
||||
ptr::null_mut(),
|
||||
dll_utf16.as_ptr(),
|
||||
ptr::null(),
|
||||
full_dll_path.len() as u32,
|
||||
full_dll_path.as_mut_ptr(),
|
||||
ptr::null_mut()
|
||||
),
|
||||
|x| x != 0
|
||||
) as usize;
|
||||
if copied_len > full_dll_path.len() {
|
||||
full_dll_path.resize(copied_len + 1, 0);
|
||||
} else {
|
||||
full_dll_path.truncate(copied_len);
|
||||
break;
|
||||
}
|
||||
}
|
||||
PathBuf::from(String::from_utf16_lossy(&full_dll_path))
|
||||
}
|
||||
} else {
|
||||
let mut dll_path = current_exe.parent().unwrap().to_path_buf();
|
||||
dll_path.push(default);
|
||||
dll_path
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct Environment {
|
||||
nvml_path_zero_terminated: String,
|
||||
nvcuda_path_zero_terminated: String,
|
||||
redirect_path_zero_terminated: String,
|
||||
winapi_command_line_zero_terminated: Vec<u16>,
|
||||
_temp_dir: TempDir,
|
||||
}
|
||||
|
||||
// This structs represents "enviroment". By environment we mean all paths
|
||||
// (nvcuda.dll, nvml.dll, etc.) and all related resources like the temporary
|
||||
// directory which contains nvcuda.dll
|
||||
impl Environment {
|
||||
fn setup(args: NormalizedArguments) -> io::Result<Self> {
|
||||
let _temp_dir = TempDir::new()?;
|
||||
let nvml_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||
args.nvml_path,
|
||||
&_temp_dir,
|
||||
NVML_DLL,
|
||||
)?);
|
||||
let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||
args.nvcuda_path,
|
||||
&_temp_dir,
|
||||
NVCUDA_DLL,
|
||||
)?);
|
||||
let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
|
||||
Ok(Self {
|
||||
nvml_path_zero_terminated,
|
||||
nvcuda_path_zero_terminated,
|
||||
redirect_path_zero_terminated,
|
||||
winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated,
|
||||
_temp_dir,
|
||||
})
|
||||
}
|
||||
|
||||
fn copy_to_correct_name(
|
||||
path_buf: PathBuf,
|
||||
temp_dir: &TempDir,
|
||||
correct_name: &str,
|
||||
) -> io::Result<PathBuf> {
|
||||
let file_name = path_buf.file_name().unwrap();
|
||||
if file_name == correct_name {
|
||||
Ok(path_buf)
|
||||
} else {
|
||||
let mut temp_file_path = temp_dir.path().to_path_buf();
|
||||
temp_file_path.push(correct_name);
|
||||
match windows::fs::symlink_file(&path_buf, &temp_file_path) {
|
||||
Ok(()) => {}
|
||||
Err(_) => {
|
||||
fs::copy(&path_buf, &temp_file_path)?;
|
||||
}
|
||||
}
|
||||
Ok(temp_file_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn zero_terminate(p: PathBuf) -> String {
|
||||
let mut s = p.to_string_lossy().to_string();
|
||||
s.push('\0');
|
||||
s
|
||||
}
|
||||
}
|
||||
|
||||
fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
|
||||
let job_handle = os_call!(CreateJobObjectA(ptr::null_mut(), ptr::null()), |x| x
|
||||
!= ptr::null_mut());
|
||||
|
@ -103,29 +254,11 @@ fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn print_help_and_exit() -> ! {
|
||||
let current_exe = env::current_exe().unwrap();
|
||||
let exe_name = current_exe.file_name().unwrap().to_string_lossy();
|
||||
println!(
|
||||
"USAGE:
|
||||
{0} -- <EXE> [ARGS]...
|
||||
{0} <DLL> -- <EXE> [ARGS]...
|
||||
ARGS:
|
||||
<DLL> DLL to be injected instead of system nvcuda.dll, if not provided
|
||||
will use nvcuda.dll from the directory where {0} is located
|
||||
<EXE> Path to the executable to be injected with <DLL>
|
||||
<ARGS>... Arguments that will be passed to <EXE>
|
||||
",
|
||||
exe_name
|
||||
);
|
||||
process::exit(1)
|
||||
}
|
||||
|
||||
// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
|
||||
fn construct_command_line(args: &[String]) -> Vec<u16> {
|
||||
fn construct_command_line(args: impl Iterator<Item = String>) -> Vec<u16> {
|
||||
let mut cmd_line = Vec::new();
|
||||
let args_len = args.len();
|
||||
for (idx, arg) in args.iter().enumerate() {
|
||||
let args_len = args.size_hint().0;
|
||||
for (idx, arg) in args.enumerate() {
|
||||
if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
|
||||
cmd_line.extend(arg.encode_utf16());
|
||||
} else {
|
||||
|
@ -176,55 +309,3 @@ fn construct_command_line(args: &[String]) -> Vec<u16> {
|
|||
cmd_line.push(0);
|
||||
cmd_line
|
||||
}
|
||||
|
||||
fn create_redirect_path(injector_dir: &Path) -> Vec<u8> {
|
||||
let mut injector_dir = injector_dir.to_path_buf();
|
||||
injector_dir.push(REDIRECT_DLL);
|
||||
let mut result = injector_dir.to_string_lossy().into_owned().into_bytes();
|
||||
result.push(0);
|
||||
result
|
||||
}
|
||||
|
||||
fn create_inject_path<'a>(
|
||||
args: &'a [String],
|
||||
injector_dir: &Path,
|
||||
) -> std::io::Result<(Vec<u16>, Vec<u16>, &'a [String])> {
|
||||
let injector_dir = injector_dir.to_path_buf();
|
||||
let (nvcuda_path, unparsed_args) = if args.get(0).map(Deref::deref) == Some("--") {
|
||||
(
|
||||
encode_file_in_directory_raw(injector_dir.clone(), ZLUDA_DLL),
|
||||
&args[1..],
|
||||
)
|
||||
} else if args.get(1).map(Deref::deref) == Some("--") {
|
||||
let dll_path = make_absolute_and_encode(&args[0])?;
|
||||
(dll_path, &args[2..])
|
||||
} else {
|
||||
print_help_and_exit()
|
||||
};
|
||||
let nvml_path = encode_file_in_directory_raw(injector_dir, ZLUDA_ML_DLL);
|
||||
Ok((nvcuda_path, nvml_path, unparsed_args))
|
||||
}
|
||||
|
||||
fn encode_file_in_directory_raw(mut dir: PathBuf, file: &'static str) -> Vec<u16> {
|
||||
dir.push(file);
|
||||
let mut result = dir
|
||||
.to_string_lossy()
|
||||
.as_ref()
|
||||
.encode_utf16()
|
||||
.collect::<Vec<_>>();
|
||||
result.push(0);
|
||||
result
|
||||
}
|
||||
|
||||
fn make_absolute_and_encode(maybe_path: &str) -> std::io::Result<Vec<u16>> {
|
||||
let path = Path::new(maybe_path);
|
||||
let mut encoded_path = if path.is_relative() {
|
||||
let mut current_dir = env::current_dir()?;
|
||||
current_dir.push(path);
|
||||
current_dir.as_os_str().encode_wide().collect::<Vec<_>>()
|
||||
} else {
|
||||
maybe_path.encode_utf16().collect::<Vec<_>>()
|
||||
};
|
||||
encoded_path.push(0);
|
||||
Ok(encoded_path)
|
||||
}
|
||||
|
|
10
zluda_inject/tests/helpers/do_cuinit_early.rs
Normal file
10
zluda_inject/tests/helpers/do_cuinit_early.rs
Normal file
|
@ -0,0 +1,10 @@
|
|||
#![crate_type = "bin"]
|
||||
|
||||
#[link(name = "do_cuinit")]
|
||||
extern "system" {
|
||||
fn do_cuinit(flags: u32) -> u32;
|
||||
}
|
||||
|
||||
fn main() {
|
||||
unsafe { do_cuinit(0) };
|
||||
}
|
10
zluda_inject/tests/helpers/subprocess.rs
Normal file
10
zluda_inject/tests/helpers/subprocess.rs
Normal file
|
@ -0,0 +1,10 @@
|
|||
#![crate_type = "bin"]
|
||||
|
||||
use std::io;
|
||||
use std::process::Command;
|
||||
|
||||
fn main() -> io::Result<()> {
|
||||
let status = Command::new("direct_cuinit.exe").status()?;
|
||||
assert!(status.success());
|
||||
Ok(())
|
||||
}
|
|
@ -5,19 +5,29 @@ fn direct_cuinit() -> io::Result<()> {
|
|||
run_process_and_check_for_zluda_dump("direct_cuinit")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn do_cuinit_early() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("do_cuinit_early")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn do_cuinit_late() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("do_cuinit_late")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn do_cuinit_late_clr() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("do_cuinit_late_clr")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn indirect_cuinit() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("indirect_cuinit")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn do_cuinit() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("do_cuinit_main")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn do_cuinit_clr() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("do_cuinit_main_clr")
|
||||
fn subprocess() -> io::Result<()> {
|
||||
run_process_and_check_for_zluda_dump("subprocess")
|
||||
}
|
||||
|
||||
fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> {
|
||||
|
@ -27,7 +37,11 @@ fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> {
|
|||
let helpers_dir = env!("HELPERS_OUT_DIR");
|
||||
let exe_under_test = format!("{}{}{}.exe", helpers_dir, std::path::MAIN_SEPARATOR, name);
|
||||
let mut test_cmd = Command::new(&zluda_with_exe);
|
||||
let test_cmd = test_cmd.arg(&zluda_dump_dll).arg("--").arg(&exe_under_test);
|
||||
let test_cmd = test_cmd
|
||||
.arg("--nvcuda")
|
||||
.arg(&zluda_dump_dll)
|
||||
.arg("--")
|
||||
.arg(&exe_under_test);
|
||||
let test_output = test_cmd.output()?;
|
||||
assert!(test_output.status.success());
|
||||
let stderr_text = String::from_utf8(test_output.stderr).unwrap();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue