diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c6b7f01..15163fc 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -2349,98 +2349,29 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { state_space: ast::StateSpace, ) -> Result { let (reg, offset) = desc.op; - let add_type; match typ { ast::Type::Scalar(underlying_type) => { - let (reg_typ, space) = self.id_def.get_typed(reg)?; - if let ast::Type::Pointer(..) = reg_typ { - let id_constant_stmt = self.id_def.register_intermediate(typ.clone(), space); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: ast::ScalarType::S64, - value: ast::ImmediateValue::S64(offset as i64), - })); - let dst = self.id_def.register_intermediate(typ.clone(), space); - self.func.push(Statement::PtrAccess(PtrAccess { - underlying_type: *underlying_type, - state_space: state_space, - dst, - ptr_src: reg, - offset_src: id_constant_stmt, - })); - return Ok(dst); - } else { - add_type = reg_typ; - } + let id_constant_stmt = self.id_def.register_intermediate( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: ast::ScalarType::S64, + value: ast::ImmediateValue::S64(offset as i64), + })); + let dst = self.id_def.register_intermediate(typ.clone(), state_space); + self.func.push(Statement::PtrAccess(PtrAccess { + underlying_type: *underlying_type, + state_space: state_space, + dst, + ptr_src: reg, + offset_src: id_constant_stmt, + })); + Ok(dst) } - _ => return Err(error_unreachable()), - }; - let (width, kind) = match add_type { - ast::Type::Scalar(scalar_t) => { - let kind = match scalar_t.kind() { - kind @ ast::ScalarKind::Bit - | kind @ ast::ScalarKind::Unsigned - | kind @ ast::ScalarKind::Signed => kind, - ast::ScalarKind::Float => return Err(TranslateError::MismatchedType), - ast::ScalarKind::Float2 => return Err(TranslateError::MismatchedType), - ast::ScalarKind::Pred => return Err(TranslateError::MismatchedType), - }; - (scalar_t.size_of(), kind) - } - _ => return Err(TranslateError::MismatchedType), - }; - let arith_detail = if kind == ast::ScalarKind::Signed { - ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::ScalarType::from_parts(width, ast::ScalarKind::Signed), - saturate: false, - }) - } else { - ast::ArithDetails::Unsigned(ast::ScalarType::from_parts( - width, - ast::ScalarKind::Unsigned, - )) - }; - let id_constant_stmt = self - .id_def - .register_intermediate(add_type.clone(), ast::StateSpace::Reg); - let result_id = self - .id_def - .register_intermediate(add_type, ast::StateSpace::Reg); - // TODO: check for edge cases around min value/max value/wrapping - if offset < 0 && kind != ast::ScalarKind::Signed { - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: ast::ScalarType::from_parts(width, kind), - value: ast::ImmediateValue::U64(-(offset as i64) as u64), - })); - self.func.push(Statement::Instruction( - ast::Instruction::::Sub( - arith_detail, - ast::Arg3 { - dst: result_id, - src1: reg, - src2: id_constant_stmt, - }, - ), - )); - } else { - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: ast::ScalarType::from_parts(width, kind), - value: ast::ImmediateValue::S64(offset as i64), - })); - self.func.push(Statement::Instruction( - ast::Instruction::::Add( - arith_detail, - ast::Arg3 { - dst: result_id, - src1: reg, - src2: id_constant_stmt, - }, - ), - )); + _ => Err(error_unreachable()), } - Ok(result_id) } fn immediate( @@ -2519,32 +2450,8 @@ fn insert_implicit_conversions( Statement::Instruction(inst) => { insert_implicit_conversions_impl(&mut result, id_def, inst)?; } - Statement::PtrAccess(PtrAccess { - underlying_type, - state_space, - dst, - ptr_src, - offset_src: constant_src, - }) => { - let visit_desc = VisitArgumentDescriptor { - desc: ArgumentDescriptor { - op: ptr_src, - is_dst: false, - non_default_implicit_conversion: None, - }, - typ: &ast::Type::Pointer(underlying_type, state_space), - state_space: new_todo!(), - stmt_ctor: |new_ptr_src| { - Statement::PtrAccess(PtrAccess { - underlying_type, - state_space, - dst, - ptr_src: new_ptr_src, - offset_src: constant_src, - }) - }, - }; - insert_implicit_conversions_impl(&mut result, id_def, visit_desc)?; + Statement::PtrAccess(access) => { + insert_implicit_conversions_impl(&mut result, id_def, access)?; } Statement::RepackVector(repack) => { insert_implicit_conversions_impl(&mut result, id_def, repack)?; @@ -5458,16 +5365,7 @@ impl> PtrAccess

{ self, visitor: &mut V, ) -> Result, TranslateError> { - let sema = match self.state_space { - ast::StateSpace::Const - | ast::StateSpace::Global - | ast::StateSpace::Shared - | ast::StateSpace::Generic => ArgumentSemantics::PhysicalPointer, - ast::StateSpace::Local | ast::StateSpace::Param => ArgumentSemantics::RegisterPointer, - ast::StateSpace::Reg => new_todo!(), - ast::StateSpace::Sreg => new_todo!(), - }; - let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), new_todo!()); + let ptr_type = ast::Type::Scalar(self.underlying_type.clone()); let new_dst = visitor.id( ArgumentDescriptor { op: self.dst, @@ -5492,7 +5390,7 @@ impl> PtrAccess

{ non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::S64), - self.state_space, + ast::StateSpace::Reg, )?; Ok(PtrAccess { underlying_type: self.underlying_type,