Start working on trace replay

This commit is contained in:
Andrzej Janik 2025-09-19 00:39:27 +00:00
commit bfef3317dc
13 changed files with 621 additions and 164 deletions

21
Cargo.lock generated
View file

@ -420,7 +420,7 @@ version = "0.0.0"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustc-hash 1.1.0", "rustc-hash 2.0.0",
"syn 2.0.89", "syn 2.0.89",
] ]
@ -3706,7 +3706,7 @@ dependencies = [
"paste", "paste",
"ptx", "ptx",
"ptx_parser", "ptx_parser",
"rustc-hash 1.1.0", "rustc-hash 2.0.0",
"serde", "serde",
"serde_json", "serde_json",
"tempfile", "tempfile",
@ -3726,7 +3726,7 @@ dependencies = [
"prettyplease", "prettyplease",
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustc-hash 1.1.0", "rustc-hash 2.0.0",
"syn 2.0.89", "syn 2.0.89",
] ]
@ -3854,7 +3854,7 @@ dependencies = [
"ptx", "ptx",
"ptx_parser", "ptx_parser",
"regex", "regex",
"rustc-hash 1.1.0", "rustc-hash 2.0.0",
"unwrap_or", "unwrap_or",
"wchar", "wchar",
"winapi", "winapi",
@ -3903,6 +3903,10 @@ dependencies = [
"format", "format",
"libc", "libc",
"libloading", "libloading",
"serde",
"serde_json",
"tar",
"zstd",
] ]
[[package]] [[package]]
@ -3979,6 +3983,15 @@ dependencies = [
"simd-adler32", "simd-adler32",
] ]
[[package]]
name = "zstd"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a"
dependencies = [
"zstd-safe",
]
[[package]] [[package]]
name = "zstd-safe" name = "zstd-safe"
version = "7.2.4" version = "7.2.4"

View file

@ -8,7 +8,7 @@ edition = "2021"
quote = "1.0" quote = "1.0"
syn = { version = "2.0", features = ["full", "visit-mut", "extra-traits"] } syn = { version = "2.0", features = ["full", "visit-mut", "extra-traits"] }
proc-macro2 = "1.0" proc-macro2 = "1.0"
rustc-hash = "1.1.0" rustc-hash = "2.0.0"
[lib] [lib]
proc-macro = true proc-macro = true

View file

