diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs index 61b31ad..ad4b473 100644 --- a/ptx/src/pass/convert_to_stateful_memory_access.rs +++ b/ptx/src/pass/convert_to_stateful_memory_access.rs @@ -475,7 +475,7 @@ fn convert_to_stateful_memory_access_postprocess( let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?; let converting_id = id_defs .register_intermediate(Some((old_operand_type.clone(), old_operand_space))); - let kind = if state_is_compatible(new_operand_space, ast::StateSpace::Reg) { + let kind = if space_is_compatible(new_operand_space, ast::StateSpace::Reg) { ConversionKind::Default } else { ConversionKind::PtrToPtr diff --git a/ptx/src/pass/expand_arguments.rs b/ptx/src/pass/expand_arguments.rs index bc01ab0..d0c7c98 100644 --- a/ptx/src/pass/expand_arguments.rs +++ b/ptx/src/pass/expand_arguments.rs @@ -65,7 +65,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { }; if state_space == ast::StateSpace::Reg || state_space == ast::StateSpace::Sreg { let (reg_type, reg_space) = self.id_def.get_typed(reg)?; - if !state_is_compatible(reg_space, ast::StateSpace::Reg) { + if !space_is_compatible(reg_space, ast::StateSpace::Reg) { return Err(error_mismatched_type()); } let reg_scalar_type = match reg_type { diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs index baf3453..0dce598 100644 --- a/ptx/src/pass/insert_implicit_conversions.rs +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -127,7 +127,22 @@ fn default_implicit_conversion( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if !state_is_compatible(instruction_space, operand_space) { + if instruction_space == ast::StateSpace::Reg { + if space_is_compatible(operand_space, ast::StateSpace::Reg) { + if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = + (operand_type, instruction_type) + { + if scalar.kind() == ast::ScalarKind::Bit + && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) + { + return Ok(Some(ConversionKind::Default)); + } + } + } else if is_addressable(operand_space) { + return Ok(Some(ConversionKind::AddressOf)); + } + } + if !space_is_compatible(instruction_space, operand_space) { default_implicit_conversion_space( (operand_space, operand_type), (instruction_space, instruction_type), @@ -139,6 +154,21 @@ fn default_implicit_conversion( } } +fn is_addressable(this: ast::StateSpace) -> bool { + match this { + ast::StateSpace::Const + | ast::StateSpace::Generic + | ast::StateSpace::Global + | ast::StateSpace::Local + | ast::StateSpace::Shared => true, + ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false, + ast::StateSpace::SharedCluster + | ast::StateSpace::SharedCta + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc => todo!(), + } +} + // Space is different fn default_implicit_conversion_space( (operand_space, operand_type): (ast::StateSpace, &ast::Type), @@ -148,7 +178,7 @@ fn default_implicit_conversion_space( || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space)) { Ok(Some(ConversionKind::PtrToPtr)) - } else if state_is_compatible(operand_space, ast::StateSpace::Reg) { + } else if space_is_compatible(operand_space, ast::StateSpace::Reg) { match operand_type { ast::Type::Pointer(operand_ptr_type, operand_ptr_space) if *operand_ptr_space == instruction_space => @@ -180,7 +210,7 @@ fn default_implicit_conversion_space( }, _ => Err(error_mismatched_type()), } - } else if state_is_compatible(instruction_space, ast::StateSpace::Reg) { + } else if space_is_compatible(instruction_space, ast::StateSpace::Reg) { match instruction_type { ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space) if operand_space == *instruction_ptr_space => @@ -204,7 +234,7 @@ fn default_implicit_conversion_type( operand_type: &ast::Type, instruction_type: &ast::Type, ) -> Result, TranslateError> { - if state_is_compatible(space, ast::StateSpace::Reg) { + if space_is_compatible(space, ast::StateSpace::Reg) { if should_bitcast(instruction_type, operand_type) { Ok(Some(ConversionKind::Default)) } else { @@ -264,7 +294,7 @@ fn should_convert_relaxed_dst_wrapper( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if !state_is_compatible(operand_space, instruction_space) { + if !space_is_compatible(operand_space, instruction_space) { return Err(error_mismatched_type()); } if operand_type == instruction_type { @@ -341,7 +371,7 @@ fn should_convert_relaxed_src_wrapper( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if !state_is_compatible(operand_space, instruction_space) { + if !space_is_compatible(operand_space, instruction_space) { return Err(error_mismatched_type()); } if operand_type == instruction_type { diff --git a/ptx/src/pass/insert_mem_ssa_statements.rs b/ptx/src/pass/insert_mem_ssa_statements.rs index 7369cdb..c1e30b0 100644 --- a/ptx/src/pass/insert_mem_ssa_statements.rs +++ b/ptx/src/pass/insert_mem_ssa_statements.rs @@ -189,7 +189,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { return Ok(symbol); }; let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?; - if !state_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable { + if !space_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable { return Ok(symbol); }; let member_index = match member_index { @@ -257,10 +257,9 @@ impl<'a, 'input> ast::VisitorMap TypedOperand::RegOffset(self.symbol(reg, None, type_space, is_dst)?, offset) } op @ TypedOperand::Imm(..) => op, - TypedOperand::VecMember(symbol, index) => TypedOperand::VecMember( - self.symbol(symbol, Some(index), type_space, is_dst)?, - index, - ), + TypedOperand::VecMember(symbol, index) => { + TypedOperand::Reg(self.symbol(symbol, Some(index), type_space, is_dst)?) + } }) } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index d0f4dfb..4ca2f02 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1214,7 +1214,7 @@ impl< } } -fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool { +fn space_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool { this == other || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg