Add replayer

This commit is contained in:
Andrzej Janik 2025-09-20 00:43:29 +00:00
commit 2b9c8946ec
9 changed files with 215 additions and 22 deletions

11
Cargo.lock generated
View file

@ -3826,6 +3826,16 @@ dependencies = [
"winapi",
]
[[package]]
name = "zluda_replay"
version = "0.0.0"
dependencies = [
"cuda_macros",
"cuda_types",
"libloading",
"zluda_trace_common",
]
[[package]]
name = "zluda_sparse"
version = "0.0.0"
@ -3903,6 +3913,7 @@ dependencies = [
"format",
"libc",
"libloading",
"rustc-hash 2.0.0",
"serde",
"serde_json",
"tar",

View file

@ -37,6 +37,7 @@ members = [
"zluda_inject",
"zluda_ld",
"zluda_ml",
"zluda_replay",
"zluda_redirect",
"zluda_sparse",
"compiler",

View file

@ -370,7 +370,7 @@ pub fn parse_for_errors_and_params<'input>(
.func_directive
.input_arguments
.iter()
.map(|arg| arg.v_type.layout())
.map(|arg| arg.info.v_type.layout())
.collect();
Some((func.func_directive.name().to_string(), layouts))
} else {

17
zluda_replay/Cargo.toml Normal file
View file

@ -0,0 +1,17 @@
[package]
name = "zluda_replay"
version = "0.0.0"
authors = ["Andrzej Janik <vosen@vosen.pl>"]
edition = "2021"
[[bin]]
name = "zluda_replay"
[dependencies]
zluda_trace_common = { path = "../zluda_trace_common" }
cuda_macros = { path = "../cuda_macros" }
cuda_types = { path = "../cuda_types" }
libloading = "0.8"
[package.metadata.zluda]
debug_only = true

98
zluda_replay/src/main.rs Normal file
View file

@ -0,0 +1,98 @@
use std::mem;
use cuda_types::cuda::{CUdeviceptr_v2, CUstream};
struct CudaDynamicFns {
handle: libloading::Library,
}
impl CudaDynamicFns {
unsafe fn new(path: &str) -> Result<Self, libloading::Error> {
let handle = libloading::Library::new(path)?;
Ok(Self { handle })
}
}
macro_rules! emit_cuda_fn_table {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
impl CudaDynamicFns {
$(
#[allow(dead_code)]
unsafe fn $fn_name(&self, $($arg_id : $arg_type),*) -> $ret_type {
let func = self.handle.get::<unsafe extern $abi fn ($($arg_type),*) -> $ret_type>(concat!(stringify!($fn_name), "\0").as_bytes());
(func.unwrap())($($arg_id),*)
}
)*
}
};
}
cuda_macros::cuda_function_declarations!(emit_cuda_fn_table);
fn main() {
let args: Vec<String> = std::env::args().collect();
let libcuda = unsafe { CudaDynamicFns::new(&args[1]).unwrap() };
unsafe { libcuda.cuInit(0) }.unwrap();
unsafe { libcuda.cuCtxCreate_v2(&mut mem::zeroed(), 0, 0) }.unwrap();
let reader = std::fs::File::open(&args[2]).unwrap();
let (mut manifest, mut source, mut buffers) = zluda_trace_common::replay::load(reader);
let mut args = manifest
.parameters
.iter()
.enumerate()
.map(|(i, param)| {
let mut buffer = buffers.remove(&format!("param_{i}.bin")).unwrap();
for param_ptr in param.pointer_offsets.iter() {
let buffer_param_slice = &mut buffer[param_ptr.offset_in_param
..param_ptr.offset_in_param + std::mem::size_of::<usize>()];
let mut dev_ptr = unsafe { mem::zeroed() };
let host_buffer = buffers
.remove(&format!(
"param_{i}_ptr_{}_pre.bin",
param_ptr.offset_in_param
))
.unwrap();
unsafe { libcuda.cuMemAlloc_v2(&mut dev_ptr, host_buffer.len()) }.unwrap();
unsafe {
libcuda.cuMemcpyHtoD_v2(dev_ptr, host_buffer.as_ptr().cast(), host_buffer.len())
}
.unwrap();
dev_ptr = CUdeviceptr_v2(unsafe {
dev_ptr
.0
.cast::<u8>()
.add(param_ptr.offset_in_buffer)
.cast()
});
buffer_param_slice.copy_from_slice(&(dev_ptr.0 as usize).to_ne_bytes());
}
})
.collect::<Vec<_>>();
let mut module = unsafe { mem::zeroed() };
std::fs::write("/tmp/source.ptx", &source).unwrap();
source.push('\0');
unsafe { libcuda.cuModuleLoadData(&mut module, source.as_ptr().cast()) }.unwrap();
let mut function = unsafe { mem::zeroed() };
manifest.kernel_name.push('\0');
unsafe {
libcuda.cuModuleGetFunction(&mut function, module, manifest.kernel_name.as_ptr().cast())
}
.unwrap();
unsafe {
libcuda.cuLaunchKernel(
function,
manifest.config.grid_dim.0,
manifest.config.grid_dim.1,
manifest.config.grid_dim.2,
manifest.config.block_dim.0,
manifest.config.block_dim.1,
manifest.config.block_dim.2,
manifest.config.shared_mem_bytes,
CUstream(std::ptr::null_mut()),
args.as_mut_ptr().cast(),
std::ptr::null_mut(),
)
}
.unwrap();
todo!();
}

View file

@ -1552,14 +1552,14 @@ fn launch_kernel_pre(
#[allow(non_snake_case)]
pub(crate) fn cuLaunchKernel_Post(
_f: cuda_types::cuda::CUfunction,
_gridDimX: ::core::ffi::c_uint,
_gridDimY: ::core::ffi::c_uint,
_gridDimZ: ::core::ffi::c_uint,
_blockDimX: ::core::ffi::c_uint,
_blockDimY: ::core::ffi::c_uint,
_blockDimZ: ::core::ffi::c_uint,
_sharedMemBytes: ::core::ffi::c_uint,
stream: cuda_types::cuda::CUstream,
gridDimX: ::core::ffi::c_uint,
gridDimY: ::core::ffi::c_uint,
gridDimZ: ::core::ffi::c_uint,
blockDimX: ::core::ffi::c_uint,
blockDimY: ::core::ffi::c_uint,
blockDimZ: ::core::ffi::c_uint,
sharedMemBytes: ::core::ffi::c_uint,
hStream: cuda_types::cuda::CUstream,
kernel_params: *mut *mut ::core::ffi::c_void,
_extra: *mut *mut ::core::ffi::c_void,
pre_state: Option<replay::LaunchPreState>,
@ -1569,7 +1569,25 @@ pub(crate) fn cuLaunchKernel_Post(
_result: CUresult,
) {
let pre_state = unwrap_some_or!(pre_state, return);
replay::post_kernel_launch(libcuda, state, fn_logger, stream, kernel_params, pre_state);
replay::post_kernel_launch(
libcuda,
state,
fn_logger,
CUlaunchConfig {
gridDimX,
gridDimY,
gridDimZ,
blockDimX,
blockDimY,
blockDimZ,
sharedMemBytes,
hStream,
attrs: ptr::null_mut(),
numAttrs: 0,
},
kernel_params,
pre_state,
);
}
#[allow(non_snake_case)]
@ -1609,7 +1627,7 @@ pub(crate) fn cuLaunchKernelEx_Post(
libcuda,
state,
fn_logger,
unsafe { *config }.hStream,
unsafe { *config },
kernel_params,
pre_state,
);

View file

@ -97,11 +97,11 @@ pub(crate) fn post_kernel_launch(
libcuda: &mut CudaDynamicFns,
state: &trace::StateTracker,
fn_logger: &mut FnCallLog,
stream: CUstream,
config: CUlaunchConfig,
kernel_params: *mut *mut std::ffi::c_void,
mut pre_state: LaunchPreState,
) -> Option<()> {
fn_logger.try_cuda(|| libcuda.cuStreamSynchronize(stream))?;
fn_logger.try_cuda(|| libcuda.cuStreamSynchronize(config.hStream))?;
let raw_args =
unsafe { std::slice::from_raw_parts(kernel_params, pre_state.kernel_params.len()) };
for (raw_arg, param) in raw_args.iter().zip(pre_state.kernel_params.iter_mut()) {
@ -128,6 +128,11 @@ pub(crate) fn post_kernel_launch(
zluda_trace_common::replay::save(
file,
pre_state.kernel_name,
zluda_trace_common::replay::LaunchConfig {
grid_dim: (config.gridDimX, config.gridDimY, config.gridDimZ),
block_dim: (config.blockDimX, config.blockDimY, config.blockDimZ),
shared_mem_bytes: config.sharedMemBytes,
},
pre_state.source,
pre_state.kernel_params,
)

View file

@ -15,6 +15,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.142"
tar = "0.4"
zstd = "0.13"
rustc-hash = "2.0.0"
[target.'cfg(not(windows))'.dependencies]
libc = "0.2"

View file

@ -1,21 +1,30 @@
use std::io::Write;
use rustc_hash::FxHashMap;
use std::io::{Read, Write};
use tar::Header;
#[derive(serde::Serialize, serde::Deserialize)]
struct Manifest {
kernel_name: String,
parameters: Vec<Parameter>,
pub struct Manifest {
pub kernel_name: String,
pub config: LaunchConfig,
pub parameters: Vec<Parameter>,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct Parameter {
pointer_offsets: Vec<ParameterPointer>,
pub struct LaunchConfig {
pub grid_dim: (u32, u32, u32),
pub block_dim: (u32, u32, u32),
pub shared_mem_bytes: u32,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct ParameterPointer {
offset_in_param: usize,
offset_in_buffer: usize,
pub struct Parameter {
pub pointer_offsets: Vec<ParameterPointer>,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct ParameterPointer {
pub offset_in_param: usize,
pub offset_in_buffer: usize,
}
impl Manifest {
@ -37,6 +46,7 @@ pub struct KernelParameter {
pub fn save(
writer: impl Write,
kernel_name: String,
config: LaunchConfig,
source: String,
kernel_params: Vec<KernelParameter>,
) -> std::io::Result<()> {
@ -44,6 +54,7 @@ pub fn save(
let mut builder = tar::Builder::new(archive);
let (mut header, manifest) = Manifest {
kernel_name,
config,
parameters: kernel_params
.iter()
.map(|param| Parameter {
@ -85,3 +96,34 @@ pub fn save(
builder.into_inner()?.finish()?;
Ok(())
}
pub fn load(reader: impl Read) -> (Manifest, String, FxHashMap<String, Vec<u8>>) {
let archive = zstd::Decoder::new(reader).unwrap();
let mut archive = tar::Archive::new(archive);
let mut manifest = None;
let mut source = None;
let mut buffers = FxHashMap::default();
for entry in archive.entries().unwrap() {
let mut entry = entry.unwrap();
let path = entry.path().unwrap().to_string_lossy().to_string();
match &*path {
Manifest::PATH => {
manifest = Some(serde_json::from_reader::<_, Manifest>(&mut entry).unwrap());
}
"source.ptx" => {
let mut string = String::new();
entry.read_to_string(&mut string).unwrap();
dbg!(string.len());
source = Some(string);
}
_ => {
let mut buffer = Vec::new();
entry.read_to_end(&mut buffer).unwrap();
buffers.insert(path, buffer);
}
}
}
let manifest = manifest.unwrap();
let source = source.unwrap();
(manifest, source, buffers)
}