From 62ec652e7c2fe8a8493d3e5e56a7716bf508e523 Mon Sep 17 00:00:00 2001 From: Violet Date: Thu, 18 Sep 2025 19:11:30 -0700 Subject: [PATCH 1/4] Disable virtual memory management (#515) We don't currently support it, so report it as unsupported. --- zluda/src/impl/device.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index ed8bb8c..6e63f4b 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -254,6 +254,10 @@ pub(crate) fn get_attribute( CUdevice_attribute::CU_DEVICE_ATTRIBUTE_UNIFIED_FUNCTION_POINTERS => { return get_device_prop(pi, dev_idx, |props| props.unifiedFunctionPointers) } + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED => { + *pi = 0; + return Ok(()); + } _ => {} } let attrib = remap_attribute! { From 875ac13be285e68d0ffdf7d5910c9736aed212be Mon Sep 17 00:00:00 2001 From: Violet Date: Fri, 19 Sep 2025 13:36:48 -0700 Subject: [PATCH 2/4] Support lists of variables to be declared (#516) For example, ``` .reg .u32 a, b; ``` --- ptx/src/pass/deparamize_functions.rs | 40 +++--- ptx/src/pass/hoist_globals.rs | 18 ++- ptx/src/pass/insert_explicit_load_store.rs | 18 +-- .../instruction_mode_to_global_mode/test.rs | 2 +- ptx/src/pass/llvm/emit.rs | 31 ++--- ptx/src/pass/mod.rs | 38 ++--- ...entifiers2.rs => normalize_identifiers.rs} | 57 ++++---- .../replace_instructions_with_functions.rs | 10 +- ...instructions_with_functions_fp_required.rs | 130 +++++++++++------- ptx/src/pass/test/expand_operands/mod.rs | 2 +- .../test/insert_implicit_conversions/mod.rs | 2 +- ptx/src/test/ll/reg_multi.ll | 37 +++++ ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/test/spirv_run/reg_multi.ptx | 22 +++ ptx_parser/src/ast.rs | 34 +++-- ptx_parser/src/lib.rs | 60 ++++---- 16 files changed, 315 insertions(+), 187 deletions(-) rename ptx/src/pass/{normalize_identifiers2.rs => normalize_identifiers.rs} (81%) create mode 100644 ptx/src/test/ll/reg_multi.ll create mode 100644 ptx/src/test/spirv_run/reg_multi.ptx diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs index e80f6a3..2145b89 100644 --- a/ptx/src/pass/deparamize_functions.rs +++ b/ptx/src/pass/deparamize_functions.rs @@ -29,22 +29,24 @@ fn run_method<'input>( let mut remap_returns = Vec::new(); if !method.is_kernel { for arg in method.return_arguments.iter_mut() { - match arg.state_space { + match arg.info.state_space { ptx_parser::StateSpace::Param => { - arg.state_space = ptx_parser::StateSpace::Reg; + arg.info.state_space = ptx_parser::StateSpace::Reg; let old_name = arg.name; - arg.name = - resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); + arg.name = resolver + .register_unnamed(Some((arg.info.v_type.clone(), arg.info.state_space))); if is_declaration { continue; } - remap_returns.push((old_name, arg.name, arg.v_type.clone())); + remap_returns.push((old_name, arg.name, arg.info.v_type.clone())); body.push(Statement::Variable(ast::Variable { - align: None, + info: ast::VariableInfo { + align: None, + v_type: arg.info.v_type.clone(), + state_space: ptx_parser::StateSpace::Param, + array_init: Vec::new(), + }, name: old_name, - v_type: arg.v_type.clone(), - state_space: ptx_parser::StateSpace::Param, - array_init: Vec::new(), })); } ptx_parser::StateSpace::Reg => {} @@ -52,28 +54,30 @@ fn run_method<'input>( } } for arg in method.input_arguments.iter_mut() { - match arg.state_space { + match arg.info.state_space { ptx_parser::StateSpace::Param => { - arg.state_space = ptx_parser::StateSpace::Reg; + arg.info.state_space = ptx_parser::StateSpace::Reg; let old_name = arg.name; - arg.name = - resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); + arg.name = resolver + .register_unnamed(Some((arg.info.v_type.clone(), arg.info.state_space))); if is_declaration { continue; } body.push(Statement::Variable(ast::Variable { - align: None, + info: ast::VariableInfo { + align: None, + v_type: arg.info.v_type.clone(), + state_space: ptx_parser::StateSpace::Param, + array_init: Vec::new(), + }, name: old_name, - v_type: arg.v_type.clone(), - state_space: ptx_parser::StateSpace::Param, - array_init: Vec::new(), })); body.push(Statement::Instruction(ast::Instruction::St { data: ast::StData { qualifier: ast::LdStQualifier::Weak, state_space: ast::StateSpace::Param, caching: ast::StCacheOperator::Writethrough, - typ: arg.v_type.clone(), + typ: arg.info.v_type.clone(), }, arguments: ast::StArgs { src1: old_name, diff --git a/ptx/src/pass/hoist_globals.rs b/ptx/src/pass/hoist_globals.rs index dfc88c2..097dff1 100644 --- a/ptx/src/pass/hoist_globals.rs +++ b/ptx/src/pass/hoist_globals.rs @@ -30,11 +30,19 @@ fn run_function<'input>( statements .into_iter() .filter_map(|statement| match statement { - Statement::Variable(var @ ast::Variable { - state_space: - ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared, - .. - }) => { + Statement::Variable( + var @ ast::Variable { + info: + ast::VariableInfo { + state_space: + ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::Shared, + .. + }, + .. + }, + ) => { result.push(Directive2::Variable(ast::LinkingDirective::NONE, var)); None } diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index 3350a82..696bb70 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -40,14 +40,14 @@ fn run_method<'a, 'input>( if is_kernel { for arg in method.input_arguments.iter_mut() { let old_name = arg.name; - let old_space = arg.state_space; + let old_space = arg.info.state_space; let new_space = ast::StateSpace::ParamEntry; let new_name = visitor .resolver - .register_unnamed(Some((arg.v_type.clone(), new_space))); + .register_unnamed(Some((arg.info.v_type.clone(), new_space))); visitor.input_argument(old_name, new_name, old_space)?; arg.name = new_name; - arg.state_space = new_space; + arg.info.state_space = new_space; } }; for arg in method.return_arguments.iter_mut() { @@ -83,10 +83,10 @@ fn run_statement<'a, 'input>( return_arguments .iter() .map(|arg| { - if arg.state_space != ast::StateSpace::Local { + if arg.info.state_space != ast::StateSpace::Local { return Err(error_unreachable()); } - Ok((arg.name, arg.v_type.clone())) + Ok((arg.name, arg.info.v_type.clone())) }) .collect::, _>>()?, ) @@ -332,7 +332,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { } fn visit_variable(&mut self, var: &mut ast::Variable) -> Result<(), TranslateError> { - let old_space = match var.state_space { + let old_space = match var.info.state_space { space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, // Do nothing ptx_parser::StateSpace::Local => return Ok(()), @@ -350,10 +350,10 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { let new_space = ast::StateSpace::Local; let new_name = self .resolver - .register_unnamed(Some((var.v_type.clone(), new_space))); - self.variable(&var.v_type, old_name, new_name, old_space)?; + .register_unnamed(Some((var.info.v_type.clone(), new_space))); + self.variable(&var.info.v_type, old_name, new_name, old_space)?; var.name = new_name; - var.state_space = new_space; + var.info.state_space = new_space; Ok(()) } } diff --git a/ptx/src/pass/instruction_mode_to_global_mode/test.rs b/ptx/src/pass/instruction_mode_to_global_mode/test.rs index 78d1d66..94fe735 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/test.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/test.rs @@ -195,7 +195,7 @@ fn compile_methods(ptx: &str) -> Vec, Spir let module = ptx_parser::parse_module_checked(ptx).unwrap(); let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); - let directives = normalize_identifiers2::run(&mut scoped_resolver, module.directives).unwrap(); + let directives = normalize_identifiers::run(&mut scoped_resolver, module.directives).unwrap(); let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap(); let directives = expand_operands::run(&mut flat_resolver, directives).unwrap(); let directives = normalize_basic_blocks::run(&mut flat_resolver, directives).unwrap(); diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index 144f5e6..76717e1 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -122,11 +122,10 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { if fn_ == ptr::null_mut() { let fn_type = get_function_type( self.context, - method.return_arguments.iter().map(|v| &v.v_type), - method - .input_arguments - .iter() - .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)), + method.return_arguments.iter().map(|v| &v.info.v_type), + method.input_arguments.iter().map(|v| { + get_input_argument_type(self.context, &v.info.v_type, v.info.state_space) + }), )?; fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; self.emit_fn_attribute(fn_, "amdgpu-unsafe-fp-atomics", "true"); @@ -153,7 +152,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { for (i, param) in method.input_arguments.iter().enumerate() { let value = unsafe { LLVMGetParam(fn_, i as u32) }; let name = self.resolver.get_or_add(param.name); - if let Some(align) = param.align { + if let Some(align) = param.info.align { unsafe { LLVMSetParamAlignment(value, align) }; } unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; @@ -166,7 +165,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { LLVMCreateTypeAttribute( self.context, attr_kind, - get_type(self.context, ¶m.v_type)?, + get_type(self.context, ¶m.info.v_type)?, ) }; unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) }; @@ -241,17 +240,17 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { let global = unsafe { LLVMAddGlobalInAddressSpace( self.module, - get_type(self.context, &var.v_type)?, + get_type(self.context, &var.info.v_type)?, name.as_ptr(), - get_state_space(var.state_space)?, + get_state_space(var.info.state_space)?, ) }; self.resolver.register(var.name, global); - if let Some(align) = var.align { + if let Some(align) = var.info.align { unsafe { LLVMSetAlignment(global, align) }; } - if !var.array_init.is_empty() { - let initializer = self.get_array_init(&var.v_type, &*var.array_init)?; + if !var.info.array_init.is_empty() { + let initializer = self.get_array_init(&var.info.v_type, &*var.info.array_init)?; unsafe { LLVMSetInitializer(global, initializer) }; } Ok(()) @@ -422,16 +421,16 @@ impl<'a> MethodEmitContext<'a> { let alloca = unsafe { LLVMZludaBuildAlloca( self.variables_builder.get(), - get_type(self.context, &var.v_type)?, - get_state_space(var.state_space)?, + get_type(self.context, &var.info.v_type)?, + get_state_space(var.info.state_space)?, self.resolver.get_or_add_raw(var.name), ) }; self.resolver.register(var.name, alloca); - if let Some(align) = var.align { + if let Some(align) = var.info.align { unsafe { LLVMSetAlignment(alloca, align) }; } - if !var.array_init.is_empty() { + if !var.info.array_init.is_empty() { return Err(error_unreachable()); } Ok(()) diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 4f87dc3..b14903d 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -21,7 +21,7 @@ mod insert_post_saturation; mod instruction_mode_to_global_mode; pub mod llvm; mod normalize_basic_blocks; -mod normalize_identifiers2; +mod normalize_identifiers; mod normalize_predicates2; mod remove_unreachable_basic_blocks; mod replace_instructions_with_functions; @@ -65,8 +65,8 @@ pub fn to_llvm_module<'input>( let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?; - let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; - on_pass_end("normalize_identifiers2"); + let directives = normalize_identifiers::run(&mut scoped_resolver, ast.directives)?; + on_pass_end("normalize_identifiers"); let directives = replace_known_functions::run(&mut flat_resolver, directives); on_pass_end("replace_known_functions"); let directives = normalize_predicates2::run(&mut flat_resolver, directives)?; @@ -308,16 +308,18 @@ impl> Statement, T> { Statement::Variable(var) => { let name = visitor.visit_ident( var.name, - Some((&var.v_type, var.state_space)), + Some((&var.info.v_type, var.info.state_space)), true, false, )?; Statement::Variable(ast::Variable { - align: var.align, - v_type: var.v_type, - state_space: var.state_space, + info: ast::VariableInfo { + align: var.info.align, + v_type: var.info.v_type, + state_space: var.info.state_space, + array_init: var.info.array_init, + }, name, - array_init: var.array_init, }) } Statement::Conditional(conditional) => { @@ -978,20 +980,24 @@ impl SpecialRegistersMap { let return_type = sreg.get_function_return_type(); let input_type = sreg.get_function_input_type(); let return_arguments = vec![ast::Variable { - align: None, - v_type: return_type.into(), - state_space: ast::StateSpace::Reg, + info: ast::VariableInfo { + align: None, + v_type: return_type.into(), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, name: resolver.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))), - array_init: Vec::new(), }]; let input_arguments = input_type .into_iter() .map(|type_| ast::Variable { - align: None, - v_type: type_.into(), - state_space: ast::StateSpace::Reg, + info: ast::VariableInfo { + align: None, + v_type: type_.into(), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, name: resolver.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))), - array_init: Vec::new(), }) .collect::>(); fn_(sreg, (return_arguments, name, input_arguments)); diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers.rs similarity index 81% rename from ptx/src/pass/normalize_identifiers2.rs rename to ptx/src/pass/normalize_identifiers.rs index 901f628..abe2bdf 100644 --- a/ptx/src/pass/normalize_identifiers2.rs +++ b/ptx/src/pass/normalize_identifiers.rs @@ -80,19 +80,28 @@ fn run_function_decl<'input, 'b>( Ok((return_arguments, input_arguments)) } +fn run_variable_info<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + info: ast::VariableInfo<&'input str>, +) -> Result, TranslateError> { + Ok(ast::VariableInfo { + align: info.align, + v_type: info.v_type, + state_space: info.state_space, + array_init: run_array_init(resolver, &info.array_init)?, + }) +} + fn run_variable<'input, 'b>( resolver: &mut ScopedResolver<'input, 'b>, variable: ast::Variable<&'input str>, ) -> Result, TranslateError> { Ok(ast::Variable { + info: run_variable_info(resolver, variable.info.clone())?, name: resolver.add( Cow::Borrowed(variable.name), - Some((variable.v_type.clone(), variable.state_space)), + Some((variable.info.v_type.clone(), variable.info.state_space)), )?, - align: variable.align, - v_type: variable.v_type, - state_space: variable.state_space, - array_init: run_array_init(resolver, &variable.array_init)?, }) } @@ -158,36 +167,26 @@ fn run_multivariable<'input, 'b>( result: &mut Vec, variable: ast::MultiVariable<&'input str>, ) -> Result<(), TranslateError> { - match variable.count { - Some(count) => { + match variable { + ptx_parser::MultiVariable::Parameterized { info, name, count } => { for i in 0..count { - let name = Cow::Owned(format!("{}{}", variable.var.name, i)); - let ident = resolver.add( - name, - Some((variable.var.v_type.clone(), variable.var.state_space)), - )?; + let name = Cow::Owned(format!("{}{}", name, i)); + let ident = resolver.add(name, Some((info.v_type.clone(), info.state_space)))?; result.push(Statement::Variable(ast::Variable { - align: variable.var.align, - v_type: variable.var.v_type.clone(), - state_space: variable.var.state_space, + info: run_variable_info(resolver, info.clone())?, name: ident, - array_init: run_array_init(resolver, &variable.var.array_init)?, })); } } - None => { - let name = Cow::Borrowed(variable.var.name); - let ident = resolver.add( - name, - Some((variable.var.v_type.clone(), variable.var.state_space)), - )?; - result.push(Statement::Variable(ast::Variable { - align: variable.var.align, - v_type: variable.var.v_type.clone(), - state_space: variable.var.state_space, - name: ident, - array_init: run_array_init(resolver, &variable.var.array_init)?, - })); + ptx_parser::MultiVariable::Names { info, names } => { + for name in names { + let name = Cow::Borrowed(name); + let ident = resolver.add(name, Some((info.v_type.clone(), info.state_space)))?; + result.push(Statement::Variable(ast::Variable { + info: run_variable_info(resolver, info.clone())?, + name: ident, + })); + } } } Ok(()) diff --git a/ptx/src/pass/replace_instructions_with_functions.rs b/ptx/src/pass/replace_instructions_with_functions.rs index f7c976e..a68008f 100644 --- a/ptx/src/pass/replace_instructions_with_functions.rs +++ b/ptx/src/pass/replace_instructions_with_functions.rs @@ -580,11 +580,13 @@ fn to_variables<'input>( arguments .iter() .map(|(type_, space)| ast::Variable { - align: None, - v_type: type_.clone(), - state_space: *space, + info: ast::VariableInfo { + align: None, + v_type: type_.clone(), + state_space: *space, + array_init: Vec::new(), + }, name: resolver.register_unnamed(Some((type_.clone(), *space))), - array_init: Vec::new(), }) .collect::>() } diff --git a/ptx/src/pass/replace_instructions_with_functions_fp_required.rs b/ptx/src/pass/replace_instructions_with_functions_fp_required.rs index bf2d690..211d5b6 100644 --- a/ptx/src/pass/replace_instructions_with_functions_fp_required.rs +++ b/ptx/src/pass/replace_instructions_with_functions_fp_required.rs @@ -33,40 +33,48 @@ pub(crate) fn run<'input>( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::U8), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::U8), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::U8), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ], name: imports.part1, @@ -76,20 +84,24 @@ pub(crate) fn run<'input>( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ], body: None, @@ -108,10 +120,12 @@ pub(crate) fn run<'input>( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }], name: imports.part2, input_arguments: vec![ @@ -120,60 +134,72 @@ pub(crate) fn run<'input>( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::U8), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::U8), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::U8), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ], body: None, diff --git a/ptx/src/pass/test/expand_operands/mod.rs b/ptx/src/pass/test/expand_operands/mod.rs index 20efae8..3d83651 100644 --- a/ptx/src/pass/test/expand_operands/mod.rs +++ b/ptx/src/pass/test/expand_operands/mod.rs @@ -12,7 +12,7 @@ fn run_expand_operands(ptx: ptx_parser::Module) -> String { // We run the minimal number of passes required to produce the input expected by expand_operands let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); - let directives = normalize_identifiers2::run(&mut scoped_resolver, ptx.directives).unwrap(); + let directives = normalize_identifiers::run(&mut scoped_resolver, ptx.directives).unwrap(); let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap(); let directives = expand_operands::run(&mut flat_resolver, directives).unwrap(); directive2_vec_to_string(&flat_resolver, directives) diff --git a/ptx/src/pass/test/insert_implicit_conversions/mod.rs b/ptx/src/pass/test/insert_implicit_conversions/mod.rs index 1b5fffb..429d943 100644 --- a/ptx/src/pass/test/insert_implicit_conversions/mod.rs +++ b/ptx/src/pass/test/insert_implicit_conversions/mod.rs @@ -12,7 +12,7 @@ fn run_insert_implicit_conversions(ptx: ptx_parser::Module) -> String { // We run the minimal number of passes required to produce the input expected by insert_implicit_conversions let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); - let directives = normalize_identifiers2::run(&mut scoped_resolver, ptx.directives).unwrap(); + let directives = normalize_identifiers::run(&mut scoped_resolver, ptx.directives).unwrap(); let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap(); let directives = expand_operands::run(&mut flat_resolver, directives).unwrap(); let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives).unwrap(); diff --git a/ptx/src/test/ll/reg_multi.ll b/ptx/src/test/ll/reg_multi.ll new file mode 100644 index 0000000..62f0459 --- /dev/null +++ b/ptx/src/test/ll/reg_multi.ll @@ -0,0 +1,37 @@ +define amdgpu_kernel void @reg_multi(ptr addrspace(4) byref(i64) %"38", ptr addrspace(4) byref(i64) %"39") #0 { + %"40" = alloca i64, align 8, addrspace(5) + %"41" = alloca i64, align 8, addrspace(5) + %"42" = alloca i32, align 4, addrspace(5) + %"43" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"37" + +"37": ; preds = %1 + %"44" = load i64, ptr addrspace(4) %"38", align 8 + store i64 %"44", ptr addrspace(5) %"40", align 8 + %"45" = load i64, ptr addrspace(4) %"39", align 8 + store i64 %"45", ptr addrspace(5) %"41", align 8 + %"47" = load i64, ptr addrspace(5) %"40", align 8 + %"54" = inttoptr i64 %"47" to ptr + %"46" = load i32, ptr %"54", align 4 + store i32 %"46", ptr addrspace(5) %"42", align 4 + %"48" = load i64, ptr addrspace(5) %"40", align 8 + %"55" = inttoptr i64 %"48" to ptr + %"34" = getelementptr inbounds i8, ptr %"55", i64 4 + %"49" = load i32, ptr %"34", align 4 + store i32 %"49", ptr addrspace(5) %"43", align 4 + %"50" = load i64, ptr addrspace(5) %"41", align 8 + %"51" = load i32, ptr addrspace(5) %"42", align 4 + %"56" = inttoptr i64 %"50" to ptr + store i32 %"51", ptr %"56", align 4 + %"52" = load i64, ptr addrspace(5) %"41", align 8 + %"57" = inttoptr i64 %"52" to ptr + %"36" = getelementptr inbounds i8, ptr %"57", i64 4 + %"53" = load i32, ptr addrspace(5) %"43", align 4 + store i32 %"53", ptr %"36", align 4 + ret void +} + +attributes #0 = { "amdgpu-ieee"="false" "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 46bdd0b..bd5d900 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -124,6 +124,7 @@ test_ptx!(vector4, [1u32, 2u32, 3u32, 4u32], [4u32]); test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]); test_ptx!(ntid, [3u32], [4u32]); test_ptx!(reg_local, [12u64], [13u64]); +test_ptx!(reg_multi, [123u32, 456u32], [123u32, 456u32]); test_ptx!(mov_address, [0xDEADu64], [0u64]); test_ptx!(b64tof64, [111u64], [111u64]); // This segfaults NV compiler diff --git a/ptx/src/test/spirv_run/reg_multi.ptx b/ptx/src/test/spirv_run/reg_multi.ptx new file mode 100644 index 0000000..99f234b --- /dev/null +++ b/ptx/src/test/spirv_run/reg_multi.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry reg_multi( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 a, b; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u32 a, [in_addr]; + ld.u32 b, [in_addr+4]; + st.u32 [out_addr], a; + st.u32 [out_addr+4], b; + ret; +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 1bc622c..84d5f57 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -998,29 +998,41 @@ impl MapOperand for Option { } } -pub struct MultiVariable { - pub var: Variable, - pub count: Option, +pub enum MultiVariable { + Parameterized { + info: VariableInfo, + name: ID, + count: u32, + }, + Names { + info: VariableInfo, + names: Vec, + }, +} + +#[derive(Clone)] +pub struct VariableInfo { + pub align: Option, + pub v_type: Type, + pub state_space: StateSpace, + pub array_init: Vec>, } #[derive(Clone)] pub struct Variable { - pub align: Option, - pub v_type: Type, - pub state_space: StateSpace, + pub info: VariableInfo, pub name: ID, - pub array_init: Vec>, } impl std::fmt::Display for Variable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.state_space)?; + write!(f, "{}", self.info.state_space)?; - if let Some(align) = self.align { + if let Some(align) = self.info.align { write!(f, " .align {}", align)?; } - let (vector_size, scalar_type, array_dims) = match &self.v_type { + let (vector_size, scalar_type, array_dims) = match &self.info.v_type { Type::Scalar(scalar_type) => (None, *scalar_type, &vec![]), Type::Vector(size, scalar_type) => (Some(*size), *scalar_type, &vec![]), Type::Array(vector_size, scalar_type, array_dims) => { @@ -1038,7 +1050,7 @@ impl std::fmt::Display for Variable { write!(f, "[{}]", dim)?; } - if self.array_init.len() > 0 { + if self.info.array_init.len() > 0 { todo!("Need to interpret the array initializer data as the appropriate type"); } diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index a4f9080..2c9003b 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -135,7 +135,7 @@ impl<'a, 'input> PtxParserState<'a, 'input> { fn get_type_space(input_arguments: &[Variable<&str>]) -> Vec<(Type, StateSpace)> { input_arguments .iter() - .map(|var| (var.v_type.clone(), var.state_space)) + .map(|var| (var.info.v_type.clone(), var.info.state_space)) .collect::>() } } @@ -552,7 +552,13 @@ fn module_variable<'a, 'input>( let var = global_space .flat_map(|space| multi_variable(linking.contains(LinkingDirective::EXTERN), space)) // TODO: support multi var in globals - .map(|multi_var| multi_var.var) + .verify_map(|multi_var| match multi_var { + MultiVariable::Names { info, names } if names.len() == 1 => Some(ast::Variable { + info, + name: names[0], + }), + _ => None, + }) .parse_next(stream)?; Ok((linking, var)) } @@ -886,7 +892,7 @@ fn method_parameter<'a, 'input: 'a>( ) -> impl Parser, Variable<&'input str>, ContextError> { fn nvptx_kernel_declaration<'a, 'input>( stream: &mut PtxParser<'a, 'input>, - ) -> PResult<(Option, Option, ScalarType, &'input str)> { + ) -> PResult<((Option, Option, ScalarType), &'input str)> { trace( "nvptx_kernel_declaration", ( @@ -897,15 +903,15 @@ fn method_parameter<'a, 'input: 'a>( ident, ), ) - .map(|(vector, type_, _, align, name)| (align, vector, type_, name)) + .map(|(vector, type_, _, align, name)| ((align, vector, type_), name)) .parse_next(stream) } trace( "method_parameter", move |stream: &mut PtxParser<'a, 'input>| { if kernel_decl_rules {} - let (align, vector, type_, name) = - alt((variable_declaration, nvptx_kernel_declaration)).parse_next(stream)?; + let ((align, vector, type_), name) = + alt(((variable_info, ident), nvptx_kernel_declaration)).parse_next(stream)?; let array_dimensions = if state_space != StateSpace::Reg { opt(array_dimensions).parse_next(stream)? } else { @@ -918,27 +924,28 @@ fn method_parameter<'a, 'input: 'a>( } } Ok(Variable { - align, - v_type: Type::maybe_array(vector, type_, array_dimensions), - state_space, + info: VariableInfo { + align, + v_type: Type::maybe_array(vector, type_, array_dimensions), + state_space, + array_init: Vec::new(), + }, name, - array_init: Vec::new(), }) }, ) } // TODO: split to a separate type -fn variable_declaration<'a, 'input>( +fn variable_info<'a, 'input>( stream: &mut PtxParser<'a, 'input>, -) -> PResult<(Option, Option, ScalarType, &'input str)> { +) -> PResult<(Option, Option, ScalarType)> { trace( - "variable_declaration", + "variable_info", ( opt(align.verify(|x| x.count_ones() == 1)), vector_prefix, scalar_type, - ident, ), ) .parse_next(stream) @@ -951,21 +958,27 @@ fn multi_variable<'a, 'input: 'a>( trace( "multi_variable", move |stream: &mut PtxParser<'a, 'input>| { - let ((align, vector, type_, name), count) = ( - variable_declaration, + let ((align, vector, type_), names, count): (_, Vec<_>, _) = ( + variable_info, + separated(1.., ident, Token::Comma), // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names opt(delimited(Token::Lt, u32.verify(|x| *x != 0), Token::Gt)), ) .parse_next(stream)?; - if count.is_some() { - return Ok(MultiVariable { - var: Variable { + if let Some(count) = count { + if names.len() > 1 { + // nvcc does not support parameterized variable names in comma-separated lists of names. + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + let name = names[0]; + return Ok(MultiVariable::Parameterized { + info: VariableInfo { align, v_type: Type::maybe_vector_parsed(vector, type_), state_space, - name, array_init: Vec::new(), }, + name, count, }); } @@ -988,15 +1001,14 @@ fn multi_variable<'a, 'input: 'a>( return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); } } - Ok(MultiVariable { - var: Variable { + Ok(MultiVariable::Names { + info: VariableInfo { align, v_type: Type::maybe_array(vector, type_, array_dimensions), state_space, - name, array_init: initializer.unwrap_or(Vec::new()), }, - count, + names, }) }, ) From 160048a293e98b4e9a424c3dbc2d357562768f30 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 19 Sep 2025 23:30:29 +0000 Subject: [PATCH 3/4] Fix cuCtxPopCurrent --- zluda/src/impl/context.rs | 49 ++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs index 8933116..e6fb35e 100644 --- a/zluda/src/impl/context.rs +++ b/zluda/src/impl/context.rs @@ -188,12 +188,25 @@ pub(crate) unsafe fn push_current_v2(ctx: CUcontext) -> CUresult { push_current(ctx) } -pub(crate) unsafe fn pop_current(ctx: &mut CUcontext) -> CUresult { - STACK.with(|stack| { - if let Some((_ctx, _)) = stack.borrow_mut().pop() { - *ctx = _ctx; - } +pub(crate) unsafe fn pop_current(result: Option<&mut CUcontext>) -> CUresult { + let old_ctx_and_new_device = STACK.with(|stack| { + let mut stack = stack.borrow_mut(); + stack + .pop() + .map(|(ctx, _)| (ctx, stack.last().map(|(_, dev)| *dev))) }); + let ctx = match old_ctx_and_new_device { + Some((old_ctx, new_device)) => { + if let Some(new_device) = new_device { + hipSetDevice(new_device)?; + } + old_ctx + } + None => return CUresult::ERROR_INVALID_CONTEXT, + }; + if let Some(out) = result { + *out = ctx; + } Ok(()) } @@ -213,7 +226,7 @@ pub(crate) unsafe fn destroy_v2(ctx: CUcontext) -> CUresult { zluda_common::drop_checked::(ctx) } -pub(crate) unsafe fn pop_current_v2(ctx: &mut CUcontext) -> CUresult { +pub(crate) unsafe fn pop_current_v2(ctx: Option<&mut CUcontext>) -> CUresult { pop_current(ctx) } @@ -241,3 +254,27 @@ pub(crate) unsafe fn get_api_version( *version = 3020; Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::CudaApi; + use cuda_macros::test_cuda; + use std::mem; + + #[test_cuda] + fn empty_pop_fails(api: impl CudaApi) { + api.cuInit(0); + assert_eq!( + api.cuCtxPopCurrent_v2_unchecked(&mut unsafe { mem::zeroed() }), + CUresult::ERROR_INVALID_CONTEXT + ); + } + + #[test_cuda] + fn pop_into_null_succeeds(api: impl CudaApi) { + api.cuInit(0); + api.cuCtxCreate_v2(&mut unsafe { mem::zeroed() }, 0, 0); + api.cuCtxPopCurrent_v2(ptr::null_mut()); + } +} From 2b9c8946ecb8b0a59fc09c1a42b99a37a4df545c Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 20 Sep 2025 00:43:29 +0000 Subject: [PATCH 4/4] Add replayer --- Cargo.lock | 11 ++++ Cargo.toml | 1 + ptx_parser/src/lib.rs | 2 +- zluda_replay/Cargo.toml | 17 ++++++ zluda_replay/src/main.rs | 98 ++++++++++++++++++++++++++++++++ zluda_trace/src/lib.rs | 38 +++++++++---- zluda_trace/src/replay.rs | 9 ++- zluda_trace_common/Cargo.toml | 1 + zluda_trace_common/src/replay.rs | 60 ++++++++++++++++--- 9 files changed, 215 insertions(+), 22 deletions(-) create mode 100644 zluda_replay/Cargo.toml create mode 100644 zluda_replay/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index ee0d570..31ec2d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3826,6 +3826,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "zluda_replay" +version = "0.0.0" +dependencies = [ + "cuda_macros", + "cuda_types", + "libloading", + "zluda_trace_common", +] + [[package]] name = "zluda_sparse" version = "0.0.0" @@ -3903,6 +3913,7 @@ dependencies = [ "format", "libc", "libloading", + "rustc-hash 2.0.0", "serde", "serde_json", "tar", diff --git a/Cargo.toml b/Cargo.toml index ca051ac..63a82c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ members = [ "zluda_inject", "zluda_ld", "zluda_ml", + "zluda_replay", "zluda_redirect", "zluda_sparse", "compiler", diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 359669a..6078dc5 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -370,7 +370,7 @@ pub fn parse_for_errors_and_params<'input>( .func_directive .input_arguments .iter() - .map(|arg| arg.v_type.layout()) + .map(|arg| arg.info.v_type.layout()) .collect(); Some((func.func_directive.name().to_string(), layouts)) } else { diff --git a/zluda_replay/Cargo.toml b/zluda_replay/Cargo.toml new file mode 100644 index 0000000..73295d4 --- /dev/null +++ b/zluda_replay/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "zluda_replay" +version = "0.0.0" +authors = ["Andrzej Janik "] +edition = "2021" + +[[bin]] +name = "zluda_replay" + +[dependencies] +zluda_trace_common = { path = "../zluda_trace_common" } +cuda_macros = { path = "../cuda_macros" } +cuda_types = { path = "../cuda_types" } +libloading = "0.8" + +[package.metadata.zluda] +debug_only = true diff --git a/zluda_replay/src/main.rs b/zluda_replay/src/main.rs new file mode 100644 index 0000000..50d2d99 --- /dev/null +++ b/zluda_replay/src/main.rs @@ -0,0 +1,98 @@ +use std::mem; + +use cuda_types::cuda::{CUdeviceptr_v2, CUstream}; + +struct CudaDynamicFns { + handle: libloading::Library, +} + +impl CudaDynamicFns { + unsafe fn new(path: &str) -> Result { + let handle = libloading::Library::new(path)?; + Ok(Self { handle }) + } +} + +macro_rules! emit_cuda_fn_table { + ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => { + impl CudaDynamicFns { + $( + #[allow(dead_code)] + unsafe fn $fn_name(&self, $($arg_id : $arg_type),*) -> $ret_type { + let func = self.handle.get:: $ret_type>(concat!(stringify!($fn_name), "\0").as_bytes()); + (func.unwrap())($($arg_id),*) + } + )* + } + }; +} + +cuda_macros::cuda_function_declarations!(emit_cuda_fn_table); + +fn main() { + let args: Vec = std::env::args().collect(); + let libcuda = unsafe { CudaDynamicFns::new(&args[1]).unwrap() }; + unsafe { libcuda.cuInit(0) }.unwrap(); + unsafe { libcuda.cuCtxCreate_v2(&mut mem::zeroed(), 0, 0) }.unwrap(); + let reader = std::fs::File::open(&args[2]).unwrap(); + let (mut manifest, mut source, mut buffers) = zluda_trace_common::replay::load(reader); + let mut args = manifest + .parameters + .iter() + .enumerate() + .map(|(i, param)| { + let mut buffer = buffers.remove(&format!("param_{i}.bin")).unwrap(); + for param_ptr in param.pointer_offsets.iter() { + let buffer_param_slice = &mut buffer[param_ptr.offset_in_param + ..param_ptr.offset_in_param + std::mem::size_of::()]; + let mut dev_ptr = unsafe { mem::zeroed() }; + let host_buffer = buffers + .remove(&format!( + "param_{i}_ptr_{}_pre.bin", + param_ptr.offset_in_param + )) + .unwrap(); + unsafe { libcuda.cuMemAlloc_v2(&mut dev_ptr, host_buffer.len()) }.unwrap(); + unsafe { + libcuda.cuMemcpyHtoD_v2(dev_ptr, host_buffer.as_ptr().cast(), host_buffer.len()) + } + .unwrap(); + dev_ptr = CUdeviceptr_v2(unsafe { + dev_ptr + .0 + .cast::() + .add(param_ptr.offset_in_buffer) + .cast() + }); + buffer_param_slice.copy_from_slice(&(dev_ptr.0 as usize).to_ne_bytes()); + } + }) + .collect::>(); + let mut module = unsafe { mem::zeroed() }; + std::fs::write("/tmp/source.ptx", &source).unwrap(); + source.push('\0'); + unsafe { libcuda.cuModuleLoadData(&mut module, source.as_ptr().cast()) }.unwrap(); + let mut function = unsafe { mem::zeroed() }; + manifest.kernel_name.push('\0'); + unsafe { + libcuda.cuModuleGetFunction(&mut function, module, manifest.kernel_name.as_ptr().cast()) + } + .unwrap(); + unsafe { + libcuda.cuLaunchKernel( + function, + manifest.config.grid_dim.0, + manifest.config.grid_dim.1, + manifest.config.grid_dim.2, + manifest.config.block_dim.0, + manifest.config.block_dim.1, + manifest.config.block_dim.2, + manifest.config.shared_mem_bytes, + CUstream(std::ptr::null_mut()), + args.as_mut_ptr().cast(), + std::ptr::null_mut(), + ) + } + .unwrap(); + todo!(); +} diff --git a/zluda_trace/src/lib.rs b/zluda_trace/src/lib.rs index 46ef901..fe2e41d 100644 --- a/zluda_trace/src/lib.rs +++ b/zluda_trace/src/lib.rs @@ -1552,14 +1552,14 @@ fn launch_kernel_pre( #[allow(non_snake_case)] pub(crate) fn cuLaunchKernel_Post( _f: cuda_types::cuda::CUfunction, - _gridDimX: ::core::ffi::c_uint, - _gridDimY: ::core::ffi::c_uint, - _gridDimZ: ::core::ffi::c_uint, - _blockDimX: ::core::ffi::c_uint, - _blockDimY: ::core::ffi::c_uint, - _blockDimZ: ::core::ffi::c_uint, - _sharedMemBytes: ::core::ffi::c_uint, - stream: cuda_types::cuda::CUstream, + gridDimX: ::core::ffi::c_uint, + gridDimY: ::core::ffi::c_uint, + gridDimZ: ::core::ffi::c_uint, + blockDimX: ::core::ffi::c_uint, + blockDimY: ::core::ffi::c_uint, + blockDimZ: ::core::ffi::c_uint, + sharedMemBytes: ::core::ffi::c_uint, + hStream: cuda_types::cuda::CUstream, kernel_params: *mut *mut ::core::ffi::c_void, _extra: *mut *mut ::core::ffi::c_void, pre_state: Option, @@ -1569,7 +1569,25 @@ pub(crate) fn cuLaunchKernel_Post( _result: CUresult, ) { let pre_state = unwrap_some_or!(pre_state, return); - replay::post_kernel_launch(libcuda, state, fn_logger, stream, kernel_params, pre_state); + replay::post_kernel_launch( + libcuda, + state, + fn_logger, + CUlaunchConfig { + gridDimX, + gridDimY, + gridDimZ, + blockDimX, + blockDimY, + blockDimZ, + sharedMemBytes, + hStream, + attrs: ptr::null_mut(), + numAttrs: 0, + }, + kernel_params, + pre_state, + ); } #[allow(non_snake_case)] @@ -1609,7 +1627,7 @@ pub(crate) fn cuLaunchKernelEx_Post( libcuda, state, fn_logger, - unsafe { *config }.hStream, + unsafe { *config }, kernel_params, pre_state, ); diff --git a/zluda_trace/src/replay.rs b/zluda_trace/src/replay.rs index 74fe292..1b6c01d 100644 --- a/zluda_trace/src/replay.rs +++ b/zluda_trace/src/replay.rs @@ -97,11 +97,11 @@ pub(crate) fn post_kernel_launch( libcuda: &mut CudaDynamicFns, state: &trace::StateTracker, fn_logger: &mut FnCallLog, - stream: CUstream, + config: CUlaunchConfig, kernel_params: *mut *mut std::ffi::c_void, mut pre_state: LaunchPreState, ) -> Option<()> { - fn_logger.try_cuda(|| libcuda.cuStreamSynchronize(stream))?; + fn_logger.try_cuda(|| libcuda.cuStreamSynchronize(config.hStream))?; let raw_args = unsafe { std::slice::from_raw_parts(kernel_params, pre_state.kernel_params.len()) }; for (raw_arg, param) in raw_args.iter().zip(pre_state.kernel_params.iter_mut()) { @@ -128,6 +128,11 @@ pub(crate) fn post_kernel_launch( zluda_trace_common::replay::save( file, pre_state.kernel_name, + zluda_trace_common::replay::LaunchConfig { + grid_dim: (config.gridDimX, config.gridDimY, config.gridDimZ), + block_dim: (config.blockDimX, config.blockDimY, config.blockDimZ), + shared_mem_bytes: config.sharedMemBytes, + }, pre_state.source, pre_state.kernel_params, ) diff --git a/zluda_trace_common/Cargo.toml b/zluda_trace_common/Cargo.toml index 3eb012a..fc24d59 100644 --- a/zluda_trace_common/Cargo.toml +++ b/zluda_trace_common/Cargo.toml @@ -15,6 +15,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.142" tar = "0.4" zstd = "0.13" +rustc-hash = "2.0.0" [target.'cfg(not(windows))'.dependencies] libc = "0.2" diff --git a/zluda_trace_common/src/replay.rs b/zluda_trace_common/src/replay.rs index fe98e7c..53005dc 100644 --- a/zluda_trace_common/src/replay.rs +++ b/zluda_trace_common/src/replay.rs @@ -1,21 +1,30 @@ -use std::io::Write; +use rustc_hash::FxHashMap; +use std::io::{Read, Write}; use tar::Header; #[derive(serde::Serialize, serde::Deserialize)] -struct Manifest { - kernel_name: String, - parameters: Vec, +pub struct Manifest { + pub kernel_name: String, + pub config: LaunchConfig, + pub parameters: Vec, } #[derive(serde::Serialize, serde::Deserialize)] -struct Parameter { - pointer_offsets: Vec, +pub struct LaunchConfig { + pub grid_dim: (u32, u32, u32), + pub block_dim: (u32, u32, u32), + pub shared_mem_bytes: u32, } #[derive(serde::Serialize, serde::Deserialize)] -struct ParameterPointer { - offset_in_param: usize, - offset_in_buffer: usize, +pub struct Parameter { + pub pointer_offsets: Vec, +} + +#[derive(serde::Serialize, serde::Deserialize)] +pub struct ParameterPointer { + pub offset_in_param: usize, + pub offset_in_buffer: usize, } impl Manifest { @@ -37,6 +46,7 @@ pub struct KernelParameter { pub fn save( writer: impl Write, kernel_name: String, + config: LaunchConfig, source: String, kernel_params: Vec, ) -> std::io::Result<()> { @@ -44,6 +54,7 @@ pub fn save( let mut builder = tar::Builder::new(archive); let (mut header, manifest) = Manifest { kernel_name, + config, parameters: kernel_params .iter() .map(|param| Parameter { @@ -85,3 +96,34 @@ pub fn save( builder.into_inner()?.finish()?; Ok(()) } + +pub fn load(reader: impl Read) -> (Manifest, String, FxHashMap>) { + let archive = zstd::Decoder::new(reader).unwrap(); + let mut archive = tar::Archive::new(archive); + let mut manifest = None; + let mut source = None; + let mut buffers = FxHashMap::default(); + for entry in archive.entries().unwrap() { + let mut entry = entry.unwrap(); + let path = entry.path().unwrap().to_string_lossy().to_string(); + match &*path { + Manifest::PATH => { + manifest = Some(serde_json::from_reader::<_, Manifest>(&mut entry).unwrap()); + } + "source.ptx" => { + let mut string = String::new(); + entry.read_to_string(&mut string).unwrap(); + dbg!(string.len()); + source = Some(string); + } + _ => { + let mut buffer = Vec::new(); + entry.read_to_end(&mut buffer).unwrap(); + buffers.insert(path, buffer); + } + } + } + let manifest = manifest.unwrap(); + let source = source.unwrap(); + (manifest, source, buffers) +}