diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c31e0a2..db062db 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -981,7 +981,6 @@ fn compute_denorm_information<'input>( Statement::Conversion(_) => {} Statement::Constant(_) => {} Statement::RetValue(_, _) => {} - Statement::Undef(_, _) => {} Statement::Label(_) => {} Statement::Variable(_) => {} Statement::PtrAccess { .. } => {} @@ -1845,7 +1844,6 @@ fn normalize_labels( | Statement::Conversion(..) | Statement::Constant(..) | Statement::Label(..) - | Statement::Undef(..) | Statement::PtrAccess { .. } | Statement::RepackVector(..) => {} } @@ -2071,46 +2069,158 @@ impl< } } -fn insert_mem_ssa_statement_default<'a, S: Visitable>( - id_def: &mut NumericIdResolver, - result: &mut Vec, - stmt: S, -) -> Result<(), TranslateError> { - let mut post_statements = Vec::new(); - let new_statement = - stmt.visit(&mut |desc: ArgumentDescriptor, - expected_type: Option<&ast::Type>| { - if expected_type.is_none() { - return Ok(desc.op); - }; - let (var_type, is_variable) = id_def.get_typed(desc.op)?; - if !is_variable { - return Ok(desc.op); +struct InsertMemSSAVisitor<'a, 'input> { + id_def: &'a mut NumericIdResolver<'input>, + func: &'a mut Vec, + post_statements: Vec, +} + +impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { + fn symbol( + &mut self, + desc: ArgumentDescriptor<(spirv::Word, Option)>, + expected_type: Option<&ast::Type>, + ) -> Result { + let symbol = desc.op.0; + if expected_type.is_none() { + return Ok(symbol); + }; + let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?; + if !is_variable { + return Ok(symbol); + }; + let member_index = match desc.op.1 { + Some(idx) => { + match var_type { + ast::Type::Vector(scalar_t, _) => { + var_type = ast::Type::Scalar(scalar_t); + } + _ => return Err(TranslateError::MismatchedType), + } + Some((idx, self.id_def.special_registers.contains_key(&symbol))) } - let generated_id = id_def.new_non_variable(Some(var_type.clone())); - if !desc.is_dst { - result.push(Statement::LoadVar(LoadVarDetails { - arg: Arg2 { - dst: generated_id, - src: desc.op, - }, - typ: var_type, - member_index: None, - })); - } else { - post_statements.push(Statement::StoreVar(StoreVarDetails { + None => None, + }; + let generated_id = self.id_def.new_non_variable(Some(var_type.clone())); + if !desc.is_dst { + self.func.push(Statement::LoadVar(LoadVarDetails { + arg: Arg2 { + dst: generated_id, + src: symbol, + }, + typ: var_type, + member_index, + })); + } else { + self.post_statements + .push(Statement::StoreVar(StoreVarDetails { arg: Arg2St { - src1: desc.op, + src1: symbol, src2: generated_id, }, typ: var_type, - member_index: None, + member_index, })); + } + Ok(generated_id) + } +} + +impl<'a, 'input> ArgumentMapVisitor + for InsertMemSSAVisitor<'a, 'input> +{ + fn id( + &mut self, + desc: ArgumentDescriptor, + typ: Option<&ast::Type>, + ) -> Result { + self.symbol(desc.new_op((desc.op, None)), typ) + } + + fn dst_operand( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + Ok(match desc.op { + ast::DstOperand::Reg(reg) => { + ast::DstOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) } - Ok(generated_id) - })?; - result.push(new_statement); - result.append(&mut post_statements); + ast::DstOperand::VecMember(symbol, index) => { + ast::DstOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) + } + }) + } + + fn src_operand( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + Ok(match desc.op { + ast::SrcOperand::Reg(reg) => { + ast::SrcOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) + } + ast::SrcOperand::RegOffset(reg, offset) => ast::SrcOperand::RegOffset( + self.symbol(desc.new_op((reg, None)), Some(typ))?, + offset, + ), + op @ ast::SrcOperand::Imm(..) => op, + ast::SrcOperand::VecMember(symbol, index) => { + ast::SrcOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) + } + }) + } + + fn dst_operand_vec( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + Ok(match desc.op { + ast::DstOperand::Reg(reg) => { + ast::DstOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) + } + ast::DstOperand::VecMember(symbol, index) => { + ast::DstOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) + } + }) + } + + fn src_operand_vec( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result, TranslateError> { + Ok(match desc.op { + ast::SrcOperand::Reg(reg) => { + ast::SrcOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) + } + ast::SrcOperand::RegOffset(reg, offset) => ast::SrcOperand::RegOffset( + self.symbol(desc.new_op((reg, None)), Some(typ))?, + offset, + ), + op @ ast::SrcOperand::Imm(..) => op, + ast::SrcOperand::VecMember(symbol, index) => { + ast::SrcOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) + } + }) + } +} + +fn insert_mem_ssa_statement_default<'a, 'input, S: Visitable>( + id_def: &'a mut NumericIdResolver<'input>, + func: &'a mut Vec, + stmt: S, +) -> Result<(), TranslateError> { + let mut visitor = InsertMemSSAVisitor { + id_def, + func, + post_statements: Vec::new(), + }; + let new_stmt = stmt.visit(&mut visitor)?; + visitor.func.push(new_stmt); + visitor.func.extend(visitor.post_statements); Ok(()) } @@ -2162,7 +2272,7 @@ fn expand_arguments<'a, 'b>( Statement::StoreVar(details) => result.push(Statement::StoreVar(details)), Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), - Statement::Constant(_) | Statement::Undef(_, _) => return Err(error_unreachable()), + Statement::Constant(_) => return Err(error_unreachable()), } } Ok(result) @@ -2472,7 +2582,6 @@ fn insert_implicit_conversions( | s @ Statement::Variable(_) | s @ Statement::LoadVar(..) | s @ Statement::StoreVar(..) - | s @ Statement::Undef(_, _) | s @ Statement::RetValue(_, _) => result.push(s), } } @@ -3049,26 +3158,66 @@ fn emit_function_body_ops( builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; } }, - Statement::LoadVar(LoadVarDetails { - arg, - typ, - member_index, - }) => { - let type_id = map.get_or_add(builder, SpirvType::from(typ.clone())); - builder.load(type_id, Some(arg.dst), arg.src, None, [])?; + Statement::LoadVar(details) => { + let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone())); + let src = match details.member_index { + Some((index, is_sreg)) => { + let storage_class = if is_sreg { + spirv::StorageClass::Input + } else { + spirv::StorageClass::Function + }; + let result_ptr_type = map.get_or_add( + builder, + SpirvType::new_pointer(details.typ.clone(), storage_class), + ); + let index_spirv = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(index as u32), + )?; + builder.in_bounds_access_chain( + result_ptr_type, + None, + details.arg.src, + &[index_spirv], + )? + } + None => details.arg.src, + }; + builder.load(result_type, Some(details.arg.dst), src, None, [])?; } - Statement::StoreVar(StoreVarDetails { - arg, member_index, .. - }) => { - builder.store(arg.src1, arg.src2, None, [])?; + Statement::StoreVar(details) => { + let dst_ptr = match details.member_index { + Some((index, is_sreg)) => { + let storage_class = if is_sreg { + spirv::StorageClass::Input + } else { + spirv::StorageClass::Function + }; + let result_ptr_type = map.get_or_add( + builder, + SpirvType::new_pointer(details.typ.clone(), storage_class), + ); + let index_spirv = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(index as u32), + )?; + builder.in_bounds_access_chain( + result_ptr_type, + None, + details.arg.src1, + &[index_spirv], + )? + } + None => details.arg.src1, + }; + builder.store(dst_ptr, details.arg.src2, None, [])?; } Statement::RetValue(_, id) => { builder.ret_value(*id)?; } - Statement::Undef(t, id) => { - let result_type = map.get_or_add(builder, SpirvType::from(t.clone())); - builder.undef(result_type, Some(*id)); - } Statement::PtrAccess(PtrAccess { underlying_type, state_space, @@ -4870,7 +5019,6 @@ enum Statement { Conversion(ImplicitConversion), Constant(ConstantDefinition), RetValue(ast::RetData, spirv::Word), - Undef(ast::Type, spirv::Word), PtrAccess(PtrAccess

), RepackVector(RepackVectorDetails), } @@ -4932,10 +5080,6 @@ impl ExpandedStatement { let id = f(id, false); Statement::RetValue(data, id) } - Statement::Undef(typ, id) => { - let id = f(id, true); - Statement::Undef(typ, id) - } Statement::PtrAccess(PtrAccess { underlying_type, state_space, @@ -4962,13 +5106,15 @@ impl ExpandedStatement { struct LoadVarDetails { arg: ast::Arg2, typ: ast::Type, - member_index: Option, + // (index, is_sreg) + member_index: Option<(u8, bool)>, } struct StoreVarDetails { arg: ast::Arg2St, typ: ast::Type, - member_index: Option, + // (index, is_sreg) + member_index: Option<(u8, bool)>, } struct RepackVectorDetails { @@ -5261,29 +5407,6 @@ impl ArgParamsEx for ExpandedArgParams { } } -#[derive(Copy, Clone)] -pub enum StateSpace { - Reg, - Const, - Global, - Local, - Shared, - Param, -} - -impl From for StateSpace { - fn from(ss: ast::StateSpace) -> Self { - match ss { - ast::StateSpace::Reg => StateSpace::Reg, - ast::StateSpace::Const => StateSpace::Const, - ast::StateSpace::Global => StateSpace::Global, - ast::StateSpace::Local => StateSpace::Local, - ast::StateSpace::Shared => StateSpace::Shared, - ast::StateSpace::Param => StateSpace::Param, - } - } -} - enum Directive<'input> { Variable(ast::Variable), Method(Function<'input>), @@ -5388,7 +5511,7 @@ where fn dst_operand( &mut self, desc: ArgumentDescriptor>, - typ: &ast::Type, + _: &ast::Type, ) -> Result, TranslateError> { Ok(match desc.op { ast::DstOperand::Reg(id) => ast::DstOperand::Reg(self(id)?), @@ -5399,7 +5522,7 @@ where fn src_operand( &mut self, desc: ArgumentDescriptor>, - typ: &ast::Type, + _: &ast::Type, ) -> Result, TranslateError> { Ok(match desc.op { ast::SrcOperand::Reg(id) => ast::SrcOperand::Reg(self(id)?), @@ -6810,6 +6933,7 @@ impl ast::DstOperand { } } } + impl ast::DstOperand { fn unwrap_reg(&self) -> Result { match self {