diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 6d5d5bc..c4efe55 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1059,7 +1059,7 @@ fn emit_function_header<'a>( let (ret_type, func_type) = get_function_type( builder, map, - &func_decl.input_arguments, + func_decl.effective_input_arguments().map(|(_, typ)| typ), &func_decl.return_arguments, ); let fn_id = match func_decl.name { @@ -1120,9 +1120,9 @@ fn emit_function_header<'a>( } } */ - for input in &func_decl.input_arguments { - let result_type = map.get_or_add(builder, SpirvType::new(input.v_type.clone())); - builder.function_parameter(Some(input.name), result_type)?; + for (name, typ) in func_decl.effective_input_arguments() { + let result_type = map.get_or_add(builder, typ); + builder.function_parameter(Some(name), result_type)?; } Ok(fn_id) } @@ -1233,7 +1233,7 @@ fn to_ssa<'input, 'b>( f_body: Option>>>, tuning: Vec, ) -> Result, TranslateError> { - deparamize_function_decl(&func_decl)?; + //deparamize_function_decl(&func_decl)?; let f_body = match f_body { Some(vec) => vec, None => { @@ -1997,30 +1997,38 @@ fn normalize_predicates( /* How do we handle arguments: - - input .params + - input .params in kernels .param .b64 in_arg get turned into this SPIR-V: %1 = OpFunctionParameter %ulong - %2 = OpVariable %%_ptr_Function_ulong Function + %2 = OpVariable %_ptr_Function_ulong Function OpStore %2 %1 We do this for two reasons. One, common treatment for argument-declared .param variables and .param variables inside function (we assume that - at SPIR-V level every .param is a pointer in Function storage class). Two, - PTX devs in their infinite wisdom decided that .reg arguments are writable + at SPIR-V level every .param is a pointer in Function storage class) + - input .params in functions + .param .b64 in_arg + get turned into this SPIR-V: + %1 = OpFunctionParameter %_ptr_Function_ulong - input .regs .reg .b64 in_arg - get turned into this SPIR-V: + get turned into the same SPIR-V as kernel .params: %1 = OpFunctionParameter %ulong - %2 = OpVariable %%_ptr_Function_ulong Function + %2 = OpVariable %_ptr_Function_ulong Function OpStore %2 %1 - with the difference that %2 is defined as a variable and not temp - output .regs .reg .b64 out_arg get just a variable declaration: %2 = OpVariable %%_ptr_Function_ulong Function - - output .params - .param .b64 out_arg - get treated the same as input .params, because there's no difference + - output .params don't exist, they have been moved to input positions + by an earlier pass + Distinguishing betweem kernel .params and function .params is not the + cleanest solution. Alternatively, we could "deparamize" all kernel .param + arguments by turning them into .reg arguments like this: + .param .b64 arg -> .reg ptr<.b64,.param> arg + This has the massive downside that this transformation would have to run + very early and would muddy up already difficult code. It's simpler to just + have an if here */ fn insert_mem_ssa_statements<'a, 'b>( func: Vec, @@ -2029,7 +2037,7 @@ fn insert_mem_ssa_statements<'a, 'b>( ) -> Result, TranslateError> { let mut result = Vec::with_capacity(func.len()); for arg in fn_decl.input_arguments.iter_mut() { - insert_mem_ssa_argument(id_def, &mut result, arg); + insert_mem_ssa_argument(id_def, &mut result, arg, fn_decl.name.is_kernel()); } for arg in fn_decl.return_arguments.iter() { insert_mem_ssa_argument_reg_return(&mut result, arg); @@ -2103,7 +2111,11 @@ fn insert_mem_ssa_argument( id_def: &mut NumericIdResolver, func: &mut Vec, arg: &mut ast::Variable, + is_kernel: bool, ) { + if !is_kernel && arg.state_space == ast::StateSpace::Param { + return; + } let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space))); func.push(Statement::Variable(ast::Variable { align: arg.align, @@ -2559,14 +2571,12 @@ fn insert_implicit_conversions_impl( fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, - spirv_input: &[ast::Variable], + spirv_input: impl ExactSizeIterator, spirv_output: &[ast::Variable], ) -> (spirv::Word, spirv::Word) { map.get_or_add_fn( builder, - spirv_input - .iter() - .map(|var| SpirvType::new(var.v_type.clone())), + spirv_input, spirv_output .iter() .map(|var| SpirvType::new(var.v_type.clone())), @@ -7542,6 +7552,23 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> { } } +impl<'a> ast::MethodDeclaration<'a, spirv::Word> { + fn effective_input_arguments( + &self, + ) -> impl ExactSizeIterator + '_ { + let is_kernel = self.name.is_kernel(); + self.input_arguments.iter().map(move |arg| { + if !is_kernel { + let spirv_type = + SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv()); + (arg.name, spirv_type) + } else { + (arg.name, SpirvType::new(arg.v_type.clone())) + } + }) + } +} + impl<'input, ID> ast::MethodName<'input, ID> { fn is_kernel(&self) -> bool { match self {