This commit is contained in:
Andrzej Janik 2025-09-23 23:24:18 +02:00 committed by GitHub
commit 1dc09827e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 1022 additions and 190 deletions

32
Cargo.lock generated
View file

@ -420,7 +420,7 @@ version = "0.0.0"
dependencies = [
"proc-macro2",
"quote",
"rustc-hash 1.1.0",
"rustc-hash 2.0.0",
"syn 2.0.89",
]
@ -3706,7 +3706,7 @@ dependencies = [
"paste",
"ptx",
"ptx_parser",
"rustc-hash 1.1.0",
"rustc-hash 2.0.0",
"serde",
"serde_json",
"tempfile",
@ -3726,7 +3726,7 @@ dependencies = [
"prettyplease",
"proc-macro2",
"quote",
"rustc-hash 1.1.0",
"rustc-hash 2.0.0",
"syn 2.0.89",
]
@ -3826,6 +3826,16 @@ dependencies = [
"winapi",
]
[[package]]
name = "zluda_replay"
version = "0.0.0"
dependencies = [
"cuda_macros",
"cuda_types",
"libloading",
"zluda_trace_common",
]
[[package]]
name = "zluda_sparse"
version = "0.0.0"
@ -3854,7 +3864,7 @@ dependencies = [
"ptx",
"ptx_parser",
"regex",
"rustc-hash 1.1.0",
"rustc-hash 2.0.0",
"unwrap_or",
"wchar",
"winapi",
@ -3903,6 +3913,11 @@ dependencies = [
"format",
"libc",
"libloading",
"rustc-hash 2.0.0",
"serde",
"serde_json",
"tar",
"zstd",
]
[[package]]
@ -3979,6 +3994,15 @@ dependencies = [
"simd-adler32",
]
[[package]]
name = "zstd"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a"
dependencies = [
"zstd-safe",
]
[[package]]
name = "zstd-safe"
version = "7.2.4"

View file

@ -37,6 +37,7 @@ members = [
"zluda_inject",
"zluda_ld",
"zluda_ml",
"zluda_replay",
"zluda_redirect",
"zluda_sparse",
"compiler",

View file

@ -219,6 +219,8 @@ pub fn compile_bitcode(
compile_to_exec.set_isa_name(gcn_arch)?;
compile_to_exec.set_language(Language::LlvmIr)?;
let common_options = [
c"-Xlinker",
c"--no-undefined",
c"-mllvm",
c"-ignore-tti-inline-compatible",
// c"-mllvm",

View file

@ -8,7 +8,7 @@ edition = "2021"
quote = "1.0"
syn = { version = "2.0", features = ["full", "visit-mut", "extra-traits"] }
proc-macro2 = "1.0"
rustc-hash = "1.1.0"
rustc-hash = "2.0.0"
[lib]
proc-macro = true

View file

@ -1653,25 +1653,23 @@ impl<'a> MethodEmitContext<'a> {
.ok_or_else(|| error_mismatched_type())?,
);
let src2 = self.resolver.value(src2)?;
self.resolver.with_result(arguments.dst, |dst| {
let vec = unsafe {
LLVMBuildInsertElement(
self.builder,
LLVMGetPoison(dst_type),
llvm_fn(self.builder, src, packed_type, LLVM_UNNAMED.as_ptr()),
LLVMConstInt(LLVMInt32TypeInContext(self.context), 1, false as i32),
LLVM_UNNAMED.as_ptr(),
)
};
unsafe {
LLVMBuildInsertElement(
self.builder,
vec,
llvm_fn(self.builder, src2, packed_type, LLVM_UNNAMED.as_ptr()),
LLVMConstInt(LLVMInt32TypeInContext(self.context), 0, false as i32),
dst,
)
}
let vec = unsafe {
LLVMBuildInsertElement(
self.builder,
LLVMGetPoison(dst_type),
llvm_fn(self.builder, src, packed_type, LLVM_UNNAMED.as_ptr()),
LLVMConstInt(LLVMInt32TypeInContext(self.context), 1, false as i32),
LLVM_UNNAMED.as_ptr(),
)
};
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildInsertElement(
self.builder,
vec,
llvm_fn(self.builder, src2, packed_type, LLVM_UNNAMED.as_ptr()),
LLVMConstInt(LLVMInt32TypeInContext(self.context), 0, false as i32),
dst,
)
})
} else {
self.resolver.with_result(arguments.dst, |dst| unsafe {
@ -2197,7 +2195,7 @@ impl<'a> MethodEmitContext<'a> {
Some(&ast::ScalarType::F32.into()),
vec![(
self.resolver.value(arguments.src)?,
get_scalar_type(self.context, ast::ScalarType::F32.into()),
get_scalar_type(self.context, ast::ScalarType::F32),
)],
)?;
Ok(())
@ -2658,14 +2656,14 @@ impl<'a> MethodEmitContext<'a> {
let load = unsafe { LLVMBuildLoad2(self.builder, from_type, from, LLVM_UNNAMED.as_ptr()) };
unsafe {
LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8);
LLVMSetAlignment(load, cp_size.as_u64() as u32);
}
let extended = unsafe { LLVMBuildZExt(self.builder, load, to_type, LLVM_UNNAMED.as_ptr()) };
unsafe { LLVMBuildStore(self.builder, extended, to) };
let store = unsafe { LLVMBuildStore(self.builder, extended, to) };
unsafe {
LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8);
LLVMSetAlignment(store, cp_size.as_u64() as u32);
}
Ok(())
}
@ -2945,7 +2943,7 @@ fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
Ok(match scope {
ast::MemScope::Cta => c"workgroup",
ast::MemScope::Gpu => c"agent",
ast::MemScope::Sys => c"",
ast::MemScope::Sys => c"system",
ast::MemScope::Cluster => todo!(),
}
.as_ptr())

View file

@ -2,6 +2,7 @@ use derive_more::Display;
use logos::Logos;
use ptx_parser_macros::derive_parser;
use rustc_hash::FxHashMap;
use std::alloc::Layout;
use std::fmt::Debug;
use std::num::{NonZeroU8, ParseFloatError, ParseIntError};
use std::{iter, usize};
@ -226,8 +227,9 @@ fn int_immediate<'a, 'input>(input: &mut PtxParser<'a, 'input>) -> PResult<ast::
take_error((opt(Token::Minus), num).map(|(neg, x)| {
let (num, radix, is_unsigned) = x;
if neg.is_some() {
match i64::from_str_radix(num, radix) {
Ok(x) => Ok(ast::ImmediateValue::S64(-x)),
let full_number = format!("-{num}");
match i64::from_str_radix(&full_number, radix) {
Ok(x) => Ok(ast::ImmediateValue::S64(x)),
Err(err) => Err((ast::ImmediateValue::S64(0), PtxError::from(err))),
}
} else if is_unsigned {
@ -345,7 +347,9 @@ fn reg_or_immediate<'a, 'input>(
.parse_next(stream)
}
pub fn parse_for_errors<'input>(text: &'input str) -> Vec<PtxError<'input>> {
pub fn parse_for_errors_and_params<'input>(
text: &'input str,
) -> (Vec<PtxError<'input>>, FxHashMap<String, Vec<Layout>>) {
let (tokens, mut errors) = lex_with_span_unchecked(text);
let parse_result = {
let state = PtxParserState::new(text, &mut errors);
@ -357,13 +361,30 @@ pub fn parse_for_errors<'input>(text: &'input str) -> Vec<PtxError<'input>> {
.parse(parser)
.map_err(|err| PtxError::Parser(err.into_inner()))
};
match parse_result {
Ok(_) => {}
let params = match parse_result {
Ok(module) => module
.directives
.into_iter()
.filter_map(|directive| {
if let ast::Directive::Method(_, func) = directive {
let layouts = func
.func_directive
.input_arguments
.iter()
.map(|arg| arg.info.v_type.layout())
.collect();
Some((func.func_directive.name().to_string(), layouts))
} else {
None
}
})
.collect(),
Err(err) => {
errors.push(err);
FxHashMap::default()
}
}
errors
};
(errors, params)
}
fn lex_with_span_unchecked<'input>(

View file

@ -22,7 +22,7 @@ num_enum = "0.4"
lz4-sys = "1.9"
tempfile = "3"
paste = "1.0"
rustc-hash = "1.1"
rustc-hash = "2.0.0"
zluda_common = { path = "../zluda_common" }
blake3 = "1.8.2"
serde = "1.0.219"

View file

@ -9,6 +9,6 @@ syn = { version = "2.0", features = ["full", "visit-mut"] }
proc-macro2 = "1.0.89"
quote = "1.0"
prettyplease = "0.2.25"
rustc-hash = "1.1.0"
rustc-hash = "2.0.0"
libloading = "0.8"
cuda_types = { path = "../cuda_types" }

17
zluda_replay/Cargo.toml Normal file
View file

@ -0,0 +1,17 @@
[package]
name = "zluda_replay"
version = "0.0.0"
authors = ["Andrzej Janik <vosen@vosen.pl>"]
edition = "2021"
[[bin]]
name = "zluda_replay"
[dependencies]
zluda_trace_common = { path = "../zluda_trace_common" }
cuda_macros = { path = "../cuda_macros" }
cuda_types = { path = "../cuda_types" }
libloading = "0.8"
[package.metadata.zluda]
debug_only = true

103
zluda_replay/src/main.rs Normal file
View file

@ -0,0 +1,103 @@
use std::mem;
use cuda_types::cuda::{CUdeviceptr_v2, CUstream};
struct CudaDynamicFns {
handle: libloading::Library,
}
impl CudaDynamicFns {
unsafe fn new(path: &str) -> Result<Self, libloading::Error> {
let handle = libloading::Library::new(path)?;
Ok(Self { handle })
}
}
macro_rules! emit_cuda_fn_table {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
impl CudaDynamicFns {
$(
#[allow(dead_code)]
unsafe fn $fn_name(&self, $($arg_id : $arg_type),*) -> $ret_type {
let func = self.handle.get::<unsafe extern $abi fn ($($arg_type),*) -> $ret_type>(concat!(stringify!($fn_name), "\0").as_bytes());
(func.unwrap())($($arg_id),*)
}
)*
}
};
}
cuda_macros::cuda_function_declarations!(emit_cuda_fn_table);
fn main() {
let args: Vec<String> = std::env::args().collect();
let libcuda = unsafe { CudaDynamicFns::new(&args[1]).unwrap() };
unsafe { libcuda.cuInit(0) }.unwrap();
unsafe { libcuda.cuCtxCreate_v2(&mut mem::zeroed(), 0, 0) }.unwrap();
let reader = std::fs::File::open(&args[2]).unwrap();
let (mut manifest, mut source, mut buffers) = zluda_trace_common::replay::load(reader);
let mut args = manifest
.parameters
.iter()
.enumerate()
.map(|(i, param)| {
let mut buffer = buffers.remove(&format!("param_{i}.bin")).unwrap();
for param_ptr in param.pointer_offsets.iter() {
let buffer_param_slice = &mut buffer[param_ptr.offset_in_param
..param_ptr.offset_in_param + std::mem::size_of::<usize>()];
let mut dev_ptr = unsafe { mem::zeroed() };
let host_buffer = buffers
.remove(&format!(
"param_{i}_ptr_{}_pre.bin",
param_ptr.offset_in_param
))
.unwrap();
unsafe { libcuda.cuMemAlloc_v2(&mut dev_ptr, host_buffer.len()) }.unwrap();
unsafe {
libcuda.cuMemcpyHtoD_v2(dev_ptr, host_buffer.as_ptr().cast(), host_buffer.len())
}
.unwrap();
dev_ptr = CUdeviceptr_v2(unsafe {
dev_ptr
.0
.cast::<u8>()
.add(param_ptr.offset_in_buffer)
.cast()
});
buffer_param_slice.copy_from_slice(&(dev_ptr.0 as usize).to_ne_bytes());
}
buffer
})
.collect::<Vec<_>>();
let mut module = unsafe { mem::zeroed() };
std::fs::write("/tmp/source.ptx", &source).unwrap();
source.push('\0');
unsafe { libcuda.cuModuleLoadData(&mut module, source.as_ptr().cast()) }.unwrap();
let mut function = unsafe { mem::zeroed() };
manifest.kernel_name.push('\0');
unsafe {
libcuda.cuModuleGetFunction(&mut function, module, manifest.kernel_name.as_ptr().cast())
}
.unwrap();
let mut cuda_args = args
.iter_mut()
.map(|arg| arg.as_mut_ptr().cast::<std::ffi::c_void>())
.collect::<Vec<_>>();
unsafe {
libcuda.cuLaunchKernel(
function,
manifest.config.grid_dim.0,
manifest.config.grid_dim.1,
manifest.config.grid_dim.2,
manifest.config.block_dim.0,
manifest.config.block_dim.1,
manifest.config.block_dim.2,
manifest.config.shared_mem_bytes,
CUstream(std::ptr::null_mut()),
cuda_args.as_mut_ptr().cast(),
std::ptr::null_mut(),
)
}
.unwrap();
unsafe { libcuda.cuCtxSynchronize() }.unwrap();
}

View file

@ -24,7 +24,7 @@ paste = "1.0"
cuda_macros = { path = "../cuda_macros" }
cuda_types = { path = "../cuda_types" }
parking_lot = "0.12.3"
rustc-hash = "1.1.0"
rustc-hash = "2.0.0"
cglue = "0.3.5"
zstd-safe = { version = "7.2.4", features = ["std"] }
unwrap_or = "1.0.1"

View file

@ -12,6 +12,7 @@ use std::ptr::NonNull;
use std::sync::LazyLock;
use std::{env, error::Error, fs, path::PathBuf, sync::Mutex};
use std::{io, mem, ptr, usize};
use unwrap_or::unwrap_some_or;
extern crate cuda_types;
@ -110,7 +111,7 @@ macro_rules! override_fn_core {
).ok();
formatted_args
};
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| Some(());
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| Some(((), ()));
let cuda_call = |_| {
paste!{ [<$fn_name _impl >] ( $($arg_id),* ) }
};
@ -121,7 +122,7 @@ macro_rules! override_fn_core {
format_curesult,
extract_fn_ptr,
cuda_call,
move |_, _, _, _| {}
move |_, _, _, _, _| {}
)
}
)*
@ -157,9 +158,9 @@ impl ::dark_api::zluda_trace::CudaDarkApi for InternalTableImpl {
Some(|| args.call().to_vec()),
internal_error,
|status| format_status(status).to_vec(),
|_, _| Some(()),
|_, _| Some(((), ())),
|_| fn_.call(),
move |_, _, _, _| {},
move |_, _, _, _, _| {},
)
}
}
@ -201,7 +202,7 @@ macro_rules! dark_api_fn_redirect_log {
).ok();
formatted_args
};
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| { Some(()) };
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| { Some(((), ())) };
let cuda_call = |_: () | {
ReprUsize::to_usize(original_fn( $( $arg_id ),* ))
};
@ -215,7 +216,7 @@ macro_rules! dark_api_fn_redirect_log {
|status| <$ret_type as ReprUsize>::format_status(status).to_vec(),
extract_fn_ptr,
cuda_call,
move |_, _, _, _| {}
move |_, _, _, _, _| {}
))
}
)+
@ -256,7 +257,7 @@ macro_rules! dark_api_fn_redirect_log_post {
).ok();
formatted_args
};
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| { Some(()) };
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| { Some(((), ())) };
let cuda_call = |_: () | {
ReprUsize::to_usize(original_fn( $( $arg_id ),* ))
};
@ -270,7 +271,7 @@ macro_rules! dark_api_fn_redirect_log_post {
|status| <$ret_type as ReprUsize>::format_status(status).to_vec(),
extract_fn_ptr,
cuda_call,
move |state, logger, _, cuda_result| paste! { Self:: [<$fn_ _post>] } ( $( $arg_id ),* , &mut state.cuda_state, logger, <$ret_type as ReprUsize>::from_usize(cuda_result))
move |state, logger, _, _, cuda_result| paste! { Self:: [<$fn_ _post>] } ( $( $arg_id ),* , &mut state.cuda_state, logger, <$ret_type as ReprUsize>::from_usize(cuda_result))
))
}
)+
@ -287,7 +288,11 @@ impl DarkApiTrace {
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
state.record_new_library(unsafe { *module }, fatbinc_wrapper.cast(), fn_logger)
state.record_new_library(
unsafe { *module }.0.cast(),
fatbinc_wrapper.cast(),
fn_logger,
)
}
fn get_module_from_cubin_ext1_post(
@ -321,7 +326,11 @@ impl DarkApiTrace {
observed: UInt::U32(arg5),
});
}
state.record_new_library(unsafe { *module }, fatbinc_wrapper.cast(), fn_logger)
state.record_new_library(
unsafe { *module }.0.cast(),
fatbinc_wrapper.cast(),
fn_logger,
)
}
fn get_module_from_cubin_ext2_post(
@ -355,7 +364,7 @@ impl DarkApiTrace {
observed: UInt::U32(arg5),
});
}
state.record_new_library(unsafe { *module }, fatbin_header.cast(), fn_logger)
state.record_new_library(unsafe { *module }.0.cast(), fatbin_header.cast(), fn_logger)
}
}
@ -770,7 +779,7 @@ macro_rules! extern_redirect {
};
let extract_fn_ptr = |state: &mut GlobalDelayedState, _: &mut FnCallLog| {
paste::paste! {
state.libcuda. [<get_ $fn_name>]()
state.libcuda. [<get_ $fn_name>]().map(|x| ((), x) )
}
};
let cuda_call = |fn_ptr: extern $abi fn ( $($arg_type),* ) -> $ret_type | {
@ -783,7 +792,7 @@ macro_rules! extern_redirect {
format_curesult,
extract_fn_ptr,
cuda_call,
move |_, _, _, _| {}
move |_, _, _, _, _| {}
)
}
)*
@ -806,7 +815,7 @@ macro_rules! extern_redirect_with_post {
};
let extract_fn_ptr = |state: &mut GlobalDelayedState, _: &mut FnCallLog| {
paste::paste! {
state.libcuda. [<get_ $fn_name>]()
state.libcuda. [<get_ $fn_name>]().map(|x| ((), x) )
}
};
let cuda_call = |fn_ptr: extern $abi fn ( $($arg_type),* ) -> $ret_type | {
@ -819,7 +828,43 @@ macro_rules! extern_redirect_with_post {
format_curesult,
extract_fn_ptr,
cuda_call,
move |state, logger, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , &mut state.cuda_state, logger, cuda_result )
move |state, logger, _, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , &mut state.cuda_state, logger, cuda_result )
)
}
)*
};
}
macro_rules! extern_redirect_with_pre_post {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
$(
#[no_mangle]
#[allow(improper_ctypes_definitions)]
pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
let format_args = || {
let mut formatted_args = Vec::new();
(paste! { format :: [<write_ $fn_name>] }) (
&mut formatted_args
$(,$arg_id)*
).ok();
formatted_args
};
let extract_fn_ptr = |state: &mut GlobalDelayedState, logger: &mut FnCallLog| {
paste::paste! {
state.libcuda. [<get_ $fn_name>]().map(|x| (paste! { [<$fn_name _Pre>] } ( $( $arg_id ),* , &mut state.libcuda, &mut state.cuda_state, logger ), x ))
}
};
let cuda_call = |fn_ptr: extern $abi fn ( $($arg_type),* ) -> $ret_type | {
fn_ptr( $( $arg_id ),* )
};
GlobalState2::under_lock(
CudaFunctionName::Normal(stringify!($fn_name)),
Some(format_args),
CUresult::INTERNAL_ERROR,
format_curesult,
extract_fn_ptr,
cuda_call,
move |state, logger, pre_state, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , pre_state, &mut state.libcuda, &mut state.cuda_state, logger, cuda_result )
)
}
)*
@ -843,13 +888,15 @@ cuda_function_declarations!(
cuModuleLoad,
cuModuleLoadData,
cuModuleLoadDataEx,
cuLibraryGetFunction,
cuModuleGetFunction,
cuDeviceGetAttribute,
cuDeviceComputeCapability,
cuModuleLoadFatBinary,
cuLibraryGetModule,
cuLibraryLoadData
cuLibraryLoadData,
],
extern_redirect_with_pre_post <= [cuLaunchKernel, cuLaunchKernelEx],
override_fn_core <= [cuGetProcAddress, cuGetProcAddress_v2],
override_fn_full <= [cuGetExportTable],
);
@ -859,6 +906,7 @@ mod log;
#[cfg_attr(windows, path = "os_win.rs")]
#[cfg_attr(not(windows), path = "os_unix.rs")]
mod os;
mod replay;
mod trace;
struct GlobalState2 {
@ -907,27 +955,33 @@ impl GlobalState2 {
// * Post-call:
// We log the output of the CUDA function and any errors that may have occurred. This phase
// is also covered by a drop guard which will flush the log buffer in case of panic
fn under_lock<'a, FnPtr: Copy, InnerResult: Copy>(
fn under_lock<'a, PreState, FnPtr: Copy, InnerResult: Copy>(
name: CudaFunctionName,
args: Option<impl FnOnce() -> Vec<u8>>,
internal_error: InnerResult,
format_status: impl FnOnce(InnerResult) -> Vec<u8>,
pre_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog) -> Option<FnPtr>,
pre_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog) -> Option<(PreState, FnPtr)>,
inner_call: impl FnOnce(FnPtr) -> InnerResult,
post_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog, FnPtr, InnerResult),
post_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog, PreState, FnPtr, InnerResult),
) -> InnerResult {
fn under_lock_impl<'a, FnPtr: Copy, InnerResult: Copy>(
fn under_lock_impl<'a, PreState, FnPtr: Copy, InnerResult: Copy>(
name: CudaFunctionName,
args: Option<impl FnOnce() -> Vec<u8>>,
internal_error: InnerResult,
format_status: impl FnOnce(InnerResult) -> Vec<u8>,
pre_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog) -> Option<FnPtr>,
pre_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog) -> Option<(PreState, FnPtr)>,
inner_call: impl FnOnce(FnPtr) -> InnerResult,
post_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog, FnPtr, InnerResult),
post_call: impl FnOnce(
&mut GlobalDelayedState,
&mut FnCallLog,
PreState,
FnPtr,
InnerResult,
),
) -> InnerResult {
let global_state = GLOBAL_STATE2.lock();
let global_state_ref_cell = &*global_state;
let pre_value = {
let (pre_state, pre_ptr) = {
let mut global_state_ref_mut = global_state_ref_cell.borrow_mut();
let global_state = &mut *global_state_ref_mut;
let panic_guard = OuterCallGuard {
@ -963,7 +1017,7 @@ impl GlobalState2 {
}
};
let panic_guard = InnerCallGuard(global_state_ref_cell);
let inner_result = inner_call(pre_value);
let inner_result = inner_call(pre_ptr);
let global_state = &mut *global_state_ref_cell.borrow_mut();
mem::forget(panic_guard);
let _drop_guard = OuterCallGuard {
@ -978,7 +1032,8 @@ impl GlobalState2 {
post_call(
global_state.delayed_state.as_mut().unwrap(),
&mut logger,
pre_value,
pre_state,
pre_ptr,
inner_result,
);
inner_result
@ -1098,6 +1153,22 @@ impl FnCallLog {
}
}
fn try_cuda(&mut self, fn_: impl FnOnce() -> Option<CUresult>) -> Option<()> {
match fn_() {
Some(Ok(())) => Some(()),
None => {
self.subcalls
.push(LogEntry::Error(ErrorEntry::CudaError(None)));
None
}
Some(Err(err)) => {
self.subcalls
.push(LogEntry::Error(ErrorEntry::CudaError(Some(err))));
None
}
}
}
fn try_<T>(&mut self, f: impl FnOnce(&mut Self) -> Result<T, ErrorEntry>) -> Option<T> {
match f(self) {
Err(e) => {
@ -1209,6 +1280,8 @@ struct Settings {
dump_dir: Option<PathBuf>,
libcuda_path: String,
override_cc: Option<(u32, u32)>,
kernel_name_filter: Option<regex::Regex>,
kernel_no_output: Option<bool>,
}
impl Settings {
@ -1257,10 +1330,42 @@ impl Settings {
})
}),
};
let kernel_name_filter = match env::var("ZLUDA_SAVE_KERNELS") {
Err(env::VarError::NotPresent) => None,
Err(e) => {
logger.log(log::ErrorEntry::ErrorBox(Box::new(e) as _));
None
}
Ok(env_string) => logger.try_return(|| {
regex::Regex::new(&env_string).map_err(|e| ErrorEntry::InvalidEnvVar {
var: "ZLUDA_SAVE_KERNELS",
pattern: "valid regex",
value: format!("{} ({})", env_string, e),
})
}),
};
let kernel_no_output = match env::var("ZLUDA_SAVE_KERNELS_NO_OUTPUT") {
Err(env::VarError::NotPresent) => None,
Err(e) => {
logger.log(log::ErrorEntry::ErrorBox(Box::new(e) as _));
None
}
Ok(env_string) => logger
.try_return(|| {
str::parse::<u8>(&env_string).map_err(|err| ErrorEntry::InvalidEnvVar {
var: "ZLUDA_SAVE_KERNELS_NO_OUTPUT",
pattern: "number",
value: format!("{} ({})", env_string, err),
})
})
.map(|x| x != 0),
};
Settings {
dump_dir,
libcuda_path,
override_cc,
kernel_name_filter,
kernel_no_output,
}
}
@ -1307,7 +1412,7 @@ pub(crate) fn cuModuleLoadData_Post(
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
state.record_new_library(unsafe { *module }, raw_image, fn_logger)
state.record_new_library(unsafe { *module }.0.cast(), raw_image, fn_logger)
}
#[allow(non_snake_case)]
@ -1326,13 +1431,17 @@ pub(crate) fn cuModuleLoadDataEx_Post(
#[allow(non_snake_case)]
pub(crate) fn cuModuleGetFunction_Post(
_hfunc: *mut CUfunction,
_hmod: CUmodule,
_name: *const ::std::os::raw::c_char,
_state: &mut trace::StateTracker,
_fn_logger: &mut FnCallLog,
_result: CUresult,
hfunc: *mut CUfunction,
hmod: CUmodule,
name: *const ::std::os::raw::c_char,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
result: CUresult,
) {
if !result.is_ok() {
return;
}
state.record_function_from_module(fn_logger, unsafe { *hfunc }, hmod, name);
}
#[allow(non_snake_case)]
@ -1385,7 +1494,7 @@ pub(crate) fn cuModuleLoadFatBinary_Post(
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
state.record_new_library(unsafe { *module }, fatbin_header.cast(), fn_logger)
state.record_new_library(unsafe { *module }.0.cast(), fatbin_header.cast(), fn_logger)
}
#[allow(non_snake_case)]
@ -1393,13 +1502,13 @@ pub(crate) fn cuLibraryGetModule_Post(
module: *mut cuda_types::cuda::CUmodule,
library: cuda_types::cuda::CUlibrary,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
_result: CUresult,
_fn_logger: &mut FnCallLog,
result: CUresult,
) {
match state.libraries.get(&library).copied() {
None => fn_logger.log(log::ErrorEntry::UnknownLibrary(library)),
Some(code) => state.record_new_library(unsafe { *module }, code.0, fn_logger),
if !result.is_ok() {
return;
}
state.record_module_in_library(unsafe { *module }, library);
}
#[allow(non_snake_case)]
@ -1416,10 +1525,149 @@ pub(crate) fn cuLibraryLoadData_Post(
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
state
.libraries
.insert(unsafe { *library }, trace::CodePointer(code));
// TODO: this is not correct, but it's enough for now, we just want to
// save the binary to disk
state.record_new_library(unsafe { CUmodule((*library).0.cast()) }, code, fn_logger);
state.record_new_library(unsafe { *library }.0.cast(), code, fn_logger);
}
#[allow(non_snake_case)]
pub(crate) fn cuLaunchKernel_Pre(
f: cuda_types::cuda::CUfunction,
gridDimX: ::core::ffi::c_uint,
gridDimY: ::core::ffi::c_uint,
gridDimZ: ::core::ffi::c_uint,
blockDimX: ::core::ffi::c_uint,
blockDimY: ::core::ffi::c_uint,
blockDimZ: ::core::ffi::c_uint,
sharedMemBytes: ::core::ffi::c_uint,
hStream: cuda_types::cuda::CUstream,
kernel_params: *mut *mut ::core::ffi::c_void,
_extra: *mut *mut ::core::ffi::c_void,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
) -> Option<replay::LaunchPreState> {
launch_kernel_pre(
f,
CUlaunchConfig {
gridDimX,
gridDimY,
gridDimZ,
blockDimX,
blockDimY,
blockDimZ,
sharedMemBytes,
hStream,
attrs: ptr::null_mut(),
numAttrs: 0,
},
hStream,
kernel_params,
libcuda,
state,
fn_logger,
)
}
fn launch_kernel_pre(
f: cuda_types::cuda::CUfunction,
config: CUlaunchConfig,
stream: cuda_types::cuda::CUstream,
kernel_params: *mut *mut ::core::ffi::c_void,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
) -> Option<replay::LaunchPreState> {
state.enqueue_counter += 1;
if kernel_params.is_null() {
fn_logger.log(ErrorEntry::NullPointer("kernel_params"));
return None;
}
if state.dump_dir().is_none() {
return None;
}
replay::pre_kernel_launch(libcuda, state, fn_logger, config, f, stream, kernel_params)
}
#[allow(non_snake_case)]
pub(crate) fn cuLaunchKernel_Post(
_f: cuda_types::cuda::CUfunction,
gridDimX: ::core::ffi::c_uint,
gridDimY: ::core::ffi::c_uint,
gridDimZ: ::core::ffi::c_uint,
blockDimX: ::core::ffi::c_uint,
blockDimY: ::core::ffi::c_uint,
blockDimZ: ::core::ffi::c_uint,
sharedMemBytes: ::core::ffi::c_uint,
hStream: cuda_types::cuda::CUstream,
kernel_params: *mut *mut ::core::ffi::c_void,
_extra: *mut *mut ::core::ffi::c_void,
pre_state: Option<replay::LaunchPreState>,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
let pre_state = unwrap_some_or!(pre_state, return);
replay::post_kernel_launch(
libcuda,
state,
fn_logger,
CUlaunchConfig {
gridDimX,
gridDimY,
gridDimZ,
blockDimX,
blockDimY,
blockDimZ,
sharedMemBytes,
hStream,
attrs: ptr::null_mut(),
numAttrs: 0,
},
kernel_params,
pre_state,
);
}
#[allow(non_snake_case)]
pub(crate) fn cuLaunchKernelEx_Pre(
config: *const cuda_types::cuda::CUlaunchConfig,
f: cuda_types::cuda::CUfunction,
kernel_params: *mut *mut ::core::ffi::c_void,
_extra: *mut *mut ::core::ffi::c_void,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
) -> Option<replay::LaunchPreState> {
launch_kernel_pre(
f,
unsafe { *config },
unsafe { *config }.hStream,
kernel_params,
libcuda,
state,
fn_logger,
)
}
#[allow(non_snake_case)]
pub(crate) fn cuLaunchKernelEx_Post(
config: *const cuda_types::cuda::CUlaunchConfig,
_f: cuda_types::cuda::CUfunction,
kernel_params: *mut *mut ::core::ffi::c_void,
_extra: *mut *mut ::core::ffi::c_void,
pre_state: Option<replay::LaunchPreState>,
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
_result: CUresult,
) {
let pre_state = unwrap_some_or!(pre_state, return);
replay::post_kernel_launch(
libcuda,
state,
fn_logger,
unsafe { *config },
kernel_params,
pre_state,
);
}

View file

@ -1,8 +1,8 @@
use super::Settings;
use crate::trace::SendablePtr;
use crate::FnCallLog;
use crate::LogEntry;
use cuda_types::cuda::*;
use format::CudaDisplay;
use std::error::Error;
use std::ffi::c_void;
use std::ffi::NulError;
@ -267,13 +267,12 @@ pub(crate) enum ErrorEntry {
CreatedDumpDirectory(PathBuf),
ErrorBox(Box<dyn Error>),
UnsupportedModule {
module: CUmodule,
handle: *mut c_void,
raw_image: *const c_void,
kind: &'static str,
},
FunctionNotFound(CudaFunctionName),
MalformedModulePath(Utf8Error),
NonUtf8ModuleText(Utf8Error),
Utf8Error(Utf8Error),
NulInsideModuleText(NulError),
ModuleParsingError(String),
Lz4DecompressionFailure,
@ -302,8 +301,11 @@ pub(crate) enum ErrorEntry {
overriden: [u64; 2],
},
NullPointer(&'static str),
UnknownLibrary(CUlibrary),
SavedModule(String),
UnknownFunctionHandle(CUfunction),
UnknownLibrary(CUfunction, SendablePtr),
UnknownFunction(CUfunction, SendablePtr, String),
CudaError(Option<CUerror>),
}
unsafe impl Send for ErrorEntry {}
@ -345,94 +347,100 @@ impl Display for ErrorEntry {
match self {
ErrorEntry::IoError(e) => e.fmt(f),
ErrorEntry::CreatedDumpDirectory(dir) => {
write!(
f,
"Created trace directory {} ",
dir.as_os_str().to_string_lossy()
)
}
write!(
f,
"Created trace directory {} ",
dir.as_os_str().to_string_lossy()
)
}
ErrorEntry::ErrorBox(e) => e.fmt(f),
ErrorEntry::UnsupportedModule {
module,
raw_image,
kind,
} => {
write!(
f,
"Unsupported {} module {:?} loaded from module image {:?}",
kind, module, raw_image
)
}
ErrorEntry::MalformedModulePath(e) => e.fmt(f),
ErrorEntry::NonUtf8ModuleText(e) => e.fmt(f),
handle,
raw_image,
kind,
} => {
write!(
f,
"Unsupported {} module {:p} loaded from module image {:p}",
kind, handle, raw_image
)
}
ErrorEntry::Utf8Error(e) => e.fmt(f),
ErrorEntry::ModuleParsingError(file_name) => {
write!(
f,
"Error parsing module, log has been written to {}",
file_name
)
}
write!(
f,
"Error parsing module, log has been written to {}",
file_name
)
}
ErrorEntry::NulInsideModuleText(e) => e.fmt(f),
ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"),
ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)),
ErrorEntry::UnexpectedBinaryField {
field_name,
expected,
observed,
} => write!(
f,
"Unexpected field {}. Expected one of: [{}], observed: {}",
field_name,
expected
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", "),
observed
),
field_name,
expected,
observed,
} => write!(
f,
"Unexpected field {}. Expected one of: [{}], observed: {}",
field_name,
expected
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", "),
observed
),
ErrorEntry::UnexpectedArgument {
arg_name,
expected,
observed,
} => write!(
f,
"Unexpected argument {}. Expected one of: {{{}}}, observed: {}",
arg_name,
expected
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", "),
observed
),
arg_name,
expected,
observed,
} => write!(
f,
"Unexpected argument {}. Expected one of: {{{}}}, observed: {}",
arg_name,
expected
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", "),
observed
),
ErrorEntry::InvalidEnvVar {
var,
pattern,
value,
} => write!(
f,
"Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}"
),
var,
pattern,
value,
} => write!(
f,
"Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}"
),
ErrorEntry::FunctionNotFound(cuda_function_name) => write!(
f,
"No function {cuda_function_name} in the underlying library"
),
f,
"No function {cuda_function_name} in the underlying library"
),
ErrorEntry::UnexpectedExportTableSize { expected, computed } => {
write!(f, "Table length mismatch. Expected: {expected}, got: {computed}")
}
write!(f, "Table length mismatch. Expected: {expected}, got: {computed}")
}
ErrorEntry::IntegrityCheck { original, overriden } => {
write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}")
}
write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}")
}
ErrorEntry::NullPointer(type_) => {
write!(f, "Null pointer of type {type_} encountered")
}
ErrorEntry::UnknownLibrary(culibrary) => {
write!(f, "Unknown library: ")?;
let mut temp_buffer = Vec::new();
CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok();
f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) })
}
write!(f, "Null pointer of type {type_} encountered")
}
ErrorEntry::SavedModule(file) => write!(f, "Saved module to {file}"),
ErrorEntry::UnknownFunctionHandle(cuda_function_name) => {
write!(f, "Function with unknown provenance: {cuda_function_name:p}")
}
ErrorEntry::UnknownLibrary(cuda_function_name, owner) => {
write!(f, "Function with unknown provenance: {cuda_function_name:p}, owner: {owner:p}")
}
ErrorEntry::UnknownFunction(cuda_function_name, owner, name) => {
write!(f, "Function with unknown provenance: {cuda_function_name:p}, owner: {owner:p}, name: {name}")
}
ErrorEntry::CudaError(cuerror) => {
let cuerror = cuerror.map(|e| e.0);
write!(f, "CUDA error encountered: {cuerror:#?}")
},
}
}
}

