diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 3ad61e5..a0bb023 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -96,6 +96,7 @@ pub struct MethodDeclaration<'input, ID> { pub return_arguments: Vec>, pub name: MethodName<'input, ID>, pub input_arguments: Vec>, + pub shared_mem: Option>, } pub struct Function<'a, ID, S> { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 2253f85..e8370cd 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -392,12 +392,12 @@ MethodDeclaration: ast::MethodDeclaration<'input, &'input str> = { ".entry" => { let return_arguments = Vec::new(); let name = ast::MethodName::Kernel(name); - ast::MethodDeclaration{ return_arguments, name, input_arguments } + ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } }, ".func" => { let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); let name = ast::MethodName::Func(name); - ast::MethodDeclaration{ return_arguments, name, input_arguments } + ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } } }; diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 90a28b7..6d5d5bc 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -562,7 +562,6 @@ fn emit_directives<'input>( call_map, &directives, kernel_info, - f.uses_shared_mem, )?; for t in f.tuning.iter() { match *t { @@ -1038,10 +1037,9 @@ fn emit_function_header<'a>( call_map: &HashMap<&'a str, HashSet>, direcitves: &[Directive], kernel_info: &mut HashMap, - uses_shared_mem: bool, ) -> Result { if let ast::MethodName::Kernel(name) = func_decl.name { - let input_args = if !uses_shared_mem { + let input_args = if func_decl.shared_mem.is_none() { func_decl.input_arguments.as_slice() } else { &func_decl.input_arguments[0..func_decl.input_arguments.len() - 1] @@ -1054,7 +1052,7 @@ fn emit_function_header<'a>( name.to_string(), KernelInfo { arguments_sizes: args_lens, - uses_shared_mem: uses_shared_mem, + uses_shared_mem: func_decl.shared_mem.is_some(), }, ); } @@ -1218,7 +1216,7 @@ fn rename_fn_params<'a, 'b>( ) -> Vec> { args.iter() .map(|a| ast::Variable { - name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), false), + name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true), v_type: a.v_type.clone(), state_space: a.state_space, align: a.align, @@ -1245,7 +1243,6 @@ fn to_ssa<'input, 'b>( globals: Vec::new(), import_as: None, tuning, - uses_shared_mem: false, }) } }; @@ -1276,7 +1273,6 @@ fn to_ssa<'input, 'b>( body: Some(f_body), import_as: None, tuning, - uses_shared_mem: false, }) } @@ -1529,18 +1525,8 @@ fn convert_to_typed_statements( match s { Statement::Instruction(inst) => match inst { ast::Instruction::Call(call) => { - // TODO: error out if lengths don't match - let fn_def = (**fn_defs.get_fn_decl(call.func)?).borrow(); - let return_arguments = - to_resolved_fn_args(call.ret_params, &*fn_def.return_arguments); - let input_arguments = - to_resolved_fn_args(call.param_list, &*fn_def.input_arguments); - let resolved_call = ResolvedCall { - uniform: call.uniform, - return_arguments, - name: call.func, - input_arguments, - }; + let resolver = fn_defs.get_fn_sig_resolver(call.func)?; + let resolved_call = resolver.resolve_in_spirv_repr(call)?; let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let reresolved_call = resolved_call.visit(&mut visitor)?; visitor.func.push(reresolved_call); @@ -1683,6 +1669,7 @@ fn to_ptx_impl_atomic_call( array_init: Vec::new(), }, ], + shared_mem: None, }; let func = Function { func_decl: Rc::new(RefCell::new(func_decl)), @@ -1690,7 +1677,6 @@ fn to_ptx_impl_atomic_call( body: None, import_as: Some(entry.key().clone()), tuning: Vec::new(), - uses_shared_mem: false, }; entry.insert(Directive::Method(func)); fn_id @@ -1772,6 +1758,7 @@ fn to_ptx_impl_bfe_call( array_init: Vec::new(), }, ], + shared_mem: None, }; let func = Function { func_decl: Rc::new(RefCell::new(func_decl)), @@ -1779,7 +1766,6 @@ fn to_ptx_impl_bfe_call( body: None, import_as: Some(entry.key().clone()), tuning: Vec::new(), - uses_shared_mem: false, }; entry.insert(Directive::Method(func)); fn_id @@ -1871,6 +1857,7 @@ fn to_ptx_impl_bfi_call( array_init: Vec::new(), }, ], + shared_mem: None, }; let func = Function { func_decl: Rc::new(RefCell::new(func_decl)), @@ -1878,7 +1865,6 @@ fn to_ptx_impl_bfi_call( body: None, import_as: Some(entry.key().clone()), tuning: Vec::new(), - uses_shared_mem: false, }; entry.insert(Directive::Method(func)); fn_id @@ -2009,42 +1995,44 @@ fn normalize_predicates( Ok(result) } +/* + How do we handle arguments: + - input .params + .param .b64 in_arg + get turned into this SPIR-V: + %1 = OpFunctionParameter %ulong + %2 = OpVariable %%_ptr_Function_ulong Function + OpStore %2 %1 + We do this for two reasons. One, common treatment for argument-declared + .param variables and .param variables inside function (we assume that + at SPIR-V level every .param is a pointer in Function storage class). Two, + PTX devs in their infinite wisdom decided that .reg arguments are writable + - input .regs + .reg .b64 in_arg + get turned into this SPIR-V: + %1 = OpFunctionParameter %ulong + %2 = OpVariable %%_ptr_Function_ulong Function + OpStore %2 %1 + with the difference that %2 is defined as a variable and not temp + - output .regs + .reg .b64 out_arg + get just a variable declaration: + %2 = OpVariable %%_ptr_Function_ulong Function + - output .params + .param .b64 out_arg + get treated the same as input .params, because there's no difference +*/ fn insert_mem_ssa_statements<'a, 'b>( func: Vec, id_def: &mut NumericIdResolver, fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>, ) -> Result, TranslateError> { let mut result = Vec::with_capacity(func.len()); - for arg in fn_decl.return_arguments.iter() { - result.push(Statement::Variable(ast::Variable { - align: arg.align, - v_type: arg.v_type.clone(), - state_space: arg.state_space, - name: arg.name, - array_init: arg.array_init.clone(), - })); - } for arg in fn_decl.input_arguments.iter_mut() { - let typ = arg.v_type.clone(); - let state_space = arg.state_space; - let new_id = id_def.register_intermediate(Some((typ.clone(), state_space))); - result.push(Statement::Variable(ast::Variable { - align: arg.align, - v_type: arg.v_type.clone(), - state_space: arg.state_space, - name: arg.name, - array_init: Vec::new(), - })); - result.push(Statement::StoreVar(StoreVarDetails { - arg: ast::Arg2St { - src1: arg.name, - src2: new_id, - }, - state_space, - typ, - member_index: None, - })); - arg.name = new_id; + insert_mem_ssa_argument(id_def, &mut result, arg); + } + for arg in fn_decl.return_arguments.iter() { + insert_mem_ssa_argument_reg_return(&mut result, arg); } for s in func { match s { @@ -2054,22 +2042,26 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::Instruction(inst) => match inst { ast::Instruction::Ret(d) => { // TODO: handle multiple output args - if let &[out_param] = &fn_decl.return_arguments.as_slice() { - let (typ, space, _) = id_def.get_typed(out_param.name)?; - let new_id = id_def.register_intermediate(Some((typ.clone(), space))); - result.push(Statement::LoadVar(LoadVarDetails { - arg: ast::Arg2 { - dst: new_id, - src: out_param.name, - }, - // TODO: ret with stateful conversion - state_space: new_todo!(), - typ: typ.clone(), - member_index: None, - })); - result.push(Statement::RetValue(d, new_id)); - } else { - result.push(Statement::Instruction(ast::Instruction::Ret(d))) + match &fn_decl.return_arguments[..] { + [return_reg] => { + let new_id = id_def.register_intermediate(Some(( + return_reg.v_type.clone(), + ast::StateSpace::Reg, + ))); + result.push(Statement::LoadVar(LoadVarDetails { + arg: ast::Arg2 { + dst: new_id, + src: return_reg.name, + }, + // TODO: ret with stateful conversion + state_space: ast::StateSpace::Reg, + typ: return_reg.v_type.clone(), + member_index: None, + })); + result.push(Statement::RetValue(d, new_id)); + } + [] => result.push(Statement::Instruction(ast::Instruction::Ret(d))), + _ => unimplemented!(), } } inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?, @@ -2107,6 +2099,43 @@ fn insert_mem_ssa_statements<'a, 'b>( Ok(result) } +fn insert_mem_ssa_argument( + id_def: &mut NumericIdResolver, + func: &mut Vec, + arg: &mut ast::Variable, +) { + let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space))); + func.push(Statement::Variable(ast::Variable { + align: arg.align, + v_type: arg.v_type.clone(), + state_space: ast::StateSpace::Reg, + name: arg.name, + array_init: Vec::new(), + })); + func.push(Statement::StoreVar(StoreVarDetails { + arg: ast::Arg2St { + src1: arg.name, + src2: new_id, + }, + typ: arg.v_type.clone(), + member_index: None, + })); + arg.name = new_id; +} + +fn insert_mem_ssa_argument_reg_return( + func: &mut Vec, + arg: &ast::Variable, +) { + func.push(Statement::Variable(ast::Variable { + align: arg.align, + v_type: arg.v_type.clone(), + state_space: arg.state_space, + name: arg.name, + array_init: arg.array_init.clone(), + })); +} + trait Visitable: Sized { fn visit( self, @@ -2202,7 +2231,6 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { src1: symbol, src2: generated_id, }, - state_space: ast::StateSpace::Reg, typ: var_type, member_index: member_index.map(|(idx, _)| idx), })); @@ -4162,10 +4190,10 @@ fn emit_load_var( Ok(()) } -fn normalize_identifiers<'a, 'b>( - id_defs: &mut FnStringIdResolver<'a, 'b>, - fn_defs: &GlobalFnDeclResolver<'a, 'b>, - func: Vec>>, +fn normalize_identifiers<'input, 'b>( + id_defs: &mut FnStringIdResolver<'input, 'b>, + fn_defs: &GlobalFnDeclResolver<'input, 'b>, + func: Vec>>, ) -> Result, TranslateError> { for s in func.iter() { match s { @@ -4796,12 +4824,92 @@ impl SpecialRegistersMap { } } +struct FnSigMapper<'input> { + // true - stays as return argument + // false - is moved to input argument + return_param_args: Vec, + func_decl: Rc>>, +} + +impl<'input> FnSigMapper<'input> { + fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, spirv::Word>) -> Self { + let return_param_args = method + .return_arguments + .iter() + .map(|a| a.state_space != ast::StateSpace::Param) + .collect::>(); + let mut new_return_arguments = Vec::new(); + for arg in method.return_arguments.into_iter() { + if arg.state_space == ast::StateSpace::Param { + method.input_arguments.push(arg); + } else { + new_return_arguments.push(arg); + } + } + method.return_arguments = new_return_arguments; + FnSigMapper { + return_param_args, + func_decl: Rc::new(RefCell::new(method)), + } + } + + fn resolve_in_spirv_repr( + &self, + call_inst: ast::CallInst, + ) -> Result, TranslateError> { + let func_decl = (*self.func_decl).borrow(); + let mut return_arguments = Vec::new(); + let mut input_arguments = call_inst + .param_list + .into_iter() + .zip(func_decl.input_arguments.iter()) + .map(|(id, var)| (id, var.v_type.clone(), var.state_space)) + .collect::>(); + let mut func_decl_return_iter = func_decl.return_arguments.iter(); + let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter(); + for (idx, id) in call_inst.ret_params.iter().enumerate() { + let stays_as_return = match self.return_param_args.get(idx) { + Some(x) => *x, + None => return Err(TranslateError::MismatchedType), + }; + if stays_as_return { + if let Some(var) = func_decl_return_iter.next() { + return_arguments.push((*id, var.v_type.clone(), var.state_space)); + } else { + return Err(TranslateError::MismatchedType); + } + } else { + if let Some(var) = func_decl_input_iter.next() { + input_arguments.push(( + ast::Operand::Reg(*id), + var.v_type.clone(), + var.state_space, + )); + } else { + return Err(TranslateError::MismatchedType); + } + } + } + if return_arguments.len() != func_decl.return_arguments.len() + || input_arguments.len() != func_decl.input_arguments.len() + { + return Err(TranslateError::MismatchedType); + } + Ok(ResolvedCall { + return_arguments, + input_arguments, + uniform: call_inst.uniform, + name: call_inst.func, + }) + } +} + struct GlobalStringIdResolver<'input> { current_id: spirv::Word, variables: HashMap, spirv::Word>, variables_type_check: HashMap>, special_registers: SpecialRegistersMap, - fns: HashMap>>>, + fns: HashMap>, } impl<'input> GlobalStringIdResolver<'input> { @@ -4885,45 +4993,36 @@ impl<'input> GlobalStringIdResolver<'input> { ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), ast::MethodName::Func(_) => ast::MethodName::Func(name_id), }; - let new_fn_decl = Rc::new(RefCell::new(ast::MethodDeclaration { + let fn_decl = ast::MethodDeclaration { return_arguments, name, input_arguments, - })); - self.fns.insert(name_id, Rc::clone(&new_fn_decl)); + shared_mem: None, + }; + let new_fn_decl = if !fn_decl.name.is_kernel() { + let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl); + let new_fn_decl = resolver.func_decl.clone(); + self.fns.insert(name_id, resolver); + new_fn_decl + } else { + Rc::new(RefCell::new(fn_decl)) + }; Ok(( fn_resolver, - GlobalFnDeclResolver { - variables: &self.variables, - fns: &self.fns, - }, + GlobalFnDeclResolver { fns: &self.fns }, new_fn_decl, )) } } pub struct GlobalFnDeclResolver<'input, 'a> { - variables: &'a HashMap, spirv::Word>, - fns: &'a HashMap>>>, + fns: &'a HashMap>, } impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { - fn get_fn_decl( - &self, - id: spirv::Word, - ) -> Result<&Rc>>, TranslateError> { + fn get_fn_sig_resolver(&self, id: spirv::Word) -> Result<&FnSigMapper<'input>, TranslateError> { self.fns.get(&id).ok_or(TranslateError::UnknownSymbol) } - - fn get_fn_decl_str( - &self, - id: &str, - ) -> Result<&'a Rc>>, TranslateError> { - match self.variables.get(id).map(|var_id| self.fns.get(var_id)) { - Some(Some(fn_d)) => Ok(fn_d), - _ => Err(TranslateError::UnknownSymbol), - } - } } struct FnStringIdResolver<'input, 'b> { @@ -5209,7 +5308,6 @@ struct LoadVarDetails { struct StoreVarDetails { arg: ast::Arg2St, typ: ast::Type, - state_space: ast::StateSpace, member_index: Option, } @@ -5300,7 +5398,7 @@ impl> ResolvedCall { self, visitor: &mut V, ) -> Result, TranslateError> { - let ret_params = self + let return_arguments = self .return_arguments .into_iter() .map::, _>(|(id, typ, space)| { @@ -5324,7 +5422,7 @@ impl> ResolvedCall { }, None, )?; - let param_list = self + let input_arguments = self .input_arguments .into_iter() .map::, _>(|(id, typ, space)| { @@ -5342,9 +5440,9 @@ impl> ResolvedCall { .collect::, _>>()?; Ok(ResolvedCall { uniform: self.uniform, - return_arguments: ret_params, + return_arguments, name: func, - input_arguments: param_list, + input_arguments, }) } } @@ -5485,7 +5583,6 @@ struct Function<'input> { pub func_decl: Rc>>, pub globals: Vec>, pub body: Option>, - pub uses_shared_mem: bool, import_as: Option, tuning: Vec, } @@ -7185,6 +7282,19 @@ fn default_implicit_conversion_space( }, _ => Err(TranslateError::MismatchedType), } + } else if instruction_space.is_compatible(ast::StateSpace::Reg) { + match instruction_type { + ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space) + if operand_space == *instruction_ptr_space => + { + if operand_type != &ast::Type::Scalar(*instruction_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } + _ => Err(TranslateError::MismatchedType), + } } else { Err(TranslateError::MismatchedType) } @@ -7432,6 +7542,15 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> { } } +impl<'input, ID> ast::MethodName<'input, ID> { + fn is_kernel(&self) -> bool { + match self { + ast::MethodName::Kernel(..) => true, + ast::MethodName::Func(..) => false, + } + } +} + #[cfg(test)] mod tests { use super::*;