Continue attempts at fixing code emission for method args

This commit is contained in:
Andrzej Janik 2021-06-04 00:48:51 +02:00
parent 2e6f7e3fdc
commit f70abd065b
3 changed files with 224 additions and 104 deletions

View file

@ -96,6 +96,7 @@ pub struct MethodDeclaration<'input, ID> {
pub return_arguments: Vec<Variable<ID>>,
pub name: MethodName<'input, ID>,
pub input_arguments: Vec<Variable<ID>>,
pub shared_mem: Option<Variable<ID>>,
}
pub struct Function<'a, ID, S> {

View file

@ -392,12 +392,12 @@ MethodDeclaration: ast::MethodDeclaration<'input, &'input str> = {
".entry" <name:ExtendedID> <input_arguments:KernelArguments> => {
let return_arguments = Vec::new();
let name = ast::MethodName::Kernel(name);
ast::MethodDeclaration{ return_arguments, name, input_arguments }
ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None }
},
".func" <return_arguments:FnArguments?> <name:ExtendedID> <input_arguments:FnArguments> => {
let return_arguments = return_arguments.unwrap_or_else(|| Vec::new());
let name = ast::MethodName::Func(name);
ast::MethodDeclaration{ return_arguments, name, input_arguments }
ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None }
}
};

View file

