diff --git a/ptx/src/pass/expand_arguments.rs b/ptx/src/pass/expand_arguments.rs new file mode 100644 index 0000000..eb03866 --- /dev/null +++ b/ptx/src/pass/expand_arguments.rs @@ -0,0 +1,181 @@ +use super::*; +use ptx_parser as ast; + +pub(super) fn run<'a, 'b>( + func: Vec, + id_def: &'b mut MutableNumericIdResolver<'a>, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(func.len()); + for s in func { + match s { + Statement::Label(id) => result.push(Statement::Label(id)), + Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), + Statement::LoadVar(details) => result.push(Statement::LoadVar(details)), + Statement::StoreVar(details) => result.push(Statement::StoreVar(details)), + Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), + Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), + Statement::Constant(c) => result.push(Statement::Constant(c)), + Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)), + s => { + let (new_statement, post_stmts) = { + let mut visitor = FlattenArguments::new(&mut result, id_def); + (s.visit_map(&mut visitor)?, visitor.post_stmts) + }; + result.push(new_statement); + result.extend(post_stmts); + } + } + } + Ok(result) +} + +struct FlattenArguments<'a, 'b> { + func: &'b mut Vec, + id_def: &'b mut MutableNumericIdResolver<'a>, + post_stmts: Vec, +} + +impl<'a, 'b> FlattenArguments<'a, 'b> { + fn new( + func: &'b mut Vec, + id_def: &'b mut MutableNumericIdResolver<'a>, + ) -> Self { + FlattenArguments { + func, + id_def, + post_stmts: Vec::new(), + } + } + + fn reg(&mut self, name: SpirvWord) -> Result { + Ok(name) + } + + fn reg_offset( + &mut self, + reg: SpirvWord, + offset: i32, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + _is_dst: bool, + ) -> Result { + let (type_, state_space) = if let Some((type_, state_space)) = type_space { + (type_, state_space) + } else { + return Err(TranslateError::UntypedSymbol); + }; + if state_space != ast::StateSpace::Reg && state_space != ast::StateSpace::Sreg { + let (reg_type, reg_space) = self.id_def.get_typed(reg)?; + if !state_is_compatible(reg_space, ast::StateSpace::Reg) { + return Err(TranslateError::MismatchedType); + } + let reg_scalar_type = match reg_type { + ast::Type::Scalar(underlying_type) => underlying_type, + _ => return Err(TranslateError::MismatchedType), + }; + let id_constant_stmt = self + .id_def + .register_intermediate(reg_type.clone(), ast::StateSpace::Reg); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: reg_scalar_type, + value: ast::ImmediateValue::S64(offset as i64), + })); + let arith_details = match reg_scalar_type.kind() { + ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger { + type_: reg_scalar_type, + saturate: false, + }), + ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => { + ast::ArithDetails::Integer(ast::ArithInteger { + type_: reg_scalar_type, + saturate: false, + }) + } + _ => return Err(error_unreachable()), + }; + let id_add_result = self.id_def.register_intermediate(reg_type, state_space); + self.func + .push(Statement::Instruction(ast::Instruction::Add { + data: arith_details, + arguments: ast::AddArgs { + dst: id_add_result, + src1: reg, + src2: id_constant_stmt, + }, + })); + Ok(id_add_result) + } else { + let id_constant_stmt = self.id_def.register_intermediate( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: ast::ScalarType::S64, + value: ast::ImmediateValue::S64(offset as i64), + })); + let dst = self + .id_def + .register_intermediate(type_.clone(), state_space); + self.func.push(Statement::PtrAccess(PtrAccess { + underlying_type: type_.clone(), + state_space: state_space, + dst, + ptr_src: reg, + offset_src: id_constant_stmt, + })); + Ok(dst) + } + } + + fn immediate( + &mut self, + value: ast::ImmediateValue, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + ) -> Result { + let (scalar_t, state_space) = + if let Some((ast::Type::Scalar(scalar), state_space)) = type_space { + (*scalar, state_space) + } else { + return Err(TranslateError::UntypedSymbol); + }; + let id = self + .id_def + .register_intermediate(ast::Type::Scalar(scalar_t), state_space); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id, + typ: scalar_t, + value, + })); + Ok(id) + } +} + +impl<'a, 'b> ast::VisitorMap for FlattenArguments<'a, 'b> { + fn visit( + &mut self, + args: TypedOperand, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + match args { + TypedOperand::Reg(r) => self.reg(r), + TypedOperand::Imm(x) => self.immediate(x, type_space), + TypedOperand::RegOffset(reg, offset) => { + self.reg_offset(reg, offset, type_space, is_dst) + } + TypedOperand::VecMember(..) => Err(error_unreachable()), + } + } + + fn visit_ident( + &mut self, + name: ::Ident, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + _is_dst: bool, + _relaxed_type_check: bool, + ) -> Result<::Ident, TranslateError> { + self.reg(name) + } +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index f6b700b..896a34a 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -11,6 +11,7 @@ use std::{ mod convert_to_stateful_memory_access; mod convert_to_typed; +mod expand_arguments; mod fix_special_registers; mod insert_mem_ssa_statements; mod normalize_identifiers; @@ -181,10 +182,10 @@ fn to_ssa<'input, 'b>( &mut numeric_id_defs, &mut (*func_decl).borrow_mut(), )?; + let mut numeric_id_defs = numeric_id_defs.finish(); + let expanded_statements = expand_arguments::run(ssa_statements, &mut numeric_id_defs)?; todo!() /* - let mut numeric_id_defs = numeric_id_defs.finish(); - let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.unmut(); @@ -743,7 +744,7 @@ impl> Statement, T> { fn visit_map, Err>( self, visitor: &mut impl ast::VisitorMap, - ) -> std::result::Result, T>, Err> { + ) -> std::result::Result, To>, Err> { Ok(match self { Statement::Instruction(i) => { return ast::visit_map(i, visitor).map(Statement::Instruction) @@ -883,6 +884,12 @@ impl> Statement, T> { false, false, )?; + let offset_src = visitor.visit( + offset_src, + Some((&underlying_type, state_space)), + false, + false, + )?; Statement::PtrAccess(PtrAccess { underlying_type, state_space,