@ -2,6 +2,7 @@ use derive_more::Display;
use logos::Logos; use logos::Logos;
use ptx_parser_macros::derive_parser; use ptx_parser_macros::derive_parser;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use std::alloc::Layout;
use std::fmt::Debug; use std::fmt::Debug;
use std::num::{NonZeroU8, ParseFloatError, ParseIntError}; use std::num::{NonZeroU8, ParseFloatError, ParseIntError};
use std::{iter, usize}; use std::{iter, usize};
@ -345,7 +346,9 @@ fn reg_or_immediate<'a, 'input>(
.parse_next(stream) .parse_next(stream)
} }
pub fn parse_for_errors<'input>(text: &'input str) -> Vec<PtxError<'input>> { pub fn parse_for_errors_and_params<'input>(
text: &'input str,
) -> (Vec<PtxError<'input>>, FxHashMap<String, Vec<Layout>>) {
let (tokens, mut errors) = lex_with_span_unchecked(text); let (tokens, mut errors) = lex_with_span_unchecked(text);
let parse_result = { let parse_result = {
let state = PtxParserState::new(text, &mut errors); let state = PtxParserState::new(text, &mut errors);
@ -357,13 +360,30 @@ pub fn parse_for_errors<'input>(text: &'input str) -> Vec<PtxError<'input>> {
.parse(parser) .parse(parser)
.map_err(|err| PtxError::Parser(err.into_inner())) .map_err(|err| PtxError::Parser(err.into_inner()))
}; };
match parse_result { let params = match parse_result {
Ok(_) => {} Ok(module) => module
.directives
.into_iter()
.filter_map(|directive| {
if let ast::Directive::Method(_, func) = directive {
let layouts = func
.func_directive
.input_arguments
.iter()
.map(|arg| arg.v_type.layout())
.collect();
Some((func.func_directive.name().to_string(), layouts))
} else {
None
}
})
.collect(),
Err(err) => { Err(err) => {
errors.push(err); errors.push(err);
FxHashMap::default()
} }
} };
errors (errors, params)
} }
fn lex_with_span_unchecked<'input>( fn lex_with_span_unchecked<'input>(

View file

@ -22,7 +22,7 @@ num_enum = "0.4"
lz4-sys = "1.9" lz4-sys = "1.9"
tempfile = "3" tempfile = "3"
paste = "1.0" paste = "1.0"
rustc-hash = "1.1" rustc-hash = "2.0.0"
zluda_common = { path = "../zluda_common" } zluda_common = { path = "../zluda_common" }
blake3 = "1.8.2" blake3 = "1.8.2"
serde = "1.0.219" serde = "1.0.219"

View file

@ -9,6 +9,6 @@ syn = { version = "2.0", features = ["full", "visit-mut"] }
proc-macro2 = "1.0.89" proc-macro2 = "1.0.89"
quote = "1.0" quote = "1.0"
prettyplease = "0.2.25" prettyplease = "0.2.25"
rustc-hash = "1.1.0" rustc-hash = "2.0.0"
libloading = "0.8" libloading = "0.8"
cuda_types = { path = "../cuda_types" } cuda_types = { path = "../cuda_types" }

View file

@ -24,7 +24,7 @@ paste = "1.0"
cuda_macros = { path = "../cuda_macros" } cuda_macros = { path = "../cuda_macros" }
cuda_types = { path = "../cuda_types" } cuda_types = { path = "../cuda_types" }
parking_lot = "0.12.3" parking_lot = "0.12.3"
rustc-hash = "1.1.0" rustc-hash = "2.0.0"
cglue = "0.3.5" cglue = "0.3.5"
zstd-safe = { version = "7.2.4", features = ["std"] } zstd-safe = { version = "7.2.4", features = ["std"] }
unwrap_or = "1.0.1" unwrap_or = "1.0.1"

View file

@ -12,6 +12,7 @@ use std::ptr::NonNull;
use std::sync::LazyLock; use std::sync::LazyLock;
use std::{env, error::Error, fs, path::PathBuf, sync::Mutex}; use std::{env, error::Error, fs, path::PathBuf, sync::Mutex};
use std::{io, mem, ptr, usize}; use std::{io, mem, ptr, usize};
use unwrap_or::unwrap_some_or;
extern crate cuda_types; extern crate cuda_types;
@ -110,7 +111,7 @@ macro_rules! override_fn_core {
).ok(); ).ok();
formatted_args formatted_args
}; };
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| Some(()); let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| Some(((), ()));
let cuda_call = |_| { let cuda_call = |_| {
paste!{ [<$fn_name _impl >] ( $($arg_id),* ) } paste!{ [<$fn_name _impl >] ( $($arg_id),* ) }
}; };
@ -121,7 +122,7 @@ macro_rules! override_fn_core {
format_curesult, format_curesult,
extract_fn_ptr, extract_fn_ptr,
cuda_call, cuda_call,
move |_, _, _, _| {} move |_, _, _, _, _| {}
) )
} }
)* )*
@ -157,9 +158,9 @@ impl ::dark_api::zluda_trace::CudaDarkApi for InternalTableImpl {
Some(|| args.call().to_vec()), Some(|| args.call().to_vec()),
internal_error, internal_error,
|status| format_status(status).to_vec(), |status| format_status(status).to_vec(),
|_, _| Some(()), |_, _| Some(((), ())),
|_| fn_.call(), |_| fn_.call(),
move |_, _, _, _| {}, move |_, _, _, _, _| {},
) )
} }
} }
@ -201,7 +202,7 @@ macro_rules! dark_api_fn_redirect_log {
).ok(); ).ok();
formatted_args formatted_args
}; };
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| { Some(()) }; let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| { Some(((), ())) };
let cuda_call = |_: () | { let cuda_call = |_: () | {
ReprUsize::to_usize(original_fn( $( $arg_id ),* )) ReprUsize::to_usize(original_fn( $( $arg_id ),* ))
}; };
@ -215,7 +216,7 @@ macro_rules! dark_api_fn_redirect_log {
|status| <$ret_type as ReprUsize>::format_status(status).to_vec(), |status| <$ret_type as ReprUsize>::format_status(status).to_vec(),
extract_fn_ptr, extract_fn_ptr,
cuda_call, cuda_call,
move |_, _, _, _| {} move |_, _, _, _, _| {}
)) ))
} }
)+ )+
@ -256,7 +257,7 @@ macro_rules! dark_api_fn_redirect_log_post {
).ok(); ).ok();
formatted_args formatted_args
}; };
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| { Some(()) }; let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| { Some(((), ())) };
let cuda_call = |_: () | { let cuda_call = |_: () | {
ReprUsize::to_usize(original_fn( $( $arg_id ),* )) ReprUsize::to_usize(original_fn( $( $arg_id ),* ))
}; };
@ -270,7 +271,7 @@ macro_rules! dark_api_fn_redirect_log_post {
|status| <$ret_type as ReprUsize>::format_status(status).to_vec(), |status| <$ret_type as ReprUsize>::format_status(status).to_vec(),
extract_fn_ptr, extract_fn_ptr,
cuda_call, cuda_call,
move |state, logger, _, cuda_result| paste! { Self:: [<$fn_ _post>] } ( $( $arg_id ),* , &mut state.cuda_state, logger, <$ret_type as ReprUsize>::from_usize(cuda_result)) move |state, logger, _, _, cuda_result| paste! { Self:: [<$fn_ _post>] } ( $( $arg_id ),* , &mut state.cuda_state, logger, <$ret_type as ReprUsize>::from_usize(cuda_result))
)) ))
} }
)+ )+
@ -287,7 +288,11 @@ impl DarkApiTrace {
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
_result: CUresult, _result: CUresult,
) { ) {
state.record_new_library(unsafe { *module }, fatbinc_wrapper.cast(), fn_logger) state.record_new_library(
unsafe { *module }.0.cast(),
fatbinc_wrapper.cast(),
fn_logger,
)
} }
fn get_module_from_cubin_ext1_post( fn get_module_from_cubin_ext1_post(
@ -321,7 +326,11 @@ impl DarkApiTrace {
observed: UInt::U32(arg5), observed: UInt::U32(arg5),
}); });
} }
state.record_new_library(unsafe { *module }, fatbinc_wrapper.cast(), fn_logger) state.record_new_library(
unsafe { *module }.0.cast(),
fatbinc_wrapper.cast(),
fn_logger,
)
} }
fn get_module_from_cubin_ext2_post( fn get_module_from_cubin_ext2_post(
@ -355,7 +364,7 @@ impl DarkApiTrace {
observed: UInt::U32(arg5), observed: UInt::U32(arg5),
}); });
} }
state.record_new_library(unsafe { *module }, fatbin_header.cast(), fn_logger) state.record_new_library(unsafe { *module }.0.cast(), fatbin_header.cast(), fn_logger)
} }
} }
@ -770,7 +779,7 @@ macro_rules! extern_redirect {
}; };
let extract_fn_ptr = |state: &mut GlobalDelayedState, _: &mut FnCallLog| { let extract_fn_ptr = |state: &mut GlobalDelayedState, _: &mut FnCallLog| {
paste::paste! { paste::paste! {
state.libcuda. [<get_ $fn_name>]() state.libcuda. [<get_ $fn_name>]().map(|x| ((), x) )
} }
}; };
let cuda_call = |fn_ptr: extern $abi fn ( $($arg_type),* ) -> $ret_type | { let cuda_call = |fn_ptr: extern $abi fn ( $($arg_type),* ) -> $ret_type | {
@ -783,7 +792,7 @@ macro_rules! extern_redirect {
format_curesult, format_curesult,
extract_fn_ptr, extract_fn_ptr,
cuda_call, cuda_call,
move |_, _, _, _| {} move |_, _, _, _, _| {}
) )
} }
)* )*
@ -806,7 +815,7 @@ macro_rules! extern_redirect_with_post {
}; };
let extract_fn_ptr = |state: &mut GlobalDelayedState, _: &mut FnCallLog| { let extract_fn_ptr = |state: &mut GlobalDelayedState, _: &mut FnCallLog| {
paste::paste! { paste::paste! {
state.libcuda. [<get_ $fn_name>]() state.libcuda. [<get_ $fn_name>]().map(|x| ((), x) )
} }
}; };
let cuda_call = |fn_ptr: extern $abi fn ( $($arg_type),* ) -> $ret_type | { let cuda_call = |fn_ptr: extern $abi fn ( $($arg_type),* ) -> $ret_type | {
@ -819,7 +828,43 @@ macro_rules! extern_redirect_with_post {
format_curesult, format_curesult,
extract_fn_ptr, extract_fn_ptr,
cuda_call, cuda_call,
move |state, logger, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , &mut state.cuda_state, logger, cuda_result ) move |state, logger, _, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , &mut state.cuda_state, logger, cuda_result )
)
}
)*
};
}
macro_rules! extern_redirect_with_pre_post {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
$(
#[no_mangle]
#[allow(improper_ctypes_definitions)]
pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
let format_args = || {
let mut formatted_args = Vec::new();
(paste! { format :: [<write_ $fn_name>] }) (
&mut formatted_args
$(,$arg_id)*
).ok();
formatted_args
};
let extract_fn_ptr = |state: &mut GlobalDelayedState, logger: &mut FnCallLog| {
paste::paste! {
state.libcuda. [<get_ $fn_name>]().map(|x| (paste! { [<$fn_name _Pre>] } ( $( $arg_id ),* , &mut state.libcuda, &mut state.cuda_state, logger ), x ))
}
};
let cuda_call = |fn_ptr: extern $abi fn ( $($arg_type),* ) -> $ret_type | {
fn_ptr( $( $arg_id ),* )
};
GlobalState2::under_lock(
CudaFunctionName::Normal(stringify!($fn_name)),
Some(format_args),
CUresult::INTERNAL_ERROR,
format_curesult,
extract_fn_ptr,
cuda_call,
move |state, logger, pre_state, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , pre_state, &mut state.libcuda, &mut state.cuda_state, logger, cuda_result )
) )
} }
)* )*
@ -843,13 +888,15 @@ cuda_function_declarations!(
cuModuleLoad, cuModuleLoad,
cuModuleLoadData, cuModuleLoadData,
cuModuleLoadDataEx, cuModuleLoadDataEx,
cuLibraryGetFunction,
cuModuleGetFunction, cuModuleGetFunction,
cuDeviceGetAttribute, cuDeviceGetAttribute,
cuDeviceComputeCapability, cuDeviceComputeCapability,
cuModuleLoadFatBinary, cuModuleLoadFatBinary,
cuLibraryGetModule, cuLibraryGetModule,
cuLibraryLoadData cuLibraryLoadData,
], ],
extern_redirect_with_pre_post <= [cuLaunchKernelEx],
override_fn_core <= [cuGetProcAddress, cuGetProcAddress_v2], override_fn_core <= [cuGetProcAddress, cuGetProcAddress_v2],
override_fn_full <= [cuGetExportTable], override_fn_full <= [cuGetExportTable],
); );
@ -859,6 +906,7 @@ mod log;
#[cfg_attr(windows, path = "os_win.rs")] #[cfg_attr(windows, path = "os_win.rs")]
#[cfg_attr(not(windows), path = "os_unix.rs")] #[cfg_attr(not(windows), path = "os_unix.rs")]
mod os; mod os;
mod replay;
mod trace; mod trace;
struct GlobalState2 { struct GlobalState2 {
@ -907,27 +955,33 @@ impl GlobalState2 {
// * Post-call: // * Post-call:
// We log the output of the CUDA function and any errors that may have occurred. This phase // We log the output of the CUDA function and any errors that may have occurred. This phase
// is also covered by a drop guard which will flush the log buffer in case of panic // is also covered by a drop guard which will flush the log buffer in case of panic
fn under_lock<'a, FnPtr: Copy, InnerResult: Copy>( fn under_lock<'a, PreState, FnPtr: Copy, InnerResult: Copy>(
name: CudaFunctionName, name: CudaFunctionName,
args: Option<impl FnOnce() -> Vec<u8>>, args: Option<impl FnOnce() -> Vec<u8>>,
internal_error: InnerResult, internal_error: InnerResult,
format_status: impl FnOnce(InnerResult) -> Vec<u8>, format_status: impl FnOnce(InnerResult) -> Vec<u8>,
pre_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog) -> Option<FnPtr>, pre_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog) -> Option<(PreState, FnPtr)>,
inner_call: impl FnOnce(FnPtr) -> InnerResult, inner_call: impl FnOnce(FnPtr) -> InnerResult,
post_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog, FnPtr, InnerResult), post_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog, PreState, FnPtr, InnerResult),
) -> InnerResult { ) -> InnerResult {
fn under_lock_impl<'a, FnPtr: Copy, InnerResult: Copy>( fn under_lock_impl<'a, PreState, FnPtr: Copy, InnerResult: Copy>(
name: CudaFunctionName, name: CudaFunctionName,
args: Option<impl FnOnce() -> Vec<u8>>, args: Option<impl FnOnce() -> Vec<u8>>,
internal_error: InnerResult, internal_error: InnerResult,
format_status: impl FnOnce(InnerResult) -> Vec<u8>, format_status: impl FnOnce(InnerResult) -> Vec<u8>,
pre_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog) -> Option<FnPtr>, pre_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog) -> Option<(PreState, FnPtr)>,
inner_call: impl FnOnce(FnPtr) -> InnerResult, inner_call: impl FnOnce(FnPtr) -> InnerResult,
post_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog, FnPtr, InnerResult), post_call: impl FnOnce(
&mut GlobalDelayedState,
&mut FnCallLog,
PreState,
FnPtr,
InnerResult,
),
) -> InnerResult { ) -> InnerResult {
let global_state = GLOBAL_STATE2.lock(); let global_state = GLOBAL_STATE2.lock();
let global_state_ref_cell = &*global_state; let global_state_ref_cell = &*global_state;
let pre_value = { let (pre_state, pre_ptr) = {
let mut global_state_ref_mut = global_state_ref_cell.borrow_mut(); let mut global_state_ref_mut = global_state_ref_cell.borrow_mut();
let global_state = &mut *global_state_ref_mut; let global_state = &mut *global_state_ref_mut;
let panic_guard = OuterCallGuard { let panic_guard = OuterCallGuard {
@ -963,7 +1017,7 @@ impl GlobalState2 {
} }
}; };
let panic_guard = InnerCallGuard(global_state_ref_cell); let panic_guard = InnerCallGuard(global_state_ref_cell);
let inner_result = inner_call(pre_value); let inner_result = inner_call(pre_ptr);
let global_state = &mut *global_state_ref_cell.borrow_mut(); let global_state = &mut *global_state_ref_cell.borrow_mut();
mem::forget(panic_guard); mem::forget(panic_guard);
let _drop_guard = OuterCallGuard { let _drop_guard = OuterCallGuard {
@ -978,7 +1032,8 @@ impl GlobalState2 {
post_call( post_call(
global_state.delayed_state.as_mut().unwrap(), global_state.delayed_state.as_mut().unwrap(),
&mut logger, &mut logger,
pre_value, pre_state,
pre_ptr,
inner_result, inner_result,
); );
inner_result inner_result
@ -1098,6 +1153,22 @@ impl FnCallLog {
} }
} }
fn try_cuda(&mut self, fn_: impl FnOnce() -> Option<CUresult>) -> Option<()> {
match fn_() {
Some(Ok(())) => Some(()),
None => {
self.subcalls
.push(LogEntry::Error(ErrorEntry::CudaError(None)));
None
}
Some(Err(err)) => {
self.subcalls
.push(LogEntry::Error(ErrorEntry::CudaError(Some(err))));
None
}
}
}
fn try_<T>(&mut self, f: impl FnOnce(&mut Self) -> Result<T, ErrorEntry>) -> Option<T> { fn try_<T>(&mut self, f: impl FnOnce(&mut Self) -> Result<T, ErrorEntry>) -> Option<T> {
match f(self) { match f(self) {
Err(e) => { Err(e) => {
@ -1307,7 +1378,7 @@ pub(crate) fn cuModuleLoadData_Post(
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
_result: CUresult, _result: CUresult,
) { ) {
state.record_new_library(unsafe { *module }, raw_image, fn_logger) state.record_new_library(unsafe { *module }.0.cast(), raw_image, fn_logger)
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
@ -1326,13 +1397,17 @@ pub(crate) fn cuModuleLoadDataEx_Post(
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub(crate) fn cuModuleGetFunction_Post( pub(crate) fn cuModuleGetFunction_Post(
_hfunc: *mut CUfunction, hfunc: *mut CUfunction,
_hmod: CUmodule, hmod: CUmodule,
_name: *const ::std::os::raw::c_char, name: *const ::std::os::raw::c_char,
_state: &mut trace::StateTracker, state: &mut trace::StateTracker,
_fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
_result: CUresult, result: CUresult,
) { ) {
if !result.is_ok() {
return;
}
state.record_function_from_module(fn_logger, unsafe { *hfunc }, hmod, name);
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
@ -1385,7 +1460,7 @@ pub(crate) fn cuModuleLoadFatBinary_Post(
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
_result: CUresult, _result: CUresult,
) { ) {
state.record_new_library(unsafe { *module }, fatbin_header.cast(), fn_logger) state.record_new_library(unsafe { *module }.0.cast(), fatbin_header.cast(), fn_logger)
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
@ -1393,13 +1468,13 @@ pub(crate) fn cuLibraryGetModule_Post(
module: *mut cuda_types::cuda::CUmodule, module: *mut cuda_types::cuda::CUmodule,
library: cuda_types::cuda::CUlibrary, library: cuda_types::cuda::CUlibrary,
state: &mut trace::StateTracker, state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog, _fn_logger: &mut FnCallLog,
_result: CUresult, result: CUresult,
) { ) {
match state.libraries.get(&library).copied() { if !result.is_ok() {
None => fn_logger.log(log::ErrorEntry::UnknownLibrary(library)), return;
Some(code) => state.record_new_library(unsafe { *module }, code.0, fn_logger),
} }
state.record_module_in_library(unsafe { *module }, library);
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
@ -1416,10 +1491,69 @@ pub(crate) fn cuLibraryLoadData_Post(
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
_result: CUresult, _result: CUresult,
) { ) {
state state.record_new_library(unsafe { *library }.0.cast(), code, fn_logger);
.libraries }
.insert(unsafe { *library }, trace::CodePointer(code));
// TODO: this is not correct, but it's enough for now, we just want to /*
// save the binary to disk #[allow(non_snake_case)]
state.record_new_library(unsafe { CUmodule((*library).0.cast()) }, code, fn_logger); pub(crate) fn cuLaunchKernel_Post(
f: cuda_types::cuda::CUfunction,
gridDimX: ::core::ffi::c_uint,
gridDimY: ::core::ffi::c_uint,
gridDimZ: ::core::ffi::c_uint,
blockDimX: ::core::ffi::c_uint,
blockDimY: ::core::ffi::c_uint,
blockDimZ: ::core::ffi::c_uint,
sharedMemBytes: ::core::ffi::c_uint,
hStream: cuda_types::cuda::CUstream,
kernelParams: *mut *mut ::core::ffi::c_void,
extra: *mut *mut ::core::ffi::c_void,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
todo!()
}
*/
#[allow(non_snake_case)]
pub(crate) fn cuLaunchKernelEx_Pre(
_config: *const cuda_types::cuda::CUlaunchConfig,
f: cuda_types::cuda::CUfunction,
kernel_params: *mut *mut ::core::ffi::c_void,
_extra: *mut *mut ::core::ffi::c_void,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
) -> Option<Vec<zluda_trace_common::replay::KernelParameter>> {
state.enqueue_counter += 1;
if kernel_params.is_null() {
fn_logger.log(ErrorEntry::NullPointer("kernel_params"));
return None;
}
replay::pre_kernel_launch(libcuda, state, fn_logger, f, kernel_params)
}
#[allow(non_snake_case)]
pub(crate) fn cuLaunchKernelEx_Post(
_config: *const cuda_types::cuda::CUlaunchConfig,
_f: cuda_types::cuda::CUfunction,
kernel_params: *mut *mut ::core::ffi::c_void,
_extra: *mut *mut ::core::ffi::c_void,
pre_state: Option<Vec<zluda_trace_common::replay::KernelParameter>>,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
let pre_state = unwrap_some_or!(pre_state, return);
replay::post_kernel_launch(
libcuda,
fn_logger,
kernel_params,
pre_state,
state.enqueue_counter,
"".to_string(),
);
} }

View file

@ -1,8 +1,8 @@
use super::Settings; use super::Settings;
use crate::trace::SendablePtr;
use crate::FnCallLog; use crate::FnCallLog;
use crate::LogEntry; use crate::LogEntry;
use cuda_types::cuda::*; use cuda_types::cuda::*;
use format::CudaDisplay;
use std::error::Error; use std::error::Error;
use std::ffi::c_void; use std::ffi::c_void;
use std::ffi::NulError; use std::ffi::NulError;
@ -267,13 +267,12 @@ pub(crate) enum ErrorEntry {
CreatedDumpDirectory(PathBuf), CreatedDumpDirectory(PathBuf),
ErrorBox(Box<dyn Error>), ErrorBox(Box<dyn Error>),
UnsupportedModule { UnsupportedModule {
module: CUmodule, handle: *mut c_void,
raw_image: *const c_void, raw_image: *const c_void,
kind: &'static str, kind: &'static str,
}, },
FunctionNotFound(CudaFunctionName), FunctionNotFound(CudaFunctionName),
MalformedModulePath(Utf8Error), Utf8Error(Utf8Error),
NonUtf8ModuleText(Utf8Error),
NulInsideModuleText(NulError), NulInsideModuleText(NulError),
ModuleParsingError(String), ModuleParsingError(String),
Lz4DecompressionFailure, Lz4DecompressionFailure,
@ -302,8 +301,11 @@ pub(crate) enum ErrorEntry {
overriden: [u64; 2], overriden: [u64; 2],
}, },
NullPointer(&'static str), NullPointer(&'static str),
UnknownLibrary(CUlibrary),
SavedModule(String), SavedModule(String),
UnknownFunctionHandle(CUfunction),
UnknownLibrary(CUfunction, SendablePtr),
UnknownFunction(CUfunction, SendablePtr, String),
CudaError(Option<CUerror>),
} }
unsafe impl Send for ErrorEntry {} unsafe impl Send for ErrorEntry {}
@ -345,94 +347,100 @@ impl Display for ErrorEntry {
match self { match self {
ErrorEntry::IoError(e) => e.fmt(f), ErrorEntry::IoError(e) => e.fmt(f),
ErrorEntry::CreatedDumpDirectory(dir) => { ErrorEntry::CreatedDumpDirectory(dir) => {
write!( write!(
f, f,
"Created trace directory {} ", "Created trace directory {} ",
dir.as_os_str().to_string_lossy() dir.as_os_str().to_string_lossy()
) )
} }
ErrorEntry::ErrorBox(e) => e.fmt(f), ErrorEntry::ErrorBox(e) => e.fmt(f),
ErrorEntry::UnsupportedModule { ErrorEntry::UnsupportedModule {
module, handle,
raw_image, raw_image,
kind, kind,
} => { } => {
write!( write!(
f, f,
"Unsupported {} module {:?} loaded from module image {:?}", "Unsupported {} module {:p} loaded from module image {:p}",
kind, module, raw_image kind, handle, raw_image
) )
} }
ErrorEntry::MalformedModulePath(e) => e.fmt(f), ErrorEntry::Utf8Error(e) => e.fmt(f),
ErrorEntry::NonUtf8ModuleText(e) => e.fmt(f),
ErrorEntry::ModuleParsingError(file_name) => { ErrorEntry::ModuleParsingError(file_name) => {
write!( write!(
f, f,
"Error parsing module, log has been written to {}", "Error parsing module, log has been written to {}",
file_name file_name
) )
} }
ErrorEntry::NulInsideModuleText(e) => e.fmt(f), ErrorEntry::NulInsideModuleText(e) => e.fmt(f),
ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"), ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"),
ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)), ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)),
ErrorEntry::UnexpectedBinaryField { ErrorEntry::UnexpectedBinaryField {
field_name, field_name,
expected, expected,
observed, observed,
} => write!( } => write!(
f, f,
"Unexpected field {}. Expected one of: [{}], observed: {}", "Unexpected field {}. Expected one of: [{}], observed: {}",
field_name, field_name,
expected expected
.iter() .iter()
.map(|x| x.to_string()) .map(|x| x.to_string())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", "), .join(", "),
observed observed
), ),
ErrorEntry::UnexpectedArgument { ErrorEntry::UnexpectedArgument {
arg_name, arg_name,
expected, expected,
observed, observed,
} => write!( } => write!(
f, f,
"Unexpected argument {}. Expected one of: {{{}}}, observed: {}", "Unexpected argument {}. Expected one of: {{{}}}, observed: {}",
arg_name, arg_name,
expected expected
.iter() .iter()
.map(|x| x.to_string()) .map(|x| x.to_string())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", "), .join(", "),
observed observed
), ),
ErrorEntry::InvalidEnvVar { ErrorEntry::InvalidEnvVar {
var, var,
pattern, pattern,
value, value,
} => write!( } => write!(
f, f,
"Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}" "Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}"
), ),
ErrorEntry::FunctionNotFound(cuda_function_name) => write!( ErrorEntry::FunctionNotFound(cuda_function_name) => write!(
f, f,
"No function {cuda_function_name} in the underlying library" "No function {cuda_function_name} in the underlying library"
), ),
ErrorEntry::UnexpectedExportTableSize { expected, computed } => { ErrorEntry::UnexpectedExportTableSize { expected, computed } => {
write!(f, "Table length mismatch. Expected: {expected}, got: {computed}") write!(f, "Table length mismatch. Expected: {expected}, got: {computed}")
} }
ErrorEntry::IntegrityCheck { original, overriden } => { ErrorEntry::IntegrityCheck { original, overriden } => {
write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}") write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}")
} }
ErrorEntry::NullPointer(type_) => { ErrorEntry::NullPointer(type_) => {
write!(f, "Null pointer of type {type_} encountered") write!(f, "Null pointer of type {type_} encountered")
} }
ErrorEntry::UnknownLibrary(culibrary) => {
write!(f, "Unknown library: ")?;
let mut temp_buffer = Vec::new();
CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok();
f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) })
}
ErrorEntry::SavedModule(file) => write!(f, "Saved module to {file}"), ErrorEntry::SavedModule(file) => write!(f, "Saved module to {file}"),
ErrorEntry::UnknownFunctionHandle(cuda_function_name) => {
write!(f, "Function with unknown provenance: {cuda_function_name:p}")
}
ErrorEntry::UnknownLibrary(cuda_function_name, owner) => {
write!(f, "Function with unknown provenance: {cuda_function_name:p}, owner: {owner:p}")
}
ErrorEntry::UnknownFunction(cuda_function_name, owner, name) => {
write!(f, "Function with unknown provenance: {cuda_function_name:p}, owner: {owner:p}, name: {name}")
}
ErrorEntry::CudaError(cuerror) => {
let cuerror = cuerror.map(|e| e.0);
write!(f, "CUDA error encountered: {cuerror:#?}")
},
} }
} }
} }