171
zluda_trace/src/replay.rs Normal file
View file

@ -0,0 +1,171 @@
use crate::{
log::ErrorEntry,
trace::{self, ParsedModule, SavedKernel},
CudaDynamicFns, FnCallLog,
};
use cuda_types::cuda::*;
use zluda_trace_common::replay::KernelParameter;
pub struct LaunchPreState {
kernel_name: String,
source: String,
kernel_params: Vec<KernelParameter>,
}
pub(crate) fn pre_kernel_launch(
libcuda: &mut CudaDynamicFns,
state: &mut trace::StateTracker,
fn_logger: &mut FnCallLog,
config: CUlaunchConfig,
f: CUfunction,
stream: CUstream,
args: *mut *mut std::ffi::c_void,
) -> Option<LaunchPreState> {
fn_logger.try_cuda(|| libcuda.cuStreamSynchronize(stream))?;
let SavedKernel { name, owner } = fn_logger.try_return(|| {
state
.kernels
.get(&f)
.ok_or(ErrorEntry::UnknownFunctionHandle(f))
})?;
let kernel_name_filter = state.kernel_name_filter.as_ref()?;
if !kernel_name_filter.is_match(name) {
return None;
}
let ParsedModule { source, kernels } = fn_logger.try_return(|| {
state
.parsed_libraries
.get(owner)
.ok_or(ErrorEntry::UnknownLibrary(f, *owner))
})?;
let kernel_params = fn_logger.try_return(|| {
kernels
.get(name)
.ok_or_else(|| ErrorEntry::UnknownFunction(f, *owner, name.clone()))
})?;
let raw_args = unsafe { std::slice::from_raw_parts(args, kernel_params.len()) };
let mut all_params = Vec::new();
for (raw_arg, layout) in raw_args.iter().zip(kernel_params.iter()) {
let mut offset = 0;
let mut ptr_overrides = Vec::new();
while offset + std::mem::size_of::<usize>() <= layout.size() {
let maybe_ptr = unsafe { raw_arg.cast::<u8>().add(offset) };
let maybe_ptr = unsafe { maybe_ptr.cast::<usize>().read_unaligned() };
let attrs = &mut [
CUpointer_attribute_enum::CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
CUpointer_attribute_enum::CU_POINTER_ATTRIBUTE_RANGE_SIZE,
];
let mut start = 0usize;
let mut size = 0usize;
let mut data = [
(&mut start as *mut usize).cast::<std::ffi::c_void>(),
(&mut size as *mut usize).cast::<std::ffi::c_void>(),
];
fn_logger.try_cuda(|| {
libcuda.cuPointerGetAttributes(
2,
attrs.as_mut_ptr(),
data.as_mut_ptr(),
CUdeviceptr_v2(maybe_ptr as _),
)
})?;
if size != 0 {
let mut pre_buffer = vec![0u8; size];
let post_buffer = vec![0u8; size];
fn_logger.try_cuda(|| {
libcuda.cuMemcpyDtoH_v2(
pre_buffer.as_mut_ptr().cast(),
CUdeviceptr_v2(start as _),
size,
)
})?;
let buffer_offset = maybe_ptr - start;
ptr_overrides.push((offset, buffer_offset, pre_buffer, post_buffer));
}
offset += std::mem::size_of::<usize>();
}
all_params.push(KernelParameter {
data: unsafe { std::slice::from_raw_parts(raw_arg.cast::<u8>(), layout.size()) }
.to_vec(),
device_ptrs: ptr_overrides,
});
}
if state.kernel_no_output {
let enqueue_counter = state.enqueue_counter;
let kernel_name = name;
let mut path = state.dump_dir()?.to_path_buf();
path.push(format!("kernel_{enqueue_counter}_{kernel_name}.tar.zst"));
let file = fn_logger
.try_return(|| std::fs::File::create_new(path).map_err(ErrorEntry::IoError))?;
fn_logger.try_return(|| {
zluda_trace_common::replay::save(
file,
name.to_string(),
false,
zluda_trace_common::replay::LaunchConfig {
grid_dim: (config.gridDimX, config.gridDimY, config.gridDimZ),
block_dim: (config.blockDimX, config.blockDimY, config.blockDimZ),
shared_mem_bytes: config.sharedMemBytes,
},
source.to_string(),
all_params,
)
.map_err(ErrorEntry::IoError)
});
None
} else {
Some(LaunchPreState {
kernel_name: name.to_string(),
source: source.to_string(),
kernel_params: all_params,
})
}
}
pub(crate) fn post_kernel_launch(
libcuda: &mut CudaDynamicFns,
state: &trace::StateTracker,
fn_logger: &mut FnCallLog,
config: CUlaunchConfig,
kernel_params: *mut *mut std::ffi::c_void,
mut pre_state: LaunchPreState,
) -> Option<()> {
fn_logger.try_cuda(|| libcuda.cuStreamSynchronize(config.hStream))?;
let raw_args =
unsafe { std::slice::from_raw_parts(kernel_params, pre_state.kernel_params.len()) };
for (raw_arg, param) in raw_args.iter().zip(pre_state.kernel_params.iter_mut()) {
for (offset_in_param, offset_in_buffer, _, data_after) in param.device_ptrs.iter_mut() {
let dev_ptr_param = unsafe { raw_arg.cast::<u8>().add(*offset_in_param) };
let mut dev_ptr = unsafe { dev_ptr_param.cast::<usize>().read_unaligned() };
dev_ptr -= *offset_in_buffer;
fn_logger.try_cuda(|| {
libcuda.cuMemcpyDtoH_v2(
data_after.as_mut_ptr().cast(),
CUdeviceptr_v2(dev_ptr as _),
data_after.len(),
)
})?;
}
}
let enqueue_counter = state.enqueue_counter;
let kernel_name = &pre_state.kernel_name;
let mut path = state.dump_dir()?.to_path_buf();
path.push(format!("kernel_{enqueue_counter}_{kernel_name}.tar.zst"));
let file =
fn_logger.try_return(|| std::fs::File::create_new(path).map_err(ErrorEntry::IoError))?;
fn_logger.try_return(|| {
zluda_trace_common::replay::save(
file,
pre_state.kernel_name,
true,
zluda_trace_common::replay::LaunchConfig {
grid_dim: (config.gridDimX, config.gridDimY, config.gridDimZ),
block_dim: (config.blockDimX, config.blockDimY, config.blockDimZ),
shared_mem_bytes: config.sharedMemBytes,
},
pre_state.source,
pre_state.kernel_params,
)
.map_err(ErrorEntry::IoError)
})
}

