mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 08:24:44 +00:00
Convert enumes to 1TT
This commit is contained in:
parent
a005c92c61
commit
a0baad9456
4 changed files with 332 additions and 462 deletions
224
ptx/src/ast.rs
224
ptx/src/ast.rs
|
@ -210,20 +210,6 @@ sub_enum!(LdStScalarType {
|
|||
F64,
|
||||
});
|
||||
|
||||
sub_enum!(SelpType {
|
||||
B16,
|
||||
B32,
|
||||
B64,
|
||||
U16,
|
||||
U32,
|
||||
U64,
|
||||
S16,
|
||||
S32,
|
||||
S64,
|
||||
F32,
|
||||
F64,
|
||||
});
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
pub enum BarDetails {
|
||||
SyncAligned,
|
||||
|
@ -425,52 +411,6 @@ pub enum ScalarType {
|
|||
Pred,
|
||||
}
|
||||
|
||||
sub_enum!(IntType {
|
||||
U8,
|
||||
U16,
|
||||
U32,
|
||||
U64,
|
||||
S8,
|
||||
S16,
|
||||
S32,
|
||||
S64
|
||||
});
|
||||
|
||||
sub_enum!(BitType { B8, B16, B32, B64 });
|
||||
|
||||
sub_enum!(UIntType { U8, U16, U32, U64 });
|
||||
|
||||
sub_enum!(SIntType { S8, S16, S32, S64 });
|
||||
|
||||
impl IntType {
|
||||
pub fn is_signed(self) -> bool {
|
||||
match self {
|
||||
IntType::U8 | IntType::U16 | IntType::U32 | IntType::U64 => false,
|
||||
IntType::S8 | IntType::S16 | IntType::S32 | IntType::S64 => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn width(self) -> u8 {
|
||||
match self {
|
||||
IntType::U8 => 1,
|
||||
IntType::U16 => 2,
|
||||
IntType::U32 => 4,
|
||||
IntType::U64 => 8,
|
||||
IntType::S8 => 1,
|
||||
IntType::S16 => 2,
|
||||
IntType::S32 => 4,
|
||||
IntType::S64 => 8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sub_enum!(FloatType {
|
||||
F16,
|
||||
F16x2,
|
||||
F32,
|
||||
F64
|
||||
});
|
||||
|
||||
impl ScalarType {
|
||||
pub fn size_of(self) -> u8 {
|
||||
match self {
|
||||
|
@ -576,24 +516,24 @@ pub enum Instruction<P: ArgParams> {
|
|||
Add(ArithDetails, Arg3<P>),
|
||||
Setp(SetpData, Arg4Setp<P>),
|
||||
SetpBool(SetpBoolData, Arg5Setp<P>),
|
||||
Not(BooleanType, Arg2<P>),
|
||||
Not(ScalarType, Arg2<P>),
|
||||
Bra(BraData, Arg1<P>),
|
||||
Cvt(CvtDetails, Arg2<P>),
|
||||
Cvta(CvtaDetails, Arg2<P>),
|
||||
Shl(ShlType, Arg3<P>),
|
||||
Shr(ShrType, Arg3<P>),
|
||||
Shl(ScalarType, Arg3<P>),
|
||||
Shr(ScalarType, Arg3<P>),
|
||||
St(StData, Arg2St<P>),
|
||||
Ret(RetData),
|
||||
Call(CallInst<P>),
|
||||
Abs(AbsDetails, Arg2<P>),
|
||||
Mad(MulDetails, Arg4<P>),
|
||||
Or(BooleanType, Arg3<P>),
|
||||
Or(ScalarType, Arg3<P>),
|
||||
Sub(ArithDetails, Arg3<P>),
|
||||
Min(MinMaxDetails, Arg3<P>),
|
||||
Max(MinMaxDetails, Arg3<P>),
|
||||
Rcp(RcpDetails, Arg2<P>),
|
||||
And(BooleanType, Arg3<P>),
|
||||
Selp(SelpType, Arg4<P>),
|
||||
And(ScalarType, Arg3<P>),
|
||||
Selp(ScalarType, Arg4<P>),
|
||||
Bar(BarDetails, Arg1Bar<P>),
|
||||
Atom(AtomDetails, Arg3<P>),
|
||||
AtomCas(AtomCasDetails, Arg4<P>),
|
||||
|
@ -605,13 +545,13 @@ pub enum Instruction<P: ArgParams> {
|
|||
Cos { flush_to_zero: bool, arg: Arg2<P> },
|
||||
Lg2 { flush_to_zero: bool, arg: Arg2<P> },
|
||||
Ex2 { flush_to_zero: bool, arg: Arg2<P> },
|
||||
Clz { typ: BitType, arg: Arg2<P> },
|
||||
Brev { typ: BitType, arg: Arg2<P> },
|
||||
Popc { typ: BitType, arg: Arg2<P> },
|
||||
Xor { typ: BooleanType, arg: Arg3<P> },
|
||||
Bfe { typ: IntType, arg: Arg4<P> },
|
||||
Bfi { typ: BitType, arg: Arg5<P> },
|
||||
Rem { typ: IntType, arg: Arg3<P> },
|
||||
Clz { typ: ScalarType, arg: Arg2<P> },
|
||||
Brev { typ: ScalarType, arg: Arg2<P> },
|
||||
Popc { typ: ScalarType, arg: Arg2<P> },
|
||||
Xor { typ: ScalarType, arg: Arg3<P> },
|
||||
Bfe { typ: ScalarType, arg: Arg4<P> },
|
||||
Bfi { typ: ScalarType, arg: Arg5<P> },
|
||||
Rem { typ: ScalarType, arg: Arg3<P> },
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
|
@ -825,7 +765,7 @@ impl MovDetails {
|
|||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MulIntDesc {
|
||||
pub typ: IntType,
|
||||
pub typ: ScalarType,
|
||||
pub control: MulIntControl,
|
||||
}
|
||||
|
||||
|
@ -845,7 +785,7 @@ pub enum RoundingMode {
|
|||
}
|
||||
|
||||
pub struct AddIntDesc {
|
||||
pub typ: IntType,
|
||||
pub typ: ScalarType,
|
||||
pub saturate: bool,
|
||||
}
|
||||
|
||||
|
@ -892,39 +832,39 @@ pub struct BraData {
|
|||
|
||||
pub enum CvtDetails {
|
||||
IntFromInt(CvtIntToIntDesc),
|
||||
FloatFromFloat(CvtDesc<FloatType, FloatType>),
|
||||
IntFromFloat(CvtDesc<IntType, FloatType>),
|
||||
FloatFromInt(CvtDesc<FloatType, IntType>),
|
||||
FloatFromFloat(CvtDesc),
|
||||
IntFromFloat(CvtDesc),
|
||||
FloatFromInt(CvtDesc),
|
||||
}
|
||||
|
||||
pub struct CvtIntToIntDesc {
|
||||
pub dst: IntType,
|
||||
pub src: IntType,
|
||||
pub dst: ScalarType,
|
||||
pub src: ScalarType,
|
||||
pub saturate: bool,
|
||||
}
|
||||
|
||||
pub struct CvtDesc<Dst, Src> {
|
||||
pub struct CvtDesc {
|
||||
pub rounding: Option<RoundingMode>,
|
||||
pub flush_to_zero: Option<bool>,
|
||||
pub saturate: bool,
|
||||
pub dst: Dst,
|
||||
pub src: Src,
|
||||
pub dst: ScalarType,
|
||||
pub src: ScalarType,
|
||||
}
|
||||
|
||||
impl CvtDetails {
|
||||
pub fn new_int_from_int_checked<'err, 'input>(
|
||||
saturate: bool,
|
||||
dst: IntType,
|
||||
src: IntType,
|
||||
dst: ScalarType,
|
||||
src: ScalarType,
|
||||
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
|
||||
) -> Self {
|
||||
if saturate {
|
||||
if src.is_signed() {
|
||||
if dst.is_signed() && dst.width() >= src.width() {
|
||||
if src.kind() == ScalarKind::Signed {
|
||||
if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() {
|
||||
err.push(ParseError::from(PtxError::SyntaxError));
|
||||
}
|
||||
} else {
|
||||
if dst == src || dst.width() >= src.width() {
|
||||
if dst == src || dst.size_of() >= src.size_of() {
|
||||
err.push(ParseError::from(PtxError::SyntaxError));
|
||||
}
|
||||
}
|
||||
|
@ -936,11 +876,11 @@ impl CvtDetails {
|
|||
rounding: RoundingMode,
|
||||
flush_to_zero: bool,
|
||||
saturate: bool,
|
||||
dst: FloatType,
|
||||
src: IntType,
|
||||
dst: ScalarType,
|
||||
src: ScalarType,
|
||||
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
|
||||
) -> Self {
|
||||
if flush_to_zero && dst != FloatType::F32 {
|
||||
if flush_to_zero && dst != ScalarType::F32 {
|
||||
err.push(ParseError::from(PtxError::NonF32Ftz));
|
||||
}
|
||||
CvtDetails::FloatFromInt(CvtDesc {
|
||||
|
@ -956,11 +896,11 @@ impl CvtDetails {
|
|||
rounding: RoundingMode,
|
||||
flush_to_zero: bool,
|
||||
saturate: bool,
|
||||
dst: IntType,
|
||||
src: FloatType,
|
||||
dst: ScalarType,
|
||||
src: ScalarType,
|
||||
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
|
||||
) -> Self {
|
||||
if flush_to_zero && src != FloatType::F32 {
|
||||
if flush_to_zero && src != ScalarType::F32 {
|
||||
err.push(ParseError::from(PtxError::NonF32Ftz));
|
||||
}
|
||||
CvtDetails::IntFromFloat(CvtDesc {
|
||||
|
@ -993,25 +933,6 @@ pub enum CvtaSize {
|
|||
U64,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||
pub enum ShlType {
|
||||
B16,
|
||||
B32,
|
||||
B64,
|
||||
}
|
||||
|
||||
sub_enum!(ShrType {
|
||||
B16,
|
||||
B32,
|
||||
B64,
|
||||
U16,
|
||||
U32,
|
||||
U64,
|
||||
S16,
|
||||
S32,
|
||||
S64,
|
||||
});
|
||||
|
||||
pub struct StData {
|
||||
pub qualifier: LdStQualifier,
|
||||
pub state_space: StStateSpace,
|
||||
|
@ -1040,13 +961,6 @@ pub struct RetData {
|
|||
pub uniform: bool,
|
||||
}
|
||||
|
||||
sub_enum!(BooleanType {
|
||||
Pred,
|
||||
B16,
|
||||
B32,
|
||||
B64,
|
||||
});
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum MulDetails {
|
||||
Unsigned(MulUInt),
|
||||
|
@ -1056,32 +970,32 @@ pub enum MulDetails {
|
|||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MulUInt {
|
||||
pub typ: UIntType,
|
||||
pub typ: ScalarType,
|
||||
pub control: MulIntControl,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MulSInt {
|
||||
pub typ: SIntType,
|
||||
pub typ: ScalarType,
|
||||
pub control: MulIntControl,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum ArithDetails {
|
||||
Unsigned(UIntType),
|
||||
Unsigned(ScalarType),
|
||||
Signed(ArithSInt),
|
||||
Float(ArithFloat),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ArithSInt {
|
||||
pub typ: SIntType,
|
||||
pub typ: ScalarType,
|
||||
pub saturate: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ArithFloat {
|
||||
pub typ: FloatType,
|
||||
pub typ: ScalarType,
|
||||
pub rounding: Option<RoundingMode>,
|
||||
pub flush_to_zero: Option<bool>,
|
||||
pub saturate: bool,
|
||||
|
@ -1089,8 +1003,8 @@ pub struct ArithFloat {
|
|||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum MinMaxDetails {
|
||||
Signed(SIntType),
|
||||
Unsigned(UIntType),
|
||||
Signed(ScalarType),
|
||||
Unsigned(ScalarType),
|
||||
Float(MinMaxFloat),
|
||||
}
|
||||
|
||||
|
@ -1098,7 +1012,7 @@ pub enum MinMaxDetails {
|
|||
pub struct MinMaxFloat {
|
||||
pub flush_to_zero: Option<bool>,
|
||||
pub nan: bool,
|
||||
pub typ: FloatType,
|
||||
pub typ: ScalarType,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
|
@ -1126,10 +1040,10 @@ pub enum AtomSpace {
|
|||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum AtomInnerDetails {
|
||||
Bit { op: AtomBitOp, typ: BitType },
|
||||
Unsigned { op: AtomUIntOp, typ: UIntType },
|
||||
Signed { op: AtomSIntOp, typ: SIntType },
|
||||
Float { op: AtomFloatOp, typ: FloatType },
|
||||
Bit { op: AtomBitOp, typ: ScalarType },
|
||||
Unsigned { op: AtomUIntOp, typ: ScalarType },
|
||||
Signed { op: AtomSIntOp, typ: ScalarType },
|
||||
Float { op: AtomFloatOp, typ: ScalarType },
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
|
@ -1166,19 +1080,19 @@ pub struct AtomCasDetails {
|
|||
pub semantics: AtomSemantics,
|
||||
pub scope: MemScope,
|
||||
pub space: AtomSpace,
|
||||
pub typ: BitType,
|
||||
pub typ: ScalarType,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum DivDetails {
|
||||
Unsigned(UIntType),
|
||||
Signed(SIntType),
|
||||
Unsigned(ScalarType),
|
||||
Signed(ScalarType),
|
||||
Float(DivFloatDetails),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct DivFloatDetails {
|
||||
pub typ: FloatType,
|
||||
pub typ: ScalarType,
|
||||
pub flush_to_zero: Option<bool>,
|
||||
pub kind: DivFloatKind,
|
||||
}
|
||||
|
@ -1197,7 +1111,7 @@ pub enum NumsOrArrays<'a> {
|
|||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct SqrtDetails {
|
||||
pub typ: FloatType,
|
||||
pub typ: ScalarType,
|
||||
pub flush_to_zero: Option<bool>,
|
||||
pub kind: SqrtKind,
|
||||
}
|
||||
|
@ -1210,7 +1124,7 @@ pub enum SqrtKind {
|
|||
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
pub struct RsqrtDetails {
|
||||
pub typ: FloatType,
|
||||
pub typ: ScalarType,
|
||||
pub flush_to_zero: bool,
|
||||
}
|
||||
|
||||
|
@ -1379,6 +1293,40 @@ pub enum TuningDirective {
|
|||
MinNCtaPerSm(u32),
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ScalarKind {
|
||||
Bit,
|
||||
Unsigned,
|
||||
Signed,
|
||||
Float,
|
||||
Float2,
|
||||
Pred,
|
||||
}
|
||||
|
||||
impl ScalarType {
|
||||
pub fn kind(self) -> ScalarKind {
|
||||
match self {
|
||||
ScalarType::U8 => ScalarKind::Unsigned,
|
||||
ScalarType::U16 => ScalarKind::Unsigned,
|
||||
ScalarType::U32 => ScalarKind::Unsigned,
|
||||
ScalarType::U64 => ScalarKind::Unsigned,
|
||||
ScalarType::S8 => ScalarKind::Signed,
|
||||
ScalarType::S16 => ScalarKind::Signed,
|
||||
ScalarType::S32 => ScalarKind::Signed,
|
||||
ScalarType::S64 => ScalarKind::Signed,
|
||||
ScalarType::B8 => ScalarKind::Bit,
|
||||
ScalarType::B16 => ScalarKind::Bit,
|
||||
ScalarType::B32 => ScalarKind::Bit,
|
||||
ScalarType::B64 => ScalarKind::Bit,
|
||||
ScalarType::F16 => ScalarKind::Float,
|
||||
ScalarType::F32 => ScalarKind::Float,
|
||||
ScalarType::F64 => ScalarKind::Float,
|
||||
ScalarType::F16x2 => ScalarKind::Float2,
|
||||
ScalarType::Pred => ScalarKind::Pred,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -899,39 +899,39 @@ RoundingModeInt : ast::RoundingMode = {
|
|||
".rpi" => ast::RoundingMode::PositiveInf,
|
||||
};
|
||||
|
||||
IntType : ast::IntType = {
|
||||
".u16" => ast::IntType::U16,
|
||||
".u32" => ast::IntType::U32,
|
||||
".u64" => ast::IntType::U64,
|
||||
".s16" => ast::IntType::S16,
|
||||
".s32" => ast::IntType::S32,
|
||||
".s64" => ast::IntType::S64,
|
||||
IntType : ast::ScalarType = {
|
||||
".u16" => ast::ScalarType::U16,
|
||||
".u32" => ast::ScalarType::U32,
|
||||
".u64" => ast::ScalarType::U64,
|
||||
".s16" => ast::ScalarType::S16,
|
||||
".s32" => ast::ScalarType::S32,
|
||||
".s64" => ast::ScalarType::S64,
|
||||
};
|
||||
|
||||
IntType3264: ast::IntType = {
|
||||
".u32" => ast::IntType::U32,
|
||||
".u64" => ast::IntType::U64,
|
||||
".s32" => ast::IntType::S32,
|
||||
".s64" => ast::IntType::S64,
|
||||
IntType3264: ast::ScalarType = {
|
||||
".u32" => ast::ScalarType::U32,
|
||||
".u64" => ast::ScalarType::U64,
|
||||
".s32" => ast::ScalarType::S32,
|
||||
".s64" => ast::ScalarType::S64,
|
||||
}
|
||||
|
||||
UIntType: ast::UIntType = {
|
||||
".u16" => ast::UIntType::U16,
|
||||
".u32" => ast::UIntType::U32,
|
||||
".u64" => ast::UIntType::U64,
|
||||
UIntType: ast::ScalarType = {
|
||||
".u16" => ast::ScalarType::U16,
|
||||
".u32" => ast::ScalarType::U32,
|
||||
".u64" => ast::ScalarType::U64,
|
||||
};
|
||||
|
||||
SIntType: ast::SIntType = {
|
||||
".s16" => ast::SIntType::S16,
|
||||
".s32" => ast::SIntType::S32,
|
||||
".s64" => ast::SIntType::S64,
|
||||
SIntType: ast::ScalarType = {
|
||||
".s16" => ast::ScalarType::S16,
|
||||
".s32" => ast::ScalarType::S32,
|
||||
".s64" => ast::ScalarType::S64,
|
||||
};
|
||||
|
||||
FloatType: ast::FloatType = {
|
||||
".f16" => ast::FloatType::F16,
|
||||
".f16x2" => ast::FloatType::F16x2,
|
||||
".f32" => ast::FloatType::F32,
|
||||
".f64" => ast::FloatType::F64,
|
||||
FloatType: ast::ScalarType = {
|
||||
".f16" => ast::ScalarType::F16,
|
||||
".f16x2" => ast::ScalarType::F16x2,
|
||||
".f32" => ast::ScalarType::F32,
|
||||
".f64" => ast::ScalarType::F64,
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
|
||||
|
@ -1023,11 +1023,11 @@ InstNot: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
"not" <t:BooleanType> <a:Arg2> => ast::Instruction::Not(t, a)
|
||||
};
|
||||
|
||||
BooleanType: ast::BooleanType = {
|
||||
".pred" => ast::BooleanType::Pred,
|
||||
".b16" => ast::BooleanType::B16,
|
||||
".b32" => ast::BooleanType::B32,
|
||||
".b64" => ast::BooleanType::B64,
|
||||
BooleanType: ast::ScalarType = {
|
||||
".pred" => ast::ScalarType::Pred,
|
||||
".b16" => ast::ScalarType::B16,
|
||||
".b32" => ast::ScalarType::B32,
|
||||
".b64" => ast::ScalarType::B64,
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at
|
||||
|
@ -1080,8 +1080,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
rounding: r,
|
||||
flush_to_zero: None,
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F16,
|
||||
src: ast::FloatType::F16
|
||||
dst: ast::ScalarType::F16,
|
||||
src: ast::ScalarType::F16
|
||||
}
|
||||
), a)
|
||||
},
|
||||
|
@ -1091,8 +1091,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
rounding: None,
|
||||
flush_to_zero: Some(f.is_some()),
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F32,
|
||||
src: ast::FloatType::F16
|
||||
dst: ast::ScalarType::F32,
|
||||
src: ast::ScalarType::F16
|
||||
}
|
||||
), a)
|
||||
},
|
||||
|
@ -1102,8 +1102,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
rounding: None,
|
||||
flush_to_zero: None,
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F64,
|
||||
src: ast::FloatType::F16
|
||||
dst: ast::ScalarType::F64,
|
||||
src: ast::ScalarType::F16
|
||||
}
|
||||
), a)
|
||||
},
|
||||
|
@ -1113,8 +1113,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
rounding: Some(r),
|
||||
flush_to_zero: Some(f.is_some()),
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F16,
|
||||
src: ast::FloatType::F32
|
||||
dst: ast::ScalarType::F16,
|
||||
src: ast::ScalarType::F32
|
||||
}
|
||||
), a)
|
||||
},
|
||||
|
@ -1124,8 +1124,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
rounding: r,
|
||||
flush_to_zero: Some(f.is_some()),
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F32,
|
||||
src: ast::FloatType::F32
|
||||
dst: ast::ScalarType::F32,
|
||||
src: ast::ScalarType::F32
|
||||
}
|
||||
), a)
|
||||
},
|
||||
|
@ -1135,8 +1135,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
rounding: None,
|
||||
flush_to_zero: None,
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F64,
|
||||
src: ast::FloatType::F32
|
||||
dst: ast::ScalarType::F64,
|
||||
src: ast::ScalarType::F32
|
||||
}
|
||||
), a)
|
||||
},
|
||||
|
@ -1146,8 +1146,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
rounding: Some(r),
|
||||
flush_to_zero: None,
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F16,
|
||||
src: ast::FloatType::F64
|
||||
dst: ast::ScalarType::F16,
|
||||
src: ast::ScalarType::F64
|
||||
}
|
||||
), a)
|
||||
},
|
||||
|
@ -1157,8 +1157,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
rounding: Some(r),
|
||||
flush_to_zero: Some(s.is_some()),
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F32,
|
||||
src: ast::FloatType::F64
|
||||
dst: ast::ScalarType::F32,
|
||||
src: ast::ScalarType::F64
|
||||
}
|
||||
), a)
|
||||
},
|
||||
|
@ -1168,28 +1168,28 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
rounding: r,
|
||||
flush_to_zero: None,
|
||||
saturate: s.is_some(),
|
||||
dst: ast::FloatType::F64,
|
||||
src: ast::FloatType::F64
|
||||
dst: ast::ScalarType::F64,
|
||||
src: ast::ScalarType::F64
|
||||
}
|
||||
), a)
|
||||
},
|
||||
};
|
||||
|
||||
CvtTypeInt: ast::IntType = {
|
||||
".u8" => ast::IntType::U8,
|
||||
".u16" => ast::IntType::U16,
|
||||
".u32" => ast::IntType::U32,
|
||||
".u64" => ast::IntType::U64,
|
||||
".s8" => ast::IntType::S8,
|
||||
".s16" => ast::IntType::S16,
|
||||
".s32" => ast::IntType::S32,
|
||||
".s64" => ast::IntType::S64,
|
||||
CvtTypeInt: ast::ScalarType = {
|
||||
".u8" => ast::ScalarType::U8,
|
||||
".u16" => ast::ScalarType::U16,
|
||||
".u32" => ast::ScalarType::U32,
|
||||
".u64" => ast::ScalarType::U64,
|
||||
".s8" => ast::ScalarType::S8,
|
||||
".s16" => ast::ScalarType::S16,
|
||||
".s32" => ast::ScalarType::S32,
|
||||
".s64" => ast::ScalarType::S64,
|
||||
};
|
||||
|
||||
CvtTypeFloat: ast::FloatType = {
|
||||
".f16" => ast::FloatType::F16,
|
||||
".f32" => ast::FloatType::F32,
|
||||
".f64" => ast::FloatType::F64,
|
||||
CvtTypeFloat: ast::ScalarType = {
|
||||
".f16" => ast::ScalarType::F16,
|
||||
".f32" => ast::ScalarType::F32,
|
||||
".f64" => ast::ScalarType::F64,
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl
|
||||
|
@ -1197,10 +1197,10 @@ InstShl: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
"shl" <t:ShlType> <a:Arg3> => ast::Instruction::Shl(t, a)
|
||||
};
|
||||
|
||||
ShlType: ast::ShlType = {
|
||||
".b16" => ast::ShlType::B16,
|
||||
".b32" => ast::ShlType::B32,
|
||||
".b64" => ast::ShlType::B64,
|
||||
ShlType: ast::ScalarType = {
|
||||
".b16" => ast::ScalarType::B16,
|
||||
".b32" => ast::ScalarType::B32,
|
||||
".b64" => ast::ScalarType::B64,
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr
|
||||
|
@ -1208,16 +1208,16 @@ InstShr: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
"shr" <t:ShrType> <a:Arg3> => ast::Instruction::Shr(t, a)
|
||||
};
|
||||
|
||||
ShrType: ast::ShrType = {
|
||||
".b16" => ast::ShrType::B16,
|
||||
".b32" => ast::ShrType::B32,
|
||||
".b64" => ast::ShrType::B64,
|
||||
".u16" => ast::ShrType::U16,
|
||||
".u32" => ast::ShrType::U32,
|
||||
".u64" => ast::ShrType::U64,
|
||||
".s16" => ast::ShrType::S16,
|
||||
".s32" => ast::ShrType::S32,
|
||||
".s64" => ast::ShrType::S64,
|
||||
ShrType: ast::ScalarType = {
|
||||
".b16" => ast::ScalarType::B16,
|
||||
".b32" => ast::ScalarType::B32,
|
||||
".b64" => ast::ScalarType::B64,
|
||||
".u16" => ast::ScalarType::U16,
|
||||
".u32" => ast::ScalarType::U32,
|
||||
".u64" => ast::ScalarType::U64,
|
||||
".s16" => ast::ScalarType::S16,
|
||||
".s32" => ast::ScalarType::S32,
|
||||
".s64" => ast::ScalarType::S64,
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
|
||||
|
@ -1393,16 +1393,16 @@ MinMaxDetails: ast::MinMaxDetails = {
|
|||
<t:UIntType> => ast::MinMaxDetails::Unsigned(t),
|
||||
<t:SIntType> => ast::MinMaxDetails::Signed(t),
|
||||
<ftz:".ftz"?> <nan:".NaN"?> ".f32" => ast::MinMaxDetails::Float(
|
||||
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F32 }
|
||||
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F32 }
|
||||
),
|
||||
".f64" => ast::MinMaxDetails::Float(
|
||||
ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::FloatType::F64 }
|
||||
ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::ScalarType::F64 }
|
||||
),
|
||||
<ftz:".ftz"?> <nan:".NaN"?> ".f16" => ast::MinMaxDetails::Float(
|
||||
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16 }
|
||||
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16 }
|
||||
),
|
||||
<ftz:".ftz"?> <nan:".NaN"?> ".f16x2" => ast::MinMaxDetails::Float(
|
||||
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16x2 }
|
||||
ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16x2 }
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -1411,18 +1411,18 @@ InstSelp: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
"selp" <t:SelpType> <a:Arg4> => ast::Instruction::Selp(t, a),
|
||||
};
|
||||
|
||||
SelpType: ast::SelpType = {
|
||||
".b16" => ast::SelpType::B16,
|
||||
".b32" => ast::SelpType::B32,
|
||||
".b64" => ast::SelpType::B64,
|
||||
".u16" => ast::SelpType::U16,
|
||||
".u32" => ast::SelpType::U32,
|
||||
".u64" => ast::SelpType::U64,
|
||||
".s16" => ast::SelpType::S16,
|
||||
".s32" => ast::SelpType::S32,
|
||||
".s64" => ast::SelpType::S64,
|
||||
".f32" => ast::SelpType::F32,
|
||||
".f64" => ast::SelpType::F64,
|
||||
SelpType: ast::ScalarType = {
|
||||
".b16" => ast::ScalarType::B16,
|
||||
".b32" => ast::ScalarType::B32,
|
||||
".b64" => ast::ScalarType::B64,
|
||||
".u16" => ast::ScalarType::U16,
|
||||
".u32" => ast::ScalarType::U32,
|
||||
".u64" => ast::ScalarType::U64,
|
||||
".s16" => ast::ScalarType::S16,
|
||||
".s32" => ast::ScalarType::S32,
|
||||
".s64" => ast::ScalarType::S64,
|
||||
".f32" => ast::ScalarType::F32,
|
||||
".f64" => ast::ScalarType::F64,
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar
|
||||
|
@ -1454,7 +1454,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
space: space.unwrap_or(ast::AtomSpace::Generic),
|
||||
inner: ast::AtomInnerDetails::Unsigned {
|
||||
op: ast::AtomUIntOp::Inc,
|
||||
typ: ast::UIntType::U32
|
||||
typ: ast::ScalarType::U32
|
||||
}
|
||||
};
|
||||
ast::Instruction::Atom(details,a)
|
||||
|
@ -1466,7 +1466,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
space: space.unwrap_or(ast::AtomSpace::Generic),
|
||||
inner: ast::AtomInnerDetails::Unsigned {
|
||||
op: ast::AtomUIntOp::Dec,
|
||||
typ: ast::UIntType::U32
|
||||
typ: ast::ScalarType::U32
|
||||
}
|
||||
};
|
||||
ast::Instruction::Atom(details,a)
|
||||
|
@ -1544,19 +1544,19 @@ AtomSIntOp: ast::AtomSIntOp = {
|
|||
".max" => ast::AtomSIntOp::Max,
|
||||
}
|
||||
|
||||
BitType: ast::BitType = {
|
||||
".b32" => ast::BitType::B32,
|
||||
".b64" => ast::BitType::B64,
|
||||
BitType: ast::ScalarType = {
|
||||
".b32" => ast::ScalarType::B32,
|
||||
".b64" => ast::ScalarType::B64,
|
||||
}
|
||||
|
||||
UIntType3264: ast::UIntType = {
|
||||
".u32" => ast::UIntType::U32,
|
||||
".u64" => ast::UIntType::U64,
|
||||
UIntType3264: ast::ScalarType = {
|
||||
".u32" => ast::ScalarType::U32,
|
||||
".u64" => ast::ScalarType::U64,
|
||||
}
|
||||
|
||||
SIntType3264: ast::SIntType = {
|
||||
".s32" => ast::SIntType::S32,
|
||||
".s64" => ast::SIntType::S64,
|
||||
SIntType3264: ast::ScalarType = {
|
||||
".s32" => ast::ScalarType::S32,
|
||||
".s64" => ast::ScalarType::S64,
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div
|
||||
|
@ -1566,7 +1566,7 @@ InstDiv: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
"div" <t:SIntType> <a:Arg3> => ast::Instruction::Div(ast::DivDetails::Signed(t), a),
|
||||
"div" <kind:DivFloatKind> <ftz:".ftz"?> ".f32" <a:Arg3> => {
|
||||
let inner = ast::DivFloatDetails {
|
||||
typ: ast::FloatType::F32,
|
||||
typ: ast::ScalarType::F32,
|
||||
flush_to_zero: Some(ftz.is_some()),
|
||||
kind
|
||||
};
|
||||
|
@ -1574,7 +1574,7 @@ InstDiv: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
},
|
||||
"div" <rnd:RoundingModeFloat> ".f64" <a:Arg3> => {
|
||||
let inner = ast::DivFloatDetails {
|
||||
typ: ast::FloatType::F64,
|
||||
typ: ast::ScalarType::F64,
|
||||
flush_to_zero: None,
|
||||
kind: ast::DivFloatKind::Rounding(rnd)
|
||||
};
|
||||
|
@ -1592,7 +1592,7 @@ DivFloatKind: ast::DivFloatKind = {
|
|||
InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"sqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => {
|
||||
let details = ast::SqrtDetails {
|
||||
typ: ast::FloatType::F32,
|
||||
typ: ast::ScalarType::F32,
|
||||
flush_to_zero: Some(ftz.is_some()),
|
||||
kind: ast::SqrtKind::Approx,
|
||||
};
|
||||
|
@ -1600,7 +1600,7 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
},
|
||||
"sqrt" <rnd:RoundingModeFloat> <ftz:".ftz"?> ".f32" <a:Arg2> => {
|
||||
let details = ast::SqrtDetails {
|
||||
typ: ast::FloatType::F32,
|
||||
typ: ast::ScalarType::F32,
|
||||
flush_to_zero: Some(ftz.is_some()),
|
||||
kind: ast::SqrtKind::Rounding(rnd),
|
||||
};
|
||||
|
@ -1608,7 +1608,7 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
},
|
||||
"sqrt" <rnd:RoundingModeFloat> ".f64" <a:Arg2> => {
|
||||
let details = ast::SqrtDetails {
|
||||
typ: ast::FloatType::F64,
|
||||
typ: ast::ScalarType::F64,
|
||||
flush_to_zero: None,
|
||||
kind: ast::SqrtKind::Rounding(rnd),
|
||||
};
|
||||
|
@ -1621,14 +1621,14 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
InstRsqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"rsqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => {
|
||||
let details = ast::RsqrtDetails {
|
||||
typ: ast::FloatType::F32,
|
||||
typ: ast::ScalarType::F32,
|
||||
flush_to_zero: ftz.is_some(),
|
||||
};
|
||||
ast::Instruction::Rsqrt(details, a)
|
||||
},
|
||||
"rsqrt" ".approx" <ftz:".ftz"?> ".f64" <a:Arg2> => {
|
||||
let details = ast::RsqrtDetails {
|
||||
typ: ast::FloatType::F64,
|
||||
typ: ast::ScalarType::F64,
|
||||
flush_to_zero: ftz.is_some(),
|
||||
};
|
||||
ast::Instruction::Rsqrt(details, a)
|
||||
|
@ -1739,7 +1739,7 @@ ArithDetails: ast::ArithDetails = {
|
|||
saturate: false,
|
||||
}),
|
||||
".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt {
|
||||
typ: ast::SIntType::S32,
|
||||
typ: ast::ScalarType::S32,
|
||||
saturate: true,
|
||||
}),
|
||||
<f:ArithFloat> => ast::ArithDetails::Float(f)
|
||||
|
@ -1747,25 +1747,25 @@ ArithDetails: ast::ArithDetails = {
|
|||
|
||||
ArithFloat: ast::ArithFloat = {
|
||||
<rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
|
||||
typ: ast::FloatType::F32,
|
||||
typ: ast::ScalarType::F32,
|
||||
rounding: rn,
|
||||
flush_to_zero: Some(ftz.is_some()),
|
||||
saturate: sat.is_some(),
|
||||
},
|
||||
<rn:RoundingModeFloat?> ".f64" => ast::ArithFloat {
|
||||
typ: ast::FloatType::F64,
|
||||
typ: ast::ScalarType::F64,
|
||||
rounding: rn,
|
||||
flush_to_zero: None,
|
||||
saturate: false,
|
||||
},
|
||||
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
|
||||
typ: ast::FloatType::F16,
|
||||
typ: ast::ScalarType::F16,
|
||||
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
|
||||
flush_to_zero: Some(ftz.is_some()),
|
||||
saturate: sat.is_some(),
|
||||
},
|
||||
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
|
||||
typ: ast::FloatType::F16x2,
|
||||
typ: ast::ScalarType::F16x2,
|
||||
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
|
||||
flush_to_zero: Some(ftz.is_some()),
|
||||
saturate: sat.is_some(),
|
||||
|
@ -1774,25 +1774,25 @@ ArithFloat: ast::ArithFloat = {
|
|||
|
||||
ArithFloatMustRound: ast::ArithFloat = {
|
||||
<rn:RoundingModeFloat> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
|
||||
typ: ast::FloatType::F32,
|
||||
typ: ast::ScalarType::F32,
|
||||
rounding: Some(rn),
|
||||
flush_to_zero: Some(ftz.is_some()),
|
||||
saturate: sat.is_some(),
|
||||
},
|
||||
<rn:RoundingModeFloat> ".f64" => ast::ArithFloat {
|
||||
typ: ast::FloatType::F64,
|
||||
typ: ast::ScalarType::F64,
|
||||
rounding: Some(rn),
|
||||
flush_to_zero: None,
|
||||
saturate: false,
|
||||
},
|
||||
".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
|
||||
typ: ast::FloatType::F16,
|
||||
typ: ast::ScalarType::F16,
|
||||
rounding: Some(ast::RoundingMode::NearestEven),
|
||||
flush_to_zero: Some(ftz.is_some()),
|
||||
saturate: sat.is_some(),
|
||||
},
|
||||
".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
|
||||
typ: ast::FloatType::F16x2,
|
||||
typ: ast::ScalarType::F16x2,
|
||||
rounding: Some(ast::RoundingMode::NearestEven),
|
||||
flush_to_zero: Some(ftz.is_some()),
|
||||
saturate: sat.is_some(),
|
||||
|
|
|
@ -71,7 +71,7 @@
|
|||
%26 = OpLoad %uint %9
|
||||
%40 = OpCopyObject %uint %23
|
||||
%41 = OpCopyObject %uint %24
|
||||
%39 = OpFunctionCall %uint %44 %41 %40 %25 %26
|
||||
%39 = OpFunctionCall %uint %44 %40 %41 %25 %26
|
||||
%22 = OpCopyObject %uint %39
|
||||
OpStore %6 %22
|
||||
%27 = OpLoad %ulong %5
|
||||
|
|
|
@ -1553,10 +1553,9 @@ fn extract_globals<'input, 'b>(
|
|||
space,
|
||||
};
|
||||
let (op, typ) = match typ {
|
||||
ast::FloatType::F32 => ("add_f32", ast::SizedScalarType::F32),
|
||||
ast::FloatType::F64 => ("add_f64", ast::SizedScalarType::F64),
|
||||
ast::FloatType::F16 => unreachable!(),
|
||||
ast::FloatType::F16x2 => unreachable!(),
|
||||
ast::ScalarType::F32 => ("add_f32", ast::SizedScalarType::F32),
|
||||
ast::ScalarType::F64 => ("add_f64", ast::SizedScalarType::F64),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
local.push(to_ptx_impl_atomic_call(
|
||||
id_def,
|
||||
|
@ -1822,15 +1821,15 @@ fn to_ptx_impl_atomic_call(
|
|||
fn to_ptx_impl_bfe_call(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
typ: ast::IntType,
|
||||
typ: ast::ScalarType,
|
||||
arg: ast::Arg4<ExpandedArgParams>,
|
||||
) -> ExpandedStatement {
|
||||
let prefix = "__zluda_ptx_impl__";
|
||||
let suffix = match typ {
|
||||
ast::IntType::U32 => "bfe_u32",
|
||||
ast::IntType::U64 => "bfe_u64",
|
||||
ast::IntType::S32 => "bfe_s32",
|
||||
ast::IntType::S64 => "bfe_s64",
|
||||
ast::ScalarType::U32 => "bfe_u32",
|
||||
ast::ScalarType::U64 => "bfe_u64",
|
||||
ast::ScalarType::S32 => "bfe_s32",
|
||||
ast::ScalarType::S64 => "bfe_s64",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let fn_name = format!("{}{}", prefix, suffix);
|
||||
|
@ -1917,14 +1916,14 @@ fn to_ptx_impl_bfe_call(
|
|||
fn to_ptx_impl_bfi_call(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
typ: ast::BitType,
|
||||
typ: ast::ScalarType,
|
||||
arg: ast::Arg5<ExpandedArgParams>,
|
||||
) -> ExpandedStatement {
|
||||
let prefix = "__zluda_ptx_impl__";
|
||||
let suffix = match typ {
|
||||
ast::BitType::B32 => "bfi_b32",
|
||||
ast::BitType::B64 => "bfi_b64",
|
||||
ast::BitType::B8 | ast::BitType::B16 => unreachable!(),
|
||||
ast::ScalarType::B32 => "bfi_b32",
|
||||
ast::ScalarType::B64 => "bfi_b64",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let fn_name = format!("{}{}", prefix, suffix);
|
||||
let fn_id = match ptx_impl_imports.entry(fn_name) {
|
||||
|
@ -2506,29 +2505,32 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
|
|||
let (width, kind) = match add_type {
|
||||
ast::Type::Scalar(scalar_t) => {
|
||||
let kind = match scalar_t.kind() {
|
||||
kind @ ScalarKind::Bit
|
||||
| kind @ ScalarKind::Unsigned
|
||||
| kind @ ScalarKind::Signed => kind,
|
||||
ScalarKind::Float => return Err(TranslateError::MismatchedType),
|
||||
ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
|
||||
ScalarKind::Pred => return Err(TranslateError::MismatchedType),
|
||||
kind @ ast::ScalarKind::Bit
|
||||
| kind @ ast::ScalarKind::Unsigned
|
||||
| kind @ ast::ScalarKind::Signed => kind,
|
||||
ast::ScalarKind::Float => return Err(TranslateError::MismatchedType),
|
||||
ast::ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
|
||||
ast::ScalarKind::Pred => return Err(TranslateError::MismatchedType),
|
||||
};
|
||||
(scalar_t.size_of(), kind)
|
||||
}
|
||||
_ => return Err(TranslateError::MismatchedType),
|
||||
};
|
||||
let arith_detail = if kind == ScalarKind::Signed {
|
||||
let arith_detail = if kind == ast::ScalarKind::Signed {
|
||||
ast::ArithDetails::Signed(ast::ArithSInt {
|
||||
typ: ast::SIntType::from_size(width),
|
||||
typ: ast::ScalarType::from_parts(width, ast::ScalarKind::Signed),
|
||||
saturate: false,
|
||||
})
|
||||
} else {
|
||||
ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
|
||||
ast::ArithDetails::Unsigned(ast::ScalarType::from_parts(
|
||||
width,
|
||||
ast::ScalarKind::Unsigned,
|
||||
))
|
||||
};
|
||||
let id_constant_stmt = self.id_def.new_non_variable(add_type.clone());
|
||||
let result_id = self.id_def.new_non_variable(add_type);
|
||||
// TODO: check for edge cases around min value/max value/wrapping
|
||||
if offset < 0 && kind != ScalarKind::Signed {
|
||||
if offset < 0 && kind != ast::ScalarKind::Signed {
|
||||
self.func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: ast::ScalarType::from_parts(width, kind),
|
||||
|
@ -3026,18 +3028,18 @@ fn emit_function_body_ops(
|
|||
emit_setp(builder, map, setp, arg)?;
|
||||
}
|
||||
ast::Instruction::Not(t, a) => {
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(t.to_type()));
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(*t));
|
||||
let result_id = Some(a.dst);
|
||||
let operand = a.src;
|
||||
match t {
|
||||
ast::BooleanType::Pred => {
|
||||
ast::ScalarType::Pred => {
|
||||
logical_not(builder, result_type, result_id, operand)
|
||||
}
|
||||
_ => builder.not(result_type, result_id, operand),
|
||||
}?;
|
||||
}
|
||||
ast::Instruction::Shl(t, a) => {
|
||||
let full_type = t.to_type();
|
||||
let full_type = ast::Type::Scalar(*t);
|
||||
let size_of = full_type.size_of();
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(full_type));
|
||||
let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?;
|
||||
|
@ -3048,7 +3050,7 @@ fn emit_function_body_ops(
|
|||
let size_of = full_type.size_of();
|
||||
let result_type = map.get_or_add_scalar(builder, full_type);
|
||||
let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?;
|
||||
if t.signed() {
|
||||
if t.kind() == ast::ScalarKind::Signed {
|
||||
builder.shift_right_arithmetic(
|
||||
result_type,
|
||||
Some(a.dst),
|
||||
|
@ -3088,7 +3090,7 @@ fn emit_function_body_ops(
|
|||
},
|
||||
ast::Instruction::Or(t, a) => {
|
||||
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
|
||||
if *t == ast::BooleanType::Pred {
|
||||
if *t == ast::ScalarType::Pred {
|
||||
builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?;
|
||||
} else {
|
||||
builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
|
||||
|
@ -3116,7 +3118,7 @@ fn emit_function_body_ops(
|
|||
}
|
||||
ast::Instruction::And(t, a) => {
|
||||
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
|
||||
if *t == ast::BooleanType::Pred {
|
||||
if *t == ast::ScalarType::Pred {
|
||||
builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?;
|
||||
} else {
|
||||
builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?;
|
||||
|
@ -3202,7 +3204,7 @@ fn emit_function_body_ops(
|
|||
}
|
||||
ast::Instruction::Neg(details, arg) => {
|
||||
let result_type = map.get_or_add_scalar(builder, details.typ);
|
||||
let negate_func = if details.typ.kind() == ScalarKind::Float {
|
||||
let negate_func = if details.typ.kind() == ast::ScalarKind::Float {
|
||||
dr::Builder::f_negate
|
||||
} else {
|
||||
dr::Builder::s_negate
|
||||
|
@ -3269,7 +3271,7 @@ fn emit_function_body_ops(
|
|||
}
|
||||
ast::Instruction::Xor { typ, arg } => {
|
||||
let builder_fn = match typ {
|
||||
ast::BooleanType::Pred => emit_logical_xor_spirv,
|
||||
ast::ScalarType::Pred => emit_logical_xor_spirv,
|
||||
_ => dr::Builder::bitwise_xor,
|
||||
};
|
||||
let result_type = map.get_or_add_scalar(builder, (*typ).into());
|
||||
|
@ -3284,7 +3286,7 @@ fn emit_function_body_ops(
|
|||
return Err(error_unreachable());
|
||||
}
|
||||
ast::Instruction::Rem { typ, arg } => {
|
||||
let builder_fn = if typ.is_signed() {
|
||||
let builder_fn = if typ.kind() == ast::ScalarKind::Signed {
|
||||
dr::Builder::s_mod
|
||||
} else {
|
||||
dr::Builder::u_mod
|
||||
|
@ -3882,7 +3884,7 @@ fn emit_cvt(
|
|||
}
|
||||
let dest_t: ast::ScalarType = desc.dst.into();
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
|
||||
if desc.src.is_signed() {
|
||||
if desc.src.kind() == ast::ScalarKind::Signed {
|
||||
builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?;
|
||||
} else {
|
||||
builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?;
|
||||
|
@ -3892,7 +3894,7 @@ fn emit_cvt(
|
|||
ast::CvtDetails::IntFromFloat(desc) => {
|
||||
let dest_t: ast::ScalarType = desc.dst.into();
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
|
||||
if desc.dst.is_signed() {
|
||||
if desc.dst.kind() == ast::ScalarKind::Signed {
|
||||
builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?;
|
||||
} else {
|
||||
builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?;
|
||||
|
@ -3904,7 +3906,7 @@ fn emit_cvt(
|
|||
let dest_t: ast::ScalarType = desc.dst.into();
|
||||
let src_t: ast::ScalarType = desc.src.into();
|
||||
// first do shortening/widening
|
||||
let src = if desc.dst.width() != desc.src.width() {
|
||||
let src = if desc.dst.size_of() != desc.src.size_of() {
|
||||
let new_dst = if dest_t.kind() == src_t.kind() {
|
||||
arg.dst
|
||||
} else {
|
||||
|
@ -3933,7 +3935,7 @@ fn emit_cvt(
|
|||
// now do actual conversion
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
|
||||
if desc.saturate {
|
||||
if desc.dst.is_signed() {
|
||||
if desc.dst.kind() == ast::ScalarKind::Signed {
|
||||
builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?;
|
||||
} else {
|
||||
builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?;
|
||||
|
@ -3989,60 +3991,60 @@ fn emit_setp(
|
|||
let operand_1 = arg.src1;
|
||||
let operand_2 = arg.src2;
|
||||
match (setp.cmp_op, setp.typ.kind()) {
|
||||
(ast::SetpCompareOp::Eq, ScalarKind::Signed)
|
||||
| (ast::SetpCompareOp::Eq, ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::Eq, ScalarKind::Bit) => {
|
||||
(ast::SetpCompareOp::Eq, ast::ScalarKind::Signed)
|
||||
| (ast::SetpCompareOp::Eq, ast::ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::Eq, ast::ScalarKind::Bit) => {
|
||||
builder.i_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::Eq, ScalarKind::Float) => {
|
||||
(ast::SetpCompareOp::Eq, ast::ScalarKind::Float) => {
|
||||
builder.f_ord_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::NotEq, ScalarKind::Signed)
|
||||
| (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => {
|
||||
(ast::SetpCompareOp::NotEq, ast::ScalarKind::Signed)
|
||||
| (ast::SetpCompareOp::NotEq, ast::ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::NotEq, ast::ScalarKind::Bit) => {
|
||||
builder.i_not_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::NotEq, ScalarKind::Float) => {
|
||||
(ast::SetpCompareOp::NotEq, ast::ScalarKind::Float) => {
|
||||
builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::Less, ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::Less, ScalarKind::Bit) => {
|
||||
(ast::SetpCompareOp::Less, ast::ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::Less, ast::ScalarKind::Bit) => {
|
||||
builder.u_less_than(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::Less, ScalarKind::Signed) => {
|
||||
(ast::SetpCompareOp::Less, ast::ScalarKind::Signed) => {
|
||||
builder.s_less_than(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::Less, ScalarKind::Float) => {
|
||||
(ast::SetpCompareOp::Less, ast::ScalarKind::Float) => {
|
||||
builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => {
|
||||
(ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Bit) => {
|
||||
builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => {
|
||||
(ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Signed) => {
|
||||
builder.s_less_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => {
|
||||
(ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Float) => {
|
||||
builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::Greater, ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::Greater, ScalarKind::Bit) => {
|
||||
(ast::SetpCompareOp::Greater, ast::ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::Greater, ast::ScalarKind::Bit) => {
|
||||
builder.u_greater_than(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::Greater, ScalarKind::Signed) => {
|
||||
(ast::SetpCompareOp::Greater, ast::ScalarKind::Signed) => {
|
||||
builder.s_greater_than(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::Greater, ScalarKind::Float) => {
|
||||
(ast::SetpCompareOp::Greater, ast::ScalarKind::Float) => {
|
||||
builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => {
|
||||
(ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Bit) => {
|
||||
builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => {
|
||||
(ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Signed) => {
|
||||
builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
|
||||
(ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Float) => {
|
||||
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::NanEq, _) => {
|
||||
|
@ -4222,7 +4224,7 @@ fn emit_abs(
|
|||
) -> Result<(), dr::Error> {
|
||||
let scalar_t = ast::ScalarType::from(d.typ);
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(scalar_t));
|
||||
let cl_abs = if scalar_t.kind() == ScalarKind::Signed {
|
||||
let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed {
|
||||
spirv::CLOp::s_abs
|
||||
} else {
|
||||
spirv::CLOp::fabs
|
||||
|
@ -4286,8 +4288,8 @@ fn emit_implicit_conversion(
|
|||
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
|
||||
if from_parts.width == to_parts.width {
|
||||
let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
|
||||
if from_parts.scalar_kind != ScalarKind::Float
|
||||
&& to_parts.scalar_kind != ScalarKind::Float
|
||||
if from_parts.scalar_kind != ast::ScalarKind::Float
|
||||
&& to_parts.scalar_kind != ast::ScalarKind::Float
|
||||
{
|
||||
// It is noop, but another instruction expects result of this conversion
|
||||
builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
|
||||
|
@ -4299,24 +4301,24 @@ fn emit_implicit_conversion(
|
|||
let same_width_bit_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::from(ast::Type::from_parts(TypeParts {
|
||||
scalar_kind: ScalarKind::Bit,
|
||||
scalar_kind: ast::ScalarKind::Bit,
|
||||
..from_parts
|
||||
})),
|
||||
);
|
||||
let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?;
|
||||
let wide_bit_type = ast::Type::from_parts(TypeParts {
|
||||
scalar_kind: ScalarKind::Bit,
|
||||
scalar_kind: ast::ScalarKind::Bit,
|
||||
..to_parts
|
||||
});
|
||||
let wide_bit_type_spirv =
|
||||
map.get_or_add(builder, SpirvType::from(wide_bit_type.clone()));
|
||||
if to_parts.scalar_kind == ScalarKind::Unsigned
|
||||
|| to_parts.scalar_kind == ScalarKind::Bit
|
||||
if to_parts.scalar_kind == ast::ScalarKind::Unsigned
|
||||
|| to_parts.scalar_kind == ast::ScalarKind::Bit
|
||||
{
|
||||
builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?;
|
||||
} else {
|
||||
let conversion_fn = if from_parts.scalar_kind == ScalarKind::Signed
|
||||
&& to_parts.scalar_kind == ScalarKind::Signed
|
||||
let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed
|
||||
&& to_parts.scalar_kind == ast::ScalarKind::Signed
|
||||
{
|
||||
dr::Builder::s_convert
|
||||
} else {
|
||||
|
@ -4614,23 +4616,23 @@ fn convert_to_stateful_memory_access<'a>(
|
|||
for statement in func_body.iter() {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Add(
|
||||
ast::ArithDetails::Unsigned(ast::UIntType::U64),
|
||||
ast::ArithDetails::Unsigned(ast::ScalarType::U64),
|
||||
arg,
|
||||
))
|
||||
| Statement::Instruction(ast::Instruction::Add(
|
||||
ast::ArithDetails::Signed(ast::ArithSInt {
|
||||
typ: ast::SIntType::S64,
|
||||
typ: ast::ScalarType::S64,
|
||||
saturate: false,
|
||||
}),
|
||||
arg,
|
||||
))
|
||||
| Statement::Instruction(ast::Instruction::Sub(
|
||||
ast::ArithDetails::Unsigned(ast::UIntType::U64),
|
||||
ast::ArithDetails::Unsigned(ast::ScalarType::U64),
|
||||
arg,
|
||||
))
|
||||
| Statement::Instruction(ast::Instruction::Sub(
|
||||
ast::ArithDetails::Signed(ast::ArithSInt {
|
||||
typ: ast::SIntType::S64,
|
||||
typ: ast::ScalarType::S64,
|
||||
saturate: false,
|
||||
}),
|
||||
arg,
|
||||
|
@ -4686,12 +4688,12 @@ fn convert_to_stateful_memory_access<'a>(
|
|||
}
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Add(
|
||||
ast::ArithDetails::Unsigned(ast::UIntType::U64),
|
||||
ast::ArithDetails::Unsigned(ast::ScalarType::U64),
|
||||
arg,
|
||||
))
|
||||
| Statement::Instruction(ast::Instruction::Add(
|
||||
ast::ArithDetails::Signed(ast::ArithSInt {
|
||||
typ: ast::SIntType::S64,
|
||||
typ: ast::ScalarType::S64,
|
||||
saturate: false,
|
||||
}),
|
||||
arg,
|
||||
|
@ -4715,12 +4717,12 @@ fn convert_to_stateful_memory_access<'a>(
|
|||
}))
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Sub(
|
||||
ast::ArithDetails::Unsigned(ast::UIntType::U64),
|
||||
ast::ArithDetails::Unsigned(ast::ScalarType::U64),
|
||||
arg,
|
||||
))
|
||||
| Statement::Instruction(ast::Instruction::Sub(
|
||||
ast::ArithDetails::Signed(ast::ArithSInt {
|
||||
typ: ast::SIntType::S64,
|
||||
typ: ast::ScalarType::S64,
|
||||
saturate: false,
|
||||
}),
|
||||
arg,
|
||||
|
@ -4867,7 +4869,7 @@ fn convert_to_stateful_memory_access_postprocess(
|
|||
ast::LdStateSpace::Global,
|
||||
),
|
||||
to: old_type,
|
||||
kind: ConversionKind::PtrToBit(ast::UIntType::U64),
|
||||
kind: ConversionKind::PtrToBit(ast::ScalarType::U64),
|
||||
src_sema: arg_desc.sema,
|
||||
dst_sema: ArgumentSemantics::Default,
|
||||
}));
|
||||
|
@ -5903,7 +5905,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||
let inst_type = d.typ;
|
||||
ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
|
||||
}
|
||||
ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?),
|
||||
ast::Instruction::Not(t, a) => {
|
||||
ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?)
|
||||
}
|
||||
ast::Instruction::Cvt(d, a) => {
|
||||
let (dst_t, src_t) = match &d {
|
||||
ast::CvtDetails::FloatFromFloat(desc) => (
|
||||
|
@ -5926,7 +5930,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||
ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?)
|
||||
}
|
||||
ast::Instruction::Shl(t, a) => {
|
||||
ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?)
|
||||
ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?)
|
||||
}
|
||||
ast::Instruction::Shr(t, a) => {
|
||||
ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?)
|
||||
|
@ -6176,9 +6180,9 @@ impl ast::Type {
|
|||
ast::Type::Scalar(scalar) => {
|
||||
let kind = scalar.kind();
|
||||
let width = scalar.size_of();
|
||||
if (kind != ScalarKind::Signed
|
||||
&& kind != ScalarKind::Unsigned
|
||||
&& kind != ScalarKind::Bit)
|
||||
if (kind != ast::ScalarKind::Signed
|
||||
&& kind != ast::ScalarKind::Unsigned
|
||||
&& kind != ast::ScalarKind::Bit)
|
||||
|| (width == 8)
|
||||
{
|
||||
return Err(TranslateError::MismatchedType);
|
||||
|
@ -6306,7 +6310,7 @@ impl ast::Type {
|
|||
#[derive(Eq, PartialEq, Clone)]
|
||||
struct TypeParts {
|
||||
kind: TypeKind,
|
||||
scalar_kind: ScalarKind,
|
||||
scalar_kind: ast::ScalarKind,
|
||||
width: u8,
|
||||
components: Vec<u32>,
|
||||
state_space: ast::LdStateSpace,
|
||||
|
@ -6461,7 +6465,7 @@ enum ConversionKind {
|
|||
// zero-extend/chop/bitcast depending on types
|
||||
SignExtend,
|
||||
BitToPtr(ast::LdStateSpace),
|
||||
PtrToBit(ast::UIntType),
|
||||
PtrToBit(ast::ScalarType),
|
||||
PtrToPtr { spirv_ptr: bool },
|
||||
}
|
||||
|
||||
|
@ -6859,7 +6863,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
|
|||
fn map_selp<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
|
||||
self,
|
||||
visitor: &mut V,
|
||||
t: ast::SelpType,
|
||||
t: ast::ScalarType,
|
||||
) -> Result<ast::Arg4<U>, TranslateError> {
|
||||
let dst = visitor.operand(
|
||||
ArgumentDescriptor {
|
||||
|
@ -6904,7 +6908,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
|
|||
fn map_atom<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
|
||||
self,
|
||||
visitor: &mut V,
|
||||
t: ast::BitType,
|
||||
t: ast::ScalarType,
|
||||
state_space: ast::AtomSpace,
|
||||
) -> Result<ast::Arg4<U>, TranslateError> {
|
||||
let scalar_type = ast::ScalarType::from(t);
|
||||
|
@ -7205,103 +7209,41 @@ impl ast::StStateSpace {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
enum ScalarKind {
|
||||
Bit,
|
||||
Unsigned,
|
||||
Signed,
|
||||
Float,
|
||||
Float2,
|
||||
Pred,
|
||||
}
|
||||
|
||||
impl ast::ScalarType {
|
||||
fn kind(self) -> ScalarKind {
|
||||
match self {
|
||||
ast::ScalarType::U8 => ScalarKind::Unsigned,
|
||||
ast::ScalarType::U16 => ScalarKind::Unsigned,
|
||||
ast::ScalarType::U32 => ScalarKind::Unsigned,
|
||||
ast::ScalarType::U64 => ScalarKind::Unsigned,
|
||||
ast::ScalarType::S8 => ScalarKind::Signed,
|
||||
ast::ScalarType::S16 => ScalarKind::Signed,
|
||||
ast::ScalarType::S32 => ScalarKind::Signed,
|
||||
ast::ScalarType::S64 => ScalarKind::Signed,
|
||||
ast::ScalarType::B8 => ScalarKind::Bit,
|
||||
ast::ScalarType::B16 => ScalarKind::Bit,
|
||||
ast::ScalarType::B32 => ScalarKind::Bit,
|
||||
ast::ScalarType::B64 => ScalarKind::Bit,
|
||||
ast::ScalarType::F16 => ScalarKind::Float,
|
||||
ast::ScalarType::F32 => ScalarKind::Float,
|
||||
ast::ScalarType::F64 => ScalarKind::Float,
|
||||
ast::ScalarType::F16x2 => ScalarKind::Float2,
|
||||
ast::ScalarType::Pred => ScalarKind::Pred,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_parts(width: u8, kind: ScalarKind) -> Self {
|
||||
fn from_parts(width: u8, kind: ast::ScalarKind) -> Self {
|
||||
match kind {
|
||||
ScalarKind::Float => match width {
|
||||
ast::ScalarKind::Float => match width {
|
||||
2 => ast::ScalarType::F16,
|
||||
4 => ast::ScalarType::F32,
|
||||
8 => ast::ScalarType::F64,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
ScalarKind::Bit => match width {
|
||||
ast::ScalarKind::Bit => match width {
|
||||
1 => ast::ScalarType::B8,
|
||||
2 => ast::ScalarType::B16,
|
||||
4 => ast::ScalarType::B32,
|
||||
8 => ast::ScalarType::B64,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
ScalarKind::Signed => match width {
|
||||
ast::ScalarKind::Signed => match width {
|
||||
1 => ast::ScalarType::S8,
|
||||
2 => ast::ScalarType::S16,
|
||||
4 => ast::ScalarType::S32,
|
||||
8 => ast::ScalarType::S64,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
ScalarKind::Unsigned => match width {
|
||||
ast::ScalarKind::Unsigned => match width {
|
||||
1 => ast::ScalarType::U8,
|
||||
2 => ast::ScalarType::U16,
|
||||
4 => ast::ScalarType::U32,
|
||||
8 => ast::ScalarType::U64,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
ScalarKind::Float2 => match width {
|
||||
ast::ScalarKind::Float2 => match width {
|
||||
4 => ast::ScalarType::F16x2,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
ScalarKind::Pred => ast::ScalarType::Pred,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::BooleanType {
|
||||
fn to_type(self) -> ast::Type {
|
||||
match self {
|
||||
ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
|
||||
ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
|
||||
ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
|
||||
ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::ShlType {
|
||||
fn to_type(self) -> ast::Type {
|
||||
match self {
|
||||
ast::ShlType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
|
||||
ast::ShlType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
|
||||
ast::ShlType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::ShrType {
|
||||
fn signed(&self) -> bool {
|
||||
match self {
|
||||
ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true,
|
||||
_ => false,
|
||||
ast::ScalarKind::Pred => ast::ScalarType::Pred,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7357,30 +7299,6 @@ impl ast::AtomInnerDetails {
|
|||
}
|
||||
}
|
||||
|
||||
impl ast::SIntType {
|
||||
fn from_size(width: u8) -> Self {
|
||||
match width {
|
||||
1 => ast::SIntType::S8,
|
||||
2 => ast::SIntType::S16,
|
||||
4 => ast::SIntType::S32,
|
||||
8 => ast::SIntType::S64,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::UIntType {
|
||||
fn from_size(width: u8) -> Self {
|
||||
match width {
|
||||
1 => ast::UIntType::U8,
|
||||
2 => ast::UIntType::U16,
|
||||
4 => ast::UIntType::U32,
|
||||
8 => ast::UIntType::U64,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::LdStateSpace {
|
||||
fn to_spirv(self) -> spirv::StorageClass {
|
||||
match self {
|
||||
|
@ -7568,16 +7486,18 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
|
|||
return false;
|
||||
}
|
||||
match inst.kind() {
|
||||
ScalarKind::Bit => operand.kind() != ScalarKind::Bit,
|
||||
ScalarKind::Float => operand.kind() == ScalarKind::Bit,
|
||||
ScalarKind::Signed => {
|
||||
operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned
|
||||
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
|
||||
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
|
||||
ast::ScalarKind::Signed => {
|
||||
operand.kind() == ast::ScalarKind::Bit
|
||||
|| operand.kind() == ast::ScalarKind::Unsigned
|
||||
}
|
||||
ScalarKind::Unsigned => {
|
||||
operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed
|
||||
ast::ScalarKind::Unsigned => {
|
||||
operand.kind() == ast::ScalarKind::Bit
|
||||
|| operand.kind() == ast::ScalarKind::Signed
|
||||
}
|
||||
ScalarKind::Float2 => false,
|
||||
ScalarKind::Pred => false,
|
||||
ast::ScalarKind::Float2 => false,
|
||||
ast::ScalarKind::Pred => false,
|
||||
}
|
||||
}
|
||||
(ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
|
||||
|
@ -7596,7 +7516,7 @@ fn should_bitcast_packed(
|
|||
if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
|
||||
(operand, instr)
|
||||
{
|
||||
if scalar.kind() == ScalarKind::Bit
|
||||
if scalar.kind() == ast::ScalarKind::Bit
|
||||
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
|
||||
{
|
||||
return Ok(Some(ConversionKind::Default));
|
||||
|
@ -7644,32 +7564,33 @@ fn should_convert_relaxed_src(
|
|||
}
|
||||
match (src_type, instr_type) {
|
||||
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||
ScalarKind::Bit => {
|
||||
ast::ScalarKind::Bit => {
|
||||
if instr_type.size_of() <= src_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ScalarKind::Signed | ScalarKind::Unsigned => {
|
||||
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
|
||||
if instr_type.size_of() <= src_type.size_of()
|
||||
&& src_type.kind() != ScalarKind::Float
|
||||
&& src_type.kind() != ast::ScalarKind::Float
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ScalarKind::Float => {
|
||||
if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit
|
||||
ast::ScalarKind::Float => {
|
||||
if instr_type.size_of() <= src_type.size_of()
|
||||
&& src_type.kind() == ast::ScalarKind::Bit
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ScalarKind::Float2 => todo!(),
|
||||
ScalarKind::Pred => None,
|
||||
ast::ScalarKind::Float2 => todo!(),
|
||||
ast::ScalarKind::Pred => None,
|
||||
},
|
||||
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
|
||||
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
|
||||
|
@ -7706,15 +7627,15 @@ fn should_convert_relaxed_dst(
|
|||
}
|
||||
match (dst_type, instr_type) {
|
||||
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||
ScalarKind::Bit => {
|
||||
ast::ScalarKind::Bit => {
|
||||
if instr_type.size_of() <= dst_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ScalarKind::Signed => {
|
||||
if dst_type.kind() != ScalarKind::Float {
|
||||
ast::ScalarKind::Signed => {
|
||||
if dst_type.kind() != ast::ScalarKind::Float {
|
||||
if instr_type.size_of() == dst_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else if instr_type.size_of() < dst_type.size_of() {
|
||||
|
@ -7726,25 +7647,26 @@ fn should_convert_relaxed_dst(
|
|||
None
|
||||
}
|
||||
}
|
||||
ScalarKind::Unsigned => {
|
||||
ast::ScalarKind::Unsigned => {
|
||||
if instr_type.size_of() <= dst_type.size_of()
|
||||
&& dst_type.kind() != ScalarKind::Float
|
||||
&& dst_type.kind() != ast::ScalarKind::Float
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ScalarKind::Float => {
|
||||
if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit
|
||||
ast::ScalarKind::Float => {
|
||||
if instr_type.size_of() <= dst_type.size_of()
|
||||
&& dst_type.kind() == ast::ScalarKind::Bit
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ScalarKind::Float2 => todo!(),
|
||||
ScalarKind::Pred => None,
|
||||
ast::ScalarKind::Float2 => todo!(),
|
||||
ast::ScalarKind::Pred => None,
|
||||
},
|
||||
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
|
||||
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
|
||||
|
|
Loading…
Add table
Reference in a new issue