diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index ecc5544..a0b5077 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1560,6 +1560,12 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { fn convert_vector( &mut self, is_dst: bool, + non_default_implicit_conversion: Option< + fn( + (ast::StateSpace, &ast::Type), + (ast::StateSpace, &ast::Type), + ) -> Result, TranslateError>, + >, typ: &ast::Type, state_space: ast::StateSpace, idx: Vec, @@ -1577,6 +1583,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { typ: scalar_t, packed: temp_vec, unpacked: idx, + non_default_implicit_conversion, }); if is_dst { self.post_stmts = Some(statement); @@ -1609,9 +1616,13 @@ impl<'a, 'b> ArgumentMapVisitor ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), ast::Operand::Imm(x) => TypedOperand::Imm(x), ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), - ast::Operand::VecPack(vec) => { - TypedOperand::Reg(self.convert_vector(desc.is_dst, typ, state_space, vec)?) - } + ast::Operand::VecPack(vec) => TypedOperand::Reg(self.convert_vector( + desc.is_dst, + desc.non_default_implicit_conversion, + typ, + state_space, + vec, + )?), }) } } @@ -5320,6 +5331,12 @@ struct RepackVectorDetails { typ: ast::ScalarType, packed: spirv::Word, unpacked: Vec, + non_default_implicit_conversion: Option< + fn( + (ast::StateSpace, &ast::Type), + (ast::StateSpace, &ast::Type), + ) -> Result, TranslateError>, + >, } impl RepackVectorDetails { @@ -5335,7 +5352,6 @@ impl RepackVectorDetails { ArgumentDescriptor { op: self.packed, is_dst: !self.is_extract, - non_default_implicit_conversion: None, }, Some(( @@ -5345,6 +5361,7 @@ impl RepackVectorDetails { )?; let scalar_type = self.typ; let is_extract = self.is_extract; + let non_default_implicit_conversion = self.non_default_implicit_conversion; let vector = self .unpacked .into_iter() @@ -5353,7 +5370,7 @@ impl RepackVectorDetails { ArgumentDescriptor { op: id, is_dst: is_extract, - non_default_implicit_conversion: None, + non_default_implicit_conversion, }, Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)), ) @@ -5364,6 +5381,7 @@ impl RepackVectorDetails { typ: self.typ, packed: scalar, unpacked: vector, + non_default_implicit_conversion, }) } } @@ -7168,6 +7186,19 @@ impl ast::StateSpace { || self == ast::StateSpace::Sreg && other == ast::StateSpace::Reg } + fn coerces_to_generic(self) -> bool { + match self { + ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::Local + | ast::StateSpace::Shared => true, + ast::StateSpace::Reg + | ast::StateSpace::Param + | ast::StateSpace::Generic + | ast::StateSpace::Sreg => false, + } + } + fn is_addressable(self) -> bool { match self { ast::StateSpace::Const @@ -7254,7 +7285,11 @@ fn default_implicit_conversion_space( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if operand_space.is_compatible(ast::StateSpace::Reg) { + if (instruction_space == ast::StateSpace::Generic && operand_space.coerces_to_generic()) + || (operand_space == ast::StateSpace::Generic && instruction_space.coerces_to_generic()) + { + Ok(Some(ConversionKind::PtrToPtr)) + } else if operand_space.is_compatible(ast::StateSpace::Reg) { match operand_type { ast::Type::Pointer(operand_ptr_type, operand_ptr_space) if *operand_ptr_space == instruction_space =>