View file

@ -4,8 +4,9 @@ use crate::{
};
use cuda_types::cuda::*;
use goblin::{elf, elf32, elf64};
use rustc_hash::{FxHashMap, FxHashSet};
use rustc_hash::FxHashMap;
use std::{
alloc::Layout,
ffi::{c_void, CStr, CString},
fs::{self, File},
io::{self, Read, Write},
@ -20,29 +21,51 @@ use unwrap_or::unwrap_some_or;
// * writes out relevant state change and details to disk and log
pub(crate) struct StateTracker {
writer: DumpWriter,
pub(crate) libraries: FxHashMap<CUlibrary, CodePointer>,
saved_modules: FxHashSet<CUmodule>,
pub(crate) parsed_libraries: FxHashMap<SendablePtr, ParsedModule>,
pub(crate) submodules: FxHashMap<CUmodule, CUlibrary>,
pub(crate) kernels: FxHashMap<CUfunction, SavedKernel>,
library_counter: usize,
pub(crate) enqueue_counter: usize,
pub(crate) override_cc: Option<(u32, u32)>,
pub(crate) kernel_name_filter: Option<regex::Regex>,
pub(crate) kernel_no_output: bool,
}
#[derive(Clone, Copy)]
pub(crate) struct CodePointer(pub *const c_void);
pub(crate) struct ParsedModule {
pub source: String,
pub kernels: FxHashMap<String, Vec<Layout>>,
}
unsafe impl Send for CodePointer {}
unsafe impl Sync for CodePointer {}
pub(crate) struct SavedKernel {
pub name: String,
pub owner: SendablePtr,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) struct SendablePtr(*mut c_void);
unsafe impl Send for SendablePtr {}
unsafe impl Sync for SendablePtr {}
impl StateTracker {
pub(crate) fn new(settings: &Settings) -> Self {
StateTracker {
writer: DumpWriter::new(settings.dump_dir.clone()),
libraries: FxHashMap::default(),
saved_modules: FxHashSet::default(),
parsed_libraries: FxHashMap::default(),
submodules: FxHashMap::default(),
kernels: FxHashMap::default(),
library_counter: 0,
enqueue_counter: 0,
override_cc: settings.override_cc,
kernel_name_filter: settings.kernel_name_filter.clone(),
kernel_no_output: settings.kernel_no_output.unwrap_or(false),
}
}
pub(crate) fn dump_dir(&self) -> Option<&PathBuf> {
self.writer.dump_dir.as_ref()
}
pub(crate) fn record_new_module_file(
&mut self,
module: CUmodule,
@ -52,7 +75,7 @@ impl StateTracker {
let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() {
Ok(f) => f,
Err(err) => {
fn_logger.log(log::ErrorEntry::MalformedModulePath(err));
fn_logger.log(log::ErrorEntry::Utf8Error(err));
return;
}
};
@ -69,21 +92,26 @@ impl StateTracker {
let mut module_file = fs::File::open(file_name)?;
let mut read_buff = Vec::new();
module_file.read_to_end(&mut read_buff)?;
self.record_new_library(module, read_buff.as_ptr() as *const _, fn_logger);
self.record_new_library(module.0.cast(), read_buff.as_ptr() as *const _, fn_logger);
Ok(())
}
pub(crate) fn record_new_library(
&mut self,
cu_module: CUmodule,
handle: *mut c_void,
raw_image: *const c_void,
fn_logger: &mut FnCallLog,
) {
self.saved_modules.insert(cu_module);
fn overwrite<T>(current: &mut Option<T>, value: Option<T>) {
if value.is_some() {
*current = value;
}
}
let mut kernel_arguments = None;
self.library_counter += 1;
let code_ref = fn_logger.try_return(|| {
unsafe { zluda_common::CodeLibraryRef::try_load(raw_image) }
.map_err(ErrorEntry::NonUtf8ModuleText)
.map_err(ErrorEntry::Utf8Error)
});
let code_ref = unwrap_some_or!(code_ref, return);
unsafe {
@ -92,17 +120,20 @@ impl StateTracker {
Ok(zluda_common::CodeModuleRef::Elf(elf)) => match get_elf_size(elf) {
Some(len) => {
let elf_image = std::slice::from_raw_parts(elf.cast::<u8>(), len);
self.record_new_submodule(index, elf_image, fn_logger, "elf");
overwrite(
&mut kernel_arguments,
self.record_new_submodule(index, elf_image, fn_logger, "elf"),
);
}
None => fn_logger.log(log::ErrorEntry::UnsupportedModule {
module: cu_module,
handle,
raw_image: elf,
kind: "ELF",
}),
},
Ok(zluda_common::CodeModuleRef::Archive(archive)) => {
fn_logger.log(log::ErrorEntry::UnsupportedModule {
module: cu_module,
handle,
raw_image: archive,
kind: "archive",
})
@ -111,23 +142,39 @@ impl StateTracker {
if let Some(buffer) = fn_logger
.try_(|_| file.get_or_decompress_content().map_err(ErrorEntry::from))
{
self.record_new_submodule(index, &*buffer, fn_logger, file.kind());
overwrite(
&mut kernel_arguments,
self.record_new_submodule(index, &*buffer, fn_logger, file.kind()),
);
}
}
Ok(zluda_common::CodeModuleRef::Text(ptx)) => {
self.record_new_submodule(index, ptx.as_bytes(), fn_logger, "ptx");
overwrite(
&mut kernel_arguments,
self.record_new_submodule(index, ptx.as_bytes(), fn_logger, "ptx"),
);
}
});
};
if let Some((source, kernel_arguments)) = kernel_arguments {
self.parsed_libraries.insert(
SendablePtr(handle),
ParsedModule {
source,
kernels: kernel_arguments,
},
);
}
}
#[must_use]
pub(crate) fn record_new_submodule(
&mut self,
index: Option<(usize, Option<usize>)>,
submodule: &[u8],
fn_logger: &mut FnCallLog,
type_: &'static str,
) {
) -> Option<(String, FxHashMap<String, Vec<Layout>>)> {
fn_logger.try_(|fn_logger| {
self.writer
.save_module(fn_logger, self.library_counter, index, submodule, type_)
@ -135,28 +182,36 @@ impl StateTracker {
});
if type_ == "ptx" {
match CString::new(submodule) {
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)),
Err(e) => {
fn_logger.log(log::ErrorEntry::NulInsideModuleText(e));
None
}
Ok(submodule_cstring) => match submodule_cstring.to_str() {
Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)),
Ok(submodule_text) => self.try_parse_and_record_kernels(
Err(e) => {
fn_logger.log(log::ErrorEntry::Utf8Error(e));
None
}
Ok(submodule_text) => Some(self.try_parse_and_record_kernels(
fn_logger,
self.library_counter,
index,
submodule_text,
),
)),
},
}
} else {
None
}
}
fn try_parse_and_record_kernels(
fn try_parse_and_record_kernels<'input>(
&mut self,
fn_logger: &mut FnCallLog,
module_index: usize,
submodule_index: Option<(usize, Option<usize>)>,
module_text: &str,
) {
let errors = ptx_parser::parse_for_errors(module_text);
module_text: &'input str,
) -> (String, FxHashMap<String, Vec<Layout>>) {
let (errors, params) = ptx_parser::parse_for_errors_and_params(module_text);
if !errors.is_empty() {
fn_logger.log(log::ErrorEntry::ModuleParsingError(
DumpWriter::get_file_name(module_index, submodule_index, "log"),
@ -167,6 +222,46 @@ impl StateTracker {
&*errors,
));
}
(module_text.to_string(), params)
}
pub(crate) fn record_module_in_library(&mut self, module: CUmodule, library: CUlibrary) {
self.submodules.insert(module, library);
}
pub(crate) fn record_function_from_module(
&mut self,
fn_logger: &mut FnCallLog,
func: CUfunction,
hmod: CUmodule,
name: *const i8,
) {
let owner = match self.submodules.get(&hmod) {
Some(m) => m.0.cast::<c_void>(),
None => hmod.0.cast::<c_void>(),
};
self.record_function_from_impl(fn_logger, func, owner, name);
}
fn record_function_from_impl(
&mut self,
fn_logger: &mut FnCallLog,
func: CUfunction,
owner: *mut c_void,
name: *const i8,
) {
let name = match unsafe { CStr::from_ptr(name) }.to_str() {
Ok(f) => f,
Err(err) => {
fn_logger.log(log::ErrorEntry::Utf8Error(err));
return;
}
};
let saved_kernel = SavedKernel {
name: name.to_string(),
owner: SendablePtr(owner),
};
self.kernels.insert(func, saved_kernel);
}
}

View file

@ -11,6 +11,11 @@ cuda_types = { path = "../cuda_types" }
dark_api = { path = "../dark_api" }
format = { path = "../format" }
cglue = "0.3.5"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.142"
tar = "0.4"
zstd = "0.13"
rustc-hash = "2.0.0"
[target.'cfg(not(windows))'.dependencies]
libc = "0.2"

View file

@ -8,6 +8,8 @@ use cuda_types::{
use dark_api::ByteVecFfi;
use std::{borrow::Cow, ffi::c_void, num::NonZero, ptr, sync::LazyLock};
pub mod replay;
pub fn get_export_table() -> Option<::dark_api::zluda_trace::ZludaTraceInternal> {
static CU_GET_EXPORT_TABLE: LazyLock<
Result<

View file

@ -0,0 +1,137 @@
use rustc_hash::FxHashMap;
use std::io::{Read, Write};
use tar::Header;
#[derive(serde::Serialize, serde::Deserialize)]
pub struct Manifest {
pub kernel_name: String,
pub outputs: bool,
pub config: LaunchConfig,
pub parameters: Vec<Parameter>,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct LaunchConfig {
pub grid_dim: (u32, u32, u32),
pub block_dim: (u32, u32, u32),
pub shared_mem_bytes: u32,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct Parameter {
pub pointer_offsets: Vec<ParameterPointer>,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct ParameterPointer {
pub offset_in_param: usize,
pub offset_in_buffer: usize,
}
impl Manifest {
const PATH: &'static str = "manifest.json";
fn serialize(&self) -> std::io::Result<(Header, Vec<u8>)> {
let vec = serde_json::to_vec(self)?;
let header = tar_header(vec.len());
Ok((header, vec))
}
}
pub struct KernelParameter {
pub data: Vec<u8>,
pub device_ptrs: Vec<(usize, usize, Vec<u8>, Vec<u8>)>,
}
pub fn save(
writer: impl Write,
kernel_name: String,
has_outputs: bool,
config: LaunchConfig,
source: String,
kernel_params: Vec<KernelParameter>,
) -> std::io::Result<()> {
let archive = zstd::Encoder::new(writer, 0)?;
let mut builder = tar::Builder::new(archive);
let (mut header, manifest) = Manifest {
kernel_name,
outputs: has_outputs,
config,
parameters: kernel_params
.iter()
.map(|param| Parameter {
pointer_offsets: param
.device_ptrs
.iter()
.map(
|(offset_in_param, offset_in_buffer, _, _)| ParameterPointer {
offset_in_param: *offset_in_param,
offset_in_buffer: *offset_in_buffer,
},
)
.collect(),
})
.collect(),
}
.serialize()?;
builder.append_data(&mut header, Manifest::PATH, &*manifest)?;
let mut header = tar_header(source.len());
builder.append_data(&mut header, "source.ptx", source.as_bytes())?;
for (i, param) in kernel_params.into_iter().enumerate() {
let path = format!("param_{i}.bin");
let mut header = tar_header(param.data.len());
builder.append_data(&mut header, &*path, &*param.data)?;
for (offset_in_param, _, data_before, data_after) in param.device_ptrs {
let path = format!("param_{i}_ptr_{offset_in_param}_pre.bin");
let mut header = tar_header(data_before.len());
builder.append_data(&mut header, &*path, &*data_before)?;
if !has_outputs {
continue;
}
let path = format!("param_{i}_ptr_{offset_in_param}_post.bin");
let mut header = tar_header(data_after.len());
builder.append_data(&mut header, &*path, &*data_after)?;
}
}
builder.finish()?;
builder.into_inner()?.finish()?;
Ok(())
}
fn tar_header(size: usize) -> Header {
let mut header = Header::new_gnu();
header.set_mode(0o644);
header.set_size(size as u64);
header
}
pub fn load(reader: impl Read) -> (Manifest, String, FxHashMap<String, Vec<u8>>) {
let archive = zstd::Decoder::new(reader).unwrap();
let mut archive = tar::Archive::new(archive);
let mut manifest = None;
let mut source = None;
let mut buffers = FxHashMap::default();
for entry in archive.entries().unwrap() {
let mut entry = entry.unwrap();
let path = entry.path().unwrap().to_string_lossy().to_string();
match &*path {
Manifest::PATH => {
manifest = Some(serde_json::from_reader::<_, Manifest>(&mut entry).unwrap());
}
"source.ptx" => {
let mut string = String::new();
entry.read_to_string(&mut string).unwrap();
dbg!(string.len());
source = Some(string);
}
_ => {
let mut buffer = Vec::new();
entry.read_to_end(&mut buffer).unwrap();
buffers.insert(path, buffer);
}
}
}
let manifest = manifest.unwrap();
let source = source.unwrap();
(manifest, source, buffers)
}