diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 1c6d2fb..bc2fa4c 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -210,20 +210,6 @@ sub_enum!(LdStScalarType { F64, }); -sub_enum!(SelpType { - B16, - B32, - B64, - U16, - U32, - U64, - S16, - S32, - S64, - F32, - F64, -}); - #[derive(Copy, Clone, Eq, PartialEq)] pub enum BarDetails { SyncAligned, @@ -425,52 +411,6 @@ pub enum ScalarType { Pred, } -sub_enum!(IntType { - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64 -}); - -sub_enum!(BitType { B8, B16, B32, B64 }); - -sub_enum!(UIntType { U8, U16, U32, U64 }); - -sub_enum!(SIntType { S8, S16, S32, S64 }); - -impl IntType { - pub fn is_signed(self) -> bool { - match self { - IntType::U8 | IntType::U16 | IntType::U32 | IntType::U64 => false, - IntType::S8 | IntType::S16 | IntType::S32 | IntType::S64 => true, - } - } - - pub fn width(self) -> u8 { - match self { - IntType::U8 => 1, - IntType::U16 => 2, - IntType::U32 => 4, - IntType::U64 => 8, - IntType::S8 => 1, - IntType::S16 => 2, - IntType::S32 => 4, - IntType::S64 => 8, - } - } -} - -sub_enum!(FloatType { - F16, - F16x2, - F32, - F64 -}); - impl ScalarType { pub fn size_of(self) -> u8 { match self { @@ -576,24 +516,24 @@ pub enum Instruction { Add(ArithDetails, Arg3

), Setp(SetpData, Arg4Setp

), SetpBool(SetpBoolData, Arg5Setp

), - Not(BooleanType, Arg2

), + Not(ScalarType, Arg2

), Bra(BraData, Arg1

), Cvt(CvtDetails, Arg2

), Cvta(CvtaDetails, Arg2

), - Shl(ShlType, Arg3

), - Shr(ShrType, Arg3

), + Shl(ScalarType, Arg3

), + Shr(ScalarType, Arg3

), St(StData, Arg2St

), Ret(RetData), Call(CallInst

), Abs(AbsDetails, Arg2

), Mad(MulDetails, Arg4

), - Or(BooleanType, Arg3

), + Or(ScalarType, Arg3

), Sub(ArithDetails, Arg3

), Min(MinMaxDetails, Arg3

), Max(MinMaxDetails, Arg3

), Rcp(RcpDetails, Arg2

), - And(BooleanType, Arg3

), - Selp(SelpType, Arg4

), + And(ScalarType, Arg3

), + Selp(ScalarType, Arg4

), Bar(BarDetails, Arg1Bar

), Atom(AtomDetails, Arg3

), AtomCas(AtomCasDetails, Arg4

), @@ -605,13 +545,13 @@ pub enum Instruction { Cos { flush_to_zero: bool, arg: Arg2

}, Lg2 { flush_to_zero: bool, arg: Arg2

}, Ex2 { flush_to_zero: bool, arg: Arg2

}, - Clz { typ: BitType, arg: Arg2

}, - Brev { typ: BitType, arg: Arg2

}, - Popc { typ: BitType, arg: Arg2

}, - Xor { typ: BooleanType, arg: Arg3

}, - Bfe { typ: IntType, arg: Arg4

}, - Bfi { typ: BitType, arg: Arg5

}, - Rem { typ: IntType, arg: Arg3

}, + Clz { typ: ScalarType, arg: Arg2

}, + Brev { typ: ScalarType, arg: Arg2

}, + Popc { typ: ScalarType, arg: Arg2

}, + Xor { typ: ScalarType, arg: Arg3

}, + Bfe { typ: ScalarType, arg: Arg4

}, + Bfi { typ: ScalarType, arg: Arg5

}, + Rem { typ: ScalarType, arg: Arg3

}, } #[derive(Copy, Clone)] @@ -825,7 +765,7 @@ impl MovDetails { #[derive(Copy, Clone)] pub struct MulIntDesc { - pub typ: IntType, + pub typ: ScalarType, pub control: MulIntControl, } @@ -845,7 +785,7 @@ pub enum RoundingMode { } pub struct AddIntDesc { - pub typ: IntType, + pub typ: ScalarType, pub saturate: bool, } @@ -892,39 +832,39 @@ pub struct BraData { pub enum CvtDetails { IntFromInt(CvtIntToIntDesc), - FloatFromFloat(CvtDesc), - IntFromFloat(CvtDesc), - FloatFromInt(CvtDesc), + FloatFromFloat(CvtDesc), + IntFromFloat(CvtDesc), + FloatFromInt(CvtDesc), } pub struct CvtIntToIntDesc { - pub dst: IntType, - pub src: IntType, + pub dst: ScalarType, + pub src: ScalarType, pub saturate: bool, } -pub struct CvtDesc { +pub struct CvtDesc { pub rounding: Option, pub flush_to_zero: Option, pub saturate: bool, - pub dst: Dst, - pub src: Src, + pub dst: ScalarType, + pub src: ScalarType, } impl CvtDetails { pub fn new_int_from_int_checked<'err, 'input>( saturate: bool, - dst: IntType, - src: IntType, + dst: ScalarType, + src: ScalarType, err: &'err mut Vec, PtxError>>, ) -> Self { if saturate { - if src.is_signed() { - if dst.is_signed() && dst.width() >= src.width() { + if src.kind() == ScalarKind::Signed { + if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() { err.push(ParseError::from(PtxError::SyntaxError)); } } else { - if dst == src || dst.width() >= src.width() { + if dst == src || dst.size_of() >= src.size_of() { err.push(ParseError::from(PtxError::SyntaxError)); } } @@ -936,11 +876,11 @@ impl CvtDetails { rounding: RoundingMode, flush_to_zero: bool, saturate: bool, - dst: FloatType, - src: IntType, + dst: ScalarType, + src: ScalarType, err: &'err mut Vec, PtxError>>, ) -> Self { - if flush_to_zero && dst != FloatType::F32 { + if flush_to_zero && dst != ScalarType::F32 { err.push(ParseError::from(PtxError::NonF32Ftz)); } CvtDetails::FloatFromInt(CvtDesc { @@ -956,11 +896,11 @@ impl CvtDetails { rounding: RoundingMode, flush_to_zero: bool, saturate: bool, - dst: IntType, - src: FloatType, + dst: ScalarType, + src: ScalarType, err: &'err mut Vec, PtxError>>, ) -> Self { - if flush_to_zero && src != FloatType::F32 { + if flush_to_zero && src != ScalarType::F32 { err.push(ParseError::from(PtxError::NonF32Ftz)); } CvtDetails::IntFromFloat(CvtDesc { @@ -993,25 +933,6 @@ pub enum CvtaSize { U64, } -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum ShlType { - B16, - B32, - B64, -} - -sub_enum!(ShrType { - B16, - B32, - B64, - U16, - U32, - U64, - S16, - S32, - S64, -}); - pub struct StData { pub qualifier: LdStQualifier, pub state_space: StStateSpace, @@ -1040,13 +961,6 @@ pub struct RetData { pub uniform: bool, } -sub_enum!(BooleanType { - Pred, - B16, - B32, - B64, -}); - #[derive(Copy, Clone)] pub enum MulDetails { Unsigned(MulUInt), @@ -1056,32 +970,32 @@ pub enum MulDetails { #[derive(Copy, Clone)] pub struct MulUInt { - pub typ: UIntType, + pub typ: ScalarType, pub control: MulIntControl, } #[derive(Copy, Clone)] pub struct MulSInt { - pub typ: SIntType, + pub typ: ScalarType, pub control: MulIntControl, } #[derive(Copy, Clone)] pub enum ArithDetails { - Unsigned(UIntType), + Unsigned(ScalarType), Signed(ArithSInt), Float(ArithFloat), } #[derive(Copy, Clone)] pub struct ArithSInt { - pub typ: SIntType, + pub typ: ScalarType, pub saturate: bool, } #[derive(Copy, Clone)] pub struct ArithFloat { - pub typ: FloatType, + pub typ: ScalarType, pub rounding: Option, pub flush_to_zero: Option, pub saturate: bool, @@ -1089,8 +1003,8 @@ pub struct ArithFloat { #[derive(Copy, Clone)] pub enum MinMaxDetails { - Signed(SIntType), - Unsigned(UIntType), + Signed(ScalarType), + Unsigned(ScalarType), Float(MinMaxFloat), } @@ -1098,7 +1012,7 @@ pub enum MinMaxDetails { pub struct MinMaxFloat { pub flush_to_zero: Option, pub nan: bool, - pub typ: FloatType, + pub typ: ScalarType, } #[derive(Copy, Clone)] @@ -1126,10 +1040,10 @@ pub enum AtomSpace { #[derive(Copy, Clone)] pub enum AtomInnerDetails { - Bit { op: AtomBitOp, typ: BitType }, - Unsigned { op: AtomUIntOp, typ: UIntType }, - Signed { op: AtomSIntOp, typ: SIntType }, - Float { op: AtomFloatOp, typ: FloatType }, + Bit { op: AtomBitOp, typ: ScalarType }, + Unsigned { op: AtomUIntOp, typ: ScalarType }, + Signed { op: AtomSIntOp, typ: ScalarType }, + Float { op: AtomFloatOp, typ: ScalarType }, } #[derive(Copy, Clone, Eq, PartialEq)] @@ -1166,19 +1080,19 @@ pub struct AtomCasDetails { pub semantics: AtomSemantics, pub scope: MemScope, pub space: AtomSpace, - pub typ: BitType, + pub typ: ScalarType, } #[derive(Copy, Clone)] pub enum DivDetails { - Unsigned(UIntType), - Signed(SIntType), + Unsigned(ScalarType), + Signed(ScalarType), Float(DivFloatDetails), } #[derive(Copy, Clone)] pub struct DivFloatDetails { - pub typ: FloatType, + pub typ: ScalarType, pub flush_to_zero: Option, pub kind: DivFloatKind, } @@ -1197,7 +1111,7 @@ pub enum NumsOrArrays<'a> { #[derive(Copy, Clone)] pub struct SqrtDetails { - pub typ: FloatType, + pub typ: ScalarType, pub flush_to_zero: Option, pub kind: SqrtKind, } @@ -1210,7 +1124,7 @@ pub enum SqrtKind { #[derive(Copy, Clone, Eq, PartialEq)] pub struct RsqrtDetails { - pub typ: FloatType, + pub typ: ScalarType, pub flush_to_zero: bool, } @@ -1379,6 +1293,40 @@ pub enum TuningDirective { MinNCtaPerSm(u32), } +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum ScalarKind { + Bit, + Unsigned, + Signed, + Float, + Float2, + Pred, +} + +impl ScalarType { + pub fn kind(self) -> ScalarKind { + match self { + ScalarType::U8 => ScalarKind::Unsigned, + ScalarType::U16 => ScalarKind::Unsigned, + ScalarType::U32 => ScalarKind::Unsigned, + ScalarType::U64 => ScalarKind::Unsigned, + ScalarType::S8 => ScalarKind::Signed, + ScalarType::S16 => ScalarKind::Signed, + ScalarType::S32 => ScalarKind::Signed, + ScalarType::S64 => ScalarKind::Signed, + ScalarType::B8 => ScalarKind::Bit, + ScalarType::B16 => ScalarKind::Bit, + ScalarType::B32 => ScalarKind::Bit, + ScalarType::B64 => ScalarKind::Bit, + ScalarType::F16 => ScalarKind::Float, + ScalarType::F32 => ScalarKind::Float, + ScalarType::F64 => ScalarKind::Float, + ScalarType::F16x2 => ScalarKind::Float2, + ScalarType::Pred => ScalarKind::Pred, + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 423fd57..41c1d73 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -899,39 +899,39 @@ RoundingModeInt : ast::RoundingMode = { ".rpi" => ast::RoundingMode::PositiveInf, }; -IntType : ast::IntType = { - ".u16" => ast::IntType::U16, - ".u32" => ast::IntType::U32, - ".u64" => ast::IntType::U64, - ".s16" => ast::IntType::S16, - ".s32" => ast::IntType::S32, - ".s64" => ast::IntType::S64, +IntType : ast::ScalarType = { + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; -IntType3264: ast::IntType = { - ".u32" => ast::IntType::U32, - ".u64" => ast::IntType::U64, - ".s32" => ast::IntType::S32, - ".s64" => ast::IntType::S64, +IntType3264: ast::ScalarType = { + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, } -UIntType: ast::UIntType = { - ".u16" => ast::UIntType::U16, - ".u32" => ast::UIntType::U32, - ".u64" => ast::UIntType::U64, +UIntType: ast::ScalarType = { + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, }; -SIntType: ast::SIntType = { - ".s16" => ast::SIntType::S16, - ".s32" => ast::SIntType::S32, - ".s64" => ast::SIntType::S64, +SIntType: ast::ScalarType = { + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; -FloatType: ast::FloatType = { - ".f16" => ast::FloatType::F16, - ".f16x2" => ast::FloatType::F16x2, - ".f32" => ast::FloatType::F32, - ".f64" => ast::FloatType::F64, +FloatType: ast::ScalarType = { + ".f16" => ast::ScalarType::F16, + ".f16x2" => ast::ScalarType::F16x2, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add @@ -1023,11 +1023,11 @@ InstNot: ast::Instruction> = { "not" => ast::Instruction::Not(t, a) }; -BooleanType: ast::BooleanType = { - ".pred" => ast::BooleanType::Pred, - ".b16" => ast::BooleanType::B16, - ".b32" => ast::BooleanType::B32, - ".b64" => ast::BooleanType::B64, +BooleanType: ast::ScalarType = { + ".pred" => ast::ScalarType::Pred, + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at @@ -1080,8 +1080,8 @@ InstCvt: ast::Instruction> = { rounding: r, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F16, - src: ast::FloatType::F16 + dst: ast::ScalarType::F16, + src: ast::ScalarType::F16 } ), a) }, @@ -1091,8 +1091,8 @@ InstCvt: ast::Instruction> = { rounding: None, flush_to_zero: Some(f.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F32, - src: ast::FloatType::F16 + dst: ast::ScalarType::F32, + src: ast::ScalarType::F16 } ), a) }, @@ -1102,8 +1102,8 @@ InstCvt: ast::Instruction> = { rounding: None, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F64, - src: ast::FloatType::F16 + dst: ast::ScalarType::F64, + src: ast::ScalarType::F16 } ), a) }, @@ -1113,8 +1113,8 @@ InstCvt: ast::Instruction> = { rounding: Some(r), flush_to_zero: Some(f.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F16, - src: ast::FloatType::F32 + dst: ast::ScalarType::F16, + src: ast::ScalarType::F32 } ), a) }, @@ -1124,8 +1124,8 @@ InstCvt: ast::Instruction> = { rounding: r, flush_to_zero: Some(f.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F32, - src: ast::FloatType::F32 + dst: ast::ScalarType::F32, + src: ast::ScalarType::F32 } ), a) }, @@ -1135,8 +1135,8 @@ InstCvt: ast::Instruction> = { rounding: None, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F64, - src: ast::FloatType::F32 + dst: ast::ScalarType::F64, + src: ast::ScalarType::F32 } ), a) }, @@ -1146,8 +1146,8 @@ InstCvt: ast::Instruction> = { rounding: Some(r), flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F16, - src: ast::FloatType::F64 + dst: ast::ScalarType::F16, + src: ast::ScalarType::F64 } ), a) }, @@ -1157,8 +1157,8 @@ InstCvt: ast::Instruction> = { rounding: Some(r), flush_to_zero: Some(s.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F32, - src: ast::FloatType::F64 + dst: ast::ScalarType::F32, + src: ast::ScalarType::F64 } ), a) }, @@ -1168,28 +1168,28 @@ InstCvt: ast::Instruction> = { rounding: r, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F64, - src: ast::FloatType::F64 + dst: ast::ScalarType::F64, + src: ast::ScalarType::F64 } ), a) }, }; -CvtTypeInt: ast::IntType = { - ".u8" => ast::IntType::U8, - ".u16" => ast::IntType::U16, - ".u32" => ast::IntType::U32, - ".u64" => ast::IntType::U64, - ".s8" => ast::IntType::S8, - ".s16" => ast::IntType::S16, - ".s32" => ast::IntType::S32, - ".s64" => ast::IntType::S64, +CvtTypeInt: ast::ScalarType = { + ".u8" => ast::ScalarType::U8, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s8" => ast::ScalarType::S8, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; -CvtTypeFloat: ast::FloatType = { - ".f16" => ast::FloatType::F16, - ".f32" => ast::FloatType::F32, - ".f64" => ast::FloatType::F64, +CvtTypeFloat: ast::ScalarType = { + ".f16" => ast::ScalarType::F16, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl @@ -1197,10 +1197,10 @@ InstShl: ast::Instruction> = { "shl" => ast::Instruction::Shl(t, a) }; -ShlType: ast::ShlType = { - ".b16" => ast::ShlType::B16, - ".b32" => ast::ShlType::B32, - ".b64" => ast::ShlType::B64, +ShlType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr @@ -1208,16 +1208,16 @@ InstShr: ast::Instruction> = { "shr" => ast::Instruction::Shr(t, a) }; -ShrType: ast::ShrType = { - ".b16" => ast::ShrType::B16, - ".b32" => ast::ShrType::B32, - ".b64" => ast::ShrType::B64, - ".u16" => ast::ShrType::U16, - ".u32" => ast::ShrType::U32, - ".u64" => ast::ShrType::U64, - ".s16" => ast::ShrType::S16, - ".s32" => ast::ShrType::S32, - ".s64" => ast::ShrType::S64, +ShrType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st @@ -1393,16 +1393,16 @@ MinMaxDetails: ast::MinMaxDetails = { => ast::MinMaxDetails::Unsigned(t), => ast::MinMaxDetails::Signed(t), ".f32" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F32 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F32 } ), ".f64" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::FloatType::F64 } + ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::ScalarType::F64 } ), ".f16" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16 } ), ".f16x2" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16x2 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16x2 } ) } @@ -1411,18 +1411,18 @@ InstSelp: ast::Instruction> = { "selp" => ast::Instruction::Selp(t, a), }; -SelpType: ast::SelpType = { - ".b16" => ast::SelpType::B16, - ".b32" => ast::SelpType::B32, - ".b64" => ast::SelpType::B64, - ".u16" => ast::SelpType::U16, - ".u32" => ast::SelpType::U32, - ".u64" => ast::SelpType::U64, - ".s16" => ast::SelpType::S16, - ".s32" => ast::SelpType::S32, - ".s64" => ast::SelpType::S64, - ".f32" => ast::SelpType::F32, - ".f64" => ast::SelpType::F64, +SelpType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar @@ -1454,7 +1454,7 @@ InstAtom: ast::Instruction> = { space: space.unwrap_or(ast::AtomSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Inc, - typ: ast::UIntType::U32 + typ: ast::ScalarType::U32 } }; ast::Instruction::Atom(details,a) @@ -1466,7 +1466,7 @@ InstAtom: ast::Instruction> = { space: space.unwrap_or(ast::AtomSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Dec, - typ: ast::UIntType::U32 + typ: ast::ScalarType::U32 } }; ast::Instruction::Atom(details,a) @@ -1544,19 +1544,19 @@ AtomSIntOp: ast::AtomSIntOp = { ".max" => ast::AtomSIntOp::Max, } -BitType: ast::BitType = { - ".b32" => ast::BitType::B32, - ".b64" => ast::BitType::B64, +BitType: ast::ScalarType = { + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, } -UIntType3264: ast::UIntType = { - ".u32" => ast::UIntType::U32, - ".u64" => ast::UIntType::U64, +UIntType3264: ast::ScalarType = { + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, } -SIntType3264: ast::SIntType = { - ".s32" => ast::SIntType::S32, - ".s64" => ast::SIntType::S64, +SIntType3264: ast::ScalarType = { + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div @@ -1566,7 +1566,7 @@ InstDiv: ast::Instruction> = { "div" => ast::Instruction::Div(ast::DivDetails::Signed(t), a), "div" ".f32" => { let inner = ast::DivFloatDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: Some(ftz.is_some()), kind }; @@ -1574,7 +1574,7 @@ InstDiv: ast::Instruction> = { }, "div" ".f64" => { let inner = ast::DivFloatDetails { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, flush_to_zero: None, kind: ast::DivFloatKind::Rounding(rnd) }; @@ -1592,7 +1592,7 @@ DivFloatKind: ast::DivFloatKind = { InstSqrt: ast::Instruction> = { "sqrt" ".approx" ".f32" => { let details = ast::SqrtDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: Some(ftz.is_some()), kind: ast::SqrtKind::Approx, }; @@ -1600,7 +1600,7 @@ InstSqrt: ast::Instruction> = { }, "sqrt" ".f32" => { let details = ast::SqrtDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: Some(ftz.is_some()), kind: ast::SqrtKind::Rounding(rnd), }; @@ -1608,7 +1608,7 @@ InstSqrt: ast::Instruction> = { }, "sqrt" ".f64" => { let details = ast::SqrtDetails { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, flush_to_zero: None, kind: ast::SqrtKind::Rounding(rnd), }; @@ -1621,14 +1621,14 @@ InstSqrt: ast::Instruction> = { InstRsqrt: ast::Instruction> = { "rsqrt" ".approx" ".f32" => { let details = ast::RsqrtDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: ftz.is_some(), }; ast::Instruction::Rsqrt(details, a) }, "rsqrt" ".approx" ".f64" => { let details = ast::RsqrtDetails { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, flush_to_zero: ftz.is_some(), }; ast::Instruction::Rsqrt(details, a) @@ -1739,7 +1739,7 @@ ArithDetails: ast::ArithDetails = { saturate: false, }), ".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S32, + typ: ast::ScalarType::S32, saturate: true, }), => ast::ArithDetails::Float(f) @@ -1747,25 +1747,25 @@ ArithDetails: ast::ArithDetails = { ArithFloat: ast::ArithFloat = { ".f32" => ast::ArithFloat { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, rounding: rn, flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".f64" => ast::ArithFloat { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, rounding: rn, flush_to_zero: None, saturate: false, }, ".f16" => ast::ArithFloat { - typ: ast::FloatType::F16, + typ: ast::ScalarType::F16, rounding: rn.map(|_| ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".f16x2" => ast::ArithFloat { - typ: ast::FloatType::F16x2, + typ: ast::ScalarType::F16x2, rounding: rn.map(|_| ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), @@ -1774,25 +1774,25 @@ ArithFloat: ast::ArithFloat = { ArithFloatMustRound: ast::ArithFloat = { ".f32" => ast::ArithFloat { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, rounding: Some(rn), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".f64" => ast::ArithFloat { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, rounding: Some(rn), flush_to_zero: None, saturate: false, }, ".rn" ".f16" => ast::ArithFloat { - typ: ast::FloatType::F16, + typ: ast::ScalarType::F16, rounding: Some(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".rn" ".f16x2" => ast::ArithFloat { - typ: ast::FloatType::F16x2, + typ: ast::ScalarType::F16x2, rounding: Some(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), diff --git a/ptx/src/test/spirv_run/bfi.spvtxt b/ptx/src/test/spirv_run/bfi.spvtxt index a226f78..dc8f683 100644 --- a/ptx/src/test/spirv_run/bfi.spvtxt +++ b/ptx/src/test/spirv_run/bfi.spvtxt @@ -71,7 +71,7 @@ %26 = OpLoad %uint %9 %40 = OpCopyObject %uint %23 %41 = OpCopyObject %uint %24 - %39 = OpFunctionCall %uint %44 %41 %40 %25 %26 + %39 = OpFunctionCall %uint %44 %40 %41 %25 %26 %22 = OpCopyObject %uint %39 OpStore %6 %22 %27 = OpLoad %ulong %5 diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index e39280a..51b1dc6 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1553,10 +1553,9 @@ fn extract_globals<'input, 'b>( space, }; let (op, typ) = match typ { - ast::FloatType::F32 => ("add_f32", ast::SizedScalarType::F32), - ast::FloatType::F64 => ("add_f64", ast::SizedScalarType::F64), - ast::FloatType::F16 => unreachable!(), - ast::FloatType::F16x2 => unreachable!(), + ast::ScalarType::F32 => ("add_f32", ast::SizedScalarType::F32), + ast::ScalarType::F64 => ("add_f64", ast::SizedScalarType::F64), + _ => unreachable!(), }; local.push(to_ptx_impl_atomic_call( id_def, @@ -1822,15 +1821,15 @@ fn to_ptx_impl_atomic_call( fn to_ptx_impl_bfe_call( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, - typ: ast::IntType, + typ: ast::ScalarType, arg: ast::Arg4, ) -> ExpandedStatement { let prefix = "__zluda_ptx_impl__"; let suffix = match typ { - ast::IntType::U32 => "bfe_u32", - ast::IntType::U64 => "bfe_u64", - ast::IntType::S32 => "bfe_s32", - ast::IntType::S64 => "bfe_s64", + ast::ScalarType::U32 => "bfe_u32", + ast::ScalarType::U64 => "bfe_u64", + ast::ScalarType::S32 => "bfe_s32", + ast::ScalarType::S64 => "bfe_s64", _ => unreachable!(), }; let fn_name = format!("{}{}", prefix, suffix); @@ -1917,14 +1916,14 @@ fn to_ptx_impl_bfe_call( fn to_ptx_impl_bfi_call( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, - typ: ast::BitType, + typ: ast::ScalarType, arg: ast::Arg5, ) -> ExpandedStatement { let prefix = "__zluda_ptx_impl__"; let suffix = match typ { - ast::BitType::B32 => "bfi_b32", - ast::BitType::B64 => "bfi_b64", - ast::BitType::B8 | ast::BitType::B16 => unreachable!(), + ast::ScalarType::B32 => "bfi_b32", + ast::ScalarType::B64 => "bfi_b64", + _ => unreachable!(), }; let fn_name = format!("{}{}", prefix, suffix); let fn_id = match ptx_impl_imports.entry(fn_name) { @@ -2506,29 +2505,32 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { let (width, kind) = match add_type { ast::Type::Scalar(scalar_t) => { let kind = match scalar_t.kind() { - kind @ ScalarKind::Bit - | kind @ ScalarKind::Unsigned - | kind @ ScalarKind::Signed => kind, - ScalarKind::Float => return Err(TranslateError::MismatchedType), - ScalarKind::Float2 => return Err(TranslateError::MismatchedType), - ScalarKind::Pred => return Err(TranslateError::MismatchedType), + kind @ ast::ScalarKind::Bit + | kind @ ast::ScalarKind::Unsigned + | kind @ ast::ScalarKind::Signed => kind, + ast::ScalarKind::Float => return Err(TranslateError::MismatchedType), + ast::ScalarKind::Float2 => return Err(TranslateError::MismatchedType), + ast::ScalarKind::Pred => return Err(TranslateError::MismatchedType), }; (scalar_t.size_of(), kind) } _ => return Err(TranslateError::MismatchedType), }; - let arith_detail = if kind == ScalarKind::Signed { + let arith_detail = if kind == ast::ScalarKind::Signed { ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::from_size(width), + typ: ast::ScalarType::from_parts(width, ast::ScalarKind::Signed), saturate: false, }) } else { - ast::ArithDetails::Unsigned(ast::UIntType::from_size(width)) + ast::ArithDetails::Unsigned(ast::ScalarType::from_parts( + width, + ast::ScalarKind::Unsigned, + )) }; let id_constant_stmt = self.id_def.new_non_variable(add_type.clone()); let result_id = self.id_def.new_non_variable(add_type); // TODO: check for edge cases around min value/max value/wrapping - if offset < 0 && kind != ScalarKind::Signed { + if offset < 0 && kind != ast::ScalarKind::Signed { self.func.push(Statement::Constant(ConstantDefinition { dst: id_constant_stmt, typ: ast::ScalarType::from_parts(width, kind), @@ -3026,18 +3028,18 @@ fn emit_function_body_ops( emit_setp(builder, map, setp, arg)?; } ast::Instruction::Not(t, a) => { - let result_type = map.get_or_add(builder, SpirvType::from(t.to_type())); + let result_type = map.get_or_add(builder, SpirvType::from(*t)); let result_id = Some(a.dst); let operand = a.src; match t { - ast::BooleanType::Pred => { + ast::ScalarType::Pred => { logical_not(builder, result_type, result_id, operand) } _ => builder.not(result_type, result_id, operand), }?; } ast::Instruction::Shl(t, a) => { - let full_type = t.to_type(); + let full_type = ast::Type::Scalar(*t); let size_of = full_type.size_of(); let result_type = map.get_or_add(builder, SpirvType::from(full_type)); let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?; @@ -3048,7 +3050,7 @@ fn emit_function_body_ops( let size_of = full_type.size_of(); let result_type = map.get_or_add_scalar(builder, full_type); let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?; - if t.signed() { + if t.kind() == ast::ScalarKind::Signed { builder.shift_right_arithmetic( result_type, Some(a.dst), @@ -3088,7 +3090,7 @@ fn emit_function_body_ops( }, ast::Instruction::Or(t, a) => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); - if *t == ast::BooleanType::Pred { + if *t == ast::ScalarType::Pred { builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?; } else { builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?; @@ -3116,7 +3118,7 @@ fn emit_function_body_ops( } ast::Instruction::And(t, a) => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); - if *t == ast::BooleanType::Pred { + if *t == ast::ScalarType::Pred { builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?; } else { builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?; @@ -3202,7 +3204,7 @@ fn emit_function_body_ops( } ast::Instruction::Neg(details, arg) => { let result_type = map.get_or_add_scalar(builder, details.typ); - let negate_func = if details.typ.kind() == ScalarKind::Float { + let negate_func = if details.typ.kind() == ast::ScalarKind::Float { dr::Builder::f_negate } else { dr::Builder::s_negate @@ -3269,7 +3271,7 @@ fn emit_function_body_ops( } ast::Instruction::Xor { typ, arg } => { let builder_fn = match typ { - ast::BooleanType::Pred => emit_logical_xor_spirv, + ast::ScalarType::Pred => emit_logical_xor_spirv, _ => dr::Builder::bitwise_xor, }; let result_type = map.get_or_add_scalar(builder, (*typ).into()); @@ -3284,7 +3286,7 @@ fn emit_function_body_ops( return Err(error_unreachable()); } ast::Instruction::Rem { typ, arg } => { - let builder_fn = if typ.is_signed() { + let builder_fn = if typ.kind() == ast::ScalarKind::Signed { dr::Builder::s_mod } else { dr::Builder::u_mod @@ -3882,7 +3884,7 @@ fn emit_cvt( } let dest_t: ast::ScalarType = desc.dst.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - if desc.src.is_signed() { + if desc.src.kind() == ast::ScalarKind::Signed { builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?; } else { builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?; @@ -3892,7 +3894,7 @@ fn emit_cvt( ast::CvtDetails::IntFromFloat(desc) => { let dest_t: ast::ScalarType = desc.dst.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - if desc.dst.is_signed() { + if desc.dst.kind() == ast::ScalarKind::Signed { builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?; } else { builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?; @@ -3904,7 +3906,7 @@ fn emit_cvt( let dest_t: ast::ScalarType = desc.dst.into(); let src_t: ast::ScalarType = desc.src.into(); // first do shortening/widening - let src = if desc.dst.width() != desc.src.width() { + let src = if desc.dst.size_of() != desc.src.size_of() { let new_dst = if dest_t.kind() == src_t.kind() { arg.dst } else { @@ -3933,7 +3935,7 @@ fn emit_cvt( // now do actual conversion let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); if desc.saturate { - if desc.dst.is_signed() { + if desc.dst.kind() == ast::ScalarKind::Signed { builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?; } else { builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?; @@ -3989,60 +3991,60 @@ fn emit_setp( let operand_1 = arg.src1; let operand_2 = arg.src2; match (setp.cmp_op, setp.typ.kind()) { - (ast::SetpCompareOp::Eq, ScalarKind::Signed) - | (ast::SetpCompareOp::Eq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::Eq, ScalarKind::Bit) => { + (ast::SetpCompareOp::Eq, ast::ScalarKind::Signed) + | (ast::SetpCompareOp::Eq, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::Eq, ast::ScalarKind::Bit) => { builder.i_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Eq, ScalarKind::Float) => { + (ast::SetpCompareOp::Eq, ast::ScalarKind::Float) => { builder.f_ord_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::NotEq, ScalarKind::Signed) - | (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => { + (ast::SetpCompareOp::NotEq, ast::ScalarKind::Signed) + | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Bit) => { builder.i_not_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::NotEq, ScalarKind::Float) => { + (ast::SetpCompareOp::NotEq, ast::ScalarKind::Float) => { builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Less, ScalarKind::Unsigned) - | (ast::SetpCompareOp::Less, ScalarKind::Bit) => { + (ast::SetpCompareOp::Less, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::Less, ast::ScalarKind::Bit) => { builder.u_less_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Less, ScalarKind::Signed) => { + (ast::SetpCompareOp::Less, ast::ScalarKind::Signed) => { builder.s_less_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Less, ScalarKind::Float) => { + (ast::SetpCompareOp::Less, ast::ScalarKind::Float) => { builder.f_ord_less_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => { + (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Bit) => { builder.u_less_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => { + (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Signed) => { builder.s_less_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => { + (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Float) => { builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Greater, ScalarKind::Unsigned) - | (ast::SetpCompareOp::Greater, ScalarKind::Bit) => { + (ast::SetpCompareOp::Greater, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::Greater, ast::ScalarKind::Bit) => { builder.u_greater_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Greater, ScalarKind::Signed) => { + (ast::SetpCompareOp::Greater, ast::ScalarKind::Signed) => { builder.s_greater_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Greater, ScalarKind::Float) => { + (ast::SetpCompareOp::Greater, ast::ScalarKind::Float) => { builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => { + (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Bit) => { builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => { + (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Signed) => { builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => { + (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Float) => { builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::NanEq, _) => { @@ -4222,7 +4224,7 @@ fn emit_abs( ) -> Result<(), dr::Error> { let scalar_t = ast::ScalarType::from(d.typ); let result_type = map.get_or_add(builder, SpirvType::from(scalar_t)); - let cl_abs = if scalar_t.kind() == ScalarKind::Signed { + let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed { spirv::CLOp::s_abs } else { spirv::CLOp::fabs @@ -4286,8 +4288,8 @@ fn emit_implicit_conversion( (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { if from_parts.width == to_parts.width { let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); - if from_parts.scalar_kind != ScalarKind::Float - && to_parts.scalar_kind != ScalarKind::Float + if from_parts.scalar_kind != ast::ScalarKind::Float + && to_parts.scalar_kind != ast::ScalarKind::Float { // It is noop, but another instruction expects result of this conversion builder.copy_object(dst_type, Some(cv.dst), cv.src)?; @@ -4299,24 +4301,24 @@ fn emit_implicit_conversion( let same_width_bit_type = map.get_or_add( builder, SpirvType::from(ast::Type::from_parts(TypeParts { - scalar_kind: ScalarKind::Bit, + scalar_kind: ast::ScalarKind::Bit, ..from_parts })), ); let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?; let wide_bit_type = ast::Type::from_parts(TypeParts { - scalar_kind: ScalarKind::Bit, + scalar_kind: ast::ScalarKind::Bit, ..to_parts }); let wide_bit_type_spirv = map.get_or_add(builder, SpirvType::from(wide_bit_type.clone())); - if to_parts.scalar_kind == ScalarKind::Unsigned - || to_parts.scalar_kind == ScalarKind::Bit + if to_parts.scalar_kind == ast::ScalarKind::Unsigned + || to_parts.scalar_kind == ast::ScalarKind::Bit { builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?; } else { - let conversion_fn = if from_parts.scalar_kind == ScalarKind::Signed - && to_parts.scalar_kind == ScalarKind::Signed + let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed + && to_parts.scalar_kind == ast::ScalarKind::Signed { dr::Builder::s_convert } else { @@ -4614,23 +4616,23 @@ fn convert_to_stateful_memory_access<'a>( for statement in func_body.iter() { match statement { Statement::Instruction(ast::Instruction::Add( - ast::ArithDetails::Unsigned(ast::UIntType::U64), + ast::ArithDetails::Unsigned(ast::ScalarType::U64), arg, )) | Statement::Instruction(ast::Instruction::Add( ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S64, + typ: ast::ScalarType::S64, saturate: false, }), arg, )) | Statement::Instruction(ast::Instruction::Sub( - ast::ArithDetails::Unsigned(ast::UIntType::U64), + ast::ArithDetails::Unsigned(ast::ScalarType::U64), arg, )) | Statement::Instruction(ast::Instruction::Sub( ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S64, + typ: ast::ScalarType::S64, saturate: false, }), arg, @@ -4686,12 +4688,12 @@ fn convert_to_stateful_memory_access<'a>( } } Statement::Instruction(ast::Instruction::Add( - ast::ArithDetails::Unsigned(ast::UIntType::U64), + ast::ArithDetails::Unsigned(ast::ScalarType::U64), arg, )) | Statement::Instruction(ast::Instruction::Add( ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S64, + typ: ast::ScalarType::S64, saturate: false, }), arg, @@ -4715,12 +4717,12 @@ fn convert_to_stateful_memory_access<'a>( })) } Statement::Instruction(ast::Instruction::Sub( - ast::ArithDetails::Unsigned(ast::UIntType::U64), + ast::ArithDetails::Unsigned(ast::ScalarType::U64), arg, )) | Statement::Instruction(ast::Instruction::Sub( ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S64, + typ: ast::ScalarType::S64, saturate: false, }), arg, @@ -4867,7 +4869,7 @@ fn convert_to_stateful_memory_access_postprocess( ast::LdStateSpace::Global, ), to: old_type, - kind: ConversionKind::PtrToBit(ast::UIntType::U64), + kind: ConversionKind::PtrToBit(ast::ScalarType::U64), src_sema: arg_desc.sema, dst_sema: ArgumentSemantics::Default, })); @@ -5903,7 +5905,9 @@ impl ast::Instruction { let inst_type = d.typ; ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?) } - ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?), + ast::Instruction::Not(t, a) => { + ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?) + } ast::Instruction::Cvt(d, a) => { let (dst_t, src_t) = match &d { ast::CvtDetails::FloatFromFloat(desc) => ( @@ -5926,7 +5930,7 @@ impl ast::Instruction { ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?) } ast::Instruction::Shl(t, a) => { - ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?) + ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?) } ast::Instruction::Shr(t, a) => { ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?) @@ -6176,9 +6180,9 @@ impl ast::Type { ast::Type::Scalar(scalar) => { let kind = scalar.kind(); let width = scalar.size_of(); - if (kind != ScalarKind::Signed - && kind != ScalarKind::Unsigned - && kind != ScalarKind::Bit) + if (kind != ast::ScalarKind::Signed + && kind != ast::ScalarKind::Unsigned + && kind != ast::ScalarKind::Bit) || (width == 8) { return Err(TranslateError::MismatchedType); @@ -6306,7 +6310,7 @@ impl ast::Type { #[derive(Eq, PartialEq, Clone)] struct TypeParts { kind: TypeKind, - scalar_kind: ScalarKind, + scalar_kind: ast::ScalarKind, width: u8, components: Vec, state_space: ast::LdStateSpace, @@ -6461,7 +6465,7 @@ enum ConversionKind { // zero-extend/chop/bitcast depending on types SignExtend, BitToPtr(ast::LdStateSpace), - PtrToBit(ast::UIntType), + PtrToBit(ast::ScalarType), PtrToPtr { spirv_ptr: bool }, } @@ -6859,7 +6863,7 @@ impl ast::Arg4 { fn map_selp>( self, visitor: &mut V, - t: ast::SelpType, + t: ast::ScalarType, ) -> Result, TranslateError> { let dst = visitor.operand( ArgumentDescriptor { @@ -6904,7 +6908,7 @@ impl ast::Arg4 { fn map_atom>( self, visitor: &mut V, - t: ast::BitType, + t: ast::ScalarType, state_space: ast::AtomSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); @@ -7205,103 +7209,41 @@ impl ast::StStateSpace { } } -#[derive(Clone, Copy, PartialEq, Eq)] -enum ScalarKind { - Bit, - Unsigned, - Signed, - Float, - Float2, - Pred, -} - impl ast::ScalarType { - fn kind(self) -> ScalarKind { - match self { - ast::ScalarType::U8 => ScalarKind::Unsigned, - ast::ScalarType::U16 => ScalarKind::Unsigned, - ast::ScalarType::U32 => ScalarKind::Unsigned, - ast::ScalarType::U64 => ScalarKind::Unsigned, - ast::ScalarType::S8 => ScalarKind::Signed, - ast::ScalarType::S16 => ScalarKind::Signed, - ast::ScalarType::S32 => ScalarKind::Signed, - ast::ScalarType::S64 => ScalarKind::Signed, - ast::ScalarType::B8 => ScalarKind::Bit, - ast::ScalarType::B16 => ScalarKind::Bit, - ast::ScalarType::B32 => ScalarKind::Bit, - ast::ScalarType::B64 => ScalarKind::Bit, - ast::ScalarType::F16 => ScalarKind::Float, - ast::ScalarType::F32 => ScalarKind::Float, - ast::ScalarType::F64 => ScalarKind::Float, - ast::ScalarType::F16x2 => ScalarKind::Float2, - ast::ScalarType::Pred => ScalarKind::Pred, - } - } - - fn from_parts(width: u8, kind: ScalarKind) -> Self { + fn from_parts(width: u8, kind: ast::ScalarKind) -> Self { match kind { - ScalarKind::Float => match width { + ast::ScalarKind::Float => match width { 2 => ast::ScalarType::F16, 4 => ast::ScalarType::F32, 8 => ast::ScalarType::F64, _ => unreachable!(), }, - ScalarKind::Bit => match width { + ast::ScalarKind::Bit => match width { 1 => ast::ScalarType::B8, 2 => ast::ScalarType::B16, 4 => ast::ScalarType::B32, 8 => ast::ScalarType::B64, _ => unreachable!(), }, - ScalarKind::Signed => match width { + ast::ScalarKind::Signed => match width { 1 => ast::ScalarType::S8, 2 => ast::ScalarType::S16, 4 => ast::ScalarType::S32, 8 => ast::ScalarType::S64, _ => unreachable!(), }, - ScalarKind::Unsigned => match width { + ast::ScalarKind::Unsigned => match width { 1 => ast::ScalarType::U8, 2 => ast::ScalarType::U16, 4 => ast::ScalarType::U32, 8 => ast::ScalarType::U64, _ => unreachable!(), }, - ScalarKind::Float2 => match width { + ast::ScalarKind::Float2 => match width { 4 => ast::ScalarType::F16x2, _ => unreachable!(), }, - ScalarKind::Pred => ast::ScalarType::Pred, - } - } -} - -impl ast::BooleanType { - fn to_type(self) -> ast::Type { - match self { - ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred), - ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16), - ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32), - ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64), - } - } -} - -impl ast::ShlType { - fn to_type(self) -> ast::Type { - match self { - ast::ShlType::B16 => ast::Type::Scalar(ast::ScalarType::B16), - ast::ShlType::B32 => ast::Type::Scalar(ast::ScalarType::B32), - ast::ShlType::B64 => ast::Type::Scalar(ast::ScalarType::B64), - } - } -} - -impl ast::ShrType { - fn signed(&self) -> bool { - match self { - ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true, - _ => false, + ast::ScalarKind::Pred => ast::ScalarType::Pred, } } } @@ -7357,30 +7299,6 @@ impl ast::AtomInnerDetails { } } -impl ast::SIntType { - fn from_size(width: u8) -> Self { - match width { - 1 => ast::SIntType::S8, - 2 => ast::SIntType::S16, - 4 => ast::SIntType::S32, - 8 => ast::SIntType::S64, - _ => unreachable!(), - } - } -} - -impl ast::UIntType { - fn from_size(width: u8) -> Self { - match width { - 1 => ast::UIntType::U8, - 2 => ast::UIntType::U16, - 4 => ast::UIntType::U32, - 8 => ast::UIntType::U64, - _ => unreachable!(), - } - } -} - impl ast::LdStateSpace { fn to_spirv(self) -> spirv::StorageClass { match self { @@ -7568,16 +7486,18 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { return false; } match inst.kind() { - ScalarKind::Bit => operand.kind() != ScalarKind::Bit, - ScalarKind::Float => operand.kind() == ScalarKind::Bit, - ScalarKind::Signed => { - operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned + ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit, + ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit, + ast::ScalarKind::Signed => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Unsigned } - ScalarKind::Unsigned => { - operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed + ast::ScalarKind::Unsigned => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Signed } - ScalarKind::Float2 => false, - ScalarKind::Pred => false, + ast::ScalarKind::Float2 => false, + ast::ScalarKind::Pred => false, } } (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _)) @@ -7596,7 +7516,7 @@ fn should_bitcast_packed( if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = (operand, instr) { - if scalar.kind() == ScalarKind::Bit + if scalar.kind() == ast::ScalarKind::Bit && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) { return Ok(Some(ConversionKind::Default)); @@ -7644,32 +7564,33 @@ fn should_convert_relaxed_src( } match (src_type, instr_type) { (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { - ScalarKind::Bit => { + ast::ScalarKind::Bit => { if instr_type.size_of() <= src_type.size_of() { Some(ConversionKind::Default) } else { None } } - ScalarKind::Signed | ScalarKind::Unsigned => { + ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => { if instr_type.size_of() <= src_type.size_of() - && src_type.kind() != ScalarKind::Float + && src_type.kind() != ast::ScalarKind::Float { Some(ConversionKind::Default) } else { None } } - ScalarKind::Float => { - if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit + ast::ScalarKind::Float => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() == ast::ScalarKind::Bit { Some(ConversionKind::Default) } else { None } } - ScalarKind::Float2 => todo!(), - ScalarKind::Pred => None, + ast::ScalarKind::Float2 => todo!(), + ast::ScalarKind::Pred => None, }, (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { @@ -7706,15 +7627,15 @@ fn should_convert_relaxed_dst( } match (dst_type, instr_type) { (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { - ScalarKind::Bit => { + ast::ScalarKind::Bit => { if instr_type.size_of() <= dst_type.size_of() { Some(ConversionKind::Default) } else { None } } - ScalarKind::Signed => { - if dst_type.kind() != ScalarKind::Float { + ast::ScalarKind::Signed => { + if dst_type.kind() != ast::ScalarKind::Float { if instr_type.size_of() == dst_type.size_of() { Some(ConversionKind::Default) } else if instr_type.size_of() < dst_type.size_of() { @@ -7726,25 +7647,26 @@ fn should_convert_relaxed_dst( None } } - ScalarKind::Unsigned => { + ast::ScalarKind::Unsigned => { if instr_type.size_of() <= dst_type.size_of() - && dst_type.kind() != ScalarKind::Float + && dst_type.kind() != ast::ScalarKind::Float { Some(ConversionKind::Default) } else { None } } - ScalarKind::Float => { - if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit + ast::ScalarKind::Float => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() == ast::ScalarKind::Bit { Some(ConversionKind::Default) } else { None } } - ScalarKind::Float2 => todo!(), - ScalarKind::Pred => None, + ast::ScalarKind::Float2 => todo!(), + ast::ScalarKind::Pred => None, }, (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {