Fix method arg load generation

This commit is contained in:
Andrzej Janik 2021-06-05 00:46:41 +02:00
commit 90960fd923

View file

@ -1059,7 +1059,7 @@ fn emit_function_header<'a>(
let (ret_type, func_type) = get_function_type(
builder,
map,
&func_decl.input_arguments,
func_decl.effective_input_arguments().map(|(_, typ)| typ),
&func_decl.return_arguments,
);
let fn_id = match func_decl.name {
@ -1120,9 +1120,9 @@ fn emit_function_header<'a>(
}
}
*/
for input in &func_decl.input_arguments {
let result_type = map.get_or_add(builder, SpirvType::new(input.v_type.clone()));
builder.function_parameter(Some(input.name), result_type)?;
for (name, typ) in func_decl.effective_input_arguments() {
let result_type = map.get_or_add(builder, typ);
builder.function_parameter(Some(name), result_type)?;
}
Ok(fn_id)
}
@ -1233,7 +1233,7 @@ fn to_ssa<'input, 'b>(
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
tuning: Vec<ast::TuningDirective>,
) -> Result<Function<'input>, TranslateError> {
deparamize_function_decl(&func_decl)?;
//deparamize_function_decl(&func_decl)?;
let f_body = match f_body {
Some(vec) => vec,
None => {
@ -1997,30 +1997,38 @@ fn normalize_predicates(
/*
How do we handle arguments:
- input .params
- input .params in kernels
.param .b64 in_arg
get turned into this SPIR-V:
%1 = OpFunctionParameter %ulong
%2 = OpVariable %%_ptr_Function_ulong Function
%2 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %1
We do this for two reasons. One, common treatment for argument-declared
.param variables and .param variables inside function (we assume that
at SPIR-V level every .param is a pointer in Function storage class). Two,
PTX devs in their infinite wisdom decided that .reg arguments are writable
at SPIR-V level every .param is a pointer in Function storage class)
- input .params in functions
.param .b64 in_arg
get turned into this SPIR-V:
%1 = OpFunctionParameter %_ptr_Function_ulong
- input .regs
.reg .b64 in_arg
get turned into this SPIR-V:
get turned into the same SPIR-V as kernel .params:
%1 = OpFunctionParameter %ulong
%2 = OpVariable %%_ptr_Function_ulong Function
%2 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %1
with the difference that %2 is defined as a variable and not temp
- output .regs
.reg .b64 out_arg
get just a variable declaration:
%2 = OpVariable %%_ptr_Function_ulong Function
- output .params
.param .b64 out_arg
get treated the same as input .params, because there's no difference
- output .params don't exist, they have been moved to input positions
by an earlier pass
Distinguishing betweem kernel .params and function .params is not the
cleanest solution. Alternatively, we could "deparamize" all kernel .param
arguments by turning them into .reg arguments like this:
.param .b64 arg -> .reg ptr<.b64,.param> arg
This has the massive downside that this transformation would have to run
very early and would muddy up already difficult code. It's simpler to just
have an if here
*/
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
@ -2029,7 +2037,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.input_arguments.iter_mut() {
insert_mem_ssa_argument(id_def, &mut result, arg);
insert_mem_ssa_argument(id_def, &mut result, arg, fn_decl.name.is_kernel());
}
for arg in fn_decl.return_arguments.iter() {
insert_mem_ssa_argument_reg_return(&mut result, arg);
@ -2103,7 +2111,11 @@ fn insert_mem_ssa_argument(
id_def: &mut NumericIdResolver,
func: &mut Vec<TypedStatement>,
arg: &mut ast::Variable<spirv::Word>,
is_kernel: bool,
) {
if !is_kernel && arg.state_space == ast::StateSpace::Param {
return;
}
let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
func.push(Statement::Variable(ast::Variable {
align: arg.align,
@ -2559,14 +2571,12 @@ fn insert_implicit_conversions_impl(
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
spirv_input: &[ast::Variable<spirv::Word>],
spirv_input: impl ExactSizeIterator<Item = SpirvType>,
spirv_output: &[ast::Variable<spirv::Word>],
) -> (spirv::Word, spirv::Word) {
map.get_or_add_fn(
builder,
spirv_input
.iter()
.map(|var| SpirvType::new(var.v_type.clone())),
spirv_input,
spirv_output
.iter()
.map(|var| SpirvType::new(var.v_type.clone())),
@ -7542,6 +7552,23 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> {
}
}
impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
fn effective_input_arguments(
&self,
) -> impl ExactSizeIterator<Item = (spirv::Word, SpirvType)> + '_ {
let is_kernel = self.name.is_kernel();
self.input_arguments.iter().map(move |arg| {
if !is_kernel {
let spirv_type =
SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
(arg.name, spirv_type)
} else {
(arg.name, SpirvType::new(arg.v_type.clone()))
}
})
}
}
impl<'input, ID> ast::MethodName<'input, ID> {
fn is_kernel(&self) -> bool {
match self {