Convert enumes to 1TT

This commit is contained in:
Andrzej Janik 2021-04-15 19:10:45 +02:00
commit a0baad9456
4 changed files with 332 additions and 462 deletions

View file

@ -210,20 +210,6 @@ sub_enum!(LdStScalarType {
F64, F64,
}); });
sub_enum!(SelpType {
B16,
B32,
B64,
U16,
U32,
U64,
S16,
S32,
S64,
F32,
F64,
});
#[derive(Copy, Clone, Eq, PartialEq)] #[derive(Copy, Clone, Eq, PartialEq)]
pub enum BarDetails { pub enum BarDetails {
SyncAligned, SyncAligned,
@ -425,52 +411,6 @@ pub enum ScalarType {
Pred, Pred,
} }
sub_enum!(IntType {
U8,
U16,
U32,
U64,
S8,
S16,
S32,
S64
});
sub_enum!(BitType { B8, B16, B32, B64 });
sub_enum!(UIntType { U8, U16, U32, U64 });
sub_enum!(SIntType { S8, S16, S32, S64 });
impl IntType {
pub fn is_signed(self) -> bool {
match self {
IntType::U8 | IntType::U16 | IntType::U32 | IntType::U64 => false,
IntType::S8 | IntType::S16 | IntType::S32 | IntType::S64 => true,
}
}
pub fn width(self) -> u8 {
match self {
IntType::U8 => 1,
IntType::U16 => 2,
IntType::U32 => 4,
IntType::U64 => 8,
IntType::S8 => 1,
IntType::S16 => 2,
IntType::S32 => 4,
IntType::S64 => 8,
}
}
}
sub_enum!(FloatType {
F16,
F16x2,
F32,
F64
});
impl ScalarType { impl ScalarType {
pub fn size_of(self) -> u8 { pub fn size_of(self) -> u8 {
match self { match self {
@ -576,24 +516,24 @@ pub enum Instruction<P: ArgParams> {
Add(ArithDetails, Arg3<P>), Add(ArithDetails, Arg3<P>),
Setp(SetpData, Arg4Setp<P>), Setp(SetpData, Arg4Setp<P>),
SetpBool(SetpBoolData, Arg5Setp<P>), SetpBool(SetpBoolData, Arg5Setp<P>),
Not(BooleanType, Arg2<P>), Not(ScalarType, Arg2<P>),
Bra(BraData, Arg1<P>), Bra(BraData, Arg1<P>),
Cvt(CvtDetails, Arg2<P>), Cvt(CvtDetails, Arg2<P>),
Cvta(CvtaDetails, Arg2<P>), Cvta(CvtaDetails, Arg2<P>),
Shl(ShlType, Arg3<P>), Shl(ScalarType, Arg3<P>),
Shr(ShrType, Arg3<P>), Shr(ScalarType, Arg3<P>),
St(StData, Arg2St<P>), St(StData, Arg2St<P>),
Ret(RetData), Ret(RetData),
Call(CallInst<P>), Call(CallInst<P>),
Abs(AbsDetails, Arg2<P>), Abs(AbsDetails, Arg2<P>),
Mad(MulDetails, Arg4<P>), Mad(MulDetails, Arg4<P>),
Or(BooleanType, Arg3<P>), Or(ScalarType, Arg3<P>),
Sub(ArithDetails, Arg3<P>), Sub(ArithDetails, Arg3<P>),
Min(MinMaxDetails, Arg3<P>), Min(MinMaxDetails, Arg3<P>),
Max(MinMaxDetails, Arg3<P>), Max(MinMaxDetails, Arg3<P>),
Rcp(RcpDetails, Arg2<P>), Rcp(RcpDetails, Arg2<P>),
And(BooleanType, Arg3<P>), And(ScalarType, Arg3<P>),
Selp(SelpType, Arg4<P>), Selp(ScalarType, Arg4<P>),
Bar(BarDetails, Arg1Bar<P>), Bar(BarDetails, Arg1Bar<P>),
Atom(AtomDetails, Arg3<P>), Atom(AtomDetails, Arg3<P>),
AtomCas(AtomCasDetails, Arg4<P>), AtomCas(AtomCasDetails, Arg4<P>),
@ -605,13 +545,13 @@ pub enum Instruction<P: ArgParams> {
Cos { flush_to_zero: bool, arg: Arg2<P> }, Cos { flush_to_zero: bool, arg: Arg2<P> },
Lg2 { flush_to_zero: bool, arg: Arg2<P> }, Lg2 { flush_to_zero: bool, arg: Arg2<P> },
Ex2 { flush_to_zero: bool, arg: Arg2<P> }, Ex2 { flush_to_zero: bool, arg: Arg2<P> },
Clz { typ: BitType, arg: Arg2<P> }, Clz { typ: ScalarType, arg: Arg2<P> },
Brev { typ: BitType, arg: Arg2<P> }, Brev { typ: ScalarType, arg: Arg2<P> },
Popc { typ: BitType, arg: Arg2<P> }, Popc { typ: ScalarType, arg: Arg2<P> },
Xor { typ: BooleanType, arg: Arg3<P> }, Xor { typ: ScalarType, arg: Arg3<P> },
Bfe { typ: IntType, arg: Arg4<P> }, Bfe { typ: ScalarType, arg: Arg4<P> },
Bfi { typ: BitType, arg: Arg5<P> }, Bfi { typ: ScalarType, arg: Arg5<P> },
Rem { typ: IntType, arg: Arg3<P> }, Rem { typ: ScalarType, arg: Arg3<P> },
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
@ -825,7 +765,7 @@ impl MovDetails {
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct MulIntDesc { pub struct MulIntDesc {
pub typ: IntType, pub typ: ScalarType,
pub control: MulIntControl, pub control: MulIntControl,
} }
@ -845,7 +785,7 @@ pub enum RoundingMode {
} }
pub struct AddIntDesc { pub struct AddIntDesc {
pub typ: IntType, pub typ: ScalarType,
pub saturate: bool, pub saturate: bool,
} }
@ -892,39 +832,39 @@ pub struct BraData {
pub enum CvtDetails { pub enum CvtDetails {
IntFromInt(CvtIntToIntDesc), IntFromInt(CvtIntToIntDesc),
FloatFromFloat(CvtDesc<FloatType, FloatType>), FloatFromFloat(CvtDesc),
IntFromFloat(CvtDesc<IntType, FloatType>), IntFromFloat(CvtDesc),
FloatFromInt(CvtDesc<FloatType, IntType>), FloatFromInt(CvtDesc),
} }
pub struct CvtIntToIntDesc { pub struct CvtIntToIntDesc {
pub dst: IntType, pub dst: ScalarType,
pub src: IntType, pub src: ScalarType,
pub saturate: bool, pub saturate: bool,
} }
pub struct CvtDesc<Dst, Src> { pub struct CvtDesc {
pub rounding: Option<RoundingMode>, pub rounding: Option<RoundingMode>,
pub flush_to_zero: Option<bool>, pub flush_to_zero: Option<bool>,
pub saturate: bool, pub saturate: bool,
pub dst: Dst, pub dst: ScalarType,
pub src: Src, pub src: ScalarType,
} }
impl CvtDetails { impl CvtDetails {
pub fn new_int_from_int_checked<'err, 'input>( pub fn new_int_from_int_checked<'err, 'input>(
saturate: bool, saturate: bool,
dst: IntType, dst: ScalarType,
src: IntType, src: ScalarType,
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>, err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self { ) -> Self {
if saturate { if saturate {
if src.is_signed() { if src.kind() == ScalarKind::Signed {
if dst.is_signed() && dst.width() >= src.width() { if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() {
err.push(ParseError::from(PtxError::SyntaxError)); err.push(ParseError::from(PtxError::SyntaxError));
} }
} else { } else {
if dst == src || dst.width() >= src.width() { if dst == src || dst.size_of() >= src.size_of() {
err.push(ParseError::from(PtxError::SyntaxError)); err.push(ParseError::from(PtxError::SyntaxError));
} }
} }
@ -936,11 +876,11 @@ impl CvtDetails {
rounding: RoundingMode, rounding: RoundingMode,
flush_to_zero: bool, flush_to_zero: bool,
saturate: bool, saturate: bool,
dst: FloatType, dst: ScalarType,
src: IntType, src: ScalarType,
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>, err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self { ) -> Self {
if flush_to_zero && dst != FloatType::F32 { if flush_to_zero && dst != ScalarType::F32 {
err.push(ParseError::from(PtxError::NonF32Ftz)); err.push(ParseError::from(PtxError::NonF32Ftz));
} }
CvtDetails::FloatFromInt(CvtDesc { CvtDetails::FloatFromInt(CvtDesc {
@ -956,11 +896,11 @@ impl CvtDetails {
rounding: RoundingMode, rounding: RoundingMode,
flush_to_zero: bool, flush_to_zero: bool,
saturate: bool, saturate: bool,
dst: IntType, dst: ScalarType,
src: FloatType, src: ScalarType,
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>, err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self { ) -> Self {
if flush_to_zero && src != FloatType::F32 { if flush_to_zero && src != ScalarType::F32 {
err.push(ParseError::from(PtxError::NonF32Ftz)); err.push(ParseError::from(PtxError::NonF32Ftz));
} }
CvtDetails::IntFromFloat(CvtDesc { CvtDetails::IntFromFloat(CvtDesc {
@ -993,25 +933,6 @@ pub enum CvtaSize {
U64, U64,
} }
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum ShlType {
B16,
B32,
B64,
}
sub_enum!(ShrType {
B16,
B32,
B64,
U16,
U32,
U64,
S16,
S32,
S64,
});
pub struct StData { pub struct StData {
pub qualifier: LdStQualifier, pub qualifier: LdStQualifier,
pub state_space: StStateSpace, pub state_space: StStateSpace,
@ -1040,13 +961,6 @@ pub struct RetData {
pub uniform: bool, pub uniform: bool,
} }
sub_enum!(BooleanType {
Pred,
B16,
B32,
B64,
});
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub enum MulDetails { pub enum MulDetails {
Unsigned(MulUInt), Unsigned(MulUInt),
@ -1056,32 +970,32 @@ pub enum MulDetails {
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct MulUInt { pub struct MulUInt {
pub typ: UIntType, pub typ: ScalarType,
pub control: MulIntControl, pub control: MulIntControl,
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct MulSInt { pub struct MulSInt {
pub typ: SIntType, pub typ: ScalarType,
pub control: MulIntControl, pub control: MulIntControl,
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub enum ArithDetails { pub enum ArithDetails {
Unsigned(UIntType), Unsigned(ScalarType),
Signed(ArithSInt), Signed(ArithSInt),
Float(ArithFloat), Float(ArithFloat),
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct ArithSInt { pub struct ArithSInt {
pub typ: SIntType, pub typ: ScalarType,
pub saturate: bool, pub saturate: bool,
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct ArithFloat { pub struct ArithFloat {
pub typ: FloatType, pub typ: ScalarType,
pub rounding: Option<RoundingMode>, pub rounding: Option<RoundingMode>,
pub flush_to_zero: Option<bool>, pub flush_to_zero: Option<bool>,
pub saturate: bool, pub saturate: bool,
@ -1089,8 +1003,8 @@ pub struct ArithFloat {
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub enum MinMaxDetails { pub enum MinMaxDetails {
Signed(SIntType), Signed(ScalarType),
Unsigned(UIntType), Unsigned(ScalarType),
Float(MinMaxFloat), Float(MinMaxFloat),
} }
@ -1098,7 +1012,7 @@ pub enum MinMaxDetails {
pub struct MinMaxFloat { pub struct MinMaxFloat {
pub flush_to_zero: Option<bool>, pub flush_to_zero: Option<bool>,
pub nan: bool, pub nan: bool,
pub typ: FloatType, pub typ: ScalarType,
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
@ -1126,10 +1040,10 @@ pub enum AtomSpace {
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub enum AtomInnerDetails { pub enum AtomInnerDetails {
Bit { op: AtomBitOp, typ: BitType }, Bit { op: AtomBitOp, typ: ScalarType },
Unsigned { op: AtomUIntOp, typ: UIntType }, Unsigned { op: AtomUIntOp, typ: ScalarType },
Signed { op: AtomSIntOp, typ: SIntType }, Signed { op: AtomSIntOp, typ: ScalarType },
Float { op: AtomFloatOp, typ: FloatType }, Float { op: AtomFloatOp, typ: ScalarType },
} }
#[derive(Copy, Clone, Eq, PartialEq)] #[derive(Copy, Clone, Eq, PartialEq)]
@ -1166,19 +1080,19 @@ pub struct AtomCasDetails {
pub semantics: AtomSemantics, pub semantics: AtomSemantics,
pub scope: MemScope, pub scope: MemScope,
pub space: AtomSpace, pub space: AtomSpace,
pub typ: BitType, pub typ: ScalarType,
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub enum DivDetails { pub enum DivDetails {
Unsigned(UIntType), Unsigned(ScalarType),
Signed(SIntType), Signed(ScalarType),
Float(DivFloatDetails), Float(DivFloatDetails),
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct DivFloatDetails { pub struct DivFloatDetails {
pub typ: FloatType, pub typ: ScalarType,
pub flush_to_zero: Option<bool>, pub flush_to_zero: Option<bool>,
pub kind: DivFloatKind, pub kind: DivFloatKind,
} }
@ -1197,7 +1111,7 @@ pub enum NumsOrArrays<'a> {
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct SqrtDetails { pub struct SqrtDetails {
pub typ: FloatType, pub typ: ScalarType,
pub flush_to_zero: Option<bool>, pub flush_to_zero: Option<bool>,
pub kind: SqrtKind, pub kind: SqrtKind,
} }
@ -1210,7 +1124,7 @@ pub enum SqrtKind {
#[derive(Copy, Clone, Eq, PartialEq)] #[derive(Copy, Clone, Eq, PartialEq)]
pub struct RsqrtDetails { pub struct RsqrtDetails {
pub typ: FloatType, pub typ: ScalarType,
pub flush_to_zero: bool, pub flush_to_zero: bool,
} }
@ -1379,6 +1293,40 @@ pub enum TuningDirective {
MinNCtaPerSm(u32), MinNCtaPerSm(u32),
} }
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum ScalarKind {
Bit,
Unsigned,
Signed,
Float,
Float2,
Pred,
}
impl ScalarType {
pub fn kind(self) -> ScalarKind {
match self {
ScalarType::U8 => ScalarKind::Unsigned,
ScalarType::U16 => ScalarKind::Unsigned,
ScalarType::U32 => ScalarKind::Unsigned,
ScalarType::U64 => ScalarKind::Unsigned,
ScalarType::S8 => ScalarKind::Signed,
ScalarType::S16 => ScalarKind::Signed,
ScalarType::S32 => ScalarKind::Signed,
ScalarType::S64 => ScalarKind::Signed,
ScalarType::B8 => ScalarKind::Bit,
ScalarType::B16 => ScalarKind::Bit,
ScalarType::B32 => ScalarKind::Bit,
ScalarType::B64 => ScalarKind::Bit,
ScalarType::F16 => ScalarKind::Float,
ScalarType::F32 => ScalarKind::Float,
ScalarType::F64 => ScalarKind::Float,
ScalarType::F16x2 => ScalarKind::Float2,
ScalarType::Pred => ScalarKind::Pred,
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View file

@ -899,39 +899,39 @@ RoundingModeInt : ast::RoundingMode = {
".rpi" => ast::RoundingMode::PositiveInf, ".rpi" => ast::RoundingMode::PositiveInf,
}; };
IntType : ast::IntType = { IntType : ast::ScalarType = {
".u16" => ast::IntType::U16, ".u16" => ast::ScalarType::U16,
".u32" => ast::IntType::U32, ".u32" => ast::ScalarType::U32,
".u64" => ast::IntType::U64, ".u64" => ast::ScalarType::U64,
".s16" => ast::IntType::S16, ".s16" => ast::ScalarType::S16,
".s32" => ast::IntType::S32, ".s32" => ast::ScalarType::S32,
".s64" => ast::IntType::S64, ".s64" => ast::ScalarType::S64,
}; };
IntType3264: ast::IntType = { IntType3264: ast::ScalarType = {
".u32" => ast::IntType::U32, ".u32" => ast::ScalarType::U32,
".u64" => ast::IntType::U64, ".u64" => ast::ScalarType::U64,
".s32" => ast::IntType::S32, ".s32" => ast::ScalarType::S32,
".s64" => ast::IntType::S64, ".s64" => ast::ScalarType::S64,
} }
UIntType: ast::UIntType = { UIntType: ast::ScalarType = {
".u16" => ast::UIntType::U16, ".u16" => ast::ScalarType::U16,
".u32" => ast::UIntType::U32, ".u32" => ast::ScalarType::U32,
".u64" => ast::UIntType::U64, ".u64" => ast::ScalarType::U64,
}; };
SIntType: ast::SIntType = { SIntType: ast::ScalarType = {
".s16" => ast::SIntType::S16, ".s16" => ast::ScalarType::S16,
".s32" => ast::SIntType::S32, ".s32" => ast::ScalarType::S32,
".s64" => ast::SIntType::S64, ".s64" => ast::ScalarType::S64,
}; };
FloatType: ast::FloatType = { FloatType: ast::ScalarType = {
".f16" => ast::FloatType::F16, ".f16" => ast::ScalarType::F16,
".f16x2" => ast::FloatType::F16x2, ".f16x2" => ast::ScalarType::F16x2,
".f32" => ast::FloatType::F32, ".f32" => ast::ScalarType::F32,
".f64" => ast::FloatType::F64, ".f64" => ast::ScalarType::F64,
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
@ -1023,11 +1023,11 @@ InstNot: ast::Instruction<ast::ParsedArgParams<'input>> = {
"not" <t:BooleanType> <a:Arg2> => ast::Instruction::Not(t, a) "not" <t:BooleanType> <a:Arg2> => ast::Instruction::Not(t, a)
}; };
BooleanType: ast::BooleanType = { BooleanType: ast::ScalarType = {
".pred" => ast::BooleanType::Pred, ".pred" => ast::ScalarType::Pred,
".b16" => ast::BooleanType::B16, ".b16" => ast::ScalarType::B16,
".b32" => ast::BooleanType::B32, ".b32" => ast::ScalarType::B32,
".b64" => ast::BooleanType::B64, ".b64" => ast::ScalarType::B64,
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at
@ -1080,8 +1080,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: r, rounding: r,
flush_to_zero: None, flush_to_zero: None,
saturate: s.is_some(), saturate: s.is_some(),
dst: ast::FloatType::F16, dst: ast::ScalarType::F16,
src: ast::FloatType::F16 src: ast::ScalarType::F16
} }
), a) ), a)
}, },
@ -1091,8 +1091,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: None, rounding: None,
flush_to_zero: Some(f.is_some()), flush_to_zero: Some(f.is_some()),
saturate: s.is_some(), saturate: s.is_some(),
dst: ast::FloatType::F32, dst: ast::ScalarType::F32,
src: ast::FloatType::F16 src: ast::ScalarType::F16
} }
), a) ), a)
}, },
@ -1102,8 +1102,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: None, rounding: None,
flush_to_zero: None, flush_to_zero: None,
saturate: s.is_some(), saturate: s.is_some(),
dst: ast::FloatType::F64, dst: ast::ScalarType::F64,
src: ast::FloatType::F16 src: ast::ScalarType::F16
} }
), a) ), a)
}, },
@ -1113,8 +1113,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: Some(r), rounding: Some(r),
flush_to_zero: Some(f.is_some()), flush_to_zero: Some(f.is_some()),
saturate: s.is_some(), saturate: s.is_some(),
dst: ast::FloatType::F16, dst: ast::ScalarType::F16,
src: ast::FloatType::F32 src: ast::ScalarType::F32
} }
), a) ), a)
}, },
@ -1124,8 +1124,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: r, rounding: r,
flush_to_zero: Some(f.is_some()), flush_to_zero: Some(f.is_some()),
saturate: s.is_some(), saturate: s.is_some(),
dst: ast::FloatType::F32, dst: ast::ScalarType::F32,
src: ast::FloatType::F32 src: ast::ScalarType::F32
} }
), a) ), a)
}, },
@ -1135,8 +1135,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: None, rounding: None,
flush_to_zero: None, flush_to_zero: None,
saturate: s.is_some(), saturate: s.is_some(),
dst: ast::FloatType::F64, dst: ast::ScalarType::F64,
src: ast::FloatType::F32 src: ast::ScalarType::F32
} }
), a) ), a)
}, },
@ -1146,8 +1146,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: Some(r), rounding: Some(r),
flush_to_zero: None, flush_to_zero: None,
saturate: s.is_some(), saturate: s.is_some(),
dst: ast::FloatType::F16, dst: ast::ScalarType::F16,
src: ast::FloatType::F64 src: ast::ScalarType::F64
} }
), a) ), a)
}, },
@ -1157,8 +1157,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: Some(r), rounding: Some(r),
flush_to_zero: Some(s.is_some()), flush_to_zero: Some(s.is_some()),
saturate: s.is_some(), saturate: s.is_some(),
dst: ast::FloatType::F32, dst: ast::ScalarType::F32,
src: ast::FloatType::F64 src: ast::ScalarType::F64
} }
), a) ), a)
}, },
@ -1168,28 +1168,28 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: r, rounding: r,
flush_to_zero: None, flush_to_zero: None,
saturate: s.is_some(), saturate: s.is_some(),
dst: ast::FloatType::F64, dst: ast::ScalarType::F64,
src: ast::FloatType::F64 src: ast::ScalarType::F64
} }
), a) ), a)
}, },
}; };
CvtTypeInt: ast::IntType = { CvtTypeInt: ast::ScalarType = {
".u8" => ast::IntType::U8, ".u8" => ast::ScalarType::U8,
".u16" => ast::IntType::U16, ".u16" => ast::ScalarType::U16,
".u32" => ast::IntType::U32, ".u32" => ast::ScalarType::U32,
".u64" => ast::IntType::U64, ".u64" => ast::ScalarType::U64,
".s8" => ast::IntType::S8, ".s8" => ast::ScalarType::S8,
".s16" => ast::IntType::S16, ".s16" => ast::ScalarType::S16,
".s32" => ast::IntType::S32, ".s32" => ast::ScalarType::S32,
".s64" => ast::IntType::S64, ".s64" => ast::ScalarType::S64,
}; };
CvtTypeFloat: ast::FloatType = { CvtTypeFloat: ast::ScalarType = {
".f16" => ast::FloatType::F16, ".f16" => ast::ScalarType::F16,
".f32" => ast::FloatType::F32, ".f32" => ast::ScalarType::F32,
".f64" => ast::FloatType::F64, ".f64" => ast::ScalarType::F64,
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl
@ -1197,10 +1197,10 @@ InstShl: ast::Instruction<ast::ParsedArgParams<'input>> = {
"shl" <t:ShlType> <a:Arg3> => ast::Instruction::Shl(t, a) "shl" <t:ShlType> <a:Arg3> => ast::Instruction::Shl(t, a)
}; };
ShlType: ast::ShlType = { ShlType: ast::ScalarType = {
".b16" => ast::ShlType::B16, ".b16" => ast::ScalarType::B16,
".b32" => ast::ShlType::B32, ".b32" => ast::ScalarType::B32,
".b64" => ast::ShlType::B64, ".b64" => ast::ScalarType::B64,
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr
@ -1208,16 +1208,16 @@ InstShr: ast::Instruction<ast::ParsedArgParams<'input>> = {
"shr" <t:ShrType> <a:Arg3> => ast::Instruction::Shr(t, a) "shr" <t:ShrType> <a:Arg3> => ast::Instruction::Shr(t, a)
}; };
ShrType: ast::ShrType = { ShrType: ast::ScalarType = {
".b16" => ast::ShrType::B16, ".b16" => ast::ScalarType::B16,
".b32" => ast::ShrType::B32, ".b32" => ast::ScalarType::B32,
".b64" => ast::ShrType::B64, ".b64" => ast::ScalarType::B64,
".u16" => ast::ShrType::U16, ".u16" => ast::ScalarType::U16,
".u32" => ast::ShrType::U32, ".u32" => ast::ScalarType::U32,
".u64" => ast::ShrType::U64, ".u64" => ast::ScalarType::U64,
".s16" => ast::ShrType::S16, ".s16" => ast::ScalarType::S16,
".s32" => ast::ShrType::S32, ".s32" => ast::ScalarType::S32,
".s64" => ast::ShrType::S64, ".s64" => ast::ScalarType::S64,
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
@ -1393,16 +1393,16 @@ MinMaxDetails: ast::MinMaxDetails = {
<t:UIntType> => ast::MinMaxDetails::Unsigned(t), <t:UIntType> => ast::MinMaxDetails::Unsigned(t),
<t:SIntType> => ast::MinMaxDetails::Signed(t), <t:SIntType> => ast::MinMaxDetails::Signed(t),
<ftz:".ftz"?> <nan:".NaN"?> ".f32" => ast::MinMaxDetails::Float( <ftz:".ftz"?> <nan:".NaN"?> ".f32" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F32 } ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F32 }
), ),
".f64" => ast::MinMaxDetails::Float( ".f64" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::FloatType::F64 } ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::ScalarType::F64 }
), ),
<ftz:".ftz"?> <nan:".NaN"?> ".f16" => ast::MinMaxDetails::Float( <ftz:".ftz"?> <nan:".NaN"?> ".f16" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16 } ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16 }
), ),
<ftz:".ftz"?> <nan:".NaN"?> ".f16x2" => ast::MinMaxDetails::Float( <ftz:".ftz"?> <nan:".NaN"?> ".f16x2" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16x2 } ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16x2 }
) )
} }
@ -1411,18 +1411,18 @@ InstSelp: ast::Instruction<ast::ParsedArgParams<'input>> = {
"selp" <t:SelpType> <a:Arg4> => ast::Instruction::Selp(t, a), "selp" <t:SelpType> <a:Arg4> => ast::Instruction::Selp(t, a),
}; };
SelpType: ast::SelpType = { SelpType: ast::ScalarType = {
".b16" => ast::SelpType::B16, ".b16" => ast::ScalarType::B16,
".b32" => ast::SelpType::B32, ".b32" => ast::ScalarType::B32,
".b64" => ast::SelpType::B64, ".b64" => ast::ScalarType::B64,
".u16" => ast::SelpType::U16, ".u16" => ast::ScalarType::U16,
".u32" => ast::SelpType::U32, ".u32" => ast::ScalarType::U32,
".u64" => ast::SelpType::U64, ".u64" => ast::ScalarType::U64,
".s16" => ast::SelpType::S16, ".s16" => ast::ScalarType::S16,
".s32" => ast::SelpType::S32, ".s32" => ast::ScalarType::S32,
".s64" => ast::SelpType::S64, ".s64" => ast::ScalarType::S64,
".f32" => ast::SelpType::F32, ".f32" => ast::ScalarType::F32,
".f64" => ast::SelpType::F64, ".f64" => ast::ScalarType::F64,
}; };
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar
@ -1454,7 +1454,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
space: space.unwrap_or(ast::AtomSpace::Generic), space: space.unwrap_or(ast::AtomSpace::Generic),
inner: ast::AtomInnerDetails::Unsigned { inner: ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Inc, op: ast::AtomUIntOp::Inc,
typ: ast::UIntType::U32 typ: ast::ScalarType::U32
} }
}; };
ast::Instruction::Atom(details,a) ast::Instruction::Atom(details,a)
@ -1466,7 +1466,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
space: space.unwrap_or(ast::AtomSpace::Generic), space: space.unwrap_or(ast::AtomSpace::Generic),
inner: ast::AtomInnerDetails::Unsigned { inner: ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Dec, op: ast::AtomUIntOp::Dec,
typ: ast::UIntType::U32 typ: ast::ScalarType::U32
} }
}; };
ast::Instruction::Atom(details,a) ast::Instruction::Atom(details,a)
@ -1544,19 +1544,19 @@ AtomSIntOp: ast::AtomSIntOp = {
".max" => ast::AtomSIntOp::Max, ".max" => ast::AtomSIntOp::Max,
} }
BitType: ast::BitType = { BitType: ast::ScalarType = {
".b32" => ast::BitType::B32, ".b32" => ast::ScalarType::B32,
".b64" => ast::BitType::B64, ".b64" => ast::ScalarType::B64,
} }
UIntType3264: ast::UIntType = { UIntType3264: ast::ScalarType = {
".u32" => ast::UIntType::U32, ".u32" => ast::ScalarType::U32,
".u64" => ast::UIntType::U64, ".u64" => ast::ScalarType::U64,
} }
SIntType3264: ast::SIntType = { SIntType3264: ast::ScalarType = {
".s32" => ast::SIntType::S32, ".s32" => ast::ScalarType::S32,
".s64" => ast::SIntType::S64, ".s64" => ast::ScalarType::S64,
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div
@ -1566,7 +1566,7 @@ InstDiv: ast::Instruction<ast::ParsedArgParams<'input>> = {
"div" <t:SIntType> <a:Arg3> => ast::Instruction::Div(ast::DivDetails::Signed(t), a), "div" <t:SIntType> <a:Arg3> => ast::Instruction::Div(ast::DivDetails::Signed(t), a),
"div" <kind:DivFloatKind> <ftz:".ftz"?> ".f32" <a:Arg3> => { "div" <kind:DivFloatKind> <ftz:".ftz"?> ".f32" <a:Arg3> => {
let inner = ast::DivFloatDetails { let inner = ast::DivFloatDetails {
typ: ast::FloatType::F32, typ: ast::ScalarType::F32,
flush_to_zero: Some(ftz.is_some()), flush_to_zero: Some(ftz.is_some()),
kind kind
}; };
@ -1574,7 +1574,7 @@ InstDiv: ast::Instruction<ast::ParsedArgParams<'input>> = {
}, },
"div" <rnd:RoundingModeFloat> ".f64" <a:Arg3> => { "div" <rnd:RoundingModeFloat> ".f64" <a:Arg3> => {
let inner = ast::DivFloatDetails { let inner = ast::DivFloatDetails {
typ: ast::FloatType::F64, typ: ast::ScalarType::F64,
flush_to_zero: None, flush_to_zero: None,
kind: ast::DivFloatKind::Rounding(rnd) kind: ast::DivFloatKind::Rounding(rnd)
}; };
@ -1592,7 +1592,7 @@ DivFloatKind: ast::DivFloatKind = {
InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = { InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"sqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => { "sqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => {
let details = ast::SqrtDetails { let details = ast::SqrtDetails {
typ: ast::FloatType::F32, typ: ast::ScalarType::F32,
flush_to_zero: Some(ftz.is_some()), flush_to_zero: Some(ftz.is_some()),
kind: ast::SqrtKind::Approx, kind: ast::SqrtKind::Approx,
}; };
@ -1600,7 +1600,7 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
}, },
"sqrt" <rnd:RoundingModeFloat> <ftz:".ftz"?> ".f32" <a:Arg2> => { "sqrt" <rnd:RoundingModeFloat> <ftz:".ftz"?> ".f32" <a:Arg2> => {
let details = ast::SqrtDetails { let details = ast::SqrtDetails {
typ: ast::FloatType::F32, typ: ast::ScalarType::F32,
flush_to_zero: Some(ftz.is_some()), flush_to_zero: Some(ftz.is_some()),
kind: ast::SqrtKind::Rounding(rnd), kind: ast::SqrtKind::Rounding(rnd),
}; };
@ -1608,7 +1608,7 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
}, },
"sqrt" <rnd:RoundingModeFloat> ".f64" <a:Arg2> => { "sqrt" <rnd:RoundingModeFloat> ".f64" <a:Arg2> => {
let details = ast::SqrtDetails { let details = ast::SqrtDetails {
typ: ast::FloatType::F64, typ: ast::ScalarType::F64,
flush_to_zero: None, flush_to_zero: None,
kind: ast::SqrtKind::Rounding(rnd), kind: ast::SqrtKind::Rounding(rnd),
}; };
@ -1621,14 +1621,14 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstRsqrt: ast::Instruction<ast::ParsedArgParams<'input>> = { InstRsqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"rsqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => { "rsqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => {
let details = ast::RsqrtDetails { let details = ast::RsqrtDetails {
typ: ast::FloatType::F32, typ: ast::ScalarType::F32,
flush_to_zero: ftz.is_some(), flush_to_zero: ftz.is_some(),
}; };
ast::Instruction::Rsqrt(details, a) ast::Instruction::Rsqrt(details, a)
}, },
"rsqrt" ".approx" <ftz:".ftz"?> ".f64" <a:Arg2> => { "rsqrt" ".approx" <ftz:".ftz"?> ".f64" <a:Arg2> => {
let details = ast::RsqrtDetails { let details = ast::RsqrtDetails {
typ: ast::FloatType::F64, typ: ast::ScalarType::F64,
flush_to_zero: ftz.is_some(), flush_to_zero: ftz.is_some(),
}; };
ast::Instruction::Rsqrt(details, a) ast::Instruction::Rsqrt(details, a)
@ -1739,7 +1739,7 @@ ArithDetails: ast::ArithDetails = {
saturate: false, saturate: false,
}), }),
".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt { ".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::S32, typ: ast::ScalarType::S32,
saturate: true, saturate: true,
}), }),
<f:ArithFloat> => ast::ArithDetails::Float(f) <f:ArithFloat> => ast::ArithDetails::Float(f)
@ -1747,25 +1747,25 @@ ArithDetails: ast::ArithDetails = {
ArithFloat: ast::ArithFloat = { ArithFloat: ast::ArithFloat = {
<rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat { <rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
typ: ast::FloatType::F32, typ: ast::ScalarType::F32,
rounding: rn, rounding: rn,
flush_to_zero: Some(ftz.is_some()), flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(), saturate: sat.is_some(),
}, },
<rn:RoundingModeFloat?> ".f64" => ast::ArithFloat { <rn:RoundingModeFloat?> ".f64" => ast::ArithFloat {
typ: ast::FloatType::F64, typ: ast::ScalarType::F64,
rounding: rn, rounding: rn,
flush_to_zero: None, flush_to_zero: None,
saturate: false, saturate: false,
}, },
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat { <rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
typ: ast::FloatType::F16, typ: ast::ScalarType::F16,
rounding: rn.map(|_| ast::RoundingMode::NearestEven), rounding: rn.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz.is_some()), flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(), saturate: sat.is_some(),
}, },
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat { <rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
typ: ast::FloatType::F16x2, typ: ast::ScalarType::F16x2,
rounding: rn.map(|_| ast::RoundingMode::NearestEven), rounding: rn.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz.is_some()), flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(), saturate: sat.is_some(),
@ -1774,25 +1774,25 @@ ArithFloat: ast::ArithFloat = {
ArithFloatMustRound: ast::ArithFloat = { ArithFloatMustRound: ast::ArithFloat = {
<rn:RoundingModeFloat> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat { <rn:RoundingModeFloat> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
typ: ast::FloatType::F32, typ: ast::ScalarType::F32,
rounding: Some(rn), rounding: Some(rn),
flush_to_zero: Some(ftz.is_some()), flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(), saturate: sat.is_some(),
}, },
<rn:RoundingModeFloat> ".f64" => ast::ArithFloat { <rn:RoundingModeFloat> ".f64" => ast::ArithFloat {
typ: ast::FloatType::F64, typ: ast::ScalarType::F64,
rounding: Some(rn), rounding: Some(rn),
flush_to_zero: None, flush_to_zero: None,
saturate: false, saturate: false,
}, },
".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat { ".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
typ: ast::FloatType::F16, typ: ast::ScalarType::F16,
rounding: Some(ast::RoundingMode::NearestEven), rounding: Some(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz.is_some()), flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(), saturate: sat.is_some(),
}, },
".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat { ".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
typ: ast::FloatType::F16x2, typ: ast::ScalarType::F16x2,
rounding: Some(ast::RoundingMode::NearestEven), rounding: Some(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz.is_some()), flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(), saturate: sat.is_some(),

View file

@ -71,7 +71,7 @@
%26 = OpLoad %uint %9 %26 = OpLoad %uint %9
%40 = OpCopyObject %uint %23 %40 = OpCopyObject %uint %23
%41 = OpCopyObject %uint %24 %41 = OpCopyObject %uint %24
%39 = OpFunctionCall %uint %44 %41 %40 %25 %26 %39 = OpFunctionCall %uint %44 %40 %41 %25 %26
%22 = OpCopyObject %uint %39 %22 = OpCopyObject %uint %39
OpStore %6 %22 OpStore %6 %22
%27 = OpLoad %ulong %5 %27 = OpLoad %ulong %5

View file

@ -1553,10 +1553,9 @@ fn extract_globals<'input, 'b>(
space, space,
}; };
let (op, typ) = match typ { let (op, typ) = match typ {
ast::FloatType::F32 => ("add_f32", ast::SizedScalarType::F32), ast::ScalarType::F32 => ("add_f32", ast::SizedScalarType::F32),
ast::FloatType::F64 => ("add_f64", ast::SizedScalarType::F64), ast::ScalarType::F64 => ("add_f64", ast::SizedScalarType::F64),
ast::FloatType::F16 => unreachable!(), _ => unreachable!(),
ast::FloatType::F16x2 => unreachable!(),
}; };
local.push(to_ptx_impl_atomic_call( local.push(to_ptx_impl_atomic_call(
id_def, id_def,
@ -1822,15 +1821,15 @@ fn to_ptx_impl_atomic_call(
fn to_ptx_impl_bfe_call( fn to_ptx_impl_bfe_call(
id_defs: &mut NumericIdResolver, id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>, ptx_impl_imports: &mut HashMap<String, Directive>,
typ: ast::IntType, typ: ast::ScalarType,
arg: ast::Arg4<ExpandedArgParams>, arg: ast::Arg4<ExpandedArgParams>,
) -> ExpandedStatement { ) -> ExpandedStatement {
let prefix = "__zluda_ptx_impl__"; let prefix = "__zluda_ptx_impl__";
let suffix = match typ { let suffix = match typ {
ast::IntType::U32 => "bfe_u32", ast::ScalarType::U32 => "bfe_u32",
ast::IntType::U64 => "bfe_u64", ast::ScalarType::U64 => "bfe_u64",
ast::IntType::S32 => "bfe_s32", ast::ScalarType::S32 => "bfe_s32",
ast::IntType::S64 => "bfe_s64", ast::ScalarType::S64 => "bfe_s64",
_ => unreachable!(), _ => unreachable!(),
}; };
let fn_name = format!("{}{}", prefix, suffix); let fn_name = format!("{}{}", prefix, suffix);
@ -1917,14 +1916,14 @@ fn to_ptx_impl_bfe_call(
fn to_ptx_impl_bfi_call( fn to_ptx_impl_bfi_call(
id_defs: &mut NumericIdResolver, id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>, ptx_impl_imports: &mut HashMap<String, Directive>,
typ: ast::BitType, typ: ast::ScalarType,
arg: ast::Arg5<ExpandedArgParams>, arg: ast::Arg5<ExpandedArgParams>,
) -> ExpandedStatement { ) -> ExpandedStatement {
let prefix = "__zluda_ptx_impl__"; let prefix = "__zluda_ptx_impl__";
let suffix = match typ { let suffix = match typ {
ast::BitType::B32 => "bfi_b32", ast::ScalarType::B32 => "bfi_b32",
ast::BitType::B64 => "bfi_b64", ast::ScalarType::B64 => "bfi_b64",
ast::BitType::B8 | ast::BitType::B16 => unreachable!(), _ => unreachable!(),
}; };
let fn_name = format!("{}{}", prefix, suffix); let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) { let fn_id = match ptx_impl_imports.entry(fn_name) {
@ -2506,29 +2505,32 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
let (width, kind) = match add_type { let (width, kind) = match add_type {
ast::Type::Scalar(scalar_t) => { ast::Type::Scalar(scalar_t) => {
let kind = match scalar_t.kind() { let kind = match scalar_t.kind() {
kind @ ScalarKind::Bit kind @ ast::ScalarKind::Bit
| kind @ ScalarKind::Unsigned | kind @ ast::ScalarKind::Unsigned
| kind @ ScalarKind::Signed => kind, | kind @ ast::ScalarKind::Signed => kind,
ScalarKind::Float => return Err(TranslateError::MismatchedType), ast::ScalarKind::Float => return Err(TranslateError::MismatchedType),
ScalarKind::Float2 => return Err(TranslateError::MismatchedType), ast::ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
ScalarKind::Pred => return Err(TranslateError::MismatchedType), ast::ScalarKind::Pred => return Err(TranslateError::MismatchedType),
}; };
(scalar_t.size_of(), kind) (scalar_t.size_of(), kind)
} }
_ => return Err(TranslateError::MismatchedType), _ => return Err(TranslateError::MismatchedType),
}; };
let arith_detail = if kind == ScalarKind::Signed { let arith_detail = if kind == ast::ScalarKind::Signed {
ast::ArithDetails::Signed(ast::ArithSInt { ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::from_size(width), typ: ast::ScalarType::from_parts(width, ast::ScalarKind::Signed),
saturate: false, saturate: false,
}) })
} else { } else {
ast::ArithDetails::Unsigned(ast::UIntType::from_size(width)) ast::ArithDetails::Unsigned(ast::ScalarType::from_parts(
width,
ast::ScalarKind::Unsigned,
))
}; };
let id_constant_stmt = self.id_def.new_non_variable(add_type.clone()); let id_constant_stmt = self.id_def.new_non_variable(add_type.clone());
let result_id = self.id_def.new_non_variable(add_type); let result_id = self.id_def.new_non_variable(add_type);
// TODO: check for edge cases around min value/max value/wrapping // TODO: check for edge cases around min value/max value/wrapping
if offset < 0 && kind != ScalarKind::Signed { if offset < 0 && kind != ast::ScalarKind::Signed {
self.func.push(Statement::Constant(ConstantDefinition { self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt, dst: id_constant_stmt,
typ: ast::ScalarType::from_parts(width, kind), typ: ast::ScalarType::from_parts(width, kind),
@ -3026,18 +3028,18 @@ fn emit_function_body_ops(
emit_setp(builder, map, setp, arg)?; emit_setp(builder, map, setp, arg)?;
} }
ast::Instruction::Not(t, a) => { ast::Instruction::Not(t, a) => {
let result_type = map.get_or_add(builder, SpirvType::from(t.to_type())); let result_type = map.get_or_add(builder, SpirvType::from(*t));
let result_id = Some(a.dst); let result_id = Some(a.dst);
let operand = a.src; let operand = a.src;
match t { match t {
ast::BooleanType::Pred => { ast::ScalarType::Pred => {
logical_not(builder, result_type, result_id, operand) logical_not(builder, result_type, result_id, operand)
} }
_ => builder.not(result_type, result_id, operand), _ => builder.not(result_type, result_id, operand),
}?; }?;
} }
ast::Instruction::Shl(t, a) => { ast::Instruction::Shl(t, a) => {
let full_type = t.to_type(); let full_type = ast::Type::Scalar(*t);
let size_of = full_type.size_of(); let size_of = full_type.size_of();
let result_type = map.get_or_add(builder, SpirvType::from(full_type)); let result_type = map.get_or_add(builder, SpirvType::from(full_type));
let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?; let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?;
@ -3048,7 +3050,7 @@ fn emit_function_body_ops(
let size_of = full_type.size_of(); let size_of = full_type.size_of();
let result_type = map.get_or_add_scalar(builder, full_type); let result_type = map.get_or_add_scalar(builder, full_type);
let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?; let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?;
if t.signed() { if t.kind() == ast::ScalarKind::Signed {
builder.shift_right_arithmetic( builder.shift_right_arithmetic(
result_type, result_type,
Some(a.dst), Some(a.dst),
@ -3088,7 +3090,7 @@ fn emit_function_body_ops(
}, },
ast::Instruction::Or(t, a) => { ast::Instruction::Or(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
if *t == ast::BooleanType::Pred { if *t == ast::ScalarType::Pred {
builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?; builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?;
} else { } else {
builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?; builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
@ -3116,7 +3118,7 @@ fn emit_function_body_ops(
} }
ast::Instruction::And(t, a) => { ast::Instruction::And(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
if *t == ast::BooleanType::Pred { if *t == ast::ScalarType::Pred {
builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?; builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?;
} else { } else {
builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?; builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?;
@ -3202,7 +3204,7 @@ fn emit_function_body_ops(
} }
ast::Instruction::Neg(details, arg) => { ast::Instruction::Neg(details, arg) => {
let result_type = map.get_or_add_scalar(builder, details.typ); let result_type = map.get_or_add_scalar(builder, details.typ);
let negate_func = if details.typ.kind() == ScalarKind::Float { let negate_func = if details.typ.kind() == ast::ScalarKind::Float {
dr::Builder::f_negate dr::Builder::f_negate
} else { } else {
dr::Builder::s_negate dr::Builder::s_negate
@ -3269,7 +3271,7 @@ fn emit_function_body_ops(
} }
ast::Instruction::Xor { typ, arg } => { ast::Instruction::Xor { typ, arg } => {
let builder_fn = match typ { let builder_fn = match typ {
ast::BooleanType::Pred => emit_logical_xor_spirv, ast::ScalarType::Pred => emit_logical_xor_spirv,
_ => dr::Builder::bitwise_xor, _ => dr::Builder::bitwise_xor,
}; };
let result_type = map.get_or_add_scalar(builder, (*typ).into()); let result_type = map.get_or_add_scalar(builder, (*typ).into());
@ -3284,7 +3286,7 @@ fn emit_function_body_ops(
return Err(error_unreachable()); return Err(error_unreachable());
} }
ast::Instruction::Rem { typ, arg } => { ast::Instruction::Rem { typ, arg } => {
let builder_fn = if typ.is_signed() { let builder_fn = if typ.kind() == ast::ScalarKind::Signed {
dr::Builder::s_mod dr::Builder::s_mod
} else { } else {
dr::Builder::u_mod dr::Builder::u_mod
@ -3882,7 +3884,7 @@ fn emit_cvt(
} }
let dest_t: ast::ScalarType = desc.dst.into(); let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.src.is_signed() { if desc.src.kind() == ast::ScalarKind::Signed {
builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?; builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?;
} else { } else {
builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?; builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?;
@ -3892,7 +3894,7 @@ fn emit_cvt(
ast::CvtDetails::IntFromFloat(desc) => { ast::CvtDetails::IntFromFloat(desc) => {
let dest_t: ast::ScalarType = desc.dst.into(); let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.dst.is_signed() { if desc.dst.kind() == ast::ScalarKind::Signed {
builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?; builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?;
} else { } else {
builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?; builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?;
@ -3904,7 +3906,7 @@ fn emit_cvt(
let dest_t: ast::ScalarType = desc.dst.into(); let dest_t: ast::ScalarType = desc.dst.into();
let src_t: ast::ScalarType = desc.src.into(); let src_t: ast::ScalarType = desc.src.into();
// first do shortening/widening // first do shortening/widening
let src = if desc.dst.width() != desc.src.width() { let src = if desc.dst.size_of() != desc.src.size_of() {
let new_dst = if dest_t.kind() == src_t.kind() { let new_dst = if dest_t.kind() == src_t.kind() {
arg.dst arg.dst
} else { } else {
@ -3933,7 +3935,7 @@ fn emit_cvt(
// now do actual conversion // now do actual conversion
let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.saturate { if desc.saturate {
if desc.dst.is_signed() { if desc.dst.kind() == ast::ScalarKind::Signed {
builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?; builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?;
} else { } else {
builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?; builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?;
@ -3989,60 +3991,60 @@ fn emit_setp(
let operand_1 = arg.src1; let operand_1 = arg.src1;
let operand_2 = arg.src2; let operand_2 = arg.src2;
match (setp.cmp_op, setp.typ.kind()) { match (setp.cmp_op, setp.typ.kind()) {
(ast::SetpCompareOp::Eq, ScalarKind::Signed) (ast::SetpCompareOp::Eq, ast::ScalarKind::Signed)
| (ast::SetpCompareOp::Eq, ScalarKind::Unsigned) | (ast::SetpCompareOp::Eq, ast::ScalarKind::Unsigned)
| (ast::SetpCompareOp::Eq, ScalarKind::Bit) => { | (ast::SetpCompareOp::Eq, ast::ScalarKind::Bit) => {
builder.i_equal(result_type, result_id, operand_1, operand_2) builder.i_equal(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::Eq, ScalarKind::Float) => { (ast::SetpCompareOp::Eq, ast::ScalarKind::Float) => {
builder.f_ord_equal(result_type, result_id, operand_1, operand_2) builder.f_ord_equal(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::NotEq, ScalarKind::Signed) (ast::SetpCompareOp::NotEq, ast::ScalarKind::Signed)
| (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned) | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Unsigned)
| (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => { | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Bit) => {
builder.i_not_equal(result_type, result_id, operand_1, operand_2) builder.i_not_equal(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::NotEq, ScalarKind::Float) => { (ast::SetpCompareOp::NotEq, ast::ScalarKind::Float) => {
builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2) builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::Less, ScalarKind::Unsigned) (ast::SetpCompareOp::Less, ast::ScalarKind::Unsigned)
| (ast::SetpCompareOp::Less, ScalarKind::Bit) => { | (ast::SetpCompareOp::Less, ast::ScalarKind::Bit) => {
builder.u_less_than(result_type, result_id, operand_1, operand_2) builder.u_less_than(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::Less, ScalarKind::Signed) => { (ast::SetpCompareOp::Less, ast::ScalarKind::Signed) => {
builder.s_less_than(result_type, result_id, operand_1, operand_2) builder.s_less_than(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::Less, ScalarKind::Float) => { (ast::SetpCompareOp::Less, ast::ScalarKind::Float) => {
builder.f_ord_less_than(result_type, result_id, operand_1, operand_2) builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned) (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Unsigned)
| (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => { | (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Bit) => {
builder.u_less_than_equal(result_type, result_id, operand_1, operand_2) builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => { (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Signed) => {
builder.s_less_than_equal(result_type, result_id, operand_1, operand_2) builder.s_less_than_equal(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => { (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Float) => {
builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2) builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::Greater, ScalarKind::Unsigned) (ast::SetpCompareOp::Greater, ast::ScalarKind::Unsigned)
| (ast::SetpCompareOp::Greater, ScalarKind::Bit) => { | (ast::SetpCompareOp::Greater, ast::ScalarKind::Bit) => {
builder.u_greater_than(result_type, result_id, operand_1, operand_2) builder.u_greater_than(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::Greater, ScalarKind::Signed) => { (ast::SetpCompareOp::Greater, ast::ScalarKind::Signed) => {
builder.s_greater_than(result_type, result_id, operand_1, operand_2) builder.s_greater_than(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::Greater, ScalarKind::Float) => { (ast::SetpCompareOp::Greater, ast::ScalarKind::Float) => {
builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2) builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned) (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Unsigned)
| (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => { | (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Bit) => {
builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2) builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => { (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Signed) => {
builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2) builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => { (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Float) => {
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2) builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
} }
(ast::SetpCompareOp::NanEq, _) => { (ast::SetpCompareOp::NanEq, _) => {
@ -4222,7 +4224,7 @@ fn emit_abs(
) -> Result<(), dr::Error> { ) -> Result<(), dr::Error> {
let scalar_t = ast::ScalarType::from(d.typ); let scalar_t = ast::ScalarType::from(d.typ);
let result_type = map.get_or_add(builder, SpirvType::from(scalar_t)); let result_type = map.get_or_add(builder, SpirvType::from(scalar_t));
let cl_abs = if scalar_t.kind() == ScalarKind::Signed { let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed {
spirv::CLOp::s_abs spirv::CLOp::s_abs
} else { } else {
spirv::CLOp::fabs spirv::CLOp::fabs
@ -4286,8 +4288,8 @@ fn emit_implicit_conversion(
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
if from_parts.width == to_parts.width { if from_parts.width == to_parts.width {
let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
if from_parts.scalar_kind != ScalarKind::Float if from_parts.scalar_kind != ast::ScalarKind::Float
&& to_parts.scalar_kind != ScalarKind::Float && to_parts.scalar_kind != ast::ScalarKind::Float
{ {
// It is noop, but another instruction expects result of this conversion // It is noop, but another instruction expects result of this conversion
builder.copy_object(dst_type, Some(cv.dst), cv.src)?; builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
@ -4299,24 +4301,24 @@ fn emit_implicit_conversion(
let same_width_bit_type = map.get_or_add( let same_width_bit_type = map.get_or_add(
builder, builder,
SpirvType::from(ast::Type::from_parts(TypeParts { SpirvType::from(ast::Type::from_parts(TypeParts {
scalar_kind: ScalarKind::Bit, scalar_kind: ast::ScalarKind::Bit,
..from_parts ..from_parts
})), })),
); );
let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?; let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?;
let wide_bit_type = ast::Type::from_parts(TypeParts { let wide_bit_type = ast::Type::from_parts(TypeParts {
scalar_kind: ScalarKind::Bit, scalar_kind: ast::ScalarKind::Bit,
..to_parts ..to_parts
}); });
let wide_bit_type_spirv = let wide_bit_type_spirv =
map.get_or_add(builder, SpirvType::from(wide_bit_type.clone())); map.get_or_add(builder, SpirvType::from(wide_bit_type.clone()));
if to_parts.scalar_kind == ScalarKind::Unsigned if to_parts.scalar_kind == ast::ScalarKind::Unsigned
|| to_parts.scalar_kind == ScalarKind::Bit || to_parts.scalar_kind == ast::ScalarKind::Bit
{ {
builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?; builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?;
} else { } else {
let conversion_fn = if from_parts.scalar_kind == ScalarKind::Signed let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed
&& to_parts.scalar_kind == ScalarKind::Signed && to_parts.scalar_kind == ast::ScalarKind::Signed
{ {
dr::Builder::s_convert dr::Builder::s_convert
} else { } else {
@ -4614,23 +4616,23 @@ fn convert_to_stateful_memory_access<'a>(
for statement in func_body.iter() { for statement in func_body.iter() {
match statement { match statement {
Statement::Instruction(ast::Instruction::Add( Statement::Instruction(ast::Instruction::Add(
ast::ArithDetails::Unsigned(ast::UIntType::U64), ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg, arg,
)) ))
| Statement::Instruction(ast::Instruction::Add( | Statement::Instruction(ast::Instruction::Add(
ast::ArithDetails::Signed(ast::ArithSInt { ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::S64, typ: ast::ScalarType::S64,
saturate: false, saturate: false,
}), }),
arg, arg,
)) ))
| Statement::Instruction(ast::Instruction::Sub( | Statement::Instruction(ast::Instruction::Sub(
ast::ArithDetails::Unsigned(ast::UIntType::U64), ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg, arg,
)) ))
| Statement::Instruction(ast::Instruction::Sub( | Statement::Instruction(ast::Instruction::Sub(
ast::ArithDetails::Signed(ast::ArithSInt { ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::S64, typ: ast::ScalarType::S64,
saturate: false, saturate: false,
}), }),
arg, arg,
@ -4686,12 +4688,12 @@ fn convert_to_stateful_memory_access<'a>(
} }
} }
Statement::Instruction(ast::Instruction::Add( Statement::Instruction(ast::Instruction::Add(
ast::ArithDetails::Unsigned(ast::UIntType::U64), ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg, arg,
)) ))
| Statement::Instruction(ast::Instruction::Add( | Statement::Instruction(ast::Instruction::Add(
ast::ArithDetails::Signed(ast::ArithSInt { ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::S64, typ: ast::ScalarType::S64,
saturate: false, saturate: false,
}), }),
arg, arg,
@ -4715,12 +4717,12 @@ fn convert_to_stateful_memory_access<'a>(
})) }))
} }
Statement::Instruction(ast::Instruction::Sub( Statement::Instruction(ast::Instruction::Sub(
ast::ArithDetails::Unsigned(ast::UIntType::U64), ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg, arg,
)) ))
| Statement::Instruction(ast::Instruction::Sub( | Statement::Instruction(ast::Instruction::Sub(
ast::ArithDetails::Signed(ast::ArithSInt { ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::S64, typ: ast::ScalarType::S64,
saturate: false, saturate: false,
}), }),
arg, arg,
@ -4867,7 +4869,7 @@ fn convert_to_stateful_memory_access_postprocess(
ast::LdStateSpace::Global, ast::LdStateSpace::Global,
), ),
to: old_type, to: old_type,
kind: ConversionKind::PtrToBit(ast::UIntType::U64), kind: ConversionKind::PtrToBit(ast::ScalarType::U64),
src_sema: arg_desc.sema, src_sema: arg_desc.sema,
dst_sema: ArgumentSemantics::Default, dst_sema: ArgumentSemantics::Default,
})); }));
@ -5903,7 +5905,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
let inst_type = d.typ; let inst_type = d.typ;
ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?) ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
} }
ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?), ast::Instruction::Not(t, a) => {
ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?)
}
ast::Instruction::Cvt(d, a) => { ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d { let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => ( ast::CvtDetails::FloatFromFloat(desc) => (
@ -5926,7 +5930,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?) ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?)
} }
ast::Instruction::Shl(t, a) => { ast::Instruction::Shl(t, a) => {
ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?) ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?)
} }
ast::Instruction::Shr(t, a) => { ast::Instruction::Shr(t, a) => {
ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?) ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?)
@ -6176,9 +6180,9 @@ impl ast::Type {
ast::Type::Scalar(scalar) => { ast::Type::Scalar(scalar) => {
let kind = scalar.kind(); let kind = scalar.kind();
let width = scalar.size_of(); let width = scalar.size_of();
if (kind != ScalarKind::Signed if (kind != ast::ScalarKind::Signed
&& kind != ScalarKind::Unsigned && kind != ast::ScalarKind::Unsigned
&& kind != ScalarKind::Bit) && kind != ast::ScalarKind::Bit)
|| (width == 8) || (width == 8)
{ {
return Err(TranslateError::MismatchedType); return Err(TranslateError::MismatchedType);
@ -6306,7 +6310,7 @@ impl ast::Type {
#[derive(Eq, PartialEq, Clone)] #[derive(Eq, PartialEq, Clone)]
struct TypeParts { struct TypeParts {
kind: TypeKind, kind: TypeKind,
scalar_kind: ScalarKind, scalar_kind: ast::ScalarKind,
width: u8, width: u8,
components: Vec<u32>, components: Vec<u32>,
state_space: ast::LdStateSpace, state_space: ast::LdStateSpace,
@ -6461,7 +6465,7 @@ enum ConversionKind {
// zero-extend/chop/bitcast depending on types // zero-extend/chop/bitcast depending on types
SignExtend, SignExtend,
BitToPtr(ast::LdStateSpace), BitToPtr(ast::LdStateSpace),
PtrToBit(ast::UIntType), PtrToBit(ast::ScalarType),
PtrToPtr { spirv_ptr: bool }, PtrToPtr { spirv_ptr: bool },
} }
@ -6859,7 +6863,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
fn map_selp<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>( fn map_selp<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self, self,
visitor: &mut V, visitor: &mut V,
t: ast::SelpType, t: ast::ScalarType,
) -> Result<ast::Arg4<U>, TranslateError> { ) -> Result<ast::Arg4<U>, TranslateError> {
let dst = visitor.operand( let dst = visitor.operand(
ArgumentDescriptor { ArgumentDescriptor {
@ -6904,7 +6908,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
fn map_atom<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>( fn map_atom<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self, self,
visitor: &mut V, visitor: &mut V,
t: ast::BitType, t: ast::ScalarType,
state_space: ast::AtomSpace, state_space: ast::AtomSpace,
) -> Result<ast::Arg4<U>, TranslateError> { ) -> Result<ast::Arg4<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t); let scalar_type = ast::ScalarType::from(t);
@ -7205,103 +7209,41 @@ impl ast::StStateSpace {
} }
} }
#[derive(Clone, Copy, PartialEq, Eq)]
enum ScalarKind {
Bit,
Unsigned,
Signed,
Float,
Float2,
Pred,
}
impl ast::ScalarType { impl ast::ScalarType {
fn kind(self) -> ScalarKind { fn from_parts(width: u8, kind: ast::ScalarKind) -> Self {
match self {
ast::ScalarType::U8 => ScalarKind::Unsigned,
ast::ScalarType::U16 => ScalarKind::Unsigned,
ast::ScalarType::U32 => ScalarKind::Unsigned,
ast::ScalarType::U64 => ScalarKind::Unsigned,
ast::ScalarType::S8 => ScalarKind::Signed,
ast::ScalarType::S16 => ScalarKind::Signed,
ast::ScalarType::S32 => ScalarKind::Signed,
ast::ScalarType::S64 => ScalarKind::Signed,
ast::ScalarType::B8 => ScalarKind::Bit,
ast::ScalarType::B16 => ScalarKind::Bit,
ast::ScalarType::B32 => ScalarKind::Bit,
ast::ScalarType::B64 => ScalarKind::Bit,
ast::ScalarType::F16 => ScalarKind::Float,
ast::ScalarType::F32 => ScalarKind::Float,
ast::ScalarType::F64 => ScalarKind::Float,
ast::ScalarType::F16x2 => ScalarKind::Float2,
ast::ScalarType::Pred => ScalarKind::Pred,
}
}
fn from_parts(width: u8, kind: ScalarKind) -> Self {
match kind { match kind {
ScalarKind::Float => match width { ast::ScalarKind::Float => match width {
2 => ast::ScalarType::F16, 2 => ast::ScalarType::F16,
4 => ast::ScalarType::F32, 4 => ast::ScalarType::F32,
8 => ast::ScalarType::F64, 8 => ast::ScalarType::F64,
_ => unreachable!(), _ => unreachable!(),
}, },
ScalarKind::Bit => match width { ast::ScalarKind::Bit => match width {
1 => ast::ScalarType::B8, 1 => ast::ScalarType::B8,
2 => ast::ScalarType::B16, 2 => ast::ScalarType::B16,
4 => ast::ScalarType::B32, 4 => ast::ScalarType::B32,
8 => ast::ScalarType::B64, 8 => ast::ScalarType::B64,
_ => unreachable!(), _ => unreachable!(),
}, },
ScalarKind::Signed => match width { ast::ScalarKind::Signed => match width {
1 => ast::ScalarType::S8, 1 => ast::ScalarType::S8,
2 => ast::ScalarType::S16, 2 => ast::ScalarType::S16,
4 => ast::ScalarType::S32, 4 => ast::ScalarType::S32,
8 => ast::ScalarType::S64, 8 => ast::ScalarType::S64,
_ => unreachable!(), _ => unreachable!(),
}, },
ScalarKind::Unsigned => match width { ast::ScalarKind::Unsigned => match width {
1 => ast::ScalarType::U8, 1 => ast::ScalarType::U8,
2 => ast::ScalarType::U16, 2 => ast::ScalarType::U16,
4 => ast::ScalarType::U32, 4 => ast::ScalarType::U32,
8 => ast::ScalarType::U64, 8 => ast::ScalarType::U64,
_ => unreachable!(), _ => unreachable!(),
}, },
ScalarKind::Float2 => match width { ast::ScalarKind::Float2 => match width {
4 => ast::ScalarType::F16x2, 4 => ast::ScalarType::F16x2,
_ => unreachable!(), _ => unreachable!(),
}, },
ScalarKind::Pred => ast::ScalarType::Pred, ast::ScalarKind::Pred => ast::ScalarType::Pred,
}
}
}
impl ast::BooleanType {
fn to_type(self) -> ast::Type {
match self {
ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
}
}
}
impl ast::ShlType {
fn to_type(self) -> ast::Type {
match self {
ast::ShlType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
ast::ShlType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
ast::ShlType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
}
}
}
impl ast::ShrType {
fn signed(&self) -> bool {
match self {
ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true,
_ => false,
} }
} }
} }
@ -7357,30 +7299,6 @@ impl ast::AtomInnerDetails {
} }
} }
impl ast::SIntType {
fn from_size(width: u8) -> Self {
match width {
1 => ast::SIntType::S8,
2 => ast::SIntType::S16,
4 => ast::SIntType::S32,
8 => ast::SIntType::S64,
_ => unreachable!(),
}
}
}
impl ast::UIntType {
fn from_size(width: u8) -> Self {
match width {
1 => ast::UIntType::U8,
2 => ast::UIntType::U16,
4 => ast::UIntType::U32,
8 => ast::UIntType::U64,
_ => unreachable!(),
}
}
}
impl ast::LdStateSpace { impl ast::LdStateSpace {
fn to_spirv(self) -> spirv::StorageClass { fn to_spirv(self) -> spirv::StorageClass {
match self { match self {
@ -7568,16 +7486,18 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
return false; return false;
} }
match inst.kind() { match inst.kind() {
ScalarKind::Bit => operand.kind() != ScalarKind::Bit, ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
ScalarKind::Float => operand.kind() == ScalarKind::Bit, ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
ScalarKind::Signed => { ast::ScalarKind::Signed => {
operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Unsigned
} }
ScalarKind::Unsigned => { ast::ScalarKind::Unsigned => {
operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Signed
} }
ScalarKind::Float2 => false, ast::ScalarKind::Float2 => false,
ScalarKind::Pred => false, ast::ScalarKind::Pred => false,
} }
} }
(ast::Type::Vector(inst, _), ast::Type::Vector(operand, _)) (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
@ -7596,7 +7516,7 @@ fn should_bitcast_packed(
if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
(operand, instr) (operand, instr)
{ {
if scalar.kind() == ScalarKind::Bit if scalar.kind() == ast::ScalarKind::Bit
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len) && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
{ {
return Ok(Some(ConversionKind::Default)); return Ok(Some(ConversionKind::Default));
@ -7644,32 +7564,33 @@ fn should_convert_relaxed_src(
} }
match (src_type, instr_type) { match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ScalarKind::Bit => { ast::ScalarKind::Bit => {
if instr_type.size_of() <= src_type.size_of() { if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default) Some(ConversionKind::Default)
} else { } else {
None None
} }
} }
ScalarKind::Signed | ScalarKind::Unsigned => { ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= src_type.size_of() if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() != ScalarKind::Float && src_type.kind() != ast::ScalarKind::Float
{ {
Some(ConversionKind::Default) Some(ConversionKind::Default)
} else { } else {
None None
} }
} }
ScalarKind::Float => { ast::ScalarKind::Float => {
if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() == ast::ScalarKind::Bit
{ {
Some(ConversionKind::Default) Some(ConversionKind::Default)
} else { } else {
None None
} }
} }
ScalarKind::Float2 => todo!(), ast::ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None, ast::ScalarKind::Pred => None,
}, },
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
@ -7706,15 +7627,15 @@ fn should_convert_relaxed_dst(
} }
match (dst_type, instr_type) { match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ScalarKind::Bit => { ast::ScalarKind::Bit => {
if instr_type.size_of() <= dst_type.size_of() { if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default) Some(ConversionKind::Default)
} else { } else {
None None
} }
} }
ScalarKind::Signed => { ast::ScalarKind::Signed => {
if dst_type.kind() != ScalarKind::Float { if dst_type.kind() != ast::ScalarKind::Float {
if instr_type.size_of() == dst_type.size_of() { if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default) Some(ConversionKind::Default)
} else if instr_type.size_of() < dst_type.size_of() { } else if instr_type.size_of() < dst_type.size_of() {
@ -7726,25 +7647,26 @@ fn should_convert_relaxed_dst(
None None
} }
} }
ScalarKind::Unsigned => { ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= dst_type.size_of() if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() != ScalarKind::Float && dst_type.kind() != ast::ScalarKind::Float
{ {
Some(ConversionKind::Default) Some(ConversionKind::Default)
} else { } else {
None None
} }
} }
ScalarKind::Float => { ast::ScalarKind::Float => {
if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() == ast::ScalarKind::Bit
{ {
Some(ConversionKind::Default) Some(ConversionKind::Default)
} else { } else {
None None
} }
} }
ScalarKind::Float2 => todo!(), ast::ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None, ast::ScalarKind::Pred => None,
}, },
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {