Add tests for injecting into CLR process

This commit is contained in:
Andrzej Janik 2022-02-03 12:28:42 +01:00
parent 9923a36b76
commit c869a0d611
6 changed files with 54 additions and 6 deletions

View file

@ -50,6 +50,15 @@ fn main() -> Result<(), VarError> {
.arg(full_file_path);
assert!(rustc_cmd.status().unwrap().success());
}
std::fs::copy(
format!(
"{}{}do_cuinit_main_clr.exe",
helpers_dir_as_string,
path::MAIN_SEPARATOR
),
format!("{}{}do_cuinit_main_clr.exe", out_dir, path::MAIN_SEPARATOR),
)
.unwrap();
println!("cargo:rustc-env=HELPERS_OUT_DIR={}", &out_dir);
Ok(())
}

View file

@ -17,7 +17,7 @@ fn main() {
dll.push("do_cuinit.dll");
let dll_cstring = CString::new(dll.to_str().unwrap()).unwrap();
let nvcuda = unsafe { LoadLibraryA(dll_cstring.as_ptr()) };
let cuInit = unsafe { GetProcAddress(nvcuda, b"do_cuinit\0".as_ptr()) };
let cuInit = unsafe { mem::transmute::<_, unsafe extern "system" fn(u32) -> u32>(cuInit) };
unsafe { cuInit(0) };
let cu_init = unsafe { GetProcAddress(nvcuda, b"do_cuinit\0".as_ptr()) };
let cu_init = unsafe { mem::transmute::<_, unsafe extern "system" fn(u32) -> u32>(cu_init) };
unsafe { cu_init(0) };
}

View file

@ -0,0 +1,34 @@
using System;
using System.IO;
using System.Reflection;
using System.Runtime.InteropServices;
namespace Zluda
{
class Program
{
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
private delegate int CuInit(int flags);
static int Main(string[] args)
{
DirectoryInfo exeDirectory = Directory.GetParent(Assembly.GetEntryAssembly().Location);
string dllPath = Path.Combine(exeDirectory.ToString(), "do_cuinit.dll");
IntPtr nvcuda = NativeMethods.LoadLibrary(dllPath);
if (nvcuda == IntPtr.Zero)
return 1;
IntPtr doCuinitPtr = NativeMethods.GetProcAddress(nvcuda, "do_cuinit");
CuInit cuinit = (CuInit)Marshal.GetDelegateForFunctionPointer(doCuinitPtr, typeof(CuInit));
return cuinit(0);
}
}
static class NativeMethods
{
[DllImport("kernel32.dll")]
public static extern IntPtr LoadLibrary(string dllToLoad);
[DllImport("kernel32.dll")]
public static extern IntPtr GetProcAddress(IntPtr hModule, string procedureName);
}
}

Binary file not shown.

View file

@ -10,7 +10,7 @@ extern "system" {
fn main() {
let nvcuda = unsafe { LoadLibraryA(b"C:\\Windows\\System32\\nvcuda.dll\0".as_ptr()) };
let cuInit = unsafe { GetProcAddress(nvcuda, b"cuInit\0".as_ptr()) };
let cuInit = unsafe { mem::transmute::<_, unsafe extern "system" fn(u32) -> u32>(cuInit) };
unsafe { cuInit(0) };
let cu_init = unsafe { GetProcAddress(nvcuda, b"cuInit\0".as_ptr()) };
let cu_init = unsafe { mem::transmute::<_, unsafe extern "system" fn(u32) -> u32>(cu_init) };
unsafe { cu_init(0) };
}

View file

@ -15,6 +15,11 @@ fn do_cuinit() -> io::Result<()> {
run_process_and_check_for_zluda_dump("do_cuinit_main")
}
#[test]
fn do_cuinit_clr() -> io::Result<()> {
run_process_and_check_for_zluda_dump("do_cuinit_main_clr")
}
fn run_process_and_check_for_zluda_dump(name: &'static str) -> io::Result<()> {
let zluda_with_exe = PathBuf::from(env!("CARGO_BIN_EXE_zluda_with"));
let mut zluda_dump_dll = zluda_with_exe.parent().unwrap().to_path_buf();