Convert enumes to 1TT

This commit is contained in:
Andrzej Janik 2021-04-15 19:10:45 +02:00
parent a005c92c61
commit a0baad9456
4 changed files with 332 additions and 462 deletions

View file

@ -210,20 +210,6 @@ sub_enum!(LdStScalarType {
F64,
});
sub_enum!(SelpType {
B16,
B32,
B64,
U16,
U32,
U64,
S16,
S32,
S64,
F32,
F64,
});
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum BarDetails {
SyncAligned,
@ -425,52 +411,6 @@ pub enum ScalarType {
Pred,
}
sub_enum!(IntType {
U8,
U16,
U32,
U64,
S8,
S16,
S32,
S64
});
sub_enum!(BitType { B8, B16, B32, B64 });
sub_enum!(UIntType { U8, U16, U32, U64 });
sub_enum!(SIntType { S8, S16, S32, S64 });
impl IntType {
pub fn is_signed(self) -> bool {
match self {
IntType::U8 | IntType::U16 | IntType::U32 | IntType::U64 => false,
IntType::S8 | IntType::S16 | IntType::S32 | IntType::S64 => true,
}
}
pub fn width(self) -> u8 {
match self {
IntType::U8 => 1,
IntType::U16 => 2,
IntType::U32 => 4,
IntType::U64 => 8,
IntType::S8 => 1,
IntType::S16 => 2,
IntType::S32 => 4,
IntType::S64 => 8,
}
}
}
sub_enum!(FloatType {
F16,
F16x2,
F32,
F64
});
impl ScalarType {
pub fn size_of(self) -> u8 {
match self {
@ -576,24 +516,24 @@ pub enum Instruction<P: ArgParams> {
Add(ArithDetails, Arg3<P>),
Setp(SetpData, Arg4Setp<P>),
SetpBool(SetpBoolData, Arg5Setp<P>),
Not(BooleanType, Arg2<P>),
Not(ScalarType, Arg2<P>),
Bra(BraData, Arg1<P>),
Cvt(CvtDetails, Arg2<P>),
Cvta(CvtaDetails, Arg2<P>),
Shl(ShlType, Arg3<P>),
Shr(ShrType, Arg3<P>),
Shl(ScalarType, Arg3<P>),
Shr(ScalarType, Arg3<P>),
St(StData, Arg2St<P>),
Ret(RetData),
Call(CallInst<P>),
Abs(AbsDetails, Arg2<P>),
Mad(MulDetails, Arg4<P>),
Or(BooleanType, Arg3<P>),
Or(ScalarType, Arg3<P>),
Sub(ArithDetails, Arg3<P>),
Min(MinMaxDetails, Arg3<P>),
Max(MinMaxDetails, Arg3<P>),
Rcp(RcpDetails, Arg2<P>),
And(BooleanType, Arg3<P>),
Selp(SelpType, Arg4<P>),
And(ScalarType, Arg3<P>),
Selp(ScalarType, Arg4<P>),
Bar(BarDetails, Arg1Bar<P>),
Atom(AtomDetails, Arg3<P>),
AtomCas(AtomCasDetails, Arg4<P>),
@ -605,13 +545,13 @@ pub enum Instruction<P: ArgParams> {
Cos { flush_to_zero: bool, arg: Arg2<P> },
Lg2 { flush_to_zero: bool, arg: Arg2<P> },
Ex2 { flush_to_zero: bool, arg: Arg2<P> },
Clz { typ: BitType, arg: Arg2<P> },
Brev { typ: BitType, arg: Arg2<P> },
Popc { typ: BitType, arg: Arg2<P> },
Xor { typ: BooleanType, arg: Arg3<P> },
Bfe { typ: IntType, arg: Arg4<P> },
Bfi { typ: BitType, arg: Arg5<P> },
Rem { typ: IntType, arg: Arg3<P> },
Clz { typ: ScalarType, arg: Arg2<P> },
Brev { typ: ScalarType, arg: Arg2<P> },
Popc { typ: ScalarType, arg: Arg2<P> },
Xor { typ: ScalarType, arg: Arg3<P> },
Bfe { typ: ScalarType, arg: Arg4<P> },
Bfi { typ: ScalarType, arg: Arg5<P> },
Rem { typ: ScalarType, arg: Arg3<P> },
}
#[derive(Copy, Clone)]
@ -825,7 +765,7 @@ impl MovDetails {
#[derive(Copy, Clone)]
pub struct MulIntDesc {
pub typ: IntType,
pub typ: ScalarType,
pub control: MulIntControl,
}
@ -845,7 +785,7 @@ pub enum RoundingMode {
}
pub struct AddIntDesc {
pub typ: IntType,
pub typ: ScalarType,
pub saturate: bool,
}
@ -892,39 +832,39 @@ pub struct BraData {
pub enum CvtDetails {
IntFromInt(CvtIntToIntDesc),
FloatFromFloat(CvtDesc<FloatType, FloatType>),
IntFromFloat(CvtDesc<IntType, FloatType>),
FloatFromInt(CvtDesc<FloatType, IntType>),
FloatFromFloat(CvtDesc),
IntFromFloat(CvtDesc),
FloatFromInt(CvtDesc),
}
pub struct CvtIntToIntDesc {
pub dst: IntType,
pub src: IntType,
pub dst: ScalarType,
pub src: ScalarType,
pub saturate: bool,
}
pub struct CvtDesc<Dst, Src> {
pub struct CvtDesc {
pub rounding: Option<RoundingMode>,
pub flush_to_zero: Option<bool>,
pub saturate: bool,
pub dst: Dst,
pub src: Src,
pub dst: ScalarType,
pub src: ScalarType,
}
impl CvtDetails {
pub fn new_int_from_int_checked<'err, 'input>(
saturate: bool,
dst: IntType,
src: IntType,
dst: ScalarType,
src: ScalarType,
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self {
if saturate {
if src.is_signed() {
if dst.is_signed() && dst.width() >= src.width() {
if src.kind() == ScalarKind::Signed {
if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() {
err.push(ParseError::from(PtxError::SyntaxError));
}
} else {
if dst == src || dst.width() >= src.width() {
if dst == src || dst.size_of() >= src.size_of() {
err.push(ParseError::from(PtxError::SyntaxError));
}
}
@ -936,11 +876,11 @@ impl CvtDetails {
rounding: RoundingMode,
flush_to_zero: bool,
saturate: bool,
dst: FloatType,
src: IntType,
dst: ScalarType,
src: ScalarType,
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self {
if flush_to_zero && dst != FloatType::F32 {
if flush_to_zero && dst != ScalarType::F32 {
err.push(ParseError::from(PtxError::NonF32Ftz));
}
CvtDetails::FloatFromInt(CvtDesc {
@ -956,11 +896,11 @@ impl CvtDetails {
rounding: RoundingMode,
flush_to_zero: bool,
saturate: bool,
dst: IntType,
src: FloatType,
dst: ScalarType,
src: ScalarType,
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self {
if flush_to_zero && src != FloatType::F32 {
if flush_to_zero && src != ScalarType::F32 {
err.push(ParseError::from(PtxError::NonF32Ftz));
}
CvtDetails::IntFromFloat(CvtDesc {
@ -993,25 +933,6 @@ pub enum CvtaSize {
U64,
}
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum ShlType {
B16,
B32,
B64,
}
sub_enum!(ShrType {
B16,
B32,
B64,
U16,
U32,
U64,
S16,
S32,
S64,
});
pub struct StData {
pub qualifier: LdStQualifier,
pub state_space: StStateSpace,
@ -1040,13 +961,6 @@ pub struct RetData {
pub uniform: bool,
}
sub_enum!(BooleanType {
Pred,
B16,
B32,
B64,
});
#[derive(Copy, Clone)]
pub enum MulDetails {
Unsigned(MulUInt),
@ -1056,32 +970,32 @@ pub enum MulDetails {
#[derive(Copy, Clone)]
pub struct MulUInt {
pub typ: UIntType,
pub typ: ScalarType,
pub control: MulIntControl,
}
#[derive(Copy, Clone)]
pub struct MulSInt {
pub typ: SIntType,
pub typ: ScalarType,
pub control: MulIntControl,
}
#[derive(Copy, Clone)]
pub enum ArithDetails {
Unsigned(UIntType),
Unsigned(ScalarType),
Signed(ArithSInt),
Float(ArithFloat),
}
#[derive(Copy, Clone)]
pub struct ArithSInt {
pub typ: SIntType,
pub typ: ScalarType,
pub saturate: bool,
}
#[derive(Copy, Clone)]
pub struct ArithFloat {
pub typ: FloatType,
pub typ: ScalarType,
pub rounding: Option<RoundingMode>,
pub flush_to_zero: Option<bool>,
pub saturate: bool,
@ -1089,8 +1003,8 @@ pub struct ArithFloat {
#[derive(Copy, Clone)]
pub enum MinMaxDetails {
Signed(SIntType),
Unsigned(UIntType),
Signed(ScalarType),
Unsigned(ScalarType),
Float(MinMaxFloat),
}
@ -1098,7 +1012,7 @@ pub enum MinMaxDetails {
pub struct MinMaxFloat {
pub flush_to_zero: Option<bool>,
pub nan: bool,
pub typ: FloatType,
pub typ: ScalarType,
}
#[derive(Copy, Clone)]
@ -1126,10 +1040,10 @@ pub enum AtomSpace {
#[derive(Copy, Clone)]
pub enum AtomInnerDetails {
Bit { op: AtomBitOp, typ: BitType },
Unsigned { op: AtomUIntOp, typ: UIntType },
Signed { op: AtomSIntOp, typ: SIntType },
Float { op: AtomFloatOp, typ: FloatType },
Bit { op: AtomBitOp, typ: ScalarType },
Unsigned { op: AtomUIntOp, typ: ScalarType },
Signed { op: AtomSIntOp, typ: ScalarType },
Float { op: AtomFloatOp, typ: ScalarType },
}
#[derive(Copy, Clone, Eq, PartialEq)]
@ -1166,19 +1080,19 @@ pub struct AtomCasDetails {
pub semantics: AtomSemantics,
pub scope: MemScope,
pub space: AtomSpace,
pub typ: BitType,
pub typ: ScalarType,
}
#[derive(Copy, Clone)]
pub enum DivDetails {
Unsigned(UIntType),
Signed(SIntType),
Unsigned(ScalarType),
Signed(ScalarType),
Float(DivFloatDetails),
}
#[derive(Copy, Clone)]
pub struct DivFloatDetails {
pub typ: FloatType,
pub typ: ScalarType,
pub flush_to_zero: Option<bool>,
pub kind: DivFloatKind,
}
@ -1197,7 +1111,7 @@ pub enum NumsOrArrays<'a> {
#[derive(Copy, Clone)]
pub struct SqrtDetails {
pub typ: FloatType,
pub typ: ScalarType,
pub flush_to_zero: Option<bool>,
pub kind: SqrtKind,
}
@ -1210,7 +1124,7 @@ pub enum SqrtKind {
#[derive(Copy, Clone, Eq, PartialEq)]
pub struct RsqrtDetails {
pub typ: FloatType,
pub typ: ScalarType,
pub flush_to_zero: bool,
}
@ -1379,6 +1293,40 @@ pub enum TuningDirective {
MinNCtaPerSm(u32),
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum ScalarKind {
Bit,
Unsigned,
Signed,
Float,
Float2,
Pred,
}
impl ScalarType {
pub fn kind(self) -> ScalarKind {
match self {
ScalarType::U8 => ScalarKind::Unsigned,
ScalarType::U16 => ScalarKind::Unsigned,
ScalarType::U32 => ScalarKind::Unsigned,
ScalarType::U64 => ScalarKind::Unsigned,
ScalarType::S8 => ScalarKind::Signed,
ScalarType::S16 => ScalarKind::Signed,
ScalarType::S32 => ScalarKind::Signed,
ScalarType::S64 => ScalarKind::Signed,
ScalarType::B8 => ScalarKind::Bit,
ScalarType::B16 => ScalarKind::Bit,
ScalarType::B32 => ScalarKind::Bit,
ScalarType::B64 => ScalarKind::Bit,
ScalarType::F16 => ScalarKind::Float,
ScalarType::F32 => ScalarKind::Float,
ScalarType::F64 => ScalarKind::Float,
ScalarType::F16x2 => ScalarKind::Float2,
ScalarType::Pred => ScalarKind::Pred,
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -899,39 +899,39 @@ RoundingModeInt : ast::RoundingMode = {
".rpi" => ast::RoundingMode::PositiveInf,
};
IntType : ast::IntType = {
".u16" => ast::IntType::U16,
".u32" => ast::IntType::U32,
".u64" => ast::IntType::U64,
".s16" => ast::IntType::S16,
".s32" => ast::IntType::S32,
".s64" => ast::IntType::S64,
IntType : ast::ScalarType = {
".u16" => ast::ScalarType::U16,
".u32" => ast::ScalarType::U32,
".u64" => ast::ScalarType::U64,
".s16" => ast::ScalarType::S16,
".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64,
};
IntType3264: ast::IntType = {
".u32" => ast::IntType::U32,
".u64" => ast::IntType::U64,
".s32" => ast::IntType::S32,
".s64" => ast::IntType::S64,
IntType3264: ast::ScalarType = {
".u32" => ast::ScalarType::U32,
".u64" => ast::ScalarType::U64,
".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64,
}
UIntType: ast::UIntType = {
".u16" => ast::UIntType::U16,
".u32" => ast::UIntType::U32,
".u64" => ast::UIntType::U64,
UIntType: ast::ScalarType = {
".u16" => ast::ScalarType::U16,
".u32" => ast::ScalarType::U32,
".u64" => ast::ScalarType::U64,
};
SIntType: ast::SIntType = {
".s16" => ast::SIntType::S16,
".s32" => ast::SIntType::S32,
".s64" => ast::SIntType::S64,
SIntType: ast::ScalarType = {
".s16" => ast::ScalarType::S16,
".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64,
};
FloatType: ast::FloatType = {
".f16" => ast::FloatType::F16,
".f16x2" => ast::FloatType::F16x2,
".f32" => ast::FloatType::F32,
".f64" => ast::FloatType::F64,
FloatType: ast::ScalarType = {
".f16" => ast::ScalarType::F16,
".f16x2" => ast::ScalarType::F16x2,
".f32" => ast::ScalarType::F32,
".f64" => ast::ScalarType::F64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
@ -1023,11 +1023,11 @@ InstNot: ast::Instruction<ast::ParsedArgParams<'input>> = {
"not" <t:BooleanType> <a:Arg2> => ast::Instruction::Not(t, a)
};
BooleanType: ast::BooleanType = {
".pred" => ast::BooleanType::Pred,
".b16" => ast::BooleanType::B16,
".b32" => ast::BooleanType::B32,
".b64" => ast::BooleanType::B64,
BooleanType: ast::ScalarType = {
".pred" => ast::ScalarType::Pred,
".b16" => ast::ScalarType::B16,
".b32" => ast::ScalarType::B32,
".b64" => ast::ScalarType::B64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at
@ -1080,8 +1080,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: r,
flush_to_zero: None,
saturate: s.is_some(),
dst: ast::FloatType::F16,
src: ast::FloatType::F16
dst: ast::ScalarType::F16,
src: ast::ScalarType::F16
}
), a)
},
@ -1091,8 +1091,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: None,
flush_to_zero: Some(f.is_some()),
saturate: s.is_some(),
dst: ast::FloatType::F32,
src: ast::FloatType::F16
dst: ast::ScalarType::F32,
src: ast::ScalarType::F16
}
), a)
},
@ -1102,8 +1102,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: None,
flush_to_zero: None,
saturate: s.is_some(),
dst: ast::FloatType::F64,
src: ast::FloatType::F16
dst: ast::ScalarType::F64,
src: ast::ScalarType::F16
}
), a)
},
@ -1113,8 +1113,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: Some(r),
flush_to_zero: Some(f.is_some()),
saturate: s.is_some(),
dst: ast::FloatType::F16,
src: ast::FloatType::F32
dst: ast::ScalarType::F16,
src: ast::ScalarType::F32
}
), a)
},
@ -1124,8 +1124,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: r,
flush_to_zero: Some(f.is_some()),
saturate: s.is_some(),
dst: ast::FloatType::F32,
src: ast::FloatType::F32
dst: ast::ScalarType::F32,
src: ast::ScalarType::F32
}
), a)
},
@ -1135,8 +1135,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: None,
flush_to_zero: None,
saturate: s.is_some(),
dst: ast::FloatType::F64,
src: ast::FloatType::F32
dst: ast::ScalarType::F64,
src: ast::ScalarType::F32
}
), a)
},
@ -1146,8 +1146,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: Some(r),
flush_to_zero: None,
saturate: s.is_some(),
dst: ast::FloatType::F16,
src: ast::FloatType::F64
dst: ast::ScalarType::F16,
src: ast::ScalarType::F64
}
), a)
},
@ -1157,8 +1157,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: Some(r),
flush_to_zero: Some(s.is_some()),
saturate: s.is_some(),
dst: ast::FloatType::F32,
src: ast::FloatType::F64
dst: ast::ScalarType::F32,
src: ast::ScalarType::F64
}
), a)
},
@ -1168,28 +1168,28 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: r,
flush_to_zero: None,
saturate: s.is_some(),
dst: ast::FloatType::F64,
src: ast::FloatType::F64
dst: ast::ScalarType::F64,
src: ast::ScalarType::F64
}
), a)
},
};
CvtTypeInt: ast::IntType = {
".u8" => ast::IntType::U8,
".u16" => ast::IntType::U16,
".u32" => ast::IntType::U32,
".u64" => ast::IntType::U64,
".s8" => ast::IntType::S8,
".s16" => ast::IntType::S16,
".s32" => ast::IntType::S32,
".s64" => ast::IntType::S64,
CvtTypeInt: ast::ScalarType = {
".u8" => ast::ScalarType::U8,
".u16" => ast::ScalarType::U16,
".u32" => ast::ScalarType::U32,
".u64" => ast::ScalarType::U64,
".s8" => ast::ScalarType::S8,
".s16" => ast::ScalarType::S16,
".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64,
};
CvtTypeFloat: ast::FloatType = {
".f16" => ast::FloatType::F16,
".f32" => ast::FloatType::F32,
".f64" => ast::FloatType::F64,
CvtTypeFloat: ast::ScalarType = {
".f16" => ast::ScalarType::F16,
".f32" => ast::ScalarType::F32,
".f64" => ast::ScalarType::F64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl
@ -1197,10 +1197,10 @@ InstShl: ast::Instruction<ast::ParsedArgParams<'input>> = {
"shl" <t:ShlType> <a:Arg3> => ast::Instruction::Shl(t, a)
};
ShlType: ast::ShlType = {
".b16" => ast::ShlType::B16,
".b32" => ast::ShlType::B32,
".b64" => ast::ShlType::B64,
ShlType: ast::ScalarType = {
".b16" => ast::ScalarType::B16,
".b32" => ast::ScalarType::B32,
".b64" => ast::ScalarType::B64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr
@ -1208,16 +1208,16 @@ InstShr: ast::Instruction<ast::ParsedArgParams<'input>> = {
"shr" <t:ShrType> <a:Arg3> => ast::Instruction::Shr(t, a)
};
ShrType: ast::ShrType = {
".b16" => ast::ShrType::B16,
".b32" => ast::ShrType::B32,
".b64" => ast::ShrType::B64,
".u16" => ast::ShrType::U16,
".u32" => ast::ShrType::U32,
".u64" => ast::ShrType::U64,
".s16" => ast::ShrType::S16,
".s32" => ast::ShrType::S32,
".s64" => ast::ShrType::S64,
ShrType: ast::ScalarType = {
".b16" => ast::ScalarType::B16,
".b32" => ast::ScalarType::B32,
".b64" => ast::ScalarType::B64,
".u16" => ast::ScalarType::U16,
".u32" => ast::ScalarType::U32,
".u64" => ast::ScalarType::U64,
".s16" => ast::ScalarType::S16,
".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
@ -1393,16 +1393,16 @@ MinMaxDetails: ast::MinMaxDetails = {
<t:UIntType> => ast::MinMaxDetails::Unsigned(t),
<t:SIntType> => ast::MinMaxDetails::Signed(t),
<ftz:".ftz"?> <nan:".NaN"?> ".f32" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F32 }
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F32 }
),
".f64" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::FloatType::F64 }
ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::ScalarType::F64 }
),
<ftz:".ftz"?> <nan:".NaN"?> ".f16" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16 }
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16 }
),
<ftz:".ftz"?> <nan:".NaN"?> ".f16x2" => ast::MinMaxDetails::Float(
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16x2 }
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16x2 }
)
}
@ -1411,18 +1411,18 @@ InstSelp: ast::Instruction<ast::ParsedArgParams<'input>> = {
"selp" <t:SelpType> <a:Arg4> => ast::Instruction::Selp(t, a),
};
SelpType: ast::SelpType = {
".b16" => ast::SelpType::B16,
".b32" => ast::SelpType::B32,
".b64" => ast::SelpType::B64,
".u16" => ast::SelpType::U16,
".u32" => ast::SelpType::U32,
".u64" => ast::SelpType::U64,
".s16" => ast::SelpType::S16,
".s32" => ast::SelpType::S32,
".s64" => ast::SelpType::S64,
".f32" => ast::SelpType::F32,
".f64" => ast::SelpType::F64,
SelpType: ast::ScalarType = {
".b16" => ast::ScalarType::B16,
".b32" => ast::ScalarType::B32,
".b64" => ast::ScalarType::B64,
".u16" => ast::ScalarType::U16,
".u32" => ast::ScalarType::U32,
".u64" => ast::ScalarType::U64,
".s16" => ast::ScalarType::S16,
".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64,
".f32" => ast::ScalarType::F32,
".f64" => ast::ScalarType::F64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar
@ -1454,7 +1454,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
space: space.unwrap_or(ast::AtomSpace::Generic),
inner: ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Inc,
typ: ast::UIntType::U32
typ: ast::ScalarType::U32
}
};
ast::Instruction::Atom(details,a)
@ -1466,7 +1466,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
space: space.unwrap_or(ast::AtomSpace::Generic),
inner: ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Dec,
typ: ast::UIntType::U32
typ: ast::ScalarType::U32
}
};
ast::Instruction::Atom(details,a)
@ -1544,19 +1544,19 @@ AtomSIntOp: ast::AtomSIntOp = {
".max" => ast::AtomSIntOp::Max,
}
BitType: ast::BitType = {
".b32" => ast::BitType::B32,
".b64" => ast::BitType::B64,
BitType: ast::ScalarType = {
".b32" => ast::ScalarType::B32,
".b64" => ast::ScalarType::B64,
}
UIntType3264: ast::UIntType = {
".u32" => ast::UIntType::U32,
".u64" => ast::UIntType::U64,
UIntType3264: ast::ScalarType = {
".u32" => ast::ScalarType::U32,
".u64" => ast::ScalarType::U64,
}
SIntType3264: ast::SIntType = {
".s32" => ast::SIntType::S32,
".s64" => ast::SIntType::S64,
SIntType3264: ast::ScalarType = {
".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64,
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div
@ -1566,7 +1566,7 @@ InstDiv: ast::Instruction<ast::ParsedArgParams<'input>> = {
"div" <t:SIntType> <a:Arg3> => ast::Instruction::Div(ast::DivDetails::Signed(t), a),
"div" <kind:DivFloatKind> <ftz:".ftz"?> ".f32" <a:Arg3> => {
let inner = ast::DivFloatDetails {
typ: ast::FloatType::F32,
typ: ast::ScalarType::F32,
flush_to_zero: Some(ftz.is_some()),
kind
};
@ -1574,7 +1574,7 @@ InstDiv: ast::Instruction<ast::ParsedArgParams<'input>> = {
},
"div" <rnd:RoundingModeFloat> ".f64" <a:Arg3> => {
let inner = ast::DivFloatDetails {
typ: ast::FloatType::F64,
typ: ast::ScalarType::F64,
flush_to_zero: None,
kind: ast::DivFloatKind::Rounding(rnd)
};
@ -1592,7 +1592,7 @@ DivFloatKind: ast::DivFloatKind = {
InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"sqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => {
let details = ast::SqrtDetails {
typ: ast::FloatType::F32,
typ: ast::ScalarType::F32,
flush_to_zero: Some(ftz.is_some()),
kind: ast::SqrtKind::Approx,
};
@ -1600,7 +1600,7 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
},
"sqrt" <rnd:RoundingModeFloat> <ftz:".ftz"?> ".f32" <a:Arg2> => {
let details = ast::SqrtDetails {
typ: ast::FloatType::F32,
typ: ast::ScalarType::F32,
flush_to_zero: Some(ftz.is_some()),
kind: ast::SqrtKind::Rounding(rnd),
};
@ -1608,7 +1608,7 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
},
"sqrt" <rnd:RoundingModeFloat> ".f64" <a:Arg2> => {
let details = ast::SqrtDetails {
typ: ast::FloatType::F64,
typ: ast::ScalarType::F64,
flush_to_zero: None,
kind: ast::SqrtKind::Rounding(rnd),
};
@ -1621,14 +1621,14 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstRsqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"rsqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => {
let details = ast::RsqrtDetails {
typ: ast::FloatType::F32,
typ: ast::ScalarType::F32,
flush_to_zero: ftz.is_some(),
};
ast::Instruction::Rsqrt(details, a)
},
"rsqrt" ".approx" <ftz:".ftz"?> ".f64" <a:Arg2> => {
let details = ast::RsqrtDetails {
typ: ast::FloatType::F64,
typ: ast::ScalarType::F64,
flush_to_zero: ftz.is_some(),
};
ast::Instruction::Rsqrt(details, a)
@ -1739,7 +1739,7 @@ ArithDetails: ast::ArithDetails = {
saturate: false,
}),
".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::S32,
typ: ast::ScalarType::S32,
saturate: true,
}),
<f:ArithFloat> => ast::ArithDetails::Float(f)
@ -1747,25 +1747,25 @@ ArithDetails: ast::ArithDetails = {
ArithFloat: ast::ArithFloat = {
<rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
typ: ast::FloatType::F32,
typ: ast::ScalarType::F32,
rounding: rn,
flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(),
},
<rn:RoundingModeFloat?> ".f64" => ast::ArithFloat {
typ: ast::FloatType::F64,
typ: ast::ScalarType::F64,
rounding: rn,
flush_to_zero: None,
saturate: false,
},
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
typ: ast::FloatType::F16,
typ: ast::ScalarType::F16,
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(),
},
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
typ: ast::FloatType::F16x2,
typ: ast::ScalarType::F16x2,
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(),
@ -1774,25 +1774,25 @@ ArithFloat: ast::ArithFloat = {
ArithFloatMustRound: ast::ArithFloat = {
<rn:RoundingModeFloat> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
typ: ast::FloatType::F32,
typ: ast::ScalarType::F32,
rounding: Some(rn),
flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(),
},
<rn:RoundingModeFloat> ".f64" => ast::ArithFloat {
typ: ast::FloatType::F64,
typ: ast::ScalarType::F64,
rounding: Some(rn),
flush_to_zero: None,
saturate: false,
},
".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
typ: ast::FloatType::F16,
typ: ast::ScalarType::F16,
rounding: Some(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(),
},
".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
typ: ast::FloatType::F16x2,
typ: ast::ScalarType::F16x2,
rounding: Some(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(),

View file

@ -71,7 +71,7 @@
%26 = OpLoad %uint %9
%40 = OpCopyObject %uint %23
%41 = OpCopyObject %uint %24
%39 = OpFunctionCall %uint %44 %41 %40 %25 %26
%39 = OpFunctionCall %uint %44 %40 %41 %25 %26
%22 = OpCopyObject %uint %39
OpStore %6 %22
%27 = OpLoad %ulong %5

View file

@ -1553,10 +1553,9 @@ fn extract_globals<'input, 'b>(
space,
};
let (op, typ) = match typ {
ast::FloatType::F32 => ("add_f32", ast::SizedScalarType::F32),
ast::FloatType::F64 => ("add_f64", ast::SizedScalarType::F64),
ast::FloatType::F16 => unreachable!(),
ast::FloatType::F16x2 => unreachable!(),
ast::ScalarType::F32 => ("add_f32", ast::SizedScalarType::F32),
ast::ScalarType::F64 => ("add_f64", ast::SizedScalarType::F64),
_ => unreachable!(),
};
local.push(to_ptx_impl_atomic_call(
id_def,
@ -1822,15 +1821,15 @@ fn to_ptx_impl_atomic_call(
fn to_ptx_impl_bfe_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
typ: ast::IntType,
typ: ast::ScalarType,
arg: ast::Arg4<ExpandedArgParams>,
) -> ExpandedStatement {
let prefix = "__zluda_ptx_impl__";
let suffix = match typ {
ast::IntType::U32 => "bfe_u32",
ast::IntType::U64 => "bfe_u64",
ast::IntType::S32 => "bfe_s32",
ast::IntType::S64 => "bfe_s64",
ast::ScalarType::U32 => "bfe_u32",
ast::ScalarType::U64 => "bfe_u64",
ast::ScalarType::S32 => "bfe_s32",
ast::ScalarType::S64 => "bfe_s64",
_ => unreachable!(),
};
let fn_name = format!("{}{}", prefix, suffix);
@ -1917,14 +1916,14 @@ fn to_ptx_impl_bfe_call(
fn to_ptx_impl_bfi_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
typ: ast::BitType,
typ: ast::ScalarType,
arg: ast::Arg5<ExpandedArgParams>,
) -> ExpandedStatement {
let prefix = "__zluda_ptx_impl__";
let suffix = match typ {
ast::BitType::B32 => "bfi_b32",
ast::BitType::B64 => "bfi_b64",
ast::BitType::B8 | ast::BitType::B16 => unreachable!(),
ast::ScalarType::B32 => "bfi_b32",
ast::ScalarType::B64 => "bfi_b64",
_ => unreachable!(),
};
let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) {
@ -2506,29 +2505,32 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
let (width, kind) = match add_type {
ast::Type::Scalar(scalar_t) => {
let kind = match scalar_t.kind() {
kind @ ScalarKind::Bit
| kind @ ScalarKind::Unsigned
| kind @ ScalarKind::Signed => kind,
ScalarKind::Float => return Err(TranslateError::MismatchedType),
ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
ScalarKind::Pred => return Err(TranslateError::MismatchedType),
kind @ ast::ScalarKind::Bit
| kind @ ast::ScalarKind::Unsigned
| kind @ ast::ScalarKind::Signed => kind,
ast::ScalarKind::Float => return Err(TranslateError::MismatchedType),
ast::ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
ast::ScalarKind::Pred => return Err(TranslateError::MismatchedType),
};
(scalar_t.size_of(), kind)
}
_ => return Err(TranslateError::MismatchedType),
};
let arith_detail = if kind == ScalarKind::Signed {
let arith_detail = if kind == ast::ScalarKind::Signed {
ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::from_size(width),
typ: ast::ScalarType::from_parts(width, ast::ScalarKind::Signed),
saturate: false,
})
} else {
ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
ast::ArithDetails::Unsigned(ast::ScalarType::from_parts(
width,
ast::ScalarKind::Unsigned,
))
};
let id_constant_stmt = self.id_def.new_non_variable(add_type.clone());
let result_id = self.id_def.new_non_variable(add_type);
// TODO: check for edge cases around min value/max value/wrapping
if offset < 0 && kind != ScalarKind::Signed {
if offset < 0 && kind != ast::ScalarKind::Signed {
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::from_parts(width, kind),
@ -3026,18 +3028,18 @@ fn emit_function_body_ops(
emit_setp(builder, map, setp, arg)?;
}
ast::Instruction::Not(t, a) => {
let result_type = map.get_or_add(builder, SpirvType::from(t.to_type()));
let result_type = map.get_or_add(builder, SpirvType::from(*t));
let result_id = Some(a.dst);
let operand = a.src;
match t {
ast::BooleanType::Pred => {
ast::ScalarType::Pred => {
logical_not(builder, result_type, result_id, operand)
}
_ => builder.not(result_type, result_id, operand),
}?;
}
ast::Instruction::Shl(t, a) => {
let full_type = t.to_type();
let full_type = ast::Type::Scalar(*t);
let size_of = full_type.size_of();
let result_type = map.get_or_add(builder, SpirvType::from(full_type));
let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?;
@ -3048,7 +3050,7 @@ fn emit_function_body_ops(
let size_of = full_type.size_of();
let result_type = map.get_or_add_scalar(builder, full_type);
let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?;
if t.signed() {
if t.kind() == ast::ScalarKind::Signed {
builder.shift_right_arithmetic(
result_type,
Some(a.dst),
@ -3088,7 +3090,7 @@ fn emit_function_body_ops(
},
ast::Instruction::Or(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
if *t == ast::BooleanType::Pred {
if *t == ast::ScalarType::Pred {
builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
@ -3116,7 +3118,7 @@ fn emit_function_body_ops(
}
ast::Instruction::And(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
if *t == ast::BooleanType::Pred {
if *t == ast::ScalarType::Pred {
builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?;
@ -3202,7 +3204,7 @@ fn emit_function_body_ops(
}
ast::Instruction::Neg(details, arg) => {
let result_type = map.get_or_add_scalar(builder, details.typ);
let negate_func = if details.typ.kind() == ScalarKind::Float {
let negate_func = if details.typ.kind() == ast::ScalarKind::Float {
dr::Builder::f_negate
} else {
dr::Builder::s_negate
@ -3269,7 +3271,7 @@ fn emit_function_body_ops(
}
ast::Instruction::Xor { typ, arg } => {
let builder_fn = match typ {
ast::BooleanType::Pred => emit_logical_xor_spirv,
ast::ScalarType::Pred => emit_logical_xor_spirv,
_ => dr::Builder::bitwise_xor,
};
let result_type = map.get_or_add_scalar(builder, (*typ).into());
@ -3284,7 +3286,7 @@ fn emit_function_body_ops(
return Err(error_unreachable());
}
ast::Instruction::Rem { typ, arg } => {
let builder_fn = if typ.is_signed() {
let builder_fn = if typ.kind() == ast::ScalarKind::Signed {
dr::Builder::s_mod
} else {
dr::Builder::u_mod
@ -3882,7 +3884,7 @@ fn emit_cvt(
}
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.src.is_signed() {
if desc.src.kind() == ast::ScalarKind::Signed {
builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?;
} else {
builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?;
@ -3892,7 +3894,7 @@ fn emit_cvt(
ast::CvtDetails::IntFromFloat(desc) => {
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.dst.is_signed() {
if desc.dst.kind() == ast::ScalarKind::Signed {
builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?;
} else {
builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?;
@ -3904,7 +3906,7 @@ fn emit_cvt(
let dest_t: ast::ScalarType = desc.dst.into();
let src_t: ast::ScalarType = desc.src.into();
// first do shortening/widening
let src = if desc.dst.width() != desc.src.width() {
let src = if desc.dst.size_of() != desc.src.size_of() {
let new_dst = if dest_t.kind() == src_t.kind() {
arg.dst
} else {
@ -3933,7 +3935,7 @@ fn emit_cvt(
// now do actual conversion
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.saturate {
if desc.dst.is_signed() {
if desc.dst.kind() == ast::ScalarKind::Signed {
builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?;
} else {
builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?;
@ -3989,60 +3991,60 @@ fn emit_setp(
let operand_1 = arg.src1;
let operand_2 = arg.src2;
match (setp.cmp_op, setp.typ.kind()) {
(ast::SetpCompareOp::Eq, ScalarKind::Signed)
| (ast::SetpCompareOp::Eq, ScalarKind::Unsigned)
| (ast::SetpCompareOp::Eq, ScalarKind::Bit) => {
(ast::SetpCompareOp::Eq, ast::ScalarKind::Signed)
| (ast::SetpCompareOp::Eq, ast::ScalarKind::Unsigned)
| (ast::SetpCompareOp::Eq, ast::ScalarKind::Bit) => {
builder.i_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Eq, ScalarKind::Float) => {
(ast::SetpCompareOp::Eq, ast::ScalarKind::Float) => {
builder.f_ord_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NotEq, ScalarKind::Signed)
| (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned)
| (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => {
(ast::SetpCompareOp::NotEq, ast::ScalarKind::Signed)
| (ast::SetpCompareOp::NotEq, ast::ScalarKind::Unsigned)
| (ast::SetpCompareOp::NotEq, ast::ScalarKind::Bit) => {
builder.i_not_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NotEq, ScalarKind::Float) => {
(ast::SetpCompareOp::NotEq, ast::ScalarKind::Float) => {
builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Unsigned)
| (ast::SetpCompareOp::Less, ScalarKind::Bit) => {
(ast::SetpCompareOp::Less, ast::ScalarKind::Unsigned)
| (ast::SetpCompareOp::Less, ast::ScalarKind::Bit) => {
builder.u_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Signed) => {
(ast::SetpCompareOp::Less, ast::ScalarKind::Signed) => {
builder.s_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Float) => {
(ast::SetpCompareOp::Less, ast::ScalarKind::Float) => {
builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned)
| (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => {
(ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Unsigned)
| (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Bit) => {
builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => {
(ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Signed) => {
builder.s_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => {
(ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Float) => {
builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Unsigned)
| (ast::SetpCompareOp::Greater, ScalarKind::Bit) => {
(ast::SetpCompareOp::Greater, ast::ScalarKind::Unsigned)
| (ast::SetpCompareOp::Greater, ast::ScalarKind::Bit) => {
builder.u_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Signed) => {
(ast::SetpCompareOp::Greater, ast::ScalarKind::Signed) => {
builder.s_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Float) => {
(ast::SetpCompareOp::Greater, ast::ScalarKind::Float) => {
builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned)
| (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => {
(ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Unsigned)
| (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Bit) => {
builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => {
(ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Signed) => {
builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
(ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Float) => {
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanEq, _) => {
@ -4222,7 +4224,7 @@ fn emit_abs(
) -> Result<(), dr::Error> {
let scalar_t = ast::ScalarType::from(d.typ);
let result_type = map.get_or_add(builder, SpirvType::from(scalar_t));
let cl_abs = if scalar_t.kind() == ScalarKind::Signed {
let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed {
spirv::CLOp::s_abs
} else {
spirv::CLOp::fabs
@ -4286,8 +4288,8 @@ fn emit_implicit_conversion(
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
if from_parts.width == to_parts.width {
let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
if from_parts.scalar_kind != ScalarKind::Float
&& to_parts.scalar_kind != ScalarKind::Float
if from_parts.scalar_kind != ast::ScalarKind::Float
&& to_parts.scalar_kind != ast::ScalarKind::Float
{
// It is noop, but another instruction expects result of this conversion
builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
@ -4299,24 +4301,24 @@ fn emit_implicit_conversion(
let same_width_bit_type = map.get_or_add(
builder,
SpirvType::from(ast::Type::from_parts(TypeParts {
scalar_kind: ScalarKind::Bit,
scalar_kind: ast::ScalarKind::Bit,
..from_parts
})),
);
let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?;
let wide_bit_type = ast::Type::from_parts(TypeParts {
scalar_kind: ScalarKind::Bit,
scalar_kind: ast::ScalarKind::Bit,
..to_parts
});
let wide_bit_type_spirv =
map.get_or_add(builder, SpirvType::from(wide_bit_type.clone()));
if to_parts.scalar_kind == ScalarKind::Unsigned
|| to_parts.scalar_kind == ScalarKind::Bit
if to_parts.scalar_kind == ast::ScalarKind::Unsigned
|| to_parts.scalar_kind == ast::ScalarKind::Bit
{
builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?;
} else {
let conversion_fn = if from_parts.scalar_kind == ScalarKind::Signed
&& to_parts.scalar_kind == ScalarKind::Signed
let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed
&& to_parts.scalar_kind == ast::ScalarKind::Signed
{
dr::Builder::s_convert
} else {
@ -4614,23 +4616,23 @@ fn convert_to_stateful_memory_access<'a>(
for statement in func_body.iter() {
match statement {
Statement::Instruction(ast::Instruction::Add(
ast::ArithDetails::Unsigned(ast::UIntType::U64),
ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Add(
ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::S64,
typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
))
| Statement::Instruction(ast::Instruction::Sub(
ast::ArithDetails::Unsigned(ast::UIntType::U64),
ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Sub(
ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::S64,
typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
@ -4686,12 +4688,12 @@ fn convert_to_stateful_memory_access<'a>(
}
}
Statement::Instruction(ast::Instruction::Add(
ast::ArithDetails::Unsigned(ast::UIntType::U64),
ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Add(
ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::S64,
typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
@ -4715,12 +4717,12 @@ fn convert_to_stateful_memory_access<'a>(
}))
}
Statement::Instruction(ast::Instruction::Sub(
ast::ArithDetails::Unsigned(ast::UIntType::U64),
ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Sub(
ast::ArithDetails::Signed(ast::ArithSInt {
typ: ast::SIntType::S64,
typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
@ -4867,7 +4869,7 @@ fn convert_to_stateful_memory_access_postprocess(
ast::LdStateSpace::Global,
),
to: old_type,
kind: ConversionKind::PtrToBit(ast::UIntType::U64),
kind: ConversionKind::PtrToBit(ast::ScalarType::U64),
src_sema: arg_desc.sema,
dst_sema: ArgumentSemantics::Default,
}));
@ -5903,7 +5905,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
let inst_type = d.typ;
ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
}
ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?),
ast::Instruction::Not(t, a) => {
ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?)
}
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (
@ -5926,7 +5930,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?)
}
ast::Instruction::Shl(t, a) => {
ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?)
ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?)
}
ast::Instruction::Shr(t, a) => {
ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?)
@ -6176,9 +6180,9 @@ impl ast::Type {
ast::Type::Scalar(scalar) => {
let kind = scalar.kind();
let width = scalar.size_of();
if (kind != ScalarKind::Signed
&& kind != ScalarKind::Unsigned
&& kind != ScalarKind::Bit)
if (kind != ast::ScalarKind::Signed
&& kind != ast::ScalarKind::Unsigned
&& kind != ast::ScalarKind::Bit)
|| (width == 8)
{
return Err(TranslateError::MismatchedType);
@ -6306,7 +6310,7 @@ impl ast::Type {
#[derive(Eq, PartialEq, Clone)]
struct TypeParts {
kind: TypeKind,
scalar_kind: ScalarKind,
scalar_kind: ast::ScalarKind,
width: u8,
components: Vec<u32>,
state_space: ast::LdStateSpace,
@ -6461,7 +6465,7 @@ enum ConversionKind {
// zero-extend/chop/bitcast depending on types
SignExtend,
BitToPtr(ast::LdStateSpace),
PtrToBit(ast::UIntType),
PtrToBit(ast::ScalarType),
PtrToPtr { spirv_ptr: bool },
}
@ -6859,7 +6863,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
fn map_selp<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::SelpType,
t: ast::ScalarType,
) -> Result<ast::Arg4<U>, TranslateError> {
let dst = visitor.operand(
ArgumentDescriptor {
@ -6904,7 +6908,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
fn map_atom<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::BitType,
t: ast::ScalarType,
state_space: ast::AtomSpace,
) -> Result<ast::Arg4<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
@ -7205,103 +7209,41 @@ impl ast::StStateSpace {
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum ScalarKind {
Bit,
Unsigned,
Signed,
Float,
Float2,
Pred,
}
impl ast::ScalarType {
fn kind(self) -> ScalarKind {
match self {
ast::ScalarType::U8 => ScalarKind::Unsigned,
ast::ScalarType::U16 => ScalarKind::Unsigned,
ast::ScalarType::U32 => ScalarKind::Unsigned,
ast::ScalarType::U64 => ScalarKind::Unsigned,
ast::ScalarType::S8 => ScalarKind::Signed,
ast::ScalarType::S16 => ScalarKind::Signed,
ast::ScalarType::S32 => ScalarKind::Signed,
ast::ScalarType::S64 => ScalarKind::Signed,
ast::ScalarType::B8 => ScalarKind::Bit,
ast::ScalarType::B16 => ScalarKind::Bit,
ast::ScalarType::B32 => ScalarKind::Bit,
ast::ScalarType::B64 => ScalarKind::Bit,
ast::ScalarType::F16 => ScalarKind::Float,
ast::ScalarType::F32 => ScalarKind::Float,
ast::ScalarType::F64 => ScalarKind::Float,
ast::ScalarType::F16x2 => ScalarKind::Float2,
ast::ScalarType::Pred => ScalarKind::Pred,
}
}
fn from_parts(width: u8, kind: ScalarKind) -> Self {
fn from_parts(width: u8, kind: ast::ScalarKind) -> Self {
match kind {
ScalarKind::Float => match width {
ast::ScalarKind::Float => match width {
2 => ast::ScalarType::F16,
4 => ast::ScalarType::F32,
8 => ast::ScalarType::F64,
_ => unreachable!(),
},
ScalarKind::Bit => match width {
ast::ScalarKind::Bit => match width {
1 => ast::ScalarType::B8,
2 => ast::ScalarType::B16,
4 => ast::ScalarType::B32,
8 => ast::ScalarType::B64,
_ => unreachable!(),
},
ScalarKind::Signed => match width {
ast::ScalarKind::Signed => match width {
1 => ast::ScalarType::S8,
2 => ast::ScalarType::S16,
4 => ast::ScalarType::S32,
8 => ast::ScalarType::S64,
_ => unreachable!(),
},
ScalarKind::Unsigned => match width {
ast::ScalarKind::Unsigned => match width {
1 => ast::ScalarType::U8,
2 => ast::ScalarType::U16,
4 => ast::ScalarType::U32,
8 => ast::ScalarType::U64,
_ => unreachable!(),
},
ScalarKind::Float2 => match width {
ast::ScalarKind::Float2 => match width {
4 => ast::ScalarType::F16x2,
_ => unreachable!(),
},
ScalarKind::Pred => ast::ScalarType::Pred,
}
}
}
impl ast::BooleanType {
fn to_type(self) -> ast::Type {
match self {
ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
}
}
}
impl ast::ShlType {
fn to_type(self) -> ast::Type {
match self {
ast::ShlType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
ast::ShlType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
ast::ShlType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
}
}
}
impl ast::ShrType {
fn signed(&self) -> bool {
match self {
ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true,
_ => false,
ast::ScalarKind::Pred => ast::ScalarType::Pred,
}
}
}
@ -7357,30 +7299,6 @@ impl ast::AtomInnerDetails {
}
}
impl ast::SIntType {
fn from_size(width: u8) -> Self {
match width {
1 => ast::SIntType::S8,
2 => ast::SIntType::S16,
4 => ast::SIntType::S32,
8 => ast::SIntType::S64,
_ => unreachable!(),
}
}
}
impl ast::UIntType {
fn from_size(width: u8) -> Self {
match width {
1 => ast::UIntType::U8,
2 => ast::UIntType::U16,
4 => ast::UIntType::U32,
8 => ast::UIntType::U64,
_ => unreachable!(),
}
}
}
impl ast::LdStateSpace {
fn to_spirv(self) -> spirv::StorageClass {
match self {
@ -7568,16 +7486,18 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
return false;
}
match inst.kind() {
ScalarKind::Bit => operand.kind() != ScalarKind::Bit,
ScalarKind::Float => operand.kind() == ScalarKind::Bit,
ScalarKind::Signed => {
operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
ast::ScalarKind::Signed => {
operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Unsigned
}
ScalarKind::Unsigned => {
operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed
ast::ScalarKind::Unsigned => {
operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Signed
}
ScalarKind::Float2 => false,
ScalarKind::Pred => false,
ast::ScalarKind::Float2 => false,
ast::ScalarKind::Pred => false,
}
}
(ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
@ -7596,7 +7516,7 @@ fn should_bitcast_packed(
if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
(operand, instr)
{
if scalar.kind() == ScalarKind::Bit
if scalar.kind() == ast::ScalarKind::Bit
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
{
return Ok(Some(ConversionKind::Default));
@ -7644,32 +7564,33 @@ fn should_convert_relaxed_src(
}
match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ScalarKind::Bit => {
ast::ScalarKind::Bit => {
if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Signed | ScalarKind::Unsigned => {
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() != ScalarKind::Float
&& src_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Float => {
if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit
ast::ScalarKind::Float => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None,
ast::ScalarKind::Float2 => todo!(),
ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
@ -7706,15 +7627,15 @@ fn should_convert_relaxed_dst(
}
match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ScalarKind::Bit => {
ast::ScalarKind::Bit => {
if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Signed => {
if dst_type.kind() != ScalarKind::Float {
ast::ScalarKind::Signed => {
if dst_type.kind() != ast::ScalarKind::Float {
if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default)
} else if instr_type.size_of() < dst_type.size_of() {
@ -7726,25 +7647,26 @@ fn should_convert_relaxed_dst(
None
}
}
ScalarKind::Unsigned => {
ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() != ScalarKind::Float
&& dst_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Float => {
if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit
ast::ScalarKind::Float => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None,
ast::ScalarKind::Float2 => todo!(),
ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {