mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-27 11:49:04 +00:00
Merge 7b007074bd
into 93820e3159
This commit is contained in:
commit
1dc09827e2
18 changed files with 1022 additions and 190 deletions
32
Cargo.lock
generated
32
Cargo.lock
generated
|
@ -420,7 +420,7 @@ version = "0.0.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"rustc-hash 1.1.0",
|
"rustc-hash 2.0.0",
|
||||||
"syn 2.0.89",
|
"syn 2.0.89",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -3706,7 +3706,7 @@ dependencies = [
|
||||||
"paste",
|
"paste",
|
||||||
"ptx",
|
"ptx",
|
||||||
"ptx_parser",
|
"ptx_parser",
|
||||||
"rustc-hash 1.1.0",
|
"rustc-hash 2.0.0",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
|
@ -3726,7 +3726,7 @@ dependencies = [
|
||||||
"prettyplease",
|
"prettyplease",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"rustc-hash 1.1.0",
|
"rustc-hash 2.0.0",
|
||||||
"syn 2.0.89",
|
"syn 2.0.89",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -3826,6 +3826,16 @@ dependencies = [
|
||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zluda_replay"
|
||||||
|
version = "0.0.0"
|
||||||
|
dependencies = [
|
||||||
|
"cuda_macros",
|
||||||
|
"cuda_types",
|
||||||
|
"libloading",
|
||||||
|
"zluda_trace_common",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zluda_sparse"
|
name = "zluda_sparse"
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
|
@ -3854,7 +3864,7 @@ dependencies = [
|
||||||
"ptx",
|
"ptx",
|
||||||
"ptx_parser",
|
"ptx_parser",
|
||||||
"regex",
|
"regex",
|
||||||
"rustc-hash 1.1.0",
|
"rustc-hash 2.0.0",
|
||||||
"unwrap_or",
|
"unwrap_or",
|
||||||
"wchar",
|
"wchar",
|
||||||
"winapi",
|
"winapi",
|
||||||
|
@ -3903,6 +3913,11 @@ dependencies = [
|
||||||
"format",
|
"format",
|
||||||
"libc",
|
"libc",
|
||||||
"libloading",
|
"libloading",
|
||||||
|
"rustc-hash 2.0.0",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"tar",
|
||||||
|
"zstd",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3979,6 +3994,15 @@ dependencies = [
|
||||||
"simd-adler32",
|
"simd-adler32",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zstd"
|
||||||
|
version = "0.13.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a"
|
||||||
|
dependencies = [
|
||||||
|
"zstd-safe",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zstd-safe"
|
name = "zstd-safe"
|
||||||
version = "7.2.4"
|
version = "7.2.4"
|
||||||
|
|
|
@ -37,6 +37,7 @@ members = [
|
||||||
"zluda_inject",
|
"zluda_inject",
|
||||||
"zluda_ld",
|
"zluda_ld",
|
||||||
"zluda_ml",
|
"zluda_ml",
|
||||||
|
"zluda_replay",
|
||||||
"zluda_redirect",
|
"zluda_redirect",
|
||||||
"zluda_sparse",
|
"zluda_sparse",
|
||||||
"compiler",
|
"compiler",
|
||||||
|
|
|
@ -219,6 +219,8 @@ pub fn compile_bitcode(
|
||||||
compile_to_exec.set_isa_name(gcn_arch)?;
|
compile_to_exec.set_isa_name(gcn_arch)?;
|
||||||
compile_to_exec.set_language(Language::LlvmIr)?;
|
compile_to_exec.set_language(Language::LlvmIr)?;
|
||||||
let common_options = [
|
let common_options = [
|
||||||
|
c"-Xlinker",
|
||||||
|
c"--no-undefined",
|
||||||
c"-mllvm",
|
c"-mllvm",
|
||||||
c"-ignore-tti-inline-compatible",
|
c"-ignore-tti-inline-compatible",
|
||||||
// c"-mllvm",
|
// c"-mllvm",
|
||||||
|
|
|
@ -8,7 +8,7 @@ edition = "2021"
|
||||||
quote = "1.0"
|
quote = "1.0"
|
||||||
syn = { version = "2.0", features = ["full", "visit-mut", "extra-traits"] }
|
syn = { version = "2.0", features = ["full", "visit-mut", "extra-traits"] }
|
||||||
proc-macro2 = "1.0"
|
proc-macro2 = "1.0"
|
||||||
rustc-hash = "1.1.0"
|
rustc-hash = "2.0.0"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
proc-macro = true
|
proc-macro = true
|
||||||
|
|
|
@ -1653,25 +1653,23 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
.ok_or_else(|| error_mismatched_type())?,
|
.ok_or_else(|| error_mismatched_type())?,
|
||||||
);
|
);
|
||||||
let src2 = self.resolver.value(src2)?;
|
let src2 = self.resolver.value(src2)?;
|
||||||
self.resolver.with_result(arguments.dst, |dst| {
|
let vec = unsafe {
|
||||||
let vec = unsafe {
|
LLVMBuildInsertElement(
|
||||||
LLVMBuildInsertElement(
|
self.builder,
|
||||||
self.builder,
|
LLVMGetPoison(dst_type),
|
||||||
LLVMGetPoison(dst_type),
|
llvm_fn(self.builder, src, packed_type, LLVM_UNNAMED.as_ptr()),
|
||||||
llvm_fn(self.builder, src, packed_type, LLVM_UNNAMED.as_ptr()),
|
LLVMConstInt(LLVMInt32TypeInContext(self.context), 1, false as i32),
|
||||||
LLVMConstInt(LLVMInt32TypeInContext(self.context), 1, false as i32),
|
LLVM_UNNAMED.as_ptr(),
|
||||||
LLVM_UNNAMED.as_ptr(),
|
)
|
||||||
)
|
};
|
||||||
};
|
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||||
unsafe {
|
LLVMBuildInsertElement(
|
||||||
LLVMBuildInsertElement(
|
self.builder,
|
||||||
self.builder,
|
vec,
|
||||||
vec,
|
llvm_fn(self.builder, src2, packed_type, LLVM_UNNAMED.as_ptr()),
|
||||||
llvm_fn(self.builder, src2, packed_type, LLVM_UNNAMED.as_ptr()),
|
LLVMConstInt(LLVMInt32TypeInContext(self.context), 0, false as i32),
|
||||||
LLVMConstInt(LLVMInt32TypeInContext(self.context), 0, false as i32),
|
dst,
|
||||||
dst,
|
)
|
||||||
)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
self.resolver.with_result(arguments.dst, |dst| unsafe {
|
||||||
|
@ -2197,7 +2195,7 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
Some(&ast::ScalarType::F32.into()),
|
Some(&ast::ScalarType::F32.into()),
|
||||||
vec![(
|
vec![(
|
||||||
self.resolver.value(arguments.src)?,
|
self.resolver.value(arguments.src)?,
|
||||||
get_scalar_type(self.context, ast::ScalarType::F32.into()),
|
get_scalar_type(self.context, ast::ScalarType::F32),
|
||||||
)],
|
)],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -2658,14 +2656,14 @@ impl<'a> MethodEmitContext<'a> {
|
||||||
|
|
||||||
let load = unsafe { LLVMBuildLoad2(self.builder, from_type, from, LLVM_UNNAMED.as_ptr()) };
|
let load = unsafe { LLVMBuildLoad2(self.builder, from_type, from, LLVM_UNNAMED.as_ptr()) };
|
||||||
unsafe {
|
unsafe {
|
||||||
LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8);
|
LLVMSetAlignment(load, cp_size.as_u64() as u32);
|
||||||
}
|
}
|
||||||
|
|
||||||
let extended = unsafe { LLVMBuildZExt(self.builder, load, to_type, LLVM_UNNAMED.as_ptr()) };
|
let extended = unsafe { LLVMBuildZExt(self.builder, load, to_type, LLVM_UNNAMED.as_ptr()) };
|
||||||
|
|
||||||
unsafe { LLVMBuildStore(self.builder, extended, to) };
|
let store = unsafe { LLVMBuildStore(self.builder, extended, to) };
|
||||||
unsafe {
|
unsafe {
|
||||||
LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8);
|
LLVMSetAlignment(store, cp_size.as_u64() as u32);
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -2945,7 +2943,7 @@ fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
|
||||||
Ok(match scope {
|
Ok(match scope {
|
||||||
ast::MemScope::Cta => c"workgroup",
|
ast::MemScope::Cta => c"workgroup",
|
||||||
ast::MemScope::Gpu => c"agent",
|
ast::MemScope::Gpu => c"agent",
|
||||||
ast::MemScope::Sys => c"",
|
ast::MemScope::Sys => c"system",
|
||||||
ast::MemScope::Cluster => todo!(),
|
ast::MemScope::Cluster => todo!(),
|
||||||
}
|
}
|
||||||
.as_ptr())
|
.as_ptr())
|
||||||
|
|
|
@ -2,6 +2,7 @@ use derive_more::Display;
|
||||||
use logos::Logos;
|
use logos::Logos;
|
||||||
use ptx_parser_macros::derive_parser;
|
use ptx_parser_macros::derive_parser;
|
||||||
use rustc_hash::FxHashMap;
|
use rustc_hash::FxHashMap;
|
||||||
|
use std::alloc::Layout;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::num::{NonZeroU8, ParseFloatError, ParseIntError};
|
use std::num::{NonZeroU8, ParseFloatError, ParseIntError};
|
||||||
use std::{iter, usize};
|
use std::{iter, usize};
|
||||||
|
@ -226,8 +227,9 @@ fn int_immediate<'a, 'input>(input: &mut PtxParser<'a, 'input>) -> PResult<ast::
|
||||||
take_error((opt(Token::Minus), num).map(|(neg, x)| {
|
take_error((opt(Token::Minus), num).map(|(neg, x)| {
|
||||||
let (num, radix, is_unsigned) = x;
|
let (num, radix, is_unsigned) = x;
|
||||||
if neg.is_some() {
|
if neg.is_some() {
|
||||||
match i64::from_str_radix(num, radix) {
|
let full_number = format!("-{num}");
|
||||||
Ok(x) => Ok(ast::ImmediateValue::S64(-x)),
|
match i64::from_str_radix(&full_number, radix) {
|
||||||
|
Ok(x) => Ok(ast::ImmediateValue::S64(x)),
|
||||||
Err(err) => Err((ast::ImmediateValue::S64(0), PtxError::from(err))),
|
Err(err) => Err((ast::ImmediateValue::S64(0), PtxError::from(err))),
|
||||||
}
|
}
|
||||||
} else if is_unsigned {
|
} else if is_unsigned {
|
||||||
|
@ -345,7 +347,9 @@ fn reg_or_immediate<'a, 'input>(
|
||||||
.parse_next(stream)
|
.parse_next(stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parse_for_errors<'input>(text: &'input str) -> Vec<PtxError<'input>> {
|
pub fn parse_for_errors_and_params<'input>(
|
||||||
|
text: &'input str,
|
||||||
|
) -> (Vec<PtxError<'input>>, FxHashMap<String, Vec<Layout>>) {
|
||||||
let (tokens, mut errors) = lex_with_span_unchecked(text);
|
let (tokens, mut errors) = lex_with_span_unchecked(text);
|
||||||
let parse_result = {
|
let parse_result = {
|
||||||
let state = PtxParserState::new(text, &mut errors);
|
let state = PtxParserState::new(text, &mut errors);
|
||||||
|
@ -357,13 +361,30 @@ pub fn parse_for_errors<'input>(text: &'input str) -> Vec<PtxError<'input>> {
|
||||||
.parse(parser)
|
.parse(parser)
|
||||||
.map_err(|err| PtxError::Parser(err.into_inner()))
|
.map_err(|err| PtxError::Parser(err.into_inner()))
|
||||||
};
|
};
|
||||||
match parse_result {
|
let params = match parse_result {
|
||||||
Ok(_) => {}
|
Ok(module) => module
|
||||||
|
.directives
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|directive| {
|
||||||
|
if let ast::Directive::Method(_, func) = directive {
|
||||||
|
let layouts = func
|
||||||
|
.func_directive
|
||||||
|
.input_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|arg| arg.info.v_type.layout())
|
||||||
|
.collect();
|
||||||
|
Some((func.func_directive.name().to_string(), layouts))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
errors.push(err);
|
errors.push(err);
|
||||||
|
FxHashMap::default()
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
errors
|
(errors, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lex_with_span_unchecked<'input>(
|
fn lex_with_span_unchecked<'input>(
|
||||||
|
|
|
@ -22,7 +22,7 @@ num_enum = "0.4"
|
||||||
lz4-sys = "1.9"
|
lz4-sys = "1.9"
|
||||||
tempfile = "3"
|
tempfile = "3"
|
||||||
paste = "1.0"
|
paste = "1.0"
|
||||||
rustc-hash = "1.1"
|
rustc-hash = "2.0.0"
|
||||||
zluda_common = { path = "../zluda_common" }
|
zluda_common = { path = "../zluda_common" }
|
||||||
blake3 = "1.8.2"
|
blake3 = "1.8.2"
|
||||||
serde = "1.0.219"
|
serde = "1.0.219"
|
||||||
|
|
|
@ -9,6 +9,6 @@ syn = { version = "2.0", features = ["full", "visit-mut"] }
|
||||||
proc-macro2 = "1.0.89"
|
proc-macro2 = "1.0.89"
|
||||||
quote = "1.0"
|
quote = "1.0"
|
||||||
prettyplease = "0.2.25"
|
prettyplease = "0.2.25"
|
||||||
rustc-hash = "1.1.0"
|
rustc-hash = "2.0.0"
|
||||||
libloading = "0.8"
|
libloading = "0.8"
|
||||||
cuda_types = { path = "../cuda_types" }
|
cuda_types = { path = "../cuda_types" }
|
||||||
|
|
17
zluda_replay/Cargo.toml
Normal file
17
zluda_replay/Cargo.toml
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
[package]
|
||||||
|
name = "zluda_replay"
|
||||||
|
version = "0.0.0"
|
||||||
|
authors = ["Andrzej Janik <vosen@vosen.pl>"]
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "zluda_replay"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
zluda_trace_common = { path = "../zluda_trace_common" }
|
||||||
|
cuda_macros = { path = "../cuda_macros" }
|
||||||
|
cuda_types = { path = "../cuda_types" }
|
||||||
|
libloading = "0.8"
|
||||||
|
|
||||||
|
[package.metadata.zluda]
|
||||||
|
debug_only = true
|
103
zluda_replay/src/main.rs
Normal file
103
zluda_replay/src/main.rs
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
use std::mem;
|
||||||
|
|
||||||
|
use cuda_types::cuda::{CUdeviceptr_v2, CUstream};
|
||||||
|
|
||||||
|
struct CudaDynamicFns {
|
||||||
|
handle: libloading::Library,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CudaDynamicFns {
|
||||||
|
unsafe fn new(path: &str) -> Result<Self, libloading::Error> {
|
||||||
|
let handle = libloading::Library::new(path)?;
|
||||||
|
Ok(Self { handle })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! emit_cuda_fn_table {
|
||||||
|
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
|
||||||
|
impl CudaDynamicFns {
|
||||||
|
$(
|
||||||
|
#[allow(dead_code)]
|
||||||
|
unsafe fn $fn_name(&self, $($arg_id : $arg_type),*) -> $ret_type {
|
||||||
|
let func = self.handle.get::<unsafe extern $abi fn ($($arg_type),*) -> $ret_type>(concat!(stringify!($fn_name), "\0").as_bytes());
|
||||||
|
(func.unwrap())($($arg_id),*)
|
||||||
|
}
|
||||||
|
)*
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
cuda_macros::cuda_function_declarations!(emit_cuda_fn_table);
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let args: Vec<String> = std::env::args().collect();
|
||||||
|
let libcuda = unsafe { CudaDynamicFns::new(&args[1]).unwrap() };
|
||||||
|
unsafe { libcuda.cuInit(0) }.unwrap();
|
||||||
|
unsafe { libcuda.cuCtxCreate_v2(&mut mem::zeroed(), 0, 0) }.unwrap();
|
||||||
|
let reader = std::fs::File::open(&args[2]).unwrap();
|
||||||
|
let (mut manifest, mut source, mut buffers) = zluda_trace_common::replay::load(reader);
|
||||||
|
let mut args = manifest
|
||||||
|
.parameters
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, param)| {
|
||||||
|
let mut buffer = buffers.remove(&format!("param_{i}.bin")).unwrap();
|
||||||
|
for param_ptr in param.pointer_offsets.iter() {
|
||||||
|
let buffer_param_slice = &mut buffer[param_ptr.offset_in_param
|
||||||
|
..param_ptr.offset_in_param + std::mem::size_of::<usize>()];
|
||||||
|
let mut dev_ptr = unsafe { mem::zeroed() };
|
||||||
|
let host_buffer = buffers
|
||||||
|
.remove(&format!(
|
||||||
|
"param_{i}_ptr_{}_pre.bin",
|
||||||
|
param_ptr.offset_in_param
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
unsafe { libcuda.cuMemAlloc_v2(&mut dev_ptr, host_buffer.len()) }.unwrap();
|
||||||
|
unsafe {
|
||||||
|
libcuda.cuMemcpyHtoD_v2(dev_ptr, host_buffer.as_ptr().cast(), host_buffer.len())
|
||||||
|
}
|
||||||
|
.unwrap();
|
||||||
|
dev_ptr = CUdeviceptr_v2(unsafe {
|
||||||
|
dev_ptr
|
||||||
|
.0
|
||||||
|
.cast::<u8>()
|
||||||
|
.add(param_ptr.offset_in_buffer)
|
||||||
|
.cast()
|
||||||
|
});
|
||||||
|
buffer_param_slice.copy_from_slice(&(dev_ptr.0 as usize).to_ne_bytes());
|
||||||
|
}
|
||||||
|
buffer
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let mut module = unsafe { mem::zeroed() };
|
||||||
|
std::fs::write("/tmp/source.ptx", &source).unwrap();
|
||||||
|
source.push('\0');
|
||||||
|
unsafe { libcuda.cuModuleLoadData(&mut module, source.as_ptr().cast()) }.unwrap();
|
||||||
|
let mut function = unsafe { mem::zeroed() };
|
||||||
|
manifest.kernel_name.push('\0');
|
||||||
|
unsafe {
|
||||||
|
libcuda.cuModuleGetFunction(&mut function, module, manifest.kernel_name.as_ptr().cast())
|
||||||
|
}
|
||||||
|
.unwrap();
|
||||||
|
let mut cuda_args = args
|
||||||
|
.iter_mut()
|
||||||
|
.map(|arg| arg.as_mut_ptr().cast::<std::ffi::c_void>())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
unsafe {
|
||||||
|
libcuda.cuLaunchKernel(
|
||||||
|
function,
|
||||||
|
manifest.config.grid_dim.0,
|
||||||
|
manifest.config.grid_dim.1,
|
||||||
|
manifest.config.grid_dim.2,
|
||||||
|
manifest.config.block_dim.0,
|
||||||
|
manifest.config.block_dim.1,
|
||||||
|
manifest.config.block_dim.2,
|
||||||
|
manifest.config.shared_mem_bytes,
|
||||||
|
CUstream(std::ptr::null_mut()),
|
||||||
|
cuda_args.as_mut_ptr().cast(),
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
.unwrap();
|
||||||
|
unsafe { libcuda.cuCtxSynchronize() }.unwrap();
|
||||||
|
}
|
|
@ -24,7 +24,7 @@ paste = "1.0"
|
||||||
cuda_macros = { path = "../cuda_macros" }
|
cuda_macros = { path = "../cuda_macros" }
|
||||||
cuda_types = { path = "../cuda_types" }
|
cuda_types = { path = "../cuda_types" }
|
||||||
parking_lot = "0.12.3"
|
parking_lot = "0.12.3"
|
||||||
rustc-hash = "1.1.0"
|
rustc-hash = "2.0.0"
|
||||||
cglue = "0.3.5"
|
cglue = "0.3.5"
|
||||||
zstd-safe = { version = "7.2.4", features = ["std"] }
|
zstd-safe = { version = "7.2.4", features = ["std"] }
|
||||||
unwrap_or = "1.0.1"
|
unwrap_or = "1.0.1"
|
||||||
|
|
|
@ -12,6 +12,7 @@ use std::ptr::NonNull;
|
||||||
use std::sync::LazyLock;
|
use std::sync::LazyLock;
|
||||||
use std::{env, error::Error, fs, path::PathBuf, sync::Mutex};
|
use std::{env, error::Error, fs, path::PathBuf, sync::Mutex};
|
||||||
use std::{io, mem, ptr, usize};
|
use std::{io, mem, ptr, usize};
|
||||||
|
use unwrap_or::unwrap_some_or;
|
||||||
|
|
||||||
extern crate cuda_types;
|
extern crate cuda_types;
|
||||||
|
|
||||||
|
@ -110,7 +111,7 @@ macro_rules! override_fn_core {
|
||||||
).ok();
|
).ok();
|
||||||
formatted_args
|
formatted_args
|
||||||
};
|
};
|
||||||
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| Some(());
|
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| Some(((), ()));
|
||||||
let cuda_call = |_| {
|
let cuda_call = |_| {
|
||||||
paste!{ [<$fn_name _impl >] ( $($arg_id),* ) }
|
paste!{ [<$fn_name _impl >] ( $($arg_id),* ) }
|
||||||
};
|
};
|
||||||
|
@ -121,7 +122,7 @@ macro_rules! override_fn_core {
|
||||||
format_curesult,
|
format_curesult,
|
||||||
extract_fn_ptr,
|
extract_fn_ptr,
|
||||||
cuda_call,
|
cuda_call,
|
||||||
move |_, _, _, _| {}
|
move |_, _, _, _, _| {}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
)*
|
)*
|
||||||
|
@ -157,9 +158,9 @@ impl ::dark_api::zluda_trace::CudaDarkApi for InternalTableImpl {
|
||||||
Some(|| args.call().to_vec()),
|
Some(|| args.call().to_vec()),
|
||||||
internal_error,
|
internal_error,
|
||||||
|status| format_status(status).to_vec(),
|
|status| format_status(status).to_vec(),
|
||||||
|_, _| Some(()),
|
|_, _| Some(((), ())),
|
||||||
|_| fn_.call(),
|
|_| fn_.call(),
|
||||||
move |_, _, _, _| {},
|
move |_, _, _, _, _| {},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -201,7 +202,7 @@ macro_rules! dark_api_fn_redirect_log {
|
||||||
).ok();
|
).ok();
|
||||||
formatted_args
|
formatted_args
|
||||||
};
|
};
|
||||||
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| { Some(()) };
|
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| { Some(((), ())) };
|
||||||
let cuda_call = |_: () | {
|
let cuda_call = |_: () | {
|
||||||
ReprUsize::to_usize(original_fn( $( $arg_id ),* ))
|
ReprUsize::to_usize(original_fn( $( $arg_id ),* ))
|
||||||
};
|
};
|
||||||
|
@ -215,7 +216,7 @@ macro_rules! dark_api_fn_redirect_log {
|
||||||
|status| <$ret_type as ReprUsize>::format_status(status).to_vec(),
|
|status| <$ret_type as ReprUsize>::format_status(status).to_vec(),
|
||||||
extract_fn_ptr,
|
extract_fn_ptr,
|
||||||
cuda_call,
|
cuda_call,
|
||||||
move |_, _, _, _| {}
|
move |_, _, _, _, _| {}
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
)+
|
)+
|
||||||
|
@ -256,7 +257,7 @@ macro_rules! dark_api_fn_redirect_log_post {
|
||||||
).ok();
|
).ok();
|
||||||
formatted_args
|
formatted_args
|
||||||
};
|
};
|
||||||
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| { Some(()) };
|
let extract_fn_ptr = |_: &mut GlobalDelayedState, _: &mut FnCallLog| { Some(((), ())) };
|
||||||
let cuda_call = |_: () | {
|
let cuda_call = |_: () | {
|
||||||
ReprUsize::to_usize(original_fn( $( $arg_id ),* ))
|
ReprUsize::to_usize(original_fn( $( $arg_id ),* ))
|
||||||
};
|
};
|
||||||
|
@ -270,7 +271,7 @@ macro_rules! dark_api_fn_redirect_log_post {
|
||||||
|status| <$ret_type as ReprUsize>::format_status(status).to_vec(),
|
|status| <$ret_type as ReprUsize>::format_status(status).to_vec(),
|
||||||
extract_fn_ptr,
|
extract_fn_ptr,
|
||||||
cuda_call,
|
cuda_call,
|
||||||
move |state, logger, _, cuda_result| paste! { Self:: [<$fn_ _post>] } ( $( $arg_id ),* , &mut state.cuda_state, logger, <$ret_type as ReprUsize>::from_usize(cuda_result))
|
move |state, logger, _, _, cuda_result| paste! { Self:: [<$fn_ _post>] } ( $( $arg_id ),* , &mut state.cuda_state, logger, <$ret_type as ReprUsize>::from_usize(cuda_result))
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
)+
|
)+
|
||||||
|
@ -287,7 +288,11 @@ impl DarkApiTrace {
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
_result: CUresult,
|
_result: CUresult,
|
||||||
) {
|
) {
|
||||||
state.record_new_library(unsafe { *module }, fatbinc_wrapper.cast(), fn_logger)
|
state.record_new_library(
|
||||||
|
unsafe { *module }.0.cast(),
|
||||||
|
fatbinc_wrapper.cast(),
|
||||||
|
fn_logger,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_module_from_cubin_ext1_post(
|
fn get_module_from_cubin_ext1_post(
|
||||||
|
@ -321,7 +326,11 @@ impl DarkApiTrace {
|
||||||
observed: UInt::U32(arg5),
|
observed: UInt::U32(arg5),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
state.record_new_library(unsafe { *module }, fatbinc_wrapper.cast(), fn_logger)
|
state.record_new_library(
|
||||||
|
unsafe { *module }.0.cast(),
|
||||||
|
fatbinc_wrapper.cast(),
|
||||||
|
fn_logger,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_module_from_cubin_ext2_post(
|
fn get_module_from_cubin_ext2_post(
|
||||||
|
@ -355,7 +364,7 @@ impl DarkApiTrace {
|
||||||
observed: UInt::U32(arg5),
|
observed: UInt::U32(arg5),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
state.record_new_library(unsafe { *module }, fatbin_header.cast(), fn_logger)
|
state.record_new_library(unsafe { *module }.0.cast(), fatbin_header.cast(), fn_logger)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -770,7 +779,7 @@ macro_rules! extern_redirect {
|
||||||
};
|
};
|
||||||
let extract_fn_ptr = |state: &mut GlobalDelayedState, _: &mut FnCallLog| {
|
let extract_fn_ptr = |state: &mut GlobalDelayedState, _: &mut FnCallLog| {
|
||||||
paste::paste! {
|
paste::paste! {
|
||||||
state.libcuda. [<get_ $fn_name>]()
|
state.libcuda. [<get_ $fn_name>]().map(|x| ((), x) )
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let cuda_call = |fn_ptr: extern $abi fn ( $($arg_type),* ) -> $ret_type | {
|
let cuda_call = |fn_ptr: extern $abi fn ( $($arg_type),* ) -> $ret_type | {
|
||||||
|
@ -783,7 +792,7 @@ macro_rules! extern_redirect {
|
||||||
format_curesult,
|
format_curesult,
|
||||||
extract_fn_ptr,
|
extract_fn_ptr,
|
||||||
cuda_call,
|
cuda_call,
|
||||||
move |_, _, _, _| {}
|
move |_, _, _, _, _| {}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
)*
|
)*
|
||||||
|
@ -806,7 +815,7 @@ macro_rules! extern_redirect_with_post {
|
||||||
};
|
};
|
||||||
let extract_fn_ptr = |state: &mut GlobalDelayedState, _: &mut FnCallLog| {
|
let extract_fn_ptr = |state: &mut GlobalDelayedState, _: &mut FnCallLog| {
|
||||||
paste::paste! {
|
paste::paste! {
|
||||||
state.libcuda. [<get_ $fn_name>]()
|
state.libcuda. [<get_ $fn_name>]().map(|x| ((), x) )
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let cuda_call = |fn_ptr: extern $abi fn ( $($arg_type),* ) -> $ret_type | {
|
let cuda_call = |fn_ptr: extern $abi fn ( $($arg_type),* ) -> $ret_type | {
|
||||||
|
@ -819,7 +828,43 @@ macro_rules! extern_redirect_with_post {
|
||||||
format_curesult,
|
format_curesult,
|
||||||
extract_fn_ptr,
|
extract_fn_ptr,
|
||||||
cuda_call,
|
cuda_call,
|
||||||
move |state, logger, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , &mut state.cuda_state, logger, cuda_result )
|
move |state, logger, _, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , &mut state.cuda_state, logger, cuda_result )
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)*
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! extern_redirect_with_pre_post {
|
||||||
|
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
|
||||||
|
$(
|
||||||
|
#[no_mangle]
|
||||||
|
#[allow(improper_ctypes_definitions)]
|
||||||
|
pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
||||||
|
let format_args = || {
|
||||||
|
let mut formatted_args = Vec::new();
|
||||||
|
(paste! { format :: [<write_ $fn_name>] }) (
|
||||||
|
&mut formatted_args
|
||||||
|
$(,$arg_id)*
|
||||||
|
).ok();
|
||||||
|
formatted_args
|
||||||
|
};
|
||||||
|
let extract_fn_ptr = |state: &mut GlobalDelayedState, logger: &mut FnCallLog| {
|
||||||
|
paste::paste! {
|
||||||
|
state.libcuda. [<get_ $fn_name>]().map(|x| (paste! { [<$fn_name _Pre>] } ( $( $arg_id ),* , &mut state.libcuda, &mut state.cuda_state, logger ), x ))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let cuda_call = |fn_ptr: extern $abi fn ( $($arg_type),* ) -> $ret_type | {
|
||||||
|
fn_ptr( $( $arg_id ),* )
|
||||||
|
};
|
||||||
|
GlobalState2::under_lock(
|
||||||
|
CudaFunctionName::Normal(stringify!($fn_name)),
|
||||||
|
Some(format_args),
|
||||||
|
CUresult::INTERNAL_ERROR,
|
||||||
|
format_curesult,
|
||||||
|
extract_fn_ptr,
|
||||||
|
cuda_call,
|
||||||
|
move |state, logger, pre_state, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , pre_state, &mut state.libcuda, &mut state.cuda_state, logger, cuda_result )
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
)*
|
)*
|
||||||
|
@ -843,13 +888,15 @@ cuda_function_declarations!(
|
||||||
cuModuleLoad,
|
cuModuleLoad,
|
||||||
cuModuleLoadData,
|
cuModuleLoadData,
|
||||||
cuModuleLoadDataEx,
|
cuModuleLoadDataEx,
|
||||||
|
cuLibraryGetFunction,
|
||||||
cuModuleGetFunction,
|
cuModuleGetFunction,
|
||||||
cuDeviceGetAttribute,
|
cuDeviceGetAttribute,
|
||||||
cuDeviceComputeCapability,
|
cuDeviceComputeCapability,
|
||||||
cuModuleLoadFatBinary,
|
cuModuleLoadFatBinary,
|
||||||
cuLibraryGetModule,
|
cuLibraryGetModule,
|
||||||
cuLibraryLoadData
|
cuLibraryLoadData,
|
||||||
],
|
],
|
||||||
|
extern_redirect_with_pre_post <= [cuLaunchKernel, cuLaunchKernelEx],
|
||||||
override_fn_core <= [cuGetProcAddress, cuGetProcAddress_v2],
|
override_fn_core <= [cuGetProcAddress, cuGetProcAddress_v2],
|
||||||
override_fn_full <= [cuGetExportTable],
|
override_fn_full <= [cuGetExportTable],
|
||||||
);
|
);
|
||||||
|
@ -859,6 +906,7 @@ mod log;
|
||||||
#[cfg_attr(windows, path = "os_win.rs")]
|
#[cfg_attr(windows, path = "os_win.rs")]
|
||||||
#[cfg_attr(not(windows), path = "os_unix.rs")]
|
#[cfg_attr(not(windows), path = "os_unix.rs")]
|
||||||
mod os;
|
mod os;
|
||||||
|
mod replay;
|
||||||
mod trace;
|
mod trace;
|
||||||
|
|
||||||
struct GlobalState2 {
|
struct GlobalState2 {
|
||||||
|
@ -907,27 +955,33 @@ impl GlobalState2 {
|
||||||
// * Post-call:
|
// * Post-call:
|
||||||
// We log the output of the CUDA function and any errors that may have occurred. This phase
|
// We log the output of the CUDA function and any errors that may have occurred. This phase
|
||||||
// is also covered by a drop guard which will flush the log buffer in case of panic
|
// is also covered by a drop guard which will flush the log buffer in case of panic
|
||||||
fn under_lock<'a, FnPtr: Copy, InnerResult: Copy>(
|
fn under_lock<'a, PreState, FnPtr: Copy, InnerResult: Copy>(
|
||||||
name: CudaFunctionName,
|
name: CudaFunctionName,
|
||||||
args: Option<impl FnOnce() -> Vec<u8>>,
|
args: Option<impl FnOnce() -> Vec<u8>>,
|
||||||
internal_error: InnerResult,
|
internal_error: InnerResult,
|
||||||
format_status: impl FnOnce(InnerResult) -> Vec<u8>,
|
format_status: impl FnOnce(InnerResult) -> Vec<u8>,
|
||||||
pre_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog) -> Option<FnPtr>,
|
pre_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog) -> Option<(PreState, FnPtr)>,
|
||||||
inner_call: impl FnOnce(FnPtr) -> InnerResult,
|
inner_call: impl FnOnce(FnPtr) -> InnerResult,
|
||||||
post_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog, FnPtr, InnerResult),
|
post_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog, PreState, FnPtr, InnerResult),
|
||||||
) -> InnerResult {
|
) -> InnerResult {
|
||||||
fn under_lock_impl<'a, FnPtr: Copy, InnerResult: Copy>(
|
fn under_lock_impl<'a, PreState, FnPtr: Copy, InnerResult: Copy>(
|
||||||
name: CudaFunctionName,
|
name: CudaFunctionName,
|
||||||
args: Option<impl FnOnce() -> Vec<u8>>,
|
args: Option<impl FnOnce() -> Vec<u8>>,
|
||||||
internal_error: InnerResult,
|
internal_error: InnerResult,
|
||||||
format_status: impl FnOnce(InnerResult) -> Vec<u8>,
|
format_status: impl FnOnce(InnerResult) -> Vec<u8>,
|
||||||
pre_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog) -> Option<FnPtr>,
|
pre_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog) -> Option<(PreState, FnPtr)>,
|
||||||
inner_call: impl FnOnce(FnPtr) -> InnerResult,
|
inner_call: impl FnOnce(FnPtr) -> InnerResult,
|
||||||
post_call: impl FnOnce(&mut GlobalDelayedState, &mut FnCallLog, FnPtr, InnerResult),
|
post_call: impl FnOnce(
|
||||||
|
&mut GlobalDelayedState,
|
||||||
|
&mut FnCallLog,
|
||||||
|
PreState,
|
||||||
|
FnPtr,
|
||||||
|
InnerResult,
|
||||||
|
),
|
||||||
) -> InnerResult {
|
) -> InnerResult {
|
||||||
let global_state = GLOBAL_STATE2.lock();
|
let global_state = GLOBAL_STATE2.lock();
|
||||||
let global_state_ref_cell = &*global_state;
|
let global_state_ref_cell = &*global_state;
|
||||||
let pre_value = {
|
let (pre_state, pre_ptr) = {
|
||||||
let mut global_state_ref_mut = global_state_ref_cell.borrow_mut();
|
let mut global_state_ref_mut = global_state_ref_cell.borrow_mut();
|
||||||
let global_state = &mut *global_state_ref_mut;
|
let global_state = &mut *global_state_ref_mut;
|
||||||
let panic_guard = OuterCallGuard {
|
let panic_guard = OuterCallGuard {
|
||||||
|
@ -963,7 +1017,7 @@ impl GlobalState2 {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let panic_guard = InnerCallGuard(global_state_ref_cell);
|
let panic_guard = InnerCallGuard(global_state_ref_cell);
|
||||||
let inner_result = inner_call(pre_value);
|
let inner_result = inner_call(pre_ptr);
|
||||||
let global_state = &mut *global_state_ref_cell.borrow_mut();
|
let global_state = &mut *global_state_ref_cell.borrow_mut();
|
||||||
mem::forget(panic_guard);
|
mem::forget(panic_guard);
|
||||||
let _drop_guard = OuterCallGuard {
|
let _drop_guard = OuterCallGuard {
|
||||||
|
@ -978,7 +1032,8 @@ impl GlobalState2 {
|
||||||
post_call(
|
post_call(
|
||||||
global_state.delayed_state.as_mut().unwrap(),
|
global_state.delayed_state.as_mut().unwrap(),
|
||||||
&mut logger,
|
&mut logger,
|
||||||
pre_value,
|
pre_state,
|
||||||
|
pre_ptr,
|
||||||
inner_result,
|
inner_result,
|
||||||
);
|
);
|
||||||
inner_result
|
inner_result
|
||||||
|
@ -1098,6 +1153,22 @@ impl FnCallLog {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn try_cuda(&mut self, fn_: impl FnOnce() -> Option<CUresult>) -> Option<()> {
|
||||||
|
match fn_() {
|
||||||
|
Some(Ok(())) => Some(()),
|
||||||
|
None => {
|
||||||
|
self.subcalls
|
||||||
|
.push(LogEntry::Error(ErrorEntry::CudaError(None)));
|
||||||
|
None
|
||||||
|
}
|
||||||
|
Some(Err(err)) => {
|
||||||
|
self.subcalls
|
||||||
|
.push(LogEntry::Error(ErrorEntry::CudaError(Some(err))));
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn try_<T>(&mut self, f: impl FnOnce(&mut Self) -> Result<T, ErrorEntry>) -> Option<T> {
|
fn try_<T>(&mut self, f: impl FnOnce(&mut Self) -> Result<T, ErrorEntry>) -> Option<T> {
|
||||||
match f(self) {
|
match f(self) {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
@ -1209,6 +1280,8 @@ struct Settings {
|
||||||
dump_dir: Option<PathBuf>,
|
dump_dir: Option<PathBuf>,
|
||||||
libcuda_path: String,
|
libcuda_path: String,
|
||||||
override_cc: Option<(u32, u32)>,
|
override_cc: Option<(u32, u32)>,
|
||||||
|
kernel_name_filter: Option<regex::Regex>,
|
||||||
|
kernel_no_output: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Settings {
|
impl Settings {
|
||||||
|
@ -1257,10 +1330,42 @@ impl Settings {
|
||||||
})
|
})
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
let kernel_name_filter = match env::var("ZLUDA_SAVE_KERNELS") {
|
||||||
|
Err(env::VarError::NotPresent) => None,
|
||||||
|
Err(e) => {
|
||||||
|
logger.log(log::ErrorEntry::ErrorBox(Box::new(e) as _));
|
||||||
|
None
|
||||||
|
}
|
||||||
|
Ok(env_string) => logger.try_return(|| {
|
||||||
|
regex::Regex::new(&env_string).map_err(|e| ErrorEntry::InvalidEnvVar {
|
||||||
|
var: "ZLUDA_SAVE_KERNELS",
|
||||||
|
pattern: "valid regex",
|
||||||
|
value: format!("{} ({})", env_string, e),
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
let kernel_no_output = match env::var("ZLUDA_SAVE_KERNELS_NO_OUTPUT") {
|
||||||
|
Err(env::VarError::NotPresent) => None,
|
||||||
|
Err(e) => {
|
||||||
|
logger.log(log::ErrorEntry::ErrorBox(Box::new(e) as _));
|
||||||
|
None
|
||||||
|
}
|
||||||
|
Ok(env_string) => logger
|
||||||
|
.try_return(|| {
|
||||||
|
str::parse::<u8>(&env_string).map_err(|err| ErrorEntry::InvalidEnvVar {
|
||||||
|
var: "ZLUDA_SAVE_KERNELS_NO_OUTPUT",
|
||||||
|
pattern: "number",
|
||||||
|
value: format!("{} ({})", env_string, err),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.map(|x| x != 0),
|
||||||
|
};
|
||||||
Settings {
|
Settings {
|
||||||
dump_dir,
|
dump_dir,
|
||||||
libcuda_path,
|
libcuda_path,
|
||||||
override_cc,
|
override_cc,
|
||||||
|
kernel_name_filter,
|
||||||
|
kernel_no_output,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1307,7 +1412,7 @@ pub(crate) fn cuModuleLoadData_Post(
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
_result: CUresult,
|
_result: CUresult,
|
||||||
) {
|
) {
|
||||||
state.record_new_library(unsafe { *module }, raw_image, fn_logger)
|
state.record_new_library(unsafe { *module }.0.cast(), raw_image, fn_logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
|
@ -1326,13 +1431,17 @@ pub(crate) fn cuModuleLoadDataEx_Post(
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
pub(crate) fn cuModuleGetFunction_Post(
|
pub(crate) fn cuModuleGetFunction_Post(
|
||||||
_hfunc: *mut CUfunction,
|
hfunc: *mut CUfunction,
|
||||||
_hmod: CUmodule,
|
hmod: CUmodule,
|
||||||
_name: *const ::std::os::raw::c_char,
|
name: *const ::std::os::raw::c_char,
|
||||||
_state: &mut trace::StateTracker,
|
state: &mut trace::StateTracker,
|
||||||
_fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
_result: CUresult,
|
result: CUresult,
|
||||||
) {
|
) {
|
||||||
|
if !result.is_ok() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
state.record_function_from_module(fn_logger, unsafe { *hfunc }, hmod, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
|
@ -1385,7 +1494,7 @@ pub(crate) fn cuModuleLoadFatBinary_Post(
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
_result: CUresult,
|
_result: CUresult,
|
||||||
) {
|
) {
|
||||||
state.record_new_library(unsafe { *module }, fatbin_header.cast(), fn_logger)
|
state.record_new_library(unsafe { *module }.0.cast(), fatbin_header.cast(), fn_logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
|
@ -1393,13 +1502,13 @@ pub(crate) fn cuLibraryGetModule_Post(
|
||||||
module: *mut cuda_types::cuda::CUmodule,
|
module: *mut cuda_types::cuda::CUmodule,
|
||||||
library: cuda_types::cuda::CUlibrary,
|
library: cuda_types::cuda::CUlibrary,
|
||||||
state: &mut trace::StateTracker,
|
state: &mut trace::StateTracker,
|
||||||
fn_logger: &mut FnCallLog,
|
_fn_logger: &mut FnCallLog,
|
||||||
_result: CUresult,
|
result: CUresult,
|
||||||
) {
|
) {
|
||||||
match state.libraries.get(&library).copied() {
|
if !result.is_ok() {
|
||||||
None => fn_logger.log(log::ErrorEntry::UnknownLibrary(library)),
|
return;
|
||||||
Some(code) => state.record_new_library(unsafe { *module }, code.0, fn_logger),
|
|
||||||
}
|
}
|
||||||
|
state.record_module_in_library(unsafe { *module }, library);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
|
@ -1416,10 +1525,149 @@ pub(crate) fn cuLibraryLoadData_Post(
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
_result: CUresult,
|
_result: CUresult,
|
||||||
) {
|
) {
|
||||||
state
|
state.record_new_library(unsafe { *library }.0.cast(), code, fn_logger);
|
||||||
.libraries
|
}
|
||||||
.insert(unsafe { *library }, trace::CodePointer(code));
|
|
||||||
// TODO: this is not correct, but it's enough for now, we just want to
|
#[allow(non_snake_case)]
|
||||||
// save the binary to disk
|
pub(crate) fn cuLaunchKernel_Pre(
|
||||||
state.record_new_library(unsafe { CUmodule((*library).0.cast()) }, code, fn_logger);
|
f: cuda_types::cuda::CUfunction,
|
||||||
|
gridDimX: ::core::ffi::c_uint,
|
||||||
|
gridDimY: ::core::ffi::c_uint,
|
||||||
|
gridDimZ: ::core::ffi::c_uint,
|
||||||
|
blockDimX: ::core::ffi::c_uint,
|
||||||
|
blockDimY: ::core::ffi::c_uint,
|
||||||
|
blockDimZ: ::core::ffi::c_uint,
|
||||||
|
sharedMemBytes: ::core::ffi::c_uint,
|
||||||
|
hStream: cuda_types::cuda::CUstream,
|
||||||
|
kernel_params: *mut *mut ::core::ffi::c_void,
|
||||||
|
_extra: *mut *mut ::core::ffi::c_void,
|
||||||
|
libcuda: &mut CudaDynamicFns,
|
||||||
|
state: &mut trace::StateTracker,
|
||||||
|
fn_logger: &mut FnCallLog,
|
||||||
|
) -> Option<replay::LaunchPreState> {
|
||||||
|
launch_kernel_pre(
|
||||||
|
f,
|
||||||
|
CUlaunchConfig {
|
||||||
|
gridDimX,
|
||||||
|
gridDimY,
|
||||||
|
gridDimZ,
|
||||||
|
blockDimX,
|
||||||
|
blockDimY,
|
||||||
|
blockDimZ,
|
||||||
|
sharedMemBytes,
|
||||||
|
hStream,
|
||||||
|
attrs: ptr::null_mut(),
|
||||||
|
numAttrs: 0,
|
||||||
|
},
|
||||||
|
hStream,
|
||||||
|
kernel_params,
|
||||||
|
libcuda,
|
||||||
|
state,
|
||||||
|
fn_logger,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn launch_kernel_pre(
|
||||||
|
f: cuda_types::cuda::CUfunction,
|
||||||
|
config: CUlaunchConfig,
|
||||||
|
stream: cuda_types::cuda::CUstream,
|
||||||
|
kernel_params: *mut *mut ::core::ffi::c_void,
|
||||||
|
libcuda: &mut CudaDynamicFns,
|
||||||
|
state: &mut trace::StateTracker,
|
||||||
|
fn_logger: &mut FnCallLog,
|
||||||
|
) -> Option<replay::LaunchPreState> {
|
||||||
|
state.enqueue_counter += 1;
|
||||||
|
if kernel_params.is_null() {
|
||||||
|
fn_logger.log(ErrorEntry::NullPointer("kernel_params"));
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
if state.dump_dir().is_none() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
replay::pre_kernel_launch(libcuda, state, fn_logger, config, f, stream, kernel_params)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
pub(crate) fn cuLaunchKernel_Post(
|
||||||
|
_f: cuda_types::cuda::CUfunction,
|
||||||
|
gridDimX: ::core::ffi::c_uint,
|
||||||
|
gridDimY: ::core::ffi::c_uint,
|
||||||
|
gridDimZ: ::core::ffi::c_uint,
|
||||||
|
blockDimX: ::core::ffi::c_uint,
|
||||||
|
blockDimY: ::core::ffi::c_uint,
|
||||||
|
blockDimZ: ::core::ffi::c_uint,
|
||||||
|
sharedMemBytes: ::core::ffi::c_uint,
|
||||||
|
hStream: cuda_types::cuda::CUstream,
|
||||||
|
kernel_params: *mut *mut ::core::ffi::c_void,
|
||||||
|
_extra: *mut *mut ::core::ffi::c_void,
|
||||||
|
pre_state: Option<replay::LaunchPreState>,
|
||||||
|
libcuda: &mut CudaDynamicFns,
|
||||||
|
state: &mut trace::StateTracker,
|
||||||
|
fn_logger: &mut FnCallLog,
|
||||||
|
_result: CUresult,
|
||||||
|
) {
|
||||||
|
let pre_state = unwrap_some_or!(pre_state, return);
|
||||||
|
replay::post_kernel_launch(
|
||||||
|
libcuda,
|
||||||
|
state,
|
||||||
|
fn_logger,
|
||||||
|
CUlaunchConfig {
|
||||||
|
gridDimX,
|
||||||
|
gridDimY,
|
||||||
|
gridDimZ,
|
||||||
|
blockDimX,
|
||||||
|
blockDimY,
|
||||||
|
blockDimZ,
|
||||||
|
sharedMemBytes,
|
||||||
|
hStream,
|
||||||
|
attrs: ptr::null_mut(),
|
||||||
|
numAttrs: 0,
|
||||||
|
},
|
||||||
|
kernel_params,
|
||||||
|
pre_state,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
pub(crate) fn cuLaunchKernelEx_Pre(
|
||||||
|
config: *const cuda_types::cuda::CUlaunchConfig,
|
||||||
|
f: cuda_types::cuda::CUfunction,
|
||||||
|
kernel_params: *mut *mut ::core::ffi::c_void,
|
||||||
|
_extra: *mut *mut ::core::ffi::c_void,
|
||||||
|
libcuda: &mut CudaDynamicFns,
|
||||||
|
state: &mut trace::StateTracker,
|
||||||
|
fn_logger: &mut FnCallLog,
|
||||||
|
) -> Option<replay::LaunchPreState> {
|
||||||
|
launch_kernel_pre(
|
||||||
|
f,
|
||||||
|
unsafe { *config },
|
||||||
|
unsafe { *config }.hStream,
|
||||||
|
kernel_params,
|
||||||
|
libcuda,
|
||||||
|
state,
|
||||||
|
fn_logger,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
pub(crate) fn cuLaunchKernelEx_Post(
|
||||||
|
config: *const cuda_types::cuda::CUlaunchConfig,
|
||||||
|
_f: cuda_types::cuda::CUfunction,
|
||||||
|
kernel_params: *mut *mut ::core::ffi::c_void,
|
||||||
|
_extra: *mut *mut ::core::ffi::c_void,
|
||||||
|
pre_state: Option<replay::LaunchPreState>,
|
||||||
|
libcuda: &mut CudaDynamicFns,
|
||||||
|
state: &mut trace::StateTracker,
|
||||||
|
fn_logger: &mut FnCallLog,
|
||||||
|
_result: CUresult,
|
||||||
|
) {
|
||||||
|
let pre_state = unwrap_some_or!(pre_state, return);
|
||||||
|
replay::post_kernel_launch(
|
||||||
|
libcuda,
|
||||||
|
state,
|
||||||
|
fn_logger,
|
||||||
|
unsafe { *config },
|
||||||
|
kernel_params,
|
||||||
|
pre_state,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
use super::Settings;
|
use super::Settings;
|
||||||
|
use crate::trace::SendablePtr;
|
||||||
use crate::FnCallLog;
|
use crate::FnCallLog;
|
||||||
use crate::LogEntry;
|
use crate::LogEntry;
|
||||||
use cuda_types::cuda::*;
|
use cuda_types::cuda::*;
|
||||||
use format::CudaDisplay;
|
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::ffi::NulError;
|
use std::ffi::NulError;
|
||||||
|
@ -267,13 +267,12 @@ pub(crate) enum ErrorEntry {
|
||||||
CreatedDumpDirectory(PathBuf),
|
CreatedDumpDirectory(PathBuf),
|
||||||
ErrorBox(Box<dyn Error>),
|
ErrorBox(Box<dyn Error>),
|
||||||
UnsupportedModule {
|
UnsupportedModule {
|
||||||
module: CUmodule,
|
handle: *mut c_void,
|
||||||
raw_image: *const c_void,
|
raw_image: *const c_void,
|
||||||
kind: &'static str,
|
kind: &'static str,
|
||||||
},
|
},
|
||||||
FunctionNotFound(CudaFunctionName),
|
FunctionNotFound(CudaFunctionName),
|
||||||
MalformedModulePath(Utf8Error),
|
Utf8Error(Utf8Error),
|
||||||
NonUtf8ModuleText(Utf8Error),
|
|
||||||
NulInsideModuleText(NulError),
|
NulInsideModuleText(NulError),
|
||||||
ModuleParsingError(String),
|
ModuleParsingError(String),
|
||||||
Lz4DecompressionFailure,
|
Lz4DecompressionFailure,
|
||||||
|
@ -302,8 +301,11 @@ pub(crate) enum ErrorEntry {
|
||||||
overriden: [u64; 2],
|
overriden: [u64; 2],
|
||||||
},
|
},
|
||||||
NullPointer(&'static str),
|
NullPointer(&'static str),
|
||||||
UnknownLibrary(CUlibrary),
|
|
||||||
SavedModule(String),
|
SavedModule(String),
|
||||||
|
UnknownFunctionHandle(CUfunction),
|
||||||
|
UnknownLibrary(CUfunction, SendablePtr),
|
||||||
|
UnknownFunction(CUfunction, SendablePtr, String),
|
||||||
|
CudaError(Option<CUerror>),
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl Send for ErrorEntry {}
|
unsafe impl Send for ErrorEntry {}
|
||||||
|
@ -345,94 +347,100 @@ impl Display for ErrorEntry {
|
||||||
match self {
|
match self {
|
||||||
ErrorEntry::IoError(e) => e.fmt(f),
|
ErrorEntry::IoError(e) => e.fmt(f),
|
||||||
ErrorEntry::CreatedDumpDirectory(dir) => {
|
ErrorEntry::CreatedDumpDirectory(dir) => {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
"Created trace directory {} ",
|
"Created trace directory {} ",
|
||||||
dir.as_os_str().to_string_lossy()
|
dir.as_os_str().to_string_lossy()
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
ErrorEntry::ErrorBox(e) => e.fmt(f),
|
ErrorEntry::ErrorBox(e) => e.fmt(f),
|
||||||
ErrorEntry::UnsupportedModule {
|
ErrorEntry::UnsupportedModule {
|
||||||
module,
|
handle,
|
||||||
raw_image,
|
raw_image,
|
||||||
kind,
|
kind,
|
||||||
} => {
|
} => {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
"Unsupported {} module {:?} loaded from module image {:?}",
|
"Unsupported {} module {:p} loaded from module image {:p}",
|
||||||
kind, module, raw_image
|
kind, handle, raw_image
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
ErrorEntry::MalformedModulePath(e) => e.fmt(f),
|
ErrorEntry::Utf8Error(e) => e.fmt(f),
|
||||||
ErrorEntry::NonUtf8ModuleText(e) => e.fmt(f),
|
|
||||||
ErrorEntry::ModuleParsingError(file_name) => {
|
ErrorEntry::ModuleParsingError(file_name) => {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
"Error parsing module, log has been written to {}",
|
"Error parsing module, log has been written to {}",
|
||||||
file_name
|
file_name
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
ErrorEntry::NulInsideModuleText(e) => e.fmt(f),
|
ErrorEntry::NulInsideModuleText(e) => e.fmt(f),
|
||||||
ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"),
|
ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"),
|
||||||
ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)),
|
ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)),
|
||||||
ErrorEntry::UnexpectedBinaryField {
|
ErrorEntry::UnexpectedBinaryField {
|
||||||
field_name,
|
field_name,
|
||||||
expected,
|
expected,
|
||||||
observed,
|
observed,
|
||||||
} => write!(
|
} => write!(
|
||||||
f,
|
f,
|
||||||
"Unexpected field {}. Expected one of: [{}], observed: {}",
|
"Unexpected field {}. Expected one of: [{}], observed: {}",
|
||||||
field_name,
|
field_name,
|
||||||
expected
|
expected
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| x.to_string())
|
.map(|x| x.to_string())
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(", "),
|
.join(", "),
|
||||||
observed
|
observed
|
||||||
),
|
),
|
||||||
ErrorEntry::UnexpectedArgument {
|
ErrorEntry::UnexpectedArgument {
|
||||||
arg_name,
|
arg_name,
|
||||||
expected,
|
expected,
|
||||||
observed,
|
observed,
|
||||||
} => write!(
|
} => write!(
|
||||||
f,
|
f,
|
||||||
"Unexpected argument {}. Expected one of: {{{}}}, observed: {}",
|
"Unexpected argument {}. Expected one of: {{{}}}, observed: {}",
|
||||||
arg_name,
|
arg_name,
|
||||||
expected
|
expected
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| x.to_string())
|
.map(|x| x.to_string())
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(", "),
|
.join(", "),
|
||||||
observed
|
observed
|
||||||
),
|
),
|
||||||
ErrorEntry::InvalidEnvVar {
|
ErrorEntry::InvalidEnvVar {
|
||||||
var,
|
var,
|
||||||
pattern,
|
pattern,
|
||||||
value,
|
value,
|
||||||
} => write!(
|
} => write!(
|
||||||
f,
|
f,
|
||||||
"Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}"
|
"Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}"
|
||||||
),
|
),
|
||||||
ErrorEntry::FunctionNotFound(cuda_function_name) => write!(
|
ErrorEntry::FunctionNotFound(cuda_function_name) => write!(
|
||||||
f,
|
f,
|
||||||
"No function {cuda_function_name} in the underlying library"
|
"No function {cuda_function_name} in the underlying library"
|
||||||
),
|
),
|
||||||
ErrorEntry::UnexpectedExportTableSize { expected, computed } => {
|
ErrorEntry::UnexpectedExportTableSize { expected, computed } => {
|
||||||
write!(f, "Table length mismatch. Expected: {expected}, got: {computed}")
|
write!(f, "Table length mismatch. Expected: {expected}, got: {computed}")
|
||||||
}
|
}
|
||||||
ErrorEntry::IntegrityCheck { original, overriden } => {
|
ErrorEntry::IntegrityCheck { original, overriden } => {
|
||||||
write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}")
|
write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}")
|
||||||
}
|
}
|
||||||
ErrorEntry::NullPointer(type_) => {
|
ErrorEntry::NullPointer(type_) => {
|
||||||
write!(f, "Null pointer of type {type_} encountered")
|
write!(f, "Null pointer of type {type_} encountered")
|
||||||
}
|
}
|
||||||
ErrorEntry::UnknownLibrary(culibrary) => {
|
|
||||||
write!(f, "Unknown library: ")?;
|
|
||||||
let mut temp_buffer = Vec::new();
|
|
||||||
CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok();
|
|
||||||
f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) })
|
|
||||||
}
|
|
||||||
ErrorEntry::SavedModule(file) => write!(f, "Saved module to {file}"),
|
ErrorEntry::SavedModule(file) => write!(f, "Saved module to {file}"),
|
||||||
|
ErrorEntry::UnknownFunctionHandle(cuda_function_name) => {
|
||||||
|
write!(f, "Function with unknown provenance: {cuda_function_name:p}")
|
||||||
|
}
|
||||||
|
ErrorEntry::UnknownLibrary(cuda_function_name, owner) => {
|
||||||
|
write!(f, "Function with unknown provenance: {cuda_function_name:p}, owner: {owner:p}")
|
||||||
|
}
|
||||||
|
ErrorEntry::UnknownFunction(cuda_function_name, owner, name) => {
|
||||||
|
write!(f, "Function with unknown provenance: {cuda_function_name:p}, owner: {owner:p}, name: {name}")
|
||||||
|
}
|
||||||
|
ErrorEntry::CudaError(cuerror) => {
|
||||||
|
let cuerror = cuerror.map(|e| e.0);
|
||||||
|
write!(f, "CUDA error encountered: {cuerror:#?}")
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
171
zluda_trace/src/replay.rs
Normal file
171
zluda_trace/src/replay.rs
Normal file
|
@ -0,0 +1,171 @@
|
||||||
|
use crate::{
|
||||||
|
log::ErrorEntry,
|
||||||
|
trace::{self, ParsedModule, SavedKernel},
|
||||||
|
CudaDynamicFns, FnCallLog,
|
||||||
|
};
|
||||||
|
use cuda_types::cuda::*;
|
||||||
|
use zluda_trace_common::replay::KernelParameter;
|
||||||
|
|
||||||
|
pub struct LaunchPreState {
|
||||||
|
kernel_name: String,
|
||||||
|
source: String,
|
||||||
|
kernel_params: Vec<KernelParameter>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn pre_kernel_launch(
|
||||||
|
libcuda: &mut CudaDynamicFns,
|
||||||
|
state: &mut trace::StateTracker,
|
||||||
|
fn_logger: &mut FnCallLog,
|
||||||
|
config: CUlaunchConfig,
|
||||||
|
f: CUfunction,
|
||||||
|
stream: CUstream,
|
||||||
|
args: *mut *mut std::ffi::c_void,
|
||||||
|
) -> Option<LaunchPreState> {
|
||||||
|
fn_logger.try_cuda(|| libcuda.cuStreamSynchronize(stream))?;
|
||||||
|
let SavedKernel { name, owner } = fn_logger.try_return(|| {
|
||||||
|
state
|
||||||
|
.kernels
|
||||||
|
.get(&f)
|
||||||
|
.ok_or(ErrorEntry::UnknownFunctionHandle(f))
|
||||||
|
})?;
|
||||||
|
let kernel_name_filter = state.kernel_name_filter.as_ref()?;
|
||||||
|
if !kernel_name_filter.is_match(name) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let ParsedModule { source, kernels } = fn_logger.try_return(|| {
|
||||||
|
state
|
||||||
|
.parsed_libraries
|
||||||
|
.get(owner)
|
||||||
|
.ok_or(ErrorEntry::UnknownLibrary(f, *owner))
|
||||||
|
})?;
|
||||||
|
let kernel_params = fn_logger.try_return(|| {
|
||||||
|
kernels
|
||||||
|
.get(name)
|
||||||
|
.ok_or_else(|| ErrorEntry::UnknownFunction(f, *owner, name.clone()))
|
||||||
|
})?;
|
||||||
|
let raw_args = unsafe { std::slice::from_raw_parts(args, kernel_params.len()) };
|
||||||
|
let mut all_params = Vec::new();
|
||||||
|
for (raw_arg, layout) in raw_args.iter().zip(kernel_params.iter()) {
|
||||||
|
let mut offset = 0;
|
||||||
|
let mut ptr_overrides = Vec::new();
|
||||||
|
while offset + std::mem::size_of::<usize>() <= layout.size() {
|
||||||
|
let maybe_ptr = unsafe { raw_arg.cast::<u8>().add(offset) };
|
||||||
|
let maybe_ptr = unsafe { maybe_ptr.cast::<usize>().read_unaligned() };
|
||||||
|
let attrs = &mut [
|
||||||
|
CUpointer_attribute_enum::CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||||
|
CUpointer_attribute_enum::CU_POINTER_ATTRIBUTE_RANGE_SIZE,
|
||||||
|
];
|
||||||
|
let mut start = 0usize;
|
||||||
|
let mut size = 0usize;
|
||||||
|
let mut data = [
|
||||||
|
(&mut start as *mut usize).cast::<std::ffi::c_void>(),
|
||||||
|
(&mut size as *mut usize).cast::<std::ffi::c_void>(),
|
||||||
|
];
|
||||||
|
fn_logger.try_cuda(|| {
|
||||||
|
libcuda.cuPointerGetAttributes(
|
||||||
|
2,
|
||||||
|
attrs.as_mut_ptr(),
|
||||||
|
data.as_mut_ptr(),
|
||||||
|
CUdeviceptr_v2(maybe_ptr as _),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
if size != 0 {
|
||||||
|
let mut pre_buffer = vec![0u8; size];
|
||||||
|
let post_buffer = vec![0u8; size];
|
||||||
|
fn_logger.try_cuda(|| {
|
||||||
|
libcuda.cuMemcpyDtoH_v2(
|
||||||
|
pre_buffer.as_mut_ptr().cast(),
|
||||||
|
CUdeviceptr_v2(start as _),
|
||||||
|
size,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let buffer_offset = maybe_ptr - start;
|
||||||
|
ptr_overrides.push((offset, buffer_offset, pre_buffer, post_buffer));
|
||||||
|
}
|
||||||
|
offset += std::mem::size_of::<usize>();
|
||||||
|
}
|
||||||
|
all_params.push(KernelParameter {
|
||||||
|
data: unsafe { std::slice::from_raw_parts(raw_arg.cast::<u8>(), layout.size()) }
|
||||||
|
.to_vec(),
|
||||||
|
device_ptrs: ptr_overrides,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if state.kernel_no_output {
|
||||||
|
let enqueue_counter = state.enqueue_counter;
|
||||||
|
let kernel_name = name;
|
||||||
|
let mut path = state.dump_dir()?.to_path_buf();
|
||||||
|
path.push(format!("kernel_{enqueue_counter}_{kernel_name}.tar.zst"));
|
||||||
|
let file = fn_logger
|
||||||
|
.try_return(|| std::fs::File::create_new(path).map_err(ErrorEntry::IoError))?;
|
||||||
|
fn_logger.try_return(|| {
|
||||||
|
zluda_trace_common::replay::save(
|
||||||
|
file,
|
||||||
|
name.to_string(),
|
||||||
|
false,
|
||||||
|
zluda_trace_common::replay::LaunchConfig {
|
||||||
|
grid_dim: (config.gridDimX, config.gridDimY, config.gridDimZ),
|
||||||
|
block_dim: (config.blockDimX, config.blockDimY, config.blockDimZ),
|
||||||
|
shared_mem_bytes: config.sharedMemBytes,
|
||||||
|
},
|
||||||
|
source.to_string(),
|
||||||
|
all_params,
|
||||||
|
)
|
||||||
|
.map_err(ErrorEntry::IoError)
|
||||||
|
});
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(LaunchPreState {
|
||||||
|
kernel_name: name.to_string(),
|
||||||
|
source: source.to_string(),
|
||||||
|
kernel_params: all_params,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn post_kernel_launch(
|
||||||
|
libcuda: &mut CudaDynamicFns,
|
||||||
|
state: &trace::StateTracker,
|
||||||
|
fn_logger: &mut FnCallLog,
|
||||||
|
config: CUlaunchConfig,
|
||||||
|
kernel_params: *mut *mut std::ffi::c_void,
|
||||||
|
mut pre_state: LaunchPreState,
|
||||||
|
) -> Option<()> {
|
||||||
|
fn_logger.try_cuda(|| libcuda.cuStreamSynchronize(config.hStream))?;
|
||||||
|
let raw_args =
|
||||||
|
unsafe { std::slice::from_raw_parts(kernel_params, pre_state.kernel_params.len()) };
|
||||||
|
for (raw_arg, param) in raw_args.iter().zip(pre_state.kernel_params.iter_mut()) {
|
||||||
|
for (offset_in_param, offset_in_buffer, _, data_after) in param.device_ptrs.iter_mut() {
|
||||||
|
let dev_ptr_param = unsafe { raw_arg.cast::<u8>().add(*offset_in_param) };
|
||||||
|
let mut dev_ptr = unsafe { dev_ptr_param.cast::<usize>().read_unaligned() };
|
||||||
|
dev_ptr -= *offset_in_buffer;
|
||||||
|
fn_logger.try_cuda(|| {
|
||||||
|
libcuda.cuMemcpyDtoH_v2(
|
||||||
|
data_after.as_mut_ptr().cast(),
|
||||||
|
CUdeviceptr_v2(dev_ptr as _),
|
||||||
|
data_after.len(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let enqueue_counter = state.enqueue_counter;
|
||||||
|
let kernel_name = &pre_state.kernel_name;
|
||||||
|
let mut path = state.dump_dir()?.to_path_buf();
|
||||||
|
path.push(format!("kernel_{enqueue_counter}_{kernel_name}.tar.zst"));
|
||||||
|
let file =
|
||||||
|
fn_logger.try_return(|| std::fs::File::create_new(path).map_err(ErrorEntry::IoError))?;
|
||||||
|
fn_logger.try_return(|| {
|
||||||
|
zluda_trace_common::replay::save(
|
||||||
|
file,
|
||||||
|
pre_state.kernel_name,
|
||||||
|
true,
|
||||||
|
zluda_trace_common::replay::LaunchConfig {
|
||||||
|
grid_dim: (config.gridDimX, config.gridDimY, config.gridDimZ),
|
||||||
|
block_dim: (config.blockDimX, config.blockDimY, config.blockDimZ),
|
||||||
|
shared_mem_bytes: config.sharedMemBytes,
|
||||||
|
},
|
||||||
|
pre_state.source,
|
||||||
|
pre_state.kernel_params,
|
||||||
|
)
|
||||||
|
.map_err(ErrorEntry::IoError)
|
||||||
|
})
|
||||||
|
}
|
|
@ -4,8 +4,9 @@ use crate::{
|
||||||
};
|
};
|
||||||
use cuda_types::cuda::*;
|
use cuda_types::cuda::*;
|
||||||
use goblin::{elf, elf32, elf64};
|
use goblin::{elf, elf32, elf64};
|
||||||
use rustc_hash::{FxHashMap, FxHashSet};
|
use rustc_hash::FxHashMap;
|
||||||
use std::{
|
use std::{
|
||||||
|
alloc::Layout,
|
||||||
ffi::{c_void, CStr, CString},
|
ffi::{c_void, CStr, CString},
|
||||||
fs::{self, File},
|
fs::{self, File},
|
||||||
io::{self, Read, Write},
|
io::{self, Read, Write},
|
||||||
|
@ -20,29 +21,51 @@ use unwrap_or::unwrap_some_or;
|
||||||
// * writes out relevant state change and details to disk and log
|
// * writes out relevant state change and details to disk and log
|
||||||
pub(crate) struct StateTracker {
|
pub(crate) struct StateTracker {
|
||||||
writer: DumpWriter,
|
writer: DumpWriter,
|
||||||
pub(crate) libraries: FxHashMap<CUlibrary, CodePointer>,
|
pub(crate) parsed_libraries: FxHashMap<SendablePtr, ParsedModule>,
|
||||||
saved_modules: FxHashSet<CUmodule>,
|
pub(crate) submodules: FxHashMap<CUmodule, CUlibrary>,
|
||||||
|
pub(crate) kernels: FxHashMap<CUfunction, SavedKernel>,
|
||||||
library_counter: usize,
|
library_counter: usize,
|
||||||
|
pub(crate) enqueue_counter: usize,
|
||||||
pub(crate) override_cc: Option<(u32, u32)>,
|
pub(crate) override_cc: Option<(u32, u32)>,
|
||||||
|
pub(crate) kernel_name_filter: Option<regex::Regex>,
|
||||||
|
pub(crate) kernel_no_output: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
pub(crate) struct ParsedModule {
|
||||||
pub(crate) struct CodePointer(pub *const c_void);
|
pub source: String,
|
||||||
|
pub kernels: FxHashMap<String, Vec<Layout>>,
|
||||||
|
}
|
||||||
|
|
||||||
unsafe impl Send for CodePointer {}
|
pub(crate) struct SavedKernel {
|
||||||
unsafe impl Sync for CodePointer {}
|
pub name: String,
|
||||||
|
pub owner: SendablePtr,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub(crate) struct SendablePtr(*mut c_void);
|
||||||
|
|
||||||
|
unsafe impl Send for SendablePtr {}
|
||||||
|
unsafe impl Sync for SendablePtr {}
|
||||||
|
|
||||||
impl StateTracker {
|
impl StateTracker {
|
||||||
pub(crate) fn new(settings: &Settings) -> Self {
|
pub(crate) fn new(settings: &Settings) -> Self {
|
||||||
StateTracker {
|
StateTracker {
|
||||||
writer: DumpWriter::new(settings.dump_dir.clone()),
|
writer: DumpWriter::new(settings.dump_dir.clone()),
|
||||||
libraries: FxHashMap::default(),
|
parsed_libraries: FxHashMap::default(),
|
||||||
saved_modules: FxHashSet::default(),
|
submodules: FxHashMap::default(),
|
||||||
|
kernels: FxHashMap::default(),
|
||||||
library_counter: 0,
|
library_counter: 0,
|
||||||
|
enqueue_counter: 0,
|
||||||
override_cc: settings.override_cc,
|
override_cc: settings.override_cc,
|
||||||
|
kernel_name_filter: settings.kernel_name_filter.clone(),
|
||||||
|
kernel_no_output: settings.kernel_no_output.unwrap_or(false),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn dump_dir(&self) -> Option<&PathBuf> {
|
||||||
|
self.writer.dump_dir.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn record_new_module_file(
|
pub(crate) fn record_new_module_file(
|
||||||
&mut self,
|
&mut self,
|
||||||
module: CUmodule,
|
module: CUmodule,
|
||||||
|
@ -52,7 +75,7 @@ impl StateTracker {
|
||||||
let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() {
|
let file_name = match unsafe { CStr::from_ptr(file_name) }.to_str() {
|
||||||
Ok(f) => f,
|
Ok(f) => f,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
fn_logger.log(log::ErrorEntry::MalformedModulePath(err));
|
fn_logger.log(log::ErrorEntry::Utf8Error(err));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -69,21 +92,26 @@ impl StateTracker {
|
||||||
let mut module_file = fs::File::open(file_name)?;
|
let mut module_file = fs::File::open(file_name)?;
|
||||||
let mut read_buff = Vec::new();
|
let mut read_buff = Vec::new();
|
||||||
module_file.read_to_end(&mut read_buff)?;
|
module_file.read_to_end(&mut read_buff)?;
|
||||||
self.record_new_library(module, read_buff.as_ptr() as *const _, fn_logger);
|
self.record_new_library(module.0.cast(), read_buff.as_ptr() as *const _, fn_logger);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn record_new_library(
|
pub(crate) fn record_new_library(
|
||||||
&mut self,
|
&mut self,
|
||||||
cu_module: CUmodule,
|
handle: *mut c_void,
|
||||||
raw_image: *const c_void,
|
raw_image: *const c_void,
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
) {
|
) {
|
||||||
self.saved_modules.insert(cu_module);
|
fn overwrite<T>(current: &mut Option<T>, value: Option<T>) {
|
||||||
|
if value.is_some() {
|
||||||
|
*current = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut kernel_arguments = None;
|
||||||
self.library_counter += 1;
|
self.library_counter += 1;
|
||||||
let code_ref = fn_logger.try_return(|| {
|
let code_ref = fn_logger.try_return(|| {
|
||||||
unsafe { zluda_common::CodeLibraryRef::try_load(raw_image) }
|
unsafe { zluda_common::CodeLibraryRef::try_load(raw_image) }
|
||||||
.map_err(ErrorEntry::NonUtf8ModuleText)
|
.map_err(ErrorEntry::Utf8Error)
|
||||||
});
|
});
|
||||||
let code_ref = unwrap_some_or!(code_ref, return);
|
let code_ref = unwrap_some_or!(code_ref, return);
|
||||||
unsafe {
|
unsafe {
|
||||||
|
@ -92,17 +120,20 @@ impl StateTracker {
|
||||||
Ok(zluda_common::CodeModuleRef::Elf(elf)) => match get_elf_size(elf) {
|
Ok(zluda_common::CodeModuleRef::Elf(elf)) => match get_elf_size(elf) {
|
||||||
Some(len) => {
|
Some(len) => {
|
||||||
let elf_image = std::slice::from_raw_parts(elf.cast::<u8>(), len);
|
let elf_image = std::slice::from_raw_parts(elf.cast::<u8>(), len);
|
||||||
self.record_new_submodule(index, elf_image, fn_logger, "elf");
|
overwrite(
|
||||||
|
&mut kernel_arguments,
|
||||||
|
self.record_new_submodule(index, elf_image, fn_logger, "elf"),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
None => fn_logger.log(log::ErrorEntry::UnsupportedModule {
|
None => fn_logger.log(log::ErrorEntry::UnsupportedModule {
|
||||||
module: cu_module,
|
handle,
|
||||||
raw_image: elf,
|
raw_image: elf,
|
||||||
kind: "ELF",
|
kind: "ELF",
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
Ok(zluda_common::CodeModuleRef::Archive(archive)) => {
|
Ok(zluda_common::CodeModuleRef::Archive(archive)) => {
|
||||||
fn_logger.log(log::ErrorEntry::UnsupportedModule {
|
fn_logger.log(log::ErrorEntry::UnsupportedModule {
|
||||||
module: cu_module,
|
handle,
|
||||||
raw_image: archive,
|
raw_image: archive,
|
||||||
kind: "archive",
|
kind: "archive",
|
||||||
})
|
})
|
||||||
|
@ -111,23 +142,39 @@ impl StateTracker {
|
||||||
if let Some(buffer) = fn_logger
|
if let Some(buffer) = fn_logger
|
||||||
.try_(|_| file.get_or_decompress_content().map_err(ErrorEntry::from))
|
.try_(|_| file.get_or_decompress_content().map_err(ErrorEntry::from))
|
||||||
{
|
{
|
||||||
self.record_new_submodule(index, &*buffer, fn_logger, file.kind());
|
overwrite(
|
||||||
|
&mut kernel_arguments,
|
||||||
|
self.record_new_submodule(index, &*buffer, fn_logger, file.kind()),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(zluda_common::CodeModuleRef::Text(ptx)) => {
|
Ok(zluda_common::CodeModuleRef::Text(ptx)) => {
|
||||||
self.record_new_submodule(index, ptx.as_bytes(), fn_logger, "ptx");
|
overwrite(
|
||||||
|
&mut kernel_arguments,
|
||||||
|
self.record_new_submodule(index, ptx.as_bytes(), fn_logger, "ptx"),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
if let Some((source, kernel_arguments)) = kernel_arguments {
|
||||||
|
self.parsed_libraries.insert(
|
||||||
|
SendablePtr(handle),
|
||||||
|
ParsedModule {
|
||||||
|
source,
|
||||||
|
kernels: kernel_arguments,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
pub(crate) fn record_new_submodule(
|
pub(crate) fn record_new_submodule(
|
||||||
&mut self,
|
&mut self,
|
||||||
index: Option<(usize, Option<usize>)>,
|
index: Option<(usize, Option<usize>)>,
|
||||||
submodule: &[u8],
|
submodule: &[u8],
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
type_: &'static str,
|
type_: &'static str,
|
||||||
) {
|
) -> Option<(String, FxHashMap<String, Vec<Layout>>)> {
|
||||||
fn_logger.try_(|fn_logger| {
|
fn_logger.try_(|fn_logger| {
|
||||||
self.writer
|
self.writer
|
||||||
.save_module(fn_logger, self.library_counter, index, submodule, type_)
|
.save_module(fn_logger, self.library_counter, index, submodule, type_)
|
||||||
|
@ -135,28 +182,36 @@ impl StateTracker {
|
||||||
});
|
});
|
||||||
if type_ == "ptx" {
|
if type_ == "ptx" {
|
||||||
match CString::new(submodule) {
|
match CString::new(submodule) {
|
||||||
Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)),
|
Err(e) => {
|
||||||
|
fn_logger.log(log::ErrorEntry::NulInsideModuleText(e));
|
||||||
|
None
|
||||||
|
}
|
||||||
Ok(submodule_cstring) => match submodule_cstring.to_str() {
|
Ok(submodule_cstring) => match submodule_cstring.to_str() {
|
||||||
Err(e) => fn_logger.log(log::ErrorEntry::NonUtf8ModuleText(e)),
|
Err(e) => {
|
||||||
Ok(submodule_text) => self.try_parse_and_record_kernels(
|
fn_logger.log(log::ErrorEntry::Utf8Error(e));
|
||||||
|
None
|
||||||
|
}
|
||||||
|
Ok(submodule_text) => Some(self.try_parse_and_record_kernels(
|
||||||
fn_logger,
|
fn_logger,
|
||||||
self.library_counter,
|
self.library_counter,
|
||||||
index,
|
index,
|
||||||
submodule_text,
|
submodule_text,
|
||||||
),
|
)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn try_parse_and_record_kernels(
|
fn try_parse_and_record_kernels<'input>(
|
||||||
&mut self,
|
&mut self,
|
||||||
fn_logger: &mut FnCallLog,
|
fn_logger: &mut FnCallLog,
|
||||||
module_index: usize,
|
module_index: usize,
|
||||||
submodule_index: Option<(usize, Option<usize>)>,
|
submodule_index: Option<(usize, Option<usize>)>,
|
||||||
module_text: &str,
|
module_text: &'input str,
|
||||||
) {
|
) -> (String, FxHashMap<String, Vec<Layout>>) {
|
||||||
let errors = ptx_parser::parse_for_errors(module_text);
|
let (errors, params) = ptx_parser::parse_for_errors_and_params(module_text);
|
||||||
if !errors.is_empty() {
|
if !errors.is_empty() {
|
||||||
fn_logger.log(log::ErrorEntry::ModuleParsingError(
|
fn_logger.log(log::ErrorEntry::ModuleParsingError(
|
||||||
DumpWriter::get_file_name(module_index, submodule_index, "log"),
|
DumpWriter::get_file_name(module_index, submodule_index, "log"),
|
||||||
|
@ -167,6 +222,46 @@ impl StateTracker {
|
||||||
&*errors,
|
&*errors,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
(module_text.to_string(), params)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn record_module_in_library(&mut self, module: CUmodule, library: CUlibrary) {
|
||||||
|
self.submodules.insert(module, library);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn record_function_from_module(
|
||||||
|
&mut self,
|
||||||
|
fn_logger: &mut FnCallLog,
|
||||||
|
func: CUfunction,
|
||||||
|
hmod: CUmodule,
|
||||||
|
name: *const i8,
|
||||||
|
) {
|
||||||
|
let owner = match self.submodules.get(&hmod) {
|
||||||
|
Some(m) => m.0.cast::<c_void>(),
|
||||||
|
None => hmod.0.cast::<c_void>(),
|
||||||
|
};
|
||||||
|
self.record_function_from_impl(fn_logger, func, owner, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_function_from_impl(
|
||||||
|
&mut self,
|
||||||
|
fn_logger: &mut FnCallLog,
|
||||||
|
func: CUfunction,
|
||||||
|
owner: *mut c_void,
|
||||||
|
name: *const i8,
|
||||||
|
) {
|
||||||
|
let name = match unsafe { CStr::from_ptr(name) }.to_str() {
|
||||||
|
Ok(f) => f,
|
||||||
|
Err(err) => {
|
||||||
|
fn_logger.log(log::ErrorEntry::Utf8Error(err));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let saved_kernel = SavedKernel {
|
||||||
|
name: name.to_string(),
|
||||||
|
owner: SendablePtr(owner),
|
||||||
|
};
|
||||||
|
self.kernels.insert(func, saved_kernel);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,11 @@ cuda_types = { path = "../cuda_types" }
|
||||||
dark_api = { path = "../dark_api" }
|
dark_api = { path = "../dark_api" }
|
||||||
format = { path = "../format" }
|
format = { path = "../format" }
|
||||||
cglue = "0.3.5"
|
cglue = "0.3.5"
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
serde_json = "1.0.142"
|
||||||
|
tar = "0.4"
|
||||||
|
zstd = "0.13"
|
||||||
|
rustc-hash = "2.0.0"
|
||||||
|
|
||||||
[target.'cfg(not(windows))'.dependencies]
|
[target.'cfg(not(windows))'.dependencies]
|
||||||
libc = "0.2"
|
libc = "0.2"
|
||||||
|
|
|
@ -8,6 +8,8 @@ use cuda_types::{
|
||||||
use dark_api::ByteVecFfi;
|
use dark_api::ByteVecFfi;
|
||||||
use std::{borrow::Cow, ffi::c_void, num::NonZero, ptr, sync::LazyLock};
|
use std::{borrow::Cow, ffi::c_void, num::NonZero, ptr, sync::LazyLock};
|
||||||
|
|
||||||
|
pub mod replay;
|
||||||
|
|
||||||
pub fn get_export_table() -> Option<::dark_api::zluda_trace::ZludaTraceInternal> {
|
pub fn get_export_table() -> Option<::dark_api::zluda_trace::ZludaTraceInternal> {
|
||||||
static CU_GET_EXPORT_TABLE: LazyLock<
|
static CU_GET_EXPORT_TABLE: LazyLock<
|
||||||
Result<
|
Result<
|
||||||
|
|
137
zluda_trace_common/src/replay.rs
Normal file
137
zluda_trace_common/src/replay.rs
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
use rustc_hash::FxHashMap;
|
||||||
|
use std::io::{Read, Write};
|
||||||
|
use tar::Header;
|
||||||
|
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize)]
|
||||||
|
pub struct Manifest {
|
||||||
|
pub kernel_name: String,
|
||||||
|
pub outputs: bool,
|
||||||
|
pub config: LaunchConfig,
|
||||||
|
pub parameters: Vec<Parameter>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize)]
|
||||||
|
pub struct LaunchConfig {
|
||||||
|
pub grid_dim: (u32, u32, u32),
|
||||||
|
pub block_dim: (u32, u32, u32),
|
||||||
|
pub shared_mem_bytes: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize)]
|
||||||
|
pub struct Parameter {
|
||||||
|
pub pointer_offsets: Vec<ParameterPointer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize)]
|
||||||
|
pub struct ParameterPointer {
|
||||||
|
pub offset_in_param: usize,
|
||||||
|
pub offset_in_buffer: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Manifest {
|
||||||
|
const PATH: &'static str = "manifest.json";
|
||||||
|
|
||||||
|
fn serialize(&self) -> std::io::Result<(Header, Vec<u8>)> {
|
||||||
|
let vec = serde_json::to_vec(self)?;
|
||||||
|
let header = tar_header(vec.len());
|
||||||
|
Ok((header, vec))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct KernelParameter {
|
||||||
|
pub data: Vec<u8>,
|
||||||
|
pub device_ptrs: Vec<(usize, usize, Vec<u8>, Vec<u8>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save(
|
||||||
|
writer: impl Write,
|
||||||
|
kernel_name: String,
|
||||||
|
has_outputs: bool,
|
||||||
|
config: LaunchConfig,
|
||||||
|
source: String,
|
||||||
|
kernel_params: Vec<KernelParameter>,
|
||||||
|
) -> std::io::Result<()> {
|
||||||
|
let archive = zstd::Encoder::new(writer, 0)?;
|
||||||
|
let mut builder = tar::Builder::new(archive);
|
||||||
|
let (mut header, manifest) = Manifest {
|
||||||
|
kernel_name,
|
||||||
|
outputs: has_outputs,
|
||||||
|
config,
|
||||||
|
parameters: kernel_params
|
||||||
|
.iter()
|
||||||
|
.map(|param| Parameter {
|
||||||
|
pointer_offsets: param
|
||||||
|
.device_ptrs
|
||||||
|
.iter()
|
||||||
|
.map(
|
||||||
|
|(offset_in_param, offset_in_buffer, _, _)| ParameterPointer {
|
||||||
|
offset_in_param: *offset_in_param,
|
||||||
|
offset_in_buffer: *offset_in_buffer,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.collect(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
.serialize()?;
|
||||||
|
builder.append_data(&mut header, Manifest::PATH, &*manifest)?;
|
||||||
|
let mut header = tar_header(source.len());
|
||||||
|
builder.append_data(&mut header, "source.ptx", source.as_bytes())?;
|
||||||
|
for (i, param) in kernel_params.into_iter().enumerate() {
|
||||||
|
let path = format!("param_{i}.bin");
|
||||||
|
let mut header = tar_header(param.data.len());
|
||||||
|
builder.append_data(&mut header, &*path, &*param.data)?;
|
||||||
|
for (offset_in_param, _, data_before, data_after) in param.device_ptrs {
|
||||||
|
let path = format!("param_{i}_ptr_{offset_in_param}_pre.bin");
|
||||||
|
let mut header = tar_header(data_before.len());
|
||||||
|
builder.append_data(&mut header, &*path, &*data_before)?;
|
||||||
|
if !has_outputs {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let path = format!("param_{i}_ptr_{offset_in_param}_post.bin");
|
||||||
|
let mut header = tar_header(data_after.len());
|
||||||
|
builder.append_data(&mut header, &*path, &*data_after)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
builder.finish()?;
|
||||||
|
builder.into_inner()?.finish()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tar_header(size: usize) -> Header {
|
||||||
|
let mut header = Header::new_gnu();
|
||||||
|
header.set_mode(0o644);
|
||||||
|
header.set_size(size as u64);
|
||||||
|
header
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load(reader: impl Read) -> (Manifest, String, FxHashMap<String, Vec<u8>>) {
|
||||||
|
let archive = zstd::Decoder::new(reader).unwrap();
|
||||||
|
let mut archive = tar::Archive::new(archive);
|
||||||
|
let mut manifest = None;
|
||||||
|
let mut source = None;
|
||||||
|
let mut buffers = FxHashMap::default();
|
||||||
|
for entry in archive.entries().unwrap() {
|
||||||
|
let mut entry = entry.unwrap();
|
||||||
|
let path = entry.path().unwrap().to_string_lossy().to_string();
|
||||||
|
match &*path {
|
||||||
|
Manifest::PATH => {
|
||||||
|
manifest = Some(serde_json::from_reader::<_, Manifest>(&mut entry).unwrap());
|
||||||
|
}
|
||||||
|
"source.ptx" => {
|
||||||
|
let mut string = String::new();
|
||||||
|
entry.read_to_string(&mut string).unwrap();
|
||||||
|
dbg!(string.len());
|
||||||
|
source = Some(string);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
entry.read_to_end(&mut buffer).unwrap();
|
||||||
|
buffers.insert(path, buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let manifest = manifest.unwrap();
|
||||||
|
let source = source.unwrap();
|
||||||
|
(manifest, source, buffers)
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue