Save source ptx and save to the right path

This commit is contained in:
Andrzej Janik 2025-09-19 01:53:01 +00:00
commit f3e143d8dd
4 changed files with 91 additions and 48 deletions

View file

@ -1504,19 +1504,33 @@ pub(crate) fn cuLaunchKernel_Pre(
_blockDimY: ::core::ffi::c_uint,
_blockDimZ: ::core::ffi::c_uint,
_sharedMemBytes: ::core::ffi::c_uint,
_hStream: cuda_types::cuda::CUstream,
stream: cuda_types::cuda::CUstream,
kernel_params: *mut *mut ::core::ffi::c_void,
_extra: *mut *mut ::core::ffi::c_void,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
) -> Option<(String, Vec<zluda_trace_common::replay::KernelParameter>)> {
) -> Option<replay::LaunchPreState> {
launch_kernel_pre(f, stream, kernel_params, libcuda, state, fn_logger)
}
fn launch_kernel_pre(
f: cuda_types::cuda::CUfunction,
stream: cuda_types::cuda::CUstream,
kernel_params: *mut *mut ::core::ffi::c_void,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
) -> Option<replay::LaunchPreState> {
state.enqueue_counter += 1;
if kernel_params.is_null() {
fn_logger.log(ErrorEntry::NullPointer("kernel_params"));
return None;
}
replay::pre_kernel_launch(libcuda, state, fn_logger, f, kernel_params)
if state.dump_dir().is_none() {
return None;
}
replay::pre_kernel_launch(libcuda, state, fn_logger, f, stream, kernel_params)
}
#[allow(non_snake_case)]
@ -1529,63 +1543,58 @@ pub(crate) fn cuLaunchKernel_Post(
_blockDimY: ::core::ffi::c_uint,
_blockDimZ: ::core::ffi::c_uint,
_sharedMemBytes: ::core::ffi::c_uint,
_hStream: cuda_types::cuda::CUstream,
stream: cuda_types::cuda::CUstream,
kernel_params: *mut *mut ::core::ffi::c_void,
_extra: *mut *mut ::core::ffi::c_void,
pre_state: Option<(String, Vec<zluda_trace_common::replay::KernelParameter>)>,
pre_state: Option<replay::LaunchPreState>,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
let (kernel_name, pre_state) = unwrap_some_or!(pre_state, return);
replay::post_kernel_launch(
libcuda,
fn_logger,
kernel_params,
pre_state,
state.enqueue_counter,
kernel_name,
);
let pre_state = unwrap_some_or!(pre_state, return);
replay::post_kernel_launch(libcuda, state, fn_logger, stream, kernel_params, pre_state);
}
#[allow(non_snake_case)]
pub(crate) fn cuLaunchKernelEx_Pre(
_config: *const cuda_types::cuda::CUlaunchConfig,
config: *const cuda_types::cuda::CUlaunchConfig,
f: cuda_types::cuda::CUfunction,
kernel_params: *mut *mut ::core::ffi::c_void,
_extra: *mut *mut ::core::ffi::c_void,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
) -> Option<(String, Vec<zluda_trace_common::replay::KernelParameter>)> {
state.enqueue_counter += 1;
if kernel_params.is_null() {
fn_logger.log(ErrorEntry::NullPointer("kernel_params"));
return None;
}
replay::pre_kernel_launch(libcuda, state, fn_logger, f, kernel_params)
) -> Option<replay::LaunchPreState> {
launch_kernel_pre(
f,
unsafe { *config }.hStream,
kernel_params,
libcuda,
state,
fn_logger,
)
}
#[allow(non_snake_case)]
pub(crate) fn cuLaunchKernelEx_Post(
_config: *const cuda_types::cuda::CUlaunchConfig,
config: *const cuda_types::cuda::CUlaunchConfig,
_f: cuda_types::cuda::CUfunction,
kernel_params: *mut *mut ::core::ffi::c_void,
_extra: *mut *mut ::core::ffi::c_void,
pre_state: Option<(String, Vec<zluda_trace_common::replay::KernelParameter>)>,
pre_state: Option<replay::LaunchPreState>,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
let (kernel_name, pre_state) = unwrap_some_or!(pre_state, return);
let pre_state = unwrap_some_or!(pre_state, return);
replay::post_kernel_launch(
libcuda,
state,
fn_logger,
unsafe { *config }.hStream,
kernel_params,
pre_state,
state.enqueue_counter,
kernel_name,
);
}

View file