110
zluda_trace/src/replay.rs Normal file
View file

@ -0,0 +1,110 @@
use crate::{
log::ErrorEntry,
trace::{self, ParsedModule, SavedKernel},
CudaDynamicFns, FnCallLog,
};
use cuda_types::cuda::*;
use zluda_trace_common::replay::KernelParameter;
pub(crate) fn pre_kernel_launch(
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
f: CUfunction,
args: *mut *mut std::ffi::c_void,
) -> Option<Vec<KernelParameter>> {
let SavedKernel { name, owner } = fn_logger.try_return(|| {
state
.kernels
.get(&f)
.ok_or(ErrorEntry::UnknownFunctionHandle(f))
})?;
let ParsedModule { kernels } = fn_logger.try_return(|| {
state
.parsed_libraries
.get(owner)
.ok_or(ErrorEntry::UnknownLibrary(f, *owner))
})?;
let kernel_params = fn_logger.try_return(|| {
kernels
.get(name)
.ok_or_else(|| ErrorEntry::UnknownFunction(f, *owner, name.clone()))
})?;
let raw_args = unsafe { std::slice::from_raw_parts(args, kernel_params.len()) };
let mut all_params = Vec::new();
for (raw_arg, layout) in raw_args.iter().zip(kernel_params.iter()) {
let mut offset = 0;
let mut ptr_overrides = Vec::new();
while offset + std::mem::size_of::<usize>() <= layout.size() {
let maybe_ptr = unsafe { raw_arg.cast::<u8>().add(offset) };
let maybe_ptr = unsafe { maybe_ptr.cast::<usize>().read_unaligned() };
let attrs = &mut [
CUpointer_attribute_enum::CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
CUpointer_attribute_enum::CU_POINTER_ATTRIBUTE_RANGE_SIZE,
];
let mut start = 0usize;
let mut size = 0usize;
let mut data = [
(&mut start as *mut usize).cast::<std::ffi::c_void>(),
(&mut size as *mut usize).cast::<std::ffi::c_void>(),
];
if let Some(Ok(())) = libcuda.cuPointerGetAttributes(
2,
attrs.as_mut_ptr(),
data.as_mut_ptr(),
CUdeviceptr_v2(maybe_ptr as _),
) {
let mut pre_buffer = vec![0u8; size];
let post_buffer = vec![0u8; size];
fn_logger.try_cuda(|| {
libcuda.cuMemcpyDtoH_v2(
pre_buffer.as_mut_ptr().cast(),
CUdeviceptr_v2(start as _),
size,
)
})?;
let buffer_offset = maybe_ptr - start;
ptr_overrides.push((offset, buffer_offset, pre_buffer, post_buffer));
}
offset += std::mem::size_of::<usize>();
}
all_params.push(KernelParameter {
data: unsafe { std::slice::from_raw_parts(raw_arg.cast::<u8>(), layout.size()) }
.to_vec(),
device_ptrs: ptr_overrides,
});
}
Some(all_params)
}
pub(crate) fn post_kernel_launch(
libcuda: &mut CudaDynamicFns,
fn_logger: &mut FnCallLog,
args: *mut *mut std::ffi::c_void,
mut kernel_params: Vec<KernelParameter>,
enqueue_counter: usize,
kernel_name: String,
) -> Option<()> {
let raw_args = unsafe { std::slice::from_raw_parts(args, kernel_params.len()) };
for (raw_arg, param) in raw_args.iter().zip(kernel_params.iter_mut()) {
for (offset_in_param, offset_in_buffer, _, data_after) in param.device_ptrs.iter_mut() {
let dev_ptr_param = unsafe { raw_arg.cast::<u8>().add(*offset_in_param) };
let mut dev_ptr = unsafe { dev_ptr_param.cast::<usize>().read_unaligned() };
dev_ptr -= *offset_in_buffer;
fn_logger.try_cuda(|| {
libcuda.cuMemcpyDtoH_v2(
data_after.as_mut_ptr().cast(),
CUdeviceptr_v2(dev_ptr as _),
data_after.len(),
)
})?;
}
}
let path = format!("kernel_{enqueue_counter}_.tar.zst");
let file =
fn_logger.try_return(|| std::fs::File::create_new(path).map_err(ErrorEntry::IoError))?;
fn_logger.try_return(|| {
zluda_trace_common::replay::save(file, kernel_name, kernel_params)
.map_err(ErrorEntry::IoError)
})
}

