diff --git a/zluda_trace/src/lib.rs b/zluda_trace/src/lib.rs index 4a52791..f6edea8 100644 --- a/zluda_trace/src/lib.rs +++ b/zluda_trace/src/lib.rs @@ -895,6 +895,7 @@ cuda_function_declarations!( cuModuleLoadFatBinary, cuLibraryGetModule, cuLibraryLoadData, + cuMemAlloc_v2, ], extern_redirect_with_pre_post <= [cuLaunchKernel, cuLaunchKernelEx], override_fn_core <= [cuGetProcAddress, cuGetProcAddress_v2], @@ -1606,6 +1607,7 @@ pub(crate) fn cuLaunchKernel_Post( fn_logger: &mut FnCallLog, _result: CUresult, ) { + save_magic_ptr(libcuda, state, _f, hStream, state.magic_ptr); let pre_state = unwrap_some_or!(pre_state, return); replay::post_kernel_launch( libcuda, @@ -1661,6 +1663,13 @@ pub(crate) fn cuLaunchKernelEx_Post( fn_logger: &mut FnCallLog, _result: CUresult, ) { + save_magic_ptr( + libcuda, + state, + _f, + unsafe { *config }.hStream, + state.magic_ptr, + ); let pre_state = unwrap_some_or!(pre_state, return); replay::post_kernel_launch( libcuda, @@ -1671,3 +1680,42 @@ pub(crate) fn cuLaunchKernelEx_Post( pre_state, ); } + +#[allow(non_snake_case)] +pub(crate) fn cuMemAlloc_v2_Post( + dptr: *mut cuda_types::cuda::CUdeviceptr, + bytesize: usize, + state: &mut trace::StateTracker, + _fn_logger: &mut FnCallLog, + _result: CUresult, +) { + if bytesize == 2097152 { + state.magic_ptr = Some(unsafe { *dptr }); + } +} + +fn save_magic_ptr( + libcuda: &mut CudaDynamicFns, + state: &mut trace::StateTracker, + f: cuda_types::cuda::CUfunction, + stream: CUstream, + magic_ptr: Option, +) { + let magic_ptr = unwrap_some_or!(magic_ptr, return); + let mut kernel_name = unwrap_some_or!(state.kernels.get(&f), return).name.clone(); + kernel_name.truncate(224); + libcuda.cuStreamSynchronize(stream).unwrap().unwrap(); + let mut host = vec![0u8; 2097152]; + let cpy_err = libcuda + .cuMemcpyDtoH_v2(host.as_mut_ptr().cast(), magic_ptr, 2097152) + .unwrap(); + if !cpy_err.is_ok() { + return; + } + let mut dump_dir = state.dump_dir().unwrap().clone(); + dump_dir.push(format!( + "magic_ptr_{}_{}.bin", + state.enqueue_counter, kernel_name + )); + std::fs::write(dump_dir, host).unwrap(); +} diff --git a/zluda_trace/src/trace.rs b/zluda_trace/src/trace.rs index 9fe8660..941661a 100644 --- a/zluda_trace/src/trace.rs +++ b/zluda_trace/src/trace.rs @@ -29,6 +29,7 @@ pub(crate) struct StateTracker { pub(crate) override_cc: Option<(u32, u32)>, pub(crate) kernel_name_filter: Option, pub(crate) kernel_no_output: bool, + pub(crate) magic_ptr: Option, } pub(crate) struct ParsedModule { @@ -59,6 +60,7 @@ impl StateTracker { override_cc: settings.override_cc, kernel_name_filter: settings.kernel_name_filter.clone(), kernel_no_output: settings.kernel_no_output.unwrap_or(false), + magic_ptr: None, } }