Fix method arg load generation

This commit is contained in:
Andrzej Janik 2021-06-05 00:46:41 +02:00
commit 90960fd923

View file

@ -1059,7 +1059,7 @@ fn emit_function_header<'a>(
let (ret_type, func_type) = get_function_type( let (ret_type, func_type) = get_function_type(
builder, builder,
map, map,
&func_decl.input_arguments, func_decl.effective_input_arguments().map(|(_, typ)| typ),
&func_decl.return_arguments, &func_decl.return_arguments,
); );
let fn_id = match func_decl.name { let fn_id = match func_decl.name {
@ -1120,9 +1120,9 @@ fn emit_function_header<'a>(
} }
} }
*/ */
for input in &func_decl.input_arguments { for (name, typ) in func_decl.effective_input_arguments() {
let result_type = map.get_or_add(builder, SpirvType::new(input.v_type.clone())); let result_type = map.get_or_add(builder, typ);
builder.function_parameter(Some(input.name), result_type)?; builder.function_parameter(Some(name), result_type)?;
} }
Ok(fn_id) Ok(fn_id)
} }
@ -1233,7 +1233,7 @@ fn to_ssa<'input, 'b>(
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>, f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
tuning: Vec<ast::TuningDirective>, tuning: Vec<ast::TuningDirective>,
) -> Result<Function<'input>, TranslateError> { ) -> Result<Function<'input>, TranslateError> {
deparamize_function_decl(&func_decl)?; //deparamize_function_decl(&func_decl)?;
let f_body = match f_body { let f_body = match f_body {
Some(vec) => vec, Some(vec) => vec,
None => { None => {
@ -1997,30 +1997,38 @@ fn normalize_predicates(
/* /*
How do we handle arguments: How do we handle arguments:
- input .params - input .params in kernels
.param .b64 in_arg .param .b64 in_arg
get turned into this SPIR-V: get turned into this SPIR-V:
%1 = OpFunctionParameter %ulong %1 = OpFunctionParameter %ulong
%2 = OpVariable %%_ptr_Function_ulong Function %2 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %1 OpStore %2 %1
We do this for two reasons. One, common treatment for argument-declared We do this for two reasons. One, common treatment for argument-declared
.param variables and .param variables inside function (we assume that .param variables and .param variables inside function (we assume that
at SPIR-V level every .param is a pointer in Function storage class). Two, at SPIR-V level every .param is a pointer in Function storage class)
PTX devs in their infinite wisdom decided that .reg arguments are writable - input .params in functions
.param .b64 in_arg
get turned into this SPIR-V:
%1 = OpFunctionParameter %_ptr_Function_ulong
- input .regs - input .regs
.reg .b64 in_arg .reg .b64 in_arg
get turned into this SPIR-V: get turned into the same SPIR-V as kernel .params:
%1 = OpFunctionParameter %ulong %1 = OpFunctionParameter %ulong
%2 = OpVariable %%_ptr_Function_ulong Function %2 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %1 OpStore %2 %1
with the difference that %2 is defined as a variable and not temp
- output .regs - output .regs
.reg .b64 out_arg .reg .b64 out_arg
get just a variable declaration: get just a variable declaration:
%2 = OpVariable %%_ptr_Function_ulong Function %2 = OpVariable %%_ptr_Function_ulong Function
- output .params - output .params don't exist, they have been moved to input positions
.param .b64 out_arg by an earlier pass
get treated the same as input .params, because there's no difference Distinguishing betweem kernel .params and function .params is not the
cleanest solution. Alternatively, we could "deparamize" all kernel .param
arguments by turning them into .reg arguments like this:
.param .b64 arg -> .reg ptr<.b64,.param> arg
This has the massive downside that this transformation would have to run
very early and would muddy up already difficult code. It's simpler to just
have an if here
*/ */
fn insert_mem_ssa_statements<'a, 'b>( fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>, func: Vec<TypedStatement>,
@ -2029,7 +2037,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
) -> Result<Vec<TypedStatement>, TranslateError> { ) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len()); let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.input_arguments.iter_mut() { for arg in fn_decl.input_arguments.iter_mut() {
insert_mem_ssa_argument(id_def, &mut result, arg); insert_mem_ssa_argument(id_def, &mut result, arg, fn_decl.name.is_kernel());
} }
for arg in fn_decl.return_arguments.iter() { for arg in fn_decl.return_arguments.iter() {
insert_mem_ssa_argument_reg_return(&mut result, arg); insert_mem_ssa_argument_reg_return(&mut result, arg);
@ -2103,7 +2111,11 @@ fn insert_mem_ssa_argument(
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
func: &mut Vec<TypedStatement>, func: &mut Vec<TypedStatement>,
arg: &mut ast::Variable<spirv::Word>, arg: &mut ast::Variable<spirv::Word>,
is_kernel: bool,
) { ) {
if !is_kernel && arg.state_space == ast::StateSpace::Param {
return;
}
let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space))); let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
func.push(Statement::Variable(ast::Variable { func.push(Statement::Variable(ast::Variable {
align: arg.align, align: arg.align,
@ -2559,14 +2571,12 @@ fn insert_implicit_conversions_impl(
fn get_function_type( fn get_function_type(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
spirv_input: &[ast::Variable<spirv::Word>], spirv_input: impl ExactSizeIterator<Item = SpirvType>,
spirv_output: &[ast::Variable<spirv::Word>], spirv_output: &[ast::Variable<spirv::Word>],
) -> (spirv::Word, spirv::Word) { ) -> (spirv::Word, spirv::Word) {
map.get_or_add_fn( map.get_or_add_fn(
builder, builder,
spirv_input spirv_input,
.iter()
.map(|var| SpirvType::new(var.v_type.clone())),
spirv_output spirv_output
.iter() .iter()
.map(|var| SpirvType::new(var.v_type.clone())), .map(|var| SpirvType::new(var.v_type.clone())),
@ -7542,6 +7552,23 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> {
} }
} }
impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
fn effective_input_arguments(
&self,
) -> impl ExactSizeIterator<Item = (spirv::Word, SpirvType)> + '_ {
let is_kernel = self.name.is_kernel();
self.input_arguments.iter().map(move |arg| {
if !is_kernel {
let spirv_type =
SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
(arg.name, spirv_type)
} else {
(arg.name, SpirvType::new(arg.v_type.clone()))
}
})
}
}
impl<'input, ID> ast::MethodName<'input, ID> { impl<'input, ID> ast::MethodName<'input, ID> {
fn is_kernel(&self) -> bool { fn is_kernel(&self) -> bool {
match self { match self {