View file

@ -4,8 +4,9 @@ use crate::{
}; };
use cuda_types::cuda::*; use cuda_types::cuda::*;
use goblin::{elf, elf32, elf64}; use goblin::{elf, elf32, elf64};
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::FxHashMap;
use std::{ use std::{
alloc::Layout,
ffi::{c_void, CStr, CString}, ffi::{c_void, CStr, CString},
fs::{self, File}, fs::{self, File},
io::{self, Read, Write}, io::{self, Read, Write},
@ -20,25 +21,38 @@ use unwrap_or::unwrap_some_or;
// * writes out relevant state change and details to disk and log // * writes out relevant state change and details to disk and log
pub(crate) struct StateTracker { pub(crate) struct StateTracker {
writer: DumpWriter, writer: DumpWriter,
pub(crate) libraries: FxHashMap<CUlibrary, CodePointer>, pub(crate) parsed_libraries: FxHashMap<SendablePtr, ParsedModule>,
saved_modules: FxHashSet<CUmodule>, pub(crate) submodules: FxHashMap<CUmodule, CUlibrary>,
pub(crate) kernels: FxHashMap<CUfunction, SavedKernel>,
library_counter: usize, library_counter: usize,
pub(crate) enqueue_counter: usize,
pub(crate) override_cc: Option<(u32, u32)>, pub(crate) override_cc: Option<(u32, u32)>,
} }
#[derive(Clone, Copy)] pub(crate) struct ParsedModule {
pub(crate) struct CodePointer(pub *const c_void); pub kernels: FxHashMap<String, Vec<Layout>>,
}
unsafe impl Send for CodePointer {} pub(crate) struct SavedKernel {
unsafe impl Sync for CodePointer {} pub name: String,
pub owner: SendablePtr,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) struct SendablePtr(*mut c_void);
unsafe impl Send for SendablePtr {}
unsafe impl Sync for SendablePtr {}
impl StateTracker { impl StateTracker {
pub(crate) fn new(settings: &Settings) -> Self { pub(crate) fn new(settings: &Settings) -> Self {
StateTracker { StateTracker {
writer: DumpWriter::new(settings.dump_dir.clone()), writer: DumpWriter::new(settings.dump_dir.clone()),
libraries: FxHashMap::default(), parsed_libraries: FxHashMap::default(),
saved_modules: FxHashSet::default(), submodules: FxHashMap::default(),
kernels: FxHashMap::default(),
library_counter: 0, library_counter: 0,
enqueue_counter: 0,
override_cc: settings.override_cc, override_cc: settings.override_cc,
} }
} }
@ -52,7 +66,7 @@ impl StateTracker {
let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() { let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() {
Ok(f) => f, Ok(f) => f,
Err(err) => { Err(err) => {
fn_logger.log(log::ErrorEntry::MalformedModulePath(err)); fn_logger.log(log::ErrorEntry::Utf8Error(err));
return; return;
} }
}; };
@ -69,21 +83,26 @@ impl StateTracker {
let mut module_file = fs::File::open(file_name)?; let mut module_file = fs::File::open(file_name)?;
let mut read_buff = Vec::new(); let mut read_buff = Vec::new();
module_file.read_to_end(&mut read_buff)?; module_file.read_to_end(&mut read_buff)?;
self.record_new_library(module, read_buff.as_ptr() as *const _, fn_logger); self.record_new_library(module.0.cast(), read_buff.as_ptr() as *const _, fn_logger);
Ok(()) Ok(())
} }
pub(crate) fn record_new_library( pub(crate) fn record_new_library(
&mut self, &mut self,
cu_module: CUmodule, handle: *mut c_void,
raw_image: *const c_void, raw_image: *const c_void,
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
) { ) {
self.saved_modules.insert(cu_module); fn overwrite<T>(current: &mut Option<T>, value: Option<T>) {
if value.is_some() {
*current = value;
}
}
let mut kernel_arguments = None;
self.library_counter += 1; self.library_counter += 1;
let code_ref = fn_logger.try_return(|| { let code_ref = fn_logger.try_return(|| {
unsafe { zluda_common::CodeLibraryRef::try_load(raw_image) } unsafe { zluda_common::CodeLibraryRef::try_load(raw_image) }
.map_err(ErrorEntry::NonUtf8ModuleText) .map_err(ErrorEntry::Utf8Error)
}); });
let code_ref = unwrap_some_or!(code_ref, return); let code_ref = unwrap_some_or!(code_ref, return);
unsafe { unsafe {
@ -92,17 +111,20 @@ impl StateTracker {
Ok(zluda_common::CodeModuleRef::Elf(elf)) => match get_elf_size(elf) { Ok(zluda_common::CodeModuleRef::Elf(elf)) => match get_elf_size(elf) {
Some(len) => { Some(len) => {
let elf_image = std::slice::from_raw_parts(elf.cast::<u8>(), len); let elf_image = std::slice::from_raw_parts(elf.cast::<u8>(), len);
self.record_new_submodule(index, elf_image, fn_logger, "elf"); overwrite(
&mut kernel_arguments,
self.record_new_submodule(index, elf_image, fn_logger, "elf"),
);
} }
None => fn_logger.log(log::ErrorEntry::UnsupportedModule { None => fn_logger.log(log::ErrorEntry::UnsupportedModule {
module: cu_module, handle,
raw_image: elf, raw_image: elf,
kind: "ELF", kind: "ELF",
}), }),
}, },
Ok(zluda_common::CodeModuleRef::Archive(archive)) => { Ok(zluda_common::CodeModuleRef::Archive(archive)) => {
fn_logger.log(log::ErrorEntry::UnsupportedModule { fn_logger.log(log::ErrorEntry::UnsupportedModule {
module: cu_module, handle,
raw_image: archive, raw_image: archive,
kind: "archive", kind: "archive",
}) })
@ -111,23 +133,36 @@ impl StateTracker {
if let Some(buffer) = fn_logger if let Some(buffer) = fn_logger
.try_(|_| file.get_or_decompress_content().map_err(ErrorEntry::from)) .try_(|_| file.get_or_decompress_content().map_err(ErrorEntry::from))
{ {
self.record_new_submodule(index, &*buffer, fn_logger, file.kind()); overwrite(
&mut kernel_arguments,
self.record_new_submodule(index, &*buffer, fn_logger, file.kind()),
);
} }
} }
Ok(zluda_common::CodeModuleRef::Text(ptx)) => { Ok(zluda_common::CodeModuleRef::Text(ptx)) => {
self.record_new_submodule(index, ptx.as_bytes(), fn_logger, "ptx"); overwrite(
&mut kernel_arguments,
self.record_new_submodule(index, ptx.as_bytes(), fn_logger, "ptx"),
);
} }
}); });
}; };
self.parsed_libraries.insert(
SendablePtr(handle),
ParsedModule {
kernels: kernel_arguments.unwrap_or_default(),
},
);
} }
#[must_use]
pub(crate) fn record_new_submodule( pub(crate) fn record_new_submodule(
&mut self, &mut self,
index: Option<(usize, Option<usize>)>, index: Option<(usize, Option<usize>)>,
submodule: &[u8], submodule: &[u8],
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
type_: &'static str, type_: &'static str,
) { ) -> Option<FxHashMap<String, Vec<Layout>>> {
fn_logger.try_(|fn_logger| { fn_logger.try_(|fn_logger| {
self.writer self.writer
.save_module(fn_logger, self.library_counter, index, submodule, type_) .save_module(fn_logger, self.library_counter, index, submodule, type_)
@ -135,28 +170,36 @@ impl StateTracker {
}); });
if type_ == "ptx" { if type_ == "ptx" {
match CString::new(submodule) { match CString::new(submodule) {
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)), Err(e) => {
fn_logger.log(log::ErrorEntry::NulInsideModuleText(e));
None
}
Ok(submodule_cstring) => match submodule_cstring.to_str() { Ok(submodule_cstring) => match submodule_cstring.to_str() {
Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)), Err(e) => {
Ok(submodule_text) => self.try_parse_and_record_kernels( fn_logger.log(log::ErrorEntry::Utf8Error(e));
None
}
Ok(submodule_text) => Some(self.try_parse_and_record_kernels(
fn_logger, fn_logger,
self.library_counter, self.library_counter,
index, index,
submodule_text, submodule_text,
), )),
}, },
} }
} else {
None
} }
} }
fn try_parse_and_record_kernels( fn try_parse_and_record_kernels<'input>(
&mut self, &mut self,
fn_logger: &mut FnCallLog, fn_logger: &mut FnCallLog,
module_index: usize, module_index: usize,
submodule_index: Option<(usize, Option<usize>)>, submodule_index: Option<(usize, Option<usize>)>,
module_text: &str, module_text: &'input str,
) { ) -> FxHashMap<String, Vec<Layout>> {
let errors = ptx_parser::parse_for_errors(module_text); let (errors, params) = ptx_parser::parse_for_errors_and_params(module_text);
if !errors.is_empty() { if !errors.is_empty() {
fn_logger.log(log::ErrorEntry::ModuleParsingError( fn_logger.log(log::ErrorEntry::ModuleParsingError(
DumpWriter::get_file_name(module_index, submodule_index, "log"), DumpWriter::get_file_name(module_index, submodule_index, "log"),
@ -167,6 +210,46 @@ impl StateTracker {
&*errors, &*errors,
)); ));
} }
params
}
pub(crate) fn record_module_in_library(&mut self, module: CUmodule, library: CUlibrary) {
self.submodules.insert(module, library);
}
pub(crate) fn record_function_from_module(
&mut self,
fn_logger: &mut FnCallLog,
func: CUfunction,
hmod: CUmodule,
name: *const i8,
) {
let owner = match self.submodules.get(&hmod) {
Some(m) => m.0.cast::<c_void>(),
None => hmod.0.cast::<c_void>(),
};
self.record_function_from_impl(fn_logger, func, owner, name);
}
fn record_function_from_impl(
&mut self,
fn_logger: &mut FnCallLog,
func: CUfunction,
owner: *mut c_void,
name: *const i8,
) {
let name = match unsafe { CStr::from_ptr(name) }.to_str() {
Ok(f) => f,
Err(err) => {
fn_logger.log(log::ErrorEntry::Utf8Error(err));
return;
}
};
let saved_kernel = SavedKernel {
name: name.to_string(),
owner: SendablePtr(owner),
};
self.kernels.insert(func, saved_kernel);
} }
} }

