Fix some bugs

This commit is contained in:
Andrzej Janik 2025-09-19 00:58:42 +00:00
commit d880ee78b5
2 changed files with 54 additions and 22 deletions

View file

@ -896,7 +896,7 @@ cuda_function_declarations!(
cuLibraryGetModule,
cuLibraryLoadData,
],
extern_redirect_with_pre_post <= [cuLaunchKernelEx],
extern_redirect_with_pre_post <= [cuLaunchKernel, cuLaunchKernelEx],
override_fn_core <= [cuGetProcAddress, cuGetProcAddress_v2],
override_fn_full <= [cuGetExportTable],
);
@ -1494,28 +1494,60 @@ pub(crate) fn cuLibraryLoadData_Post(
state.record_new_library(unsafe { *library }.0.cast(), code, fn_logger);
}
/*
#[allow(non_snake_case)]
pub(crate) fn cuLaunchKernel_Pre(
f: cuda_types::cuda::CUfunction,
_gridDimX: ::core::ffi::c_uint,
_gridDimY: ::core::ffi::c_uint,
_gridDimZ: ::core::ffi::c_uint,
_blockDimX: ::core::ffi::c_uint,
_blockDimY: ::core::ffi::c_uint,
_blockDimZ: ::core::ffi::c_uint,
_sharedMemBytes: ::core::ffi::c_uint,
_hStream: cuda_types::cuda::CUstream,
kernel_params: *mut *mut ::core::ffi::c_void,
_extra: *mut *mut ::core::ffi::c_void,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
) -> Option<(String, Vec<zluda_trace_common::replay::KernelParameter>)> {
state.enqueue_counter += 1;
if kernel_params.is_null() {
fn_logger.log(ErrorEntry::NullPointer("kernel_params"));
return None;
}
replay::pre_kernel_launch(libcuda, state, fn_logger, f, kernel_params)
}
#[allow(non_snake_case)]
pub(crate) fn cuLaunchKernel_Post(
f: cuda_types::cuda::CUfunction,
gridDimX: ::core::ffi::c_uint,
gridDimY: ::core::ffi::c_uint,
gridDimZ: ::core::ffi::c_uint,
blockDimX: ::core::ffi::c_uint,
blockDimY: ::core::ffi::c_uint,
blockDimZ: ::core::ffi::c_uint,
sharedMemBytes: ::core::ffi::c_uint,
hStream: cuda_types::cuda::CUstream,
kernelParams: *mut *mut ::core::ffi::c_void,
extra: *mut *mut ::core::ffi::c_void,
_f: cuda_types::cuda::CUfunction,
_gridDimX: ::core::ffi::c_uint,
_gridDimY: ::core::ffi::c_uint,
_gridDimZ: ::core::ffi::c_uint,
_blockDimX: ::core::ffi::c_uint,
_blockDimY: ::core::ffi::c_uint,
_blockDimZ: ::core::ffi::c_uint,
_sharedMemBytes: ::core::ffi::c_uint,
_hStream: cuda_types::cuda::CUstream,
kernel_params: *mut *mut ::core::ffi::c_void,
_extra: *mut *mut ::core::ffi::c_void,
pre_state: Option<(String, Vec<zluda_trace_common::replay::KernelParameter>)>,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
todo!()
let (kernel_name, pre_state) = unwrap_some_or!(pre_state, return);
replay::post_kernel_launch(
libcuda,
fn_logger,
kernel_params,
pre_state,
state.enqueue_counter,
kernel_name,
);
}
*/
#[allow(non_snake_case)]
pub(crate) fn cuLaunchKernelEx_Pre(
@ -1526,7 +1558,7 @@ pub(crate) fn cuLaunchKernelEx_Pre(
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
) -> Option<Vec<zluda_trace_common::replay::KernelParameter>> {
) -> Option<(String, Vec<zluda_trace_common::replay::KernelParameter>)> {
state.enqueue_counter += 1;
if kernel_params.is_null() {
fn_logger.log(ErrorEntry::NullPointer("kernel_params"));
@ -1541,19 +1573,19 @@ pub(crate) fn cuLaunchKernelEx_Post(
_f: cuda_types::cuda::CUfunction,
kernel_params: *mut *mut ::core::ffi::c_void,
_extra: *mut *mut ::core::ffi::c_void,
pre_state: Option<Vec<zluda_trace_common::replay::KernelParameter>>,
pre_state: Option<(String, Vec<zluda_trace_common::replay::KernelParameter>)>,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
let pre_state = unwrap_some_or!(pre_state, return);
let (kernel_name, pre_state) = unwrap_some_or!(pre_state, return);
replay::post_kernel_launch(
libcuda,
fn_logger,
kernel_params,
pre_state,
state.enqueue_counter,
"".to_string(),
kernel_name,
);
}

View file

@ -12,7 +12,7 @@ pub(crate) fn pre_kernel_launch(
fn_logger: &mut FnCallLog,
f: CUfunction,
args: *mut *mut std::ffi::c_void,
) -> Option<Vec<KernelParameter>> {
) -> Option<(String, Vec<KernelParameter>)> {
let SavedKernel { name, owner } = fn_logger.try_return(|| {
state
.kernels
@ -74,7 +74,7 @@ pub(crate) fn pre_kernel_launch(
device_ptrs: ptr_overrides,
});
}
Some(all_params)
Some((name.to_string(), all_params))
}
pub(crate) fn post_kernel_launch(
@ -100,7 +100,7 @@ pub(crate) fn post_kernel_launch(
})?;
}
}
let path = format!("kernel_{enqueue_counter}_.tar.zst");
let path = format!("kernel_{enqueue_counter}_{kernel_name}.tar.zst");
let file =
fn_logger.try_return(|| std::fs::File::create_new(path).map_err(ErrorEntry::IoError))?;
fn_logger.try_return(|| {