diff --git a/zluda_trace/src/lib.rs b/zluda_trace/src/lib.rs index a0c56b4..e0f5daf 100644 --- a/zluda_trace/src/lib.rs +++ b/zluda_trace/src/lib.rs @@ -896,7 +896,7 @@ cuda_function_declarations!( cuLibraryGetModule, cuLibraryLoadData, ], - extern_redirect_with_pre_post <= [cuLaunchKernelEx], + extern_redirect_with_pre_post <= [cuLaunchKernel, cuLaunchKernelEx], override_fn_core <= [cuGetProcAddress, cuGetProcAddress_v2], override_fn_full <= [cuGetExportTable], ); @@ -1494,28 +1494,60 @@ pub(crate) fn cuLibraryLoadData_Post( state.record_new_library(unsafe { *library }.0.cast(), code, fn_logger); } -/* +#[allow(non_snake_case)] +pub(crate) fn cuLaunchKernel_Pre( + f: cuda_types::cuda::CUfunction, + _gridDimX: ::core::ffi::c_uint, + _gridDimY: ::core::ffi::c_uint, + _gridDimZ: ::core::ffi::c_uint, + _blockDimX: ::core::ffi::c_uint, + _blockDimY: ::core::ffi::c_uint, + _blockDimZ: ::core::ffi::c_uint, + _sharedMemBytes: ::core::ffi::c_uint, + _hStream: cuda_types::cuda::CUstream, + kernel_params: *mut *mut ::core::ffi::c_void, + _extra: *mut *mut ::core::ffi::c_void, + libcuda: &mut CudaDynamicFns, + state: &mut trace::StateTracker, + fn_logger: &mut FnCallLog, +) -> Option<(String, Vec)> { + state.enqueue_counter += 1; + if kernel_params.is_null() { + fn_logger.log(ErrorEntry::NullPointer("kernel_params")); + return None; + } + replay::pre_kernel_launch(libcuda, state, fn_logger, f, kernel_params) +} + #[allow(non_snake_case)] pub(crate) fn cuLaunchKernel_Post( - f: cuda_types::cuda::CUfunction, - gridDimX: ::core::ffi::c_uint, - gridDimY: ::core::ffi::c_uint, - gridDimZ: ::core::ffi::c_uint, - blockDimX: ::core::ffi::c_uint, - blockDimY: ::core::ffi::c_uint, - blockDimZ: ::core::ffi::c_uint, - sharedMemBytes: ::core::ffi::c_uint, - hStream: cuda_types::cuda::CUstream, - kernelParams: *mut *mut ::core::ffi::c_void, - extra: *mut *mut ::core::ffi::c_void, + _f: cuda_types::cuda::CUfunction, + _gridDimX: ::core::ffi::c_uint, + _gridDimY: ::core::ffi::c_uint, + _gridDimZ: ::core::ffi::c_uint, + _blockDimX: ::core::ffi::c_uint, + _blockDimY: ::core::ffi::c_uint, + _blockDimZ: ::core::ffi::c_uint, + _sharedMemBytes: ::core::ffi::c_uint, + _hStream: cuda_types::cuda::CUstream, + kernel_params: *mut *mut ::core::ffi::c_void, + _extra: *mut *mut ::core::ffi::c_void, + pre_state: Option<(String, Vec)>, libcuda: &mut CudaDynamicFns, state: &mut trace::StateTracker, fn_logger: &mut FnCallLog, _result: CUresult, ) { - todo!() + let (kernel_name, pre_state) = unwrap_some_or!(pre_state, return); + replay::post_kernel_launch( + libcuda, + fn_logger, + kernel_params, + pre_state, + state.enqueue_counter, + kernel_name, + ); } - */ #[allow(non_snake_case)] pub(crate) fn cuLaunchKernelEx_Pre( @@ -1526,7 +1558,7 @@ pub(crate) fn cuLaunchKernelEx_Pre( libcuda: &mut CudaDynamicFns, state: &mut trace::StateTracker, fn_logger: &mut FnCallLog, -) -> Option> { +) -> Option<(String, Vec)> { state.enqueue_counter += 1; if kernel_params.is_null() { fn_logger.log(ErrorEntry::NullPointer("kernel_params")); @@ -1541,19 +1573,19 @@ pub(crate) fn cuLaunchKernelEx_Post( _f: cuda_types::cuda::CUfunction, kernel_params: *mut *mut ::core::ffi::c_void, _extra: *mut *mut ::core::ffi::c_void, - pre_state: Option>, + pre_state: Option<(String, Vec)>, libcuda: &mut CudaDynamicFns, state: &mut trace::StateTracker, fn_logger: &mut FnCallLog, _result: CUresult, ) { - let pre_state = unwrap_some_or!(pre_state, return); + let (kernel_name, pre_state) = unwrap_some_or!(pre_state, return); replay::post_kernel_launch( libcuda, fn_logger, kernel_params, pre_state, state.enqueue_counter, - "".to_string(), + kernel_name, ); } diff --git a/zluda_trace/src/replay.rs b/zluda_trace/src/replay.rs index fd30836..7bfd7da 100644 --- a/zluda_trace/src/replay.rs +++ b/zluda_trace/src/replay.rs @@ -12,7 +12,7 @@ pub(crate) fn pre_kernel_launch( fn_logger: &mut FnCallLog, f: CUfunction, args: *mut *mut std::ffi::c_void, -) -> Option> { +) -> Option<(String, Vec)> { let SavedKernel { name, owner } = fn_logger.try_return(|| { state .kernels @@ -74,7 +74,7 @@ pub(crate) fn pre_kernel_launch( device_ptrs: ptr_overrides, }); } - Some(all_params) + Some((name.to_string(), all_params)) } pub(crate) fn post_kernel_launch( @@ -100,7 +100,7 @@ pub(crate) fn post_kernel_launch( })?; } } - let path = format!("kernel_{enqueue_counter}_.tar.zst"); + let path = format!("kernel_{enqueue_counter}_{kernel_name}.tar.zst"); let file = fn_logger.try_return(|| std::fs::File::create_new(path).map_err(ErrorEntry::IoError))?; fn_logger.try_return(|| {