@ -6,20 +6,28 @@ use crate::{
use cuda_types::cuda::*;
use zluda_trace_common::replay::KernelParameter;
pub struct LaunchPreState {
kernel_name: String,
source: String,
kernel_params: Vec<KernelParameter>,
}
pub(crate) fn pre_kernel_launch(
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
f: CUfunction,
stream: CUstream,
args: *mut *mut std::ffi::c_void,
) -> Option<(String, Vec<KernelParameter>)> {
) -> Option<LaunchPreState> {
fn_logger.try_cuda(|| libcuda.cuStreamSynchronize(stream))?;
let SavedKernel { name, owner } = fn_logger.try_return(|| {
state
.kernels
.get(&f)
.ok_or(ErrorEntry::UnknownFunctionHandle(f))
})?;
let ParsedModule { kernels } = fn_logger.try_return(|| {
let ParsedModule { source, kernels } = fn_logger.try_return(|| {
state
.parsed_libraries
.get(owner)
@ -74,19 +82,25 @@ pub(crate) fn pre_kernel_launch(
device_ptrs: ptr_overrides,
});
}
Some((name.to_string(), all_params))
Some(LaunchPreState {
kernel_name: name.to_string(),
source: source.to_string(),
kernel_params: all_params,
})
}
pub(crate) fn post_kernel_launch(
libcuda: &mut CudaDynamicFns,
state: &trace::StateTracker,
fn_logger: &mut FnCallLog,
args: *mut *mut std::ffi::c_void,
mut kernel_params: Vec<KernelParameter>,
enqueue_counter: usize,
kernel_name: String,
stream: CUstream,
kernel_params: *mut *mut std::ffi::c_void,
mut pre_state: LaunchPreState,
) -> Option<()> {
let raw_args = unsafe { std::slice::from_raw_parts(args, kernel_params.len()) };
for (raw_arg, param) in raw_args.iter().zip(kernel_params.iter_mut()) {
fn_logger.try_cuda(|| libcuda.cuStreamSynchronize(stream))?;
let raw_args =
unsafe { std::slice::from_raw_parts(kernel_params, pre_state.kernel_params.len()) };
for (raw_arg, param) in raw_args.iter().zip(pre_state.kernel_params.iter_mut()) {
for (offset_in_param, offset_in_buffer, _, data_after) in param.device_ptrs.iter_mut() {
let dev_ptr_param = unsafe { raw_arg.cast::<u8>().add(*offset_in_param) };
let mut dev_ptr = unsafe { dev_ptr_param.cast::<usize>().read_unaligned() };
@ -100,11 +114,19 @@ pub(crate) fn post_kernel_launch(
})?;
}
}
let path = format!("kernel_{enqueue_counter}_{kernel_name}.tar.zst");
let enqueue_counter = state.enqueue_counter;
let kernel_name = &pre_state.kernel_name;
let mut path = state.dump_dir()?.to_path_buf();
path.push(format!("kernel_{enqueue_counter}_{kernel_name}.tar.zst"));
let file =
fn_logger.try_return(|| std::fs::File::create_new(path).map_err(ErrorEntry::IoError))?;
fn_logger.try_return(|| {
zluda_trace_common::replay::save(file, kernel_name, kernel_params)
.map_err(ErrorEntry::IoError)
zluda_trace_common::replay::save(
file,
pre_state.kernel_name,
pre_state.source,
pre_state.kernel_params,
)
.map_err(ErrorEntry::IoError)
})
}

View file

@ -30,6 +30,7 @@ pub(crate) struct StateTracker {
}
pub(crate) struct ParsedModule {
pub source: String,
pub kernels: FxHashMap<String, Vec<Layout>>,
}
@ -57,6 +58,10 @@ impl StateTracker {
}
}
pub(crate) fn dump_dir(&self) -> Option<&PathBuf> {
self.writer.dump_dir.as_ref()
}
pub(crate) fn record_new_module_file(
&mut self,
module: CUmodule,
@ -147,12 +152,15 @@ impl StateTracker {
}
});
};
self.parsed_libraries.insert(
SendablePtr(handle),
ParsedModule {
kernels: kernel_arguments.unwrap_or_default(),
},
);
if let Some((source, kernel_arguments)) = kernel_arguments {
self.parsed_libraries.insert(
SendablePtr(handle),
ParsedModule {
source,
kernels: kernel_arguments,
},
);
}
}
#[must_use]
@ -162,7 +170,7 @@ impl StateTracker {
submodule: &[u8],
fn_logger: &mut FnCallLog,
type_: &'static str,
) -> Option<FxHashMap<String, Vec<Layout>>> {
) -> Option<(String, FxHashMap<String, Vec<Layout>>)> {
fn_logger.try_(|fn_logger| {
self.writer
.save_module(fn_logger, self.library_counter, index, submodule, type_)
@ -198,7 +206,7 @@ impl StateTracker {
module_index: usize,
submodule_index: Option<(usize, Option<usize>)>,
module_text: &'input str,
) -> FxHashMap<String, Vec<Layout>> {
) -> (String, FxHashMap<String, Vec<Layout>>) {
let (errors, params) = ptx_parser::parse_for_errors_and_params(module_text);
if !errors.is_empty() {
fn_logger.log(log::ErrorEntry::ModuleParsingError(
@ -210,7 +218,7 @@ impl StateTracker {
&*errors,
));
}
params
(module_text.to_string(), params)
}
pub(crate) fn record_module_in_library(&mut self, module: CUmodule, library: CUlibrary) {

View file

@ -37,6 +37,7 @@ pub struct KernelParameter {
pub fn save(
writer: impl Write,
kernel_name: String,
source: String,
kernel_params: Vec<KernelParameter>,
) -> std::io::Result<()> {
let archive = zstd::Encoder::new(writer, 0)?;
@ -61,6 +62,9 @@ pub fn save(
}
.serialize()?;
builder.append_data(&mut header, Manifest::PATH, &*manifest)?;
let mut header = Header::new_gnu();
header.set_size(source.len() as u64);
builder.append_data(&mut header, "source.ptx", source.as_bytes())?;
for (i, param) in kernel_params.into_iter().enumerate() {
let path = format!("param_{i}.bin");
let mut header = Header::new_gnu();