View file

@ -11,6 +11,10 @@ cuda_types = { path = "../cuda_types" }
dark_api = { path = "../dark_api" } dark_api = { path = "../dark_api" }
format = { path = "../format" } format = { path = "../format" }
cglue = "0.3.5" cglue = "0.3.5"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.142"
tar = "0.4"
zstd = "0.13"
[target.'cfg(not(windows))'.dependencies] [target.'cfg(not(windows))'.dependencies]
libc = "0.2" libc = "0.2"

View file

@ -8,6 +8,8 @@ use cuda_types::{
use dark_api::ByteVecFfi; use dark_api::ByteVecFfi;
use std::{borrow::Cow, ffi::c_void, num::NonZero, ptr, sync::LazyLock}; use std::{borrow::Cow, ffi::c_void, num::NonZero, ptr, sync::LazyLock};
pub mod replay;
pub fn get_export_table() -> Option<::dark_api::zluda_trace::ZludaTraceInternal> { pub fn get_export_table() -> Option<::dark_api::zluda_trace::ZludaTraceInternal> {
static CU_GET_EXPORT_TABLE: LazyLock< static CU_GET_EXPORT_TABLE: LazyLock<
Result< Result<

View file

@ -0,0 +1,83 @@
use std::io::Write;
use tar::Header;
#[derive(serde::Serialize, serde::Deserialize)]
struct Manifest {
kernel_name: String,
parameters: Vec<Parameter>,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct Parameter {
pointer_offsets: Vec<ParameterPointer>,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct ParameterPointer {
offset_in_param: usize,
offset_in_buffer: usize,
}
impl Manifest {
const PATH: &'static str = "manifest.json";
fn serialize(&self) -> std::io::Result<(Header, Vec<u8>)> {
let vec = serde_json::to_vec(self)?;
let mut header = Header::new_gnu();
header.set_size(vec.len() as u64);
Ok((header, vec))
}
}
pub struct KernelParameter {
pub data: Vec<u8>,
pub device_ptrs: Vec<(usize, usize, Vec<u8>, Vec<u8>)>,
}
pub fn save(
writer: impl Write,
kernel_name: String,
kernel_params: Vec<KernelParameter>,
) -> std::io::Result<()> {
let archive = zstd::Encoder::new(writer, 0)?;
let mut builder = tar::Builder::new(archive);
let (mut header, manifest) = Manifest {
kernel_name,
parameters: kernel_params
.iter()
.map(|param| Parameter {
pointer_offsets: param
.device_ptrs
.iter()
.map(
|(offset_in_param, offset_in_buffer, _, _)| ParameterPointer {
offset_in_param: *offset_in_param,
offset_in_buffer: *offset_in_buffer,
},
)
.collect(),
})
.collect(),
}
.serialize()?;
builder.append_data(&mut header, Manifest::PATH, &*manifest)?;
for (i, param) in kernel_params.into_iter().enumerate() {
let path = format!("param_{i}.bin");
let mut header = Header::new_gnu();
header.set_size(param.data.len() as u64);
builder.append_data(&mut header, &*path, &*param.data)?;
for (offset_in_param, _, data_before, data_after) in param.device_ptrs {
let path = format!("param_{i}_ptr_{offset_in_param}_pre.bin");
let mut header = Header::new_gnu();
header.set_size(data_before.len() as u64);
builder.append_data(&mut header, &*path, &*data_before)?;
let path = format!("param_{i}_ptr_{offset_in_param}_post.bin");
let mut header = Header::new_gnu();
header.set_size(data_after.len() as u64);
builder.append_data(&mut header, &*path, &*data_after)?;
}
}
builder.finish()?;
builder.into_inner()?.finish()?;
Ok(())
}