mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
Add a library for dumping kernels arguments before and after launch (#18)
This commit is contained in:
parent
09f679693b
commit
ff8135e8a3
16 changed files with 4951 additions and 66 deletions
2
.github/workflows/rust.yml
vendored
2
.github/workflows/rust.yml
vendored
|
@ -48,7 +48,7 @@ jobs:
|
|||
sudo apt update
|
||||
sudo apt install ocl-icd-opencl-dev
|
||||
- name: Build
|
||||
run: cargo build --verbose
|
||||
run: cargo build --workspace --verbose
|
||||
# TODO(take-cheeze): Support testing
|
||||
# - name: Run tests
|
||||
# run: cargo test --verbose
|
||||
|
|
|
@ -6,6 +6,7 @@ members = [
|
|||
"level_zero",
|
||||
"spirv_tools-sys",
|
||||
"zluda",
|
||||
"zluda_dump",
|
||||
"zluda_lib",
|
||||
"zluda_inject",
|
||||
"zluda_redirect",
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
use std::error::Error;
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn main() {}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
windows::main_impl()
|
||||
}
|
||||
|
||||
|
@ -37,18 +35,15 @@ mod windows {
|
|||
.try_compile("detours")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(target_env = "msvc")]
|
||||
fn add_target_options(build: &mut cc::Build) -> &mut cc::Build {
|
||||
build
|
||||
}
|
||||
|
||||
#[cfg(not(target_env = "msvc"))]
|
||||
fn add_target_options(build: &mut cc::Build) -> &mut cc::Build {
|
||||
build
|
||||
.compiler("clang")
|
||||
.cpp(true)
|
||||
.flag("-fms-extensions")
|
||||
.flag("-Wno-everything")
|
||||
if std::env::var("CARGO_CFG_TARGET_ENV").unwrap() != "msvc" {
|
||||
build
|
||||
.compiler("clang")
|
||||
.cpp(true)
|
||||
.flag("-fms-extensions")
|
||||
.flag("-Wno-everything")
|
||||
} else {
|
||||
build
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -294,6 +294,17 @@ pub enum KernelArgumentType {
|
|||
Shared,
|
||||
}
|
||||
|
||||
impl From<KernelArgumentType> for Type {
|
||||
fn from(this: KernelArgumentType) -> Self {
|
||||
match this {
|
||||
KernelArgumentType::Normal(typ) => typ.into(),
|
||||
KernelArgumentType::Shared => {
|
||||
Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FnArgumentType {
|
||||
pub fn to_type(&self, is_kernel: bool) -> Type {
|
||||
if is_kernel {
|
||||
|
|
|
@ -6039,7 +6039,7 @@ impl ast::Type {
|
|||
}
|
||||
}
|
||||
|
||||
fn size_of(&self) -> usize {
|
||||
pub fn size_of(&self) -> usize {
|
||||
match self {
|
||||
ast::Type::Scalar(typ) => typ.size_of() as usize,
|
||||
ast::Type::Vector(typ, len) => (typ.size_of() as usize) * (*len as usize),
|
||||
|
@ -6253,18 +6253,6 @@ impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<ast::KernelArgumentType> for ast::Type {
|
||||
fn from(this: ast::KernelArgumentType) -> Self {
|
||||
match this {
|
||||
ast::KernelArgumentType::Normal(typ) => typ.into(),
|
||||
ast::KernelArgumentType::Shared => ast::Type::Pointer(
|
||||
ast::PointerType::Scalar(ast::ScalarType::B8),
|
||||
ast::LdStateSpace::Shared,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ArgParamsEx> ast::Arg1<T> {
|
||||
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
|
||||
self,
|
||||
|
|
|
@ -13,7 +13,7 @@ level_zero = { path = "../level_zero" }
|
|||
level_zero-sys = { path = "../level_zero-sys" }
|
||||
lazy_static = "1.4"
|
||||
num_enum = "0.4"
|
||||
lz4 = "1.23"
|
||||
lz4-sys = "1.9"
|
||||
|
||||
[dev-dependencies]
|
||||
cuda-driver-sys = "0.3.0"
|
||||
|
|
|
@ -5,12 +5,12 @@ use crate::{
|
|||
};
|
||||
|
||||
use super::{context, context::ContextData, device, module, Decuda, Encuda, GlobalState};
|
||||
use std::mem;
|
||||
use std::os::raw::{c_uint, c_ulong, c_ushort};
|
||||
use std::{
|
||||
ffi::{c_void, CStr},
|
||||
ptr, slice,
|
||||
ptr,
|
||||
};
|
||||
use std::{mem, os::raw::c_int};
|
||||
|
||||
pub fn get(table: *mut *const std::os::raw::c_void, id: *const CUuuid) -> CUresult {
|
||||
if table == ptr::null_mut() || id == ptr::null_mut() {
|
||||
|
@ -177,6 +177,7 @@ const FATBIN_FILE_HEADER_VERSION_CURRENT: c_ushort = 0x101;
|
|||
|
||||
// assembly file header is a bit different, but we don't care
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
struct FatbinFileHeader {
|
||||
kind: c_ushort,
|
||||
version: c_ushort,
|
||||
|
@ -221,12 +222,10 @@ unsafe extern "C" fn get_module_from_cubin(
|
|||
let mut ptx_files = get_ptx_files(file, end);
|
||||
ptx_files.sort_unstable_by_key(|f| c_uint::max_value() - (**f).sm_version);
|
||||
for file in ptx_files {
|
||||
let slice = slice::from_raw_parts(
|
||||
(file as *const u8).add((*file).header_size as usize),
|
||||
(*file).payload_size as usize,
|
||||
);
|
||||
let kernel_text =
|
||||
lz4::block::decompress(slice, Some((*file).uncompressed_payload as i32)).unwrap();
|
||||
let kernel_text = match decompress_kernel_module(file) {
|
||||
None => continue,
|
||||
Some(vec) => vec,
|
||||
};
|
||||
let kernel_text_string = match CStr::from_bytes_with_nul(&kernel_text) {
|
||||
Ok(c_str) => match c_str.to_str() {
|
||||
Ok(s) => s,
|
||||
|
@ -264,6 +263,33 @@ unsafe fn get_ptx_files(file: *const u8, end: *const u8) -> Vec<*const FatbinFil
|
|||
result
|
||||
}
|
||||
|
||||
const MAX_PTX_MODULE_DECOMPRESSION_BOUND: usize = 16 * 1024 * 1024;
|
||||
|
||||
unsafe fn decompress_kernel_module(file: *const FatbinFileHeader) -> Option<Vec<u8>> {
|
||||
let decompressed_size = usize::max(1024, (*file).uncompressed_payload as usize);
|
||||
let mut decompressed_vec = vec![0u8; decompressed_size];
|
||||
loop {
|
||||
match lz4_sys::LZ4_decompress_safe(
|
||||
(file as *const u8).add((*file).header_size as usize) as *const _,
|
||||
decompressed_vec.as_mut_ptr() as *mut _,
|
||||
(*file).payload_size as c_int,
|
||||
decompressed_vec.len() as c_int,
|
||||
) {
|
||||
error if error < 0 => {
|
||||
let new_size = decompressed_vec.len() * 2;
|
||||
if new_size > MAX_PTX_MODULE_DECOMPRESSION_BOUND {
|
||||
return None;
|
||||
}
|
||||
decompressed_vec.resize(decompressed_vec.len() * 2, 0);
|
||||
}
|
||||
real_decompressed_size => {
|
||||
decompressed_vec.truncate(real_decompressed_size as usize);
|
||||
return Some(decompressed_vec);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "C" fn cudart_interface_fn6(_: u64) {}
|
||||
|
||||
const TOOLS_TLS_GUID: CUuuid = CUuuid {
|
||||
|
|
|
@ -4,7 +4,6 @@ extern crate level_zero_sys as l0_sys;
|
|||
extern crate lazy_static;
|
||||
#[cfg(test)]
|
||||
extern crate cuda_driver_sys;
|
||||
extern crate lz4;
|
||||
#[cfg(test)]
|
||||
#[macro_use]
|
||||
extern crate paste;
|
||||
|
|
22
zluda_dump/Cargo.toml
Normal file
22
zluda_dump/Cargo.toml
Normal file
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
name = "zluda_dump"
|
||||
version = "0.0.0"
|
||||
authors = ["Andrzej Janik <vosen@vosen.pl>"]
|
||||
edition = "2018"
|
||||
|
||||
[lib]
|
||||
name = "zluda_dump"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
ptx = { path = "../ptx" }
|
||||
lz4-sys = "1.9"
|
||||
regex = "1.4"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
winapi = { version = "0.3", features = ["libloaderapi", "debugapi"] }
|
||||
wchar = "0.6"
|
||||
detours-sys = { path = "../detours-sys" }
|
||||
|
||||
[target.'cfg(not(windows))'.dependencies]
|
||||
libc = "0.2"
|
4072
zluda_dump/src/cuda.rs
Normal file
4072
zluda_dump/src/cuda.rs
Normal file
File diff suppressed because it is too large
Load diff
670
zluda_dump/src/lib.rs
Normal file
670
zluda_dump/src/lib.rs
Normal file
|
@ -0,0 +1,670 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
env,
|
||||
error::Error,
|
||||
ffi::{c_void, CStr},
|
||||
fs,
|
||||
io::prelude::*,
|
||||
mem,
|
||||
os::raw::{c_int, c_uint, c_ulong, c_ushort},
|
||||
path::PathBuf,
|
||||
rc::Rc,
|
||||
slice,
|
||||
};
|
||||
use std::{fs::File, ptr};
|
||||
|
||||
use cuda::{CUdeviceptr, CUfunction, CUjit_option, CUmodule, CUresult, CUstream, CUuuid};
|
||||
use ptx::ast;
|
||||
use regex::Regex;
|
||||
|
||||
#[cfg_attr(windows, path = "os_win.rs")]
|
||||
#[cfg_attr(not(windows), path = "os_unix.rs")]
|
||||
mod os;
|
||||
|
||||
macro_rules! extern_redirect {
|
||||
(pub fn $fn_name:ident ( $($arg_id:ident: $arg_type:ty),* $(,)? ) -> $ret_type:ty ;) => {
|
||||
#[no_mangle]
|
||||
pub fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
||||
unsafe { $crate::init_libcuda_handle() };
|
||||
let name = std::ffi::CString::new(stringify!($fn_name)).unwrap();
|
||||
let fn_ptr = unsafe { crate::os::get_proc_address($crate::LIBCUDA_HANDLE, &name) };
|
||||
if fn_ptr == std::ptr::null_mut() {
|
||||
return CUresult::CUDA_ERROR_UNKNOWN;
|
||||
}
|
||||
let typed_fn = unsafe { std::mem::transmute::<_, fn( $( $arg_id : $arg_type),* ) -> $ret_type>(fn_ptr) };
|
||||
typed_fn($( $arg_id ),*)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! extern_redirect_with {
|
||||
(
|
||||
pub fn $fn_name:ident ( $($arg_id:ident: $arg_type:ty),* $(,)? ) -> $ret_type:ty ;
|
||||
$receiver:path ;
|
||||
) => {
|
||||
#[no_mangle]
|
||||
pub fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
||||
unsafe { $crate::init_libcuda_handle() };
|
||||
let continuation = |$( $arg_id : $arg_type),* | {
|
||||
let name = std::ffi::CString::new(stringify!($fn_name)).unwrap();
|
||||
let fn_ptr = unsafe { crate::os::get_proc_address($crate::LIBCUDA_HANDLE, &name) };
|
||||
if fn_ptr == std::ptr::null_mut() {
|
||||
return CUresult::CUDA_ERROR_UNKNOWN;
|
||||
}
|
||||
let typed_fn = unsafe { std::mem::transmute::<_, fn( $( $arg_id : $arg_type),* ) -> $ret_type>(fn_ptr) };
|
||||
typed_fn($( $arg_id ),*)
|
||||
};
|
||||
unsafe { $receiver($( $arg_id ),* , continuation) }
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(warnings)]
|
||||
mod cuda;
|
||||
|
||||
pub static mut LIBCUDA_HANDLE: *mut c_void = ptr::null_mut();
|
||||
pub static mut MODULES: Option<HashMap<CUmodule, ModuleDump>> = None;
|
||||
pub static mut KERNELS: Option<HashMap<CUfunction, KernelDump>> = None;
|
||||
pub static mut BUFFERS: Vec<(usize, usize)> = Vec::new();
|
||||
pub static mut LAUNCH_COUNTER: usize = 0;
|
||||
pub static mut KERNEL_PATTERN: Option<Regex> = None;
|
||||
|
||||
pub struct ModuleDump {
|
||||
content: Rc<String>,
|
||||
kernels_args: HashMap<String, Vec<usize>>,
|
||||
}
|
||||
|
||||
pub struct KernelDump {
|
||||
module_content: Rc<String>,
|
||||
name: String,
|
||||
arguments: Vec<usize>,
|
||||
}
|
||||
|
||||
// We are doing dlopen here instead of just using LD_PRELOAD,
|
||||
// it's because CUDA Runtime API does dlopen to open libcuda.so, which ignores LD_PRELOAD
|
||||
pub unsafe fn init_libcuda_handle() {
|
||||
if LIBCUDA_HANDLE == ptr::null_mut() {
|
||||
let libcuda_handle = os::load_cuda_library();
|
||||
assert_ne!(libcuda_handle, ptr::null_mut());
|
||||
LIBCUDA_HANDLE = libcuda_handle;
|
||||
match env::var("ZLUDA_DUMP_KERNEL") {
|
||||
Ok(kernel_filter) => match Regex::new(&kernel_filter) {
|
||||
Ok(r) => KERNEL_PATTERN = Some(r),
|
||||
Err(err) => {
|
||||
eprintln!(
|
||||
"[ZLUDA_DUMP] Env variable ZLUDA_DUMP_KERNEL is not a regex: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
},
|
||||
Err(_) => (),
|
||||
}
|
||||
eprintln!("[ZLUDA_DUMP] Initialized");
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
pub unsafe fn cuModuleLoadData(
|
||||
module: *mut CUmodule,
|
||||
raw_image: *const ::std::os::raw::c_void,
|
||||
cont: impl FnOnce(*mut CUmodule, *const c_void) -> CUresult,
|
||||
) -> CUresult {
|
||||
let result = cont(module, raw_image);
|
||||
if result == CUresult::CUDA_SUCCESS {
|
||||
record_module_image_raw(*module, raw_image);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
unsafe fn record_module_image_raw(module: CUmodule, raw_image: *const ::std::os::raw::c_void) {
|
||||
let image = to_str(raw_image);
|
||||
match image {
|
||||
None => eprintln!("[ZLUDA_DUMP] Malformed module image: {:?}", raw_image),
|
||||
Some(image) => record_module_image(module, image),
|
||||
};
|
||||
}
|
||||
|
||||
unsafe fn record_module_image(module: CUmodule, image: &str) {
|
||||
if !image.contains(&".address_size") {
|
||||
eprintln!("[ZLUDA_DUMP] Malformed module image: {:?}", module)
|
||||
} else {
|
||||
let mut errors = Vec::new();
|
||||
let ast = ptx::ModuleParser::new().parse(&mut errors, image);
|
||||
match (&*errors, ast) {
|
||||
(&[], Ok(ast)) => {
|
||||
let kernels_args = ast
|
||||
.directives
|
||||
.iter()
|
||||
.filter_map(directive_to_kernel)
|
||||
.collect::<HashMap<_, _>>();
|
||||
let modules = MODULES.get_or_insert_with(|| HashMap::new());
|
||||
modules.insert(
|
||||
module,
|
||||
ModuleDump {
|
||||
content: Rc::new(image.to_string()),
|
||||
kernels_args,
|
||||
},
|
||||
);
|
||||
}
|
||||
(errs, ast) => {
|
||||
let err_string = errs
|
||||
.iter()
|
||||
.map(|e| format!("{:?}", e))
|
||||
.chain(ast.err().iter().map(|e| format!("{:?}", e)))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
eprintln!(
|
||||
"[ZLUDA_DUMP] Errors when parsing module:\n---ERRORS---\n{}\n---MODULE---\n{}",
|
||||
err_string, image
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn to_str<T>(image: *const T) -> Option<&'static str> {
|
||||
let ptr = image as *const u8;
|
||||
let mut offset = 0;
|
||||
loop {
|
||||
let c = *ptr.add(offset);
|
||||
if !c.is_ascii() {
|
||||
return None;
|
||||
}
|
||||
if c == 0 {
|
||||
return Some(std::str::from_utf8_unchecked(slice::from_raw_parts(
|
||||
ptr, offset,
|
||||
)));
|
||||
}
|
||||
offset += 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn directive_to_kernel(dir: &ast::Directive<ast::ParsedArgParams>) -> Option<(String, Vec<usize>)> {
|
||||
match dir {
|
||||
ast::Directive::Method(ast::Function {
|
||||
func_directive: ast::MethodDecl::Kernel { name, in_args },
|
||||
..
|
||||
}) => {
|
||||
let arg_sizes = in_args
|
||||
.iter()
|
||||
.map(|arg| ast::Type::from(arg.v_type.clone()).size_of())
|
||||
.collect();
|
||||
Some((name.to_string(), arg_sizes))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
pub unsafe fn cuModuleLoadDataEx(
|
||||
module: *mut CUmodule,
|
||||
image: *const c_void,
|
||||
numOptions: c_uint,
|
||||
options: *mut CUjit_option,
|
||||
optionValues: *mut *mut c_void,
|
||||
cont: impl FnOnce(
|
||||
*mut CUmodule,
|
||||
*const c_void,
|
||||
c_uint,
|
||||
*mut CUjit_option,
|
||||
*mut *mut c_void,
|
||||
) -> CUresult,
|
||||
) -> CUresult {
|
||||
let result = cont(module, image, numOptions, options, optionValues);
|
||||
if result == CUresult::CUDA_SUCCESS {
|
||||
record_module_image_raw(*module, image);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
unsafe fn cuModuleGetFunction(
|
||||
hfunc: *mut CUfunction,
|
||||
hmod: CUmodule,
|
||||
name: *const ::std::os::raw::c_char,
|
||||
cont: impl FnOnce(*mut CUfunction, CUmodule, *const ::std::os::raw::c_char) -> CUresult,
|
||||
) -> CUresult {
|
||||
let result = cont(hfunc, hmod, name);
|
||||
if result != CUresult::CUDA_SUCCESS {
|
||||
return result;
|
||||
}
|
||||
if let Some(modules) = &MODULES {
|
||||
if let Some(module_dump) = modules.get(&hmod) {
|
||||
if let Some(kernel) = to_str(name) {
|
||||
if let Some(args) = module_dump.kernels_args.get(kernel) {
|
||||
let kernel_args = KERNELS.get_or_insert_with(|| HashMap::new());
|
||||
kernel_args.insert(
|
||||
*hfunc,
|
||||
KernelDump {
|
||||
module_content: module_dump.content.clone(),
|
||||
name: kernel.to_string(),
|
||||
arguments: args.clone(),
|
||||
},
|
||||
);
|
||||
} else {
|
||||
eprintln!("[ZLUDA_DUMP] Unknown kernel: {}", kernel);
|
||||
}
|
||||
} else {
|
||||
eprintln!("[ZLUDA_DUMP] Unknown kernel name at: {:?}", hfunc);
|
||||
}
|
||||
} else {
|
||||
eprintln!("[ZLUDA_DUMP] Unknown module: {:?}", hmod);
|
||||
}
|
||||
} else {
|
||||
eprintln!("[ZLUDA_DUMP] Unknown module: {:?}", hmod);
|
||||
}
|
||||
CUresult::CUDA_SUCCESS
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
pub unsafe fn cuMemAlloc_v2(
|
||||
dptr: *mut CUdeviceptr,
|
||||
bytesize: usize,
|
||||
cont: impl FnOnce(*mut CUdeviceptr, usize) -> CUresult,
|
||||
) -> CUresult {
|
||||
let result = cont(dptr, bytesize);
|
||||
assert_eq!(result, CUresult::CUDA_SUCCESS);
|
||||
let start = (*dptr).0 as usize;
|
||||
BUFFERS.push((start, bytesize));
|
||||
CUresult::CUDA_SUCCESS
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
pub unsafe fn cuLaunchKernel(
|
||||
f: CUfunction,
|
||||
gridDimX: ::std::os::raw::c_uint,
|
||||
gridDimY: ::std::os::raw::c_uint,
|
||||
gridDimZ: ::std::os::raw::c_uint,
|
||||
blockDimX: ::std::os::raw::c_uint,
|
||||
blockDimY: ::std::os::raw::c_uint,
|
||||
blockDimZ: ::std::os::raw::c_uint,
|
||||
sharedMemBytes: ::std::os::raw::c_uint,
|
||||
hStream: CUstream,
|
||||
kernelParams: *mut *mut ::std::os::raw::c_void,
|
||||
extra: *mut *mut ::std::os::raw::c_void,
|
||||
cont: impl FnOnce(
|
||||
CUfunction,
|
||||
::std::os::raw::c_uint,
|
||||
::std::os::raw::c_uint,
|
||||
::std::os::raw::c_uint,
|
||||
::std::os::raw::c_uint,
|
||||
::std::os::raw::c_uint,
|
||||
::std::os::raw::c_uint,
|
||||
::std::os::raw::c_uint,
|
||||
CUstream,
|
||||
*mut *mut ::std::os::raw::c_void,
|
||||
*mut *mut ::std::os::raw::c_void,
|
||||
) -> CUresult,
|
||||
) -> CUresult {
|
||||
let mut error;
|
||||
let dump_env = match create_dump_dir(f, LAUNCH_COUNTER) {
|
||||
Ok(dump_env) => dump_env,
|
||||
Err(err) => {
|
||||
eprintln!("[ZLUDA_DUMP] {:#?}", err);
|
||||
None
|
||||
}
|
||||
};
|
||||
if let Some(dump_env) = &dump_env {
|
||||
dump_pre_data(
|
||||
gridDimX,
|
||||
gridDimY,
|
||||
gridDimZ,
|
||||
blockDimX,
|
||||
blockDimY,
|
||||
blockDimZ,
|
||||
sharedMemBytes,
|
||||
kernelParams,
|
||||
dump_env,
|
||||
)
|
||||
.unwrap_or_else(|err| eprintln!("[ZLUDA_DUMP] {:#?}", err));
|
||||
};
|
||||
error = cont(
|
||||
f,
|
||||
gridDimX,
|
||||
gridDimY,
|
||||
gridDimZ,
|
||||
blockDimX,
|
||||
blockDimY,
|
||||
blockDimZ,
|
||||
sharedMemBytes,
|
||||
hStream,
|
||||
kernelParams,
|
||||
extra,
|
||||
);
|
||||
assert_eq!(error, CUresult::CUDA_SUCCESS);
|
||||
error = cuda::cuStreamSynchronize(hStream);
|
||||
assert_eq!(error, CUresult::CUDA_SUCCESS);
|
||||
if let Some((_, kernel_dump)) = &dump_env {
|
||||
dump_arguments(
|
||||
kernelParams,
|
||||
"post",
|
||||
&kernel_dump.name,
|
||||
LAUNCH_COUNTER,
|
||||
&kernel_dump.arguments,
|
||||
)
|
||||
.unwrap_or_else(|err| eprintln!("[ZLUDA_DUMP] {:#?}", err));
|
||||
}
|
||||
LAUNCH_COUNTER += 1;
|
||||
CUresult::CUDA_SUCCESS
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
fn dump_launch_arguments(
|
||||
gridDimX: u32,
|
||||
gridDimY: u32,
|
||||
gridDimZ: u32,
|
||||
blockDimX: u32,
|
||||
blockDimY: u32,
|
||||
blockDimZ: u32,
|
||||
sharedMemBytes: u32,
|
||||
dump_dir: &PathBuf,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let mut module_file_path = dump_dir.clone();
|
||||
module_file_path.push("launch.txt");
|
||||
let mut module_file = File::create(module_file_path)?;
|
||||
write!(&mut module_file, "{}\n", gridDimX)?;
|
||||
write!(&mut module_file, "{}\n", gridDimY)?;
|
||||
write!(&mut module_file, "{}\n", gridDimZ)?;
|
||||
write!(&mut module_file, "{}\n", blockDimX)?;
|
||||
write!(&mut module_file, "{}\n", blockDimY)?;
|
||||
write!(&mut module_file, "{}\n", blockDimZ)?;
|
||||
write!(&mut module_file, "{}\n", sharedMemBytes)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe fn should_dump_kernel(name: &str) -> bool {
|
||||
match &KERNEL_PATTERN {
|
||||
Some(pattern) => pattern.is_match(name),
|
||||
None => true,
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn create_dump_dir(
|
||||
f: CUfunction,
|
||||
counter: usize,
|
||||
) -> Result<Option<(PathBuf, &'static KernelDump)>, Box<dyn Error>> {
|
||||
match KERNELS.as_ref().and_then(|kernels| kernels.get(&f)) {
|
||||
Some(kernel_dump) => {
|
||||
if !should_dump_kernel(&kernel_dump.name) {
|
||||
return Ok(None);
|
||||
}
|
||||
let mut dump_dir = get_dump_dir()?;
|
||||
dump_dir.push(format!("{:04}_{}", counter, kernel_dump.name));
|
||||
fs::create_dir_all(&dump_dir)?;
|
||||
Ok(Some((dump_dir, kernel_dump)))
|
||||
}
|
||||
None => Err("Unknown kernel: {:?}")?,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
unsafe fn dump_pre_data(
|
||||
gridDimX: ::std::os::raw::c_uint,
|
||||
gridDimY: ::std::os::raw::c_uint,
|
||||
gridDimZ: ::std::os::raw::c_uint,
|
||||
blockDimX: ::std::os::raw::c_uint,
|
||||
blockDimY: ::std::os::raw::c_uint,
|
||||
blockDimZ: ::std::os::raw::c_uint,
|
||||
sharedMemBytes: ::std::os::raw::c_uint,
|
||||
kernelParams: *mut *mut ::std::os::raw::c_void,
|
||||
(dump_dir, kernel_dump): &(PathBuf, &'static KernelDump),
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
dump_launch_arguments(
|
||||
gridDimX,
|
||||
gridDimY,
|
||||
gridDimZ,
|
||||
blockDimX,
|
||||
blockDimY,
|
||||
blockDimZ,
|
||||
sharedMemBytes,
|
||||
dump_dir,
|
||||
)?;
|
||||
let mut module_file_path = dump_dir.clone();
|
||||
module_file_path.push("module.ptx");
|
||||
let mut module_file = File::create(module_file_path)?;
|
||||
module_file.write_all(kernel_dump.module_content.as_bytes())?;
|
||||
dump_arguments(
|
||||
kernelParams,
|
||||
"pre",
|
||||
&kernel_dump.name,
|
||||
LAUNCH_COUNTER,
|
||||
&kernel_dump.arguments,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe fn dump_arguments(
|
||||
kernel_params: *mut *mut ::std::os::raw::c_void,
|
||||
prefix: &str,
|
||||
kernel_name: &str,
|
||||
counter: usize,
|
||||
args: &[usize],
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let mut dump_dir = get_dump_dir()?;
|
||||
dump_dir.push(format!("{:04}_{}", counter, kernel_name));
|
||||
dump_dir.push(prefix);
|
||||
if dump_dir.exists() {
|
||||
fs::remove_dir_all(&dump_dir)?;
|
||||
}
|
||||
fs::create_dir_all(&dump_dir)?;
|
||||
for (i, arg_len) in args.iter().enumerate() {
|
||||
let dev_ptr = *(*kernel_params.add(i) as *mut usize);
|
||||
match BUFFERS.iter().find(|(start, _)| *start == dev_ptr as usize) {
|
||||
Some((start, len)) => {
|
||||
let mut output = vec![0u8; *len];
|
||||
let error =
|
||||
cuda::cuMemcpyDtoH_v2(output.as_mut_ptr() as *mut _, CUdeviceptr(*start), *len);
|
||||
assert_eq!(error, CUresult::CUDA_SUCCESS);
|
||||
let mut path = dump_dir.clone();
|
||||
path.push(format!("arg_{:03}.buffer", i));
|
||||
let mut file = File::create(path)?;
|
||||
file.write_all(&mut output)?;
|
||||
}
|
||||
None => {
|
||||
let mut path = dump_dir.clone();
|
||||
path.push(format!("arg_{:03}", i));
|
||||
let mut file = File::create(path)?;
|
||||
file.write_all(slice::from_raw_parts(
|
||||
*kernel_params.add(i) as *mut u8,
|
||||
*arg_len,
|
||||
))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_dump_dir() -> Result<PathBuf, Box<dyn Error>> {
|
||||
let dir = env::var("ZLUDA_DUMP_DIR")?;
|
||||
let mut main_dir = PathBuf::from(dir);
|
||||
let current_exe = env::current_exe()?;
|
||||
main_dir.push(current_exe.file_name().unwrap());
|
||||
fs::create_dir_all(&main_dir)?;
|
||||
Ok(main_dir)
|
||||
}
|
||||
|
||||
// TODO make this more common with ZLUDA implementation
|
||||
const CUDART_INTERFACE_GUID: CUuuid = CUuuid {
|
||||
bytes: [
|
||||
0x6b, 0xd5, 0xfb, 0x6c, 0x5b, 0xf4, 0xe7, 0x4a, 0x89, 0x87, 0xd9, 0x39, 0x12, 0xfd, 0x9d,
|
||||
0xf9,
|
||||
],
|
||||
};
|
||||
|
||||
const GET_MODULE_OFFSET: usize = 6;
|
||||
static mut CUDART_INTERFACE_VTABLE: Vec<*const c_void> = Vec::new();
|
||||
static mut ORIGINAL_GET_MODULE_FROM_CUBIN: Option<
|
||||
unsafe extern "C" fn(
|
||||
result: *mut CUmodule,
|
||||
fatbinc_wrapper: *const FatbincWrapper,
|
||||
ptr1: *mut c_void,
|
||||
ptr2: *mut c_void,
|
||||
) -> CUresult,
|
||||
> = None;
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
pub unsafe fn cuGetExportTable(
|
||||
ppExportTable: *mut *const ::std::os::raw::c_void,
|
||||
pExportTableId: *const CUuuid,
|
||||
cont: impl FnOnce(*mut *const ::std::os::raw::c_void, *const CUuuid) -> CUresult,
|
||||
) -> CUresult {
|
||||
if *pExportTableId == CUDART_INTERFACE_GUID {
|
||||
if CUDART_INTERFACE_VTABLE.len() == 0 {
|
||||
let mut base_table = ptr::null();
|
||||
let base_result = cont(&mut base_table, pExportTableId);
|
||||
if base_result != CUresult::CUDA_SUCCESS {
|
||||
return base_result;
|
||||
}
|
||||
let len = *(base_table as *const usize);
|
||||
CUDART_INTERFACE_VTABLE = vec![ptr::null(); len];
|
||||
ptr::copy_nonoverlapping(
|
||||
base_table as *const _,
|
||||
CUDART_INTERFACE_VTABLE.as_mut_ptr(),
|
||||
len,
|
||||
);
|
||||
if GET_MODULE_OFFSET >= len {
|
||||
return CUresult::CUDA_ERROR_UNKNOWN;
|
||||
}
|
||||
ORIGINAL_GET_MODULE_FROM_CUBIN =
|
||||
mem::transmute(CUDART_INTERFACE_VTABLE[GET_MODULE_OFFSET]);
|
||||
CUDART_INTERFACE_VTABLE[GET_MODULE_OFFSET] = get_module_from_cubin as *const _;
|
||||
}
|
||||
*ppExportTable = CUDART_INTERFACE_VTABLE.as_ptr() as *const _;
|
||||
return CUresult::CUDA_SUCCESS;
|
||||
} else {
|
||||
cont(ppExportTable, pExportTableId)
|
||||
}
|
||||
}
|
||||
|
||||
const FATBINC_MAGIC: c_uint = 0x466243B1;
|
||||
const FATBINC_VERSION: c_uint = 0x1;
|
||||
|
||||
#[repr(C)]
|
||||
struct FatbincWrapper {
|
||||
magic: c_uint,
|
||||
version: c_uint,
|
||||
data: *const FatbinHeader,
|
||||
filename_or_fatbins: *const c_void,
|
||||
}
|
||||
|
||||
const FATBIN_MAGIC: c_uint = 0xBA55ED50;
|
||||
const FATBIN_VERSION: c_ushort = 0x01;
|
||||
|
||||
#[repr(C, align(8))]
|
||||
struct FatbinHeader {
|
||||
magic: c_uint,
|
||||
version: c_ushort,
|
||||
header_size: c_ushort,
|
||||
files_size: c_ulong, // excluding frame header, size of all blocks framed by this frame
|
||||
}
|
||||
|
||||
const FATBIN_FILE_HEADER_KIND_PTX: c_ushort = 0x01;
|
||||
const FATBIN_FILE_HEADER_VERSION_CURRENT: c_ushort = 0x101;
|
||||
|
||||
// assembly file header is a bit different, but we don't care
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
struct FatbinFileHeader {
|
||||
kind: c_ushort,
|
||||
version: c_ushort,
|
||||
header_size: c_uint,
|
||||
padded_payload_size: c_uint,
|
||||
unknown0: c_uint, // check if it's written into separately
|
||||
payload_size: c_uint,
|
||||
unknown1: c_uint,
|
||||
unknown2: c_uint,
|
||||
sm_version: c_uint,
|
||||
bit_width: c_uint,
|
||||
unknown3: c_uint,
|
||||
unknown4: c_ulong,
|
||||
unknown5: c_ulong,
|
||||
uncompressed_payload: c_ulong,
|
||||
}
|
||||
|
||||
unsafe extern "C" fn get_module_from_cubin(
|
||||
module: *mut CUmodule,
|
||||
fatbinc_wrapper: *const FatbincWrapper,
|
||||
ptr1: *mut c_void,
|
||||
ptr2: *mut c_void,
|
||||
) -> CUresult {
|
||||
if module == ptr::null_mut()
|
||||
|| (*fatbinc_wrapper).magic != FATBINC_MAGIC
|
||||
|| (*fatbinc_wrapper).version != FATBINC_VERSION
|
||||
{
|
||||
return CUresult::CUDA_ERROR_INVALID_VALUE;
|
||||
}
|
||||
let fatbin_header = (*fatbinc_wrapper).data;
|
||||
if (*fatbin_header).magic != FATBIN_MAGIC || (*fatbin_header).version != FATBIN_VERSION {
|
||||
return CUresult::CUDA_ERROR_INVALID_VALUE;
|
||||
}
|
||||
let file = (fatbin_header as *const u8).add((*fatbin_header).header_size as usize);
|
||||
let end = file.add((*fatbin_header).files_size as usize);
|
||||
let mut ptx_files = get_ptx_files(file, end);
|
||||
ptx_files.sort_unstable_by_key(|f| c_uint::max_value() - (**f).sm_version);
|
||||
let mut maybe_kernel_text = None;
|
||||
for file in ptx_files {
|
||||
match decompress_kernel_module(file) {
|
||||
None => continue,
|
||||
Some(vec) => {
|
||||
maybe_kernel_text = Some(vec);
|
||||
break;
|
||||
}
|
||||
};
|
||||
}
|
||||
let result = ORIGINAL_GET_MODULE_FROM_CUBIN.unwrap()(module, fatbinc_wrapper, ptr1, ptr2);
|
||||
if result != CUresult::CUDA_SUCCESS {
|
||||
return result;
|
||||
}
|
||||
if let Some(text) = maybe_kernel_text {
|
||||
match CStr::from_bytes_with_nul(&text) {
|
||||
Ok(cstr) => match cstr.to_str() {
|
||||
Ok(utf8_str) => record_module_image(*module, utf8_str),
|
||||
Err(_) => {}
|
||||
},
|
||||
Err(_) => {}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
unsafe fn get_ptx_files(file: *const u8, end: *const u8) -> Vec<*const FatbinFileHeader> {
|
||||
let mut index = file;
|
||||
let mut result = Vec::new();
|
||||
while index < end {
|
||||
let file = index as *const FatbinFileHeader;
|
||||
if (*file).kind == FATBIN_FILE_HEADER_KIND_PTX
|
||||
&& (*file).version == FATBIN_FILE_HEADER_VERSION_CURRENT
|
||||
{
|
||||
result.push(file)
|
||||
}
|
||||
index = index.add((*file).header_size as usize + (*file).padded_payload_size as usize);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
const MAX_PTX_MODULE_DECOMPRESSION_BOUND: usize = 16 * 1024 * 1024;
|
||||
|
||||
unsafe fn decompress_kernel_module(file: *const FatbinFileHeader) -> Option<Vec<u8>> {
|
||||
let decompressed_size = usize::max(1024, (*file).uncompressed_payload as usize);
|
||||
let mut decompressed_vec = vec![0u8; decompressed_size];
|
||||
loop {
|
||||
match lz4_sys::LZ4_decompress_safe(
|
||||
(file as *const u8).add((*file).header_size as usize) as *const _,
|
||||
decompressed_vec.as_mut_ptr() as *mut _,
|
||||
(*file).payload_size as c_int,
|
||||
decompressed_vec.len() as c_int,
|
||||
) {
|
||||
error if error < 0 => {
|
||||
let new_size = decompressed_vec.len() * 2;
|
||||
if new_size > MAX_PTX_MODULE_DECOMPRESSION_BOUND {
|
||||
return None;
|
||||
}
|
||||
decompressed_vec.resize(decompressed_vec.len() * 2, 0);
|
||||
}
|
||||
real_decompressed_size => {
|
||||
decompressed_vec.truncate(real_decompressed_size as usize);
|
||||
return Some(decompressed_vec);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
14
zluda_dump/src/os_unix.rs
Normal file
14
zluda_dump/src/os_unix.rs
Normal file
|
@ -0,0 +1,14 @@
|
|||
use std::ffi::{c_void, CStr};
|
||||
|
||||
const NVCUDA_DEFAULT_PATH: &'static [u8] = b"/usr/lib/x86_64-linux-gnu/libcuda.so.1\0";
|
||||
|
||||
pub unsafe fn load_cuda_library() -> *mut c_void {
|
||||
libc::dlopen(
|
||||
NVCUDA_DEFAULT_PATH.as_ptr() as *const _,
|
||||
libc::RTLD_LOCAL | libc::RTLD_NOW,
|
||||
)
|
||||
}
|
||||
|
||||
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
|
||||
libc::dlsym(handle, func.as_ptr() as *const _)
|
||||
}
|
68
zluda_dump/src/os_win.rs
Normal file
68
zluda_dump/src/os_win.rs
Normal file
|
@ -0,0 +1,68 @@
|
|||
use std::{
|
||||
ffi::{c_void, CStr},
|
||||
mem,
|
||||
os::raw::c_ushort,
|
||||
ptr,
|
||||
};
|
||||
|
||||
use wchar::wch_c;
|
||||
use winapi::{
|
||||
shared::minwindef::HMODULE,
|
||||
um::libloaderapi::{GetProcAddress, LoadLibraryW},
|
||||
};
|
||||
|
||||
const NVCUDA_DEFAULT_PATH: &[u16] = wch_c!(r"C:\Windows\System32\nvcuda.dll");
|
||||
const LOAD_LIBRARY_NO_REDIRECT: &'static [u8] = b"ZludaLoadLibraryW_NoRedirect\0";
|
||||
|
||||
include!("../../zluda_redirect/src/payload_guid.rs");
|
||||
|
||||
pub unsafe fn load_cuda_library() -> *mut c_void {
|
||||
let load_lib = if is_detoured() {
|
||||
match get_non_detoured_load_library() {
|
||||
Some(load_lib) => load_lib,
|
||||
None => return ptr::null_mut(),
|
||||
}
|
||||
} else {
|
||||
LoadLibraryW
|
||||
};
|
||||
load_lib(NVCUDA_DEFAULT_PATH.as_ptr()) as *mut _
|
||||
}
|
||||
|
||||
unsafe fn is_detoured() -> bool {
|
||||
let mut module = ptr::null_mut();
|
||||
loop {
|
||||
module = detours_sys::DetourEnumerateModules(module);
|
||||
if module == ptr::null_mut() {
|
||||
break;
|
||||
}
|
||||
let mut size = 0;
|
||||
let payload = detours_sys::DetourFindPayload(module, &PAYLOAD_GUID, &mut size);
|
||||
if payload != ptr::null_mut() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
unsafe fn get_non_detoured_load_library(
|
||||
) -> Option<unsafe extern "system" fn(*const c_ushort) -> HMODULE> {
|
||||
let mut module = ptr::null_mut();
|
||||
loop {
|
||||
module = detours_sys::DetourEnumerateModules(module);
|
||||
if module == ptr::null_mut() {
|
||||
break;
|
||||
}
|
||||
let result = GetProcAddress(
|
||||
module as *mut _,
|
||||
LOAD_LIBRARY_NO_REDIRECT.as_ptr() as *mut _,
|
||||
);
|
||||
if result != ptr::null_mut() {
|
||||
return Some(mem::transmute(result));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub unsafe fn get_proc_address(handle: *mut c_void, func: &CStr) -> *mut c_void {
|
||||
GetProcAddress(handle as *mut _, func.as_ptr()) as *mut _
|
||||
}
|
|
@ -9,5 +9,5 @@ name = "zluda_with"
|
|||
path = "src/main.rs"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "std", "synchapi"] }
|
||||
winapi = { version = "0.3", features = ["jobapi2", "processthreadsapi", "std", "synchapi", "winbase"] }
|
||||
detours-sys = { path = "../detours-sys" }
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
use std::env;
|
||||
use std::env::Args;
|
||||
use std::mem;
|
||||
use std::path::Path;
|
||||
use std::ptr;
|
||||
use std::{env, ops::Deref};
|
||||
use std::{error::Error, process};
|
||||
|
||||
use mem::size_of_val;
|
||||
|
@ -25,15 +24,15 @@ static ZLUDA_DLL: &'static str = "nvcuda.dll";
|
|||
include!("../../zluda_redirect/src/payload_guid.rs");
|
||||
|
||||
pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
||||
let args = env::args();
|
||||
if args.len() == 0 {
|
||||
print_help();
|
||||
process::exit(1);
|
||||
let args = env::args().collect::<Vec<_>>();
|
||||
if args.len() <= 1 {
|
||||
print_help_and_exit();
|
||||
}
|
||||
let mut cmd_line = construct_command_line(args);
|
||||
let injector_path = env::current_exe()?;
|
||||
let injector_dir = injector_path.parent().unwrap();
|
||||
let redirect_path = create_redirect_path(injector_dir);
|
||||
let (mut inject_path, cmd) = create_inject_path(&args[1..], injector_dir);
|
||||
let mut cmd_line = construct_command_line(cmd);
|
||||
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
|
||||
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
|
||||
os_call!(
|
||||
|
@ -54,13 +53,12 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
|||
|x| x != 0
|
||||
);
|
||||
kill_child_on_process_exit(proc_info.hProcess)?;
|
||||
let mut zluda_path = create_zluda_path(injector_dir);
|
||||
os_call!(
|
||||
detours_sys::DetourCopyPayloadToProcess(
|
||||
proc_info.hProcess,
|
||||
&PAYLOAD_GUID,
|
||||
zluda_path.as_mut_ptr() as *mut _,
|
||||
(zluda_path.len() * mem::size_of::<u16>()) as u32
|
||||
inject_path.as_mut_ptr() as *mut _,
|
||||
(inject_path.len() * mem::size_of::<u16>()) as u32
|
||||
),
|
||||
|x| x != 0
|
||||
);
|
||||
|
@ -93,22 +91,29 @@ fn kill_child_on_process_exit(child: HANDLE) -> Result<(), Box<dyn Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn print_help() {
|
||||
fn print_help_and_exit() -> ! {
|
||||
let current_exe = env::current_exe().unwrap();
|
||||
let exe_name = current_exe.file_name().unwrap().to_string_lossy();
|
||||
println!(
|
||||
"USAGE:
|
||||
zluda <EXE> [ARGS]...
|
||||
{0} -- <EXE> [ARGS]...
|
||||
{0} <DLL> -- <EXE> [ARGS]...
|
||||
ARGS:
|
||||
<EXE> Path to the executable to be injected with ZLUDA
|
||||
<DLL> DLL to ne injected instead of system nvcuda.dll, if not provided
|
||||
will use nvcuda.dll from the directory where {0} is located
|
||||
<EXE> Path to the executable to be injected with <DLL>
|
||||
<ARGS>... Arguments that will be passed to <EXE>
|
||||
"
|
||||
",
|
||||
exe_name
|
||||
);
|
||||
process::exit(1)
|
||||
}
|
||||
|
||||
// Adapted from https://docs.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way
|
||||
fn construct_command_line(args: Args) -> Vec<u16> {
|
||||
fn construct_command_line(args: &[String]) -> Vec<u16> {
|
||||
let mut cmd_line = Vec::new();
|
||||
let args_len = args.len();
|
||||
for (idx, arg) in args.enumerate().skip(1) {
|
||||
for (idx, arg) in args.iter().enumerate() {
|
||||
if !arg.contains(&[' ', '\t', '\n', '\u{2B7F}', '\"'][..]) {
|
||||
cmd_line.extend(arg.encode_utf16());
|
||||
} else {
|
||||
|
@ -168,14 +173,22 @@ fn create_redirect_path(injector_dir: &Path) -> Vec<u8> {
|
|||
result
|
||||
}
|
||||
|
||||
fn create_zluda_path(injector_dir: &Path) -> Vec<u16> {
|
||||
let mut injector_dir = injector_dir.to_path_buf();
|
||||
injector_dir.push(ZLUDA_DLL);
|
||||
let mut result = injector_dir
|
||||
.to_string_lossy()
|
||||
.as_ref()
|
||||
.encode_utf16()
|
||||
.collect::<Vec<_>>();
|
||||
result.push(0);
|
||||
result
|
||||
fn create_inject_path<'a>(args: &'a [String], injector_dir: &Path) -> (Vec<u16>, &'a [String]) {
|
||||
if args.get(0).map(Deref::deref) == Some("--") {
|
||||
let mut injector_dir = injector_dir.to_path_buf();
|
||||
injector_dir.push(ZLUDA_DLL);
|
||||
let mut result = injector_dir
|
||||
.to_string_lossy()
|
||||
.as_ref()
|
||||
.encode_utf16()
|
||||
.collect::<Vec<_>>();
|
||||
result.push(0);
|
||||
(result, &args[1..])
|
||||
} else if args.get(1).map(Deref::deref) == Some("--") {
|
||||
let mut dll_path = args[0].encode_utf16().collect::<Vec<_>>();
|
||||
dll_path.push(0);
|
||||
(dll_path, &args[2..])
|
||||
} else {
|
||||
print_help_and_exit()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,6 +46,12 @@ static mut LOAD_LIBRARY_EX_W: unsafe extern "system" fn(
|
|||
dwFlags: DWORD,
|
||||
) -> HMODULE = LoadLibraryExW;
|
||||
|
||||
#[no_mangle]
|
||||
#[allow(non_snake_case)]
|
||||
unsafe extern "system" fn ZludaLoadLibraryW_NoRedirect(lpLibFileName: LPCWSTR) -> HMODULE {
|
||||
(LOAD_LIBRARY_W)(lpLibFileName)
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
unsafe extern "system" fn ZludaLoadLibraryA(lpLibFileName: LPCSTR) -> HMODULE {
|
||||
let nvcuda_file_name = if is_nvcuda_dll_utf8(lpLibFileName as *const _) {
|
||||
|
|
Loading…
Add table
Reference in a new issue