From 875ac13be285e68d0ffdf7d5910c9736aed212be Mon Sep 17 00:00:00 2001 From: Violet Date: Fri, 19 Sep 2025 13:36:48 -0700 Subject: [PATCH] Support lists of variables to be declared (#516) For example, ``` .reg .u32 a, b; ``` --- ptx/src/pass/deparamize_functions.rs | 40 +++--- ptx/src/pass/hoist_globals.rs | 18 ++- ptx/src/pass/insert_explicit_load_store.rs | 18 +-- .../instruction_mode_to_global_mode/test.rs | 2 +- ptx/src/pass/llvm/emit.rs | 31 ++--- ptx/src/pass/mod.rs | 38 ++--- ...entifiers2.rs => normalize_identifiers.rs} | 57 ++++---- .../replace_instructions_with_functions.rs | 10 +- ...instructions_with_functions_fp_required.rs | 130 +++++++++++------- ptx/src/pass/test/expand_operands/mod.rs | 2 +- .../test/insert_implicit_conversions/mod.rs | 2 +- ptx/src/test/ll/reg_multi.ll | 37 +++++ ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/test/spirv_run/reg_multi.ptx | 22 +++ ptx_parser/src/ast.rs | 34 +++-- ptx_parser/src/lib.rs | 60 ++++---- 16 files changed, 315 insertions(+), 187 deletions(-) rename ptx/src/pass/{normalize_identifiers2.rs => normalize_identifiers.rs} (81%) create mode 100644 ptx/src/test/ll/reg_multi.ll create mode 100644 ptx/src/test/spirv_run/reg_multi.ptx diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs index e80f6a3..2145b89 100644 --- a/ptx/src/pass/deparamize_functions.rs +++ b/ptx/src/pass/deparamize_functions.rs @@ -29,22 +29,24 @@ fn run_method<'input>( let mut remap_returns = Vec::new(); if !method.is_kernel { for arg in method.return_arguments.iter_mut() { - match arg.state_space { + match arg.info.state_space { ptx_parser::StateSpace::Param => { - arg.state_space = ptx_parser::StateSpace::Reg; + arg.info.state_space = ptx_parser::StateSpace::Reg; let old_name = arg.name; - arg.name = - resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); + arg.name = resolver + .register_unnamed(Some((arg.info.v_type.clone(), arg.info.state_space))); if is_declaration { continue; } - remap_returns.push((old_name, arg.name, arg.v_type.clone())); + remap_returns.push((old_name, arg.name, arg.info.v_type.clone())); body.push(Statement::Variable(ast::Variable { - align: None, + info: ast::VariableInfo { + align: None, + v_type: arg.info.v_type.clone(), + state_space: ptx_parser::StateSpace::Param, + array_init: Vec::new(), + }, name: old_name, - v_type: arg.v_type.clone(), - state_space: ptx_parser::StateSpace::Param, - array_init: Vec::new(), })); } ptx_parser::StateSpace::Reg => {} @@ -52,28 +54,30 @@ fn run_method<'input>( } } for arg in method.input_arguments.iter_mut() { - match arg.state_space { + match arg.info.state_space { ptx_parser::StateSpace::Param => { - arg.state_space = ptx_parser::StateSpace::Reg; + arg.info.state_space = ptx_parser::StateSpace::Reg; let old_name = arg.name; - arg.name = - resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); + arg.name = resolver + .register_unnamed(Some((arg.info.v_type.clone(), arg.info.state_space))); if is_declaration { continue; } body.push(Statement::Variable(ast::Variable { - align: None, + info: ast::VariableInfo { + align: None, + v_type: arg.info.v_type.clone(), + state_space: ptx_parser::StateSpace::Param, + array_init: Vec::new(), + }, name: old_name, - v_type: arg.v_type.clone(), - state_space: ptx_parser::StateSpace::Param, - array_init: Vec::new(), })); body.push(Statement::Instruction(ast::Instruction::St { data: ast::StData { qualifier: ast::LdStQualifier::Weak, state_space: ast::StateSpace::Param, caching: ast::StCacheOperator::Writethrough, - typ: arg.v_type.clone(), + typ: arg.info.v_type.clone(), }, arguments: ast::StArgs { src1: old_name, diff --git a/ptx/src/pass/hoist_globals.rs b/ptx/src/pass/hoist_globals.rs index dfc88c2..097dff1 100644 --- a/ptx/src/pass/hoist_globals.rs +++ b/ptx/src/pass/hoist_globals.rs @@ -30,11 +30,19 @@ fn run_function<'input>( statements .into_iter() .filter_map(|statement| match statement { - Statement::Variable(var @ ast::Variable { - state_space: - ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared, - .. - }) => { + Statement::Variable( + var @ ast::Variable { + info: + ast::VariableInfo { + state_space: + ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::Shared, + .. + }, + .. + }, + ) => { result.push(Directive2::Variable(ast::LinkingDirective::NONE, var)); None } diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index 3350a82..696bb70 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -40,14 +40,14 @@ fn run_method<'a, 'input>( if is_kernel { for arg in method.input_arguments.iter_mut() { let old_name = arg.name; - let old_space = arg.state_space; + let old_space = arg.info.state_space; let new_space = ast::StateSpace::ParamEntry; let new_name = visitor .resolver - .register_unnamed(Some((arg.v_type.clone(), new_space))); + .register_unnamed(Some((arg.info.v_type.clone(), new_space))); visitor.input_argument(old_name, new_name, old_space)?; arg.name = new_name; - arg.state_space = new_space; + arg.info.state_space = new_space; } }; for arg in method.return_arguments.iter_mut() { @@ -83,10 +83,10 @@ fn run_statement<'a, 'input>( return_arguments .iter() .map(|arg| { - if arg.state_space != ast::StateSpace::Local { + if arg.info.state_space != ast::StateSpace::Local { return Err(error_unreachable()); } - Ok((arg.name, arg.v_type.clone())) + Ok((arg.name, arg.info.v_type.clone())) }) .collect::, _>>()?, ) @@ -332,7 +332,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { } fn visit_variable(&mut self, var: &mut ast::Variable) -> Result<(), TranslateError> { - let old_space = match var.state_space { + let old_space = match var.info.state_space { space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, // Do nothing ptx_parser::StateSpace::Local => return Ok(()), @@ -350,10 +350,10 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { let new_space = ast::StateSpace::Local; let new_name = self .resolver - .register_unnamed(Some((var.v_type.clone(), new_space))); - self.variable(&var.v_type, old_name, new_name, old_space)?; + .register_unnamed(Some((var.info.v_type.clone(), new_space))); + self.variable(&var.info.v_type, old_name, new_name, old_space)?; var.name = new_name; - var.state_space = new_space; + var.info.state_space = new_space; Ok(()) } } diff --git a/ptx/src/pass/instruction_mode_to_global_mode/test.rs b/ptx/src/pass/instruction_mode_to_global_mode/test.rs index 78d1d66..94fe735 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/test.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/test.rs @@ -195,7 +195,7 @@ fn compile_methods(ptx: &str) -> Vec, Spir let module = ptx_parser::parse_module_checked(ptx).unwrap(); let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); - let directives = normalize_identifiers2::run(&mut scoped_resolver, module.directives).unwrap(); + let directives = normalize_identifiers::run(&mut scoped_resolver, module.directives).unwrap(); let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap(); let directives = expand_operands::run(&mut flat_resolver, directives).unwrap(); let directives = normalize_basic_blocks::run(&mut flat_resolver, directives).unwrap(); diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index 144f5e6..76717e1 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -122,11 +122,10 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { if fn_ == ptr::null_mut() { let fn_type = get_function_type( self.context, - method.return_arguments.iter().map(|v| &v.v_type), - method - .input_arguments - .iter() - .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)), + method.return_arguments.iter().map(|v| &v.info.v_type), + method.input_arguments.iter().map(|v| { + get_input_argument_type(self.context, &v.info.v_type, v.info.state_space) + }), )?; fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; self.emit_fn_attribute(fn_, "amdgpu-unsafe-fp-atomics", "true"); @@ -153,7 +152,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { for (i, param) in method.input_arguments.iter().enumerate() { let value = unsafe { LLVMGetParam(fn_, i as u32) }; let name = self.resolver.get_or_add(param.name); - if let Some(align) = param.align { + if let Some(align) = param.info.align { unsafe { LLVMSetParamAlignment(value, align) }; } unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; @@ -166,7 +165,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { LLVMCreateTypeAttribute( self.context, attr_kind, - get_type(self.context, ¶m.v_type)?, + get_type(self.context, ¶m.info.v_type)?, ) }; unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) }; @@ -241,17 +240,17 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { let global = unsafe { LLVMAddGlobalInAddressSpace( self.module, - get_type(self.context, &var.v_type)?, + get_type(self.context, &var.info.v_type)?, name.as_ptr(), - get_state_space(var.state_space)?, + get_state_space(var.info.state_space)?, ) }; self.resolver.register(var.name, global); - if let Some(align) = var.align { + if let Some(align) = var.info.align { unsafe { LLVMSetAlignment(global, align) }; } - if !var.array_init.is_empty() { - let initializer = self.get_array_init(&var.v_type, &*var.array_init)?; + if !var.info.array_init.is_empty() { + let initializer = self.get_array_init(&var.info.v_type, &*var.info.array_init)?; unsafe { LLVMSetInitializer(global, initializer) }; } Ok(()) @@ -422,16 +421,16 @@ impl<'a> MethodEmitContext<'a> { let alloca = unsafe { LLVMZludaBuildAlloca( self.variables_builder.get(), - get_type(self.context, &var.v_type)?, - get_state_space(var.state_space)?, + get_type(self.context, &var.info.v_type)?, + get_state_space(var.info.state_space)?, self.resolver.get_or_add_raw(var.name), ) }; self.resolver.register(var.name, alloca); - if let Some(align) = var.align { + if let Some(align) = var.info.align { unsafe { LLVMSetAlignment(alloca, align) }; } - if !var.array_init.is_empty() { + if !var.info.array_init.is_empty() { return Err(error_unreachable()); } Ok(()) diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 4f87dc3..b14903d 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -21,7 +21,7 @@ mod insert_post_saturation; mod instruction_mode_to_global_mode; pub mod llvm; mod normalize_basic_blocks; -mod normalize_identifiers2; +mod normalize_identifiers; mod normalize_predicates2; mod remove_unreachable_basic_blocks; mod replace_instructions_with_functions; @@ -65,8 +65,8 @@ pub fn to_llvm_module<'input>( let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?; - let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; - on_pass_end("normalize_identifiers2"); + let directives = normalize_identifiers::run(&mut scoped_resolver, ast.directives)?; + on_pass_end("normalize_identifiers"); let directives = replace_known_functions::run(&mut flat_resolver, directives); on_pass_end("replace_known_functions"); let directives = normalize_predicates2::run(&mut flat_resolver, directives)?; @@ -308,16 +308,18 @@ impl> Statement, T> { Statement::Variable(var) => { let name = visitor.visit_ident( var.name, - Some((&var.v_type, var.state_space)), + Some((&var.info.v_type, var.info.state_space)), true, false, )?; Statement::Variable(ast::Variable { - align: var.align, - v_type: var.v_type, - state_space: var.state_space, + info: ast::VariableInfo { + align: var.info.align, + v_type: var.info.v_type, + state_space: var.info.state_space, + array_init: var.info.array_init, + }, name, - array_init: var.array_init, }) } Statement::Conditional(conditional) => { @@ -978,20 +980,24 @@ impl SpecialRegistersMap { let return_type = sreg.get_function_return_type(); let input_type = sreg.get_function_input_type(); let return_arguments = vec![ast::Variable { - align: None, - v_type: return_type.into(), - state_space: ast::StateSpace::Reg, + info: ast::VariableInfo { + align: None, + v_type: return_type.into(), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, name: resolver.register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))), - array_init: Vec::new(), }]; let input_arguments = input_type .into_iter() .map(|type_| ast::Variable { - align: None, - v_type: type_.into(), - state_space: ast::StateSpace::Reg, + info: ast::VariableInfo { + align: None, + v_type: type_.into(), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, name: resolver.register_unnamed(Some((type_.into(), ast::StateSpace::Reg))), - array_init: Vec::new(), }) .collect::>(); fn_(sreg, (return_arguments, name, input_arguments)); diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers.rs similarity index 81% rename from ptx/src/pass/normalize_identifiers2.rs rename to ptx/src/pass/normalize_identifiers.rs index 901f628..abe2bdf 100644 --- a/ptx/src/pass/normalize_identifiers2.rs +++ b/ptx/src/pass/normalize_identifiers.rs @@ -80,19 +80,28 @@ fn run_function_decl<'input, 'b>( Ok((return_arguments, input_arguments)) } +fn run_variable_info<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + info: ast::VariableInfo<&'input str>, +) -> Result, TranslateError> { + Ok(ast::VariableInfo { + align: info.align, + v_type: info.v_type, + state_space: info.state_space, + array_init: run_array_init(resolver, &info.array_init)?, + }) +} + fn run_variable<'input, 'b>( resolver: &mut ScopedResolver<'input, 'b>, variable: ast::Variable<&'input str>, ) -> Result, TranslateError> { Ok(ast::Variable { + info: run_variable_info(resolver, variable.info.clone())?, name: resolver.add( Cow::Borrowed(variable.name), - Some((variable.v_type.clone(), variable.state_space)), + Some((variable.info.v_type.clone(), variable.info.state_space)), )?, - align: variable.align, - v_type: variable.v_type, - state_space: variable.state_space, - array_init: run_array_init(resolver, &variable.array_init)?, }) } @@ -158,36 +167,26 @@ fn run_multivariable<'input, 'b>( result: &mut Vec, variable: ast::MultiVariable<&'input str>, ) -> Result<(), TranslateError> { - match variable.count { - Some(count) => { + match variable { + ptx_parser::MultiVariable::Parameterized { info, name, count } => { for i in 0..count { - let name = Cow::Owned(format!("{}{}", variable.var.name, i)); - let ident = resolver.add( - name, - Some((variable.var.v_type.clone(), variable.var.state_space)), - )?; + let name = Cow::Owned(format!("{}{}", name, i)); + let ident = resolver.add(name, Some((info.v_type.clone(), info.state_space)))?; result.push(Statement::Variable(ast::Variable { - align: variable.var.align, - v_type: variable.var.v_type.clone(), - state_space: variable.var.state_space, + info: run_variable_info(resolver, info.clone())?, name: ident, - array_init: run_array_init(resolver, &variable.var.array_init)?, })); } } - None => { - let name = Cow::Borrowed(variable.var.name); - let ident = resolver.add( - name, - Some((variable.var.v_type.clone(), variable.var.state_space)), - )?; - result.push(Statement::Variable(ast::Variable { - align: variable.var.align, - v_type: variable.var.v_type.clone(), - state_space: variable.var.state_space, - name: ident, - array_init: run_array_init(resolver, &variable.var.array_init)?, - })); + ptx_parser::MultiVariable::Names { info, names } => { + for name in names { + let name = Cow::Borrowed(name); + let ident = resolver.add(name, Some((info.v_type.clone(), info.state_space)))?; + result.push(Statement::Variable(ast::Variable { + info: run_variable_info(resolver, info.clone())?, + name: ident, + })); + } } } Ok(()) diff --git a/ptx/src/pass/replace_instructions_with_functions.rs b/ptx/src/pass/replace_instructions_with_functions.rs index f7c976e..a68008f 100644 --- a/ptx/src/pass/replace_instructions_with_functions.rs +++ b/ptx/src/pass/replace_instructions_with_functions.rs @@ -580,11 +580,13 @@ fn to_variables<'input>( arguments .iter() .map(|(type_, space)| ast::Variable { - align: None, - v_type: type_.clone(), - state_space: *space, + info: ast::VariableInfo { + align: None, + v_type: type_.clone(), + state_space: *space, + array_init: Vec::new(), + }, name: resolver.register_unnamed(Some((type_.clone(), *space))), - array_init: Vec::new(), }) .collect::>() } diff --git a/ptx/src/pass/replace_instructions_with_functions_fp_required.rs b/ptx/src/pass/replace_instructions_with_functions_fp_required.rs index bf2d690..211d5b6 100644 --- a/ptx/src/pass/replace_instructions_with_functions_fp_required.rs +++ b/ptx/src/pass/replace_instructions_with_functions_fp_required.rs @@ -33,40 +33,48 @@ pub(crate) fn run<'input>( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::U8), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::U8), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::U8), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ], name: imports.part1, @@ -76,20 +84,24 @@ pub(crate) fn run<'input>( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ], body: None, @@ -108,10 +120,12 @@ pub(crate) fn run<'input>( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }], name: imports.part2, input_arguments: vec![ @@ -120,60 +134,72 @@ pub(crate) fn run<'input>( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::F32), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::F32), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::F32), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ast::Variable { name: resolver.register_unnamed(Some(( ast::Type::Scalar(ast::ScalarType::U8), ast::StateSpace::Reg, ))), - align: None, - v_type: ast::Type::Scalar(ast::ScalarType::U8), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), + info: ast::VariableInfo { + align: None, + v_type: ast::Type::Scalar(ast::ScalarType::U8), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), + }, }, ], body: None, diff --git a/ptx/src/pass/test/expand_operands/mod.rs b/ptx/src/pass/test/expand_operands/mod.rs index 20efae8..3d83651 100644 --- a/ptx/src/pass/test/expand_operands/mod.rs +++ b/ptx/src/pass/test/expand_operands/mod.rs @@ -12,7 +12,7 @@ fn run_expand_operands(ptx: ptx_parser::Module) -> String { // We run the minimal number of passes required to produce the input expected by expand_operands let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); - let directives = normalize_identifiers2::run(&mut scoped_resolver, ptx.directives).unwrap(); + let directives = normalize_identifiers::run(&mut scoped_resolver, ptx.directives).unwrap(); let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap(); let directives = expand_operands::run(&mut flat_resolver, directives).unwrap(); directive2_vec_to_string(&flat_resolver, directives) diff --git a/ptx/src/pass/test/insert_implicit_conversions/mod.rs b/ptx/src/pass/test/insert_implicit_conversions/mod.rs index 1b5fffb..429d943 100644 --- a/ptx/src/pass/test/insert_implicit_conversions/mod.rs +++ b/ptx/src/pass/test/insert_implicit_conversions/mod.rs @@ -12,7 +12,7 @@ fn run_insert_implicit_conversions(ptx: ptx_parser::Module) -> String { // We run the minimal number of passes required to produce the input expected by insert_implicit_conversions let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); - let directives = normalize_identifiers2::run(&mut scoped_resolver, ptx.directives).unwrap(); + let directives = normalize_identifiers::run(&mut scoped_resolver, ptx.directives).unwrap(); let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap(); let directives = expand_operands::run(&mut flat_resolver, directives).unwrap(); let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives).unwrap(); diff --git a/ptx/src/test/ll/reg_multi.ll b/ptx/src/test/ll/reg_multi.ll new file mode 100644 index 0000000..62f0459 --- /dev/null +++ b/ptx/src/test/ll/reg_multi.ll @@ -0,0 +1,37 @@ +define amdgpu_kernel void @reg_multi(ptr addrspace(4) byref(i64) %"38", ptr addrspace(4) byref(i64) %"39") #0 { + %"40" = alloca i64, align 8, addrspace(5) + %"41" = alloca i64, align 8, addrspace(5) + %"42" = alloca i32, align 4, addrspace(5) + %"43" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"37" + +"37": ; preds = %1 + %"44" = load i64, ptr addrspace(4) %"38", align 8 + store i64 %"44", ptr addrspace(5) %"40", align 8 + %"45" = load i64, ptr addrspace(4) %"39", align 8 + store i64 %"45", ptr addrspace(5) %"41", align 8 + %"47" = load i64, ptr addrspace(5) %"40", align 8 + %"54" = inttoptr i64 %"47" to ptr + %"46" = load i32, ptr %"54", align 4 + store i32 %"46", ptr addrspace(5) %"42", align 4 + %"48" = load i64, ptr addrspace(5) %"40", align 8 + %"55" = inttoptr i64 %"48" to ptr + %"34" = getelementptr inbounds i8, ptr %"55", i64 4 + %"49" = load i32, ptr %"34", align 4 + store i32 %"49", ptr addrspace(5) %"43", align 4 + %"50" = load i64, ptr addrspace(5) %"41", align 8 + %"51" = load i32, ptr addrspace(5) %"42", align 4 + %"56" = inttoptr i64 %"50" to ptr + store i32 %"51", ptr %"56", align 4 + %"52" = load i64, ptr addrspace(5) %"41", align 8 + %"57" = inttoptr i64 %"52" to ptr + %"36" = getelementptr inbounds i8, ptr %"57", i64 4 + %"53" = load i32, ptr addrspace(5) %"43", align 4 + store i32 %"53", ptr %"36", align 4 + ret void +} + +attributes #0 = { "amdgpu-ieee"="false" "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 46bdd0b..bd5d900 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -124,6 +124,7 @@ test_ptx!(vector4, [1u32, 2u32, 3u32, 4u32], [4u32]); test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]); test_ptx!(ntid, [3u32], [4u32]); test_ptx!(reg_local, [12u64], [13u64]); +test_ptx!(reg_multi, [123u32, 456u32], [123u32, 456u32]); test_ptx!(mov_address, [0xDEADu64], [0u64]); test_ptx!(b64tof64, [111u64], [111u64]); // This segfaults NV compiler diff --git a/ptx/src/test/spirv_run/reg_multi.ptx b/ptx/src/test/spirv_run/reg_multi.ptx new file mode 100644 index 0000000..99f234b --- /dev/null +++ b/ptx/src/test/spirv_run/reg_multi.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry reg_multi( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 a, b; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u32 a, [in_addr]; + ld.u32 b, [in_addr+4]; + st.u32 [out_addr], a; + st.u32 [out_addr+4], b; + ret; +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 1bc622c..84d5f57 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -998,29 +998,41 @@ impl MapOperand for Option { } } -pub struct MultiVariable { - pub var: Variable, - pub count: Option, +pub enum MultiVariable { + Parameterized { + info: VariableInfo, + name: ID, + count: u32, + }, + Names { + info: VariableInfo, + names: Vec, + }, +} + +#[derive(Clone)] +pub struct VariableInfo { + pub align: Option, + pub v_type: Type, + pub state_space: StateSpace, + pub array_init: Vec>, } #[derive(Clone)] pub struct Variable { - pub align: Option, - pub v_type: Type, - pub state_space: StateSpace, + pub info: VariableInfo, pub name: ID, - pub array_init: Vec>, } impl std::fmt::Display for Variable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.state_space)?; + write!(f, "{}", self.info.state_space)?; - if let Some(align) = self.align { + if let Some(align) = self.info.align { write!(f, " .align {}", align)?; } - let (vector_size, scalar_type, array_dims) = match &self.v_type { + let (vector_size, scalar_type, array_dims) = match &self.info.v_type { Type::Scalar(scalar_type) => (None, *scalar_type, &vec![]), Type::Vector(size, scalar_type) => (Some(*size), *scalar_type, &vec![]), Type::Array(vector_size, scalar_type, array_dims) => { @@ -1038,7 +1050,7 @@ impl std::fmt::Display for Variable { write!(f, "[{}]", dim)?; } - if self.array_init.len() > 0 { + if self.info.array_init.len() > 0 { todo!("Need to interpret the array initializer data as the appropriate type"); } diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index a4f9080..2c9003b 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -135,7 +135,7 @@ impl<'a, 'input> PtxParserState<'a, 'input> { fn get_type_space(input_arguments: &[Variable<&str>]) -> Vec<(Type, StateSpace)> { input_arguments .iter() - .map(|var| (var.v_type.clone(), var.state_space)) + .map(|var| (var.info.v_type.clone(), var.info.state_space)) .collect::>() } } @@ -552,7 +552,13 @@ fn module_variable<'a, 'input>( let var = global_space .flat_map(|space| multi_variable(linking.contains(LinkingDirective::EXTERN), space)) // TODO: support multi var in globals - .map(|multi_var| multi_var.var) + .verify_map(|multi_var| match multi_var { + MultiVariable::Names { info, names } if names.len() == 1 => Some(ast::Variable { + info, + name: names[0], + }), + _ => None, + }) .parse_next(stream)?; Ok((linking, var)) } @@ -886,7 +892,7 @@ fn method_parameter<'a, 'input: 'a>( ) -> impl Parser, Variable<&'input str>, ContextError> { fn nvptx_kernel_declaration<'a, 'input>( stream: &mut PtxParser<'a, 'input>, - ) -> PResult<(Option, Option, ScalarType, &'input str)> { + ) -> PResult<((Option, Option, ScalarType), &'input str)> { trace( "nvptx_kernel_declaration", ( @@ -897,15 +903,15 @@ fn method_parameter<'a, 'input: 'a>( ident, ), ) - .map(|(vector, type_, _, align, name)| (align, vector, type_, name)) + .map(|(vector, type_, _, align, name)| ((align, vector, type_), name)) .parse_next(stream) } trace( "method_parameter", move |stream: &mut PtxParser<'a, 'input>| { if kernel_decl_rules {} - let (align, vector, type_, name) = - alt((variable_declaration, nvptx_kernel_declaration)).parse_next(stream)?; + let ((align, vector, type_), name) = + alt(((variable_info, ident), nvptx_kernel_declaration)).parse_next(stream)?; let array_dimensions = if state_space != StateSpace::Reg { opt(array_dimensions).parse_next(stream)? } else { @@ -918,27 +924,28 @@ fn method_parameter<'a, 'input: 'a>( } } Ok(Variable { - align, - v_type: Type::maybe_array(vector, type_, array_dimensions), - state_space, + info: VariableInfo { + align, + v_type: Type::maybe_array(vector, type_, array_dimensions), + state_space, + array_init: Vec::new(), + }, name, - array_init: Vec::new(), }) }, ) } // TODO: split to a separate type -fn variable_declaration<'a, 'input>( +fn variable_info<'a, 'input>( stream: &mut PtxParser<'a, 'input>, -) -> PResult<(Option, Option, ScalarType, &'input str)> { +) -> PResult<(Option, Option, ScalarType)> { trace( - "variable_declaration", + "variable_info", ( opt(align.verify(|x| x.count_ones() == 1)), vector_prefix, scalar_type, - ident, ), ) .parse_next(stream) @@ -951,21 +958,27 @@ fn multi_variable<'a, 'input: 'a>( trace( "multi_variable", move |stream: &mut PtxParser<'a, 'input>| { - let ((align, vector, type_, name), count) = ( - variable_declaration, + let ((align, vector, type_), names, count): (_, Vec<_>, _) = ( + variable_info, + separated(1.., ident, Token::Comma), // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names opt(delimited(Token::Lt, u32.verify(|x| *x != 0), Token::Gt)), ) .parse_next(stream)?; - if count.is_some() { - return Ok(MultiVariable { - var: Variable { + if let Some(count) = count { + if names.len() > 1 { + // nvcc does not support parameterized variable names in comma-separated lists of names. + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + let name = names[0]; + return Ok(MultiVariable::Parameterized { + info: VariableInfo { align, v_type: Type::maybe_vector_parsed(vector, type_), state_space, - name, array_init: Vec::new(), }, + name, count, }); } @@ -988,15 +1001,14 @@ fn multi_variable<'a, 'input: 'a>( return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); } } - Ok(MultiVariable { - var: Variable { + Ok(MultiVariable::Names { + info: VariableInfo { align, v_type: Type::maybe_array(vector, type_, array_dimensions), state_space, - name, array_init: initializer.unwrap_or(Vec::new()), }, - count, + names, }) }, )