Trace magic pointer

This commit is contained in:
Andrzej Janik 2025-09-24 00:15:25 +00:00
commit ed1ea1f6de
2 changed files with 50 additions and 0 deletions

View file

@ -895,6 +895,7 @@ cuda_function_declarations!(
cuModuleLoadFatBinary,
cuLibraryGetModule,
cuLibraryLoadData,
cuMemAlloc_v2,
],
extern_redirect_with_pre_post <= [cuLaunchKernel, cuLaunchKernelEx],
override_fn_core <= [cuGetProcAddress, cuGetProcAddress_v2],
@ -1606,6 +1607,7 @@ pub(crate) fn cuLaunchKernel_Post(
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
save_magic_ptr(libcuda, state, _f, hStream, state.magic_ptr);
let pre_state = unwrap_some_or!(pre_state, return);
replay::post_kernel_launch(
libcuda,
@ -1661,6 +1663,13 @@ pub(crate) fn cuLaunchKernelEx_Post(
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
save_magic_ptr(
libcuda,
state,
_f,
unsafe { *config }.hStream,
state.magic_ptr,
);
let pre_state = unwrap_some_or!(pre_state, return);
replay::post_kernel_launch(
libcuda,
@ -1671,3 +1680,42 @@ pub(crate) fn cuLaunchKernelEx_Post(
pre_state,
);
}
#[allow(non_snake_case)]
pub(crate) fn cuMemAlloc_v2_Post(
dptr: *mut cuda_types::cuda::CUdeviceptr,
bytesize: usize,
state: &mut trace::StateTracker,
_fn_logger: &mut FnCallLog,
_result: CUresult,
) {
if bytesize == 2097152 {
state.magic_ptr = Some(unsafe { *dptr });
}
}
fn save_magic_ptr(
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
f: cuda_types::cuda::CUfunction,
stream: CUstream,
magic_ptr: Option<CUdeviceptr_v2>,
) {
let magic_ptr = unwrap_some_or!(magic_ptr, return);
let mut kernel_name = unwrap_some_or!(state.kernels.get(&f), return).name.clone();
kernel_name.truncate(224);
libcuda.cuStreamSynchronize(stream).unwrap().unwrap();
let mut host = vec![0u8; 2097152];
let cpy_err = libcuda
.cuMemcpyDtoH_v2(host.as_mut_ptr().cast(), magic_ptr, 2097152)
.unwrap();
if !cpy_err.is_ok() {
return;
}
let mut dump_dir = state.dump_dir().unwrap().clone();
dump_dir.push(format!(
"magic_ptr_{}_{}.bin",
state.enqueue_counter, kernel_name
));
std::fs::write(dump_dir, host).unwrap();
}

View file

@ -29,6 +29,7 @@ pub(crate) struct StateTracker {
pub(crate) override_cc: Option<(u32, u32)>,
pub(crate) kernel_name_filter: Option<regex::Regex>,
pub(crate) kernel_no_output: bool,
pub(crate) magic_ptr: Option<CUdeviceptr>,
}
pub(crate) struct ParsedModule {
@ -59,6 +60,7 @@ impl StateTracker {
override_cc: settings.override_cc,
kernel_name_filter: settings.kernel_name_filter.clone(),
kernel_no_output: settings.kernel_no_output.unwrap_or(false),
magic_ptr: None,
}
}