@ -562,7 +562,6 @@ fn emit_directives<'input>(
call_map,
&directives,
kernel_info,
f.uses_shared_mem,
)?;
for t in f.tuning.iter() {
match *t {
@ -1038,10 +1037,9 @@ fn emit_function_header<'a>(
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
direcitves: &[Directive],
kernel_info: &mut HashMap<String, KernelInfo>,
uses_shared_mem: bool,
) -> Result<spirv::Word, TranslateError> {
if let ast::MethodName::Kernel(name) = func_decl.name {
let input_args = if !uses_shared_mem {
let input_args = if func_decl.shared_mem.is_none() {
func_decl.input_arguments.as_slice()
} else {
&func_decl.input_arguments[0..func_decl.input_arguments.len() - 1]
@ -1054,7 +1052,7 @@ fn emit_function_header<'a>(
name.to_string(),
KernelInfo {
arguments_sizes: args_lens,
uses_shared_mem: uses_shared_mem,
uses_shared_mem: func_decl.shared_mem.is_some(),
},
);
}
@ -1218,7 +1216,7 @@ fn rename_fn_params<'a, 'b>(
) -> Vec<ast::Variable<spirv::Word>> {
args.iter()
.map(|a| ast::Variable {
name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), false),
name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true),
v_type: a.v_type.clone(),
state_space: a.state_space,
align: a.align,
@ -1245,7 +1243,6 @@ fn to_ssa<'input, 'b>(
globals: Vec::new(),
import_as: None,
tuning,
uses_shared_mem: false,
})
}
};
@ -1276,7 +1273,6 @@ fn to_ssa<'input, 'b>(
body: Some(f_body),
import_as: None,
tuning,
uses_shared_mem: false,
})
}
@ -1529,18 +1525,8 @@ fn convert_to_typed_statements(
match s {
Statement::Instruction(inst) => match inst {
ast::Instruction::Call(call) => {
// TODO: error out if lengths don't match
let fn_def = (**fn_defs.get_fn_decl(call.func)?).borrow();
let return_arguments =
to_resolved_fn_args(call.ret_params, &*fn_def.return_arguments);
let input_arguments =
to_resolved_fn_args(call.param_list, &*fn_def.input_arguments);
let resolved_call = ResolvedCall {
uniform: call.uniform,
return_arguments,
name: call.func,
input_arguments,
};
let resolver = fn_defs.get_fn_sig_resolver(call.func)?;
let resolved_call = resolver.resolve_in_spirv_repr(call)?;
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let reresolved_call = resolved_call.visit(&mut visitor)?;
visitor.func.push(reresolved_call);
@ -1683,6 +1669,7 @@ fn to_ptx_impl_atomic_call(
array_init: Vec::new(),
},
],
shared_mem: None,
};
let func = Function {
func_decl: Rc::new(RefCell::new(func_decl)),
@ -1690,7 +1677,6 @@ fn to_ptx_impl_atomic_call(
body: None,
import_as: Some(entry.key().clone()),
tuning: Vec::new(),
uses_shared_mem: false,
};
entry.insert(Directive::Method(func));
fn_id
@ -1772,6 +1758,7 @@ fn to_ptx_impl_bfe_call(
array_init: Vec::new(),
},
],
shared_mem: None,
};
let func = Function {
func_decl: Rc::new(RefCell::new(func_decl)),
@ -1779,7 +1766,6 @@ fn to_ptx_impl_bfe_call(
body: None,
import_as: Some(entry.key().clone()),
tuning: Vec::new(),
uses_shared_mem: false,
};
entry.insert(Directive::Method(func));
fn_id
@ -1871,6 +1857,7 @@ fn to_ptx_impl_bfi_call(
array_init: Vec::new(),
},
],
shared_mem: None,
};
let func = Function {
func_decl: Rc::new(RefCell::new(func_decl)),
@ -1878,7 +1865,6 @@ fn to_ptx_impl_bfi_call(
body: None,
import_as: Some(entry.key().clone()),
tuning: Vec::new(),
uses_shared_mem: false,
};
entry.insert(Directive::Method(func));
fn_id
@ -2009,42 +1995,44 @@ fn normalize_predicates(
Ok(result)
}
/*
How do we handle arguments:
- input .params
.param .b64 in_arg
get turned into this SPIR-V:
%1 = OpFunctionParameter %ulong
%2 = OpVariable %%_ptr_Function_ulong Function
OpStore %2 %1
We do this for two reasons. One, common treatment for argument-declared
.param variables and .param variables inside function (we assume that
at SPIR-V level every .param is a pointer in Function storage class). Two,
PTX devs in their infinite wisdom decided that .reg arguments are writable
- input .regs
.reg .b64 in_arg
get turned into this SPIR-V:
%1 = OpFunctionParameter %ulong
%2 = OpVariable %%_ptr_Function_ulong Function
OpStore %2 %1
with the difference that %2 is defined as a variable and not temp
- output .regs
.reg .b64 out_arg
get just a variable declaration:
%2 = OpVariable %%_ptr_Function_ulong Function
- output .params
.param .b64 out_arg
get treated the same as input .params, because there's no difference
*/
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver,
fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.return_arguments.iter() {
result.push(Statement::Variable(ast::Variable {
align: arg.align,
v_type: arg.v_type.clone(),
state_space: arg.state_space,
name: arg.name,
array_init: arg.array_init.clone(),
}));
}
for arg in fn_decl.input_arguments.iter_mut() {
let typ = arg.v_type.clone();
let state_space = arg.state_space;
let new_id = id_def.register_intermediate(Some((typ.clone(), state_space)));
result.push(Statement::Variable(ast::Variable {
align: arg.align,
v_type: arg.v_type.clone(),
state_space: arg.state_space,
name: arg.name,
array_init: Vec::new(),
}));
result.push(Statement::StoreVar(StoreVarDetails {
arg: ast::Arg2St {
src1: arg.name,
src2: new_id,
},
state_space,
typ,
member_index: None,
}));
arg.name = new_id;
insert_mem_ssa_argument(id_def, &mut result, arg);
}
for arg in fn_decl.return_arguments.iter() {
insert_mem_ssa_argument_reg_return(&mut result, arg);
}
for s in func {
match s {
@ -2054,22 +2042,26 @@ fn insert_mem_ssa_statements<'a, 'b>(
Statement::Instruction(inst) => match inst {
ast::Instruction::Ret(d) => {
// TODO: handle multiple output args
if let &[out_param] = &fn_decl.return_arguments.as_slice() {
let (typ, space, _) = id_def.get_typed(out_param.name)?;
let new_id = id_def.register_intermediate(Some((typ.clone(), space)));
result.push(Statement::LoadVar(LoadVarDetails {
arg: ast::Arg2 {
dst: new_id,
src: out_param.name,
},
// TODO: ret with stateful conversion
state_space: new_todo!(),
typ: typ.clone(),
member_index: None,
}));
result.push(Statement::RetValue(d, new_id));
} else {
result.push(Statement::Instruction(ast::Instruction::Ret(d)))
match &fn_decl.return_arguments[..] {
[return_reg] => {
let new_id = id_def.register_intermediate(Some((
return_reg.v_type.clone(),
ast::StateSpace::Reg,
)));
result.push(Statement::LoadVar(LoadVarDetails {
arg: ast::Arg2 {
dst: new_id,
src: return_reg.name,
},
// TODO: ret with stateful conversion
state_space: ast::StateSpace::Reg,
typ: return_reg.v_type.clone(),
member_index: None,
}));
result.push(Statement::RetValue(d, new_id));
}
[] => result.push(Statement::Instruction(ast::Instruction::Ret(d))),
_ => unimplemented!(),
}
}
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
@ -2107,6 +2099,43 @@ fn insert_mem_ssa_statements<'a, 'b>(
Ok(result)
}
fn insert_mem_ssa_argument(
id_def: &mut NumericIdResolver,
func: &mut Vec<TypedStatement>,
arg: &mut ast::Variable<spirv::Word>,
) {
let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
func.push(Statement::Variable(ast::Variable {
align: arg.align,
v_type: arg.v_type.clone(),
state_space: ast::StateSpace::Reg,
name: arg.name,
array_init: Vec::new(),
}));
func.push(Statement::StoreVar(StoreVarDetails {
arg: ast::Arg2St {
src1: arg.name,
src2: new_id,
},
typ: arg.v_type.clone(),
member_index: None,
}));
arg.name = new_id;
}
fn insert_mem_ssa_argument_reg_return(
func: &mut Vec<TypedStatement>,
arg: &ast::Variable<spirv::Word>,
) {
func.push(Statement::Variable(ast::Variable {
align: arg.align,
v_type: arg.v_type.clone(),
state_space: arg.state_space,
name: arg.name,
array_init: arg.array_init.clone(),
}));
}
trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
fn visit(
self,
@ -2202,7 +2231,6 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
src1: symbol,
src2: generated_id,
},
state_space: ast::StateSpace::Reg,
typ: var_type,
member_index: member_index.map(|(idx, _)| idx),
}));
@ -4162,10 +4190,10 @@ fn emit_load_var(
Ok(())
}
fn normalize_identifiers<'a, 'b>(
id_defs: &mut FnStringIdResolver<'a, 'b>,
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
fn normalize_identifiers<'input, 'b>(
id_defs: &mut FnStringIdResolver<'input, 'b>,
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
func: Vec<ast::Statement<ast::ParsedArgParams<'input>>>,
) -> Result<Vec<NormalizedStatement>, TranslateError> {
for s in func.iter() {
match s {
@ -4796,12 +4824,92 @@ impl SpecialRegistersMap {
}
}
struct FnSigMapper<'input> {
// true - stays as return argument
// false - is moved to input argument
return_param_args: Vec<bool>,
func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
}
impl<'input> FnSigMapper<'input> {
fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, spirv::Word>) -> Self {
let return_param_args = method
.return_arguments
.iter()
.map(|a| a.state_space != ast::StateSpace::Param)
.collect::<Vec<_>>();
let mut new_return_arguments = Vec::new();
for arg in method.return_arguments.into_iter() {
if arg.state_space == ast::StateSpace::Param {
method.input_arguments.push(arg);
} else {
new_return_arguments.push(arg);
}
}
method.return_arguments = new_return_arguments;
FnSigMapper {
return_param_args,
func_decl: Rc::new(RefCell::new(method)),
}
}
fn resolve_in_spirv_repr(
&self,
call_inst: ast::CallInst<NormalizedArgParams>,
) -> Result<ResolvedCall<NormalizedArgParams>, TranslateError> {
let func_decl = (*self.func_decl).borrow();
let mut return_arguments = Vec::new();
let mut input_arguments = call_inst
.param_list
.into_iter()
.zip(func_decl.input_arguments.iter())
.map(|(id, var)| (id, var.v_type.clone(), var.state_space))
.collect::<Vec<_>>();
let mut func_decl_return_iter = func_decl.return_arguments.iter();
let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter();
for (idx, id) in call_inst.ret_params.iter().enumerate() {
let stays_as_return = match self.return_param_args.get(idx) {
Some(x) => *x,
None => return Err(TranslateError::MismatchedType),
};
if stays_as_return {
if let Some(var) = func_decl_return_iter.next() {
return_arguments.push((*id, var.v_type.clone(), var.state_space));
} else {
return Err(TranslateError::MismatchedType);
}
} else {
if let Some(var) = func_decl_input_iter.next() {
input_arguments.push((
ast::Operand::Reg(*id),
var.v_type.clone(),
var.state_space,
));
} else {
return Err(TranslateError::MismatchedType);
}
}
}
if return_arguments.len() != func_decl.return_arguments.len()
|| input_arguments.len() != func_decl.input_arguments.len()
{
return Err(TranslateError::MismatchedType);
}
Ok(ResolvedCall {
return_arguments,
input_arguments,
uniform: call_inst.uniform,
name: call_inst.func,
})
}
}
struct GlobalStringIdResolver<'input> {
current_id: spirv::Word,
variables: HashMap<Cow<'input, str>, spirv::Word>,
variables_type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: SpecialRegistersMap,
fns: HashMap<spirv::Word, Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>>,
fns: HashMap<spirv::Word, FnSigMapper<'input>>,
}
impl<'input> GlobalStringIdResolver<'input> {
@ -4885,45 +4993,36 @@ impl<'input> GlobalStringIdResolver<'input> {
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
ast::MethodName::Func(_) => ast::MethodName::Func(name_id),
};
let new_fn_decl = Rc::new(RefCell::new(ast::MethodDeclaration {
let fn_decl = ast::MethodDeclaration {
return_arguments,
name,
input_arguments,
}));
self.fns.insert(name_id, Rc::clone(&new_fn_decl));
shared_mem: None,
};
let new_fn_decl = if !fn_decl.name.is_kernel() {
let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl);
let new_fn_decl = resolver.func_decl.clone();
self.fns.insert(name_id, resolver);
new_fn_decl
} else {
Rc::new(RefCell::new(fn_decl))
};
Ok((
fn_resolver,
GlobalFnDeclResolver {
variables: &self.variables,
fns: &self.fns,
},
GlobalFnDeclResolver { fns: &self.fns },
new_fn_decl,
))
}
}
pub struct GlobalFnDeclResolver<'input, 'a> {
variables: &'a HashMap<Cow<'input, str>, spirv::Word>,
fns: &'a HashMap<spirv::Word, Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>>,
fns: &'a HashMap<spirv::Word, FnSigMapper<'input>>,
}
impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
fn get_fn_decl(
&self,
id: spirv::Word,
) -> Result<&Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>, TranslateError> {
fn get_fn_sig_resolver(&self, id: spirv::Word) -> Result<&FnSigMapper<'input>, TranslateError> {
self.fns.get(&id).ok_or(TranslateError::UnknownSymbol)
}
fn get_fn_decl_str(
&self,
id: &str,
) -> Result<&'a Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>, TranslateError> {
match self.variables.get(id).map(|var_id| self.fns.get(var_id)) {
Some(Some(fn_d)) => Ok(fn_d),
_ => Err(TranslateError::UnknownSymbol),
}
}
}
struct FnStringIdResolver<'input, 'b> {
@ -5209,7 +5308,6 @@ struct LoadVarDetails {
struct StoreVarDetails {
arg: ast::Arg2St<ExpandedArgParams>,
typ: ast::Type,
state_space: ast::StateSpace,
member_index: Option<u8>,
}
@ -5300,7 +5398,7 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
self,
visitor: &mut V,
) -> Result<ResolvedCall<To>, TranslateError> {
let ret_params = self
let return_arguments = self
.return_arguments
.into_iter()
.map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
@ -5324,7 +5422,7 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
},
None,
)?;
let param_list = self
let input_arguments = self
.input_arguments
.into_iter()
.map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
@ -5342,9 +5440,9 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
.collect::<Result<Vec<_>, _>>()?;
Ok(ResolvedCall {
uniform: self.uniform,
return_arguments: ret_params,
return_arguments,
name: func,
input_arguments: param_list,
input_arguments,
})
}
}
@ -5485,7 +5583,6 @@ struct Function<'input> {
pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
pub globals: Vec<ast::Variable<spirv::Word>>,
pub body: Option<Vec<ExpandedStatement>>,
pub uses_shared_mem: bool,
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
}
@ -7185,6 +7282,19 @@ fn default_implicit_conversion_space(
},
_ => Err(TranslateError::MismatchedType),
}
} else if instruction_space.is_compatible(ast::StateSpace::Reg) {
match instruction_type {
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
if operand_space == *instruction_ptr_space =>
{
if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
Ok(Some(ConversionKind::PtrToPtr))
} else {
Ok(None)
}
}
_ => Err(TranslateError::MismatchedType),
}
} else {
Err(TranslateError::MismatchedType)
}
@ -7432,6 +7542,15 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> {
}
}
impl<'input, ID> ast::MethodName<'input, ID> {
fn is_kernel(&self) -> bool {
match self {
ast::MethodName::Kernel(..) => true,
ast::MethodName::Func(..) => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;