From a0baad94562bf305c0a3f478c00848c5982a7a05 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 15 Apr 2021 19:10:45 +0200 Subject: [PATCH 01/25] Convert enumes to 1TT --- ptx/src/ast.rs | 224 ++++++++------------ ptx/src/ptx.lalrpop | 236 ++++++++++----------- ptx/src/test/spirv_run/bfi.spvtxt | 2 +- ptx/src/translate.rs | 332 ++++++++++++------------------ 4 files changed, 332 insertions(+), 462 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 1c6d2fb..bc2fa4c 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -210,20 +210,6 @@ sub_enum!(LdStScalarType { F64, }); -sub_enum!(SelpType { - B16, - B32, - B64, - U16, - U32, - U64, - S16, - S32, - S64, - F32, - F64, -}); - #[derive(Copy, Clone, Eq, PartialEq)] pub enum BarDetails { SyncAligned, @@ -425,52 +411,6 @@ pub enum ScalarType { Pred, } -sub_enum!(IntType { - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64 -}); - -sub_enum!(BitType { B8, B16, B32, B64 }); - -sub_enum!(UIntType { U8, U16, U32, U64 }); - -sub_enum!(SIntType { S8, S16, S32, S64 }); - -impl IntType { - pub fn is_signed(self) -> bool { - match self { - IntType::U8 | IntType::U16 | IntType::U32 | IntType::U64 => false, - IntType::S8 | IntType::S16 | IntType::S32 | IntType::S64 => true, - } - } - - pub fn width(self) -> u8 { - match self { - IntType::U8 => 1, - IntType::U16 => 2, - IntType::U32 => 4, - IntType::U64 => 8, - IntType::S8 => 1, - IntType::S16 => 2, - IntType::S32 => 4, - IntType::S64 => 8, - } - } -} - -sub_enum!(FloatType { - F16, - F16x2, - F32, - F64 -}); - impl ScalarType { pub fn size_of(self) -> u8 { match self { @@ -576,24 +516,24 @@ pub enum Instruction { Add(ArithDetails, Arg3

), Setp(SetpData, Arg4Setp

), SetpBool(SetpBoolData, Arg5Setp

), - Not(BooleanType, Arg2

), + Not(ScalarType, Arg2

), Bra(BraData, Arg1

), Cvt(CvtDetails, Arg2

), Cvta(CvtaDetails, Arg2

), - Shl(ShlType, Arg3

), - Shr(ShrType, Arg3

), + Shl(ScalarType, Arg3

), + Shr(ScalarType, Arg3

), St(StData, Arg2St

), Ret(RetData), Call(CallInst

), Abs(AbsDetails, Arg2

), Mad(MulDetails, Arg4

), - Or(BooleanType, Arg3

), + Or(ScalarType, Arg3

), Sub(ArithDetails, Arg3

), Min(MinMaxDetails, Arg3

), Max(MinMaxDetails, Arg3

), Rcp(RcpDetails, Arg2

), - And(BooleanType, Arg3

), - Selp(SelpType, Arg4

), + And(ScalarType, Arg3

), + Selp(ScalarType, Arg4

), Bar(BarDetails, Arg1Bar

), Atom(AtomDetails, Arg3

), AtomCas(AtomCasDetails, Arg4

), @@ -605,13 +545,13 @@ pub enum Instruction { Cos { flush_to_zero: bool, arg: Arg2

}, Lg2 { flush_to_zero: bool, arg: Arg2

}, Ex2 { flush_to_zero: bool, arg: Arg2

}, - Clz { typ: BitType, arg: Arg2

}, - Brev { typ: BitType, arg: Arg2

}, - Popc { typ: BitType, arg: Arg2

}, - Xor { typ: BooleanType, arg: Arg3

}, - Bfe { typ: IntType, arg: Arg4

}, - Bfi { typ: BitType, arg: Arg5

}, - Rem { typ: IntType, arg: Arg3

}, + Clz { typ: ScalarType, arg: Arg2

}, + Brev { typ: ScalarType, arg: Arg2

}, + Popc { typ: ScalarType, arg: Arg2

}, + Xor { typ: ScalarType, arg: Arg3

}, + Bfe { typ: ScalarType, arg: Arg4

}, + Bfi { typ: ScalarType, arg: Arg5

}, + Rem { typ: ScalarType, arg: Arg3

}, } #[derive(Copy, Clone)] @@ -825,7 +765,7 @@ impl MovDetails { #[derive(Copy, Clone)] pub struct MulIntDesc { - pub typ: IntType, + pub typ: ScalarType, pub control: MulIntControl, } @@ -845,7 +785,7 @@ pub enum RoundingMode { } pub struct AddIntDesc { - pub typ: IntType, + pub typ: ScalarType, pub saturate: bool, } @@ -892,39 +832,39 @@ pub struct BraData { pub enum CvtDetails { IntFromInt(CvtIntToIntDesc), - FloatFromFloat(CvtDesc), - IntFromFloat(CvtDesc), - FloatFromInt(CvtDesc), + FloatFromFloat(CvtDesc), + IntFromFloat(CvtDesc), + FloatFromInt(CvtDesc), } pub struct CvtIntToIntDesc { - pub dst: IntType, - pub src: IntType, + pub dst: ScalarType, + pub src: ScalarType, pub saturate: bool, } -pub struct CvtDesc { +pub struct CvtDesc { pub rounding: Option, pub flush_to_zero: Option, pub saturate: bool, - pub dst: Dst, - pub src: Src, + pub dst: ScalarType, + pub src: ScalarType, } impl CvtDetails { pub fn new_int_from_int_checked<'err, 'input>( saturate: bool, - dst: IntType, - src: IntType, + dst: ScalarType, + src: ScalarType, err: &'err mut Vec, PtxError>>, ) -> Self { if saturate { - if src.is_signed() { - if dst.is_signed() && dst.width() >= src.width() { + if src.kind() == ScalarKind::Signed { + if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() { err.push(ParseError::from(PtxError::SyntaxError)); } } else { - if dst == src || dst.width() >= src.width() { + if dst == src || dst.size_of() >= src.size_of() { err.push(ParseError::from(PtxError::SyntaxError)); } } @@ -936,11 +876,11 @@ impl CvtDetails { rounding: RoundingMode, flush_to_zero: bool, saturate: bool, - dst: FloatType, - src: IntType, + dst: ScalarType, + src: ScalarType, err: &'err mut Vec, PtxError>>, ) -> Self { - if flush_to_zero && dst != FloatType::F32 { + if flush_to_zero && dst != ScalarType::F32 { err.push(ParseError::from(PtxError::NonF32Ftz)); } CvtDetails::FloatFromInt(CvtDesc { @@ -956,11 +896,11 @@ impl CvtDetails { rounding: RoundingMode, flush_to_zero: bool, saturate: bool, - dst: IntType, - src: FloatType, + dst: ScalarType, + src: ScalarType, err: &'err mut Vec, PtxError>>, ) -> Self { - if flush_to_zero && src != FloatType::F32 { + if flush_to_zero && src != ScalarType::F32 { err.push(ParseError::from(PtxError::NonF32Ftz)); } CvtDetails::IntFromFloat(CvtDesc { @@ -993,25 +933,6 @@ pub enum CvtaSize { U64, } -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum ShlType { - B16, - B32, - B64, -} - -sub_enum!(ShrType { - B16, - B32, - B64, - U16, - U32, - U64, - S16, - S32, - S64, -}); - pub struct StData { pub qualifier: LdStQualifier, pub state_space: StStateSpace, @@ -1040,13 +961,6 @@ pub struct RetData { pub uniform: bool, } -sub_enum!(BooleanType { - Pred, - B16, - B32, - B64, -}); - #[derive(Copy, Clone)] pub enum MulDetails { Unsigned(MulUInt), @@ -1056,32 +970,32 @@ pub enum MulDetails { #[derive(Copy, Clone)] pub struct MulUInt { - pub typ: UIntType, + pub typ: ScalarType, pub control: MulIntControl, } #[derive(Copy, Clone)] pub struct MulSInt { - pub typ: SIntType, + pub typ: ScalarType, pub control: MulIntControl, } #[derive(Copy, Clone)] pub enum ArithDetails { - Unsigned(UIntType), + Unsigned(ScalarType), Signed(ArithSInt), Float(ArithFloat), } #[derive(Copy, Clone)] pub struct ArithSInt { - pub typ: SIntType, + pub typ: ScalarType, pub saturate: bool, } #[derive(Copy, Clone)] pub struct ArithFloat { - pub typ: FloatType, + pub typ: ScalarType, pub rounding: Option, pub flush_to_zero: Option, pub saturate: bool, @@ -1089,8 +1003,8 @@ pub struct ArithFloat { #[derive(Copy, Clone)] pub enum MinMaxDetails { - Signed(SIntType), - Unsigned(UIntType), + Signed(ScalarType), + Unsigned(ScalarType), Float(MinMaxFloat), } @@ -1098,7 +1012,7 @@ pub enum MinMaxDetails { pub struct MinMaxFloat { pub flush_to_zero: Option, pub nan: bool, - pub typ: FloatType, + pub typ: ScalarType, } #[derive(Copy, Clone)] @@ -1126,10 +1040,10 @@ pub enum AtomSpace { #[derive(Copy, Clone)] pub enum AtomInnerDetails { - Bit { op: AtomBitOp, typ: BitType }, - Unsigned { op: AtomUIntOp, typ: UIntType }, - Signed { op: AtomSIntOp, typ: SIntType }, - Float { op: AtomFloatOp, typ: FloatType }, + Bit { op: AtomBitOp, typ: ScalarType }, + Unsigned { op: AtomUIntOp, typ: ScalarType }, + Signed { op: AtomSIntOp, typ: ScalarType }, + Float { op: AtomFloatOp, typ: ScalarType }, } #[derive(Copy, Clone, Eq, PartialEq)] @@ -1166,19 +1080,19 @@ pub struct AtomCasDetails { pub semantics: AtomSemantics, pub scope: MemScope, pub space: AtomSpace, - pub typ: BitType, + pub typ: ScalarType, } #[derive(Copy, Clone)] pub enum DivDetails { - Unsigned(UIntType), - Signed(SIntType), + Unsigned(ScalarType), + Signed(ScalarType), Float(DivFloatDetails), } #[derive(Copy, Clone)] pub struct DivFloatDetails { - pub typ: FloatType, + pub typ: ScalarType, pub flush_to_zero: Option, pub kind: DivFloatKind, } @@ -1197,7 +1111,7 @@ pub enum NumsOrArrays<'a> { #[derive(Copy, Clone)] pub struct SqrtDetails { - pub typ: FloatType, + pub typ: ScalarType, pub flush_to_zero: Option, pub kind: SqrtKind, } @@ -1210,7 +1124,7 @@ pub enum SqrtKind { #[derive(Copy, Clone, Eq, PartialEq)] pub struct RsqrtDetails { - pub typ: FloatType, + pub typ: ScalarType, pub flush_to_zero: bool, } @@ -1379,6 +1293,40 @@ pub enum TuningDirective { MinNCtaPerSm(u32), } +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum ScalarKind { + Bit, + Unsigned, + Signed, + Float, + Float2, + Pred, +} + +impl ScalarType { + pub fn kind(self) -> ScalarKind { + match self { + ScalarType::U8 => ScalarKind::Unsigned, + ScalarType::U16 => ScalarKind::Unsigned, + ScalarType::U32 => ScalarKind::Unsigned, + ScalarType::U64 => ScalarKind::Unsigned, + ScalarType::S8 => ScalarKind::Signed, + ScalarType::S16 => ScalarKind::Signed, + ScalarType::S32 => ScalarKind::Signed, + ScalarType::S64 => ScalarKind::Signed, + ScalarType::B8 => ScalarKind::Bit, + ScalarType::B16 => ScalarKind::Bit, + ScalarType::B32 => ScalarKind::Bit, + ScalarType::B64 => ScalarKind::Bit, + ScalarType::F16 => ScalarKind::Float, + ScalarType::F32 => ScalarKind::Float, + ScalarType::F64 => ScalarKind::Float, + ScalarType::F16x2 => ScalarKind::Float2, + ScalarType::Pred => ScalarKind::Pred, + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 423fd57..41c1d73 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -899,39 +899,39 @@ RoundingModeInt : ast::RoundingMode = { ".rpi" => ast::RoundingMode::PositiveInf, }; -IntType : ast::IntType = { - ".u16" => ast::IntType::U16, - ".u32" => ast::IntType::U32, - ".u64" => ast::IntType::U64, - ".s16" => ast::IntType::S16, - ".s32" => ast::IntType::S32, - ".s64" => ast::IntType::S64, +IntType : ast::ScalarType = { + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; -IntType3264: ast::IntType = { - ".u32" => ast::IntType::U32, - ".u64" => ast::IntType::U64, - ".s32" => ast::IntType::S32, - ".s64" => ast::IntType::S64, +IntType3264: ast::ScalarType = { + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, } -UIntType: ast::UIntType = { - ".u16" => ast::UIntType::U16, - ".u32" => ast::UIntType::U32, - ".u64" => ast::UIntType::U64, +UIntType: ast::ScalarType = { + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, }; -SIntType: ast::SIntType = { - ".s16" => ast::SIntType::S16, - ".s32" => ast::SIntType::S32, - ".s64" => ast::SIntType::S64, +SIntType: ast::ScalarType = { + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; -FloatType: ast::FloatType = { - ".f16" => ast::FloatType::F16, - ".f16x2" => ast::FloatType::F16x2, - ".f32" => ast::FloatType::F32, - ".f64" => ast::FloatType::F64, +FloatType: ast::ScalarType = { + ".f16" => ast::ScalarType::F16, + ".f16x2" => ast::ScalarType::F16x2, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add @@ -1023,11 +1023,11 @@ InstNot: ast::Instruction> = { "not" => ast::Instruction::Not(t, a) }; -BooleanType: ast::BooleanType = { - ".pred" => ast::BooleanType::Pred, - ".b16" => ast::BooleanType::B16, - ".b32" => ast::BooleanType::B32, - ".b64" => ast::BooleanType::B64, +BooleanType: ast::ScalarType = { + ".pred" => ast::ScalarType::Pred, + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at @@ -1080,8 +1080,8 @@ InstCvt: ast::Instruction> = { rounding: r, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F16, - src: ast::FloatType::F16 + dst: ast::ScalarType::F16, + src: ast::ScalarType::F16 } ), a) }, @@ -1091,8 +1091,8 @@ InstCvt: ast::Instruction> = { rounding: None, flush_to_zero: Some(f.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F32, - src: ast::FloatType::F16 + dst: ast::ScalarType::F32, + src: ast::ScalarType::F16 } ), a) }, @@ -1102,8 +1102,8 @@ InstCvt: ast::Instruction> = { rounding: None, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F64, - src: ast::FloatType::F16 + dst: ast::ScalarType::F64, + src: ast::ScalarType::F16 } ), a) }, @@ -1113,8 +1113,8 @@ InstCvt: ast::Instruction> = { rounding: Some(r), flush_to_zero: Some(f.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F16, - src: ast::FloatType::F32 + dst: ast::ScalarType::F16, + src: ast::ScalarType::F32 } ), a) }, @@ -1124,8 +1124,8 @@ InstCvt: ast::Instruction> = { rounding: r, flush_to_zero: Some(f.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F32, - src: ast::FloatType::F32 + dst: ast::ScalarType::F32, + src: ast::ScalarType::F32 } ), a) }, @@ -1135,8 +1135,8 @@ InstCvt: ast::Instruction> = { rounding: None, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F64, - src: ast::FloatType::F32 + dst: ast::ScalarType::F64, + src: ast::ScalarType::F32 } ), a) }, @@ -1146,8 +1146,8 @@ InstCvt: ast::Instruction> = { rounding: Some(r), flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F16, - src: ast::FloatType::F64 + dst: ast::ScalarType::F16, + src: ast::ScalarType::F64 } ), a) }, @@ -1157,8 +1157,8 @@ InstCvt: ast::Instruction> = { rounding: Some(r), flush_to_zero: Some(s.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F32, - src: ast::FloatType::F64 + dst: ast::ScalarType::F32, + src: ast::ScalarType::F64 } ), a) }, @@ -1168,28 +1168,28 @@ InstCvt: ast::Instruction> = { rounding: r, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F64, - src: ast::FloatType::F64 + dst: ast::ScalarType::F64, + src: ast::ScalarType::F64 } ), a) }, }; -CvtTypeInt: ast::IntType = { - ".u8" => ast::IntType::U8, - ".u16" => ast::IntType::U16, - ".u32" => ast::IntType::U32, - ".u64" => ast::IntType::U64, - ".s8" => ast::IntType::S8, - ".s16" => ast::IntType::S16, - ".s32" => ast::IntType::S32, - ".s64" => ast::IntType::S64, +CvtTypeInt: ast::ScalarType = { + ".u8" => ast::ScalarType::U8, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s8" => ast::ScalarType::S8, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; -CvtTypeFloat: ast::FloatType = { - ".f16" => ast::FloatType::F16, - ".f32" => ast::FloatType::F32, - ".f64" => ast::FloatType::F64, +CvtTypeFloat: ast::ScalarType = { + ".f16" => ast::ScalarType::F16, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl @@ -1197,10 +1197,10 @@ InstShl: ast::Instruction> = { "shl" => ast::Instruction::Shl(t, a) }; -ShlType: ast::ShlType = { - ".b16" => ast::ShlType::B16, - ".b32" => ast::ShlType::B32, - ".b64" => ast::ShlType::B64, +ShlType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr @@ -1208,16 +1208,16 @@ InstShr: ast::Instruction> = { "shr" => ast::Instruction::Shr(t, a) }; -ShrType: ast::ShrType = { - ".b16" => ast::ShrType::B16, - ".b32" => ast::ShrType::B32, - ".b64" => ast::ShrType::B64, - ".u16" => ast::ShrType::U16, - ".u32" => ast::ShrType::U32, - ".u64" => ast::ShrType::U64, - ".s16" => ast::ShrType::S16, - ".s32" => ast::ShrType::S32, - ".s64" => ast::ShrType::S64, +ShrType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st @@ -1393,16 +1393,16 @@ MinMaxDetails: ast::MinMaxDetails = { => ast::MinMaxDetails::Unsigned(t), => ast::MinMaxDetails::Signed(t), ".f32" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F32 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F32 } ), ".f64" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::FloatType::F64 } + ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::ScalarType::F64 } ), ".f16" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16 } ), ".f16x2" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16x2 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16x2 } ) } @@ -1411,18 +1411,18 @@ InstSelp: ast::Instruction> = { "selp" => ast::Instruction::Selp(t, a), }; -SelpType: ast::SelpType = { - ".b16" => ast::SelpType::B16, - ".b32" => ast::SelpType::B32, - ".b64" => ast::SelpType::B64, - ".u16" => ast::SelpType::U16, - ".u32" => ast::SelpType::U32, - ".u64" => ast::SelpType::U64, - ".s16" => ast::SelpType::S16, - ".s32" => ast::SelpType::S32, - ".s64" => ast::SelpType::S64, - ".f32" => ast::SelpType::F32, - ".f64" => ast::SelpType::F64, +SelpType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar @@ -1454,7 +1454,7 @@ InstAtom: ast::Instruction> = { space: space.unwrap_or(ast::AtomSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Inc, - typ: ast::UIntType::U32 + typ: ast::ScalarType::U32 } }; ast::Instruction::Atom(details,a) @@ -1466,7 +1466,7 @@ InstAtom: ast::Instruction> = { space: space.unwrap_or(ast::AtomSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Dec, - typ: ast::UIntType::U32 + typ: ast::ScalarType::U32 } }; ast::Instruction::Atom(details,a) @@ -1544,19 +1544,19 @@ AtomSIntOp: ast::AtomSIntOp = { ".max" => ast::AtomSIntOp::Max, } -BitType: ast::BitType = { - ".b32" => ast::BitType::B32, - ".b64" => ast::BitType::B64, +BitType: ast::ScalarType = { + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, } -UIntType3264: ast::UIntType = { - ".u32" => ast::UIntType::U32, - ".u64" => ast::UIntType::U64, +UIntType3264: ast::ScalarType = { + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, } -SIntType3264: ast::SIntType = { - ".s32" => ast::SIntType::S32, - ".s64" => ast::SIntType::S64, +SIntType3264: ast::ScalarType = { + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div @@ -1566,7 +1566,7 @@ InstDiv: ast::Instruction> = { "div" => ast::Instruction::Div(ast::DivDetails::Signed(t), a), "div" ".f32" => { let inner = ast::DivFloatDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: Some(ftz.is_some()), kind }; @@ -1574,7 +1574,7 @@ InstDiv: ast::Instruction> = { }, "div" ".f64" => { let inner = ast::DivFloatDetails { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, flush_to_zero: None, kind: ast::DivFloatKind::Rounding(rnd) }; @@ -1592,7 +1592,7 @@ DivFloatKind: ast::DivFloatKind = { InstSqrt: ast::Instruction> = { "sqrt" ".approx" ".f32" => { let details = ast::SqrtDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: Some(ftz.is_some()), kind: ast::SqrtKind::Approx, }; @@ -1600,7 +1600,7 @@ InstSqrt: ast::Instruction> = { }, "sqrt" ".f32" => { let details = ast::SqrtDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: Some(ftz.is_some()), kind: ast::SqrtKind::Rounding(rnd), }; @@ -1608,7 +1608,7 @@ InstSqrt: ast::Instruction> = { }, "sqrt" ".f64" => { let details = ast::SqrtDetails { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, flush_to_zero: None, kind: ast::SqrtKind::Rounding(rnd), }; @@ -1621,14 +1621,14 @@ InstSqrt: ast::Instruction> = { InstRsqrt: ast::Instruction> = { "rsqrt" ".approx" ".f32" => { let details = ast::RsqrtDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: ftz.is_some(), }; ast::Instruction::Rsqrt(details, a) }, "rsqrt" ".approx" ".f64" => { let details = ast::RsqrtDetails { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, flush_to_zero: ftz.is_some(), }; ast::Instruction::Rsqrt(details, a) @@ -1739,7 +1739,7 @@ ArithDetails: ast::ArithDetails = { saturate: false, }), ".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S32, + typ: ast::ScalarType::S32, saturate: true, }), => ast::ArithDetails::Float(f) @@ -1747,25 +1747,25 @@ ArithDetails: ast::ArithDetails = { ArithFloat: ast::ArithFloat = { ".f32" => ast::ArithFloat { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, rounding: rn, flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".f64" => ast::ArithFloat { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, rounding: rn, flush_to_zero: None, saturate: false, }, ".f16" => ast::ArithFloat { - typ: ast::FloatType::F16, + typ: ast::ScalarType::F16, rounding: rn.map(|_| ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".f16x2" => ast::ArithFloat { - typ: ast::FloatType::F16x2, + typ: ast::ScalarType::F16x2, rounding: rn.map(|_| ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), @@ -1774,25 +1774,25 @@ ArithFloat: ast::ArithFloat = { ArithFloatMustRound: ast::ArithFloat = { ".f32" => ast::ArithFloat { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, rounding: Some(rn), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".f64" => ast::ArithFloat { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, rounding: Some(rn), flush_to_zero: None, saturate: false, }, ".rn" ".f16" => ast::ArithFloat { - typ: ast::FloatType::F16, + typ: ast::ScalarType::F16, rounding: Some(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".rn" ".f16x2" => ast::ArithFloat { - typ: ast::FloatType::F16x2, + typ: ast::ScalarType::F16x2, rounding: Some(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), diff --git a/ptx/src/test/spirv_run/bfi.spvtxt b/ptx/src/test/spirv_run/bfi.spvtxt index a226f78..dc8f683 100644 --- a/ptx/src/test/spirv_run/bfi.spvtxt +++ b/ptx/src/test/spirv_run/bfi.spvtxt @@ -71,7 +71,7 @@ %26 = OpLoad %uint %9 %40 = OpCopyObject %uint %23 %41 = OpCopyObject %uint %24 - %39 = OpFunctionCall %uint %44 %41 %40 %25 %26 + %39 = OpFunctionCall %uint %44 %40 %41 %25 %26 %22 = OpCopyObject %uint %39 OpStore %6 %22 %27 = OpLoad %ulong %5 diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index e39280a..51b1dc6 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1553,10 +1553,9 @@ fn extract_globals<'input, 'b>( space, }; let (op, typ) = match typ { - ast::FloatType::F32 => ("add_f32", ast::SizedScalarType::F32), - ast::FloatType::F64 => ("add_f64", ast::SizedScalarType::F64), - ast::FloatType::F16 => unreachable!(), - ast::FloatType::F16x2 => unreachable!(), + ast::ScalarType::F32 => ("add_f32", ast::SizedScalarType::F32), + ast::ScalarType::F64 => ("add_f64", ast::SizedScalarType::F64), + _ => unreachable!(), }; local.push(to_ptx_impl_atomic_call( id_def, @@ -1822,15 +1821,15 @@ fn to_ptx_impl_atomic_call( fn to_ptx_impl_bfe_call( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, - typ: ast::IntType, + typ: ast::ScalarType, arg: ast::Arg4, ) -> ExpandedStatement { let prefix = "__zluda_ptx_impl__"; let suffix = match typ { - ast::IntType::U32 => "bfe_u32", - ast::IntType::U64 => "bfe_u64", - ast::IntType::S32 => "bfe_s32", - ast::IntType::S64 => "bfe_s64", + ast::ScalarType::U32 => "bfe_u32", + ast::ScalarType::U64 => "bfe_u64", + ast::ScalarType::S32 => "bfe_s32", + ast::ScalarType::S64 => "bfe_s64", _ => unreachable!(), }; let fn_name = format!("{}{}", prefix, suffix); @@ -1917,14 +1916,14 @@ fn to_ptx_impl_bfe_call( fn to_ptx_impl_bfi_call( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, - typ: ast::BitType, + typ: ast::ScalarType, arg: ast::Arg5, ) -> ExpandedStatement { let prefix = "__zluda_ptx_impl__"; let suffix = match typ { - ast::BitType::B32 => "bfi_b32", - ast::BitType::B64 => "bfi_b64", - ast::BitType::B8 | ast::BitType::B16 => unreachable!(), + ast::ScalarType::B32 => "bfi_b32", + ast::ScalarType::B64 => "bfi_b64", + _ => unreachable!(), }; let fn_name = format!("{}{}", prefix, suffix); let fn_id = match ptx_impl_imports.entry(fn_name) { @@ -2506,29 +2505,32 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { let (width, kind) = match add_type { ast::Type::Scalar(scalar_t) => { let kind = match scalar_t.kind() { - kind @ ScalarKind::Bit - | kind @ ScalarKind::Unsigned - | kind @ ScalarKind::Signed => kind, - ScalarKind::Float => return Err(TranslateError::MismatchedType), - ScalarKind::Float2 => return Err(TranslateError::MismatchedType), - ScalarKind::Pred => return Err(TranslateError::MismatchedType), + kind @ ast::ScalarKind::Bit + | kind @ ast::ScalarKind::Unsigned + | kind @ ast::ScalarKind::Signed => kind, + ast::ScalarKind::Float => return Err(TranslateError::MismatchedType), + ast::ScalarKind::Float2 => return Err(TranslateError::MismatchedType), + ast::ScalarKind::Pred => return Err(TranslateError::MismatchedType), }; (scalar_t.size_of(), kind) } _ => return Err(TranslateError::MismatchedType), }; - let arith_detail = if kind == ScalarKind::Signed { + let arith_detail = if kind == ast::ScalarKind::Signed { ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::from_size(width), + typ: ast::ScalarType::from_parts(width, ast::ScalarKind::Signed), saturate: false, }) } else { - ast::ArithDetails::Unsigned(ast::UIntType::from_size(width)) + ast::ArithDetails::Unsigned(ast::ScalarType::from_parts( + width, + ast::ScalarKind::Unsigned, + )) }; let id_constant_stmt = self.id_def.new_non_variable(add_type.clone()); let result_id = self.id_def.new_non_variable(add_type); // TODO: check for edge cases around min value/max value/wrapping - if offset < 0 && kind != ScalarKind::Signed { + if offset < 0 && kind != ast::ScalarKind::Signed { self.func.push(Statement::Constant(ConstantDefinition { dst: id_constant_stmt, typ: ast::ScalarType::from_parts(width, kind), @@ -3026,18 +3028,18 @@ fn emit_function_body_ops( emit_setp(builder, map, setp, arg)?; } ast::Instruction::Not(t, a) => { - let result_type = map.get_or_add(builder, SpirvType::from(t.to_type())); + let result_type = map.get_or_add(builder, SpirvType::from(*t)); let result_id = Some(a.dst); let operand = a.src; match t { - ast::BooleanType::Pred => { + ast::ScalarType::Pred => { logical_not(builder, result_type, result_id, operand) } _ => builder.not(result_type, result_id, operand), }?; } ast::Instruction::Shl(t, a) => { - let full_type = t.to_type(); + let full_type = ast::Type::Scalar(*t); let size_of = full_type.size_of(); let result_type = map.get_or_add(builder, SpirvType::from(full_type)); let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?; @@ -3048,7 +3050,7 @@ fn emit_function_body_ops( let size_of = full_type.size_of(); let result_type = map.get_or_add_scalar(builder, full_type); let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?; - if t.signed() { + if t.kind() == ast::ScalarKind::Signed { builder.shift_right_arithmetic( result_type, Some(a.dst), @@ -3088,7 +3090,7 @@ fn emit_function_body_ops( }, ast::Instruction::Or(t, a) => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); - if *t == ast::BooleanType::Pred { + if *t == ast::ScalarType::Pred { builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?; } else { builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?; @@ -3116,7 +3118,7 @@ fn emit_function_body_ops( } ast::Instruction::And(t, a) => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); - if *t == ast::BooleanType::Pred { + if *t == ast::ScalarType::Pred { builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?; } else { builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?; @@ -3202,7 +3204,7 @@ fn emit_function_body_ops( } ast::Instruction::Neg(details, arg) => { let result_type = map.get_or_add_scalar(builder, details.typ); - let negate_func = if details.typ.kind() == ScalarKind::Float { + let negate_func = if details.typ.kind() == ast::ScalarKind::Float { dr::Builder::f_negate } else { dr::Builder::s_negate @@ -3269,7 +3271,7 @@ fn emit_function_body_ops( } ast::Instruction::Xor { typ, arg } => { let builder_fn = match typ { - ast::BooleanType::Pred => emit_logical_xor_spirv, + ast::ScalarType::Pred => emit_logical_xor_spirv, _ => dr::Builder::bitwise_xor, }; let result_type = map.get_or_add_scalar(builder, (*typ).into()); @@ -3284,7 +3286,7 @@ fn emit_function_body_ops( return Err(error_unreachable()); } ast::Instruction::Rem { typ, arg } => { - let builder_fn = if typ.is_signed() { + let builder_fn = if typ.kind() == ast::ScalarKind::Signed { dr::Builder::s_mod } else { dr::Builder::u_mod @@ -3882,7 +3884,7 @@ fn emit_cvt( } let dest_t: ast::ScalarType = desc.dst.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - if desc.src.is_signed() { + if desc.src.kind() == ast::ScalarKind::Signed { builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?; } else { builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?; @@ -3892,7 +3894,7 @@ fn emit_cvt( ast::CvtDetails::IntFromFloat(desc) => { let dest_t: ast::ScalarType = desc.dst.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - if desc.dst.is_signed() { + if desc.dst.kind() == ast::ScalarKind::Signed { builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?; } else { builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?; @@ -3904,7 +3906,7 @@ fn emit_cvt( let dest_t: ast::ScalarType = desc.dst.into(); let src_t: ast::ScalarType = desc.src.into(); // first do shortening/widening - let src = if desc.dst.width() != desc.src.width() { + let src = if desc.dst.size_of() != desc.src.size_of() { let new_dst = if dest_t.kind() == src_t.kind() { arg.dst } else { @@ -3933,7 +3935,7 @@ fn emit_cvt( // now do actual conversion let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); if desc.saturate { - if desc.dst.is_signed() { + if desc.dst.kind() == ast::ScalarKind::Signed { builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?; } else { builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?; @@ -3989,60 +3991,60 @@ fn emit_setp( let operand_1 = arg.src1; let operand_2 = arg.src2; match (setp.cmp_op, setp.typ.kind()) { - (ast::SetpCompareOp::Eq, ScalarKind::Signed) - | (ast::SetpCompareOp::Eq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::Eq, ScalarKind::Bit) => { + (ast::SetpCompareOp::Eq, ast::ScalarKind::Signed) + | (ast::SetpCompareOp::Eq, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::Eq, ast::ScalarKind::Bit) => { builder.i_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Eq, ScalarKind::Float) => { + (ast::SetpCompareOp::Eq, ast::ScalarKind::Float) => { builder.f_ord_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::NotEq, ScalarKind::Signed) - | (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => { + (ast::SetpCompareOp::NotEq, ast::ScalarKind::Signed) + | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Bit) => { builder.i_not_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::NotEq, ScalarKind::Float) => { + (ast::SetpCompareOp::NotEq, ast::ScalarKind::Float) => { builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Less, ScalarKind::Unsigned) - | (ast::SetpCompareOp::Less, ScalarKind::Bit) => { + (ast::SetpCompareOp::Less, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::Less, ast::ScalarKind::Bit) => { builder.u_less_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Less, ScalarKind::Signed) => { + (ast::SetpCompareOp::Less, ast::ScalarKind::Signed) => { builder.s_less_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Less, ScalarKind::Float) => { + (ast::SetpCompareOp::Less, ast::ScalarKind::Float) => { builder.f_ord_less_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => { + (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Bit) => { builder.u_less_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => { + (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Signed) => { builder.s_less_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => { + (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Float) => { builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Greater, ScalarKind::Unsigned) - | (ast::SetpCompareOp::Greater, ScalarKind::Bit) => { + (ast::SetpCompareOp::Greater, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::Greater, ast::ScalarKind::Bit) => { builder.u_greater_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Greater, ScalarKind::Signed) => { + (ast::SetpCompareOp::Greater, ast::ScalarKind::Signed) => { builder.s_greater_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Greater, ScalarKind::Float) => { + (ast::SetpCompareOp::Greater, ast::ScalarKind::Float) => { builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => { + (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Bit) => { builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => { + (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Signed) => { builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => { + (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Float) => { builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::NanEq, _) => { @@ -4222,7 +4224,7 @@ fn emit_abs( ) -> Result<(), dr::Error> { let scalar_t = ast::ScalarType::from(d.typ); let result_type = map.get_or_add(builder, SpirvType::from(scalar_t)); - let cl_abs = if scalar_t.kind() == ScalarKind::Signed { + let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed { spirv::CLOp::s_abs } else { spirv::CLOp::fabs @@ -4286,8 +4288,8 @@ fn emit_implicit_conversion( (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { if from_parts.width == to_parts.width { let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); - if from_parts.scalar_kind != ScalarKind::Float - && to_parts.scalar_kind != ScalarKind::Float + if from_parts.scalar_kind != ast::ScalarKind::Float + && to_parts.scalar_kind != ast::ScalarKind::Float { // It is noop, but another instruction expects result of this conversion builder.copy_object(dst_type, Some(cv.dst), cv.src)?; @@ -4299,24 +4301,24 @@ fn emit_implicit_conversion( let same_width_bit_type = map.get_or_add( builder, SpirvType::from(ast::Type::from_parts(TypeParts { - scalar_kind: ScalarKind::Bit, + scalar_kind: ast::ScalarKind::Bit, ..from_parts })), ); let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?; let wide_bit_type = ast::Type::from_parts(TypeParts { - scalar_kind: ScalarKind::Bit, + scalar_kind: ast::ScalarKind::Bit, ..to_parts }); let wide_bit_type_spirv = map.get_or_add(builder, SpirvType::from(wide_bit_type.clone())); - if to_parts.scalar_kind == ScalarKind::Unsigned - || to_parts.scalar_kind == ScalarKind::Bit + if to_parts.scalar_kind == ast::ScalarKind::Unsigned + || to_parts.scalar_kind == ast::ScalarKind::Bit { builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?; } else { - let conversion_fn = if from_parts.scalar_kind == ScalarKind::Signed - && to_parts.scalar_kind == ScalarKind::Signed + let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed + && to_parts.scalar_kind == ast::ScalarKind::Signed { dr::Builder::s_convert } else { @@ -4614,23 +4616,23 @@ fn convert_to_stateful_memory_access<'a>( for statement in func_body.iter() { match statement { Statement::Instruction(ast::Instruction::Add( - ast::ArithDetails::Unsigned(ast::UIntType::U64), + ast::ArithDetails::Unsigned(ast::ScalarType::U64), arg, )) | Statement::Instruction(ast::Instruction::Add( ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S64, + typ: ast::ScalarType::S64, saturate: false, }), arg, )) | Statement::Instruction(ast::Instruction::Sub( - ast::ArithDetails::Unsigned(ast::UIntType::U64), + ast::ArithDetails::Unsigned(ast::ScalarType::U64), arg, )) | Statement::Instruction(ast::Instruction::Sub( ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S64, + typ: ast::ScalarType::S64, saturate: false, }), arg, @@ -4686,12 +4688,12 @@ fn convert_to_stateful_memory_access<'a>( } } Statement::Instruction(ast::Instruction::Add( - ast::ArithDetails::Unsigned(ast::UIntType::U64), + ast::ArithDetails::Unsigned(ast::ScalarType::U64), arg, )) | Statement::Instruction(ast::Instruction::Add( ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S64, + typ: ast::ScalarType::S64, saturate: false, }), arg, @@ -4715,12 +4717,12 @@ fn convert_to_stateful_memory_access<'a>( })) } Statement::Instruction(ast::Instruction::Sub( - ast::ArithDetails::Unsigned(ast::UIntType::U64), + ast::ArithDetails::Unsigned(ast::ScalarType::U64), arg, )) | Statement::Instruction(ast::Instruction::Sub( ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S64, + typ: ast::ScalarType::S64, saturate: false, }), arg, @@ -4867,7 +4869,7 @@ fn convert_to_stateful_memory_access_postprocess( ast::LdStateSpace::Global, ), to: old_type, - kind: ConversionKind::PtrToBit(ast::UIntType::U64), + kind: ConversionKind::PtrToBit(ast::ScalarType::U64), src_sema: arg_desc.sema, dst_sema: ArgumentSemantics::Default, })); @@ -5903,7 +5905,9 @@ impl ast::Instruction { let inst_type = d.typ; ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?) } - ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?), + ast::Instruction::Not(t, a) => { + ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?) + } ast::Instruction::Cvt(d, a) => { let (dst_t, src_t) = match &d { ast::CvtDetails::FloatFromFloat(desc) => ( @@ -5926,7 +5930,7 @@ impl ast::Instruction { ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?) } ast::Instruction::Shl(t, a) => { - ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?) + ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?) } ast::Instruction::Shr(t, a) => { ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?) @@ -6176,9 +6180,9 @@ impl ast::Type { ast::Type::Scalar(scalar) => { let kind = scalar.kind(); let width = scalar.size_of(); - if (kind != ScalarKind::Signed - && kind != ScalarKind::Unsigned - && kind != ScalarKind::Bit) + if (kind != ast::ScalarKind::Signed + && kind != ast::ScalarKind::Unsigned + && kind != ast::ScalarKind::Bit) || (width == 8) { return Err(TranslateError::MismatchedType); @@ -6306,7 +6310,7 @@ impl ast::Type { #[derive(Eq, PartialEq, Clone)] struct TypeParts { kind: TypeKind, - scalar_kind: ScalarKind, + scalar_kind: ast::ScalarKind, width: u8, components: Vec, state_space: ast::LdStateSpace, @@ -6461,7 +6465,7 @@ enum ConversionKind { // zero-extend/chop/bitcast depending on types SignExtend, BitToPtr(ast::LdStateSpace), - PtrToBit(ast::UIntType), + PtrToBit(ast::ScalarType), PtrToPtr { spirv_ptr: bool }, } @@ -6859,7 +6863,7 @@ impl ast::Arg4 { fn map_selp>( self, visitor: &mut V, - t: ast::SelpType, + t: ast::ScalarType, ) -> Result, TranslateError> { let dst = visitor.operand( ArgumentDescriptor { @@ -6904,7 +6908,7 @@ impl ast::Arg4 { fn map_atom>( self, visitor: &mut V, - t: ast::BitType, + t: ast::ScalarType, state_space: ast::AtomSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); @@ -7205,103 +7209,41 @@ impl ast::StStateSpace { } } -#[derive(Clone, Copy, PartialEq, Eq)] -enum ScalarKind { - Bit, - Unsigned, - Signed, - Float, - Float2, - Pred, -} - impl ast::ScalarType { - fn kind(self) -> ScalarKind { - match self { - ast::ScalarType::U8 => ScalarKind::Unsigned, - ast::ScalarType::U16 => ScalarKind::Unsigned, - ast::ScalarType::U32 => ScalarKind::Unsigned, - ast::ScalarType::U64 => ScalarKind::Unsigned, - ast::ScalarType::S8 => ScalarKind::Signed, - ast::ScalarType::S16 => ScalarKind::Signed, - ast::ScalarType::S32 => ScalarKind::Signed, - ast::ScalarType::S64 => ScalarKind::Signed, - ast::ScalarType::B8 => ScalarKind::Bit, - ast::ScalarType::B16 => ScalarKind::Bit, - ast::ScalarType::B32 => ScalarKind::Bit, - ast::ScalarType::B64 => ScalarKind::Bit, - ast::ScalarType::F16 => ScalarKind::Float, - ast::ScalarType::F32 => ScalarKind::Float, - ast::ScalarType::F64 => ScalarKind::Float, - ast::ScalarType::F16x2 => ScalarKind::Float2, - ast::ScalarType::Pred => ScalarKind::Pred, - } - } - - fn from_parts(width: u8, kind: ScalarKind) -> Self { + fn from_parts(width: u8, kind: ast::ScalarKind) -> Self { match kind { - ScalarKind::Float => match width { + ast::ScalarKind::Float => match width { 2 => ast::ScalarType::F16, 4 => ast::ScalarType::F32, 8 => ast::ScalarType::F64, _ => unreachable!(), }, - ScalarKind::Bit => match width { + ast::ScalarKind::Bit => match width { 1 => ast::ScalarType::B8, 2 => ast::ScalarType::B16, 4 => ast::ScalarType::B32, 8 => ast::ScalarType::B64, _ => unreachable!(), }, - ScalarKind::Signed => match width { + ast::ScalarKind::Signed => match width { 1 => ast::ScalarType::S8, 2 => ast::ScalarType::S16, 4 => ast::ScalarType::S32, 8 => ast::ScalarType::S64, _ => unreachable!(), }, - ScalarKind::Unsigned => match width { + ast::ScalarKind::Unsigned => match width { 1 => ast::ScalarType::U8, 2 => ast::ScalarType::U16, 4 => ast::ScalarType::U32, 8 => ast::ScalarType::U64, _ => unreachable!(), }, - ScalarKind::Float2 => match width { + ast::ScalarKind::Float2 => match width { 4 => ast::ScalarType::F16x2, _ => unreachable!(), }, - ScalarKind::Pred => ast::ScalarType::Pred, - } - } -} - -impl ast::BooleanType { - fn to_type(self) -> ast::Type { - match self { - ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred), - ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16), - ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32), - ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64), - } - } -} - -impl ast::ShlType { - fn to_type(self) -> ast::Type { - match self { - ast::ShlType::B16 => ast::Type::Scalar(ast::ScalarType::B16), - ast::ShlType::B32 => ast::Type::Scalar(ast::ScalarType::B32), - ast::ShlType::B64 => ast::Type::Scalar(ast::ScalarType::B64), - } - } -} - -impl ast::ShrType { - fn signed(&self) -> bool { - match self { - ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true, - _ => false, + ast::ScalarKind::Pred => ast::ScalarType::Pred, } } } @@ -7357,30 +7299,6 @@ impl ast::AtomInnerDetails { } } -impl ast::SIntType { - fn from_size(width: u8) -> Self { - match width { - 1 => ast::SIntType::S8, - 2 => ast::SIntType::S16, - 4 => ast::SIntType::S32, - 8 => ast::SIntType::S64, - _ => unreachable!(), - } - } -} - -impl ast::UIntType { - fn from_size(width: u8) -> Self { - match width { - 1 => ast::UIntType::U8, - 2 => ast::UIntType::U16, - 4 => ast::UIntType::U32, - 8 => ast::UIntType::U64, - _ => unreachable!(), - } - } -} - impl ast::LdStateSpace { fn to_spirv(self) -> spirv::StorageClass { match self { @@ -7568,16 +7486,18 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { return false; } match inst.kind() { - ScalarKind::Bit => operand.kind() != ScalarKind::Bit, - ScalarKind::Float => operand.kind() == ScalarKind::Bit, - ScalarKind::Signed => { - operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned + ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit, + ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit, + ast::ScalarKind::Signed => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Unsigned } - ScalarKind::Unsigned => { - operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed + ast::ScalarKind::Unsigned => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Signed } - ScalarKind::Float2 => false, - ScalarKind::Pred => false, + ast::ScalarKind::Float2 => false, + ast::ScalarKind::Pred => false, } } (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _)) @@ -7596,7 +7516,7 @@ fn should_bitcast_packed( if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = (operand, instr) { - if scalar.kind() == ScalarKind::Bit + if scalar.kind() == ast::ScalarKind::Bit && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) { return Ok(Some(ConversionKind::Default)); @@ -7644,32 +7564,33 @@ fn should_convert_relaxed_src( } match (src_type, instr_type) { (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { - ScalarKind::Bit => { + ast::ScalarKind::Bit => { if instr_type.size_of() <= src_type.size_of() { Some(ConversionKind::Default) } else { None } } - ScalarKind::Signed | ScalarKind::Unsigned => { + ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => { if instr_type.size_of() <= src_type.size_of() - && src_type.kind() != ScalarKind::Float + && src_type.kind() != ast::ScalarKind::Float { Some(ConversionKind::Default) } else { None } } - ScalarKind::Float => { - if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit + ast::ScalarKind::Float => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() == ast::ScalarKind::Bit { Some(ConversionKind::Default) } else { None } } - ScalarKind::Float2 => todo!(), - ScalarKind::Pred => None, + ast::ScalarKind::Float2 => todo!(), + ast::ScalarKind::Pred => None, }, (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { @@ -7706,15 +7627,15 @@ fn should_convert_relaxed_dst( } match (dst_type, instr_type) { (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { - ScalarKind::Bit => { + ast::ScalarKind::Bit => { if instr_type.size_of() <= dst_type.size_of() { Some(ConversionKind::Default) } else { None } } - ScalarKind::Signed => { - if dst_type.kind() != ScalarKind::Float { + ast::ScalarKind::Signed => { + if dst_type.kind() != ast::ScalarKind::Float { if instr_type.size_of() == dst_type.size_of() { Some(ConversionKind::Default) } else if instr_type.size_of() < dst_type.size_of() { @@ -7726,25 +7647,26 @@ fn should_convert_relaxed_dst( None } } - ScalarKind::Unsigned => { + ast::ScalarKind::Unsigned => { if instr_type.size_of() <= dst_type.size_of() - && dst_type.kind() != ScalarKind::Float + && dst_type.kind() != ast::ScalarKind::Float { Some(ConversionKind::Default) } else { None } } - ScalarKind::Float => { - if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit + ast::ScalarKind::Float => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() == ast::ScalarKind::Bit { Some(ConversionKind::Default) } else { None } } - ScalarKind::Float2 => todo!(), - ScalarKind::Pred => None, + ast::ScalarKind::Float2 => todo!(), + ast::ScalarKind::Pred => None, }, (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { From 4d04fe251d4776722d8e5ee74333e8b5fa8b6931 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 15 Apr 2021 19:21:52 +0200 Subject: [PATCH 02/25] Remove all remaining subenums --- ptx/src/ast.rs | 161 +++++++++++-------------------------------- ptx/src/ptx.lalrpop | 72 +++++++++---------- ptx/src/translate.rs | 46 +++++-------- 3 files changed, 92 insertions(+), 187 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index bc2fa4c..3a7cf98 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -34,43 +34,6 @@ pub enum PtxError { NonExternPointer, } -macro_rules! sub_enum { - ($name:ident { $($variant:ident),+ $(,)? }) => { - sub_enum!{ $name : ScalarType { $($variant),+ } } - }; - ($name:ident : $base_type:ident { $($variant:ident),+ $(,)? }) => { - #[derive(PartialEq, Eq, Clone, Copy)] - pub enum $name { - $( - $variant, - )+ - } - - impl From<$name> for $base_type { - fn from(t: $name) -> $base_type { - match t { - $( - $name::$variant => $base_type::$variant, - )+ - } - } - } - - impl std::convert::TryFrom<$base_type> for $name { - type Error = (); - - fn try_from(t: $base_type) -> Result { - match t { - $( - $base_type::$variant => Ok($name::$variant), - )+ - _ => Err(()), - } - } - } - }; -} - macro_rules! sub_type { ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { sub_type! { $type_name : Type { @@ -118,12 +81,12 @@ macro_rules! sub_type { sub_type! { VariableRegType { Scalar(ScalarType), - Vector(SizedScalarType, u8), + Vector(ScalarType, u8), // Array type is used when emiting SSA statements at the start of a method Array(ScalarType, VecU32), // Pointer variant is used when passing around SLM pointer between // function calls for dynamic SLM - Pointer(SizedScalarType, PointerStateSpace) + Pointer(ScalarType, LdStateSpace) } } @@ -131,9 +94,9 @@ type VecU32 = Vec; sub_type! { VariableLocalType { - Scalar(SizedScalarType), - Vector(SizedScalarType, u8), - Array(SizedScalarType, VecU32), + Scalar(ScalarType), + Vector(ScalarType, u8), + Array(ScalarType, VecU32), } } @@ -152,10 +115,10 @@ impl TryFrom for VariableLocalType { sub_type! { VariableGlobalType { - Scalar(SizedScalarType), - Vector(SizedScalarType, u8), - Array(SizedScalarType, VecU32), - Pointer(SizedScalarType, PointerStateSpace), + Scalar(ScalarType), + Vector(ScalarType, u8), + Array(ScalarType, VecU32), + Pointer(ScalarType, LdStateSpace), } } @@ -167,49 +130,12 @@ sub_type! { // .param .b32 foobar[] sub_type! { VariableParamType { - Scalar(LdStScalarType), - Array(SizedScalarType, VecU32), - Pointer(SizedScalarType, PointerStateSpace), + Scalar(ScalarType), + Array(ScalarType, VecU32), + Pointer(ScalarType, LdStateSpace), } } -sub_enum!(SizedScalarType { - B8, - B16, - B32, - B64, - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64, - F16, - F16x2, - F32, - F64, -}); - -sub_enum!(LdStScalarType { - B8, - B16, - B32, - B64, - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64, - F16, - F32, - F64, -}); - #[derive(Copy, Clone, Eq, PartialEq)] pub enum BarDetails { SyncAligned, @@ -345,16 +271,6 @@ impl FnArgumentType { } } -sub_enum!( - PointerStateSpace : LdStateSpace { - Generic, - Global, - Const, - Shared, - Param, - } -); - #[derive(PartialEq, Eq, Clone)] pub enum Type { Scalar(ScalarType), @@ -371,18 +287,18 @@ pub enum PointerType { Pointer(ScalarType, LdStateSpace), } -impl From for PointerType { - fn from(t: SizedScalarType) -> Self { +impl From for PointerType { + fn from(t: ScalarType) -> Self { PointerType::Scalar(t.into()) } } -impl TryFrom for SizedScalarType { +impl TryFrom for ScalarType { type Error = (); fn try_from(value: PointerType) -> Result { match value { - PointerType::Scalar(t) => Ok(t.try_into()?), + PointerType::Scalar(t) => Ok(t), PointerType::Vector(_, _) => Err(()), PointerType::Array(_, _) => Err(()), PointerType::Pointer(_, _) => Err(()), @@ -685,8 +601,8 @@ pub struct LdDetails { sub_type! { LdStType { - Scalar(LdStScalarType), - Vector(LdStScalarType, u8), + Scalar(ScalarType), + Vector(ScalarType, u8), // Used in generated code Pointer(PointerType, LdStateSpace), } @@ -1135,7 +1051,7 @@ pub struct NegDetails { } impl<'a> NumsOrArrays<'a> { - pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result, PtxError> { + pub fn to_vec(self, typ: ScalarType, dimensions: &mut [u32]) -> Result, PtxError> { self.normalize_dimensions(dimensions)?; let sizeof_t = ScalarType::from(typ).size_of() as usize; let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize)); @@ -1166,7 +1082,7 @@ impl<'a> NumsOrArrays<'a> { fn parse_and_copy( &self, - t: SizedScalarType, + t: ScalarType, size_of_t: usize, dimensions: &[u32], result: &mut [u8], @@ -1206,47 +1122,48 @@ impl<'a> NumsOrArrays<'a> { } fn parse_and_copy_single( - t: SizedScalarType, + t: ScalarType, idx: usize, str_val: &str, radix: u32, output: &mut [u8], ) -> Result<(), PtxError> { match t { - SizedScalarType::B8 | SizedScalarType::U8 => { + ScalarType::B8 | ScalarType::U8 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::B16 | SizedScalarType::U16 => { + ScalarType::B16 | ScalarType::U16 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::B32 | SizedScalarType::U32 => { + ScalarType::B32 | ScalarType::U32 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::B64 | SizedScalarType::U64 => { + ScalarType::B64 | ScalarType::U64 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::S8 => { + ScalarType::S8 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::S16 => { + ScalarType::S16 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::S32 => { + ScalarType::S32 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::S64 => { + ScalarType::S64 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::F16 => { + ScalarType::F16 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::F16x2 => todo!(), - SizedScalarType::F32 => { + ScalarType::F16x2 => todo!(), + ScalarType::F32 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::F64 => { + ScalarType::F64 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } + ScalarType::Pred => todo!() } Ok(()) } @@ -1334,13 +1251,13 @@ mod tests { #[test] fn array_fails_multiple_0_dmiensions() { let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0, 0]).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut vec![0, 0]).is_err()); } #[test] fn array_fails_on_empty() { let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0]).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut vec![0]).is_err()); } #[test] @@ -1352,7 +1269,7 @@ mod tests { let mut dimensions = vec![0u32, 2]; assert_eq!( vec![1u8, 2, 3, 4], - inp.to_vec(SizedScalarType::B8, &mut dimensions).unwrap() + inp.to_vec(ScalarType::B8, &mut dimensions).unwrap() ); assert_eq!(dimensions, vec![2u32, 2]); } @@ -1364,7 +1281,7 @@ mod tests { NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]), ]); let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err()); } #[test] @@ -1374,6 +1291,6 @@ mod tests { NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]), ]); let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err()); } } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 41c1d73..7bd9c4f 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -611,9 +611,9 @@ ModuleVariable: ast::Variable = { return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); } if space == ".global" { - (ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Global)), Vec::new()) + (ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Global)), Vec::new()) } else { - (ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Shared)), Vec::new()) + (ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Shared)), Vec::new()) } } }; @@ -635,7 +635,7 @@ ParamVariable: (Option, Vec, ast::VariableParamType, &'input str) = { (ast::VariableParamType::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { - (ast::VariableParamType::Pointer(t, ast::PointerStateSpace::Param), Vec::new()) + (ast::VariableParamType::Pointer(t, ast::LdStateSpace::Param), Vec::new()) } }; (align, array_init, v_type, name) @@ -667,42 +667,42 @@ GlobalVariableDefinitionNoArray: (Option, ast::VariableGlobalType, &'input } #[inline] -SizedScalarType: ast::SizedScalarType = { - ".b8" => ast::SizedScalarType::B8, - ".b16" => ast::SizedScalarType::B16, - ".b32" => ast::SizedScalarType::B32, - ".b64" => ast::SizedScalarType::B64, - ".u8" => ast::SizedScalarType::U8, - ".u16" => ast::SizedScalarType::U16, - ".u32" => ast::SizedScalarType::U32, - ".u64" => ast::SizedScalarType::U64, - ".s8" => ast::SizedScalarType::S8, - ".s16" => ast::SizedScalarType::S16, - ".s32" => ast::SizedScalarType::S32, - ".s64" => ast::SizedScalarType::S64, - ".f16" => ast::SizedScalarType::F16, - ".f16x2" => ast::SizedScalarType::F16x2, - ".f32" => ast::SizedScalarType::F32, - ".f64" => ast::SizedScalarType::F64, +SizedScalarType: ast::ScalarType = { + ".b8" => ast::ScalarType::B8, + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u8" => ast::ScalarType::U8, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s8" => ast::ScalarType::S8, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f16" => ast::ScalarType::F16, + ".f16x2" => ast::ScalarType::F16x2, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, } #[inline] -LdStScalarType: ast::LdStScalarType = { - ".b8" => ast::LdStScalarType::B8, - ".b16" => ast::LdStScalarType::B16, - ".b32" => ast::LdStScalarType::B32, - ".b64" => ast::LdStScalarType::B64, - ".u8" => ast::LdStScalarType::U8, - ".u16" => ast::LdStScalarType::U16, - ".u32" => ast::LdStScalarType::U32, - ".u64" => ast::LdStScalarType::U64, - ".s8" => ast::LdStScalarType::S8, - ".s16" => ast::LdStScalarType::S16, - ".s32" => ast::LdStScalarType::S32, - ".s64" => ast::LdStScalarType::S64, - ".f16" => ast::LdStScalarType::F16, - ".f32" => ast::LdStScalarType::F32, - ".f64" => ast::LdStScalarType::F64, +LdStScalarType: ast::ScalarType = { + ".b8" => ast::ScalarType::B8, + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u8" => ast::ScalarType::U8, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s8" => ast::ScalarType::S8, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f16" => ast::ScalarType::F16, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, } Instruction: ast::Instruction> = { diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 51b1dc6..7eec085 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -97,18 +97,6 @@ impl ast::Type { } } -impl Into for ast::PointerStateSpace { - fn into(self) -> spirv::StorageClass { - match self { - ast::PointerStateSpace::Const => spirv::StorageClass::UniformConstant, - ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup, - ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup, - ast::PointerStateSpace::Param => spirv::StorageClass::Function, - ast::PointerStateSpace::Generic => spirv::StorageClass::Generic, - } - } -} - impl From for SpirvType { fn from(t: ast::ScalarType) -> Self { SpirvType::Base(t.into()) @@ -824,8 +812,8 @@ fn convert_dynamic_shared_memory_usage<'input>( name: shared_var_id, array_init: Vec::new(), v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( - ast::SizedScalarType::B8, - ast::PointerStateSpace::Shared, + ast::ScalarType::B8, + ast::LdStateSpace::Shared, )), }); let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails { @@ -863,7 +851,7 @@ fn convert_dynamic_shared_memory_usage<'input>( fn replace_uses_of_shared_memory<'a>( result: &mut Vec, new_id: &mut impl FnMut() -> spirv::Word, - extern_shared_decls: &HashMap, + extern_shared_decls: &HashMap, methods_using_extern_shared: &mut HashSet>, shared_id_param: spirv::Word, shared_var_id: spirv::Word, @@ -884,7 +872,7 @@ fn replace_uses_of_shared_memory<'a>( statement => { let new_statement = statement.map_id(&mut |id, _| { if let Some(typ) = extern_shared_decls.get(&id) { - if *typ == ast::SizedScalarType::B8 { + if *typ == ast::ScalarType::B8 { return shared_var_id; } let replacement_id = new_id(); @@ -1505,7 +1493,7 @@ fn extract_globals<'input, 'b>( d, a, "inc", - ast::SizedScalarType::U32, + ast::ScalarType::U32, )); } Statement::Instruction(ast::Instruction::Atom( @@ -1527,7 +1515,7 @@ fn extract_globals<'input, 'b>( d, a, "dec", - ast::SizedScalarType::U32, + ast::ScalarType::U32, )); } Statement::Instruction(ast::Instruction::Atom( @@ -1553,8 +1541,8 @@ fn extract_globals<'input, 'b>( space, }; let (op, typ) = match typ { - ast::ScalarType::F32 => ("add_f32", ast::SizedScalarType::F32), - ast::ScalarType::F64 => ("add_f64", ast::SizedScalarType::F64), + ast::ScalarType::F32 => ("add_f32", ast::ScalarType::F32), + ast::ScalarType::F64 => ("add_f64", ast::ScalarType::F64), _ => unreachable!(), }; local.push(to_ptx_impl_atomic_call( @@ -1734,7 +1722,7 @@ fn to_ptx_impl_atomic_call( details: ast::AtomDetails, arg: ast::Arg3, op: &'static str, - typ: ast::SizedScalarType, + typ: ast::ScalarType, ) -> ExpandedStatement { let semantics = ptx_semantics_name(details.semantics); let scope = ptx_scope_name(details.scope); @@ -1745,9 +1733,9 @@ fn to_ptx_impl_atomic_call( ); // TODO: extract to a function let ptr_space = match details.space { - ast::AtomSpace::Generic => ast::PointerStateSpace::Generic, - ast::AtomSpace::Global => ast::PointerStateSpace::Global, - ast::AtomSpace::Shared => ast::PointerStateSpace::Shared, + ast::AtomSpace::Generic => ast::LdStateSpace::Generic, + ast::AtomSpace::Global => ast::LdStateSpace::Global, + ast::AtomSpace::Shared => ast::LdStateSpace::Shared, }; let scalar_typ = ast::ScalarType::from(typ); let fn_id = match ptx_impl_imports.entry(fn_name) { @@ -4565,7 +4553,7 @@ fn convert_to_stateful_memory_access<'a>( Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { state_space: ast::LdStateSpace::Param, - typ: ast::LdStType::Scalar(ast::LdStScalarType::U64), + typ: ast::LdStType::Scalar(ast::ScalarType::U64), .. }, arg, @@ -4573,7 +4561,7 @@ fn convert_to_stateful_memory_access<'a>( | Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { state_space: ast::LdStateSpace::Param, - typ: ast::LdStType::Scalar(ast::LdStScalarType::S64), + typ: ast::LdStType::Scalar(ast::ScalarType::S64), .. }, arg, @@ -4581,7 +4569,7 @@ fn convert_to_stateful_memory_access<'a>( | Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { state_space: ast::LdStateSpace::Param, - typ: ast::LdStType::Scalar(ast::LdStScalarType::B64), + typ: ast::LdStType::Scalar(ast::ScalarType::B64), .. }, arg, @@ -4672,8 +4660,8 @@ fn convert_to_stateful_memory_access<'a>( name: new_id, array_init: Vec::new(), v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( - ast::SizedScalarType::U8, - ast::PointerStateSpace::Global, + ast::ScalarType::U8, + ast::LdStateSpace::Global, )), })); remapped_ids.insert(reg, new_id); From 8cd3db66485175b7589f7d3828a147fafe5ecaed Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 15 Apr 2021 19:53:54 +0200 Subject: [PATCH 03/25] Remove LdStType --- ptx/src/ast.rs | 26 ++------------------------ ptx/src/ptx.lalrpop | 6 +++--- ptx/src/translate.rs | 6 +++--- 3 files changed, 8 insertions(+), 30 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 3a7cf98..6a01a6a 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -595,32 +595,10 @@ pub struct LdDetails { pub qualifier: LdStQualifier, pub state_space: LdStateSpace, pub caching: LdCacheOperator, - pub typ: LdStType, + pub typ: PointerType, pub non_coherent: bool, } -sub_type! { - LdStType { - Scalar(ScalarType), - Vector(ScalarType, u8), - // Used in generated code - Pointer(PointerType, LdStateSpace), - } -} - -impl From for PointerType { - fn from(t: LdStType) -> Self { - match t { - LdStType::Scalar(t) => PointerType::Scalar(t.into()), - LdStType::Vector(t, len) => PointerType::Vector(t.into(), len), - LdStType::Pointer(PointerType::Scalar(scalar_type), space) => { - PointerType::Pointer(scalar_type, space) - } - LdStType::Pointer(..) => unreachable!(), - } - } -} - #[derive(Copy, Clone, PartialEq, Eq)] pub enum LdStQualifier { Weak, @@ -853,7 +831,7 @@ pub struct StData { pub qualifier: LdStQualifier, pub state_space: StStateSpace, pub caching: StCacheOperator, - pub typ: LdStType, + pub typ: PointerType, } #[derive(PartialEq, Eq, Copy, Clone)] diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 7bd9c4f..44852a2 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -789,9 +789,9 @@ InstLd: ast::Instruction> = { } }; -LdStType: ast::LdStType = { - => ast::LdStType::Vector(t, v), - => ast::LdStType::Scalar(t), +LdStType: ast::PointerType = { + => ast::PointerType::Vector(t, v), + => ast::PointerType::Scalar(t), } LdStQualifier: ast::LdStQualifier = { diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7eec085..1f647bd 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -4553,7 +4553,7 @@ fn convert_to_stateful_memory_access<'a>( Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { state_space: ast::LdStateSpace::Param, - typ: ast::LdStType::Scalar(ast::ScalarType::U64), + typ: ast::PointerType::Scalar(ast::ScalarType::U64), .. }, arg, @@ -4561,7 +4561,7 @@ fn convert_to_stateful_memory_access<'a>( | Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { state_space: ast::LdStateSpace::Param, - typ: ast::LdStType::Scalar(ast::ScalarType::S64), + typ: ast::PointerType::Scalar(ast::ScalarType::S64), .. }, arg, @@ -4569,7 +4569,7 @@ fn convert_to_stateful_memory_access<'a>( | Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { state_space: ast::LdStateSpace::Param, - typ: ast::LdStType::Scalar(ast::ScalarType::B64), + typ: ast::PointerType::Scalar(ast::ScalarType::B64), .. }, arg, From a55c851eaa4ded60d5f62aba1d7da850a63163f3 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 15 Apr 2021 20:01:01 +0200 Subject: [PATCH 04/25] Add comment --- ptx/src/ast.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 6a01a6a..3e62cb1 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -284,6 +284,7 @@ pub enum PointerType { Scalar(ScalarType), Vector(ScalarType, u8), Array(ScalarType, VecU32), + // Instances of this variant are generated during stateful conversion Pointer(ScalarType, LdStateSpace), } @@ -1141,7 +1142,7 @@ impl<'a> NumsOrArrays<'a> { ScalarType::F64 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - ScalarType::Pred => todo!() + ScalarType::Pred => todo!(), } Ok(()) } From d51aaaf5529dbfec0735c73768e468728112c26b Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 17 Apr 2021 14:01:50 +0200 Subject: [PATCH 05/25] Throw away special variable types --- ptx/src/ast.rs | 215 +--------------------- ptx/src/ptx.lalrpop | 102 +++++----- ptx/src/translate.rs | 429 +++++++++++++++++++------------------------ 3 files changed, 256 insertions(+), 490 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 3e62cb1..c7b9563 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,6 +1,5 @@ use half::f16; use lalrpop_util::{lexer::Token, ParseError}; -use std::convert::TryInto; use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; @@ -34,107 +33,12 @@ pub enum PtxError { NonExternPointer, } -macro_rules! sub_type { - ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { - sub_type! { $type_name : Type { - $( - $variant ($($field_type),+), - )+ - }} - }; - ($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { - #[derive(PartialEq, Eq, Clone)] - pub enum $type_name { - $( - $variant ($($field_type),+), - )+ - } - - impl From<$type_name> for $base_type { - #[allow(non_snake_case)] - fn from(t: $type_name) -> $base_type { - match t { - $( - $type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+), - )+ - } - } - } - - impl std::convert::TryFrom<$base_type> for $type_name { - type Error = (); - - #[allow(non_snake_case)] - #[allow(unreachable_patterns)] - fn try_from(t: $base_type) -> Result { - match t { - $( - $base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )), - )+ - _ => Err(()), - } - } - } - }; -} - -sub_type! { - VariableRegType { - Scalar(ScalarType), - Vector(ScalarType, u8), - // Array type is used when emiting SSA statements at the start of a method - Array(ScalarType, VecU32), - // Pointer variant is used when passing around SLM pointer between - // function calls for dynamic SLM - Pointer(ScalarType, LdStateSpace) - } -} - -type VecU32 = Vec; - -sub_type! { - VariableLocalType { - Scalar(ScalarType), - Vector(ScalarType, u8), - Array(ScalarType, VecU32), - } -} - -impl TryFrom for VariableLocalType { - type Error = PtxError; - - fn try_from(value: VariableGlobalType) -> Result { - match value { - VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)), - VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)), - VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)), - VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray), - } - } -} - -sub_type! { - VariableGlobalType { - Scalar(ScalarType), - Vector(ScalarType, u8), - Array(ScalarType, VecU32), - Pointer(ScalarType, LdStateSpace), - } -} - // For some weird reson this is illegal: // .param .f16x2 foobar; // but this is legal: // .param .f16x2 foobar[1]; // even more interestingly this is legal, but only in .func (not in .entry): // .param .b32 foobar[] -sub_type! { - VariableParamType { - Scalar(ScalarType), - Array(ScalarType, VecU32), - Pointer(ScalarType, LdStateSpace), - } -} #[derive(Copy, Clone, Eq, PartialEq)] pub enum BarDetails { @@ -178,7 +82,7 @@ pub struct Module<'a> { } pub enum Directive<'a, P: ArgParams> { - Variable(Variable), + Variable(Variable), Method(Function<'a, &'a str, Statement

>), } @@ -190,8 +94,8 @@ pub enum MethodDecl<'a, ID> { }, } -pub type FnArgument = Variable; -pub type KernelArgument = Variable; +pub type FnArgument = Variable; +pub type KernelArgument = Variable; pub struct Function<'a, ID, S> { pub func_directive: MethodDecl<'a, ID>, @@ -201,76 +105,6 @@ pub struct Function<'a, ID, S> { pub type ParsedFunction<'a> = Function<'a, &'a str, Statement>>; -#[derive(PartialEq, Eq, Clone)] -pub enum FnArgumentType { - Reg(VariableRegType), - Param(VariableParamType), - Shared, -} -#[derive(PartialEq, Eq, Clone)] -pub enum KernelArgumentType { - Normal(VariableParamType), - Shared, -} - -impl From for Type { - fn from(this: KernelArgumentType) -> Self { - match this { - KernelArgumentType::Normal(typ) => typ.into(), - KernelArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } -} - -impl FnArgumentType { - pub fn to_type(&self, is_kernel: bool) -> Type { - if is_kernel { - self.to_kernel_type() - } else { - self.to_func_type() - } - } - - pub fn to_kernel_type(&self) -> Type { - match self { - FnArgumentType::Reg(x) => x.clone().into(), - FnArgumentType::Param(x) => x.clone().into(), - FnArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } - - pub fn to_func_type(&self) -> Type { - match self { - FnArgumentType::Reg(x) => x.clone().into(), - FnArgumentType::Param(VariableParamType::Scalar(t)) => { - Type::Pointer(PointerType::Scalar((*t).into()), LdStateSpace::Param) - } - FnArgumentType::Param(VariableParamType::Array(t, dims)) => Type::Pointer( - PointerType::Array((*t).into(), dims.clone()), - LdStateSpace::Param, - ), - FnArgumentType::Param(VariableParamType::Pointer(t, space)) => Type::Pointer( - PointerType::Pointer((*t).into(), (*space).into()), - LdStateSpace::Param, - ), - FnArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } - - pub fn is_param(&self) -> bool { - match self { - FnArgumentType::Param(_) => true, - _ => false, - } - } -} - #[derive(PartialEq, Eq, Clone)] pub enum Type { Scalar(ScalarType), @@ -283,7 +117,7 @@ pub enum Type { pub enum PointerType { Scalar(ScalarType), Vector(ScalarType, u8), - Array(ScalarType, VecU32), + Array(ScalarType, Vec), // Instances of this variant are generated during stateful conversion Pointer(ScalarType, LdStateSpace), } @@ -366,51 +200,19 @@ pub enum Statement { } pub struct MultiVariable { - pub var: Variable, + pub var: Variable, pub count: Option, } #[derive(Clone)] -pub struct Variable { +pub struct Variable { pub align: Option, - pub v_type: T, + pub v_type: Type, + pub state_space: StateSpace, pub name: ID, pub array_init: Vec, } -#[derive(Eq, PartialEq, Clone)] -pub enum VariableType { - Reg(VariableRegType), - Local(VariableLocalType), - Param(VariableParamType), - Global(VariableGlobalType), - Shared(VariableGlobalType), -} - -impl VariableType { - pub fn to_type(&self) -> (StateSpace, Type) { - match self { - VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()), - VariableType::Local(t) => (StateSpace::Local, t.clone().into()), - VariableType::Param(t) => (StateSpace::Param, t.clone().into()), - VariableType::Global(t) => (StateSpace::Global, t.clone().into()), - VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()), - } - } -} - -impl From for Type { - fn from(t: VariableType) -> Self { - match t { - VariableType::Reg(t) => t.into(), - VariableType::Local(t) => t.into(), - VariableType::Param(t) => t.into(), - VariableType::Global(t) => t.into(), - VariableType::Shared(t) => t.into(), - } - } -} - #[derive(Copy, Clone, PartialEq, Eq)] pub enum StateSpace { Reg, @@ -419,6 +221,7 @@ pub enum StateSpace { Local, Shared, Param, + Generic, } pub struct PredAt { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 44852a2..dc439b7 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -404,28 +404,29 @@ FnArguments: Vec> = { "(" > ")" => args }; -KernelInput: ast::Variable = { +KernelInput: ast::Variable<&'input str> = { => { let (align, v_type, name) = v; ast::Variable { align, - v_type: ast::KernelArgumentType::Normal(v_type), + v_type, + state_space: ast::StateSpace::Param, name, array_init: Vec::new() } } } -FnInput: ast::Variable = { +FnInput: ast::Variable<&'input str> = { => { let (align, v_type, name) = v; - let v_type = ast::FnArgumentType::Reg(v_type); - ast::Variable{ align, v_type, name, array_init: Vec::new() } + let state_space = ast::StateSpace::Reg; + ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() } }, => { let (align, v_type, name) = v; - let v_type = ast::FnArgumentType::Param(v_type); - ast::Variable{ align, v_type, name, array_init: Vec::new() } + let state_space = ast::StateSpace::Param; + ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() } } } @@ -508,102 +509,109 @@ VariableParam: u32 = { "<" ">" => n } -Variable: ast::Variable = { +Variable: ast::Variable<&'input str> = { => { let (align, v_type, name) = v; - let v_type = ast::VariableType::Reg(v_type); - ast::Variable {align, v_type, name, array_init: Vec::new()} + let state_space = ast::StateSpace::Reg; + ast::Variable {align, v_type, state_space, name, array_init: Vec::new()} }, LocalVariable, => { let (align, array_init, v_type, name) = v; - let v_type = ast::VariableType::Param(v_type); - ast::Variable {align, v_type, name, array_init} + let state_space = ast::StateSpace::Param; + ast::Variable {align, v_type, state_space, name, array_init} }, SharedVariable, }; -RegVariable: (Option, ast::VariableRegType, &'input str) = { +RegVariable: (Option, ast::Type, &'input str) = { ".reg" > => { let (align, t, name) = var; - let v_type = ast::VariableRegType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, v_type, name) }, ".reg" > => { let (align, v_len, t, name) = var; - let v_type = ast::VariableRegType::Vector(t, v_len); + let v_type = ast::Type::Vector(t, v_len); (align, v_type, name) } } -LocalVariable: ast::Variable = { +LocalVariable: ast::Variable<&'input str> = { ".local" > => { let (align, t, name) = var; - let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t)); - ast::Variable { align, v_type, name, array_init: Vec::new() } + let v_type = ast::Type::Scalar(t); + let state_space = ast::StateSpace::Local; + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".local" > => { let (align, v_len, t, name) = var; - let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len)); - ast::Variable { align, v_type, name, array_init: Vec::new() } + let v_type = ast::Type::Vector(t, v_len); + let state_space = ast::StateSpace::Local; + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".local" > =>? { let (align, t, name, arr_or_ptr) = var; + let state_space = ast::StateSpace::Local; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableLocalType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); } }; - Ok(ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init }) + Ok(ast::Variable { align, v_type, state_space, name, array_init }) } } -SharedVariable: ast::Variable = { +SharedVariable: ast::Variable<&'input str> = { ".shared" > => { let (align, t, name) = var; - let v_type = ast::VariableGlobalType::Scalar(t); - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + let v_type = ast::Type::Scalar(t); + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".shared" > => { let (align, v_len, t, name) = var; - let v_type = ast::VariableGlobalType::Vector(t, v_len); - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + let v_type = ast::Type::Vector(t, v_len); + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".shared" > =>? { let (align, t, name, arr_or_ptr) = var; + let state_space = ast::StateSpace::Shared; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableGlobalType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); } }; - Ok(ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init }) + Ok(ast::Variable { align, v_type, state_space, name, array_init }) } } - -ModuleVariable: ast::Variable = { +ModuleVariable: ast::Variable<&'input str> = { LinkingDirectives ".global" => { let (align, v_type, name, array_init) = def; - ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init } + let state_space = ast::StateSpace::Global; + ast::Variable { align, v_type, state_space, name, array_init } }, LinkingDirectives ".shared" => { let (align, v_type, name, array_init) = def; - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, > > =>? { let (align, t, name, arr_or_ptr) = var; - let (v_type, array_init) = match arr_or_ptr { + let (v_type, state_space, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { if space == ".global" { - (ast::VariableType::Global(ast::VariableGlobalType::Array(t, dimensions)), init) + (ast::Type::Array(t, dimensions), ast::StateSpace::Global, init) } else { - (ast::VariableType::Shared(ast::VariableGlobalType::Array(t, dimensions)), init) + (ast::Type::Array(t, dimensions), ast::StateSpace::Shared, init) } } ast::ArrayOrPointer::Pointer => { @@ -611,38 +619,38 @@ ModuleVariable: ast::Variable = { return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); } if space == ".global" { - (ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Global)), Vec::new()) + (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Global), ast::StateSpace::Global, Vec::new()) } else { - (ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Shared)), Vec::new()) + (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Shared), ast::StateSpace::Shared, Vec::new()) } } }; - Ok(ast::Variable{ align, array_init, v_type, name }) + Ok(ast::Variable{ align, v_type, state_space, name, array_init }) } } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space -ParamVariable: (Option, Vec, ast::VariableParamType, &'input str) = { +ParamVariable: (Option, Vec, ast::Type, &'input str) = { ".param" > => { let (align, t, name) = var; - let v_type = ast::VariableParamType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, Vec::new(), v_type, name) }, ".param" > => { let (align, t, name, arr_or_ptr) = var; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableParamType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { - (ast::VariableParamType::Pointer(t, ast::LdStateSpace::Param), Vec::new()) + (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Param), Vec::new()) } }; (align, array_init, v_type, name) } } -ParamDeclaration: (Option, ast::VariableParamType, &'input str) = { +ParamDeclaration: (Option, ast::Type, &'input str) = { =>? { let (align, array_init, v_type, name) = var; if array_init.len() > 0 { @@ -653,15 +661,15 @@ ParamDeclaration: (Option, ast::VariableParamType, &'input str) = { } } -GlobalVariableDefinitionNoArray: (Option, ast::VariableGlobalType, &'input str, Vec) = { +GlobalVariableDefinitionNoArray: (Option, ast::Type, &'input str, Vec) = { > => { let (align, t, name) = scalar; - let v_type = ast::VariableGlobalType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, v_type, name, Vec::new()) }, > => { let (align, v_len, t, name) = var; - let v_type = ast::VariableGlobalType::Vector(t, v_len); + let v_type = ast::Type::Vector(t, v_len); (align, v_type, name, Vec::new()) }, } diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 1f647bd..4ba5729 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -714,12 +714,13 @@ fn convert_dynamic_shared_memory_usage<'input>( let mut extern_shared_decls = HashMap::new(); for dir in module.iter() { match dir { - Directive::Variable(var) => { - if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) = - var.v_type - { - extern_shared_decls.insert(var.name, p_type); - } + Directive::Variable(ast::Variable { + v_type: ast::Type::Pointer(p_type, ast::LdStateSpace::Shared), + state_space: ast::StateSpace::Shared, + name, + .. + }) => { + extern_shared_decls.insert(*name, p_type.clone()); } _ => {} } @@ -796,25 +797,27 @@ fn convert_dynamic_shared_memory_usage<'input>( let shared_id_param = new_id(); spirv_decl.input.push({ ast::Variable { + name: shared_id_param, align: None, v_type: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), + ast::PointerType::Scalar(ast::ScalarType::B8), ast::LdStateSpace::Shared, ), + state_space: ast::StateSpace::Param, array_init: Vec::new(), - name: shared_id_param, } }); spirv_decl.uses_shared_mem = true; let shared_var_id = new_id(); let shared_var = ExpandedStatement::Variable(ast::Variable { - align: None, name: shared_var_id, - array_init: Vec::new(), - v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( - ast::ScalarType::B8, + align: None, + v_type: ast::Type::Pointer( + ast::PointerType::Scalar(ast::ScalarType::B8), ast::LdStateSpace::Shared, - )), + ), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), }); let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails { arg: ast::Arg2St { @@ -851,7 +854,7 @@ fn convert_dynamic_shared_memory_usage<'input>( fn replace_uses_of_shared_memory<'a>( result: &mut Vec, new_id: &mut impl FnMut() -> spirv::Word, - extern_shared_decls: &HashMap, + extern_shared_decls: &HashMap, methods_using_extern_shared: &mut HashSet>, shared_id_param: spirv::Word, shared_var_id: spirv::Word, @@ -864,14 +867,17 @@ fn replace_uses_of_shared_memory<'a>( // because there's simply no way to pass shared ptr // without converting it to .b64 first if methods_using_extern_shared.contains(&MethodName::Func(call.func)) { - call.param_list - .push((shared_id_param, ast::FnArgumentType::Shared)); + call.param_list.push(( + shared_id_param, + ast::Type::Scalar(ast::ScalarType::B8), + ast::StateSpace::Shared, + )); } result.push(Statement::Call(call)) } statement => { let new_statement = statement.map_id(&mut |id, _| { - if let Some(typ) = extern_shared_decls.get(&id) { + if let Some(ast::PointerType::Scalar(typ)) = extern_shared_decls.get(&id) { if *typ == ast::ScalarType::B8 { return shared_var_id; } @@ -1067,7 +1073,7 @@ fn emit_function_header<'a>( builder: &mut dr::Builder, map: &mut TypeWordMap, defined_globals: &GlobalStringIdResolver<'a>, - synthetic_globals: &[ast::Variable], + synthetic_globals: &[ast::Variable], func_decl: &SpirvMethodDecl<'a>, _denorm_information: &HashMap, HashMap>, call_map: &HashMap<&'a str, HashSet>, @@ -1204,9 +1210,9 @@ fn translate_directive<'input>( fn translate_variable<'a>( id_defs: &mut GlobalStringIdResolver<'a>, - var: ast::Variable, -) -> Result, TranslateError> { - let (space, var_type) = var.v_type.to_type(); + var: ast::Variable<&'a str>, +) -> Result, TranslateError> { + let (space, var_type) = (var.state_space, var.v_type.clone()); let mut is_variable = false; let var_type = match space { ast::StateSpace::Reg => { @@ -1226,10 +1232,12 @@ fn translate_variable<'a>( } } ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?, + ast::StateSpace::Generic => todo!(), }; Ok(ast::Variable { align: var.align, v_type: var.v_type, + state_space: var.state_space, name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable), array_init: var.array_init, }) @@ -1279,6 +1287,7 @@ fn expand_kernel_params<'a, 'b>( false, ), v_type: a.v_type.clone(), + state_space: a.state_space, align: a.align, array_init: Vec::new(), }) @@ -1291,14 +1300,11 @@ fn expand_fn_params<'a, 'b>( args: impl Iterator>, ) -> Result>, TranslateError> { args.map(|a| { - let is_variable = match a.v_type { - ast::FnArgumentType::Reg(_) => true, - _ => false, - }; - let var_type = a.v_type.to_func_type(); + let is_variable = a.state_space == ast::StateSpace::Reg; Ok(ast::FnArgument { - name: fn_resolver.add_def(a.name, Some(var_type), is_variable), + name: fn_resolver.add_def(a.name, Some(a.v_type.clone()), is_variable), v_type: a.v_type.clone(), + state_space: a.state_space, align: a.align, array_init: Vec::new(), }) @@ -1444,10 +1450,7 @@ fn extract_globals<'input, 'b>( sorted_statements: Vec, ptx_impl_imports: &mut HashMap, id_def: &mut NumericIdResolver, -) -> ( - Vec, - Vec>, -) { +) -> (Vec, Vec>) { let mut local = Vec::with_capacity(sorted_statements.len()); let mut global = Vec::new(); for statement in sorted_statements { @@ -1456,7 +1459,7 @@ fn extract_globals<'input, 'b>( var @ ast::Variable { - v_type: ast::VariableType::Shared(_), + state_space: ast::StateSpace::Shared, .. }, ) @@ -1464,7 +1467,7 @@ fn extract_globals<'input, 'b>( var @ ast::Variable { - v_type: ast::VariableType::Global(_), + state_space: ast::StateSpace::Global, .. }, ) => global.push(var), @@ -1592,10 +1595,10 @@ fn convert_to_typed_statements( let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params); let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args .into_iter() - .partition(|(_, arg_type)| arg_type.is_param()); + .partition(|(_, _, space)| *space == ast::StateSpace::Param); let normalized_input_args = out_params .into_iter() - .map(|(id, typ)| (ast::Operand::Reg(id), typ)) + .map(|(id, typ, space)| (ast::Operand::Reg(id), typ, space)) .chain(in_args.into_iter()) .collect(); let resolved_call = ResolvedCall { @@ -1744,7 +1747,8 @@ fn to_ptx_impl_atomic_call( let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), + v_type: ast::Type::Scalar(scalar_typ), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }], @@ -1752,15 +1756,15 @@ fn to_ptx_impl_atomic_call( vec![ ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer( - typ, ptr_space, - )), + v_type: ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), + v_type: ast::Type::Scalar(scalar_typ), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, @@ -1789,18 +1793,17 @@ fn to_ptx_impl_atomic_call( Statement::Call(ResolvedCall { uniform: false, func: fn_id, - ret_params: vec![( - arg.dst, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), - )], + ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)], param_list: vec![ ( arg.src1, - ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(typ, ptr_space)), + ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space), + ast::StateSpace::Reg, ), ( arg.src2, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), + ast::Type::Scalar(scalar_typ), + ast::StateSpace::Reg, ), ], }) @@ -1827,7 +1830,8 @@ fn to_ptx_impl_bfe_call( let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }], @@ -1835,23 +1839,22 @@ fn to_ptx_impl_bfe_call( vec![ ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( - ast::ScalarType::U32, - )), + v_type: ast::Type::Scalar(ast::ScalarType::U32), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( - ast::ScalarType::U32, - )), + v_type: ast::Type::Scalar(ast::ScalarType::U32), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, @@ -1880,22 +1883,22 @@ fn to_ptx_impl_bfe_call( Statement::Call(ResolvedCall { uniform: false, func: fn_id, - ret_params: vec![( - arg.dst, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), - )], + ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], param_list: vec![ ( arg.src1, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + ast::Type::Scalar(typ.into()), + ast::StateSpace::Reg, ), ( arg.src2, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, ), ( arg.src3, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, ), ], }) @@ -1920,7 +1923,8 @@ fn to_ptx_impl_bfi_call( let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }], @@ -1928,29 +1932,29 @@ fn to_ptx_impl_bfi_call( vec![ ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( - ast::ScalarType::U32, - )), + v_type: ast::Type::Scalar(ast::ScalarType::U32), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( - ast::ScalarType::U32, - )), + v_type: ast::Type::Scalar(ast::ScalarType::U32), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, @@ -1979,26 +1983,27 @@ fn to_ptx_impl_bfi_call( Statement::Call(ResolvedCall { uniform: false, func: fn_id, - ret_params: vec![( - arg.dst, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), - )], + ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], param_list: vec![ ( arg.src1, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + ast::Type::Scalar(typ.into()), + ast::StateSpace::Reg, ), ( arg.src2, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + ast::Type::Scalar(typ.into()), + ast::StateSpace::Reg, ), ( arg.src3, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, ), ( arg.src4, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, ), ], }) @@ -2006,12 +2011,12 @@ fn to_ptx_impl_bfi_call( fn to_resolved_fn_args( params: Vec, - params_decl: &[ast::FnArgumentType], -) -> Vec<(T, ast::FnArgumentType)> { + params_decl: &[(ast::Type, ast::StateSpace)], +) -> Vec<(T, ast::Type, ast::StateSpace)> { params .into_iter() .zip(params_decl.iter()) - .map(|(id, typ)| (id, typ.clone())) + .map(|(id, (typ, space))| (id, typ.clone(), *space)) .collect::>() } @@ -2096,50 +2101,38 @@ fn normalize_predicates( fn insert_mem_ssa_statements<'a, 'b>( func: Vec, id_def: &mut NumericIdResolver, - ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>, + _: &'a ast::MethodDecl<'b, spirv::Word>, fn_decl: &mut SpirvMethodDecl, ) -> Result, TranslateError> { - let is_func = match ast_fn_decl { - ast::MethodDecl::Func(..) => true, - ast::MethodDecl::Kernel { .. } => false, - }; let mut result = Vec::with_capacity(func.len()); for arg in fn_decl.output.iter() { - match type_to_variable_type(&arg.v_type, is_func)? { - Some(var_type) => { - result.push(Statement::Variable(ast::Variable { - align: arg.align, - v_type: var_type, - name: arg.name, - array_init: arg.array_init.clone(), - })); - } - None => return Err(error_unreachable()), - } + result.push(Statement::Variable(ast::Variable { + align: arg.align, + v_type: arg.v_type.clone(), + state_space: arg.state_space, + name: arg.name, + array_init: arg.array_init.clone(), + })); } for spirv_arg in fn_decl.input.iter_mut() { - match type_to_variable_type(&spirv_arg.v_type, is_func)? { - Some(var_type) => { - let typ = spirv_arg.v_type.clone(); - let new_id = id_def.new_non_variable(Some(typ.clone())); - result.push(Statement::Variable(ast::Variable { - align: spirv_arg.align, - v_type: var_type, - name: spirv_arg.name, - array_init: spirv_arg.array_init.clone(), - })); - result.push(Statement::StoreVar(StoreVarDetails { - arg: ast::Arg2St { - src1: spirv_arg.name, - src2: new_id, - }, - typ, - member_index: None, - })); - spirv_arg.name = new_id; - } - None => {} - } + let typ = spirv_arg.v_type.clone(); + let new_id = id_def.new_non_variable(Some(typ.clone())); + result.push(Statement::Variable(ast::Variable { + align: spirv_arg.align, + v_type: spirv_arg.v_type.clone(), + state_space: spirv_arg.state_space, + name: spirv_arg.name, + array_init: spirv_arg.array_init.clone(), + })); + result.push(Statement::StoreVar(StoreVarDetails { + arg: ast::Arg2St { + src1: spirv_arg.name, + src2: new_id, + }, + typ, + member_index: None, + })); + spirv_arg.name = new_id; } for s in func { match s { @@ -2197,41 +2190,6 @@ fn insert_mem_ssa_statements<'a, 'b>( Ok(result) } -fn type_to_variable_type( - t: &ast::Type, - is_func: bool, -) -> Result, TranslateError> { - Ok(match t { - ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))), - ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector( - (*typ) - .try_into() - .map_err(|_| TranslateError::MismatchedType)?, - *len, - ))), - ast::Type::Array(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Array( - (*typ) - .try_into() - .map_err(|_| TranslateError::MismatchedType)?, - len.clone(), - ))), - ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => { - if is_func { - return Ok(None); - } - Some(ast::VariableType::Reg(ast::VariableRegType::Pointer( - scalar_type - .clone() - .try_into() - .map_err(|_| error_unreachable())?, - (*space).try_into().map_err(|_| error_unreachable())?, - ))) - } - ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None, - _ => return Err(error_unreachable()), - }) -} - trait Visitable: Sized { fn visit( self, @@ -2398,11 +2356,13 @@ fn expand_arguments<'a, 'b>( Statement::Variable(ast::Variable { align, v_type, + state_space, name, array_init, }) => result.push(Statement::Variable(ast::Variable { align, v_type, + state_space, name, array_init, })), @@ -2784,8 +2744,8 @@ fn insert_implicit_conversions_impl( fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, - spirv_input: &[ast::Variable], - spirv_output: &[ast::Variable], + spirv_input: &[ast::Variable], + spirv_output: &[ast::Variable], ) -> (spirv::Word, spirv::Word) { map.get_or_add_fn( builder, @@ -2822,8 +2782,8 @@ fn emit_function_body_ops( Statement::Label(_) => (), Statement::Call(call) => { let (result_type, result_id) = match &*call.ret_params { - [(id, typ)] => ( - map.get_or_add(builder, SpirvType::from(typ.to_func_type())), + [(id, typ, _)] => ( + map.get_or_add(builder, SpirvType::from(typ.clone())), Some(*id), ), [] => (map.void(), None), @@ -2832,7 +2792,7 @@ fn emit_function_body_ops( let arg_list = call .param_list .iter() - .map(|(id, _)| *id) + .map(|(id, _, _)| *id) .collect::>(); builder.function_call(result_type, result_id, call.func, arg_list)?; } @@ -3602,14 +3562,16 @@ fn vec_repr(t: T) -> Vec { fn emit_variable( builder: &mut dr::Builder, map: &mut TypeWordMap, - var: &ast::Variable, + var: &ast::Variable, ) -> Result<(), TranslateError> { - let (must_init, st_class) = match var.v_type { - ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => { + let (must_init, st_class) = match var.state_space { + ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { (false, spirv::StorageClass::Function) } - ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup), - ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup), + ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup), + ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), + ast::StateSpace::Const => todo!(), + ast::StateSpace::Generic => todo!(), }; let initalizer = if var.array_init.len() > 0 { Some(map.get_or_add_constant( @@ -4460,12 +4422,12 @@ fn expand_map_variables<'a, 'b>( ast::Statement::Variable(var) => { let mut var_type = ast::Type::from(var.var.v_type.clone()); let mut is_variable = false; - var_type = match var.var.v_type { - ast::VariableType::Reg(_) => { + var_type = match var.var.state_space { + ast::StateSpace::Reg => { is_variable = true; var_type } - ast::VariableType::Shared(_) => { + ast::StateSpace::Shared => { // If it's a pointer it will be translated to a method parameter later if let ast::Type::Pointer(..) = var_type { is_variable = true; @@ -4474,15 +4436,11 @@ fn expand_map_variables<'a, 'b>( var_type.param_pointer_to(ast::LdStateSpace::Shared)? } } - ast::VariableType::Global(_) => { - var_type.param_pointer_to(ast::LdStateSpace::Global)? - } - ast::VariableType::Param(_) => { - var_type.param_pointer_to(ast::LdStateSpace::Param)? - } - ast::VariableType::Local(_) => { - var_type.param_pointer_to(ast::LdStateSpace::Local)? - } + ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?, + ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?, + ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?, + ast::StateSpace::Const => todo!(), + ast::StateSpace::Generic => todo!(), }; match var.count { Some(count) => { @@ -4490,6 +4448,7 @@ fn expand_map_variables<'a, 'b>( result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), + state_space: var.var.state_space, name: new_id, array_init: var.var.array_init.clone(), })) @@ -4500,6 +4459,7 @@ fn expand_map_variables<'a, 'b>( result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), + state_space: var.var.state_space, name: new_id, array_init: var.var.array_init, })); @@ -4659,10 +4619,11 @@ fn convert_to_stateful_memory_access<'a>( align: None, name: new_id, array_init: Vec::new(), - v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( - ast::ScalarType::U8, + v_type: ast::Type::Pointer( + ast::PointerType::Scalar(ast::ScalarType::U8), ast::LdStateSpace::Global, - )), + ), + state_space: ast::StateSpace::Reg, })); remapped_ids.insert(reg, new_id); } @@ -5052,8 +5013,8 @@ struct GlobalStringIdResolver<'input> { } pub struct FnDecl { - ret_vals: Vec, - params: Vec, + ret_vals: Vec<(ast::Type, ast::StateSpace)>, + params: Vec<(ast::Type, ast::StateSpace)>, } impl<'a> GlobalStringIdResolver<'a> { @@ -5137,8 +5098,14 @@ impl<'a> GlobalStringIdResolver<'a> { self.fns.insert( name_id, FnDecl { - ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(), - params: params_ids.iter().map(|p| p.v_type.clone()).collect(), + ret_vals: ret_params_ids + .iter() + .map(|p| (p.v_type.clone(), p.state_space)) + .collect(), + params: params_ids + .iter() + .map(|p| (p.v_type.clone(), p.state_space)) + .collect(), }, ); ast::MethodDecl::Func(ret_params_ids, name_id, params_ids) @@ -5314,7 +5281,7 @@ impl<'b> MutableNumericIdResolver<'b> { enum Statement { Label(u32), - Variable(ast::Variable), + Variable(ast::Variable), Instruction(I), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), @@ -5352,16 +5319,17 @@ impl ExpandedStatement { Statement::StoreVar(details) } Statement::Call(mut call) => { - for (id, typ) in call.ret_params.iter_mut() { - let is_dst = match typ { - ast::FnArgumentType::Reg(_) => true, - ast::FnArgumentType::Param(_) => false, - ast::FnArgumentType::Shared => false, + for (id, _, space) in call.ret_params.iter_mut() { + let is_dst = match space { + ast::StateSpace::Reg => true, + ast::StateSpace::Param => false, + ast::StateSpace::Shared => false, + _ => todo!(), }; *id = f(*id, is_dst); } call.func = f(call.func, false); - for (id, _) in call.param_list.iter_mut() { + for (id, _, _) in call.param_list.iter_mut() { *id = f(*id, false); } Statement::Call(call) @@ -5502,9 +5470,9 @@ impl, U: ArgParamsEx> Visitab struct ResolvedCall { pub uniform: bool, - pub ret_params: Vec<(P::Id, ast::FnArgumentType)>, + pub ret_params: Vec<(P::Id, ast::Type, ast::StateSpace)>, pub func: P::Id, - pub param_list: Vec<(P::Operand, ast::FnArgumentType)>, + pub param_list: Vec<(P::Operand, ast::Type, ast::StateSpace)>, } impl ResolvedCall { @@ -5526,16 +5494,16 @@ impl> ResolvedCall { let ret_params = self .ret_params .into_iter() - .map::, _>(|(id, typ)| { + .map::, _>(|(id, typ, space)| { let new_id = visitor.id( ArgumentDescriptor { op: id, - is_dst: !typ.is_param(), - sema: typ.semantics(), + is_dst: space != ast::StateSpace::Param, + sema: space.semantics(), }, - Some(&typ.to_func_type()), + Some(&typ), )?; - Ok((new_id, typ)) + Ok((new_id, typ, space)) }) .collect::, _>>()?; let func = visitor.id( @@ -5549,16 +5517,16 @@ impl> ResolvedCall { let param_list = self .param_list .into_iter() - .map::, _>(|(id, typ)| { + .map::, _>(|(id, typ, space)| { let new_id = visitor.operand( ArgumentDescriptor { op: id, is_dst: false, - sema: typ.semantics(), + sema: space.semantics(), }, - &typ.to_func_type(), + &typ, )?; - Ok((new_id, typ)) + Ok((new_id, typ, space)) }) .collect::, _>>()?; Ok(ResolvedCall { @@ -5738,14 +5706,14 @@ impl ArgParamsEx for ExpandedArgParams { } enum Directive<'input> { - Variable(ast::Variable), + Variable(ast::Variable), Method(Function<'input>), } struct Function<'input> { pub func_decl: ast::MethodDecl<'input, spirv::Word>, pub spirv_decl: SpirvMethodDecl<'input>, - pub globals: Vec>, + pub globals: Vec>, pub body: Option>, import_as: Option, tuning: Vec, @@ -7300,16 +7268,6 @@ impl ast::LdStateSpace { } } -impl From for ast::VariableType { - fn from(t: ast::FnArgumentType) -> Self { - match t { - ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t), - ast::FnArgumentType::Param(t) => ast::VariableType::Param(t), - ast::FnArgumentType::Shared => todo!(), - } - } -} - impl ast::Operand { fn underlying(&self) -> Option<&T> { match self { @@ -7362,12 +7320,13 @@ impl ast::AtomSemantics { } } -impl ast::FnArgumentType { - fn semantics(&self) -> ArgumentSemantics { +impl ast::StateSpace { + fn semantics(self) -> ArgumentSemantics { match self { - ast::FnArgumentType::Reg(_) => ArgumentSemantics::Default, - ast::FnArgumentType::Param(_) => ArgumentSemantics::RegisterPointer, - ast::FnArgumentType::Shared => ArgumentSemantics::PhysicalPointer, + ast::StateSpace::Reg => ArgumentSemantics::Default, + ast::StateSpace::Param => ArgumentSemantics::RegisterPointer, + ast::StateSpace::Shared => ArgumentSemantics::PhysicalPointer, + _ => todo!(), } } } @@ -7677,8 +7636,8 @@ impl<'a> ast::MethodDecl<'a, &'a str> { } struct SpirvMethodDecl<'input> { - input: Vec>, - output: Vec>, + input: Vec>, + output: Vec>, name: MethodName<'input>, uses_shared_mem: bool, } @@ -7689,33 +7648,28 @@ impl<'input> SpirvMethodDecl<'input> { ast::MethodDecl::Kernel { in_args, .. } => { let spirv_input = in_args .iter() - .map(|var| { - let v_type = match &var.v_type { - ast::KernelArgumentType::Normal(t) => { - ast::FnArgumentType::Param(t.clone()) - } - ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared, - }; - ast::Variable { - name: var.name, - align: var.align, - v_type: v_type.to_kernel_type(), - array_init: var.array_init.clone(), - } + .map(|var| ast::Variable { + name: var.name, + align: var.align, + v_type: var.v_type.clone(), + state_space: var.state_space, + array_init: var.array_init.clone(), }) .collect(); (spirv_input, Vec::new()) } ast::MethodDecl::Func(out_args, _, in_args) => { - let (param_output, non_param_output): (Vec<_>, Vec<_>) = - out_args.iter().partition(|var| var.v_type.is_param()); + let (param_output, non_param_output): (Vec<_>, Vec<_>) = out_args + .iter() + .partition(|var| var.state_space == ast::StateSpace::Param); let spirv_output = non_param_output .into_iter() .cloned() .map(|var| ast::Variable { name: var.name, align: var.align, - v_type: var.v_type.to_func_type(), + v_type: var.v_type.clone(), + state_space: var.state_space, array_init: var.array_init.clone(), }) .collect(); @@ -7726,7 +7680,8 @@ impl<'input> SpirvMethodDecl<'input> { .map(|var| ast::Variable { name: var.name, align: var.align, - v_type: var.v_type.to_func_type(), + v_type: var.v_type.clone(), + state_space: var.state_space, array_init: var.array_init.clone(), }) .collect(); From 9d92a6e284dce00b0b785a50f623d3715f8aeac4 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 5 May 2021 22:56:58 +0200 Subject: [PATCH 06/25] Start converting the translation to one type type --- ptx/src/ast.rs | 85 +-- ptx/src/ptx.lalrpop | 74 +-- ptx/src/translate.rs | 1189 ++++++++++++++++++++++-------------------- 3 files changed, 666 insertions(+), 682 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index c7b9563..364ec01 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,6 +1,6 @@ use half::f16; use lalrpop_util::{lexer::Token, ParseError}; -use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr}; +use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; #[derive(Debug, thiserror::Error)] @@ -110,35 +110,7 @@ pub enum Type { Scalar(ScalarType), Vector(ScalarType, u8), Array(ScalarType, Vec), - Pointer(PointerType, LdStateSpace), -} - -#[derive(PartialEq, Eq, Clone)] -pub enum PointerType { - Scalar(ScalarType), - Vector(ScalarType, u8), - Array(ScalarType, Vec), - // Instances of this variant are generated during stateful conversion - Pointer(ScalarType, LdStateSpace), -} - -impl From for PointerType { - fn from(t: ScalarType) -> Self { - PointerType::Scalar(t.into()) - } -} - -impl TryFrom for ScalarType { - type Error = (); - - fn try_from(value: PointerType) -> Result { - match value { - PointerType::Scalar(t) => Ok(t), - PointerType::Vector(_, _) => Err(()), - PointerType::Array(_, _) => Err(()), - PointerType::Pointer(_, _) => Err(()), - } - } + Pointer(ScalarType), } #[derive(PartialEq, Eq, Hash, Clone, Copy)] @@ -222,6 +194,7 @@ pub enum StateSpace { Shared, Param, Generic, + Sreg, } pub struct PredAt { @@ -397,9 +370,9 @@ pub enum VectorPrefix { pub struct LdDetails { pub qualifier: LdStQualifier, - pub state_space: LdStateSpace, + pub state_space: StateSpace, pub caching: LdCacheOperator, - pub typ: PointerType, + pub typ: Type, pub non_coherent: bool, } @@ -418,17 +391,6 @@ pub enum MemScope { Sys, } -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -#[repr(u8)] -pub enum LdStateSpace { - Generic, - Const, - Global, - Local, - Param, - Shared, -} - #[derive(Copy, Clone, PartialEq, Eq)] pub enum LdCacheOperator { Cached, @@ -612,20 +574,11 @@ impl CvtDetails { } pub struct CvtaDetails { - pub to: CvtaStateSpace, - pub from: CvtaStateSpace, + pub to: StateSpace, + pub from: StateSpace, pub size: CvtaSize, } -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum CvtaStateSpace { - Generic, - Const, - Global, - Local, - Shared, -} - pub enum CvtaSize { U32, U64, @@ -633,18 +586,9 @@ pub enum CvtaSize { pub struct StData { pub qualifier: LdStQualifier, - pub state_space: StStateSpace, + pub state_space: StateSpace, pub caching: StCacheOperator, - pub typ: PointerType, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum StStateSpace { - Generic, - Global, - Local, - Param, - Shared, + pub typ: Type, } #[derive(PartialEq, Eq)] @@ -717,7 +661,7 @@ pub struct MinMaxFloat { pub struct AtomDetails { pub semantics: AtomSemantics, pub scope: MemScope, - pub space: AtomSpace, + pub space: StateSpace, pub inner: AtomInnerDetails, } @@ -729,13 +673,6 @@ pub enum AtomSemantics { AcquireRelease, } -#[derive(Copy, Clone)] -pub enum AtomSpace { - Generic, - Global, - Shared, -} - #[derive(Copy, Clone)] pub enum AtomInnerDetails { Bit { op: AtomBitOp, typ: ScalarType }, @@ -777,7 +714,7 @@ pub enum AtomFloatOp { pub struct AtomCasDetails { pub semantics: AtomSemantics, pub scope: MemScope, - pub space: AtomSpace, + pub space: StateSpace, pub typ: ScalarType, } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index dc439b7..8fee7c2 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -619,9 +619,9 @@ ModuleVariable: ast::Variable<&'input str> = { return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); } if space == ".global" { - (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Global), ast::StateSpace::Global, Vec::new()) + (ast::Type::Pointer(t), ast::StateSpace::Global, Vec::new()) } else { - (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Shared), ast::StateSpace::Shared, Vec::new()) + (ast::Type::Pointer(t), ast::StateSpace::Shared, Vec::new()) } } }; @@ -643,7 +643,7 @@ ParamVariable: (Option, Vec, ast::Type, &'input str) = { (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { - (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Param), Vec::new()) + (ast::Type::Pointer(t), Vec::new()) } }; (align, array_init, v_type, name) @@ -763,7 +763,7 @@ InstLd: ast::Instruction> = { ast::Instruction::Ld( ast::LdDetails { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ss.unwrap_or(ast::LdStateSpace::Generic), + state_space: ss.unwrap_or(ast::StateSpace::Generic), caching: cop.unwrap_or(ast::LdCacheOperator::Cached), typ: t, non_coherent: false @@ -775,7 +775,7 @@ InstLd: ast::Instruction> = { ast::Instruction::Ld( ast::LdDetails { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ast::LdStateSpace::Global, + state_space: ast::StateSpace::Global, caching: cop.unwrap_or(ast::LdCacheOperator::Cached), typ: t, non_coherent: false @@ -787,7 +787,7 @@ InstLd: ast::Instruction> = { ast::Instruction::Ld( ast::LdDetails { qualifier: ast::LdStQualifier::Weak, - state_space: ast::LdStateSpace::Global, + state_space: ast::StateSpace::Global, caching: cop.unwrap_or(ast::LdCacheOperator::Cached), typ: t, non_coherent: true @@ -797,9 +797,9 @@ InstLd: ast::Instruction> = { } }; -LdStType: ast::PointerType = { - => ast::PointerType::Vector(t, v), - => ast::PointerType::Scalar(t), +LdStType: ast::Type = { + => ast::Type::Vector(t, v), + => ast::Type::Scalar(t), } LdStQualifier: ast::LdStQualifier = { @@ -815,11 +815,11 @@ MemScope: ast::MemScope = { ".sys" => ast::MemScope::Sys }; -LdNonGlobalStateSpace: ast::LdStateSpace = { - ".const" => ast::LdStateSpace::Const, - ".local" => ast::LdStateSpace::Local, - ".param" => ast::LdStateSpace::Param, - ".shared" => ast::LdStateSpace::Shared, +LdNonGlobalStateSpace: ast::StateSpace = { + ".const" => ast::StateSpace::Const, + ".local" => ast::StateSpace::Local, + ".param" => ast::StateSpace::Param, + ".shared" => ast::StateSpace::Shared, }; LdCacheOperator: ast::LdCacheOperator = { @@ -1235,7 +1235,7 @@ InstSt: ast::Instruction> = { ast::Instruction::St( ast::StData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ss.unwrap_or(ast::StStateSpace::Generic), + state_space: ss.unwrap_or(ast::StateSpace::Generic), caching: cop.unwrap_or(ast::StCacheOperator::Writeback), typ: t }, @@ -1249,11 +1249,11 @@ MemoryOperand: ast::Operand<&'input str> = { "[" "]" => o } -StStateSpace: ast::StStateSpace = { - ".global" => ast::StStateSpace::Global, - ".local" => ast::StStateSpace::Local, - ".param" => ast::StStateSpace::Param, - ".shared" => ast::StStateSpace::Shared, +StStateSpace: ast::StateSpace = { + ".global" => ast::StateSpace::Global, + ".local" => ast::StateSpace::Local, + ".param" => ast::StateSpace::Param, + ".shared" => ast::StateSpace::Shared, }; StCacheOperator: ast::StCacheOperator = { @@ -1272,7 +1272,7 @@ InstRet: ast::Instruction> = { InstCvta: ast::Instruction> = { "cvta" => { ast::Instruction::Cvta(ast::CvtaDetails { - to: ast::CvtaStateSpace::Generic, + to: ast::StateSpace::Generic, from, size: s }, @@ -1281,18 +1281,18 @@ InstCvta: ast::Instruction> = { "cvta" ".to" => { ast::Instruction::Cvta(ast::CvtaDetails { to, - from: ast::CvtaStateSpace::Generic, + from: ast::StateSpace::Generic, size: s }, a) } } -CvtaStateSpace: ast::CvtaStateSpace = { - ".const" => ast::CvtaStateSpace::Const, - ".global" => ast::CvtaStateSpace::Global, - ".local" => ast::CvtaStateSpace::Local, - ".shared" => ast::CvtaStateSpace::Shared, +CvtaStateSpace: ast::StateSpace = { + ".const" => ast::StateSpace::Const, + ".global" => ast::StateSpace::Global, + ".local" => ast::StateSpace::Local, + ".shared" => ast::StateSpace::Shared, } CvtaSize: ast::CvtaSize = { @@ -1450,7 +1450,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Bit { op, typ } }; ast::Instruction::Atom(details,a) @@ -1459,7 +1459,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Inc, typ: ast::ScalarType::U32 @@ -1471,7 +1471,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Dec, typ: ast::ScalarType::U32 @@ -1484,7 +1484,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Float { op, typ } }; ast::Instruction::Atom(details,a) @@ -1493,7 +1493,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op, typ } }; ast::Instruction::Atom(details,a) @@ -1502,7 +1502,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Signed { op, typ } }; ast::Instruction::Atom(details,a) @@ -1514,7 +1514,7 @@ InstAtomCas: ast::Instruction> = { let details = ast::AtomCasDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), typ, }; ast::Instruction::AtomCas(details,a) @@ -1528,9 +1528,9 @@ AtomSemantics: ast::AtomSemantics = { ".acq_rel" => ast::AtomSemantics::AcquireRelease } -AtomSpace: ast::AtomSpace = { - ".global" => ast::AtomSpace::Global, - ".shared" => ast::AtomSpace::Shared +AtomSpace: ast::StateSpace = { + ".global" => ast::StateSpace::Global, + ".shared" => ast::StateSpace::Shared } AtomBitOp: ast::AtomBitOp = { diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 4ba5729..a743496 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -37,6 +37,12 @@ fn error_unreachable() -> TranslateError { TranslateError::Unreachable } +macro_rules! new_todo { + () => { + todo!() + }; +} + #[derive(PartialEq, Eq, Hash, Clone)] enum SpirvType { Base(SpirvScalarKey), @@ -48,52 +54,40 @@ enum SpirvType { } impl SpirvType { - fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self { - let key = t.into(); - SpirvType::Pointer(Box::new(key), sc) - } -} - -impl From for SpirvType { - fn from(t: ast::Type) -> Self { + fn new(t: ast::Type, decl_space: ast::StateSpace) -> Self { match t { ast::Type::Scalar(t) => SpirvType::Base(t.into()), ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), - ast::Type::Pointer(pointer_t, state_space) => SpirvType::Pointer( - Box::new(SpirvType::from(ast::Type::from(pointer_t))), - state_space.to_spirv(), - ), - } - } -} - -impl From for ast::Type { - fn from(t: ast::PointerType) -> Self { - match t { - ast::PointerType::Scalar(t) => ast::Type::Scalar(t), - ast::PointerType::Vector(t, len) => ast::Type::Vector(t, len), - ast::PointerType::Array(t, dims) => ast::Type::Array(t, dims), - ast::PointerType::Pointer(t, space) => { - ast::Type::Pointer(ast::PointerType::Scalar(t), space) + ast::Type::Pointer(pointer_t) => { + let spirv_space = match decl_space { + ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { + spirv::StorageClass::Private + } + ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, + ast::StateSpace::Const => spirv::StorageClass::UniformConstant, + ast::StateSpace::Shared => spirv::StorageClass::Workgroup, + ast::StateSpace::Generic => spirv::StorageClass::Generic, + ast::StateSpace::Sreg => spirv::StorageClass::Input, + }; + SpirvType::Pointer(Box::new(SpirvType::Base(pointer_t.into())), spirv_space) } } } + + fn pointer_to( + t: ast::Type, + inner_space: ast::StateSpace, + outer_space: spirv::StorageClass, + ) -> Self { + let key = Self::new(t, inner_space); + SpirvType::Pointer(Box::new(key), outer_space) + } } impl ast::Type { - fn param_pointer_to(self, space: ast::LdStateSpace) -> Result { - Ok(match self { - ast::Type::Scalar(t) => ast::Type::Pointer(ast::PointerType::Scalar(t), space), - ast::Type::Vector(t, len) => { - ast::Type::Pointer(ast::PointerType::Vector(t, len), space) - } - ast::Type::Array(t, _) => ast::Type::Pointer(ast::PointerType::Scalar(t), space), - ast::Type::Pointer(ast::PointerType::Scalar(t), space) => { - ast::Type::Pointer(ast::PointerType::Pointer(t, space), space) - } - ast::Type::Pointer(_, _) => return Err(error_unreachable()), - }) + fn param_pointer_to(self, space: ast::StateSpace) -> Result { + Ok(self) } } @@ -398,18 +392,7 @@ impl TypeWordMap { b.constant_composite(result_type, None, components.into_iter()) } }, - ast::Type::Pointer(typ, state_space) => { - let base_t = typ.clone().into(); - let base = self.get_or_add_constant(b, &base_t, &[])?; - let result_type = self.get_or_add( - b, - SpirvType::Pointer( - Box::new(SpirvType::from(base_t)), - (*state_space).to_spirv(), - ), - ); - b.variable(result_type, None, (*state_space).to_spirv(), Some(base)) - } + ast::Type::Pointer(typ) => return Err(error_unreachable()), }) } @@ -702,11 +685,29 @@ fn multi_hash_map_append(m: &mut MultiHashMap, } } -// PTX represents dynamically allocated shared local memory as -// .extern .shared .align 4 .b8 shared_mem[]; -// In SPIRV/OpenCL world this is expressed as an additional argument -// This pass looks for all uses of .extern .shared and converts them to -// an additional method argument +/* + PTX represents dynamically allocated shared local memory as + .extern .shared .b32 shared_mem[]; + In SPIRV/OpenCL world this is expressed as an additional argument + This pass looks for all uses of .extern .shared and converts them to + an additional method argument + The question is how this artificial argument should be expressed. There are + several options: + * Straight conversion: + .shared .b32 shared_mem[] + * Introduce .param_shared statespace: + .param_shared .b32 shared_mem + or + .param_shared .b32 shared_mem[] + * Introduce .shared_ptr type: + .param .shared_ptr .b32 shared_mem + * Reuse .ptr hint: + .param .u64 .ptr shared_mem + This is the most tempting, but also the most nonsensical, .ptr is just a + hint, which has no semantical meaning (and the output of our + transformation has a semantical meaning - we emit additional + "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...") +*/ fn convert_dynamic_shared_memory_usage<'input>( module: Vec>, new_id: &mut impl FnMut() -> spirv::Word, @@ -715,7 +716,7 @@ fn convert_dynamic_shared_memory_usage<'input>( for dir in module.iter() { match dir { Directive::Variable(ast::Variable { - v_type: ast::Type::Pointer(p_type, ast::LdStateSpace::Shared), + v_type: ast::Type::Pointer(p_type), state_space: ast::StateSpace::Shared, name, .. @@ -799,48 +800,23 @@ fn convert_dynamic_shared_memory_usage<'input>( ast::Variable { name: shared_id_param, align: None, - v_type: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::B8), - ast::LdStateSpace::Shared, - ), - state_space: ast::StateSpace::Param, + v_type: ast::Type::Pointer(ast::ScalarType::B8), + state_space: ast::StateSpace::Shared, array_init: Vec::new(), } }); spirv_decl.uses_shared_mem = true; - let shared_var_id = new_id(); - let shared_var = ExpandedStatement::Variable(ast::Variable { - name: shared_var_id, - align: None, - v_type: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::B8), - ast::LdStateSpace::Shared, - ), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), - }); - let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails { - arg: ast::Arg2St { - src1: shared_var_id, - src2: shared_id_param, - }, - typ: ast::Type::Scalar(ast::ScalarType::B8), - member_index: None, - }); - let mut new_statements = vec![shared_var, shared_var_st]; - replace_uses_of_shared_memory( - &mut new_statements, + let statements = replace_uses_of_shared_memory( new_id, &extern_shared_decls, &mut methods_using_extern_shared, shared_id_param, - shared_var_id, statements, ); Directive::Method(Function { func_decl, globals, - body: Some(new_statements), + body: Some(statements), import_as, spirv_decl, tuning, @@ -852,14 +828,13 @@ fn convert_dynamic_shared_memory_usage<'input>( } fn replace_uses_of_shared_memory<'a>( - result: &mut Vec, new_id: &mut impl FnMut() -> spirv::Word, - extern_shared_decls: &HashMap, + extern_shared_decls: &HashMap, methods_using_extern_shared: &mut HashSet>, shared_id_param: spirv::Word, - shared_var_id: spirv::Word, statements: Vec, -) { +) -> Vec { + let mut result = Vec::with_capacity(statements.len()); for statement in statements { match statement { Statement::Call(mut call) => { @@ -877,22 +852,18 @@ fn replace_uses_of_shared_memory<'a>( } statement => { let new_statement = statement.map_id(&mut |id, _| { - if let Some(ast::PointerType::Scalar(typ)) = extern_shared_decls.get(&id) { - if *typ == ast::ScalarType::B8 { - return shared_var_id; + if let Some(scalar_type) = extern_shared_decls.get(&id) { + if *scalar_type == ast::ScalarType::B8 { + return shared_id_param; } let replacement_id = new_id(); result.push(Statement::Conversion(ImplicitConversion { - src: shared_var_id, + src: shared_id_param, dst: replacement_id, - from: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::B8), - ast::LdStateSpace::Shared, - ), - to: ast::Type::Pointer( - ast::PointerType::Scalar((*typ).into()), - ast::LdStateSpace::Shared, - ), + from_type: ast::Type::Pointer(ast::ScalarType::B8), + from_space: ast::StateSpace::Shared, + to_type: ast::Type::Pointer((*scalar_type).into()), + to_space: ast::StateSpace::Shared, kind: ConversionKind::PtrToPtr { spirv_ptr: true }, src_sema: ArgumentSemantics::Default, dst_sema: ArgumentSemantics::Default, @@ -906,6 +877,7 @@ fn replace_uses_of_shared_memory<'a>( } } } + result } fn get_callers_of_extern_shared<'a>( @@ -1055,8 +1027,9 @@ fn emit_builtins( for (reg, id) in id_defs.special_registers.builtins() { let result_type = map.get_or_add( builder, - SpirvType::Pointer( - Box::new(SpirvType::from(reg.get_type())), + SpirvType::pointer_to( + reg.get_type(), + ast::StateSpace::Reg, spirv::StorageClass::Input, ), ); @@ -1158,7 +1131,10 @@ fn emit_function_header<'a>( } */ for input in &func_decl.input { - let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone())); + let result_type = map.get_or_add( + builder, + SpirvType::new(input.v_type.clone(), input.state_space), + ); builder.function_parameter(Some(input.name), result_type)?; } Ok(fn_id) @@ -1219,26 +1195,26 @@ fn translate_variable<'a>( is_variable = true; var_type } - ast::StateSpace::Const => var_type.param_pointer_to(ast::LdStateSpace::Const)?, - ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?, - ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?, + ast::StateSpace::Const => var_type.param_pointer_to(ast::StateSpace::Const)?, + ast::StateSpace::Global => var_type.param_pointer_to(ast::StateSpace::Global)?, + ast::StateSpace::Local => var_type.param_pointer_to(ast::StateSpace::Local)?, ast::StateSpace::Shared => { // If it's a pointer it will be translated to a method parameter later if let ast::Type::Pointer(..) = var_type { is_variable = true; var_type } else { - var_type.param_pointer_to(ast::LdStateSpace::Shared)? + var_type.param_pointer_to(ast::StateSpace::Shared)? } } - ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?, - ast::StateSpace::Generic => todo!(), + ast::StateSpace::Param => var_type.param_pointer_to(ast::StateSpace::Param)?, + ast::StateSpace::Generic | ast::StateSpace::Sreg => return Err(error_unreachable()), }; Ok(ast::Variable { align: var.align, v_type: var.v_type, state_space: var.state_space, - name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable), + name: id_defs.get_or_add_def_typed(var.name, var_type, var.state_space, is_variable), array_init: var.array_init, }) } @@ -1283,7 +1259,10 @@ fn expand_kernel_params<'a, 'b>( Ok(ast::KernelArgument { name: fn_resolver.add_def( a.name, - Some(ast::Type::from(a.v_type.clone()).param_pointer_to(ast::LdStateSpace::Param)?), + Some(( + ast::Type::from(a.v_type.clone()).param_pointer_to(ast::StateSpace::Param)?, + a.state_space, + )), false, ), v_type: a.v_type.clone(), @@ -1302,7 +1281,7 @@ fn expand_fn_params<'a, 'b>( args.map(|a| { let is_variable = a.state_space == ast::StateSpace::Reg; Ok(ast::FnArgument { - name: fn_resolver.add_def(a.name, Some(a.v_type.clone()), is_variable), + name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), is_variable), v_type: a.v_type.clone(), state_space: a.state_space, align: a.align, @@ -1339,15 +1318,15 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; - let typed_statements = - convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; + //let typed_statements = + // convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; let ssa_statements = insert_mem_ssa_statements( typed_statements, &mut numeric_id_defs, &f_args, &mut spirv_decl, )?; - let ssa_statements = fix_builtins(ssa_statements, &mut numeric_id_defs)?; + let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.finish(); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = @@ -1366,7 +1345,7 @@ fn to_ssa<'input, 'b>( }) } -fn fix_builtins( +fn fix_special_registers( typed_statements: Vec, numeric_id_defs: &mut NumericIdResolver, ) -> Result, TranslateError> { @@ -1402,7 +1381,8 @@ fn fix_builtins( continue; } }; - let temp_id = numeric_id_defs.new_non_variable(Some(details.typ.clone())); + let temp_id = numeric_id_defs + .register_intermediate(Some((details.typ.clone(), details.state_space))); let real_dst = details.arg.dst; details.arg.dst = temp_id; result.push(Statement::LoadVar(LoadVarDetails { @@ -1410,14 +1390,17 @@ fn fix_builtins( src: sreg_src, dst: temp_id, }, + state_space: ast::StateSpace::Sreg, typ: ast::Type::Scalar(scalar_typ), member_index: Some((index, Some(vector_width))), })); result.push(Statement::Conversion(ImplicitConversion { src: temp_id, dst: real_dst, - from: ast::Type::Scalar(scalar_typ), - to: ast::Type::Scalar(ast::ScalarType::U32), + from_type: ast::Type::Scalar(scalar_typ), + from_space: ast::StateSpace::Sreg, + to_type: ast::Type::Scalar(ast::ScalarType::U32), + to_space: ast::StateSpace::Sreg, kind: ConversionKind::Default, src_sema: ArgumentSemantics::Default, dst_sema: ArgumentSemantics::Default, @@ -1614,12 +1597,12 @@ fn convert_to_typed_statements( } ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => { if let Some(src_id) = src.underlying() { - let (typ, _) = id_defs.get_typed(*src_id)?; + let (typ, _, _) = id_defs.get_typed(*src_id)?; let take_address = match typ { - ast::Type::Scalar(_) => false, - ast::Type::Vector(_, _) => false, - ast::Type::Array(_, _) => true, - ast::Type::Pointer(_, _) => true, + ast::Type::Scalar(..) => false, + ast::Type::Vector(..) => false, + ast::Type::Array(..) => true, + ast::Type::Pointer(..) => true, }; d.src_is_address = take_address; } @@ -1666,6 +1649,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { is_dst: bool, vector_sema: ArgumentSemantics, typ: &ast::Type, + state_space: ast::StateSpace, idx: Vec, ) -> Result { // mov.u32 foobar, {a,b}; @@ -1673,7 +1657,9 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { ast::Type::Vector(scalar_t, _) => *scalar_t, _ => return Err(TranslateError::MismatchedType), }; - let temp_vec = self.id_def.new_non_variable(Some(typ.clone())); + let temp_vec = self + .id_def + .register_intermediate(Some((typ.clone(), state_space))); let statement = Statement::RepackVector(RepackVectorDetails { is_extract: is_dst, typ: scalar_t, @@ -1696,7 +1682,7 @@ impl<'a, 'b> ArgumentMapVisitor fn id( &mut self, desc: ArgumentDescriptor, - _: Option<&ast::Type>, + _: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { Ok(desc.op) } @@ -1705,15 +1691,20 @@ impl<'a, 'b> ArgumentMapVisitor &mut self, desc: ArgumentDescriptor>, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { Ok(match desc.op { ast::Operand::Reg(reg) => TypedOperand::Reg(reg), ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), ast::Operand::Imm(x) => TypedOperand::Imm(x), ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), - ast::Operand::VecPack(vec) => { - TypedOperand::Reg(self.convert_vector(desc.is_dst, desc.sema, typ, vec)?) - } + ast::Operand::VecPack(vec) => TypedOperand::Reg(self.convert_vector( + desc.is_dst, + desc.sema, + typ, + state_space, + vec, + )?), }) } } @@ -1735,37 +1726,33 @@ fn to_ptx_impl_atomic_call( semantics, scope, space, op ); // TODO: extract to a function - let ptr_space = match details.space { - ast::AtomSpace::Generic => ast::LdStateSpace::Generic, - ast::AtomSpace::Global => ast::LdStateSpace::Global, - ast::AtomSpace::Shared => ast::LdStateSpace::Shared, - }; + let ptr_space = details.space; let scalar_typ = ast::ScalarType::from(typ); let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.new_non_variable(None); + let fn_id = id_defs.register_intermediate(None); let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, v_type: ast::Type::Scalar(scalar_typ), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }], fn_id, vec![ ast::FnArgument { align: None, - v_type: ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space), - state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + v_type: ast::Type::Pointer(typ), + state_space: ptr_space, + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::Type::Scalar(scalar_typ), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ], @@ -1795,11 +1782,7 @@ fn to_ptx_impl_atomic_call( func: fn_id, ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)], param_list: vec![ - ( - arg.src1, - ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space), - ast::StateSpace::Reg, - ), + (arg.src1, ast::Type::Pointer(typ), ptr_space), ( arg.src2, ast::Type::Scalar(scalar_typ), @@ -1826,13 +1809,13 @@ fn to_ptx_impl_bfe_call( let fn_name = format!("{}{}", prefix, suffix); let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.new_non_variable(None); + let fn_id = id_defs.register_intermediate(None); let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }], fn_id, @@ -1841,21 +1824,21 @@ fn to_ptx_impl_bfe_call( align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ], @@ -1919,13 +1902,13 @@ fn to_ptx_impl_bfi_call( let fn_name = format!("{}{}", prefix, suffix); let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.new_non_variable(None); + let fn_id = id_defs.register_intermediate(None); let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }], fn_id, @@ -1934,28 +1917,28 @@ fn to_ptx_impl_bfi_call( align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ], @@ -2048,7 +2031,7 @@ fn normalize_labels( | Statement::RepackVector(..) => {} } } - iter::once(Statement::Label(id_def.new_non_variable(None))) + iter::once(Statement::Label(id_def.register_intermediate(None))) .chain(func.into_iter().filter(|s| match s { Statement::Label(i) => labels_in_use.contains(i), _ => true, @@ -2066,8 +2049,8 @@ fn normalize_predicates( Statement::Label(id) => result.push(Statement::Label(id)), Statement::Instruction((pred, inst)) => { if let Some(pred) = pred { - let if_true = id_def.new_non_variable(None); - let if_false = id_def.new_non_variable(None); + let if_true = id_def.register_intermediate(None); + let if_false = id_def.register_intermediate(None); let folded_bra = match &inst { ast::Instruction::Bra(_, arg) => Some(arg.src), _ => None, @@ -2116,7 +2099,8 @@ fn insert_mem_ssa_statements<'a, 'b>( } for spirv_arg in fn_decl.input.iter_mut() { let typ = spirv_arg.v_type.clone(); - let new_id = id_def.new_non_variable(Some(typ.clone())); + let state_space = spirv_arg.state_space; + let new_id = id_def.register_intermediate(Some((typ.clone(), state_space))); result.push(Statement::Variable(ast::Variable { align: spirv_arg.align, v_type: spirv_arg.v_type.clone(), @@ -2129,6 +2113,7 @@ fn insert_mem_ssa_statements<'a, 'b>( src1: spirv_arg.name, src2: new_id, }, + state_space, typ, member_index: None, })); @@ -2143,13 +2128,15 @@ fn insert_mem_ssa_statements<'a, 'b>( ast::Instruction::Ret(d) => { // TODO: handle multiple output args if let &[out_param] = &fn_decl.output.as_slice() { - let (typ, _) = id_def.get_typed(out_param.name)?; - let new_id = id_def.new_non_variable(Some(typ.clone())); + let (typ, space, _) = id_def.get_typed(out_param.name)?; + let new_id = id_def.register_intermediate(Some((typ.clone(), space))); result.push(Statement::LoadVar(LoadVarDetails { arg: ast::Arg2 { dst: new_id, src: out_param.name, }, + // TODO: ret with stateful conversion + state_space: new_todo!(), typ: typ.clone(), member_index: None, })); @@ -2161,13 +2148,16 @@ fn insert_mem_ssa_statements<'a, 'b>( inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?, }, Statement::Conditional(mut bra) => { - let generated_id = - id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred))); + let generated_id = id_def.register_intermediate(Some(( + ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + ))); result.push(Statement::LoadVar(LoadVarDetails { arg: Arg2 { dst: generated_id, src: bra.predicate, }, + state_space: ast::StateSpace::Reg, typ: ast::Type::Scalar(ast::ScalarType::Pred), member_index: None, })); @@ -2204,6 +2194,7 @@ struct VisitArgumentDescriptor< > { desc: ArgumentDescriptor, typ: &'a ast::Type, + state_space: ast::StateSpace, stmt_ctor: Ctor, } @@ -2218,7 +2209,9 @@ impl< self, visitor: &mut impl ArgumentMapVisitor, ) -> Result, U>, TranslateError> { - Ok((self.stmt_ctor)(visitor.id(self.desc, Some(self.typ))?)) + Ok((self.stmt_ctor)( + visitor.id(self.desc, Some((self.typ, self.state_space)))?, + )) } } @@ -2232,13 +2225,13 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { fn symbol( &mut self, desc: ArgumentDescriptor<(spirv::Word, Option)>, - expected_type: Option<&ast::Type>, + expected_type: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { let symbol = desc.op.0; if expected_type.is_none() { return Ok(symbol); }; - let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?; + let (mut var_type, _, is_variable) = self.id_def.get_typed(symbol)?; if !is_variable { return Ok(symbol); }; @@ -2262,13 +2255,16 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { } None => None, }; - let generated_id = self.id_def.new_non_variable(Some(var_type.clone())); + let generated_id = self + .id_def + .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg))); if !desc.is_dst { self.func.push(Statement::LoadVar(LoadVarDetails { arg: Arg2 { dst: generated_id, src: symbol, }, + state_space: ast::StateSpace::Reg, typ: var_type, member_index, })); @@ -2279,6 +2275,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { src1: symbol, src2: generated_id, }, + state_space: ast::StateSpace::Reg, typ: var_type, member_index: member_index.map(|(idx, _)| idx), })); @@ -2293,7 +2290,7 @@ impl<'a, 'input> ArgumentMapVisitor fn id( &mut self, desc: ArgumentDescriptor, - typ: Option<&ast::Type>, + typ: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self.symbol(desc.new_op((desc.op, None)), typ) } @@ -2302,18 +2299,20 @@ impl<'a, 'input> ArgumentMapVisitor &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { Ok(match desc.op { TypedOperand::Reg(reg) => { - TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) - } - TypedOperand::RegOffset(reg, offset) => { - TypedOperand::RegOffset(self.symbol(desc.new_op((reg, None)), Some(typ))?, offset) + TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?) } + TypedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset( + self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?, + offset, + ), op @ TypedOperand::Imm(..) => op, - TypedOperand::VecMember(symbol, index) => { - TypedOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) - } + TypedOperand::VecMember(symbol, index) => TypedOperand::Reg( + self.symbol(desc.new_op((symbol, Some(index))), Some((typ, state_space)))?, + ), }) } } @@ -2411,7 +2410,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn reg( &mut self, desc: ArgumentDescriptor, - _: Option<&ast::Type>, + _: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { Ok(desc.op) } @@ -2420,30 +2419,31 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { &mut self, desc: ArgumentDescriptor<(spirv::Word, i32)>, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { let (reg, offset) = desc.op; let add_type; match typ { - ast::Type::Pointer(underlying_type, state_space) => { - let reg_typ = self.id_def.get_typed(reg)?; - if let ast::Type::Pointer(_, _) = reg_typ { - let id_constant_stmt = self.id_def.new_non_variable(typ.clone()); + ast::Type::Pointer(underlying_type) => { + let (reg_typ, space) = self.id_def.get_typed(reg)?; + if let ast::Type::Pointer(..) = reg_typ { + let id_constant_stmt = self.id_def.register_intermediate(typ.clone(), space); self.func.push(Statement::Constant(ConstantDefinition { dst: id_constant_stmt, typ: ast::ScalarType::S64, value: ast::ImmediateValue::S64(offset as i64), })); - let dst = self.id_def.new_non_variable(typ.clone()); + let dst = self.id_def.register_intermediate(typ.clone(), space); self.func.push(Statement::PtrAccess(PtrAccess { - underlying_type: underlying_type.clone(), - state_space: *state_space, + underlying_type: *underlying_type, + state_space: state_space, dst, ptr_src: reg, offset_src: id_constant_stmt, })); return Ok(dst); } else { - add_type = self.id_def.get_typed(reg)?; + add_type = self.id_def.get_typed(reg)?.0; } } _ => { @@ -2475,8 +2475,12 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { ast::ScalarKind::Unsigned, )) }; - let id_constant_stmt = self.id_def.new_non_variable(add_type.clone()); - let result_id = self.id_def.new_non_variable(add_type); + let id_constant_stmt = self + .id_def + .register_intermediate(add_type.clone(), ast::StateSpace::Reg); + let result_id = self + .id_def + .register_intermediate(add_type, ast::StateSpace::Reg); // TODO: check for edge cases around min value/max value/wrapping if offset < 0 && kind != ast::ScalarKind::Signed { self.func.push(Statement::Constant(ConstantDefinition { @@ -2518,13 +2522,16 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { let scalar_t = if let ast::Type::Scalar(scalar) = typ { *scalar } else { todo!() }; - let id = self.id_def.new_non_variable(ast::Type::Scalar(scalar_t)); + let id = self + .id_def + .register_intermediate(ast::Type::Scalar(scalar_t), state_space); self.func.push(Statement::Constant(ConstantDefinition { dst: id, typ: scalar_t, @@ -2538,7 +2545,7 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr fn id( &mut self, desc: ArgumentDescriptor, - t: Option<&ast::Type>, + t: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self.reg(desc, t) } @@ -2547,12 +2554,13 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { match desc.op { - TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)), - TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ), + TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some((typ, state_space))), + TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ, state_space), TypedOperand::RegOffset(reg, offset) => { - self.reg_offset(desc.new_op((reg, offset)), typ) + self.reg_offset(desc.new_op((reg, offset)), typ, state_space) } TypedOperand::VecMember(..) => Err(error_unreachable()), } @@ -2580,39 +2588,29 @@ fn insert_implicit_conversions( let mut result = Vec::with_capacity(func.len()); for s in func.into_iter() { match s { - Statement::Call(call) => insert_implicit_conversions_impl( - &mut result, - id_def, - call, - should_bitcast_wrapper, - None, - )?, + Statement::Call(call) => { + insert_implicit_conversions_impl(&mut result, id_def, call, should_bitcast_wrapper)? + } Statement::Instruction(inst) => { let mut default_conversion_fn = - should_bitcast_wrapper as for<'a> fn(&'a ast::Type, &'a ast::Type, _) -> _; + should_bitcast_wrapper as for<'a> fn(&'a _, _, &'a _, _) -> _; let mut state_space = None; if let ast::Instruction::Ld(d, _) = &inst { state_space = Some(d.state_space); } if let ast::Instruction::St(d, _) = &inst { - state_space = Some(d.state_space.to_ld_ss()); + state_space = Some(d.state_space); } if let ast::Instruction::Atom(d, _) = &inst { - state_space = Some(d.space.to_ld_ss()); + state_space = Some(d.space); } if let ast::Instruction::AtomCas(d, _) = &inst { - state_space = Some(d.space.to_ld_ss()); + state_space = Some(d.space); } if let ast::Instruction::Mov(..) = &inst { default_conversion_fn = should_bitcast_packed; } - insert_implicit_conversions_impl( - &mut result, - id_def, - inst, - default_conversion_fn, - state_space, - )?; + insert_implicit_conversions_impl(&mut result, id_def, inst, default_conversion_fn)?; } Statement::PtrAccess(PtrAccess { underlying_type, @@ -2627,7 +2625,8 @@ fn insert_implicit_conversions( is_dst: false, sema: ArgumentSemantics::PhysicalPointer, }, - typ: &ast::Type::Pointer(underlying_type.clone(), state_space), + typ: &ast::Type::Pointer(underlying_type), + state_space, stmt_ctor: |new_ptr_src| { Statement::PtrAccess(PtrAccess { underlying_type, @@ -2643,7 +2642,6 @@ fn insert_implicit_conversions( id_def, visit_desc, bitcast_physical_pointer, - Some(state_space), )?; } Statement::RepackVector(repack) => insert_implicit_conversions_impl( @@ -2651,7 +2649,6 @@ fn insert_implicit_conversions( id_def, repack, should_bitcast_wrapper, - None, )?, s @ Statement::Conditional(_) | s @ Statement::Conversion(_) @@ -2672,19 +2669,20 @@ fn insert_implicit_conversions_impl( stmt: impl Visitable, default_conversion_fn: for<'a> fn( &'a ast::Type, + ast::StateSpace, &'a ast::Type, - Option, + ast::StateSpace, ) -> Result, TranslateError>, - state_space: Option, ) -> Result<(), TranslateError> { let mut post_conv = Vec::new(); - let statement = stmt.visit( - &mut |desc: ArgumentDescriptor, typ: Option<&ast::Type>| { - let instr_type = match typ { + let statement = + stmt.visit(&mut |desc: ArgumentDescriptor, + typ: Option<(&ast::Type, ast::StateSpace)>| { + let (instr_type, instruction_space) = match typ { None => return Ok(desc.op), Some(t) => t, }; - let operand_type = id_def.get_typed(desc.op)?; + let (operand_type, operand_space) = id_def.get_typed(desc.op)?; let mut conversion_fn = default_conversion_fn; match desc.sema { ArgumentSemantics::Default => {} @@ -2705,27 +2703,33 @@ fn insert_implicit_conversions_impl( conversion_fn = force_bitcast_ptr_to_bit; } }; - match conversion_fn(&operand_type, instr_type, state_space)? { + match conversion_fn(&operand_type, operand_space, instr_type, instruction_space)? { Some(conv_kind) => { let conv_output = if desc.is_dst { &mut post_conv } else { &mut *func }; - let mut from = instr_type.clone(); - let mut to = operand_type; - let mut src = id_def.new_non_variable(instr_type.clone()); + let mut from_type = instr_type.clone(); + let mut from_space = instruction_space; + let mut to_type = operand_type; + let mut to_space = operand_space; + let mut src = + id_def.register_intermediate(instr_type.clone(), instruction_space); let mut dst = desc.op; let result = Ok(src); if !desc.is_dst { mem::swap(&mut src, &mut dst); - mem::swap(&mut from, &mut to); + mem::swap(&mut from_type, &mut to_type); + mem::swap(&mut from_space, &mut to_space); } conv_output.push(Statement::Conversion(ImplicitConversion { src, dst, - from, - to, + from_type, + from_space, + to_type, + to_space, kind: conv_kind, src_sema: ArgumentSemantics::Default, dst_sema: ArgumentSemantics::Default, @@ -2734,8 +2738,7 @@ fn insert_implicit_conversions_impl( } None => Ok(desc.op), } - }, - )?; + })?; func.push(statement); func.append(&mut post_conv); Ok(()) @@ -2751,10 +2754,10 @@ fn get_function_type( builder, spirv_input .iter() - .map(|var| SpirvType::from(var.v_type.clone())), + .map(|var| SpirvType::new(var.v_type.clone(), var.state_space)), spirv_output .iter() - .map(|var| SpirvType::from(var.v_type.clone())), + .map(|var| SpirvType::new(var.v_type.clone(), var.state_space)), ) } @@ -2782,8 +2785,8 @@ fn emit_function_body_ops( Statement::Label(_) => (), Statement::Call(call) => { let (result_type, result_id) = match &*call.ret_params { - [(id, typ, _)] => ( - map.get_or_add(builder, SpirvType::from(typ.clone())), + [(id, typ, space)] => ( + map.get_or_add(builder, SpirvType::new(typ.clone(), *space)), Some(*id), ), [] => (map.void(), None), @@ -2915,8 +2918,10 @@ fn emit_function_body_ops( if data.qualifier != ast::LdStQualifier::Weak { todo!() } - let result_type = - map.get_or_add(builder, SpirvType::from(ast::Type::from(data.typ.clone()))); + let result_type = map.get_or_add( + builder, + SpirvType::new(ast::Type::from(data.typ.clone()), data.state_space), + ); builder.load( result_type, Some(arg.dst), @@ -2947,8 +2952,10 @@ fn emit_function_body_ops( // SPIR-V does not support ret as guaranteed-converged ast::Instruction::Ret(_) => builder.ret()?, ast::Instruction::Mov(d, arg) => { - let result_type = - map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone()))); + let result_type = map.get_or_add( + builder, + SpirvType::new(ast::Type::from(d.typ.clone()), ast::StateSpace::Reg), + ); builder.copy_object(result_type, Some(arg.dst), arg.src)?; } ast::Instruction::Mul(mul, arg) => match mul { @@ -2989,7 +2996,8 @@ fn emit_function_body_ops( ast::Instruction::Shl(t, a) => { let full_type = ast::Type::Scalar(*t); let size_of = full_type.size_of(); - let result_type = map.get_or_add(builder, SpirvType::from(full_type)); + let result_type = + map.get_or_add(builder, SpirvType::new(full_type, ast::StateSpace::Reg)); let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?; builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?; } @@ -3251,8 +3259,9 @@ fn emit_function_body_ops( Some(index) => { let result_ptr_type = map.get_or_add( builder, - SpirvType::new_pointer( + SpirvType::pointer_to( details.typ.clone(), + details.state_space, spirv::StorageClass::Function, ), ); @@ -3284,14 +3293,11 @@ fn emit_function_body_ops( }) => { let u8_pointer = map.get_or_add( builder, - SpirvType::from(ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - *state_space, - )), + SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8), *state_space), ); let result_type = map.get_or_add( builder, - SpirvType::from(ast::Type::Pointer(underlying_type.clone(), *state_space)), + SpirvType::new(ast::Type::Pointer(*underlying_type), *state_space), ); let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?; let temp = builder.in_bounds_ptr_access_chain( @@ -3503,11 +3509,16 @@ fn ptx_scope_name(scope: ast::MemScope) -> &'static str { } } -fn ptx_space_name(space: ast::AtomSpace) -> &'static str { +fn ptx_space_name(space: ast::StateSpace) -> &'static str { match space { - ast::AtomSpace::Generic => "generic", - ast::AtomSpace::Global => "global", - ast::AtomSpace::Shared => "shared", + ast::StateSpace::Generic => "generic", + ast::StateSpace::Global => "global", + ast::StateSpace::Shared => "shared", + ast::StateSpace::Reg => "reg", + ast::StateSpace::Const => "const", + ast::StateSpace::Local => "local", + ast::StateSpace::Param => "param", + ast::StateSpace::Sreg => "sreg", } } @@ -3572,6 +3583,7 @@ fn emit_variable( ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), ast::StateSpace::Const => todo!(), ast::StateSpace::Generic => todo!(), + ast::StateSpace::Sreg => todo!(), }; let initalizer = if var.array_init.len() > 0 { Some(map.get_or_add_constant( @@ -3580,17 +3592,14 @@ fn emit_variable( &*var.array_init, )?) } else if must_init { - let type_id = map.get_or_add( - builder, - SpirvType::from(ast::Type::from(var.v_type.clone())), - ); + let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone(), var.state_space)); Some(builder.constant_null(type_id, None)) } else { None }; let ptr_type_id = map.get_or_add( builder, - SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class), + SpirvType::pointer_to(var.v_type.clone(), var.state_space, st_class), ); builder.variable(ptr_type_id, Some(var.name), st_class, initalizer); if let Some(align) = var.align { @@ -3729,7 +3738,10 @@ fn emit_min( ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, }; - let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type())); + let inst_type = map.get_or_add( + builder, + SpirvType::new(desc.get_type(), ast::StateSpace::Reg), + ); builder.ext_inst( inst_type, Some(arg.dst), @@ -3754,7 +3766,10 @@ fn emit_max( ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, }; - let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type())); + let inst_type = map.get_or_add( + builder, + SpirvType::new(desc.get_type(), ast::StateSpace::Reg), + ); builder.ext_inst( inst_type, Some(arg.dst), @@ -3865,11 +3880,13 @@ fn emit_cvt( let cv = ImplicitConversion { src: arg.src, dst: new_dst, - from: ast::Type::Scalar(src_t), - to: ast::Type::Scalar(ast::ScalarType::from_parts( + from_type: ast::Type::Scalar(src_t), + from_space: ast::StateSpace::Reg, + to_type: ast::Type::Scalar(ast::ScalarType::from_parts( dest_t.size_of(), src_t.kind(), )), + to_space: ast::StateSpace::Reg, kind: ConversionKind::Default, src_sema: ArgumentSemantics::Default, dst_sema: ArgumentSemantics::Default, @@ -4224,20 +4241,24 @@ fn emit_implicit_conversion( map: &mut TypeWordMap, cv: &ImplicitConversion, ) -> Result<(), TranslateError> { - let from_parts = cv.from.to_parts(); - let to_parts = cv.to.to_parts(); + let from_parts = cv.from_type.to_parts(); + let to_parts = cv.to_type.to_parts(); match (from_parts.kind, to_parts.kind, cv.kind) { (_, _, ConversionKind::PtrToBit(typ)) => { let dst_type = map.get_or_add_scalar(builder, typ.into()); builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; } - (_, _, ConversionKind::BitToPtr(_)) => { - let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); + (_, _, ConversionKind::BitToPtr) => { + let dst_type = map.get_or_add( + builder, + SpirvType::pointer_to(cv.to_type.clone(), cv.from_space, cv.to_space.to_spirv()), + ); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { if from_parts.width == to_parts.width { - let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); + let dst_type = + map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)); if from_parts.scalar_kind != ast::ScalarKind::Float && to_parts.scalar_kind != ast::ScalarKind::Float { @@ -4247,13 +4268,16 @@ fn emit_implicit_conversion( builder.bitcast(dst_type, Some(cv.dst), cv.src)?; } } else { - // This block is safe because it's illegal to implictly convert between floating point instructions + // This block is safe because it's illegal to implictly convert between floating point values let same_width_bit_type = map.get_or_add( builder, - SpirvType::from(ast::Type::from_parts(TypeParts { - scalar_kind: ast::ScalarKind::Bit, - ..from_parts - })), + SpirvType::new( + ast::Type::from_parts(TypeParts { + scalar_kind: ast::ScalarKind::Bit, + ..from_parts + }), + cv.from_space, + ), ); let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?; let wide_bit_type = ast::Type::from_parts(TypeParts { @@ -4261,7 +4285,7 @@ fn emit_implicit_conversion( ..to_parts }); let wide_bit_type_spirv = - map.get_or_add(builder, SpirvType::from(wide_bit_type.clone())); + map.get_or_add(builder, SpirvType::new(wide_bit_type.clone(), cv.to_space)); if to_parts.scalar_kind == ast::ScalarKind::Unsigned || to_parts.scalar_kind == ast::ScalarKind::Bit { @@ -4282,8 +4306,10 @@ fn emit_implicit_conversion( &ImplicitConversion { src: wide_bit_value, dst: cv.dst, - from: wide_bit_type, - to: cv.to.clone(), + from_type: wide_bit_type, + from_space: new_todo!(), + to_type: cv.to_type.clone(), + to_space: new_todo!(), kind: ConversionKind::Default, src_sema: cv.src_sema, dst_sema: cv.dst_sema, @@ -4293,13 +4319,15 @@ fn emit_implicit_conversion( } } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => { - let result_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); + let result_type = + map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)); builder.s_convert(result_type, Some(cv.dst), cv.src)?; } (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default) | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default) | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => { - let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); + let into_type = + map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)); builder.bitcast(into_type, Some(cv.dst), cv.src)?; } (_, _, ConversionKind::PtrToPtr { spirv_ptr }) => { @@ -4307,12 +4335,12 @@ fn emit_implicit_conversion( map.get_or_add( builder, SpirvType::Pointer( - Box::new(SpirvType::from(cv.to.clone())), + Box::new(SpirvType::new(cv.to_type.clone(), cv.to_space)), spirv::StorageClass::Function, ), ) } else { - map.get_or_add(builder, SpirvType::from(cv.to.clone())) + map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)) }; builder.bitcast(result_type, Some(cv.dst), cv.src)?; } @@ -4326,14 +4354,18 @@ fn emit_load_var( map: &mut TypeWordMap, details: &LoadVarDetails, ) -> Result<(), TranslateError> { - let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone())); + let result_type = map.get_or_add( + builder, + SpirvType::new(details.typ.clone(), details.state_space), + ); match details.member_index { Some((index, Some(width))) => { let vector_type = match details.typ { ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), _ => return Err(TranslateError::MismatchedType), }; - let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type)); + let vector_type_spirv = + map.get_or_add(builder, SpirvType::new(vector_type, details.state_space)); let vector_temp = builder.load( vector_type_spirv, None, @@ -4351,7 +4383,11 @@ fn emit_load_var( Some((index, None)) => { let result_ptr_type = map.get_or_add( builder, - SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function), + SpirvType::pointer_to( + details.typ.clone(), + details.state_space, + spirv::StorageClass::Function, + ), ); let index_spirv = map.get_or_add_constant( builder, @@ -4433,18 +4469,25 @@ fn expand_map_variables<'a, 'b>( is_variable = true; var_type } else { - var_type.param_pointer_to(ast::LdStateSpace::Shared)? + var_type.param_pointer_to(ast::StateSpace::Shared)? } } - ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?, - ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?, - ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?, - ast::StateSpace::Const => todo!(), - ast::StateSpace::Generic => todo!(), + ast::StateSpace::Global => var_type.param_pointer_to(ast::StateSpace::Global)?, + ast::StateSpace::Param => var_type.param_pointer_to(ast::StateSpace::Param)?, + ast::StateSpace::Local => var_type.param_pointer_to(ast::StateSpace::Local)?, + ast::StateSpace::Const => new_todo!(), + ast::StateSpace::Generic => new_todo!(), + ast::StateSpace::Sreg => new_todo!(), }; match var.count { Some(count) => { - for new_id in id_defs.add_defs(var.var.name, count, var_type, is_variable) { + for new_id in id_defs.add_defs( + var.var.name, + count, + var_type, + var.var.state_space, + is_variable, + ) { result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), @@ -4455,7 +4498,11 @@ fn expand_map_variables<'a, 'b>( } } None => { - let new_id = id_defs.add_def(var.var.name, Some(var_type), is_variable); + let new_id = id_defs.add_def( + var.var.name, + Some((var_type, var.var.state_space)), + is_variable, + ); result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), @@ -4470,11 +4517,42 @@ fn expand_map_variables<'a, 'b>( Ok(()) } +/* + Our goal here is to transform + .visible .entry foobar(.param .u64 input) { + .reg .b64 in_addr; + .reg .b64 in_addr2; + ld.param.u64 in_addr, [input]; + cvta.to.global.u64 in_addr2, in_addr; + } + into: + .visible .entry foobar(.param .u8 input[]) { + .reg .u8 in_addr[]; + .reg .u8 in_addr2[]; + ld.param.u8[] in_addr, [input]; + mov.u8[] in_addr2, in_addr; + } + or: + .visible .entry foobar(.reg .u8 input[]) { + .reg .u8 in_addr[]; + .reg .u8 in_addr2[]; + mov.u8[] in_addr, input; + mov.u8[] in_addr2, in_addr; + } + or: + .visible .entry foobar(.param ptr input) { + .reg ptr in_addr; + .reg ptr in_addr2; + ld.param.ptr in_addr, [input]; + mov.ptr in_addr2, in_addr; + } +*/ // TODO: detect more patterns (mov, call via reg, call via param) // TODO: don't convert to ptr if the register is not ultimately used for ld/st // TODO: once insert_mem_ssa_statements is moved to later, move this pass after // argument expansion // TODO: propagate through calls? +/* fn convert_to_stateful_memory_access<'a>( func_args: &mut SpirvMethodDecl, func_body: Vec, @@ -4496,9 +4574,9 @@ fn convert_to_stateful_memory_access<'a>( match statement { Statement::Instruction(ast::Instruction::Cvta( ast::CvtaDetails { - to: ast::CvtaStateSpace::Global, + to: ast::StateSpace::Global, size: ast::CvtaSize::U64, - from: ast::CvtaStateSpace::Generic, + from: ast::StateSpace::Generic, }, arg, )) => { @@ -4512,24 +4590,24 @@ fn convert_to_stateful_memory_access<'a>( } Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { - state_space: ast::LdStateSpace::Param, - typ: ast::PointerType::Scalar(ast::ScalarType::U64), + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::U64), .. }, arg, )) | Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { - state_space: ast::LdStateSpace::Param, - typ: ast::PointerType::Scalar(ast::ScalarType::S64), + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::S64), .. }, arg, )) | Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { - state_space: ast::LdStateSpace::Param, - typ: ast::PointerType::Scalar(ast::ScalarType::B64), + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::B64), .. }, arg, @@ -4611,19 +4689,16 @@ fn convert_to_stateful_memory_access<'a>( let mut remapped_ids = HashMap::new(); let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len()); for reg in regs_ptr_seen { - let new_id = id_defs.new_variable(ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, - )); + let new_id = id_defs.register_variable( + ast::Type::Pointer(ast::ScalarType::U8), + ast::StateSpace::Global, + ); result.push(Statement::Variable(ast::Variable { align: None, name: new_id, array_init: Vec::new(), - v_type: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, - ), - state_space: ast::StateSpace::Reg, + v_type: ast::Type::Pointer(ast::ScalarType::U8), + state_space: ast::StateSpace::Global, })); remapped_ids.insert(reg, new_id); } @@ -4658,8 +4733,8 @@ fn convert_to_stateful_memory_access<'a>( }; let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { - underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), - state_space: ast::LdStateSpace::Global, + underlying_type: ast::ScalarType::U8, + state_space: ast::StateSpace::Global, dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, offset_src: offset, @@ -4686,7 +4761,7 @@ fn convert_to_stateful_memory_access<'a>( _ => return Err(error_unreachable()), }; let offset_neg = - id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64))); + id_defs.register_intermediate(Some(ast::Type::Scalar(ast::ScalarType::S64))); result.push(Statement::Instruction(ast::Instruction::Neg( ast::NegDetails { typ: ast::ScalarType::S64, @@ -4699,8 +4774,8 @@ fn convert_to_stateful_memory_access<'a>( ))); let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { - underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), - state_space: ast::LdStateSpace::Global, + underlying_type: ast::ScalarType::U8, + state_space: ast::StateSpace::Global, dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, offset_src: TypedOperand::Reg(offset_neg), @@ -4768,10 +4843,8 @@ fn convert_to_stateful_memory_access<'a>( } for arg in func_args.input.iter_mut() { if func_args_ptr.contains(&arg.name) { - arg.v_type = ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, - ); + arg.v_type = ast::Type::Pointer(ast::ScalarType::U8); + arg.state_space = ast::StateSpace::Global; } } Ok(result) @@ -4790,21 +4863,21 @@ fn convert_to_stateful_memory_access_postprocess( Some(new_id) => { // We skip conversion here to trigger PtrAcces in a later pass let old_type = match expected_type { - Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id), + Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id), _ => id_defs.get_typed(arg_desc.op)?.0, }; let old_type_clone = old_type.clone(); - let converting_id = id_defs.new_non_variable(Some(old_type_clone)); + let converting_id = id_defs.register_intermediate(Some(old_type_clone)); if arg_desc.is_dst { post_statements.push(Statement::Conversion(ImplicitConversion { src: converting_id, dst: *new_id, - from: old_type, - to: ast::Type::Pointer( + from_type: old_type, + to_type: ast::Type::Pointer( ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, + ast::StateSpace::Global, ), - kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global), + kind: ConversionKind::BitToPtr(ast::StateSpace::Global), src_sema: ArgumentSemantics::Default, dst_sema: arg_desc.sema, })); @@ -4813,11 +4886,11 @@ fn convert_to_stateful_memory_access_postprocess( result.push(Statement::Conversion(ImplicitConversion { src: *new_id, dst: converting_id, - from: ast::Type::Pointer( + from_type: ast::Type::Pointer( ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, + ast::StateSpace::Global, ), - to: old_type, + to_type: old_type, kind: ConversionKind::PtrToBit(ast::ScalarType::U64), src_sema: arg_desc.sema, dst_sema: ArgumentSemantics::Default, @@ -4832,19 +4905,19 @@ fn convert_to_stateful_memory_access_postprocess( } // We skip conversion here to trigger PtrAcces in a later pass let old_type = match expected_type { - Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id), + Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id), _ => id_defs.get_typed(arg_desc.op)?.0, }; let old_type_clone = old_type.clone(); - let converting_id = id_defs.new_non_variable(Some(old_type)); + let converting_id = id_defs.register_intermediate(Some(old_type)); result.push(Statement::Conversion(ImplicitConversion { src: *new_id, dst: converting_id, - from: ast::Type::Pointer( - ast::PointerType::Pointer(ast::ScalarType::U8, ast::LdStateSpace::Global), - ast::LdStateSpace::Param, + from_type: ast::Type::Pointer( + ast::PointerType::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + ast::StateSpace::Param, ), - to: old_type_clone, + to_type: old_type_clone, kind: ConversionKind::PtrToPtr { spirv_ptr: false }, src_sema: arg_desc.sema, dst_sema: ArgumentSemantics::Default, @@ -4855,6 +4928,7 @@ fn convert_to_stateful_memory_access_postprocess( }, }) } +*/ fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { match arg.dst { @@ -4876,9 +4950,9 @@ fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3 bool { match id_defs.get_typed(id) { - Ok((ast::Type::Scalar(ast::ScalarType::U64), _)) - | Ok((ast::Type::Scalar(ast::ScalarType::S64), _)) - | Ok((ast::Type::Scalar(ast::ScalarType::B64), _)) => true, + Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _)) + | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _)) + | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true, _ => false, } } @@ -5007,7 +5081,7 @@ impl SpecialRegistersMap { struct GlobalStringIdResolver<'input> { current_id: spirv::Word, variables: HashMap, spirv::Word>, - variables_type_check: HashMap>, + variables_type_check: HashMap>, special_registers: SpecialRegistersMap, fns: HashMap, } @@ -5036,12 +5110,17 @@ impl<'a> GlobalStringIdResolver<'a> { &mut self, id: &'a str, typ: ast::Type, + state_space: ast::StateSpace, is_variable: bool, ) -> spirv::Word { - self.get_or_add_impl(id, Some((typ, is_variable))) + self.get_or_add_impl(id, Some((typ, state_space, is_variable))) } - fn get_or_add_impl(&mut self, id: &'a str, typ: Option<(ast::Type, bool)>) -> spirv::Word { + fn get_or_add_impl( + &mut self, + id: &'a str, + typ: Option<(ast::Type, ast::StateSpace, bool)>, + ) -> spirv::Word { let id = match self.variables.entry(Cow::Borrowed(id)) { hash_map::Entry::Occupied(e) => *(e.get()), hash_map::Entry::Vacant(e) => { @@ -5143,10 +5222,10 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { struct FnStringIdResolver<'input, 'b> { current_id: &'b mut spirv::Word, global_variables: &'b HashMap, spirv::Word>, - global_type_check: &'b HashMap>, + global_type_check: &'b HashMap>, special_registers: &'b mut SpecialRegistersMap, variables: Vec, spirv::Word>>, - type_check: HashMap>, + type_check: HashMap>, } impl<'a, 'b> FnStringIdResolver<'a, 'b> { @@ -5184,14 +5263,21 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { } } - fn add_def(&mut self, id: &'a str, typ: Option, is_variable: bool) -> spirv::Word { + fn add_def( + &mut self, + id: &'a str, + typ: Option<(ast::Type, ast::StateSpace)>, + is_variable: bool, + ) -> spirv::Word { let numeric_id = *self.current_id; self.variables .last_mut() .unwrap() .insert(Cow::Borrowed(id), numeric_id); - self.type_check - .insert(numeric_id, typ.map(|t| (t, is_variable))); + self.type_check.insert( + numeric_id, + typ.map(|(typ, space)| (typ, space, is_variable)), + ); *self.current_id += 1; numeric_id } @@ -5202,6 +5288,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { base_id: &'a str, count: u32, typ: ast::Type, + state_space: ast::StateSpace, is_variable: bool, ) -> impl Iterator { let numeric_id = *self.current_id; @@ -5210,8 +5297,10 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .last_mut() .unwrap() .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i); - self.type_check - .insert(numeric_id + i, Some((typ.clone(), is_variable))); + self.type_check.insert( + numeric_id + i, + Some((typ.clone(), state_space, is_variable)), + ); } *self.current_id += count; (0..count).into_iter().map(move |i| i + numeric_id) @@ -5220,8 +5309,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { struct NumericIdResolver<'b> { current_id: &'b mut spirv::Word, - global_type_check: &'b HashMap>, - type_check: HashMap>, + global_type_check: &'b HashMap>, + type_check: HashMap>, special_registers: &'b mut SpecialRegistersMap, } @@ -5230,12 +5319,15 @@ impl<'b> NumericIdResolver<'b> { MutableNumericIdResolver { base: self } } - fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, bool), TranslateError> { + fn get_typed( + &self, + id: spirv::Word, + ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> { match self.type_check.get(&id) { Some(Some(x)) => Ok(x.clone()), Some(None) => Err(TranslateError::UntypedSymbol), None => match self.special_registers.get(id) { - Some(x) => Ok((x.get_type(), true)), + Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)), None => match self.global_type_check.get(&id) { Some(Some(result)) => Ok(result.clone()), Some(None) | None => Err(TranslateError::UntypedSymbol), @@ -5246,16 +5338,18 @@ impl<'b> NumericIdResolver<'b> { // This is for identifiers which will be emitted later as OpVariable // They are candidates for insertion of LoadVar/StoreVar - fn new_variable(&mut self, typ: ast::Type) -> spirv::Word { + fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> spirv::Word { let new_id = *self.current_id; - self.type_check.insert(new_id, Some((typ, true))); + self.type_check + .insert(new_id, Some((typ, state_space, true))); *self.current_id += 1; new_id } - fn new_non_variable(&mut self, typ: Option) -> spirv::Word { + fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> spirv::Word { let new_id = *self.current_id; - self.type_check.insert(new_id, typ.map(|t| (t, false))); + self.type_check + .insert(new_id, typ.map(|(t, space)| (t, space, false))); *self.current_id += 1; new_id } @@ -5270,12 +5364,16 @@ impl<'b> MutableNumericIdResolver<'b> { self.base } - fn get_typed(&self, id: spirv::Word) -> Result { - self.base.get_typed(id).map(|(t, _)| t) + fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, ast::StateSpace), TranslateError> { + self.base.get_typed(id).map(|(t, space, _)| (t, space)) } - fn new_non_variable(&mut self, typ: ast::Type) -> spirv::Word { - self.base.new_non_variable(Some(typ)) + fn register_intermediate( + &mut self, + typ: ast::Type, + state_space: ast::StateSpace, + ) -> spirv::Word { + self.base.register_intermediate(Some((typ, state_space))) } } @@ -5304,7 +5402,8 @@ impl ExpandedStatement { Statement::Variable(var) } Statement::Instruction(inst) => inst - .visit(&mut |arg: ArgumentDescriptor<_>, _: Option<&ast::Type>| { + .visit(&mut |arg: ArgumentDescriptor<_>, + _: Option<(&ast::Type, ast::StateSpace)>| { Ok(f(arg.op, arg.is_dst)) }) .unwrap(), @@ -5391,6 +5490,7 @@ impl ExpandedStatement { struct LoadVarDetails { arg: ast::Arg2, typ: ast::Type, + state_space: ast::StateSpace, // (index, vector_width) // HACK ALERT // For some reason IGC explodes when you try to load from builtin vectors @@ -5402,6 +5502,7 @@ struct LoadVarDetails { struct StoreVarDetails { arg: ast::Arg2St, typ: ast::Type, + state_space: ast::StateSpace, member_index: Option, } @@ -5428,7 +5529,10 @@ impl RepackVectorDetails { is_dst: !self.is_extract, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)), + Some(( + &ast::Type::Vector(self.typ, self.unpacked.len() as u8), + ast::StateSpace::Reg, + )), )?; let scalar_type = self.typ; let is_extract = self.is_extract; @@ -5443,7 +5547,7 @@ impl RepackVectorDetails { is_dst: is_extract, sema: vector_sema, }, - Some(&ast::Type::Scalar(scalar_type)), + Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)), ) }) .collect::>()?; @@ -5501,7 +5605,7 @@ impl> ResolvedCall { is_dst: space != ast::StateSpace::Param, sema: space.semantics(), }, - Some(&typ), + Some((&typ, space)), )?; Ok((new_id, typ, space)) }) @@ -5525,6 +5629,7 @@ impl> ResolvedCall { sema: space.semantics(), }, &typ, + space, )?; Ok((new_id, typ, space)) }) @@ -5555,22 +5660,22 @@ impl> PtrAccess

{ visitor: &mut V, ) -> Result, TranslateError> { let sema = match self.state_space { - ast::LdStateSpace::Const - | ast::LdStateSpace::Global - | ast::LdStateSpace::Shared - | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer, - ast::LdStateSpace::Local | ast::LdStateSpace::Param => { - ArgumentSemantics::RegisterPointer - } + ast::StateSpace::Const + | ast::StateSpace::Global + | ast::StateSpace::Shared + | ast::StateSpace::Generic => ArgumentSemantics::PhysicalPointer, + ast::StateSpace::Local | ast::StateSpace::Param => ArgumentSemantics::RegisterPointer, + ast::StateSpace::Reg => new_todo!(), + ast::StateSpace::Sreg => new_todo!(), }; - let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), self.state_space); + let ptr_type = ast::Type::Pointer(self.underlying_type.clone()); let new_dst = visitor.id( ArgumentDescriptor { op: self.dst, is_dst: true, sema, }, - Some(&ptr_type), + Some((&ptr_type, self.state_space)), )?; let new_ptr_src = visitor.id( ArgumentDescriptor { @@ -5578,7 +5683,7 @@ impl> PtrAccess

{ is_dst: false, sema, }, - Some(&ptr_type), + Some((&ptr_type, self.state_space)), )?; let new_constant_src = visitor.operand( ArgumentDescriptor { @@ -5587,6 +5692,7 @@ impl> PtrAccess

{ sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::S64), + self.state_space, )?; Ok(PtrAccess { underlying_type: self.underlying_type, @@ -5723,12 +5829,13 @@ pub trait ArgumentMapVisitor { fn id( &mut self, desc: ArgumentDescriptor, - typ: Option<&ast::Type>, + typ: Option<(&ast::Type, ast::StateSpace)>, ) -> Result; fn operand( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result; } @@ -5736,13 +5843,13 @@ impl ArgumentMapVisitor for T where T: FnMut( ArgumentDescriptor, - Option<&ast::Type>, + Option<(&ast::Type, ast::StateSpace)>, ) -> Result, { fn id( &mut self, desc: ArgumentDescriptor, - t: Option<&ast::Type>, + t: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self(desc, t) } @@ -5751,8 +5858,9 @@ where &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { - self(desc, Some(typ)) + self(desc, Some((typ, state_space))) } } @@ -5763,7 +5871,7 @@ where fn id( &mut self, desc: ArgumentDescriptor<&str>, - _: Option<&ast::Type>, + _: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self(desc.op) } @@ -5772,6 +5880,7 @@ where &mut self, desc: ArgumentDescriptor>, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result, TranslateError> { Ok(match desc.op { ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?), @@ -5780,7 +5889,7 @@ where ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member), ast::Operand::VecPack(ref ids) => ast::Operand::VecPack( ids.into_iter() - .map(|id| self.id(desc.new_op(id), Some(typ))) + .map(|id| self.id(desc.new_op(id), Some((typ, state_space)))) .collect::, _>>()?, ), }) @@ -5794,8 +5903,8 @@ pub struct ArgumentDescriptor { } pub struct PtrAccess { - underlying_type: ast::PointerType, - state_space: ast::LdStateSpace, + underlying_type: ast::ScalarType, + state_space: ast::StateSpace, dst: spirv::Word, ptr_src: spirv::Word, offset_src: P::Operand, @@ -6061,7 +6170,7 @@ impl ImplicitConversion { is_dst: true, sema: self.dst_sema, }, - Some(&self.to), + Some((&self.to_type, self.to_space)), )?; let new_src = visitor.id( ArgumentDescriptor { @@ -6069,7 +6178,7 @@ impl ImplicitConversion { is_dst: false, sema: self.src_sema, }, - Some(&self.from), + Some((&self.from_type, self.from_space)), )?; Ok(Statement::Conversion({ ImplicitConversion { @@ -6096,13 +6205,13 @@ impl ArgumentMapVisitor for T where T: FnMut( ArgumentDescriptor, - Option<&ast::Type>, + Option<(&ast::Type, ast::StateSpace)>, ) -> Result, { fn id( &mut self, desc: ArgumentDescriptor, - t: Option<&ast::Type>, + t: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self(desc, t) } @@ -6111,12 +6220,15 @@ where &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { Ok(match desc.op { - TypedOperand::Reg(id) => TypedOperand::Reg(self(desc.new_op(id), Some(typ))?), + TypedOperand::Reg(id) => { + TypedOperand::Reg(self(desc.new_op(id), Some((typ, state_space)))?) + } TypedOperand::Imm(imm) => TypedOperand::Imm(imm), TypedOperand::RegOffset(id, imm) => { - TypedOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm) + TypedOperand::RegOffset(self(desc.new_op(id), Some((typ, state_space)))?, imm) } TypedOperand::VecMember(reg, index) => { let scalar_type = match typ { @@ -6124,7 +6236,10 @@ where _ => return Err(error_unreachable()), }; let vec_type = ast::Type::Vector(scalar_type, index + 1); - TypedOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) + TypedOperand::VecMember( + self(desc.new_op(reg), Some((&vec_type, state_space)))?, + index, + ) } }) } @@ -6159,54 +6274,25 @@ impl ast::Type { scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), - state_space: ast::LdStateSpace::Global, }, ast::Type::Vector(scalar, components) => TypeParts { kind: TypeKind::Vector, scalar_kind: scalar.kind(), width: scalar.size_of(), components: vec![*components as u32], - state_space: ast::LdStateSpace::Global, }, ast::Type::Array(scalar, components) => TypeParts { kind: TypeKind::Array, scalar_kind: scalar.kind(), width: scalar.size_of(), components: components.clone(), - state_space: ast::LdStateSpace::Global, }, - ast::Type::Pointer(ast::PointerType::Scalar(scalar), state_space) => TypeParts { + ast::Type::Pointer(scalar) => TypeParts { kind: TypeKind::PointerScalar, scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), - state_space: *state_space, }, - ast::Type::Pointer(ast::PointerType::Vector(scalar, len), state_space) => TypeParts { - kind: TypeKind::PointerVector, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: vec![*len as u32], - state_space: *state_space, - }, - ast::Type::Pointer(ast::PointerType::Array(scalar, components), state_space) => { - TypeParts { - kind: TypeKind::PointerArray, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: components.clone(), - state_space: *state_space, - } - } - ast::Type::Pointer(ast::PointerType::Pointer(scalar, inner_space), state_space) => { - TypeParts { - kind: TypeKind::PointerPointer, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: vec![*inner_space as u32], - state_space: *state_space, - } - } } } @@ -6223,31 +6309,9 @@ impl ast::Type { ast::ScalarType::from_parts(t.width, t.scalar_kind), t.components, ), - TypeKind::PointerScalar => ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)), - t.state_space, - ), - TypeKind::PointerVector => ast::Type::Pointer( - ast::PointerType::Vector( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - t.components[0] as u8, - ), - t.state_space, - ), - TypeKind::PointerArray => ast::Type::Pointer( - ast::PointerType::Array( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - t.components, - ), - t.state_space, - ), - TypeKind::PointerPointer => ast::Type::Pointer( - ast::PointerType::Pointer( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - unsafe { mem::transmute::<_, ast::LdStateSpace>(t.components[0] as u8) }, - ), - t.state_space, - ), + TypeKind::PointerScalar => { + ast::Type::Pointer(ast::ScalarType::from_parts(t.width, t.scalar_kind)) + } } } @@ -6258,7 +6322,7 @@ impl ast::Type { ast::Type::Array(typ, len) => len .iter() .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), - ast::Type::Pointer(_, _) => mem::size_of::(), + ast::Type::Pointer(..) => mem::size_of::(), } } } @@ -6269,7 +6333,6 @@ struct TypeParts { scalar_kind: ast::ScalarKind, width: u8, components: Vec, - state_space: ast::LdStateSpace, } #[derive(Eq, PartialEq, Copy, Clone)] @@ -6278,9 +6341,6 @@ enum TypeKind { Vector, Array, PointerScalar, - PointerVector, - PointerArray, - PointerPointer, } impl ast::Instruction { @@ -6408,8 +6468,10 @@ struct BrachCondition { struct ImplicitConversion { src: spirv::Word, dst: spirv::Word, - from: ast::Type, - to: ast::Type, + from_type: ast::Type, + to_type: ast::Type, + from_space: ast::StateSpace, + to_space: ast::StateSpace, kind: ConversionKind, src_sema: ArgumentSemantics, dst_sema: ArgumentSemantics, @@ -6420,7 +6482,7 @@ enum ConversionKind { Default, // zero-extend/chop/bitcast depending on types SignExtend, - BitToPtr(ast::LdStateSpace), + BitToPtr, PtrToBit(ast::ScalarType), PtrToPtr { spirv_ptr: bool }, } @@ -6470,7 +6532,7 @@ impl ast::Arg1 { fn map>( self, visitor: &mut V, - t: Option<&ast::Type>, + t: Option<(&ast::Type, ast::StateSpace)>, ) -> Result, TranslateError> { let new_src = visitor.id( ArgumentDescriptor { @@ -6496,6 +6558,7 @@ impl ast::Arg1Bar { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, )?; Ok(ast::Arg1Bar { src: new_src }) } @@ -6514,6 +6577,7 @@ impl ast::Arg2 { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let new_src = visitor.operand( ArgumentDescriptor { @@ -6522,6 +6586,7 @@ impl ast::Arg2 { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; Ok(ast::Arg2 { dst: new_dst, @@ -6542,6 +6607,7 @@ impl ast::Arg2 { sema: ArgumentSemantics::Default, }, dst_t, + ast::StateSpace::Reg, )?; let src = visitor.operand( ArgumentDescriptor { @@ -6550,6 +6616,7 @@ impl ast::Arg2 { sema: ArgumentSemantics::Default, }, src_t, + ast::StateSpace::Reg, )?; Ok(ast::Arg2 { dst, src }) } @@ -6568,9 +6635,10 @@ impl ast::Arg2Ld { sema: ArgumentSemantics::DefaultRelaxed, }, &ast::Type::from(details.typ.clone()), + ast::StateSpace::Reg, )?; - let is_logical_ptr = details.state_space == ast::LdStateSpace::Param - || details.state_space == ast::LdStateSpace::Local; + let is_logical_ptr = details.state_space == ast::StateSpace::Param + || details.state_space == ast::StateSpace::Local; let src = visitor.operand( ArgumentDescriptor { op: self.src, @@ -6581,10 +6649,8 @@ impl ast::Arg2Ld { ArgumentSemantics::PhysicalPointer }, }, - &ast::Type::Pointer( - ast::PointerType::from(details.typ.clone()), - details.state_space, - ), + &details.typ, + details.state_space, )?; Ok(ast::Arg2Ld { dst, src }) } @@ -6596,8 +6662,8 @@ impl ast::Arg2St { visitor: &mut V, details: &ast::StData, ) -> Result, TranslateError> { - let is_logical_ptr = details.state_space == ast::StStateSpace::Param - || details.state_space == ast::StStateSpace::Local; + let is_logical_ptr = details.state_space == ast::StateSpace::Param + || details.state_space == ast::StateSpace::Local; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, @@ -6608,10 +6674,8 @@ impl ast::Arg2St { ArgumentSemantics::PhysicalPointer }, }, - &ast::Type::Pointer( - ast::PointerType::from(details.typ.clone()), - details.state_space.to_ld_ss(), - ), + &details.typ, + details.state_space, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -6620,6 +6684,7 @@ impl ast::Arg2St { sema: ArgumentSemantics::DefaultRelaxed, }, &details.typ.clone().into(), + ast::StateSpace::Reg, )?; Ok(ast::Arg2St { src1, src2 }) } @@ -6638,6 +6703,7 @@ impl ast::Arg2Mov { sema: ArgumentSemantics::Default, }, &details.typ.clone().into(), + ast::StateSpace::Reg, )?; let src = visitor.operand( ArgumentDescriptor { @@ -6650,6 +6716,7 @@ impl ast::Arg2Mov { }, }, &details.typ.clone().into(), + ast::StateSpace::Reg, )?; Ok(ast::Arg2Mov { dst, src }) } @@ -6674,6 +6741,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, wide_type.as_ref().unwrap_or(typ), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6682,6 +6750,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, typ, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -6690,6 +6759,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, typ, + ast::StateSpace::Reg, )?; Ok(ast::Arg3 { dst, src1, src2 }) } @@ -6706,6 +6776,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6714,6 +6785,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -6722,6 +6794,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, )?; Ok(ast::Arg3 { dst, src1, src2 }) } @@ -6730,7 +6803,7 @@ impl ast::Arg3 { self, visitor: &mut V, t: ast::ScalarType, - state_space: ast::AtomSpace, + state_space: ast::StateSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); let dst = visitor.operand( @@ -6740,6 +6813,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6747,10 +6821,8 @@ impl ast::Arg3 { is_dst: false, sema: ArgumentSemantics::PhysicalPointer, }, - &ast::Type::Pointer( - ast::PointerType::Scalar(scalar_type), - state_space.to_ld_ss(), - ), + &ast::Type::Scalar(scalar_type), + state_space, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -6759,6 +6831,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; Ok(ast::Arg3 { dst, src1, src2 }) } @@ -6783,6 +6856,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, wide_type.as_ref().unwrap_or(t), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6791,6 +6865,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -6799,6 +6874,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { @@ -6807,6 +6883,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; Ok(ast::Arg4 { dst, @@ -6828,6 +6905,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(t.into()), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6836,6 +6914,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(t.into()), + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -6844,6 +6923,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(t.into()), + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { @@ -6852,6 +6932,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, )?; Ok(ast::Arg4 { dst, @@ -6865,7 +6946,7 @@ impl ast::Arg4 { self, visitor: &mut V, t: ast::ScalarType, - state_space: ast::AtomSpace, + state_space: ast::StateSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); let dst = visitor.operand( @@ -6875,6 +6956,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6882,10 +6964,8 @@ impl ast::Arg4 { is_dst: false, sema: ArgumentSemantics::PhysicalPointer, }, - &ast::Type::Pointer( - ast::PointerType::Scalar(scalar_type), - state_space.to_ld_ss(), - ), + &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -6894,6 +6974,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { @@ -6902,6 +6983,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; Ok(ast::Arg4 { dst, @@ -6923,6 +7005,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, typ, + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6931,6 +7014,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, typ, + ast::StateSpace::Reg, )?; let u32_type = ast::Type::Scalar(ast::ScalarType::U32); let src2 = visitor.operand( @@ -6940,6 +7024,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &u32_type, + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { @@ -6948,6 +7033,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &u32_type, + ast::StateSpace::Reg, )?; Ok(ast::Arg4 { dst, @@ -6970,7 +7056,10 @@ impl ast::Arg4Setp { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(ast::ScalarType::Pred)), + Some(( + &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + )), )?; let dst2 = self .dst2 @@ -6981,7 +7070,10 @@ impl ast::Arg4Setp { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(ast::ScalarType::Pred)), + Some(( + &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + )), ) }) .transpose()?; @@ -6992,6 +7084,7 @@ impl ast::Arg4Setp { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -7000,6 +7093,7 @@ impl ast::Arg4Setp { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; Ok(ast::Arg4Setp { dst1, @@ -7023,6 +7117,7 @@ impl ast::Arg5 { sema: ArgumentSemantics::Default, }, base_type, + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -7031,6 +7126,7 @@ impl ast::Arg5 { sema: ArgumentSemantics::Default, }, base_type, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -7039,6 +7135,7 @@ impl ast::Arg5 { sema: ArgumentSemantics::Default, }, base_type, + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { @@ -7047,6 +7144,7 @@ impl ast::Arg5 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, )?; let src4 = visitor.operand( ArgumentDescriptor { @@ -7055,6 +7153,7 @@ impl ast::Arg5 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, )?; Ok(ast::Arg5 { dst, @@ -7078,7 +7177,10 @@ impl ast::Arg5Setp { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(ast::ScalarType::Pred)), + Some(( + &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + )), )?; let dst2 = self .dst2 @@ -7089,7 +7191,10 @@ impl ast::Arg5Setp { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(ast::ScalarType::Pred)), + Some(( + &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + )), ) }) .transpose()?; @@ -7100,6 +7205,7 @@ impl ast::Arg5Setp { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -7108,6 +7214,7 @@ impl ast::Arg5Setp { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { @@ -7116,6 +7223,7 @@ impl ast::Arg5Setp { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, )?; Ok(ast::Arg5Setp { dst1, @@ -7153,18 +7261,6 @@ impl ast::Operand { } } -impl ast::StStateSpace { - fn to_ld_ss(self) -> ast::LdStateSpace { - match self { - ast::StStateSpace::Generic => ast::LdStateSpace::Generic, - ast::StStateSpace::Global => ast::LdStateSpace::Global, - ast::StStateSpace::Local => ast::LdStateSpace::Local, - ast::StStateSpace::Param => ast::LdStateSpace::Param, - ast::StStateSpace::Shared => ast::LdStateSpace::Shared, - } - } -} - impl ast::ScalarType { fn from_parts(width: u8, kind: ast::ScalarKind) -> Self { match kind { @@ -7255,15 +7351,17 @@ impl ast::AtomInnerDetails { } } -impl ast::LdStateSpace { +impl ast::StateSpace { fn to_spirv(self) -> spirv::StorageClass { match self { - ast::LdStateSpace::Const => spirv::StorageClass::UniformConstant, - ast::LdStateSpace::Generic => spirv::StorageClass::Generic, - ast::LdStateSpace::Global => spirv::StorageClass::CrossWorkgroup, - ast::LdStateSpace::Local => spirv::StorageClass::Function, - ast::LdStateSpace::Shared => spirv::StorageClass::Workgroup, - ast::LdStateSpace::Param => spirv::StorageClass::Function, + ast::StateSpace::Const => spirv::StorageClass::UniformConstant, + ast::StateSpace::Generic => spirv::StorageClass::Generic, + ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, + ast::StateSpace::Local => spirv::StorageClass::Function, + ast::StateSpace::Shared => spirv::StorageClass::Workgroup, + ast::StateSpace::Param => spirv::StorageClass::Function, + ast::StateSpace::Reg => spirv::StorageClass::Function, + ast::StateSpace::Sreg => spirv::StorageClass::Input, } } } @@ -7289,16 +7387,6 @@ impl ast::MulDetails { } } -impl ast::AtomSpace { - fn to_ld_ss(self) -> ast::LdStateSpace { - match self { - ast::AtomSpace::Generic => ast::LdStateSpace::Generic, - ast::AtomSpace::Global => ast::LdStateSpace::Global, - ast::AtomSpace::Shared => ast::LdStateSpace::Shared, - } - } -} - impl ast::MemScope { fn to_spirv(self) -> spirv::Scope { match self { @@ -7333,89 +7421,44 @@ impl ast::StateSpace { fn bitcast_register_pointer( operand_type: &ast::Type, + operand_space: ast::StateSpace, instr_type: &ast::Type, - ss: Option, + instruction_space: ast::StateSpace, ) -> Result, TranslateError> { - bitcast_physical_pointer(operand_type, instr_type, ss) + bitcast_physical_pointer(operand_type, operand_space, instr_type, instruction_space) } fn bitcast_physical_pointer( operand_type: &ast::Type, - instr_type: &ast::Type, - ss: Option, + operand_space: ast::StateSpace, + instruction_type: &ast::Type, + instruction_space: ast::StateSpace, ) -> Result, TranslateError> { - match operand_type { - // array decays to a pointer - ast::Type::Array(op_scalar_t, _) => { - if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type { - if ss == Some(*instr_space) { - if ast::Type::Scalar(*op_scalar_t) == ast::Type::from(instr_scalar_t.clone()) { - Ok(None) - } else { - Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) - } - } else { - if ss == Some(ast::LdStateSpace::Generic) - || *instr_space == ast::LdStateSpace::Generic - { - Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) - } else { - Err(TranslateError::MismatchedType) - } - } - } else { - Err(TranslateError::MismatchedType) - } + if operand_space == instruction_space { + if operand_type != instruction_type { + Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) + } else { + Ok(None) } - ast::Type::Scalar(ast::ScalarType::B64) - | ast::Type::Scalar(ast::ScalarType::U64) - | ast::Type::Scalar(ast::ScalarType::S64) => { - if let Some(space) = ss { - Ok(Some(ConversionKind::BitToPtr(space))) - } else { - Err(error_unreachable()) - } - } - ast::Type::Scalar(ast::ScalarType::B32) - | ast::Type::Scalar(ast::ScalarType::U32) - | ast::Type::Scalar(ast::ScalarType::S32) => match ss { - Some(ast::LdStateSpace::Shared) - | Some(ast::LdStateSpace::Generic) - | Some(ast::LdStateSpace::Param) - | Some(ast::LdStateSpace::Local) => { - Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared))) - } + } else { + match operand_space { + ast::StateSpace::Reg | ast::StateSpace::Sreg => match instruction_space { + ast::StateSpace::Generic + | ast::StateSpace::Global + | ast::StateSpace::Shared + | ast::StateSpace::Local => Ok(Some(ConversionKind::BitToPtr)), + _ => Err(TranslateError::MismatchedType), + }, _ => Err(TranslateError::MismatchedType), - }, - ast::Type::Pointer(op_scalar_t, op_space) => { - if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type { - if op_space == instr_space { - if op_scalar_t == instr_scalar_t { - Ok(None) - } else { - Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) - } - } else { - if *op_space == ast::LdStateSpace::Generic - || *instr_space == ast::LdStateSpace::Generic - { - Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) - } else { - Err(TranslateError::MismatchedType) - } - } - } else { - Err(TranslateError::MismatchedType) - } } - _ => Err(TranslateError::MismatchedType), } } fn force_bitcast_ptr_to_bit( _: &ast::Type, + _: ast::StateSpace, instr_type: &ast::Type, - _: Option, + _: ast::StateSpace, ) -> Result, TranslateError> { // TODO: verify this on f32, u16 and the like if let ast::Type::Scalar(scalar_t) = instr_type { @@ -7457,11 +7500,12 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { fn should_bitcast_packed( operand: &ast::Type, - instr: &ast::Type, - ss: Option, + operand_space: ast::StateSpace, + instruction: &ast::Type, + instruction_space: ast::StateSpace, ) -> Result, TranslateError> { if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = - (operand, instr) + (operand, instruction) { if scalar.kind() == ast::ScalarKind::Bit && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) @@ -7469,13 +7513,14 @@ fn should_bitcast_packed( return Ok(Some(ConversionKind::Default)); } } - should_bitcast_wrapper(operand, instr, ss) + should_bitcast_wrapper(operand, operand_space, instruction, instruction_space) } fn should_bitcast_wrapper( operand: &ast::Type, + _: ast::StateSpace, instr: &ast::Type, - _: Option, + _: ast::StateSpace, ) -> Result, TranslateError> { if instr == operand { return Ok(None); @@ -7489,8 +7534,9 @@ fn should_bitcast_wrapper( fn should_convert_relaxed_src_wrapper( src_type: &ast::Type, + _: ast::StateSpace, instr_type: &ast::Type, - _: Option, + _: ast::StateSpace, ) -> Result, TranslateError> { if src_type == instr_type { return Ok(None); @@ -7552,8 +7598,9 @@ fn should_convert_relaxed_src( fn should_convert_relaxed_dst_wrapper( dst_type: &ast::Type, + _: ast::StateSpace, instr_type: &ast::Type, - _: Option, + _: ast::StateSpace, ) -> Result, TranslateError> { if dst_type == instr_type { return Ok(None); From 7f051ad20ec933f78ce4539020a25fab3503011c Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 6 May 2021 01:32:45 +0200 Subject: [PATCH 07/25] Fix and test --- ptx/src/translate.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a743496..1a2eda3 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -2424,7 +2424,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { let (reg, offset) = desc.op; let add_type; match typ { - ast::Type::Pointer(underlying_type) => { + ast::Type::Scalar(underlying_type) => { let (reg_typ, space) = self.id_def.get_typed(reg)?; if let ast::Type::Pointer(..) = reg_typ { let id_constant_stmt = self.id_def.register_intermediate(typ.clone(), space); @@ -2443,12 +2443,10 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { })); return Ok(dst); } else { - add_type = self.id_def.get_typed(reg)?.0; + add_type = reg_typ; } } - _ => { - add_type = typ.clone(); - } + _ => return Err(error_unreachable()), }; let (width, kind) = match add_type { ast::Type::Scalar(scalar_t) => { From 425edfcdd49a4fa49d480f1b078c55dba4709e29 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 7 May 2021 18:22:09 +0200 Subject: [PATCH 08/25] Simplify typing --- ptx/src/ast.rs | 21 +- ptx/src/ptx.lalrpop | 21 +- ptx/src/translate.rs | 524 +++++++++++++++++------------------------- zluda_dump/src/lib.rs | 14 +- 4 files changed, 247 insertions(+), 333 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 364ec01..e45a6fb 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,6 +1,6 @@ use half::f16; use lalrpop_util::{lexer::Token, ParseError}; -use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; +use std::{convert::From, mem, num::ParseFloatError, rc::Rc, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; #[derive(Debug, thiserror::Error)] @@ -86,19 +86,20 @@ pub enum Directive<'a, P: ArgParams> { Method(Function<'a, &'a str, Statement

>), } -pub enum MethodDecl<'a, ID> { - Func(Vec>, ID, Vec>), - Kernel { - name: &'a str, - in_args: Vec>, - }, +#[derive(Hash, PartialEq, Eq, Copy, Clone)] +pub enum MethodName<'input, ID> { + Kernel(&'input str), + Func(ID), } -pub type FnArgument = Variable; -pub type KernelArgument = Variable; +pub struct MethodDeclaration<'input, ID> { + pub return_arguments: Vec>, + pub name: MethodName<'input, ID>, + pub input_arguments: Vec>, +} pub struct Function<'a, ID, S> { - pub func_directive: MethodDecl<'a, ID>, + pub func_directive: MethodDeclaration<'a, ID>, pub tuning: Vec, pub body: Option>, } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 8fee7c2..78ebf1d 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -360,7 +360,7 @@ AddressSize = { Function: ast::Function<'input, &'input str, ast::Statement>> = { LinkingDirectives - + => ast::Function{<>} }; @@ -388,19 +388,24 @@ LinkingDirectives: ast::LinkingDirective = { } } -MethodDecl: ast::MethodDecl<'input, &'input str> = { - ".entry" => - ast::MethodDecl::Kernel{ name, in_args }, - ".func" => { - ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params) +MethodDeclaration: ast::MethodDeclaration<'input, &'input str> = { + ".entry" => { + let return_arguments = Vec::new(); + let name = ast::MethodName::Kernel(name); + ast::MethodDeclaration{ return_arguments, name, input_arguments } + }, + ".func" => { + let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); + let name = ast::MethodName::Func(name); + ast::MethodDeclaration{ return_arguments, name, input_arguments } } }; -KernelArguments: Vec> = { +KernelArguments: Vec> = { "(" > ")" => args }; -FnArguments: Vec> = { +FnArguments: Vec> = { "(" > ")" => args }; diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 1a2eda3..88ef51b 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,7 +1,9 @@ use crate::ast; +use core::borrow; use half::f16; use rspirv::dr; -use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem}; +use std::{borrow::Borrow, cell::RefCell}; +use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem, rc::Rc}; use std::{ collections::{hash_map, HashMap, HashSet}, convert::TryInto, @@ -458,7 +460,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result>(); let mut builder = dr::Builder::new(); builder.reserve_ids(id_defs.current_id()); - let call_map = get_call_map(&directives); + let call_map = get_kernels_call_map(&directives); let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id()); normalize_variable_decls(&mut directives); let denorm_information = compute_denorm_information(&directives); @@ -496,9 +498,12 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result( call_map: &HashMap<&str, HashSet>, - denorm_information: &HashMap>, + denorm_information: &HashMap< + ast::MethodName<'input, spirv::Word>, + HashMap, + >, ) -> CString { let denorm_counts = denorm_information .iter() @@ -516,10 +521,12 @@ fn emit_denorm_build_string( .collect::>(); let mut flush_over_preserve = 0; for (kernel, children) in call_map { - flush_over_preserve += *denorm_counts.get(&MethodName::Kernel(kernel)).unwrap_or(&0); + flush_over_preserve += *denorm_counts + .get(&ast::MethodName::Kernel(kernel)) + .unwrap_or(&0); for child_fn in children { flush_over_preserve += *denorm_counts - .get(&MethodName::Func(*child_fn)) + .get(&ast::MethodName::Func(*child_fn)) .unwrap_or(&0); } } @@ -535,9 +542,12 @@ fn emit_directives<'input>( map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver<'input>, opencl_id: spirv::Word, - denorm_information: &HashMap, HashMap>, + denorm_information: &HashMap< + ast::MethodName<'input, spirv::Word>, + HashMap, + >, call_map: &HashMap<&'input str, HashSet>, - directives: Vec, + directives: Vec>, kernel_info: &mut HashMap, ) -> Result<(), TranslateError> { let empty_body = Vec::new(); @@ -560,16 +570,18 @@ fn emit_directives<'input>( for var in f.globals.iter() { emit_variable(builder, map, var)?; } + let func_decl = (*f.func_decl).borrow(); let fn_id = emit_function_header( builder, map, &id_defs, &f.globals, - &f.spirv_decl, + &*func_decl, &denorm_information, call_map, &directives, kernel_info, + f.uses_shared_mem, )?; for t in f.tuning.iter() { match *t { @@ -594,8 +606,13 @@ fn emit_directives<'input>( } emit_function_body_ops(builder, map, opencl_id, &f_body)?; builder.end_function()?; - if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) = - (&f.func_decl, &f.import_as) + if let ( + ast::MethodDeclaration { + name: ast::MethodName::Func(fn_id), + .. + }, + Some(name), + ) = (&*func_decl, &f.import_as) { builder.decorate( *fn_id, @@ -614,7 +631,7 @@ fn emit_directives<'input>( Ok(()) } -fn get_call_map<'input>( +fn get_kernels_call_map<'input>( module: &[Directive<'input>], ) -> HashMap<&'input str, HashSet> { let mut directly_called_by = HashMap::new(); @@ -625,7 +642,7 @@ fn get_call_map<'input>( body: Some(statements), .. }) => { - let call_key = MethodName::new(&func_decl); + let call_key: ast::MethodName<_> = (**func_decl).borrow().name; if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) { entry.insert(Vec::new()); } @@ -644,28 +661,28 @@ fn get_call_map<'input>( let mut result = HashMap::new(); for (method_key, children) in directly_called_by.iter() { match method_key { - MethodName::Kernel(name) => { + ast::MethodName::Kernel(name) => { let mut visited = HashSet::new(); for child in children { add_call_map_single(&directly_called_by, &mut visited, *child); } result.insert(*name, visited); } - MethodName::Func(_) => {} + ast::MethodName::Func(_) => {} } } result } fn add_call_map_single<'input>( - directly_called_by: &MultiHashMap, spirv::Word>, + directly_called_by: &MultiHashMap, spirv::Word>, visited: &mut HashSet, current: spirv::Word, ) { if !visited.insert(current) { return; } - if let Some(children) = directly_called_by.get(&MethodName::Func(current)) { + if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) { for child in children { add_call_map_single(directly_called_by, visited, *child); } @@ -739,10 +756,10 @@ fn convert_dynamic_shared_memory_usage<'input>( globals, body: Some(statements), import_as, - spirv_decl, tuning, + uses_shared_mem, }) => { - let call_key = MethodName::new(&func_decl); + let call_key = (*func_decl).borrow().name; let statements = statements .into_iter() .map(|statement| match statement { @@ -763,8 +780,8 @@ fn convert_dynamic_shared_memory_usage<'input>( globals, body: Some(statements), import_as, - spirv_decl, tuning, + uses_shared_mem, }) } directive => directive, @@ -782,30 +799,32 @@ fn convert_dynamic_shared_memory_usage<'input>( globals, body: Some(statements), import_as, - mut spirv_decl, tuning, + uses_shared_mem, }) => { - if !methods_using_extern_shared.contains(&spirv_decl.name) { + if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) { return Directive::Method(Function { func_decl, globals, body: Some(statements), import_as, - spirv_decl, tuning, + uses_shared_mem, }); } let shared_id_param = new_id(); - spirv_decl.input.push({ - ast::Variable { - name: shared_id_param, - align: None, - v_type: ast::Type::Pointer(ast::ScalarType::B8), - state_space: ast::StateSpace::Shared, - array_init: Vec::new(), - } - }); - spirv_decl.uses_shared_mem = true; + { + let mut func_decl = (*func_decl).borrow_mut(); + func_decl.input_arguments.push({ + ast::Variable { + name: shared_id_param, + align: None, + v_type: ast::Type::Pointer(ast::ScalarType::B8), + state_space: ast::StateSpace::Shared, + array_init: Vec::new(), + } + }); + } let statements = replace_uses_of_shared_memory( new_id, &extern_shared_decls, @@ -818,8 +837,8 @@ fn convert_dynamic_shared_memory_usage<'input>( globals, body: Some(statements), import_as, - spirv_decl, tuning, + uses_shared_mem: true, }) } directive => directive, @@ -830,7 +849,7 @@ fn convert_dynamic_shared_memory_usage<'input>( fn replace_uses_of_shared_memory<'a>( new_id: &mut impl FnMut() -> spirv::Word, extern_shared_decls: &HashMap, - methods_using_extern_shared: &mut HashSet>, + methods_using_extern_shared: &mut HashSet>, shared_id_param: spirv::Word, statements: Vec, ) -> Vec { @@ -841,7 +860,7 @@ fn replace_uses_of_shared_memory<'a>( // We can safely skip checking call arguments, // because there's simply no way to pass shared ptr // without converting it to .b64 first - if methods_using_extern_shared.contains(&MethodName::Func(call.func)) { + if methods_using_extern_shared.contains(&ast::MethodName::Func(call.func)) { call.param_list.push(( shared_id_param, ast::Type::Scalar(ast::ScalarType::B8), @@ -881,13 +900,13 @@ fn replace_uses_of_shared_memory<'a>( } fn get_callers_of_extern_shared<'a>( - methods_using_extern_shared: &mut HashSet>, - directly_called_by: &MultiHashMap>, + methods_using_extern_shared: &mut HashSet>, + directly_called_by: &MultiHashMap>, ) { let direct_uses_of_extern_shared = methods_using_extern_shared .iter() .filter_map(|method| { - if let MethodName::Func(f_id) = method { + if let ast::MethodName::Func(f_id) = method { Some(*f_id) } else { None @@ -900,14 +919,14 @@ fn get_callers_of_extern_shared<'a>( } fn get_callers_of_extern_shared_single<'a>( - methods_using_extern_shared: &mut HashSet>, - directly_called_by: &MultiHashMap>, + methods_using_extern_shared: &mut HashSet>, + directly_called_by: &MultiHashMap>, fn_id: spirv::Word, ) { if let Some(callers) = directly_called_by.get(&fn_id) { for caller in callers { if methods_using_extern_shared.insert(*caller) { - if let MethodName::Func(caller_fn) = caller { + if let ast::MethodName::Func(caller_fn) = caller { get_callers_of_extern_shared_single( methods_using_extern_shared, directly_called_by, @@ -949,7 +968,7 @@ fn denorm_count_map_update_impl( // and emit suitable execution mode fn compute_denorm_information<'input>( module: &[Directive<'input>], -) -> HashMap, HashMap> { +) -> HashMap, HashMap> { let mut denorm_methods = HashMap::new(); for directive in module { match directive { @@ -960,7 +979,7 @@ fn compute_denorm_information<'input>( .. }) => { let mut flush_counter = DenormCountMap::new(); - let method_key = MethodName::new(func_decl); + let method_key = (**func_decl).borrow().name; for statement in statements { match statement { Statement::Instruction(inst) => { @@ -1004,21 +1023,6 @@ fn compute_denorm_information<'input>( .collect() } -#[derive(Hash, PartialEq, Eq, Copy, Clone)] -enum MethodName<'input> { - Kernel(&'input str), - Func(spirv::Word), -} - -impl<'input> MethodName<'input> { - fn new(decl: &ast::MethodDecl<'input, spirv::Word>) -> Self { - match decl { - ast::MethodDecl::Kernel { name, .. } => MethodName::Kernel(name), - ast::MethodDecl::Func(_, id, _) => MethodName::Func(*id), - } - } -} - fn emit_builtins( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -1047,17 +1051,21 @@ fn emit_function_header<'a>( map: &mut TypeWordMap, defined_globals: &GlobalStringIdResolver<'a>, synthetic_globals: &[ast::Variable], - func_decl: &SpirvMethodDecl<'a>, - _denorm_information: &HashMap, HashMap>, + func_decl: &ast::MethodDeclaration<'a, spirv::Word>, + _denorm_information: &HashMap< + ast::MethodName<'a, spirv::Word>, + HashMap, + >, call_map: &HashMap<&'a str, HashSet>, direcitves: &[Directive], kernel_info: &mut HashMap, + uses_shared_mem: bool, ) -> Result { - if let MethodName::Kernel(name) = func_decl.name { - let input_args = if !func_decl.uses_shared_mem { - func_decl.input.as_slice() + if let ast::MethodName::Kernel(name) = func_decl.name { + let input_args = if !uses_shared_mem { + func_decl.input_arguments.as_slice() } else { - &func_decl.input[0..func_decl.input.len() - 1] + &func_decl.input_arguments[0..func_decl.input_arguments.len() - 1] }; let args_lens = input_args .iter() @@ -1067,14 +1075,18 @@ fn emit_function_header<'a>( name.to_string(), KernelInfo { arguments_sizes: args_lens, - uses_shared_mem: func_decl.uses_shared_mem, + uses_shared_mem: uses_shared_mem, }, ); } - let (ret_type, func_type) = - get_function_type(builder, map, &func_decl.input, &func_decl.output); + let (ret_type, func_type) = get_function_type( + builder, + map, + &func_decl.input_arguments, + &func_decl.return_arguments, + ); let fn_id = match func_decl.name { - MethodName::Kernel(name) => { + ast::MethodName::Kernel(name) => { let fn_id = defined_globals.get_id(name)?; let mut global_variables = defined_globals .variables_type_check @@ -1090,15 +1102,16 @@ fn emit_function_header<'a>( for directive in direcitves { match directive { Directive::Method(Function { - func_decl: ast::MethodDecl::Func(_, name, _), - globals, - .. + func_decl, globals, .. }) => { - if child_fns.contains(name) { - for var in globals { - interface.push(var.name); + match (**func_decl).borrow().name { + ast::MethodName::Func(name) => { + for var in globals { + interface.push(var.name); + } } - } + ast::MethodName::Kernel(_) => {} + }; } _ => {} } @@ -1107,7 +1120,7 @@ fn emit_function_header<'a>( builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables); fn_id } - MethodName::Func(name) => name, + ast::MethodName::Func(name) => name, }; builder.begin_function( ret_type, @@ -1130,7 +1143,7 @@ fn emit_function_header<'a>( } } */ - for input in &func_decl.input { + for input in &func_decl.input_arguments { let result_type = map.get_or_add( builder, SpirvType::new(input.v_type.clone(), input.state_space), @@ -1225,9 +1238,10 @@ fn translate_function<'a>( f: ast::ParsedFunction<'a>, ) -> Result>, TranslateError> { let import_as = match &f.func_directive { - ast::MethodDecl::Func(_, "__assertfail", _) => { - Some("__zluda_ptx_impl____assertfail".to_owned()) - } + ast::MethodDeclaration { + name: ast::MethodName::Func("__assertfail"), + .. + } => Some("__zluda_ptx_impl____assertfail".to_owned()), _ => None, }; let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?; @@ -1253,10 +1267,10 @@ fn translate_function<'a>( fn expand_kernel_params<'a, 'b>( fn_resolver: &mut FnStringIdResolver<'a, 'b>, - args: impl Iterator>, -) -> Result>, TranslateError> { + args: impl Iterator>, +) -> Result>, TranslateError> { args.map(|a| { - Ok(ast::KernelArgument { + Ok(ast::Variable { name: fn_resolver.add_def( a.name, Some(( @@ -1274,42 +1288,39 @@ fn expand_kernel_params<'a, 'b>( .collect::>() } -fn expand_fn_params<'a, 'b>( +fn rename_fn_params<'a, 'b>( fn_resolver: &mut FnStringIdResolver<'a, 'b>, - args: impl Iterator>, -) -> Result>, TranslateError> { - args.map(|a| { - let is_variable = a.state_space == ast::StateSpace::Reg; - Ok(ast::FnArgument { - name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), is_variable), + args: &'b [ast::Variable<&'a str>], +) -> Vec> { + args.iter() + .map(|a| ast::Variable { + name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), false), v_type: a.v_type.clone(), state_space: a.state_space, align: a.align, - array_init: Vec::new(), + array_init: a.array_init.clone(), }) - }) - .collect() + .collect() } fn to_ssa<'input, 'b>( ptx_impl_imports: &mut HashMap, mut id_defs: FnStringIdResolver<'input, 'b>, fn_defs: GlobalFnDeclResolver<'input, 'b>, - f_args: ast::MethodDecl<'input, spirv::Word>, + func_decl: Rc>>, f_body: Option>>>, tuning: Vec, ) -> Result, TranslateError> { - let mut spirv_decl = SpirvMethodDecl::new(&f_args); let f_body = match f_body { Some(vec) => vec, None => { return Ok(Function { - func_decl: f_args, + func_decl: func_decl, body: None, globals: Vec::new(), import_as: None, - spirv_decl, tuning, + uses_shared_mem: false, }) } }; @@ -1323,8 +1334,7 @@ fn to_ssa<'input, 'b>( let ssa_statements = insert_mem_ssa_statements( typed_statements, &mut numeric_id_defs, - &f_args, - &mut spirv_decl, + &mut (*func_decl).borrow_mut(), )?; let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.finish(); @@ -1336,12 +1346,12 @@ fn to_ssa<'input, 'b>( let (f_body, globals) = extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs); Ok(Function { - func_decl: f_args, + func_decl: func_decl, globals: globals, body: Some(f_body), import_as: None, - spirv_decl, tuning, + uses_shared_mem: false, }) } @@ -1573,9 +1583,9 @@ fn convert_to_typed_statements( Statement::Instruction(inst) => match inst { ast::Instruction::Call(call) => { // TODO: error out if lengths don't match - let fn_def = fn_defs.get_fn_decl(call.func)?; - let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals); - let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params); + let fn_def = (**fn_defs.get_fn_decl(call.func)?).borrow(); + let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.return_arguments); + let in_args = to_resolved_fn_args(call.param_list, &*fn_def.input_arguments); let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args .into_iter() .partition(|(_, _, space)| *space == ast::StateSpace::Param); @@ -1731,24 +1741,24 @@ fn to_ptx_impl_atomic_call( let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { let fn_id = id_defs.register_intermediate(None); - let func_decl = ast::MethodDecl::Func::( - vec![ast::FnArgument { + let func_decl = ast::MethodDeclaration:: { + return_arguments: vec![ast::Variable { align: None, v_type: ast::Type::Scalar(scalar_typ), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }], - fn_id, - vec![ - ast::FnArgument { + name: ast::MethodName::Func(fn_id), + input_arguments: vec![ + ast::Variable { align: None, v_type: ast::Type::Pointer(typ), state_space: ptr_space, name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, v_type: ast::Type::Scalar(scalar_typ), state_space: ast::StateSpace::Reg, @@ -1756,24 +1766,23 @@ fn to_ptx_impl_atomic_call( array_init: Vec::new(), }, ], - ); - let spirv_decl = SpirvMethodDecl::new(&func_decl); + }; let func = Function { - func_decl, + func_decl: Rc::new(RefCell::new(func_decl)), globals: Vec::new(), body: None, import_as: Some(entry.key().clone()), - spirv_decl, tuning: Vec::new(), + uses_shared_mem: false, }; entry.insert(Directive::Method(func)); fn_id } hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { - func_decl: ast::MethodDecl::Func(_, name, _), - .. - }) => *name, + Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { + ast::MethodName::Func(fn_id) => fn_id, + ast::MethodName::Kernel(_) => unreachable!(), + }, _ => unreachable!(), }, }; @@ -1810,31 +1819,31 @@ fn to_ptx_impl_bfe_call( let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { let fn_id = id_defs.register_intermediate(None); - let func_decl = ast::MethodDecl::Func::( - vec![ast::FnArgument { + let func_decl = ast::MethodDeclaration:: { + return_arguments: vec![ast::Variable { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }], - fn_id, - vec![ - ast::FnArgument { + name: ast::MethodName::Func(fn_id), + input_arguments: vec![ + ast::Variable { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, @@ -1842,24 +1851,23 @@ fn to_ptx_impl_bfe_call( array_init: Vec::new(), }, ], - ); - let spirv_decl = SpirvMethodDecl::new(&func_decl); + }; let func = Function { - func_decl, + func_decl: Rc::new(RefCell::new(func_decl)), globals: Vec::new(), body: None, import_as: Some(entry.key().clone()), - spirv_decl, tuning: Vec::new(), + uses_shared_mem: false, }; entry.insert(Directive::Method(func)); fn_id } hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { - func_decl: ast::MethodDecl::Func(_, name, _), - .. - }) => *name, + Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { + ast::MethodName::Func(fn_id) => fn_id, + ast::MethodName::Kernel(_) => unreachable!(), + }, _ => unreachable!(), }, }; @@ -1903,38 +1911,38 @@ fn to_ptx_impl_bfi_call( let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { let fn_id = id_defs.register_intermediate(None); - let func_decl = ast::MethodDecl::Func::( - vec![ast::FnArgument { + let func_decl = ast::MethodDeclaration:: { + return_arguments: vec![ast::Variable { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }], - fn_id, - vec![ - ast::FnArgument { + name: ast::MethodName::Func(fn_id), + input_arguments: vec![ + ast::Variable { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, @@ -1942,24 +1950,23 @@ fn to_ptx_impl_bfi_call( array_init: Vec::new(), }, ], - ); - let spirv_decl = SpirvMethodDecl::new(&func_decl); + }; let func = Function { - func_decl, + func_decl: Rc::new(RefCell::new(func_decl)), globals: Vec::new(), body: None, import_as: Some(entry.key().clone()), - spirv_decl, tuning: Vec::new(), + uses_shared_mem: false, }; entry.insert(Directive::Method(func)); fn_id } hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { - func_decl: ast::MethodDecl::Func(_, name, _), - .. - }) => *name, + Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { + ast::MethodName::Func(fn_id) => fn_id, + ast::MethodName::Kernel(_) => unreachable!(), + }, _ => unreachable!(), }, }; @@ -1994,12 +2001,12 @@ fn to_ptx_impl_bfi_call( fn to_resolved_fn_args( params: Vec, - params_decl: &[(ast::Type, ast::StateSpace)], + params_decl: &[ast::Variable], ) -> Vec<(T, ast::Type, ast::StateSpace)> { params .into_iter() .zip(params_decl.iter()) - .map(|(id, (typ, space))| (id, typ.clone(), *space)) + .map(|(id, var)| (id, var.v_type.clone(), var.state_space)) .collect::>() } @@ -2084,11 +2091,10 @@ fn normalize_predicates( fn insert_mem_ssa_statements<'a, 'b>( func: Vec, id_def: &mut NumericIdResolver, - _: &'a ast::MethodDecl<'b, spirv::Word>, - fn_decl: &mut SpirvMethodDecl, + fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>, ) -> Result, TranslateError> { let mut result = Vec::with_capacity(func.len()); - for arg in fn_decl.output.iter() { + for arg in fn_decl.return_arguments.iter() { result.push(Statement::Variable(ast::Variable { align: arg.align, v_type: arg.v_type.clone(), @@ -2097,27 +2103,27 @@ fn insert_mem_ssa_statements<'a, 'b>( array_init: arg.array_init.clone(), })); } - for spirv_arg in fn_decl.input.iter_mut() { - let typ = spirv_arg.v_type.clone(); - let state_space = spirv_arg.state_space; + for arg in fn_decl.input_arguments.iter_mut() { + let typ = arg.v_type.clone(); + let state_space = arg.state_space; let new_id = id_def.register_intermediate(Some((typ.clone(), state_space))); result.push(Statement::Variable(ast::Variable { - align: spirv_arg.align, - v_type: spirv_arg.v_type.clone(), - state_space: spirv_arg.state_space, - name: spirv_arg.name, - array_init: spirv_arg.array_init.clone(), + align: arg.align, + v_type: arg.v_type.clone(), + state_space: arg.state_space, + name: arg.name, + array_init: Vec::new(), })); result.push(Statement::StoreVar(StoreVarDetails { arg: ast::Arg2St { - src1: spirv_arg.name, + src1: arg.name, src2: new_id, }, state_space, typ, member_index: None, })); - spirv_arg.name = new_id; + arg.name = new_id; } for s in func { match s { @@ -2127,7 +2133,7 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::Instruction(inst) => match inst { ast::Instruction::Ret(d) => { // TODO: handle multiple output args - if let &[out_param] = &fn_decl.output.as_slice() { + if let &[out_param] = &fn_decl.return_arguments.as_slice() { let (typ, space, _) = id_def.get_typed(out_param.name)?; let new_id = id_def.register_intermediate(Some((typ.clone(), space))); result.push(Statement::LoadVar(LoadVarDetails { @@ -5081,15 +5087,10 @@ struct GlobalStringIdResolver<'input> { variables: HashMap, spirv::Word>, variables_type_check: HashMap>, special_registers: SpecialRegistersMap, - fns: HashMap, + fns: HashMap>>>, } -pub struct FnDecl { - ret_vals: Vec<(ast::Type, ast::StateSpace)>, - params: Vec<(ast::Type, ast::StateSpace)>, -} - -impl<'a> GlobalStringIdResolver<'a> { +impl<'input> GlobalStringIdResolver<'input> { fn new(start_id: spirv::Word) -> Self { Self { current_id: start_id, @@ -5100,13 +5101,13 @@ impl<'a> GlobalStringIdResolver<'a> { } } - fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word { + fn get_or_add_def(&mut self, id: &'input str) -> spirv::Word { self.get_or_add_impl(id, None) } fn get_or_add_def_typed( &mut self, - id: &'a str, + id: &'input str, typ: ast::Type, state_space: ast::StateSpace, is_variable: bool, @@ -5116,7 +5117,7 @@ impl<'a> GlobalStringIdResolver<'a> { fn get_or_add_impl( &mut self, - id: &'a str, + id: &'input str, typ: Option<(ast::Type, ast::StateSpace, bool)>, ) -> spirv::Word { let id = match self.variables.entry(Cow::Borrowed(id)) { @@ -5145,12 +5146,12 @@ impl<'a> GlobalStringIdResolver<'a> { fn start_fn<'b>( &'b mut self, - header: &'b ast::MethodDecl<'a, &'a str>, + header: &'b ast::MethodDeclaration<'input, &'input str>, ) -> Result< ( - FnStringIdResolver<'a, 'b>, - GlobalFnDeclResolver<'a, 'b>, - ast::MethodDecl<'a, spirv::Word>, + FnStringIdResolver<'input, 'b>, + GlobalFnDeclResolver<'input, 'b>, + Rc>>, ), TranslateError, > { @@ -5164,30 +5165,18 @@ impl<'a> GlobalStringIdResolver<'a> { variables: vec![HashMap::new(); 1], type_check: HashMap::new(), }; - let new_fn_decl = match header { - ast::MethodDecl::Kernel { name, in_args } => ast::MethodDecl::Kernel { - name, - in_args: expand_kernel_params(&mut fn_resolver, in_args.iter())?, - }, - ast::MethodDecl::Func(ret_params, _, params) => { - let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter())?; - let params_ids = expand_fn_params(&mut fn_resolver, params.iter())?; - self.fns.insert( - name_id, - FnDecl { - ret_vals: ret_params_ids - .iter() - .map(|p| (p.v_type.clone(), p.state_space)) - .collect(), - params: params_ids - .iter() - .map(|p| (p.v_type.clone(), p.state_space)) - .collect(), - }, - ); - ast::MethodDecl::Func(ret_params_ids, name_id, params_ids) - } + let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments); + let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments); + let name = match header.name { + ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), + ast::MethodName::Func(_) => ast::MethodName::Func(name_id), }; + let new_fn_decl = Rc::new(RefCell::new(ast::MethodDeclaration { + return_arguments, + name, + input_arguments, + })); + self.fns.insert(name_id, Rc::clone(&new_fn_decl)); Ok(( fn_resolver, GlobalFnDeclResolver { @@ -5201,15 +5190,21 @@ impl<'a> GlobalStringIdResolver<'a> { pub struct GlobalFnDeclResolver<'input, 'a> { variables: &'a HashMap, spirv::Word>, - fns: &'a HashMap, + fns: &'a HashMap>>>, } impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { - fn get_fn_decl(&self, id: spirv::Word) -> Result<&FnDecl, TranslateError> { + fn get_fn_decl( + &self, + id: spirv::Word, + ) -> Result<&Rc>>, TranslateError> { self.fns.get(&id).ok_or(TranslateError::UnknownSymbol) } - fn get_fn_decl_str(&self, id: &str) -> Result<&'a FnDecl, TranslateError> { + fn get_fn_decl_str( + &self, + id: &str, + ) -> Result<&'a Rc>>, TranslateError> { match self.variables.get(id).map(|var_id| self.fns.get(var_id)) { Some(Some(fn_d)) => Ok(fn_d), _ => Err(TranslateError::UnknownSymbol), @@ -5713,21 +5708,9 @@ impl, U: ArgParamsEx> Visitab } } -pub trait ArgParamsEx: ast::ArgParams + Sized { - fn get_fn_decl<'x, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'x, 'b>, - ) -> Result<&'b FnDecl, TranslateError>; -} +pub trait ArgParamsEx: ast::ArgParams + Sized {} -impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> { - fn get_fn_decl<'x, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'x, 'b>, - ) -> Result<&'b FnDecl, TranslateError> { - decl.get_fn_decl_str(id) - } -} +impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {} enum NormalizedArgParams {} @@ -5736,14 +5719,7 @@ impl ast::ArgParams for NormalizedArgParams { type Operand = ast::Operand; } -impl ArgParamsEx for NormalizedArgParams { - fn get_fn_decl<'a, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'a, 'b>, - ) -> Result<&'b FnDecl, TranslateError> { - decl.get_fn_decl(*id) - } -} +impl ArgParamsEx for NormalizedArgParams {} type NormalizedStatement = Statement< ( @@ -5762,14 +5738,7 @@ impl ast::ArgParams for TypedArgParams { type Operand = TypedOperand; } -impl ArgParamsEx for TypedArgParams { - fn get_fn_decl<'a, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'a, 'b>, - ) -> Result<&'b FnDecl, TranslateError> { - decl.get_fn_decl(*id) - } -} +impl ArgParamsEx for TypedArgParams {} #[derive(Copy, Clone)] enum TypedOperand { @@ -5800,14 +5769,7 @@ impl ast::ArgParams for ExpandedArgParams { type Operand = spirv::Word; } -impl ArgParamsEx for ExpandedArgParams { - fn get_fn_decl<'a, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'a, 'b>, - ) -> Result<&'b FnDecl, TranslateError> { - decl.get_fn_decl(*id) - } -} +impl ArgParamsEx for ExpandedArgParams {} enum Directive<'input> { Variable(ast::Variable), @@ -5815,10 +5777,10 @@ enum Directive<'input> { } struct Function<'input> { - pub func_decl: ast::MethodDecl<'input, spirv::Word>, - pub spirv_decl: SpirvMethodDecl<'input>, + pub func_decl: Rc>>, pub globals: Vec>, pub body: Option>, + pub uses_shared_mem: bool, import_as: Option, tuning: Vec, } @@ -7671,73 +7633,11 @@ fn should_convert_relaxed_dst( } } -impl<'a> ast::MethodDecl<'a, &'a str> { +impl<'a> ast::MethodDeclaration<'a, &'a str> { fn name(&self) -> &'a str { - match self { - ast::MethodDecl::Kernel { name, .. } => name, - ast::MethodDecl::Func(_, name, _) => name, - } - } -} - -struct SpirvMethodDecl<'input> { - input: Vec>, - output: Vec>, - name: MethodName<'input>, - uses_shared_mem: bool, -} - -impl<'input> SpirvMethodDecl<'input> { - fn new(ast_decl: &ast::MethodDecl<'input, spirv::Word>) -> Self { - let (input, output) = match ast_decl { - ast::MethodDecl::Kernel { in_args, .. } => { - let spirv_input = in_args - .iter() - .map(|var| ast::Variable { - name: var.name, - align: var.align, - v_type: var.v_type.clone(), - state_space: var.state_space, - array_init: var.array_init.clone(), - }) - .collect(); - (spirv_input, Vec::new()) - } - ast::MethodDecl::Func(out_args, _, in_args) => { - let (param_output, non_param_output): (Vec<_>, Vec<_>) = out_args - .iter() - .partition(|var| var.state_space == ast::StateSpace::Param); - let spirv_output = non_param_output - .into_iter() - .cloned() - .map(|var| ast::Variable { - name: var.name, - align: var.align, - v_type: var.v_type.clone(), - state_space: var.state_space, - array_init: var.array_init.clone(), - }) - .collect(); - let spirv_input = param_output - .into_iter() - .cloned() - .chain(in_args.iter().cloned()) - .map(|var| ast::Variable { - name: var.name, - align: var.align, - v_type: var.v_type.clone(), - state_space: var.state_space, - array_init: var.array_init.clone(), - }) - .collect(); - (spirv_input, spirv_output) - } - }; - SpirvMethodDecl { - input, - output, - name: MethodName::new(ast_decl), - uses_shared_mem: false, + match self.name { + ast::MethodName::Kernel(name) => name, + ast::MethodName::Func(name) => name, } } } diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs index 4ea449c..f168930 100644 --- a/zluda_dump/src/lib.rs +++ b/zluda_dump/src/lib.rs @@ -191,7 +191,10 @@ unsafe fn record_module_image(module: CUmodule, image: &str) { unsafe fn try_dump_module_image(image: &str) -> Result<(), Box> { let mut dump_path = get_dump_dir()?; - dump_path.push(format!("module_{:04}.ptx", MODULES.as_ref().unwrap().len() - 1)); + dump_path.push(format!( + "module_{:04}.ptx", + MODULES.as_ref().unwrap().len() - 1 + )); let mut file = File::create(dump_path)?; file.write_all(image.as_bytes())?; Ok(()) @@ -217,10 +220,15 @@ unsafe fn to_str(image: *const T) -> Option<&'static str> { fn directive_to_kernel(dir: &ast::Directive) -> Option<(String, Vec)> { match dir { ast::Directive::Method(ast::Function { - func_directive: ast::MethodDecl::Kernel { name, in_args }, + func_directive: + ast::MethodDeclaration { + name: ast::MethodName::Kernel(name), + input_arguments, + .. + }, .. }) => { - let arg_sizes = in_args + let arg_sizes = input_arguments .iter() .map(|arg| ast::Type::from(arg.v_type.clone()).size_of()) .collect(); From 82b5cef0bd03fd395dd213ea8386c26d16671894 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 15 May 2021 15:58:11 +0200 Subject: [PATCH 09/25] Carry state space with pointer --- ptx/src/ast.rs | 41 ++++++++++- ptx/src/ptx.lalrpop | 6 +- ptx/src/translate.rs | 167 +++++++++++++++++-------------------------- 3 files changed, 108 insertions(+), 106 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index e45a6fb..e49e489 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -108,10 +108,49 @@ pub type ParsedFunction<'a> = Function<'a, &'a str, Statement OpTypeInt Scalar(ScalarType), + // .param.v2.b32 foo; + // -> OpTypeVector Vector(ScalarType, u8), + // .param.b32 foo[4]; + // -> OpTypeArray Array(ScalarType, Vec), - Pointer(ScalarType), + /* + Variables of this type almost never exist in the original .ptx and are + usually artificially created. Some examples below: + - extern pointers to the .shared memory in the form: + .extern .shared .b32 shared_mem[]; + which we first parse as + .extern .shared .b32 shared_mem; + and then convert to an additional function parameter: + .param .ptr<.b32.shared> shared_mem; + and do a load at the start of the function (and renames inside fn): + .reg .ptr<.b32.shared> temp; + ld.param.ptr<.b32.shared> temp, [shared_mem]; + note, we don't support non-.shared extern pointers, because there's + zero use for them in the ptxas + - artifical pointers created by stateful conversion, which work + similiarly to the above + - function parameters: + foobar(.param .align 4 .b8 numbers[]) + which get parsed to + foobar(.param .align 4 .b8 numbers) + and then converted to + foobar(.reg .align 4 .ptr<.b8.param> numbers) + - ld/st with offset: + .reg.b32 x; + .param.b64 arg0; + st.param.b32 [arg0+4], x; + Yes, this code is legal and actually emitted by the NV compiler! + We convert the st to: + .reg ptr<.b64.param> temp = ptr_offset(arg0, 4); + st.param.b32 [temp], x; + */ + // .reg ptr<.b64.param> + // -> OpTypePointer Function + Pointer(ScalarType, StateSpace), } #[derive(PartialEq, Eq, Hash, Clone, Copy)] diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 78ebf1d..2253f85 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -624,9 +624,9 @@ ModuleVariable: ast::Variable<&'input str> = { return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); } if space == ".global" { - (ast::Type::Pointer(t), ast::StateSpace::Global, Vec::new()) + (ast::Type::Scalar(t), ast::StateSpace::Global, Vec::new()) } else { - (ast::Type::Pointer(t), ast::StateSpace::Shared, Vec::new()) + (ast::Type::Scalar(t), ast::StateSpace::Shared, Vec::new()) } } }; @@ -648,7 +648,7 @@ ParamVariable: (Option, Vec, ast::Type, &'input str) = { (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { - (ast::Type::Pointer(t), Vec::new()) + (ast::Type::Scalar(t), Vec::new()) } }; (align, array_init, v_type, name) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 88ef51b..ea6451e 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -56,33 +56,20 @@ enum SpirvType { } impl SpirvType { - fn new(t: ast::Type, decl_space: ast::StateSpace) -> Self { + fn new(t: ast::Type) -> Self { match t { ast::Type::Scalar(t) => SpirvType::Base(t.into()), ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), - ast::Type::Pointer(pointer_t) => { - let spirv_space = match decl_space { - ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { - spirv::StorageClass::Private - } - ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, - ast::StateSpace::Const => spirv::StorageClass::UniformConstant, - ast::StateSpace::Shared => spirv::StorageClass::Workgroup, - ast::StateSpace::Generic => spirv::StorageClass::Generic, - ast::StateSpace::Sreg => spirv::StorageClass::Input, - }; - SpirvType::Pointer(Box::new(SpirvType::Base(pointer_t.into())), spirv_space) - } + ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer( + Box::new(SpirvType::Base(pointer_t.into())), + space.to_spirv(), + ), } } - fn pointer_to( - t: ast::Type, - inner_space: ast::StateSpace, - outer_space: spirv::StorageClass, - ) -> Self { - let key = Self::new(t, inner_space); + fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self { + let key = Self::new(t); SpirvType::Pointer(Box::new(key), outer_space) } } @@ -394,7 +381,7 @@ impl TypeWordMap { b.constant_composite(result_type, None, components.into_iter()) } }, - ast::Type::Pointer(typ) => return Err(error_unreachable()), + ast::Type::Pointer(..) => return Err(error_unreachable()), }) } @@ -453,7 +440,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result, _>>()?; let must_link_ptx_impl = ptx_impl_imports.len() > 0; - let directives = ptx_impl_imports + let mut directives = ptx_impl_imports .into_iter() .map(|(_, v)| v) .chain(directives.into_iter()) @@ -461,7 +448,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result(m: &mut MultiHashMap, transformation has a semantical meaning - we emit additional "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...") */ +/* fn convert_dynamic_shared_memory_usage<'input>( module: Vec>, new_id: &mut impl FnMut() -> spirv::Word, @@ -819,7 +807,7 @@ fn convert_dynamic_shared_memory_usage<'input>( ast::Variable { name: shared_id_param, align: None, - v_type: ast::Type::Pointer(ast::ScalarType::B8), + v_type: ast::Type::Pointer(ast::ScalarType::B8, new_todo!()), state_space: ast::StateSpace::Shared, array_init: Vec::new(), } @@ -937,6 +925,7 @@ fn get_callers_of_extern_shared_single<'a>( } } } +*/ type DenormCountMap = HashMap; @@ -1031,11 +1020,7 @@ fn emit_builtins( for (reg, id) in id_defs.special_registers.builtins() { let result_type = map.get_or_add( builder, - SpirvType::pointer_to( - reg.get_type(), - ast::StateSpace::Reg, - spirv::StorageClass::Input, - ), + SpirvType::pointer_to(reg.get_type(), spirv::StorageClass::Input), ); builder.variable(result_type, Some(id), spirv::StorageClass::Input, None); builder.decorate( @@ -1144,10 +1129,7 @@ fn emit_function_header<'a>( } */ for input in &func_decl.input_arguments { - let result_type = map.get_or_add( - builder, - SpirvType::new(input.v_type.clone(), input.state_space), - ); + let result_type = map.get_or_add(builder, SpirvType::new(input.v_type.clone())); builder.function_parameter(Some(input.name), result_type)?; } Ok(fn_id) @@ -1753,8 +1735,8 @@ fn to_ptx_impl_atomic_call( input_arguments: vec![ ast::Variable { align: None, - v_type: ast::Type::Pointer(typ), - state_space: ptr_space, + v_type: ast::Type::Pointer(typ, ptr_space), + state_space: ast::StateSpace::Reg, name: id_defs.register_intermediate(None), array_init: Vec::new(), }, @@ -1791,7 +1773,11 @@ fn to_ptx_impl_atomic_call( func: fn_id, ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)], param_list: vec![ - (arg.src1, ast::Type::Pointer(typ), ptr_space), + ( + arg.src1, + ast::Type::Pointer(typ, ptr_space), + ast::StateSpace::Reg, + ), ( arg.src2, ast::Type::Scalar(scalar_typ), @@ -2629,8 +2615,8 @@ fn insert_implicit_conversions( is_dst: false, sema: ArgumentSemantics::PhysicalPointer, }, - typ: &ast::Type::Pointer(underlying_type), - state_space, + typ: &ast::Type::Pointer(underlying_type, state_space), + state_space: new_todo!(), stmt_ctor: |new_ptr_src| { Statement::PtrAccess(PtrAccess { underlying_type, @@ -2758,10 +2744,10 @@ fn get_function_type( builder, spirv_input .iter() - .map(|var| SpirvType::new(var.v_type.clone(), var.state_space)), + .map(|var| SpirvType::new(var.v_type.clone())), spirv_output .iter() - .map(|var| SpirvType::new(var.v_type.clone(), var.state_space)), + .map(|var| SpirvType::new(var.v_type.clone())), ) } @@ -2790,7 +2776,7 @@ fn emit_function_body_ops( Statement::Call(call) => { let (result_type, result_id) = match &*call.ret_params { [(id, typ, space)] => ( - map.get_or_add(builder, SpirvType::new(typ.clone(), *space)), + map.get_or_add(builder, SpirvType::new(typ.clone())), Some(*id), ), [] => (map.void(), None), @@ -2922,10 +2908,8 @@ fn emit_function_body_ops( if data.qualifier != ast::LdStQualifier::Weak { todo!() } - let result_type = map.get_or_add( - builder, - SpirvType::new(ast::Type::from(data.typ.clone()), data.state_space), - ); + let result_type = + map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); builder.load( result_type, Some(arg.dst), @@ -2956,10 +2940,8 @@ fn emit_function_body_ops( // SPIR-V does not support ret as guaranteed-converged ast::Instruction::Ret(_) => builder.ret()?, ast::Instruction::Mov(d, arg) => { - let result_type = map.get_or_add( - builder, - SpirvType::new(ast::Type::from(d.typ.clone()), ast::StateSpace::Reg), - ); + let result_type = + map.get_or_add(builder, SpirvType::new(ast::Type::from(d.typ.clone()))); builder.copy_object(result_type, Some(arg.dst), arg.src)?; } ast::Instruction::Mul(mul, arg) => match mul { @@ -3000,8 +2982,7 @@ fn emit_function_body_ops( ast::Instruction::Shl(t, a) => { let full_type = ast::Type::Scalar(*t); let size_of = full_type.size_of(); - let result_type = - map.get_or_add(builder, SpirvType::new(full_type, ast::StateSpace::Reg)); + let result_type = map.get_or_add(builder, SpirvType::new(full_type)); let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?; builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?; } @@ -3265,7 +3246,6 @@ fn emit_function_body_ops( builder, SpirvType::pointer_to( details.typ.clone(), - details.state_space, spirv::StorageClass::Function, ), ); @@ -3297,11 +3277,11 @@ fn emit_function_body_ops( }) => { let u8_pointer = map.get_or_add( builder, - SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8), *state_space), + SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)), ); let result_type = map.get_or_add( builder, - SpirvType::new(ast::Type::Pointer(*underlying_type), *state_space), + SpirvType::new(ast::Type::Pointer(*underlying_type, *state_space)), ); let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?; let temp = builder.in_bounds_ptr_access_chain( @@ -3596,15 +3576,12 @@ fn emit_variable( &*var.array_init, )?) } else if must_init { - let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone(), var.state_space)); + let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone())); Some(builder.constant_null(type_id, None)) } else { None }; - let ptr_type_id = map.get_or_add( - builder, - SpirvType::pointer_to(var.v_type.clone(), var.state_space, st_class), - ); + let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class)); builder.variable(ptr_type_id, Some(var.name), st_class, initalizer); if let Some(align) = var.align { builder.decorate( @@ -3742,10 +3719,7 @@ fn emit_min( ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, }; - let inst_type = map.get_or_add( - builder, - SpirvType::new(desc.get_type(), ast::StateSpace::Reg), - ); + let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type())); builder.ext_inst( inst_type, Some(arg.dst), @@ -3770,10 +3744,7 @@ fn emit_max( ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, }; - let inst_type = map.get_or_add( - builder, - SpirvType::new(desc.get_type(), ast::StateSpace::Reg), - ); + let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type())); builder.ext_inst( inst_type, Some(arg.dst), @@ -4255,14 +4226,13 @@ fn emit_implicit_conversion( (_, _, ConversionKind::BitToPtr) => { let dst_type = map.get_or_add( builder, - SpirvType::pointer_to(cv.to_type.clone(), cv.from_space, cv.to_space.to_spirv()), + SpirvType::pointer_to(cv.to_type.clone(), cv.to_space.to_spirv()), ); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { if from_parts.width == to_parts.width { - let dst_type = - map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)); + let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); if from_parts.scalar_kind != ast::ScalarKind::Float && to_parts.scalar_kind != ast::ScalarKind::Float { @@ -4275,13 +4245,10 @@ fn emit_implicit_conversion( // This block is safe because it's illegal to implictly convert between floating point values let same_width_bit_type = map.get_or_add( builder, - SpirvType::new( - ast::Type::from_parts(TypeParts { - scalar_kind: ast::ScalarKind::Bit, - ..from_parts - }), - cv.from_space, - ), + SpirvType::new(ast::Type::from_parts(TypeParts { + scalar_kind: ast::ScalarKind::Bit, + ..from_parts + })), ); let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?; let wide_bit_type = ast::Type::from_parts(TypeParts { @@ -4289,7 +4256,7 @@ fn emit_implicit_conversion( ..to_parts }); let wide_bit_type_spirv = - map.get_or_add(builder, SpirvType::new(wide_bit_type.clone(), cv.to_space)); + map.get_or_add(builder, SpirvType::new(wide_bit_type.clone())); if to_parts.scalar_kind == ast::ScalarKind::Unsigned || to_parts.scalar_kind == ast::ScalarKind::Bit { @@ -4323,15 +4290,13 @@ fn emit_implicit_conversion( } } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => { - let result_type = - map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)); + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.s_convert(result_type, Some(cv.dst), cv.src)?; } (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default) | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default) | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => { - let into_type = - map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)); + let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.bitcast(into_type, Some(cv.dst), cv.src)?; } (_, _, ConversionKind::PtrToPtr { spirv_ptr }) => { @@ -4339,12 +4304,12 @@ fn emit_implicit_conversion( map.get_or_add( builder, SpirvType::Pointer( - Box::new(SpirvType::new(cv.to_type.clone(), cv.to_space)), + Box::new(SpirvType::new(cv.to_type.clone())), spirv::StorageClass::Function, ), ) } else { - map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)) + map.get_or_add(builder, SpirvType::new(cv.to_type.clone())) }; builder.bitcast(result_type, Some(cv.dst), cv.src)?; } @@ -4358,18 +4323,14 @@ fn emit_load_var( map: &mut TypeWordMap, details: &LoadVarDetails, ) -> Result<(), TranslateError> { - let result_type = map.get_or_add( - builder, - SpirvType::new(details.typ.clone(), details.state_space), - ); + let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone())); match details.member_index { Some((index, Some(width))) => { let vector_type = match details.typ { ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), _ => return Err(TranslateError::MismatchedType), }; - let vector_type_spirv = - map.get_or_add(builder, SpirvType::new(vector_type, details.state_space)); + let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); let vector_temp = builder.load( vector_type_spirv, None, @@ -4387,11 +4348,7 @@ fn emit_load_var( Some((index, None)) => { let result_ptr_type = map.get_or_add( builder, - SpirvType::pointer_to( - details.typ.clone(), - details.state_space, - spirv::StorageClass::Function, - ), + SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function), ); let index_spirv = map.get_or_add_constant( builder, @@ -5661,7 +5618,7 @@ impl> PtrAccess

{ ast::StateSpace::Reg => new_todo!(), ast::StateSpace::Sreg => new_todo!(), }; - let ptr_type = ast::Type::Pointer(self.underlying_type.clone()); + let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), new_todo!()); let new_dst = visitor.id( ArgumentDescriptor { op: self.dst, @@ -6231,24 +6188,28 @@ impl ast::Type { match self { ast::Type::Scalar(scalar) => TypeParts { kind: TypeKind::Scalar, + state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), }, ast::Type::Vector(scalar, components) => TypeParts { kind: TypeKind::Vector, + state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), width: scalar.size_of(), components: vec![*components as u32], }, ast::Type::Array(scalar, components) => TypeParts { kind: TypeKind::Array, + state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), width: scalar.size_of(), components: components.clone(), }, - ast::Type::Pointer(scalar) => TypeParts { - kind: TypeKind::PointerScalar, + ast::Type::Pointer(scalar, space) => TypeParts { + kind: TypeKind::Pointer, + state_space: *space, scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), @@ -6269,9 +6230,10 @@ impl ast::Type { ast::ScalarType::from_parts(t.width, t.scalar_kind), t.components, ), - TypeKind::PointerScalar => { - ast::Type::Pointer(ast::ScalarType::from_parts(t.width, t.scalar_kind)) - } + TypeKind::Pointer => ast::Type::Pointer( + ast::ScalarType::from_parts(t.width, t.scalar_kind), + t.state_space, + ), } } @@ -6292,6 +6254,7 @@ struct TypeParts { kind: TypeKind, scalar_kind: ast::ScalarKind, width: u8, + state_space: ast::StateSpace, components: Vec, } @@ -6300,7 +6263,7 @@ enum TypeKind { Scalar, Vector, Array, - PointerScalar, + Pointer, } impl ast::Instruction { From 8d74c16c8697b36b5a93484008457b5fcfa7b7b9 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 30 May 2021 16:08:18 +0200 Subject: [PATCH 10/25] Refactor implicit conversions --- ptx/src/ast.rs | 2 +- ptx/src/translate.rs | 713 +++++++++++++++++++------------------------ 2 files changed, 320 insertions(+), 395 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index e49e489..3ad61e5 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,6 +1,6 @@ use half::f16; use lalrpop_util::{lexer::Token, ParseError}; -use std::{convert::From, mem, num::ParseFloatError, rc::Rc, str::FromStr}; +use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; #[derive(Debug, thiserror::Error)] diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index ea6451e..c6b7f01 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -74,12 +74,6 @@ impl SpirvType { } } -impl ast::Type { - fn param_pointer_to(self, space: ast::StateSpace) -> Result { - Ok(self) - } -} - impl From for SpirvType { fn from(t: ast::ScalarType) -> Self { SpirvType::Base(t.into()) @@ -636,7 +630,7 @@ fn get_kernels_call_map<'input>( for statement in statements { match statement { Statement::Call(call) => { - multi_hash_map_append(&mut directly_called_by, call_key, call.func); + multi_hash_map_append(&mut directly_called_by, call_key, call.name); } _ => {} } @@ -872,8 +866,8 @@ fn replace_uses_of_shared_memory<'a>( to_type: ast::Type::Pointer((*scalar_type).into()), to_space: ast::StateSpace::Shared, kind: ConversionKind::PtrToPtr { spirv_ptr: true }, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, + src_ + dst_ })); replacement_id } else { @@ -1172,48 +1166,19 @@ fn translate_directive<'input>( d: ast::Directive<'input, ast::ParsedArgParams<'input>>, ) -> Result>, TranslateError> { Ok(match d { - ast::Directive::Variable(v) => Some(Directive::Variable(translate_variable(id_defs, v)?)), + ast::Directive::Variable(var) => Some(Directive::Variable(ast::Variable { + align: var.align, + v_type: var.v_type.clone(), + state_space: var.state_space, + name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true), + array_init: var.array_init, + })), ast::Directive::Method(f) => { translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method) } }) } -fn translate_variable<'a>( - id_defs: &mut GlobalStringIdResolver<'a>, - var: ast::Variable<&'a str>, -) -> Result, TranslateError> { - let (space, var_type) = (var.state_space, var.v_type.clone()); - let mut is_variable = false; - let var_type = match space { - ast::StateSpace::Reg => { - is_variable = true; - var_type - } - ast::StateSpace::Const => var_type.param_pointer_to(ast::StateSpace::Const)?, - ast::StateSpace::Global => var_type.param_pointer_to(ast::StateSpace::Global)?, - ast::StateSpace::Local => var_type.param_pointer_to(ast::StateSpace::Local)?, - ast::StateSpace::Shared => { - // If it's a pointer it will be translated to a method parameter later - if let ast::Type::Pointer(..) = var_type { - is_variable = true; - var_type - } else { - var_type.param_pointer_to(ast::StateSpace::Shared)? - } - } - ast::StateSpace::Param => var_type.param_pointer_to(ast::StateSpace::Param)?, - ast::StateSpace::Generic | ast::StateSpace::Sreg => return Err(error_unreachable()), - }; - Ok(ast::Variable { - align: var.align, - v_type: var.v_type, - state_space: var.state_space, - name: id_defs.get_or_add_def_typed(var.name, var_type, var.state_space, is_variable), - array_init: var.array_init, - }) -} - fn translate_function<'a>( id_defs: &mut GlobalStringIdResolver<'a>, ptx_impl_imports: &mut HashMap>, @@ -1247,29 +1212,6 @@ fn translate_function<'a>( } } -fn expand_kernel_params<'a, 'b>( - fn_resolver: &mut FnStringIdResolver<'a, 'b>, - args: impl Iterator>, -) -> Result>, TranslateError> { - args.map(|a| { - Ok(ast::Variable { - name: fn_resolver.add_def( - a.name, - Some(( - ast::Type::from(a.v_type.clone()).param_pointer_to(ast::StateSpace::Param)?, - a.state_space, - )), - false, - ), - v_type: a.v_type.clone(), - state_space: a.state_space, - align: a.align, - array_init: Vec::new(), - }) - }) - .collect::>() -} - fn rename_fn_params<'a, 'b>( fn_resolver: &mut FnStringIdResolver<'a, 'b>, args: &'b [ast::Variable<&'a str>], @@ -1293,6 +1235,7 @@ fn to_ssa<'input, 'b>( f_body: Option>>>, tuning: Vec, ) -> Result, TranslateError> { + deparamize_function_decl(&func_decl)?; let f_body = match f_body { Some(vec) => vec, None => { @@ -1337,6 +1280,30 @@ fn to_ssa<'input, 'b>( }) } +fn deparamize_function_decl( + func_decl_rc: &Rc>>, +) -> Result<(), TranslateError> { + let mut func_decl = func_decl_rc.borrow_mut(); + match func_decl.name { + ast::MethodName::Func(..) => { + for decl in func_decl.input_arguments.iter_mut() { + if decl.state_space == ast::StateSpace::Param { + decl.state_space = ast::StateSpace::Reg; + let baseline_type = match decl.v_type { + ast::Type::Scalar(t) => t, + ast::Type::Vector(t, _) => t, // TODO: write a test for this + ast::Type::Array(t, _) => t, // TODO: write a test for this + ast::Type::Pointer(_, _) => return Err(error_unreachable()), + }; + decl.v_type = ast::Type::Pointer(baseline_type, ast::StateSpace::Param); + } + } + } + ast::MethodName::Kernel(..) => {} + }; + Ok(()) +} + fn fix_special_registers( typed_statements: Vec, numeric_id_defs: &mut NumericIdResolver, @@ -1394,8 +1361,6 @@ fn fix_special_registers( to_type: ast::Type::Scalar(ast::ScalarType::U32), to_space: ast::StateSpace::Sreg, kind: ConversionKind::Default, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, })); } } @@ -1566,45 +1531,21 @@ fn convert_to_typed_statements( ast::Instruction::Call(call) => { // TODO: error out if lengths don't match let fn_def = (**fn_defs.get_fn_decl(call.func)?).borrow(); - let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.return_arguments); - let in_args = to_resolved_fn_args(call.param_list, &*fn_def.input_arguments); - let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args - .into_iter() - .partition(|(_, _, space)| *space == ast::StateSpace::Param); - let normalized_input_args = out_params - .into_iter() - .map(|(id, typ, space)| (ast::Operand::Reg(id), typ, space)) - .chain(in_args.into_iter()) - .collect(); + let return_arguments = + to_resolved_fn_args(call.ret_params, &*fn_def.return_arguments); + let input_arguments = + to_resolved_fn_args(call.param_list, &*fn_def.input_arguments); let resolved_call = ResolvedCall { uniform: call.uniform, - ret_params: out_non_params, - func: call.func, - param_list: normalized_input_args, + return_arguments, + name: call.func, + input_arguments, }; let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let reresolved_call = resolved_call.visit(&mut visitor)?; visitor.func.push(reresolved_call); visitor.func.extend(visitor.post_stmts); } - ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => { - if let Some(src_id) = src.underlying() { - let (typ, _, _) = id_defs.get_typed(*src_id)?; - let take_address = match typ { - ast::Type::Scalar(..) => false, - ast::Type::Vector(..) => false, - ast::Type::Array(..) => true, - ast::Type::Pointer(..) => true, - }; - d.src_is_address = take_address; - } - let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); - let instruction = Statement::Instruction( - ast::Instruction::Mov(d, ast::Arg2Mov { dst, src }).map(&mut visitor)?, - ); - visitor.func.push(instruction); - visitor.func.extend(visitor.post_stmts); - } inst => { let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let instruction = Statement::Instruction(inst.map(&mut visitor)?); @@ -1639,7 +1580,6 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { fn convert_vector( &mut self, is_dst: bool, - vector_sema: ArgumentSemantics, typ: &ast::Type, state_space: ast::StateSpace, idx: Vec, @@ -1657,7 +1597,6 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { typ: scalar_t, packed: temp_vec, unpacked: idx, - vector_sema, }); if is_dst { self.post_stmts = Some(statement); @@ -1690,13 +1629,9 @@ impl<'a, 'b> ArgumentMapVisitor ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), ast::Operand::Imm(x) => TypedOperand::Imm(x), ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), - ast::Operand::VecPack(vec) => TypedOperand::Reg(self.convert_vector( - desc.is_dst, - desc.sema, - typ, - state_space, - vec, - )?), + ast::Operand::VecPack(vec) => { + TypedOperand::Reg(self.convert_vector(desc.is_dst, typ, state_space, vec)?) + } }) } } @@ -1770,9 +1705,9 @@ fn to_ptx_impl_atomic_call( }; Statement::Call(ResolvedCall { uniform: false, - func: fn_id, - ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)], - param_list: vec![ + name: fn_id, + return_arguments: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)], + input_arguments: vec![ ( arg.src1, ast::Type::Pointer(typ, ptr_space), @@ -1859,9 +1794,9 @@ fn to_ptx_impl_bfe_call( }; Statement::Call(ResolvedCall { uniform: false, - func: fn_id, - ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], - param_list: vec![ + name: fn_id, + return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], + input_arguments: vec![ ( arg.src1, ast::Type::Scalar(typ.into()), @@ -1958,9 +1893,9 @@ fn to_ptx_impl_bfi_call( }; Statement::Call(ResolvedCall { uniform: false, - func: fn_id, - ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], - param_list: vec![ + name: fn_id, + return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], + input_arguments: vec![ ( arg.src1, ast::Type::Scalar(typ.into()), @@ -2217,14 +2152,14 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { fn symbol( &mut self, desc: ArgumentDescriptor<(spirv::Word, Option)>, - expected_type: Option<(&ast::Type, ast::StateSpace)>, + expected: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { let symbol = desc.op.0; - if expected_type.is_none() { + if expected.is_none() { return Ok(symbol); }; - let (mut var_type, _, is_variable) = self.id_def.get_typed(symbol)?; - if !is_variable { + let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?; + if !var_space.is_compatible(ast::StateSpace::Reg) || !is_variable { return Ok(symbol); }; let member_index = match desc.op.1 { @@ -2579,28 +2514,10 @@ fn insert_implicit_conversions( for s in func.into_iter() { match s { Statement::Call(call) => { - insert_implicit_conversions_impl(&mut result, id_def, call, should_bitcast_wrapper)? + insert_implicit_conversions_impl(&mut result, id_def, call)?; } Statement::Instruction(inst) => { - let mut default_conversion_fn = - should_bitcast_wrapper as for<'a> fn(&'a _, _, &'a _, _) -> _; - let mut state_space = None; - if let ast::Instruction::Ld(d, _) = &inst { - state_space = Some(d.state_space); - } - if let ast::Instruction::St(d, _) = &inst { - state_space = Some(d.state_space); - } - if let ast::Instruction::Atom(d, _) = &inst { - state_space = Some(d.space); - } - if let ast::Instruction::AtomCas(d, _) = &inst { - state_space = Some(d.space); - } - if let ast::Instruction::Mov(..) = &inst { - default_conversion_fn = should_bitcast_packed; - } - insert_implicit_conversions_impl(&mut result, id_def, inst, default_conversion_fn)?; + insert_implicit_conversions_impl(&mut result, id_def, inst)?; } Statement::PtrAccess(PtrAccess { underlying_type, @@ -2613,7 +2530,7 @@ fn insert_implicit_conversions( desc: ArgumentDescriptor { op: ptr_src, is_dst: false, - sema: ArgumentSemantics::PhysicalPointer, + non_default_implicit_conversion: None, }, typ: &ast::Type::Pointer(underlying_type, state_space), state_space: new_todo!(), @@ -2627,19 +2544,11 @@ fn insert_implicit_conversions( }) }, }; - insert_implicit_conversions_impl( - &mut result, - id_def, - visit_desc, - bitcast_physical_pointer, - )?; + insert_implicit_conversions_impl(&mut result, id_def, visit_desc)?; + } + Statement::RepackVector(repack) => { + insert_implicit_conversions_impl(&mut result, id_def, repack)?; } - Statement::RepackVector(repack) => insert_implicit_conversions_impl( - &mut result, - id_def, - repack, - should_bitcast_wrapper, - )?, s @ Statement::Conditional(_) | s @ Statement::Conversion(_) | s @ Statement::Label(_) @@ -2657,12 +2566,6 @@ fn insert_implicit_conversions_impl( func: &mut Vec, id_def: &mut MutableNumericIdResolver, stmt: impl Visitable, - default_conversion_fn: for<'a> fn( - &'a ast::Type, - ast::StateSpace, - &'a ast::Type, - ast::StateSpace, - ) -> Result, TranslateError>, ) -> Result<(), TranslateError> { let mut post_conv = Vec::new(); let statement = @@ -2673,27 +2576,13 @@ fn insert_implicit_conversions_impl( Some(t) => t, }; let (operand_type, operand_space) = id_def.get_typed(desc.op)?; - let mut conversion_fn = default_conversion_fn; - match desc.sema { - ArgumentSemantics::Default => {} - ArgumentSemantics::DefaultRelaxed => { - if desc.is_dst { - conversion_fn = should_convert_relaxed_dst_wrapper; - } else { - conversion_fn = should_convert_relaxed_src_wrapper; - } - } - ArgumentSemantics::PhysicalPointer => { - conversion_fn = bitcast_physical_pointer; - } - ArgumentSemantics::RegisterPointer => { - conversion_fn = bitcast_register_pointer; - } - ArgumentSemantics::Address => { - conversion_fn = force_bitcast_ptr_to_bit; - } - }; - match conversion_fn(&operand_type, operand_space, instr_type, instruction_space)? { + let conversion_fn = desc + .non_default_implicit_conversion + .unwrap_or(default_implicit_conversion); + match conversion_fn( + (operand_space, &operand_type), + (instruction_space, instr_type), + )? { Some(conv_kind) => { let conv_output = if desc.is_dst { &mut post_conv @@ -2721,8 +2610,6 @@ fn insert_implicit_conversions_impl( to_type, to_space, kind: conv_kind, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, })); result } @@ -2774,7 +2661,7 @@ fn emit_function_body_ops( match s { Statement::Label(_) => (), Statement::Call(call) => { - let (result_type, result_id) = match &*call.ret_params { + let (result_type, result_id) = match &*call.return_arguments { [(id, typ, space)] => ( map.get_or_add(builder, SpirvType::new(typ.clone())), Some(*id), @@ -2783,11 +2670,11 @@ fn emit_function_body_ops( _ => todo!(), }; let arg_list = call - .param_list + .input_arguments .iter() .map(|(id, _, _)| *id) .collect::>(); - builder.function_call(result_type, result_id, call.func, arg_list)?; + builder.function_call(result_type, result_id, call.name, arg_list)?; } Statement::Variable(var) => { emit_variable(builder, map, var)?; @@ -3863,8 +3750,6 @@ fn emit_cvt( )), to_space: ast::StateSpace::Reg, kind: ConversionKind::Default, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, }; emit_implicit_conversion(builder, map, &cv)?; new_dst @@ -4218,19 +4103,19 @@ fn emit_implicit_conversion( ) -> Result<(), TranslateError> { let from_parts = cv.from_type.to_parts(); let to_parts = cv.to_type.to_parts(); - match (from_parts.kind, to_parts.kind, cv.kind) { - (_, _, ConversionKind::PtrToBit(typ)) => { + match (from_parts.kind, to_parts.kind, &cv.kind) { + (_, _, &ConversionKind::PtrToBit(typ)) => { let dst_type = map.get_or_add_scalar(builder, typ.into()); builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; } - (_, _, ConversionKind::BitToPtr) => { + (_, _, &ConversionKind::BitToPtr) => { let dst_type = map.get_or_add( builder, SpirvType::pointer_to(cv.to_type.clone(), cv.to_space.to_spirv()), ); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } - (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { + (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => { if from_parts.width == to_parts.width { let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); if from_parts.scalar_kind != ast::ScalarKind::Float @@ -4282,35 +4167,29 @@ fn emit_implicit_conversion( to_type: cv.to_type.clone(), to_space: new_todo!(), kind: ConversionKind::Default, - src_sema: cv.src_sema, - dst_sema: cv.dst_sema, }, )?; } } } - (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => { + (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => { let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.s_convert(result_type, Some(cv.dst), cv.src)?; } - (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default) - | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default) - | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => { + (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default) + | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default) + | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => { let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.bitcast(into_type, Some(cv.dst), cv.src)?; } - (_, _, ConversionKind::PtrToPtr { spirv_ptr }) => { - let result_type = if spirv_ptr { - map.get_or_add( - builder, - SpirvType::Pointer( - Box::new(SpirvType::new(cv.to_type.clone())), - spirv::StorageClass::Function, - ), - ) - } else { - map.get_or_add(builder, SpirvType::new(cv.to_type.clone())) - }; + (_, _, &ConversionKind::PtrToPtr) => { + let result_type = map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::new(cv.to_type.clone())), + cv.to_space.to_spirv(), + ), + ); builder.bitcast(result_type, Some(cv.dst), cv.src)?; } _ => unreachable!(), @@ -4417,38 +4296,12 @@ fn expand_map_variables<'a, 'b>( i.map_variable(&mut |id| id_defs.get_id(id))?, ))), ast::Statement::Variable(var) => { - let mut var_type = ast::Type::from(var.var.v_type.clone()); - let mut is_variable = false; - var_type = match var.var.state_space { - ast::StateSpace::Reg => { - is_variable = true; - var_type - } - ast::StateSpace::Shared => { - // If it's a pointer it will be translated to a method parameter later - if let ast::Type::Pointer(..) = var_type { - is_variable = true; - var_type - } else { - var_type.param_pointer_to(ast::StateSpace::Shared)? - } - } - ast::StateSpace::Global => var_type.param_pointer_to(ast::StateSpace::Global)?, - ast::StateSpace::Param => var_type.param_pointer_to(ast::StateSpace::Param)?, - ast::StateSpace::Local => var_type.param_pointer_to(ast::StateSpace::Local)?, - ast::StateSpace::Const => new_todo!(), - ast::StateSpace::Generic => new_todo!(), - ast::StateSpace::Sreg => new_todo!(), - }; + let var_type = var.var.v_type.clone(); match var.count { Some(count) => { - for new_id in id_defs.add_defs( - var.var.name, - count, - var_type, - var.var.state_space, - is_variable, - ) { + for new_id in + id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true) + { result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), @@ -4459,11 +4312,8 @@ fn expand_map_variables<'a, 'b>( } } None => { - let new_id = id_defs.add_def( - var.var.name, - Some((var_type, var.var.state_space)), - is_variable, - ); + let new_id = + id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true); result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), @@ -4839,7 +4689,7 @@ fn convert_to_stateful_memory_access_postprocess( ast::StateSpace::Global, ), kind: ConversionKind::BitToPtr(ast::StateSpace::Global), - src_sema: ArgumentSemantics::Default, + src_ dst_sema: arg_desc.sema, })); converting_id @@ -4854,7 +4704,7 @@ fn convert_to_stateful_memory_access_postprocess( to_type: old_type, kind: ConversionKind::PtrToBit(ast::ScalarType::U64), src_sema: arg_desc.sema, - dst_sema: ArgumentSemantics::Default, + dst_ })); converting_id } @@ -4881,7 +4731,7 @@ fn convert_to_stateful_memory_access_postprocess( to_type: old_type_clone, kind: ConversionKind::PtrToPtr { spirv_ptr: false }, src_sema: arg_desc.sema, - dst_sema: ArgumentSemantics::Default, + dst_ })); converting_id } @@ -4889,7 +4739,6 @@ fn convert_to_stateful_memory_access_postprocess( }, }) } -*/ fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { match arg.dst { @@ -4917,6 +4766,7 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool { _ => false, } } +*/ #[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] enum PtxSpecialRegister { @@ -5368,7 +5218,7 @@ impl ExpandedStatement { Statement::StoreVar(details) } Statement::Call(mut call) => { - for (id, _, space) in call.ret_params.iter_mut() { + for (id, _, space) in call.return_arguments.iter_mut() { let is_dst = match space { ast::StateSpace::Reg => true, ast::StateSpace::Param => false, @@ -5377,8 +5227,8 @@ impl ExpandedStatement { }; *id = f(*id, is_dst); } - call.func = f(call.func, false); - for (id, _, _) in call.param_list.iter_mut() { + call.name = f(call.name, false); + for (id, _, _) in call.input_arguments.iter_mut() { *id = f(*id, false); } Statement::Call(call) @@ -5461,7 +5311,6 @@ struct RepackVectorDetails { typ: ast::ScalarType, packed: spirv::Word, unpacked: Vec, - vector_sema: ArgumentSemantics, } impl RepackVectorDetails { @@ -5477,7 +5326,8 @@ impl RepackVectorDetails { ArgumentDescriptor { op: self.packed, is_dst: !self.is_extract, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, Some(( &ast::Type::Vector(self.typ, self.unpacked.len() as u8), @@ -5486,7 +5336,6 @@ impl RepackVectorDetails { )?; let scalar_type = self.typ; let is_extract = self.is_extract; - let vector_sema = self.vector_sema; let vector = self .unpacked .into_iter() @@ -5495,7 +5344,7 @@ impl RepackVectorDetails { ArgumentDescriptor { op: id, is_dst: is_extract, - sema: vector_sema, + non_default_implicit_conversion: None, }, Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)), ) @@ -5506,7 +5355,6 @@ impl RepackVectorDetails { typ: self.typ, packed: scalar, unpacked: vector, - vector_sema, }) } } @@ -5524,18 +5372,18 @@ impl, U: ArgParamsEx> Visitab struct ResolvedCall { pub uniform: bool, - pub ret_params: Vec<(P::Id, ast::Type, ast::StateSpace)>, - pub func: P::Id, - pub param_list: Vec<(P::Operand, ast::Type, ast::StateSpace)>, + pub return_arguments: Vec<(P::Id, ast::Type, ast::StateSpace)>, + pub name: P::Id, + pub input_arguments: Vec<(P::Operand, ast::Type, ast::StateSpace)>, } impl ResolvedCall { fn cast>(self) -> ResolvedCall { ResolvedCall { uniform: self.uniform, - ret_params: self.ret_params, - func: self.func, - param_list: self.param_list, + return_arguments: self.return_arguments, + name: self.name, + input_arguments: self.input_arguments, } } } @@ -5546,14 +5394,14 @@ impl> ResolvedCall { visitor: &mut V, ) -> Result, TranslateError> { let ret_params = self - .ret_params + .return_arguments .into_iter() .map::, _>(|(id, typ, space)| { let new_id = visitor.id( ArgumentDescriptor { op: id, is_dst: space != ast::StateSpace::Param, - sema: space.semantics(), + non_default_implicit_conversion: None, }, Some((&typ, space)), )?; @@ -5562,21 +5410,22 @@ impl> ResolvedCall { .collect::, _>>()?; let func = visitor.id( ArgumentDescriptor { - op: self.func, + op: self.name, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, None, )?; let param_list = self - .param_list + .input_arguments .into_iter() .map::, _>(|(id, typ, space)| { let new_id = visitor.operand( ArgumentDescriptor { op: id, is_dst: false, - sema: space.semantics(), + non_default_implicit_conversion: None, }, &typ, space, @@ -5586,9 +5435,9 @@ impl> ResolvedCall { .collect::, _>>()?; Ok(ResolvedCall { uniform: self.uniform, - ret_params, - func, - param_list, + return_arguments: ret_params, + name: func, + input_arguments: param_list, }) } } @@ -5623,7 +5472,7 @@ impl> PtrAccess

{ ArgumentDescriptor { op: self.dst, is_dst: true, - sema, + non_default_implicit_conversion: None, }, Some((&ptr_type, self.state_space)), )?; @@ -5631,7 +5480,7 @@ impl> PtrAccess

{ ArgumentDescriptor { op: self.ptr_src, is_dst: false, - sema, + non_default_implicit_conversion: None, }, Some((&ptr_type, self.state_space)), )?; @@ -5639,7 +5488,8 @@ impl> PtrAccess

{ ArgumentDescriptor { op: self.offset_src, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::S64), self.state_space, @@ -5816,7 +5666,12 @@ where pub struct ArgumentDescriptor { op: Op, is_dst: bool, - sema: ArgumentSemantics, + non_default_implicit_conversion: Option< + fn( + (ast::StateSpace, &ast::Type), + (ast::StateSpace, &ast::Type), + ) -> Result, TranslateError>, + >, } pub struct PtrAccess { @@ -5846,7 +5701,7 @@ impl ArgumentDescriptor { ArgumentDescriptor { op: u, is_dst: self.is_dst, - sema: self.sema, + non_default_implicit_conversion: None, } } } @@ -6085,7 +5940,7 @@ impl ImplicitConversion { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: self.dst_sema, + non_default_implicit_conversion: None, }, Some((&self.to_type, self.to_space)), )?; @@ -6093,7 +5948,7 @@ impl ImplicitConversion { ArgumentDescriptor { op: self.src, is_dst: false, - sema: self.src_sema, + non_default_implicit_conversion: None, }, Some((&self.from_type, self.from_space)), )?; @@ -6396,18 +6251,16 @@ struct ImplicitConversion { from_space: ast::StateSpace, to_space: ast::StateSpace, kind: ConversionKind, - src_sema: ArgumentSemantics, - dst_sema: ArgumentSemantics, } -#[derive(PartialEq, Copy, Clone)] +#[derive(PartialEq, Clone)] enum ConversionKind { Default, // zero-extend/chop/bitcast depending on types SignExtend, BitToPtr, PtrToBit(ast::ScalarType), - PtrToPtr { spirv_ptr: bool }, + PtrToPtr, } impl ast::PredAt { @@ -6461,7 +6314,8 @@ impl ast::Arg1 { ArgumentDescriptor { op: self.src, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, )?; @@ -6478,7 +6332,8 @@ impl ast::Arg1Bar { ArgumentDescriptor { op: self.src, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), ast::StateSpace::Reg, @@ -6497,7 +6352,8 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -6506,7 +6362,8 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.src, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -6527,7 +6384,8 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, dst_t, ast::StateSpace::Reg, @@ -6536,7 +6394,8 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.src, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, src_t, ast::StateSpace::Reg, @@ -6555,7 +6414,7 @@ impl ast::Arg2Ld { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::DefaultRelaxed, + non_default_implicit_conversion: None, }, &ast::Type::from(details.typ.clone()), ast::StateSpace::Reg, @@ -6566,11 +6425,7 @@ impl ast::Arg2Ld { ArgumentDescriptor { op: self.src, is_dst: false, - sema: if is_logical_ptr { - ArgumentSemantics::RegisterPointer - } else { - ArgumentSemantics::PhysicalPointer - }, + non_default_implicit_conversion: None, }, &details.typ, details.state_space, @@ -6591,11 +6446,7 @@ impl ast::Arg2St { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: if is_logical_ptr { - ArgumentSemantics::RegisterPointer - } else { - ArgumentSemantics::PhysicalPointer - }, + non_default_implicit_conversion: None, }, &details.typ, details.state_space, @@ -6604,7 +6455,7 @@ impl ast::Arg2St { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::DefaultRelaxed, + non_default_implicit_conversion: None, }, &details.typ.clone().into(), ast::StateSpace::Reg, @@ -6623,7 +6474,8 @@ impl ast::Arg2Mov { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &details.typ.clone().into(), ast::StateSpace::Reg, @@ -6632,11 +6484,7 @@ impl ast::Arg2Mov { ArgumentDescriptor { op: self.src, is_dst: false, - sema: if details.src_is_address { - ArgumentSemantics::Address - } else { - ArgumentSemantics::Default - }, + non_default_implicit_conversion: None, }, &details.typ.clone().into(), ast::StateSpace::Reg, @@ -6661,7 +6509,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, wide_type.as_ref().unwrap_or(typ), ast::StateSpace::Reg, @@ -6670,7 +6519,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, typ, ast::StateSpace::Reg, @@ -6679,7 +6529,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, typ, ast::StateSpace::Reg, @@ -6696,7 +6547,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -6705,7 +6557,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -6714,7 +6567,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), ast::StateSpace::Reg, @@ -6733,7 +6587,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), ast::StateSpace::Reg, @@ -6742,7 +6597,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::PhysicalPointer, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), state_space, @@ -6751,7 +6606,8 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), ast::StateSpace::Reg, @@ -6776,7 +6632,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, wide_type.as_ref().unwrap_or(t), ast::StateSpace::Reg, @@ -6785,7 +6642,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -6794,7 +6652,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -6803,7 +6662,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -6825,7 +6684,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(t.into()), ast::StateSpace::Reg, @@ -6834,7 +6693,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(t.into()), ast::StateSpace::Reg, @@ -6843,7 +6702,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(t.into()), ast::StateSpace::Reg, @@ -6852,7 +6711,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::Pred), ast::StateSpace::Reg, @@ -6876,7 +6735,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), ast::StateSpace::Reg, @@ -6885,7 +6745,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::PhysicalPointer, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), ast::StateSpace::Reg, @@ -6894,7 +6754,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), ast::StateSpace::Reg, @@ -6903,7 +6764,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), ast::StateSpace::Reg, @@ -6925,7 +6787,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, typ, ast::StateSpace::Reg, @@ -6934,7 +6797,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, typ, ast::StateSpace::Reg, @@ -6944,7 +6808,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &u32_type, ast::StateSpace::Reg, @@ -6953,7 +6818,8 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &u32_type, ast::StateSpace::Reg, @@ -6977,7 +6843,8 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: self.dst1, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, Some(( &ast::Type::Scalar(ast::ScalarType::Pred), @@ -6991,7 +6858,8 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: dst2, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, Some(( &ast::Type::Scalar(ast::ScalarType::Pred), @@ -7004,7 +6872,8 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -7013,7 +6882,8 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -7037,7 +6907,8 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, base_type, ast::StateSpace::Reg, @@ -7046,7 +6917,8 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, base_type, ast::StateSpace::Reg, @@ -7055,7 +6927,8 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, base_type, ast::StateSpace::Reg, @@ -7064,7 +6937,8 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), ast::StateSpace::Reg, @@ -7073,7 +6947,8 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src4, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), ast::StateSpace::Reg, @@ -7098,7 +6973,8 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: self.dst1, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, Some(( &ast::Type::Scalar(ast::ScalarType::Pred), @@ -7112,7 +6988,8 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: dst2, is_dst: true, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, Some(( &ast::Type::Scalar(ast::ScalarType::Pred), @@ -7125,7 +7002,8 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -7134,7 +7012,8 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, t, ast::StateSpace::Reg, @@ -7143,7 +7022,8 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::Pred), ast::StateSpace::Reg, @@ -7287,6 +7167,12 @@ impl ast::StateSpace { ast::StateSpace::Sreg => spirv::StorageClass::Input, } } + + fn is_compatible(self, other: ast::StateSpace) -> bool { + self == other + || self == ast::StateSpace::Reg && other == ast::StateSpace::Sreg + || self == ast::StateSpace::Sreg && other == ast::StateSpace::Reg + } } impl ast::Operand { @@ -7342,54 +7228,89 @@ impl ast::StateSpace { } } -fn bitcast_register_pointer( - operand_type: &ast::Type, - operand_space: ast::StateSpace, - instr_type: &ast::Type, - instruction_space: ast::StateSpace, +fn default_implicit_conversion( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - bitcast_physical_pointer(operand_type, operand_space, instr_type, instruction_space) + if !instruction_space.is_compatible(operand_space) { + default_implicit_conversion_space( + (operand_space, operand_type), + (instruction_space, instruction_type), + ) + } else if instruction_type != operand_type { + default_implicit_conversion_type(instruction_space, operand_type, instruction_type) + } else { + Ok(None) + } } -fn bitcast_physical_pointer( - operand_type: &ast::Type, - operand_space: ast::StateSpace, - instruction_type: &ast::Type, - instruction_space: ast::StateSpace, +// Space is different +fn default_implicit_conversion_space( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if operand_space == instruction_space { - if operand_type != instruction_type { - Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) - } else { - Ok(None) - } - } else { - match operand_space { - ast::StateSpace::Reg | ast::StateSpace::Sreg => match instruction_space { - ast::StateSpace::Generic - | ast::StateSpace::Global - | ast::StateSpace::Shared - | ast::StateSpace::Local => Ok(Some(ConversionKind::BitToPtr)), + if operand_space.is_compatible(ast::StateSpace::Reg) { + match operand_type { + ast::Type::Pointer(operand_ptr_type, operand_ptr_space) + if *operand_ptr_space == instruction_space => + { + if instruction_type != &ast::Type::Scalar(*operand_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } + // TODO: 32 bit + ast::Type::Scalar(ast::ScalarType::B64) + | ast::Type::Scalar(ast::ScalarType::U64) + | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space { + ast::StateSpace::Global + | ast::StateSpace::Generic + | ast::StateSpace::Const + | ast::StateSpace::Local + | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), + _ => Err(TranslateError::MismatchedType), + }, + ast::Type::Scalar(ast::ScalarType::B32) + | ast::Type::Scalar(ast::ScalarType::U32) + | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { + ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { + Ok(Some(ConversionKind::BitToPtr)) + } _ => Err(TranslateError::MismatchedType), }, _ => Err(TranslateError::MismatchedType), } + } else if instruction_space.is_compatible(ast::StateSpace::Reg) { + if let ast::Type::Pointer(instr_ptr_type, instr_ptr_space) = instruction_type { + if operand_space != *instr_ptr_space { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } else { + Err(TranslateError::MismatchedType) + } + } else { + Err(TranslateError::MismatchedType) } } -fn force_bitcast_ptr_to_bit( - _: &ast::Type, - _: ast::StateSpace, - instr_type: &ast::Type, - _: ast::StateSpace, +// Space is same, but type is different +fn default_implicit_conversion_type( + space: ast::StateSpace, + operand_type: &ast::Type, + instruction_type: &ast::Type, ) -> Result, TranslateError> { - // TODO: verify this on f32, u16 and the like - if let ast::Type::Scalar(scalar_t) = instr_type { - if let Ok(int_type) = (*scalar_t).try_into() { - return Ok(Some(ConversionKind::PtrToBit(int_type))); + if space.is_compatible(ast::StateSpace::Reg) { + if should_bitcast(instruction_type, operand_type) { + Ok(Some(ConversionKind::Default)) + } else { + Err(TranslateError::MismatchedType) } + } else { + Ok(Some(ConversionKind::PtrToPtr)) } - Err(TranslateError::MismatchedType) } fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { @@ -7421,22 +7342,26 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { } } -fn should_bitcast_packed( - operand: &ast::Type, - operand_space: ast::StateSpace, - instruction: &ast::Type, - instruction_space: ast::StateSpace, +fn implicit_conversion_mov( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = - (operand, instruction) - { - if scalar.kind() == ast::ScalarKind::Bit - && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) + // instruction_space is always reg + if operand_space.is_compatible(ast::StateSpace::Reg) { + if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = + (operand_type, instruction_type) { - return Ok(Some(ConversionKind::Default)); + if scalar.kind() == ast::ScalarKind::Bit + && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) + { + return Ok(Some(ConversionKind::Default)); + } } } - should_bitcast_wrapper(operand, operand_space, instruction, instruction_space) + default_implicit_conversion( + (operand_space, operand_type), + (instruction_space, instruction_type), + ) } fn should_bitcast_wrapper( From 4091f658b299f297397fa8d5e4e9edb597993d5c Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 30 May 2021 20:21:43 +0200 Subject: [PATCH 11/25] Fix PtrAccess --- ptx/src/translate.rs | 148 +++++++------------------------------------ 1 file changed, 23 insertions(+), 125 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c6b7f01..15163fc 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -2349,98 +2349,29 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { state_space: ast::StateSpace, ) -> Result { let (reg, offset) = desc.op; - let add_type; match typ { ast::Type::Scalar(underlying_type) => { - let (reg_typ, space) = self.id_def.get_typed(reg)?; - if let ast::Type::Pointer(..) = reg_typ { - let id_constant_stmt = self.id_def.register_intermediate(typ.clone(), space); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: ast::ScalarType::S64, - value: ast::ImmediateValue::S64(offset as i64), - })); - let dst = self.id_def.register_intermediate(typ.clone(), space); - self.func.push(Statement::PtrAccess(PtrAccess { - underlying_type: *underlying_type, - state_space: state_space, - dst, - ptr_src: reg, - offset_src: id_constant_stmt, - })); - return Ok(dst); - } else { - add_type = reg_typ; - } + let id_constant_stmt = self.id_def.register_intermediate( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: ast::ScalarType::S64, + value: ast::ImmediateValue::S64(offset as i64), + })); + let dst = self.id_def.register_intermediate(typ.clone(), state_space); + self.func.push(Statement::PtrAccess(PtrAccess { + underlying_type: *underlying_type, + state_space: state_space, + dst, + ptr_src: reg, + offset_src: id_constant_stmt, + })); + Ok(dst) } - _ => return Err(error_unreachable()), - }; - let (width, kind) = match add_type { - ast::Type::Scalar(scalar_t) => { - let kind = match scalar_t.kind() { - kind @ ast::ScalarKind::Bit - | kind @ ast::ScalarKind::Unsigned - | kind @ ast::ScalarKind::Signed => kind, - ast::ScalarKind::Float => return Err(TranslateError::MismatchedType), - ast::ScalarKind::Float2 => return Err(TranslateError::MismatchedType), - ast::ScalarKind::Pred => return Err(TranslateError::MismatchedType), - }; - (scalar_t.size_of(), kind) - } - _ => return Err(TranslateError::MismatchedType), - }; - let arith_detail = if kind == ast::ScalarKind::Signed { - ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::ScalarType::from_parts(width, ast::ScalarKind::Signed), - saturate: false, - }) - } else { - ast::ArithDetails::Unsigned(ast::ScalarType::from_parts( - width, - ast::ScalarKind::Unsigned, - )) - }; - let id_constant_stmt = self - .id_def - .register_intermediate(add_type.clone(), ast::StateSpace::Reg); - let result_id = self - .id_def - .register_intermediate(add_type, ast::StateSpace::Reg); - // TODO: check for edge cases around min value/max value/wrapping - if offset < 0 && kind != ast::ScalarKind::Signed { - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: ast::ScalarType::from_parts(width, kind), - value: ast::ImmediateValue::U64(-(offset as i64) as u64), - })); - self.func.push(Statement::Instruction( - ast::Instruction::::Sub( - arith_detail, - ast::Arg3 { - dst: result_id, - src1: reg, - src2: id_constant_stmt, - }, - ), - )); - } else { - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: ast::ScalarType::from_parts(width, kind), - value: ast::ImmediateValue::S64(offset as i64), - })); - self.func.push(Statement::Instruction( - ast::Instruction::::Add( - arith_detail, - ast::Arg3 { - dst: result_id, - src1: reg, - src2: id_constant_stmt, - }, - ), - )); + _ => Err(error_unreachable()), } - Ok(result_id) } fn immediate( @@ -2519,32 +2450,8 @@ fn insert_implicit_conversions( Statement::Instruction(inst) => { insert_implicit_conversions_impl(&mut result, id_def, inst)?; } - Statement::PtrAccess(PtrAccess { - underlying_type, - state_space, - dst, - ptr_src, - offset_src: constant_src, - }) => { - let visit_desc = VisitArgumentDescriptor { - desc: ArgumentDescriptor { - op: ptr_src, - is_dst: false, - non_default_implicit_conversion: None, - }, - typ: &ast::Type::Pointer(underlying_type, state_space), - state_space: new_todo!(), - stmt_ctor: |new_ptr_src| { - Statement::PtrAccess(PtrAccess { - underlying_type, - state_space, - dst, - ptr_src: new_ptr_src, - offset_src: constant_src, - }) - }, - }; - insert_implicit_conversions_impl(&mut result, id_def, visit_desc)?; + Statement::PtrAccess(access) => { + insert_implicit_conversions_impl(&mut result, id_def, access)?; } Statement::RepackVector(repack) => { insert_implicit_conversions_impl(&mut result, id_def, repack)?; @@ -5458,16 +5365,7 @@ impl> PtrAccess

{ self, visitor: &mut V, ) -> Result, TranslateError> { - let sema = match self.state_space { - ast::StateSpace::Const - | ast::StateSpace::Global - | ast::StateSpace::Shared - | ast::StateSpace::Generic => ArgumentSemantics::PhysicalPointer, - ast::StateSpace::Local | ast::StateSpace::Param => ArgumentSemantics::RegisterPointer, - ast::StateSpace::Reg => new_todo!(), - ast::StateSpace::Sreg => new_todo!(), - }; - let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), new_todo!()); + let ptr_type = ast::Type::Scalar(self.underlying_type.clone()); let new_dst = visitor.id( ArgumentDescriptor { op: self.dst, @@ -5492,7 +5390,7 @@ impl> PtrAccess

{ non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::S64), - self.state_space, + ast::StateSpace::Reg, )?; Ok(PtrAccess { underlying_type: self.underlying_type, From 3d9a79c41e8115e23c3d5db431c021e5a4848298 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 30 May 2021 23:06:44 +0200 Subject: [PATCH 12/25] Re-enable relaxed conversions --- ptx/src/translate.rs | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 15163fc..0f6368e 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -6312,13 +6312,11 @@ impl ast::Arg2Ld { ArgumentDescriptor { op: self.dst, is_dst: true, - non_default_implicit_conversion: None, + non_default_implicit_conversion: Some(should_convert_relaxed_dst_wrapper), }, &ast::Type::from(details.typ.clone()), ast::StateSpace::Reg, )?; - let is_logical_ptr = details.state_space == ast::StateSpace::Param - || details.state_space == ast::StateSpace::Local; let src = visitor.operand( ArgumentDescriptor { op: self.src, @@ -6338,8 +6336,6 @@ impl ast::Arg2St { visitor: &mut V, details: &ast::StData, ) -> Result, TranslateError> { - let is_logical_ptr = details.state_space == ast::StateSpace::Param - || details.state_space == ast::StateSpace::Local; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, @@ -6353,7 +6349,7 @@ impl ast::Arg2St { ArgumentDescriptor { op: self.src2, is_dst: false, - non_default_implicit_conversion: None, + non_default_implicit_conversion: Some(should_convert_relaxed_src_wrapper), }, &details.typ.clone().into(), ast::StateSpace::Reg, @@ -6427,7 +6423,6 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src2, is_dst: false, - non_default_implicit_conversion: None, }, typ, @@ -6646,7 +6641,7 @@ impl ast::Arg4 { non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), - ast::StateSpace::Reg, + state_space, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -7279,15 +7274,16 @@ fn should_bitcast_wrapper( } fn should_convert_relaxed_src_wrapper( - src_type: &ast::Type, - _: ast::StateSpace, - instr_type: &ast::Type, - _: ast::StateSpace, + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if src_type == instr_type { + if !operand_space.is_compatible(instruction_space) { + return Err(TranslateError::MismatchedType); + } + if operand_type == instruction_type { return Ok(None); } - match should_convert_relaxed_src(src_type, instr_type) { + match should_convert_relaxed_src(operand_type, instruction_type) { conv @ Some(_) => Ok(conv), None => Err(TranslateError::MismatchedType), } @@ -7343,15 +7339,16 @@ fn should_convert_relaxed_src( } fn should_convert_relaxed_dst_wrapper( - dst_type: &ast::Type, - _: ast::StateSpace, - instr_type: &ast::Type, - _: ast::StateSpace, + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if dst_type == instr_type { + if !operand_space.is_compatible(instruction_space) { + return Err(TranslateError::MismatchedType); + } + if operand_type == instruction_type { return Ok(None); } - match should_convert_relaxed_dst(dst_type, instr_type) { + match should_convert_relaxed_dst(operand_type, instruction_type) { conv @ Some(_) => Ok(conv), None => Err(TranslateError::MismatchedType), } From 2e6f7e3fdc6176279644f7bd02f8fb09195d6298 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 31 May 2021 00:00:57 +0200 Subject: [PATCH 13/25] Implement address-taking mov --- ptx/src/translate.rs | 39 +++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 0f6368e..90a28b7 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -4011,10 +4011,6 @@ fn emit_implicit_conversion( let from_parts = cv.from_type.to_parts(); let to_parts = cv.to_type.to_parts(); match (from_parts.kind, to_parts.kind, &cv.kind) { - (_, _, &ConversionKind::PtrToBit(typ)) => { - let dst_type = map.get_or_add_scalar(builder, typ.into()); - builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; - } (_, _, &ConversionKind::BitToPtr) => { let dst_type = map.get_or_add( builder, @@ -4099,6 +4095,10 @@ fn emit_implicit_conversion( ); builder.bitcast(result_type, Some(cv.dst), cv.src)?; } + (_, _, &ConversionKind::AddressOf) => { + let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; + } _ => unreachable!(), } Ok(()) @@ -6157,8 +6157,8 @@ enum ConversionKind { // zero-extend/chop/bitcast depending on types SignExtend, BitToPtr, - PtrToBit(ast::ScalarType), PtrToPtr, + AddressOf, } impl ast::PredAt { @@ -6378,7 +6378,7 @@ impl ast::Arg2Mov { ArgumentDescriptor { op: self.src, is_dst: false, - non_default_implicit_conversion: None, + non_default_implicit_conversion: Some(implicit_conversion_mov), }, &details.typ.clone().into(), ast::StateSpace::Reg, @@ -7066,6 +7066,17 @@ impl ast::StateSpace { || self == ast::StateSpace::Reg && other == ast::StateSpace::Sreg || self == ast::StateSpace::Sreg && other == ast::StateSpace::Reg } + + fn is_addressable(self) -> bool { + match self { + ast::StateSpace::Const + | ast::StateSpace::Generic + | ast::StateSpace::Global + | ast::StateSpace::Local + | ast::StateSpace::Shared => true, + ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false, + } + } } impl ast::Operand { @@ -7174,16 +7185,6 @@ fn default_implicit_conversion_space( }, _ => Err(TranslateError::MismatchedType), } - } else if instruction_space.is_compatible(ast::StateSpace::Reg) { - if let ast::Type::Pointer(instr_ptr_type, instr_ptr_space) = instruction_type { - if operand_space != *instr_ptr_space { - Ok(Some(ConversionKind::PtrToPtr)) - } else { - Ok(None) - } - } else { - Err(TranslateError::MismatchedType) - } } else { Err(TranslateError::MismatchedType) } @@ -7250,6 +7251,12 @@ fn implicit_conversion_mov( return Ok(Some(ConversionKind::Default)); } } + // TODO: verify .params addressability: + // * kernel arg + // * func arg + // * variable + } else if operand_space.is_addressable() { + return Ok(Some(ConversionKind::AddressOf)); } default_implicit_conversion( (operand_space, operand_type), From f70abd065bc7651f75b5f41475a862f509fd68bd Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 4 Jun 2021 00:48:51 +0200 Subject: [PATCH 14/25] Continue attempts at fixing code emission for method args --- ptx/src/ast.rs | 1 + ptx/src/ptx.lalrpop | 4 +- ptx/src/translate.rs | 323 +++++++++++++++++++++++++++++-------------- 3 files changed, 224 insertions(+), 104 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 3ad61e5..a0bb023 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -96,6 +96,7 @@ pub struct MethodDeclaration<'input, ID> { pub return_arguments: Vec>, pub name: MethodName<'input, ID>, pub input_arguments: Vec>, + pub shared_mem: Option>, } pub struct Function<'a, ID, S> { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 2253f85..e8370cd 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -392,12 +392,12 @@ MethodDeclaration: ast::MethodDeclaration<'input, &'input str> = { ".entry" => { let return_arguments = Vec::new(); let name = ast::MethodName::Kernel(name); - ast::MethodDeclaration{ return_arguments, name, input_arguments } + ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } }, ".func" => { let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); let name = ast::MethodName::Func(name); - ast::MethodDeclaration{ return_arguments, name, input_arguments } + ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } } }; diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 90a28b7..6d5d5bc 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -562,7 +562,6 @@ fn emit_directives<'input>( call_map, &directives, kernel_info, - f.uses_shared_mem, )?; for t in f.tuning.iter() { match *t { @@ -1038,10 +1037,9 @@ fn emit_function_header<'a>( call_map: &HashMap<&'a str, HashSet>, direcitves: &[Directive], kernel_info: &mut HashMap, - uses_shared_mem: bool, ) -> Result { if let ast::MethodName::Kernel(name) = func_decl.name { - let input_args = if !uses_shared_mem { + let input_args = if func_decl.shared_mem.is_none() { func_decl.input_arguments.as_slice() } else { &func_decl.input_arguments[0..func_decl.input_arguments.len() - 1] @@ -1054,7 +1052,7 @@ fn emit_function_header<'a>( name.to_string(), KernelInfo { arguments_sizes: args_lens, - uses_shared_mem: uses_shared_mem, + uses_shared_mem: func_decl.shared_mem.is_some(), }, ); } @@ -1218,7 +1216,7 @@ fn rename_fn_params<'a, 'b>( ) -> Vec> { args.iter() .map(|a| ast::Variable { - name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), false), + name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true), v_type: a.v_type.clone(), state_space: a.state_space, align: a.align, @@ -1245,7 +1243,6 @@ fn to_ssa<'input, 'b>( globals: Vec::new(), import_as: None, tuning, - uses_shared_mem: false, }) } }; @@ -1276,7 +1273,6 @@ fn to_ssa<'input, 'b>( body: Some(f_body), import_as: None, tuning, - uses_shared_mem: false, }) } @@ -1529,18 +1525,8 @@ fn convert_to_typed_statements( match s { Statement::Instruction(inst) => match inst { ast::Instruction::Call(call) => { - // TODO: error out if lengths don't match - let fn_def = (**fn_defs.get_fn_decl(call.func)?).borrow(); - let return_arguments = - to_resolved_fn_args(call.ret_params, &*fn_def.return_arguments); - let input_arguments = - to_resolved_fn_args(call.param_list, &*fn_def.input_arguments); - let resolved_call = ResolvedCall { - uniform: call.uniform, - return_arguments, - name: call.func, - input_arguments, - }; + let resolver = fn_defs.get_fn_sig_resolver(call.func)?; + let resolved_call = resolver.resolve_in_spirv_repr(call)?; let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let reresolved_call = resolved_call.visit(&mut visitor)?; visitor.func.push(reresolved_call); @@ -1683,6 +1669,7 @@ fn to_ptx_impl_atomic_call( array_init: Vec::new(), }, ], + shared_mem: None, }; let func = Function { func_decl: Rc::new(RefCell::new(func_decl)), @@ -1690,7 +1677,6 @@ fn to_ptx_impl_atomic_call( body: None, import_as: Some(entry.key().clone()), tuning: Vec::new(), - uses_shared_mem: false, }; entry.insert(Directive::Method(func)); fn_id @@ -1772,6 +1758,7 @@ fn to_ptx_impl_bfe_call( array_init: Vec::new(), }, ], + shared_mem: None, }; let func = Function { func_decl: Rc::new(RefCell::new(func_decl)), @@ -1779,7 +1766,6 @@ fn to_ptx_impl_bfe_call( body: None, import_as: Some(entry.key().clone()), tuning: Vec::new(), - uses_shared_mem: false, }; entry.insert(Directive::Method(func)); fn_id @@ -1871,6 +1857,7 @@ fn to_ptx_impl_bfi_call( array_init: Vec::new(), }, ], + shared_mem: None, }; let func = Function { func_decl: Rc::new(RefCell::new(func_decl)), @@ -1878,7 +1865,6 @@ fn to_ptx_impl_bfi_call( body: None, import_as: Some(entry.key().clone()), tuning: Vec::new(), - uses_shared_mem: false, }; entry.insert(Directive::Method(func)); fn_id @@ -2009,42 +1995,44 @@ fn normalize_predicates( Ok(result) } +/* + How do we handle arguments: + - input .params + .param .b64 in_arg + get turned into this SPIR-V: + %1 = OpFunctionParameter %ulong + %2 = OpVariable %%_ptr_Function_ulong Function + OpStore %2 %1 + We do this for two reasons. One, common treatment for argument-declared + .param variables and .param variables inside function (we assume that + at SPIR-V level every .param is a pointer in Function storage class). Two, + PTX devs in their infinite wisdom decided that .reg arguments are writable + - input .regs + .reg .b64 in_arg + get turned into this SPIR-V: + %1 = OpFunctionParameter %ulong + %2 = OpVariable %%_ptr_Function_ulong Function + OpStore %2 %1 + with the difference that %2 is defined as a variable and not temp + - output .regs + .reg .b64 out_arg + get just a variable declaration: + %2 = OpVariable %%_ptr_Function_ulong Function + - output .params + .param .b64 out_arg + get treated the same as input .params, because there's no difference +*/ fn insert_mem_ssa_statements<'a, 'b>( func: Vec, id_def: &mut NumericIdResolver, fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>, ) -> Result, TranslateError> { let mut result = Vec::with_capacity(func.len()); - for arg in fn_decl.return_arguments.iter() { - result.push(Statement::Variable(ast::Variable { - align: arg.align, - v_type: arg.v_type.clone(), - state_space: arg.state_space, - name: arg.name, - array_init: arg.array_init.clone(), - })); - } for arg in fn_decl.input_arguments.iter_mut() { - let typ = arg.v_type.clone(); - let state_space = arg.state_space; - let new_id = id_def.register_intermediate(Some((typ.clone(), state_space))); - result.push(Statement::Variable(ast::Variable { - align: arg.align, - v_type: arg.v_type.clone(), - state_space: arg.state_space, - name: arg.name, - array_init: Vec::new(), - })); - result.push(Statement::StoreVar(StoreVarDetails { - arg: ast::Arg2St { - src1: arg.name, - src2: new_id, - }, - state_space, - typ, - member_index: None, - })); - arg.name = new_id; + insert_mem_ssa_argument(id_def, &mut result, arg); + } + for arg in fn_decl.return_arguments.iter() { + insert_mem_ssa_argument_reg_return(&mut result, arg); } for s in func { match s { @@ -2054,22 +2042,26 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::Instruction(inst) => match inst { ast::Instruction::Ret(d) => { // TODO: handle multiple output args - if let &[out_param] = &fn_decl.return_arguments.as_slice() { - let (typ, space, _) = id_def.get_typed(out_param.name)?; - let new_id = id_def.register_intermediate(Some((typ.clone(), space))); - result.push(Statement::LoadVar(LoadVarDetails { - arg: ast::Arg2 { - dst: new_id, - src: out_param.name, - }, - // TODO: ret with stateful conversion - state_space: new_todo!(), - typ: typ.clone(), - member_index: None, - })); - result.push(Statement::RetValue(d, new_id)); - } else { - result.push(Statement::Instruction(ast::Instruction::Ret(d))) + match &fn_decl.return_arguments[..] { + [return_reg] => { + let new_id = id_def.register_intermediate(Some(( + return_reg.v_type.clone(), + ast::StateSpace::Reg, + ))); + result.push(Statement::LoadVar(LoadVarDetails { + arg: ast::Arg2 { + dst: new_id, + src: return_reg.name, + }, + // TODO: ret with stateful conversion + state_space: ast::StateSpace::Reg, + typ: return_reg.v_type.clone(), + member_index: None, + })); + result.push(Statement::RetValue(d, new_id)); + } + [] => result.push(Statement::Instruction(ast::Instruction::Ret(d))), + _ => unimplemented!(), } } inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?, @@ -2107,6 +2099,43 @@ fn insert_mem_ssa_statements<'a, 'b>( Ok(result) } +fn insert_mem_ssa_argument( + id_def: &mut NumericIdResolver, + func: &mut Vec, + arg: &mut ast::Variable, +) { + let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space))); + func.push(Statement::Variable(ast::Variable { + align: arg.align, + v_type: arg.v_type.clone(), + state_space: ast::StateSpace::Reg, + name: arg.name, + array_init: Vec::new(), + })); + func.push(Statement::StoreVar(StoreVarDetails { + arg: ast::Arg2St { + src1: arg.name, + src2: new_id, + }, + typ: arg.v_type.clone(), + member_index: None, + })); + arg.name = new_id; +} + +fn insert_mem_ssa_argument_reg_return( + func: &mut Vec, + arg: &ast::Variable, +) { + func.push(Statement::Variable(ast::Variable { + align: arg.align, + v_type: arg.v_type.clone(), + state_space: arg.state_space, + name: arg.name, + array_init: arg.array_init.clone(), + })); +} + trait Visitable: Sized { fn visit( self, @@ -2202,7 +2231,6 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { src1: symbol, src2: generated_id, }, - state_space: ast::StateSpace::Reg, typ: var_type, member_index: member_index.map(|(idx, _)| idx), })); @@ -4162,10 +4190,10 @@ fn emit_load_var( Ok(()) } -fn normalize_identifiers<'a, 'b>( - id_defs: &mut FnStringIdResolver<'a, 'b>, - fn_defs: &GlobalFnDeclResolver<'a, 'b>, - func: Vec>>, +fn normalize_identifiers<'input, 'b>( + id_defs: &mut FnStringIdResolver<'input, 'b>, + fn_defs: &GlobalFnDeclResolver<'input, 'b>, + func: Vec>>, ) -> Result, TranslateError> { for s in func.iter() { match s { @@ -4796,12 +4824,92 @@ impl SpecialRegistersMap { } } +struct FnSigMapper<'input> { + // true - stays as return argument + // false - is moved to input argument + return_param_args: Vec, + func_decl: Rc>>, +} + +impl<'input> FnSigMapper<'input> { + fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, spirv::Word>) -> Self { + let return_param_args = method + .return_arguments + .iter() + .map(|a| a.state_space != ast::StateSpace::Param) + .collect::>(); + let mut new_return_arguments = Vec::new(); + for arg in method.return_arguments.into_iter() { + if arg.state_space == ast::StateSpace::Param { + method.input_arguments.push(arg); + } else { + new_return_arguments.push(arg); + } + } + method.return_arguments = new_return_arguments; + FnSigMapper { + return_param_args, + func_decl: Rc::new(RefCell::new(method)), + } + } + + fn resolve_in_spirv_repr( + &self, + call_inst: ast::CallInst, + ) -> Result, TranslateError> { + let func_decl = (*self.func_decl).borrow(); + let mut return_arguments = Vec::new(); + let mut input_arguments = call_inst + .param_list + .into_iter() + .zip(func_decl.input_arguments.iter()) + .map(|(id, var)| (id, var.v_type.clone(), var.state_space)) + .collect::>(); + let mut func_decl_return_iter = func_decl.return_arguments.iter(); + let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter(); + for (idx, id) in call_inst.ret_params.iter().enumerate() { + let stays_as_return = match self.return_param_args.get(idx) { + Some(x) => *x, + None => return Err(TranslateError::MismatchedType), + }; + if stays_as_return { + if let Some(var) = func_decl_return_iter.next() { + return_arguments.push((*id, var.v_type.clone(), var.state_space)); + } else { + return Err(TranslateError::MismatchedType); + } + } else { + if let Some(var) = func_decl_input_iter.next() { + input_arguments.push(( + ast::Operand::Reg(*id), + var.v_type.clone(), + var.state_space, + )); + } else { + return Err(TranslateError::MismatchedType); + } + } + } + if return_arguments.len() != func_decl.return_arguments.len() + || input_arguments.len() != func_decl.input_arguments.len() + { + return Err(TranslateError::MismatchedType); + } + Ok(ResolvedCall { + return_arguments, + input_arguments, + uniform: call_inst.uniform, + name: call_inst.func, + }) + } +} + struct GlobalStringIdResolver<'input> { current_id: spirv::Word, variables: HashMap, spirv::Word>, variables_type_check: HashMap>, special_registers: SpecialRegistersMap, - fns: HashMap>>>, + fns: HashMap>, } impl<'input> GlobalStringIdResolver<'input> { @@ -4885,45 +4993,36 @@ impl<'input> GlobalStringIdResolver<'input> { ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), ast::MethodName::Func(_) => ast::MethodName::Func(name_id), }; - let new_fn_decl = Rc::new(RefCell::new(ast::MethodDeclaration { + let fn_decl = ast::MethodDeclaration { return_arguments, name, input_arguments, - })); - self.fns.insert(name_id, Rc::clone(&new_fn_decl)); + shared_mem: None, + }; + let new_fn_decl = if !fn_decl.name.is_kernel() { + let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl); + let new_fn_decl = resolver.func_decl.clone(); + self.fns.insert(name_id, resolver); + new_fn_decl + } else { + Rc::new(RefCell::new(fn_decl)) + }; Ok(( fn_resolver, - GlobalFnDeclResolver { - variables: &self.variables, - fns: &self.fns, - }, + GlobalFnDeclResolver { fns: &self.fns }, new_fn_decl, )) } } pub struct GlobalFnDeclResolver<'input, 'a> { - variables: &'a HashMap, spirv::Word>, - fns: &'a HashMap>>>, + fns: &'a HashMap>, } impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { - fn get_fn_decl( - &self, - id: spirv::Word, - ) -> Result<&Rc>>, TranslateError> { + fn get_fn_sig_resolver(&self, id: spirv::Word) -> Result<&FnSigMapper<'input>, TranslateError> { self.fns.get(&id).ok_or(TranslateError::UnknownSymbol) } - - fn get_fn_decl_str( - &self, - id: &str, - ) -> Result<&'a Rc>>, TranslateError> { - match self.variables.get(id).map(|var_id| self.fns.get(var_id)) { - Some(Some(fn_d)) => Ok(fn_d), - _ => Err(TranslateError::UnknownSymbol), - } - } } struct FnStringIdResolver<'input, 'b> { @@ -5209,7 +5308,6 @@ struct LoadVarDetails { struct StoreVarDetails { arg: ast::Arg2St, typ: ast::Type, - state_space: ast::StateSpace, member_index: Option, } @@ -5300,7 +5398,7 @@ impl> ResolvedCall { self, visitor: &mut V, ) -> Result, TranslateError> { - let ret_params = self + let return_arguments = self .return_arguments .into_iter() .map::, _>(|(id, typ, space)| { @@ -5324,7 +5422,7 @@ impl> ResolvedCall { }, None, )?; - let param_list = self + let input_arguments = self .input_arguments .into_iter() .map::, _>(|(id, typ, space)| { @@ -5342,9 +5440,9 @@ impl> ResolvedCall { .collect::, _>>()?; Ok(ResolvedCall { uniform: self.uniform, - return_arguments: ret_params, + return_arguments, name: func, - input_arguments: param_list, + input_arguments, }) } } @@ -5485,7 +5583,6 @@ struct Function<'input> { pub func_decl: Rc>>, pub globals: Vec>, pub body: Option>, - pub uses_shared_mem: bool, import_as: Option, tuning: Vec, } @@ -7185,6 +7282,19 @@ fn default_implicit_conversion_space( }, _ => Err(TranslateError::MismatchedType), } + } else if instruction_space.is_compatible(ast::StateSpace::Reg) { + match instruction_type { + ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space) + if operand_space == *instruction_ptr_space => + { + if operand_type != &ast::Type::Scalar(*instruction_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } + _ => Err(TranslateError::MismatchedType), + } } else { Err(TranslateError::MismatchedType) } @@ -7432,6 +7542,15 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> { } } +impl<'input, ID> ast::MethodName<'input, ID> { + fn is_kernel(&self) -> bool { + match self { + ast::MethodName::Kernel(..) => true, + ast::MethodName::Func(..) => false, + } + } +} + #[cfg(test)] mod tests { use super::*; From 90960fd9239b9972dfffbff6ce26ce2642ec50af Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 5 Jun 2021 00:46:41 +0200 Subject: [PATCH 15/25] Fix method arg load generation --- ptx/src/translate.rs | 67 +++++++++++++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 6d5d5bc..c4efe55 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1059,7 +1059,7 @@ fn emit_function_header<'a>( let (ret_type, func_type) = get_function_type( builder, map, - &func_decl.input_arguments, + func_decl.effective_input_arguments().map(|(_, typ)| typ), &func_decl.return_arguments, ); let fn_id = match func_decl.name { @@ -1120,9 +1120,9 @@ fn emit_function_header<'a>( } } */ - for input in &func_decl.input_arguments { - let result_type = map.get_or_add(builder, SpirvType::new(input.v_type.clone())); - builder.function_parameter(Some(input.name), result_type)?; + for (name, typ) in func_decl.effective_input_arguments() { + let result_type = map.get_or_add(builder, typ); + builder.function_parameter(Some(name), result_type)?; } Ok(fn_id) } @@ -1233,7 +1233,7 @@ fn to_ssa<'input, 'b>( f_body: Option>>>, tuning: Vec, ) -> Result, TranslateError> { - deparamize_function_decl(&func_decl)?; + //deparamize_function_decl(&func_decl)?; let f_body = match f_body { Some(vec) => vec, None => { @@ -1997,30 +1997,38 @@ fn normalize_predicates( /* How do we handle arguments: - - input .params + - input .params in kernels .param .b64 in_arg get turned into this SPIR-V: %1 = OpFunctionParameter %ulong - %2 = OpVariable %%_ptr_Function_ulong Function + %2 = OpVariable %_ptr_Function_ulong Function OpStore %2 %1 We do this for two reasons. One, common treatment for argument-declared .param variables and .param variables inside function (we assume that - at SPIR-V level every .param is a pointer in Function storage class). Two, - PTX devs in their infinite wisdom decided that .reg arguments are writable + at SPIR-V level every .param is a pointer in Function storage class) + - input .params in functions + .param .b64 in_arg + get turned into this SPIR-V: + %1 = OpFunctionParameter %_ptr_Function_ulong - input .regs .reg .b64 in_arg - get turned into this SPIR-V: + get turned into the same SPIR-V as kernel .params: %1 = OpFunctionParameter %ulong - %2 = OpVariable %%_ptr_Function_ulong Function + %2 = OpVariable %_ptr_Function_ulong Function OpStore %2 %1 - with the difference that %2 is defined as a variable and not temp - output .regs .reg .b64 out_arg get just a variable declaration: %2 = OpVariable %%_ptr_Function_ulong Function - - output .params - .param .b64 out_arg - get treated the same as input .params, because there's no difference + - output .params don't exist, they have been moved to input positions + by an earlier pass + Distinguishing betweem kernel .params and function .params is not the + cleanest solution. Alternatively, we could "deparamize" all kernel .param + arguments by turning them into .reg arguments like this: + .param .b64 arg -> .reg ptr<.b64,.param> arg + This has the massive downside that this transformation would have to run + very early and would muddy up already difficult code. It's simpler to just + have an if here */ fn insert_mem_ssa_statements<'a, 'b>( func: Vec, @@ -2029,7 +2037,7 @@ fn insert_mem_ssa_statements<'a, 'b>( ) -> Result, TranslateError> { let mut result = Vec::with_capacity(func.len()); for arg in fn_decl.input_arguments.iter_mut() { - insert_mem_ssa_argument(id_def, &mut result, arg); + insert_mem_ssa_argument(id_def, &mut result, arg, fn_decl.name.is_kernel()); } for arg in fn_decl.return_arguments.iter() { insert_mem_ssa_argument_reg_return(&mut result, arg); @@ -2103,7 +2111,11 @@ fn insert_mem_ssa_argument( id_def: &mut NumericIdResolver, func: &mut Vec, arg: &mut ast::Variable, + is_kernel: bool, ) { + if !is_kernel && arg.state_space == ast::StateSpace::Param { + return; + } let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space))); func.push(Statement::Variable(ast::Variable { align: arg.align, @@ -2559,14 +2571,12 @@ fn insert_implicit_conversions_impl( fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, - spirv_input: &[ast::Variable], + spirv_input: impl ExactSizeIterator, spirv_output: &[ast::Variable], ) -> (spirv::Word, spirv::Word) { map.get_or_add_fn( builder, - spirv_input - .iter() - .map(|var| SpirvType::new(var.v_type.clone())), + spirv_input, spirv_output .iter() .map(|var| SpirvType::new(var.v_type.clone())), @@ -7542,6 +7552,23 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> { } } +impl<'a> ast::MethodDeclaration<'a, spirv::Word> { + fn effective_input_arguments( + &self, + ) -> impl ExactSizeIterator + '_ { + let is_kernel = self.name.is_kernel(); + self.input_arguments.iter().map(move |arg| { + if !is_kernel { + let spirv_type = + SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv()); + (arg.name, spirv_type) + } else { + (arg.name, SpirvType::new(arg.v_type.clone())) + } + }) + } +} + impl<'input, ID> ast::MethodName<'input, ID> { fn is_kernel(&self) -> bool { match self { From 83ba70bf37a47d58d2e6e2ac808ad77bd50a029d Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 5 Jun 2021 01:15:36 +0200 Subject: [PATCH 16/25] Remove last uses of new_todo --- ptx/src/translate.rs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c4efe55..ecc5544 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -39,12 +39,6 @@ fn error_unreachable() -> TranslateError { TranslateError::Unreachable } -macro_rules! new_todo { - () => { - todo!() - }; -} - #[derive(PartialEq, Eq, Hash, Clone)] enum SpirvType { Base(SpirvScalarKey), @@ -4104,9 +4098,9 @@ fn emit_implicit_conversion( src: wide_bit_value, dst: cv.dst, from_type: wide_bit_type, - from_space: new_todo!(), + from_space: cv.from_space, to_type: cv.to_type.clone(), - to_space: new_todo!(), + to_space: cv.to_space, kind: ConversionKind::Default, }, )?; @@ -7558,7 +7552,7 @@ impl<'a> ast::MethodDeclaration<'a, spirv::Word> { ) -> impl ExactSizeIterator + '_ { let is_kernel = self.name.is_kernel(); self.input_arguments.iter().map(move |arg| { - if !is_kernel { + if !is_kernel && arg.state_space != ast::StateSpace::Reg { let spirv_type = SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv()); (arg.name, spirv_type) From 491e71e346b267dd647e5a17d8fecca2a08e0f53 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 6 Jun 2021 00:10:26 +0200 Subject: [PATCH 17/25] Make vector extraction honor relaxed implicit conversion semantics --- ptx/src/translate.rs | 47 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index ecc5544..a0b5077 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1560,6 +1560,12 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { fn convert_vector( &mut self, is_dst: bool, + non_default_implicit_conversion: Option< + fn( + (ast::StateSpace, &ast::Type), + (ast::StateSpace, &ast::Type), + ) -> Result, TranslateError>, + >, typ: &ast::Type, state_space: ast::StateSpace, idx: Vec, @@ -1577,6 +1583,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { typ: scalar_t, packed: temp_vec, unpacked: idx, + non_default_implicit_conversion, }); if is_dst { self.post_stmts = Some(statement); @@ -1609,9 +1616,13 @@ impl<'a, 'b> ArgumentMapVisitor ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), ast::Operand::Imm(x) => TypedOperand::Imm(x), ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), - ast::Operand::VecPack(vec) => { - TypedOperand::Reg(self.convert_vector(desc.is_dst, typ, state_space, vec)?) - } + ast::Operand::VecPack(vec) => TypedOperand::Reg(self.convert_vector( + desc.is_dst, + desc.non_default_implicit_conversion, + typ, + state_space, + vec, + )?), }) } } @@ -5320,6 +5331,12 @@ struct RepackVectorDetails { typ: ast::ScalarType, packed: spirv::Word, unpacked: Vec, + non_default_implicit_conversion: Option< + fn( + (ast::StateSpace, &ast::Type), + (ast::StateSpace, &ast::Type), + ) -> Result, TranslateError>, + >, } impl RepackVectorDetails { @@ -5335,7 +5352,6 @@ impl RepackVectorDetails { ArgumentDescriptor { op: self.packed, is_dst: !self.is_extract, - non_default_implicit_conversion: None, }, Some(( @@ -5345,6 +5361,7 @@ impl RepackVectorDetails { )?; let scalar_type = self.typ; let is_extract = self.is_extract; + let non_default_implicit_conversion = self.non_default_implicit_conversion; let vector = self .unpacked .into_iter() @@ -5353,7 +5370,7 @@ impl RepackVectorDetails { ArgumentDescriptor { op: id, is_dst: is_extract, - non_default_implicit_conversion: None, + non_default_implicit_conversion, }, Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)), ) @@ -5364,6 +5381,7 @@ impl RepackVectorDetails { typ: self.typ, packed: scalar, unpacked: vector, + non_default_implicit_conversion, }) } } @@ -7168,6 +7186,19 @@ impl ast::StateSpace { || self == ast::StateSpace::Sreg && other == ast::StateSpace::Reg } + fn coerces_to_generic(self) -> bool { + match self { + ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::Local + | ast::StateSpace::Shared => true, + ast::StateSpace::Reg + | ast::StateSpace::Param + | ast::StateSpace::Generic + | ast::StateSpace::Sreg => false, + } + } + fn is_addressable(self) -> bool { match self { ast::StateSpace::Const @@ -7254,7 +7285,11 @@ fn default_implicit_conversion_space( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if operand_space.is_compatible(ast::StateSpace::Reg) { + if (instruction_space == ast::StateSpace::Generic && operand_space.coerces_to_generic()) + || (operand_space == ast::StateSpace::Generic && instruction_space.coerces_to_generic()) + { + Ok(Some(ConversionKind::PtrToPtr)) + } else if operand_space.is_compatible(ast::StateSpace::Reg) { match operand_type { ast::Type::Pointer(operand_ptr_type, operand_ptr_space) if *operand_ptr_space == instruction_space => From e940b9400fe9e67bb9ffdb79ab6d81f31cf88877 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 6 Jun 2021 17:25:05 +0200 Subject: [PATCH 18/25] Bring back support for dynamic shared memory --- ptx/src/ast.rs | 6 +-- ptx/src/ptx.lalrpop | 38 ++++++++------ ptx/src/translate.rs | 119 ++++++++++++++++++++++--------------------- 3 files changed, 87 insertions(+), 76 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index a0bb023..5432207 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -82,8 +82,8 @@ pub struct Module<'a> { } pub enum Directive<'a, P: ArgParams> { - Variable(Variable), - Method(Function<'a, &'a str, Statement

>), + Variable(LinkingDirective, Variable), + Method(LinkingDirective, Function<'a, &'a str, Statement

>), } #[derive(Hash, PartialEq, Eq, Copy, Clone)] @@ -96,7 +96,7 @@ pub struct MethodDeclaration<'input, ID> { pub return_arguments: Vec>, pub name: MethodName<'input, ID>, pub input_arguments: Vec>, - pub shared_mem: Option>, + pub shared_mem: Option, } pub struct Function<'a, ID, S> { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index e8370cd..b697317 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -343,10 +343,16 @@ TargetSpecifier = { Directive: Option>> = { AddressSize => None, - => Some(ast::Directive::Method(f)), + => { + let (linking, func) = f; + Some(ast::Directive::Method(linking, func)) + }, File => None, Section => None, - ";" => Some(ast::Directive::Variable(v)), + ";" => { + let (linking, var) = v; + Some(ast::Directive::Variable(linking, var)) + }, ! => { let err = <>; errors.push(err.error); @@ -358,11 +364,13 @@ AddressSize = { ".address_size" U8Num }; -Function: ast::Function<'input, &'input str, ast::Statement>> = { - LinkingDirectives +Function: (ast::LinkingDirective, ast::Function<'input, &'input str, ast::Statement>>) = { + - => ast::Function{<>} + => { + (linking, ast::Function{func_directive, tuning, body}) + } }; LinkingDirective: ast::LinkingDirective = { @@ -598,18 +606,18 @@ SharedVariable: ast::Variable<&'input str> = { } } -ModuleVariable: ast::Variable<&'input str> = { - LinkingDirectives ".global" => { +ModuleVariable: (ast::LinkingDirective, ast::Variable<&'input str>) = { + ".global" => { let (align, v_type, name, array_init) = def; let state_space = ast::StateSpace::Global; - ast::Variable { align, v_type, state_space, name, array_init } + (linking, ast::Variable { align, v_type, state_space, name, array_init }) }, - LinkingDirectives ".shared" => { + ".shared" => { let (align, v_type, name, array_init) = def; let state_space = ast::StateSpace::Shared; - ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } + (linking, ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }) }, - > > =>? { + > > =>? { let (align, t, name, arr_or_ptr) = var; let (v_type, state_space, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { @@ -620,17 +628,17 @@ ModuleVariable: ast::Variable<&'input str> = { } } ast::ArrayOrPointer::Pointer => { - if !ldirs.contains(ast::LinkingDirective::EXTERN) { + if !linking.contains(ast::LinkingDirective::EXTERN) { return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); } if space == ".global" { - (ast::Type::Scalar(t), ast::StateSpace::Global, Vec::new()) + (ast::Type::Array(t, Vec::new()), ast::StateSpace::Global, Vec::new()) } else { - (ast::Type::Scalar(t), ast::StateSpace::Shared, Vec::new()) + (ast::Type::Array(t, Vec::new()), ast::StateSpace::Shared, Vec::new()) } } }; - Ok(ast::Variable{ align, v_type, state_space, name, array_init }) + Ok((linking, ast::Variable{ align, v_type, state_space, name, array_init })) } } diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a0b5077..6b9dcfb 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -172,14 +172,18 @@ impl TypeWordMap { .or_insert_with(|| b.type_vector(None, base, len as u32)) } SpirvType::Array(typ, array_dimensions) => { - let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); let (base_type, length) = match &*array_dimensions { + &[] => { + return self.get_or_add(b, SpirvType::Base(typ)); + } &[len] => { + let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); let base = self.get_or_add_spirv_scalar(b, typ); let len_const = b.constant_u32(u32_type, None, len); (base, len_const) } array_dimensions => { + let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); let base = self .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec())); let len_const = b.constant_u32(u32_type, None, array_dimensions[0]); @@ -221,7 +225,7 @@ impl TypeWordMap { fn get_or_add_fn( &mut self, b: &mut dr::Builder, - in_params: impl ExactSizeIterator, + in_params: impl Iterator, mut out_params: impl ExactSizeIterator, ) -> (spirv::Word, spirv::Word) { let (out_args, out_spirv_type) = if out_params.len() == 0 { @@ -233,6 +237,7 @@ impl TypeWordMap { self.get_or_add(b, arg_as_key), ) } else { + // TODO: support multiple return values todo!() }; ( @@ -436,7 +441,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result( let empty_body = Vec::new(); for d in directives.iter() { match d { - Directive::Variable(var) => { + Directive::Variable(_, var) => { emit_variable(builder, map, &var)?; } Directive::Method(f) => { @@ -699,7 +704,6 @@ fn multi_hash_map_append(m: &mut MultiHashMap, transformation has a semantical meaning - we emit additional "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...") */ -/* fn convert_dynamic_shared_memory_usage<'input>( module: Vec>, new_id: &mut impl FnMut() -> spirv::Word, @@ -707,13 +711,16 @@ fn convert_dynamic_shared_memory_usage<'input>( let mut extern_shared_decls = HashMap::new(); for dir in module.iter() { match dir { - Directive::Variable(ast::Variable { - v_type: ast::Type::Pointer(p_type), - state_space: ast::StateSpace::Shared, - name, - .. - }) => { - extern_shared_decls.insert(*name, p_type.clone()); + Directive::Variable( + linking, + ast::Variable { + v_type: ast::Type::Array(p_type, dims), + state_space: ast::StateSpace::Shared, + name, + .. + }, + ) if linking.contains(ast::LinkingDirective::EXTERN) && dims.len() == 0 => { + extern_shared_decls.insert(*name, *p_type); } _ => {} } @@ -732,14 +739,13 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), import_as, tuning, - uses_shared_mem, }) => { let call_key = (*func_decl).borrow().name; let statements = statements .into_iter() .map(|statement| match statement { Statement::Call(call) => { - multi_hash_map_append(&mut directly_called_by, call.func, call_key); + multi_hash_map_append(&mut directly_called_by, call.name, call_key); Statement::Call(call) } statement => statement.map_id(&mut |id, _| { @@ -756,7 +762,6 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), import_as, tuning, - uses_shared_mem, }) } directive => directive, @@ -775,7 +780,6 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), import_as, tuning, - uses_shared_mem, }) => { if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) { return Directive::Method(Function { @@ -784,21 +788,12 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), import_as, tuning, - uses_shared_mem, }); } let shared_id_param = new_id(); { let mut func_decl = (*func_decl).borrow_mut(); - func_decl.input_arguments.push({ - ast::Variable { - name: shared_id_param, - align: None, - v_type: ast::Type::Pointer(ast::ScalarType::B8, new_todo!()), - state_space: ast::StateSpace::Shared, - array_init: Vec::new(), - } - }); + func_decl.shared_mem = Some(shared_id_param); } let statements = replace_uses_of_shared_memory( new_id, @@ -813,7 +808,6 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), import_as, tuning, - uses_shared_mem: true, }) } directive => directive, @@ -835,8 +829,8 @@ fn replace_uses_of_shared_memory<'a>( // We can safely skip checking call arguments, // because there's simply no way to pass shared ptr // without converting it to .b64 first - if methods_using_extern_shared.contains(&ast::MethodName::Func(call.func)) { - call.param_list.push(( + if methods_using_extern_shared.contains(&ast::MethodName::Func(call.name)) { + call.input_arguments.push(( shared_id_param, ast::Type::Scalar(ast::ScalarType::B8), ast::StateSpace::Shared, @@ -854,13 +848,11 @@ fn replace_uses_of_shared_memory<'a>( result.push(Statement::Conversion(ImplicitConversion { src: shared_id_param, dst: replacement_id, - from_type: ast::Type::Pointer(ast::ScalarType::B8), + from_type: ast::Type::Scalar(ast::ScalarType::B8), from_space: ast::StateSpace::Shared, - to_type: ast::Type::Pointer((*scalar_type).into()), + to_type: ast::Type::Scalar(*scalar_type), to_space: ast::StateSpace::Shared, - kind: ConversionKind::PtrToPtr { spirv_ptr: true }, - src_ - dst_ + kind: ConversionKind::PtrToPtr, })); replacement_id } else { @@ -912,7 +904,6 @@ fn get_callers_of_extern_shared_single<'a>( } } } -*/ type DenormCountMap = HashMap; @@ -948,7 +939,7 @@ fn compute_denorm_information<'input>( let mut denorm_methods = HashMap::new(); for directive in module { match directive { - Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {} + Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {} Directive::Method(Function { func_decl, body: Some(statements), @@ -1158,14 +1149,17 @@ fn translate_directive<'input>( d: ast::Directive<'input, ast::ParsedArgParams<'input>>, ) -> Result>, TranslateError> { Ok(match d { - ast::Directive::Variable(var) => Some(Directive::Variable(ast::Variable { - align: var.align, - v_type: var.v_type.clone(), - state_space: var.state_space, - name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true), - array_init: var.array_init, - })), - ast::Directive::Method(f) => { + ast::Directive::Variable(linking, var) => Some(Directive::Variable( + linking, + ast::Variable { + align: var.align, + v_type: var.v_type.clone(), + state_space: var.state_space, + name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true), + array_init: var.array_init, + }, + )), + ast::Directive::Method(_, f) => { translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method) } }) @@ -2576,7 +2570,7 @@ fn insert_implicit_conversions_impl( fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, - spirv_input: impl ExactSizeIterator, + spirv_input: impl Iterator, spirv_output: &[ast::Variable], ) -> (spirv::Word, spirv::Word) { map.get_or_add_fn( @@ -5597,7 +5591,7 @@ impl ast::ArgParams for ExpandedArgParams { impl ArgParamsEx for ExpandedArgParams {} enum Directive<'input> { - Variable(ast::Variable), + Variable(ast::LinkingDirective, ast::Variable), Method(Function<'input>), } @@ -7582,19 +7576,28 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> { } impl<'a> ast::MethodDeclaration<'a, spirv::Word> { - fn effective_input_arguments( - &self, - ) -> impl ExactSizeIterator + '_ { + fn effective_input_arguments(&self) -> impl Iterator + '_ { let is_kernel = self.name.is_kernel(); - self.input_arguments.iter().map(move |arg| { - if !is_kernel && arg.state_space != ast::StateSpace::Reg { - let spirv_type = - SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv()); - (arg.name, spirv_type) - } else { - (arg.name, SpirvType::new(arg.v_type.clone())) - } - }) + self.input_arguments + .iter() + .map(move |arg| { + if !is_kernel && arg.state_space != ast::StateSpace::Reg { + let spirv_type = + SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv()); + (arg.name, spirv_type) + } else { + (arg.name, SpirvType::new(arg.v_type.clone())) + } + }) + .chain(self.shared_mem.iter().map(|id| { + ( + *id, + SpirvType::Pointer( + Box::new(SpirvType::Base(SpirvScalarKey::B8)), + spirv::StorageClass::Workgroup, + ), + ) + })) } } From 9ad88ac98298a2df76ad64570675db1d3a0c7b58 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 6 Jun 2021 18:14:49 +0200 Subject: [PATCH 19/25] Make stateful optimization build --- ptx/src/translate.rs | 109 ++++++++++++++++++++----------------------- 1 file changed, 51 insertions(+), 58 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 6b9dcfb..61b255d 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1239,8 +1239,8 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; - //let typed_statements = - // convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; + let typed_statements = + convert_to_stateful_memory_access(&func_decl, typed_statements, &mut numeric_id_defs)?; let ssa_statements = insert_mem_ssa_statements( typed_statements, &mut numeric_id_defs, @@ -4307,14 +4307,14 @@ fn expand_map_variables<'a, 'b>( // TODO: once insert_mem_ssa_statements is moved to later, move this pass after // argument expansion // TODO: propagate through calls? -/* -fn convert_to_stateful_memory_access<'a>( - func_args: &mut SpirvMethodDecl, +fn convert_to_stateful_memory_access<'a, 'input>( + func_args: &Rc>>, func_body: Vec, id_defs: &mut NumericIdResolver<'a>, ) -> Result, TranslateError> { - let func_args_64bit = func_args - .input + let mut func_args = func_args.borrow_mut(); + let func_args_64bit = (*func_args) + .input_arguments .iter() .filter_map(|arg| match arg.v_type { ast::Type::Scalar(ast::ScalarType::U64) @@ -4445,15 +4445,15 @@ fn convert_to_stateful_memory_access<'a>( let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len()); for reg in regs_ptr_seen { let new_id = id_defs.register_variable( - ast::Type::Pointer(ast::ScalarType::U8), - ast::StateSpace::Global, + ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + ast::StateSpace::Reg, ); result.push(Statement::Variable(ast::Variable { align: None, name: new_id, array_init: Vec::new(), - v_type: ast::Type::Pointer(ast::ScalarType::U8), - state_space: ast::StateSpace::Global, + v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + state_space: ast::StateSpace::Reg, })); remapped_ids.insert(reg, new_id); } @@ -4515,8 +4515,10 @@ fn convert_to_stateful_memory_access<'a>( } _ => return Err(error_unreachable()), }; - let offset_neg = - id_defs.register_intermediate(Some(ast::Type::Scalar(ast::ScalarType::S64))); + let offset_neg = id_defs.register_intermediate(Some(( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ))); result.push(Statement::Instruction(ast::Instruction::Neg( ast::NegDetails { typ: ast::ScalarType::S64, @@ -4538,9 +4540,8 @@ fn convert_to_stateful_memory_access<'a>( } Statement::Instruction(inst) => { let mut post_statements = Vec::new(); - let new_statement = inst.visit( - &mut |arg_desc: ArgumentDescriptor, - expected_type: Option<&ast::Type>| { + let new_statement = + inst.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, @@ -4550,16 +4551,14 @@ fn convert_to_stateful_memory_access<'a>( arg_desc, expected_type, ) - }, - )?; + })?; result.push(new_statement); result.extend(post_statements); } Statement::Call(call) => { let mut post_statements = Vec::new(); - let new_statement = call.visit( - &mut |arg_desc: ArgumentDescriptor, - expected_type: Option<&ast::Type>| { + let new_statement = + call.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, @@ -4569,16 +4568,14 @@ fn convert_to_stateful_memory_access<'a>( arg_desc, expected_type, ) - }, - )?; + })?; result.push(new_statement); result.extend(post_statements); } Statement::RepackVector(pack) => { let mut post_statements = Vec::new(); - let new_statement = pack.visit( - &mut |arg_desc: ArgumentDescriptor, - expected_type: Option<&ast::Type>| { + let new_statement = + pack.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, @@ -4588,18 +4585,17 @@ fn convert_to_stateful_memory_access<'a>( arg_desc, expected_type, ) - }, - )?; + })?; result.push(new_statement); result.extend(post_statements); } _ => return Err(error_unreachable()), } } - for arg in func_args.input.iter_mut() { + for arg in (*func_args).input_arguments.iter_mut() { if func_args_ptr.contains(&arg.name) { - arg.v_type = ast::Type::Pointer(ast::ScalarType::U8); - arg.state_space = ast::StateSpace::Global; + arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global); + arg.state_space = ast::StateSpace::Reg; } } Ok(result) @@ -4612,43 +4608,40 @@ fn convert_to_stateful_memory_access_postprocess( result: &mut Vec, post_statements: &mut Vec, arg_desc: ArgumentDescriptor, - expected_type: Option<&ast::Type>, + expected_type: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { Ok(match remapped_ids.get(&arg_desc.op) { Some(new_id) => { // We skip conversion here to trigger PtrAcces in a later pass let old_type = match expected_type { - Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id), + Some((ast::Type::Pointer(_, ast::StateSpace::Global), ast::StateSpace::Reg)) => { + return Ok(*new_id) + } _ => id_defs.get_typed(arg_desc.op)?.0, }; let old_type_clone = old_type.clone(); - let converting_id = id_defs.register_intermediate(Some(old_type_clone)); + let converting_id = + id_defs.register_intermediate(Some((old_type_clone, ast::StateSpace::Reg))); if arg_desc.is_dst { post_statements.push(Statement::Conversion(ImplicitConversion { src: converting_id, dst: *new_id, from_type: old_type, - to_type: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::StateSpace::Global, - ), - kind: ConversionKind::BitToPtr(ast::StateSpace::Global), - src_ - dst_sema: arg_desc.sema, + from_space: ast::StateSpace::Reg, + to_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + to_space: ast::StateSpace::Reg, + kind: ConversionKind::BitToPtr, })); converting_id } else { result.push(Statement::Conversion(ImplicitConversion { src: *new_id, dst: converting_id, - from_type: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::StateSpace::Global, - ), + from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + from_space: ast::StateSpace::Reg, to_type: old_type, - kind: ConversionKind::PtrToBit(ast::ScalarType::U64), - src_sema: arg_desc.sema, - dst_ + to_space: ast::StateSpace::Reg, + kind: ConversionKind::AddressOf, })); converting_id } @@ -4660,22 +4653,23 @@ fn convert_to_stateful_memory_access_postprocess( } // We skip conversion here to trigger PtrAcces in a later pass let old_type = match expected_type { - Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id), + Some(( + ast::Type::Pointer(_, ast::StateSpace::Global), + ast::StateSpace::Reg, + )) => return Ok(*new_id), _ => id_defs.get_typed(arg_desc.op)?.0, }; let old_type_clone = old_type.clone(); - let converting_id = id_defs.register_intermediate(Some(old_type)); + let converting_id = + id_defs.register_intermediate(Some((old_type, ast::StateSpace::Reg))); result.push(Statement::Conversion(ImplicitConversion { src: *new_id, dst: converting_id, - from_type: ast::Type::Pointer( - ast::PointerType::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - ast::StateSpace::Param, - ), + from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + from_space: ast::StateSpace::Reg, to_type: old_type_clone, - kind: ConversionKind::PtrToPtr { spirv_ptr: false }, - src_sema: arg_desc.sema, - dst_ + to_space: ast::StateSpace::Reg, + kind: ConversionKind::PtrToPtr, })); converting_id } @@ -4710,7 +4704,6 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool { _ => false, } } -*/ #[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] enum PtxSpecialRegister { From 994cfb338655048ac274f913582aed214102b3d9 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 6 Jun 2021 21:51:40 +0200 Subject: [PATCH 20/25] Fix small bug in stateful postprocess --- ptx/src/translate.rs | 116 +++++++++++++++++++++---------------------- 1 file changed, 56 insertions(+), 60 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 61b255d..4c1c0e7 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -4136,6 +4136,11 @@ fn emit_implicit_conversion( let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; } + (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) + | (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => { + let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.bitcast(dst_type, Some(cv.dst), cv.src)?; + } _ => unreachable!(), } Ok(()) @@ -4610,72 +4615,63 @@ fn convert_to_stateful_memory_access_postprocess( arg_desc: ArgumentDescriptor, expected_type: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { - Ok(match remapped_ids.get(&arg_desc.op) { - Some(new_id) => { - // We skip conversion here to trigger PtrAcces in a later pass - let old_type = match expected_type { - Some((ast::Type::Pointer(_, ast::StateSpace::Global), ast::StateSpace::Reg)) => { - return Ok(*new_id) - } - _ => id_defs.get_typed(arg_desc.op)?.0, - }; - let old_type_clone = old_type.clone(); - let converting_id = - id_defs.register_intermediate(Some((old_type_clone, ast::StateSpace::Reg))); - if arg_desc.is_dst { - post_statements.push(Statement::Conversion(ImplicitConversion { - src: converting_id, - dst: *new_id, - from_type: old_type, - from_space: ast::StateSpace::Reg, - to_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - to_space: ast::StateSpace::Reg, - kind: ConversionKind::BitToPtr, - })); - converting_id - } else { - result.push(Statement::Conversion(ImplicitConversion { - src: *new_id, - dst: converting_id, - from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - from_space: ast::StateSpace::Reg, - to_type: old_type, - to_space: ast::StateSpace::Reg, - kind: ConversionKind::AddressOf, - })); - converting_id - } - } - None => match func_args_ptr.get(&arg_desc.op) { + Ok( + match remapped_ids + .get(&arg_desc.op) + .or_else(|| func_args_ptr.get(&arg_desc.op)) + { Some(new_id) => { - if arg_desc.is_dst { - return Err(error_unreachable()); + let (new_operand_type, new_operand_space, is_variable) = + id_defs.get_typed(*new_id)?; + if let Some((expected_type, expected_space)) = expected_type { + let implicit_conversion = arg_desc + .non_default_implicit_conversion + .unwrap_or(default_implicit_conversion); + if implicit_conversion( + (new_operand_space, &new_operand_type), + (expected_space, expected_type), + ) + .is_ok() + { + return Ok(*new_id); + } } - // We skip conversion here to trigger PtrAcces in a later pass - let old_type = match expected_type { - Some(( - ast::Type::Pointer(_, ast::StateSpace::Global), - ast::StateSpace::Reg, - )) => return Ok(*new_id), - _ => id_defs.get_typed(arg_desc.op)?.0, + let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?; + let new_operand_type_clone = new_operand_type.clone(); + let converting_id = id_defs + .register_intermediate(Some((old_operand_type.clone(), old_operand_space))); + let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) { + ConversionKind::Default + } else { + ConversionKind::PtrToPtr }; - let old_type_clone = old_type.clone(); - let converting_id = - id_defs.register_intermediate(Some((old_type, ast::StateSpace::Reg))); - result.push(Statement::Conversion(ImplicitConversion { - src: *new_id, - dst: converting_id, - from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - from_space: ast::StateSpace::Reg, - to_type: old_type_clone, - to_space: ast::StateSpace::Reg, - kind: ConversionKind::PtrToPtr, - })); - converting_id + if arg_desc.is_dst { + post_statements.push(Statement::Conversion(ImplicitConversion { + src: converting_id, + dst: *new_id, + from_type: old_operand_type, + from_space: old_operand_space, + to_type: new_operand_type, + to_space: new_operand_space, + kind, + })); + converting_id + } else { + result.push(Statement::Conversion(ImplicitConversion { + src: *new_id, + dst: converting_id, + from_type: new_operand_type, + from_space: new_operand_space, + to_type: old_operand_type, + to_space: old_operand_space, + kind, + })); + converting_id + } } None => arg_desc.op, }, - }) + ) } fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { From f0771e1fb6bb95e3f22b8bfa3a9efd3bfe88c946 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 11 Jun 2021 00:00:56 +0200 Subject: [PATCH 21/25] Slightly improve stateful optimization --- ptx/src/translate.rs | 156 ++++++++++++++++++++++-------------------- zluda_dump/src/lib.rs | 21 +++--- 2 files changed, 95 insertions(+), 82 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 4c1c0e7..511d763 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1239,8 +1239,8 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; - let typed_statements = - convert_to_stateful_memory_access(&func_decl, typed_statements, &mut numeric_id_defs)?; + let (func_decl, typed_statements) = + convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?; let ssa_statements = insert_mem_ssa_statements( typed_statements, &mut numeric_id_defs, @@ -4311,14 +4311,27 @@ fn expand_map_variables<'a, 'b>( // TODO: don't convert to ptr if the register is not ultimately used for ld/st // TODO: once insert_mem_ssa_statements is moved to later, move this pass after // argument expansion -// TODO: propagate through calls? +// TODO: propagate out of calls and into calls fn convert_to_stateful_memory_access<'a, 'input>( - func_args: &Rc>>, + func_args: Rc>>, func_body: Vec, id_defs: &mut NumericIdResolver<'a>, -) -> Result, TranslateError> { - let mut func_args = func_args.borrow_mut(); - let func_args_64bit = (*func_args) +) -> Result< + ( + Rc>>, + Vec, + ), + TranslateError, +> { + let mut method_decl = func_args.borrow_mut(); + if !method_decl.name.is_kernel() { + drop(method_decl); + return Ok((func_args, func_body)); + } + if Rc::strong_count(&func_args) != 1 { + return Err(error_unreachable()); + } + let func_args_64bit = (*method_decl) .input_arguments .iter() .filter_map(|arg| match arg.v_type { @@ -4462,6 +4475,18 @@ fn convert_to_stateful_memory_access<'a, 'input>( })); remapped_ids.insert(reg, new_id); } + for arg in (*method_decl).input_arguments.iter_mut() { + let new_id = id_defs.register_variable( + ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + ast::StateSpace::Reg, + ); + let old_name = arg.name; + if func_args_ptr.contains(&arg.name) { + arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global); + arg.name = new_id; + } + remapped_ids.insert(old_name, new_id); + } for statement in func_body { match statement { l @ Statement::Label(_) => result.push(l), @@ -4550,7 +4575,6 @@ fn convert_to_stateful_memory_access<'a, 'input>( convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, - &func_args_ptr, &mut result, &mut post_statements, arg_desc, @@ -4567,7 +4591,6 @@ fn convert_to_stateful_memory_access<'a, 'input>( convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, - &func_args_ptr, &mut result, &mut post_statements, arg_desc, @@ -4584,7 +4607,6 @@ fn convert_to_stateful_memory_access<'a, 'input>( convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, - &func_args_ptr, &mut result, &mut post_statements, arg_desc, @@ -4597,81 +4619,69 @@ fn convert_to_stateful_memory_access<'a, 'input>( _ => return Err(error_unreachable()), } } - for arg in (*func_args).input_arguments.iter_mut() { - if func_args_ptr.contains(&arg.name) { - arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global); - arg.state_space = ast::StateSpace::Reg; - } - } - Ok(result) + drop(method_decl); + Ok((func_args, result)) } fn convert_to_stateful_memory_access_postprocess( id_defs: &mut NumericIdResolver, remapped_ids: &HashMap, - func_args_ptr: &HashSet, result: &mut Vec, post_statements: &mut Vec, arg_desc: ArgumentDescriptor, expected_type: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { - Ok( - match remapped_ids - .get(&arg_desc.op) - .or_else(|| func_args_ptr.get(&arg_desc.op)) - { - Some(new_id) => { - let (new_operand_type, new_operand_space, is_variable) = - id_defs.get_typed(*new_id)?; - if let Some((expected_type, expected_space)) = expected_type { - let implicit_conversion = arg_desc - .non_default_implicit_conversion - .unwrap_or(default_implicit_conversion); - if implicit_conversion( - (new_operand_space, &new_operand_type), - (expected_space, expected_type), - ) - .is_ok() - { - return Ok(*new_id); - } - } - let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?; - let new_operand_type_clone = new_operand_type.clone(); - let converting_id = id_defs - .register_intermediate(Some((old_operand_type.clone(), old_operand_space))); - let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) { - ConversionKind::Default - } else { - ConversionKind::PtrToPtr - }; - if arg_desc.is_dst { - post_statements.push(Statement::Conversion(ImplicitConversion { - src: converting_id, - dst: *new_id, - from_type: old_operand_type, - from_space: old_operand_space, - to_type: new_operand_type, - to_space: new_operand_space, - kind, - })); - converting_id - } else { - result.push(Statement::Conversion(ImplicitConversion { - src: *new_id, - dst: converting_id, - from_type: new_operand_type, - from_space: new_operand_space, - to_type: old_operand_type, - to_space: old_operand_space, - kind, - })); - converting_id + Ok(match remapped_ids.get(&arg_desc.op) { + Some(new_id) => { + let (new_operand_type, new_operand_space, is_variable) = id_defs.get_typed(*new_id)?; + if let Some((expected_type, expected_space)) = expected_type { + let implicit_conversion = arg_desc + .non_default_implicit_conversion + .unwrap_or(default_implicit_conversion); + if implicit_conversion( + (new_operand_space, &new_operand_type), + (expected_space, expected_type), + ) + .is_ok() + { + return Ok(*new_id); } } - None => arg_desc.op, - }, - ) + let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?; + let new_operand_type_clone = new_operand_type.clone(); + let converting_id = + id_defs.register_intermediate(Some((old_operand_type.clone(), old_operand_space))); + let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) { + ConversionKind::Default + } else { + ConversionKind::PtrToPtr + }; + if arg_desc.is_dst { + post_statements.push(Statement::Conversion(ImplicitConversion { + src: converting_id, + dst: *new_id, + from_type: old_operand_type, + from_space: old_operand_space, + to_type: new_operand_type, + to_space: new_operand_space, + kind, + })); + converting_id + } else { + result.push(Statement::Conversion(ImplicitConversion { + src: *new_id, + dst: converting_id, + from_type: new_operand_type, + from_space: new_operand_space, + to_type: old_operand_type, + to_space: old_operand_space, + kind, + })); + converting_id + } + } + None => arg_desc.op, + }) } fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs index f168930..ffd1498 100644 --- a/zluda_dump/src/lib.rs +++ b/zluda_dump/src/lib.rs @@ -219,15 +219,18 @@ unsafe fn to_str(image: *const T) -> Option<&'static str> { fn directive_to_kernel(dir: &ast::Directive) -> Option<(String, Vec)> { match dir { - ast::Directive::Method(ast::Function { - func_directive: - ast::MethodDeclaration { - name: ast::MethodName::Kernel(name), - input_arguments, - .. - }, - .. - }) => { + ast::Directive::Method( + _, + ast::Function { + func_directive: + ast::MethodDeclaration { + name: ast::MethodName::Kernel(name), + input_arguments, + .. + }, + .. + }, + ) => { let arg_sizes = input_arguments .iter() .map(|arg| ast::Type::from(arg.v_type.clone()).size_of()) From 2198862e76d070abaf013f1a52d8a2a03649434b Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 11 Jun 2021 12:36:23 +0200 Subject: [PATCH 22/25] Fix handling of kernel args in stateful conversion --- ptx/src/translate.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 511d763..dc8cc5a 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -4136,10 +4136,13 @@ fn emit_implicit_conversion( let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; } - (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) - | (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => { - let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.bitcast(dst_type, Some(cv.dst), cv.src)?; + (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => { + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_ptr_to_u(result_type, Some(cv.dst), cv.src)?; + } + (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => { + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_u_to_ptr(result_type, Some(cv.dst), cv.src)?; } _ => unreachable!(), } @@ -4478,7 +4481,7 @@ fn convert_to_stateful_memory_access<'a, 'input>( for arg in (*method_decl).input_arguments.iter_mut() { let new_id = id_defs.register_variable( ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - ast::StateSpace::Reg, + ast::StateSpace::Param, ); let old_name = arg.name; if func_args_ptr.contains(&arg.name) { From 951c7558ccb2f8b14b31295faac7994c3ebdc4b5 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 12 Jun 2021 16:17:32 +0200 Subject: [PATCH 23/25] Fix problems with non-dereferencing inline addition --- ptx/src/test/spirv_run/verify.py | 21 ++++ ptx/src/translate.rs | 198 ++++++++++++++++++++----------- 2 files changed, 153 insertions(+), 66 deletions(-) create mode 100644 ptx/src/test/spirv_run/verify.py diff --git a/ptx/src/test/spirv_run/verify.py b/ptx/src/test/spirv_run/verify.py new file mode 100644 index 0000000..dbfab00 --- /dev/null +++ b/ptx/src/test/spirv_run/verify.py @@ -0,0 +1,21 @@ +import os, sys, subprocess + +def main(path): + dirs = os.listdir(path) + for file in dirs: + if not file.endswith(".spvtxt"): + continue + full_file = os.path.join(path, file) + print(file) + spv_file = f"/tmp/{file}.spv" + # We nominally emit spv1.3, but use spv1.4 feature (OpEntryPoint interface changes in 1.4) + proc1 = subprocess.run(["spirv-as", "--target-env", "spv1.4", full_file, "-o", spv_file]) + proc2 = subprocess.run(["spirv-dis", spv_file, "-o", f"{spv_file}.dis.txt"]) + proc3 = subprocess.run(["spirv-val", spv_file ]) + if proc1.returncode != 0 or proc2.returncode != 0 or proc3.returncode != 0: + print(proc1.returncode) + print(proc2.returncode) + print(proc3.returncode) + +if __name__ == "__main__": + main(sys.argv[1]) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index dc8cc5a..277db5c 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -2388,28 +2388,66 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { state_space: ast::StateSpace, ) -> Result { let (reg, offset) = desc.op; - match typ { - ast::Type::Scalar(underlying_type) => { - let id_constant_stmt = self.id_def.register_intermediate( - ast::Type::Scalar(ast::ScalarType::S64), - ast::StateSpace::Reg, - ); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: ast::ScalarType::S64, - value: ast::ImmediateValue::S64(offset as i64), - })); - let dst = self.id_def.register_intermediate(typ.clone(), state_space); - self.func.push(Statement::PtrAccess(PtrAccess { - underlying_type: *underlying_type, - state_space: state_space, - dst, - ptr_src: reg, - offset_src: id_constant_stmt, - })); - Ok(dst) + if !desc.is_memory_access { + let (reg_type, reg_space) = self.id_def.get_typed(reg)?; + if !reg_space.is_compatible(ast::StateSpace::Reg) { + return Err(TranslateError::MismatchedType); } - _ => Err(error_unreachable()), + let reg_scalar_type = match reg_type { + ast::Type::Scalar(underlying_type) => underlying_type, + _ => return Err(TranslateError::MismatchedType), + }; + let id_constant_stmt = self + .id_def + .register_intermediate(reg_type.clone(), ast::StateSpace::Reg); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: reg_scalar_type, + value: ast::ImmediateValue::S64(offset as i64), + })); + let arith_details = match reg_scalar_type.kind() { + ast::ScalarKind::Signed => ast::ArithDetails::Signed(ast::ArithSInt { + typ: reg_scalar_type, + saturate: false, + }), + ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => { + ast::ArithDetails::Unsigned(reg_scalar_type) + } + _ => return Err(error_unreachable()), + }; + let id_add_result = self.id_def.register_intermediate(reg_type, state_space); + self.func.push(Statement::Instruction(ast::Instruction::Add( + arith_details, + ast::Arg3 { + dst: id_add_result, + src1: reg, + src2: id_constant_stmt, + }, + ))); + Ok(id_add_result) + } else { + let scalar_type = match typ { + ast::Type::Scalar(underlying_type) => *underlying_type, + _ => return Err(error_unreachable()), + }; + let id_constant_stmt = self.id_def.register_intermediate( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: ast::ScalarType::S64, + value: ast::ImmediateValue::S64(offset as i64), + })); + let dst = self.id_def.register_intermediate(typ.clone(), state_space); + self.func.push(Statement::PtrAccess(PtrAccess { + underlying_type: scalar_type, + state_space: state_space, + dst, + ptr_src: reg, + offset_src: id_constant_stmt, + })); + Ok(dst) } } @@ -4399,6 +4437,10 @@ fn convert_to_stateful_memory_access<'a, 'input>( _ => {} } } + if stateful_markers.len() == 0 { + drop(method_decl); + return Ok((func_args, func_body)); + } let mut func_args_ptr = HashSet::new(); let mut regs_ptr_current = HashSet::new(); for (dst, src) in stateful_markers { @@ -4479,15 +4521,16 @@ fn convert_to_stateful_memory_access<'a, 'input>( remapped_ids.insert(reg, new_id); } for arg in (*method_decl).input_arguments.iter_mut() { + if !func_args_ptr.contains(&arg.name) { + continue; + } let new_id = id_defs.register_variable( ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), ast::StateSpace::Param, ); let old_name = arg.name; - if func_args_ptr.contains(&arg.name) { - arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global); - arg.name = new_id; - } + arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global); + arg.name = new_id; remapped_ids.insert(old_name, new_id); } for statement in func_body { @@ -5348,6 +5391,7 @@ impl RepackVectorDetails { ArgumentDescriptor { op: self.packed, is_dst: !self.is_extract, + is_memory_access: false, non_default_implicit_conversion: None, }, Some(( @@ -5366,6 +5410,7 @@ impl RepackVectorDetails { ArgumentDescriptor { op: id, is_dst: is_extract, + is_memory_access: false, non_default_implicit_conversion, }, Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)), @@ -5424,6 +5469,7 @@ impl> ResolvedCall { ArgumentDescriptor { op: id, is_dst: space != ast::StateSpace::Param, + is_memory_access: false, non_default_implicit_conversion: None, }, Some((&typ, space)), @@ -5435,7 +5481,7 @@ impl> ResolvedCall { ArgumentDescriptor { op: self.name, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, None, @@ -5448,6 +5494,7 @@ impl> ResolvedCall { ArgumentDescriptor { op: id, is_dst: false, + is_memory_access: false, non_default_implicit_conversion: None, }, &typ, @@ -5486,6 +5533,7 @@ impl> PtrAccess

{ ArgumentDescriptor { op: self.dst, is_dst: true, + is_memory_access: false, non_default_implicit_conversion: None, }, Some((&ptr_type, self.state_space)), @@ -5494,6 +5542,7 @@ impl> PtrAccess

{ ArgumentDescriptor { op: self.ptr_src, is_dst: false, + is_memory_access: false, non_default_implicit_conversion: None, }, Some((&ptr_type, self.state_space)), @@ -5502,7 +5551,7 @@ impl> PtrAccess

{ ArgumentDescriptor { op: self.offset_src, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::S64), @@ -5679,6 +5728,7 @@ where pub struct ArgumentDescriptor { op: Op, is_dst: bool, + is_memory_access: bool, non_default_implicit_conversion: Option< fn( (ast::StateSpace, &ast::Type), @@ -5714,7 +5764,8 @@ impl ArgumentDescriptor { ArgumentDescriptor { op: u, is_dst: self.is_dst, - non_default_implicit_conversion: None, + is_memory_access: self.is_memory_access, + non_default_implicit_conversion: self.non_default_implicit_conversion, } } } @@ -5953,6 +6004,7 @@ impl ImplicitConversion { ArgumentDescriptor { op: self.dst, is_dst: true, + is_memory_access: false, non_default_implicit_conversion: None, }, Some((&self.to_type, self.to_space)), @@ -5961,6 +6013,7 @@ impl ImplicitConversion { ArgumentDescriptor { op: self.src, is_dst: false, + is_memory_access: false, non_default_implicit_conversion: None, }, Some((&self.from_type, self.from_space)), @@ -6327,7 +6380,7 @@ impl ast::Arg1 { ArgumentDescriptor { op: self.src, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, t, @@ -6345,7 +6398,7 @@ impl ast::Arg1Bar { ArgumentDescriptor { op: self.src, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), @@ -6365,7 +6418,7 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.dst, is_dst: true, - + is_memory_access: false, non_default_implicit_conversion: None, }, t, @@ -6375,7 +6428,7 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.src, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, t, @@ -6397,7 +6450,7 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.dst, is_dst: true, - + is_memory_access: false, non_default_implicit_conversion: None, }, dst_t, @@ -6407,7 +6460,7 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.src, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, src_t, @@ -6427,6 +6480,7 @@ impl ast::Arg2Ld { ArgumentDescriptor { op: self.dst, is_dst: true, + is_memory_access: false, non_default_implicit_conversion: Some(should_convert_relaxed_dst_wrapper), }, &ast::Type::from(details.typ.clone()), @@ -6436,6 +6490,7 @@ impl ast::Arg2Ld { ArgumentDescriptor { op: self.src, is_dst: false, + is_memory_access: true, non_default_implicit_conversion: None, }, &details.typ, @@ -6455,6 +6510,7 @@ impl ast::Arg2St { ArgumentDescriptor { op: self.src1, is_dst: false, + is_memory_access: true, non_default_implicit_conversion: None, }, &details.typ, @@ -6464,6 +6520,7 @@ impl ast::Arg2St { ArgumentDescriptor { op: self.src2, is_dst: false, + is_memory_access: false, non_default_implicit_conversion: Some(should_convert_relaxed_src_wrapper), }, &details.typ.clone().into(), @@ -6483,7 +6540,7 @@ impl ast::Arg2Mov { ArgumentDescriptor { op: self.dst, is_dst: true, - + is_memory_access: false, non_default_implicit_conversion: None, }, &details.typ.clone().into(), @@ -6493,6 +6550,7 @@ impl ast::Arg2Mov { ArgumentDescriptor { op: self.src, is_dst: false, + is_memory_access: false, non_default_implicit_conversion: Some(implicit_conversion_mov), }, &details.typ.clone().into(), @@ -6518,7 +6576,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.dst, is_dst: true, - + is_memory_access: false, non_default_implicit_conversion: None, }, wide_type.as_ref().unwrap_or(typ), @@ -6528,7 +6586,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src1, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, typ, @@ -6538,6 +6596,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src2, is_dst: false, + is_memory_access: false, non_default_implicit_conversion: None, }, typ, @@ -6555,7 +6614,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.dst, is_dst: true, - + is_memory_access: false, non_default_implicit_conversion: None, }, t, @@ -6565,7 +6624,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src1, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, t, @@ -6575,7 +6634,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src2, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), @@ -6595,7 +6654,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.dst, is_dst: true, - + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), @@ -6605,6 +6664,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src1, is_dst: false, + is_memory_access: true, non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), @@ -6614,7 +6674,7 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.src2, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), @@ -6640,7 +6700,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst, is_dst: true, - + is_memory_access: false, non_default_implicit_conversion: None, }, wide_type.as_ref().unwrap_or(t), @@ -6650,7 +6710,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src1, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, t, @@ -6660,7 +6720,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src2, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, t, @@ -6670,6 +6730,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src3, is_dst: false, + is_memory_access: false, non_default_implicit_conversion: None, }, t, @@ -6692,6 +6753,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst, is_dst: true, + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(t.into()), @@ -6701,6 +6763,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src1, is_dst: false, + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(t.into()), @@ -6710,6 +6773,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src2, is_dst: false, + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(t.into()), @@ -6719,6 +6783,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src3, is_dst: false, + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::Pred), @@ -6743,7 +6808,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst, is_dst: true, - + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), @@ -6753,6 +6818,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src1, is_dst: false, + is_memory_access: true, non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), @@ -6762,7 +6828,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src2, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), @@ -6772,7 +6838,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src3, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), @@ -6795,7 +6861,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst, is_dst: true, - + is_memory_access: false, non_default_implicit_conversion: None, }, typ, @@ -6805,7 +6871,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src1, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, typ, @@ -6816,7 +6882,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src2, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, &u32_type, @@ -6826,7 +6892,7 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.src3, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, &u32_type, @@ -6851,7 +6917,7 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: self.dst1, is_dst: true, - + is_memory_access: false, non_default_implicit_conversion: None, }, Some(( @@ -6866,7 +6932,7 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: dst2, is_dst: true, - + is_memory_access: false, non_default_implicit_conversion: None, }, Some(( @@ -6880,7 +6946,7 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: self.src1, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, t, @@ -6890,7 +6956,7 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: self.src2, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, t, @@ -6915,7 +6981,7 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.dst, is_dst: true, - + is_memory_access: false, non_default_implicit_conversion: None, }, base_type, @@ -6925,7 +6991,7 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src1, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, base_type, @@ -6935,7 +7001,7 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src2, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, base_type, @@ -6945,7 +7011,7 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src3, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), @@ -6955,7 +7021,7 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.src4, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), @@ -6981,7 +7047,7 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: self.dst1, is_dst: true, - + is_memory_access: false, non_default_implicit_conversion: None, }, Some(( @@ -6996,7 +7062,7 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: dst2, is_dst: true, - + is_memory_access: false, non_default_implicit_conversion: None, }, Some(( @@ -7010,7 +7076,7 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: self.src1, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, t, @@ -7020,7 +7086,7 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: self.src2, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, t, @@ -7030,7 +7096,7 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: self.src3, is_dst: false, - + is_memory_access: false, non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::Pred), From 9a568e2969abbb28e614991fca69f364f9e2354e Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 25 Jun 2021 01:08:45 +0200 Subject: [PATCH 24/25] Update tests --- ptx/src/test/spirv_run/and.spvtxt | 10 +- ptx/src/test/spirv_run/atom_add.spvtxt | 17 ++- ptx/src/test/spirv_run/atom_add_float.spvtxt | 17 ++- ptx/src/test/spirv_run/atom_cas.spvtxt | 28 +++-- ptx/src/test/spirv_run/atom_inc.spvtxt | 36 +++--- ptx/src/test/spirv_run/bfe.spvtxt | 18 ++- ptx/src/test/spirv_run/bfi.spvtxt | 26 ++-- ptx/src/test/spirv_run/call.spvtxt | 4 +- ptx/src/test/spirv_run/cvt_rni.spvtxt | 18 ++- ptx/src/test/spirv_run/cvt_rzi.spvtxt | 18 ++- ptx/src/test/spirv_run/cvt_s32_f32.spvtxt | 21 ++-- ptx/src/test/spirv_run/div_approx.spvtxt | 10 +- ptx/src/test/spirv_run/extern_shared.spvtxt | 49 +++----- .../test/spirv_run/extern_shared_call.spvtxt | 103 +++++++--------- ptx/src/test/spirv_run/fma.spvtxt | 18 ++- ptx/src/test/spirv_run/ld_st_offset.spvtxt | 18 ++- ptx/src/test/spirv_run/mad_s32.spvtxt | 38 +++--- ptx/src/test/spirv_run/max.spvtxt | 10 +- ptx/src/test/spirv_run/min.spvtxt | 10 +- ptx/src/test/spirv_run/mul_ftz.spvtxt | 10 +- ptx/src/test/spirv_run/mul_non_ftz.spvtxt | 10 +- ptx/src/test/spirv_run/mul_wide.spvtxt | 22 ++-- ptx/src/test/spirv_run/or.spvtxt | 10 +- ptx/src/test/spirv_run/pred_not.spvtxt | 10 +- ptx/src/test/spirv_run/reg_local.spvtxt | 17 +-- ptx/src/test/spirv_run/rem.spvtxt | 10 +- ptx/src/test/spirv_run/selp.spvtxt | 10 +- ptx/src/test/spirv_run/selp_true.spvtxt | 10 +- ptx/src/test/spirv_run/setp.spvtxt | 10 +- ptx/src/test/spirv_run/setp_gt.spvtxt | 10 +- ptx/src/test/spirv_run/setp_leu.spvtxt | 10 +- ptx/src/test/spirv_run/setp_nan.spvtxt | 106 +++++++++------- ptx/src/test/spirv_run/setp_num.spvtxt | 114 +++++++++++------- ptx/src/test/spirv_run/shared_ptr_32.spvtxt | 11 +- .../spirv_run/shared_ptr_take_address.spvtxt | 55 ++++----- .../test/spirv_run/stateful_ld_st_ntid.spvtxt | 66 +++++----- .../stateful_ld_st_ntid_chain.spvtxt | 66 +++++----- ptx/src/test/spirv_run/vector.spvtxt | 2 +- ptx/src/test/spirv_run/xor.spvtxt | 10 +- 39 files changed, 602 insertions(+), 436 deletions(-) diff --git a/ptx/src/test/spirv_run/and.spvtxt b/ptx/src/test/spirv_run/and.spvtxt index a378602..f66639a 100644 --- a/ptx/src/test/spirv_run/and.spvtxt +++ b/ptx/src/test/spirv_run/and.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %34 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %41 = OpBitcast %_ptr_Generic_uchar %24 + %42 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %41 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %42 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/atom_add.spvtxt b/ptx/src/test/spirv_run/atom_add.spvtxt index 3966da6..b4de00a 100644 --- a/ptx/src/test/spirv_run/atom_add.spvtxt +++ b/ptx/src/test/spirv_run/atom_add.spvtxt @@ -24,6 +24,7 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %_ptr_Workgroup_uint = OpTypePointer Workgroup %uint %uint_1 = OpConstant %uint 1 %uint_0 = OpConstant %uint 0 @@ -49,9 +50,11 @@ %13 = OpLoad %uint %29 Aligned 4 OpStore %7 %13 %16 = OpLoad %ulong %5 - %26 = OpIAdd %ulong %16 %ulong_4 - %30 = OpConvertUToPtr %_ptr_Generic_uint %26 - %15 = OpLoad %uint %30 Aligned 4 + %30 = OpConvertUToPtr %_ptr_Generic_uint %16 + %51 = OpBitcast %_ptr_Generic_uchar %30 + %52 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %51 %ulong_4 + %26 = OpBitcast %_ptr_Generic_uint %52 + %15 = OpLoad %uint %26 Aligned 4 OpStore %8 %15 %17 = OpLoad %uint %7 %31 = OpBitcast %_ptr_Workgroup_uint %4 @@ -69,8 +72,10 @@ OpStore %34 %22 Aligned 4 %23 = OpLoad %ulong %6 %24 = OpLoad %uint %8 - %28 = OpIAdd %ulong %23 %ulong_4_0 - %35 = OpConvertUToPtr %_ptr_Generic_uint %28 - OpStore %35 %24 Aligned 4 + %35 = OpConvertUToPtr %_ptr_Generic_uint %23 + %56 = OpBitcast %_ptr_Generic_uchar %35 + %57 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %56 %ulong_4_0 + %28 = OpBitcast %_ptr_Generic_uint %57 + OpStore %28 %24 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_add_float.spvtxt b/ptx/src/test/spirv_run/atom_add_float.spvtxt index c2292f1..7d25632 100644 --- a/ptx/src/test/spirv_run/atom_add_float.spvtxt +++ b/ptx/src/test/spirv_run/atom_add_float.spvtxt @@ -28,6 +28,7 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_4_0 = OpConstant %ulong 4 %37 = OpFunction %float None %46 %39 = OpFunctionParameter %_ptr_Workgroup_float @@ -54,9 +55,11 @@ %13 = OpLoad %float %29 Aligned 4 OpStore %7 %13 %16 = OpLoad %ulong %5 - %26 = OpIAdd %ulong %16 %ulong_4 - %30 = OpConvertUToPtr %_ptr_Generic_float %26 - %15 = OpLoad %float %30 Aligned 4 + %30 = OpConvertUToPtr %_ptr_Generic_float %16 + %58 = OpBitcast %_ptr_Generic_uchar %30 + %59 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %58 %ulong_4 + %26 = OpBitcast %_ptr_Generic_float %59 + %15 = OpLoad %float %26 Aligned 4 OpStore %8 %15 %17 = OpLoad %float %7 %31 = OpBitcast %_ptr_Workgroup_float %4 @@ -74,8 +77,10 @@ OpStore %34 %22 Aligned 4 %23 = OpLoad %ulong %6 %24 = OpLoad %float %8 - %28 = OpIAdd %ulong %23 %ulong_4_0 - %35 = OpConvertUToPtr %_ptr_Generic_float %28 - OpStore %35 %24 Aligned 4 + %35 = OpConvertUToPtr %_ptr_Generic_float %23 + %60 = OpBitcast %_ptr_Generic_uchar %35 + %61 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %60 %ulong_4_0 + %28 = OpBitcast %_ptr_Generic_float %61 + OpStore %28 %24 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_cas.spvtxt b/ptx/src/test/spirv_run/atom_cas.spvtxt index e1feb0a..7c2f4fa 100644 --- a/ptx/src/test/spirv_run/atom_cas.spvtxt +++ b/ptx/src/test/spirv_run/atom_cas.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %uint_100 = OpConstant %uint 100 %uint_1 = OpConstant %uint 1 %uint_0 = OpConstant %uint 0 @@ -45,16 +47,20 @@ OpStore %6 %12 %15 = OpLoad %ulong %4 %16 = OpLoad %uint %6 - %24 = OpIAdd %ulong %15 %ulong_4 - %32 = OpConvertUToPtr %_ptr_Generic_uint %24 + %31 = OpConvertUToPtr %_ptr_Generic_uint %15 + %49 = OpBitcast %_ptr_Generic_uchar %31 + %50 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %49 %ulong_4 + %24 = OpBitcast %_ptr_Generic_uint %50 %33 = OpCopyObject %uint %16 - %31 = OpAtomicCompareExchange %uint %32 %uint_1 %uint_0 %uint_0 %uint_100 %33 - %14 = OpCopyObject %uint %31 + %32 = OpAtomicCompareExchange %uint %24 %uint_1 %uint_0 %uint_0 %uint_100 %33 + %14 = OpCopyObject %uint %32 OpStore %6 %14 %18 = OpLoad %ulong %4 - %27 = OpIAdd %ulong %18 %ulong_4_0 - %34 = OpConvertUToPtr %_ptr_Generic_uint %27 - %17 = OpLoad %uint %34 Aligned 4 + %34 = OpConvertUToPtr %_ptr_Generic_uint %18 + %53 = OpBitcast %_ptr_Generic_uchar %34 + %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_4_0 + %27 = OpBitcast %_ptr_Generic_uint %54 + %17 = OpLoad %uint %27 Aligned 4 OpStore %7 %17 %19 = OpLoad %ulong %5 %20 = OpLoad %uint %6 @@ -62,8 +68,10 @@ OpStore %35 %20 Aligned 4 %21 = OpLoad %ulong %5 %22 = OpLoad %uint %7 - %29 = OpIAdd %ulong %21 %ulong_4_1 - %36 = OpConvertUToPtr %_ptr_Generic_uint %29 - OpStore %36 %22 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_uint %21 + %55 = OpBitcast %_ptr_Generic_uchar %36 + %56 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %55 %ulong_4_1 + %29 = OpBitcast %_ptr_Generic_uint %56 + OpStore %29 %22 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_inc.spvtxt b/ptx/src/test/spirv_run/atom_inc.spvtxt index 11b4243..4855cd4 100644 --- a/ptx/src/test/spirv_run/atom_inc.spvtxt +++ b/ptx/src/test/spirv_run/atom_inc.spvtxt @@ -10,14 +10,14 @@ %47 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "atom_inc" - OpDecorate %42 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_global_inc" Import OpDecorate %38 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_generic_inc" Import + OpDecorate %42 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_global_inc" Import %void = OpTypeVoid %uint = OpTypeInt 32 0 -%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint - %51 = OpTypeFunction %uint %_ptr_CrossWorkgroup_uint %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %53 = OpTypeFunction %uint %_ptr_Generic_uint %uint + %51 = OpTypeFunction %uint %_ptr_Generic_uint %uint +%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint + %53 = OpTypeFunction %uint %_ptr_CrossWorkgroup_uint %uint %ulong = OpTypeInt 64 0 %55 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong @@ -25,15 +25,17 @@ %uint_101 = OpConstant %uint 101 %uint_101_0 = OpConstant %uint 101 %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 - %42 = OpFunction %uint None %51 - %44 = OpFunctionParameter %_ptr_CrossWorkgroup_uint - %45 = OpFunctionParameter %uint - OpFunctionEnd - %38 = OpFunction %uint None %53 + %38 = OpFunction %uint None %51 %40 = OpFunctionParameter %_ptr_Generic_uint %41 = OpFunctionParameter %uint OpFunctionEnd + %42 = OpFunction %uint None %53 + %44 = OpFunctionParameter %_ptr_CrossWorkgroup_uint + %45 = OpFunctionParameter %uint + OpFunctionEnd %1 = OpFunction %void None %55 %9 = OpFunctionParameter %ulong %10 = OpFunctionParameter %ulong @@ -69,13 +71,17 @@ OpStore %34 %20 Aligned 4 %21 = OpLoad %ulong %5 %22 = OpLoad %uint %7 - %28 = OpIAdd %ulong %21 %ulong_4 - %35 = OpConvertUToPtr %_ptr_Generic_uint %28 - OpStore %35 %22 Aligned 4 + %35 = OpConvertUToPtr %_ptr_Generic_uint %21 + %60 = OpBitcast %_ptr_Generic_uchar %35 + %61 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %60 %ulong_4 + %28 = OpBitcast %_ptr_Generic_uint %61 + OpStore %28 %22 Aligned 4 %23 = OpLoad %ulong %5 %24 = OpLoad %uint %8 - %30 = OpIAdd %ulong %23 %ulong_8 - %36 = OpConvertUToPtr %_ptr_Generic_uint %30 - OpStore %36 %24 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_uint %23 + %62 = OpBitcast %_ptr_Generic_uchar %36 + %63 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %62 %ulong_8 + %30 = OpBitcast %_ptr_Generic_uint %63 + OpStore %30 %24 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/bfe.spvtxt b/ptx/src/test/spirv_run/bfe.spvtxt index 535ede9..0001808 100644 --- a/ptx/src/test/spirv_run/bfe.spvtxt +++ b/ptx/src/test/spirv_run/bfe.spvtxt @@ -20,6 +20,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %34 = OpFunction %uint None %43 %36 = OpFunctionParameter %uint @@ -48,14 +50,18 @@ %13 = OpLoad %uint %29 Aligned 4 OpStore %6 %13 %16 = OpLoad %ulong %4 - %26 = OpIAdd %ulong %16 %ulong_4 - %30 = OpConvertUToPtr %_ptr_Generic_uint %26 - %15 = OpLoad %uint %30 Aligned 4 + %30 = OpConvertUToPtr %_ptr_Generic_uint %16 + %51 = OpBitcast %_ptr_Generic_uchar %30 + %52 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %51 %ulong_4 + %26 = OpBitcast %_ptr_Generic_uint %52 + %15 = OpLoad %uint %26 Aligned 4 OpStore %7 %15 %18 = OpLoad %ulong %4 - %28 = OpIAdd %ulong %18 %ulong_8 - %31 = OpConvertUToPtr %_ptr_Generic_uint %28 - %17 = OpLoad %uint %31 Aligned 4 + %31 = OpConvertUToPtr %_ptr_Generic_uint %18 + %53 = OpBitcast %_ptr_Generic_uchar %31 + %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_8 + %28 = OpBitcast %_ptr_Generic_uint %54 + %17 = OpLoad %uint %28 Aligned 4 OpStore %8 %17 %20 = OpLoad %uint %6 %21 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/bfi.spvtxt b/ptx/src/test/spirv_run/bfi.spvtxt index dc8f683..1979939 100644 --- a/ptx/src/test/spirv_run/bfi.spvtxt +++ b/ptx/src/test/spirv_run/bfi.spvtxt @@ -20,6 +20,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %ulong_12 = OpConstant %ulong 12 %44 = OpFunction %uint None %54 @@ -51,19 +53,25 @@ %14 = OpLoad %uint %35 Aligned 4 OpStore %6 %14 %17 = OpLoad %ulong %4 - %30 = OpIAdd %ulong %17 %ulong_4 - %36 = OpConvertUToPtr %_ptr_Generic_uint %30 - %16 = OpLoad %uint %36 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_uint %17 + %62 = OpBitcast %_ptr_Generic_uchar %36 + %63 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %62 %ulong_4 + %30 = OpBitcast %_ptr_Generic_uint %63 + %16 = OpLoad %uint %30 Aligned 4 OpStore %7 %16 %19 = OpLoad %ulong %4 - %32 = OpIAdd %ulong %19 %ulong_8 - %37 = OpConvertUToPtr %_ptr_Generic_uint %32 - %18 = OpLoad %uint %37 Aligned 4 + %37 = OpConvertUToPtr %_ptr_Generic_uint %19 + %64 = OpBitcast %_ptr_Generic_uchar %37 + %65 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %64 %ulong_8 + %32 = OpBitcast %_ptr_Generic_uint %65 + %18 = OpLoad %uint %32 Aligned 4 OpStore %8 %18 %21 = OpLoad %ulong %4 - %34 = OpIAdd %ulong %21 %ulong_12 - %38 = OpConvertUToPtr %_ptr_Generic_uint %34 - %20 = OpLoad %uint %38 Aligned 4 + %38 = OpConvertUToPtr %_ptr_Generic_uint %21 + %66 = OpBitcast %_ptr_Generic_uchar %38 + %67 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %66 %ulong_12 + %34 = OpBitcast %_ptr_Generic_uint %67 + %20 = OpLoad %uint %34 Aligned 4 OpStore %9 %20 %23 = OpLoad %uint %6 %24 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/call.spvtxt b/ptx/src/test/spirv_run/call.spvtxt index 5473234..6929b1e 100644 --- a/ptx/src/test/spirv_run/call.spvtxt +++ b/ptx/src/test/spirv_run/call.spvtxt @@ -42,7 +42,7 @@ %23 = OpBitcast %_ptr_Function_ulong %10 %24 = OpCopyObject %ulong %18 OpStore %23 %24 Aligned 8 - %43 = OpFunctionCall %void %1 %11 %10 + %43 = OpFunctionCall %void %1 %10 %11 %19 = OpLoad %ulong %11 Aligned 8 OpStore %9 %19 %20 = OpLoad %ulong %8 @@ -52,8 +52,8 @@ OpReturn OpFunctionEnd %1 = OpFunction %void None %44 - %27 = OpFunctionParameter %_ptr_Function_ulong %28 = OpFunctionParameter %_ptr_Function_ulong + %27 = OpFunctionParameter %_ptr_Function_ulong %35 = OpLabel %29 = OpVariable %_ptr_Function_ulong Function %30 = OpLoad %ulong %28 Aligned 8 diff --git a/ptx/src/test/spirv_run/cvt_rni.spvtxt b/ptx/src/test/spirv_run/cvt_rni.spvtxt index 288a939..e10999c 100644 --- a/ptx/src/test/spirv_run/cvt_rni.spvtxt +++ b/ptx/src/test/spirv_run/cvt_rni.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_4_0 = OpConstant %ulong 4 %1 = OpFunction %void None %37 %8 = OpFunctionParameter %ulong @@ -40,9 +42,11 @@ %12 = OpLoad %float %28 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %25 = OpIAdd %ulong %15 %ulong_4 - %29 = OpConvertUToPtr %_ptr_Generic_float %25 - %14 = OpLoad %float %29 Aligned 4 + %29 = OpConvertUToPtr %_ptr_Generic_float %15 + %44 = OpBitcast %_ptr_Generic_uchar %29 + %45 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %44 %ulong_4 + %25 = OpBitcast %_ptr_Generic_float %45 + %14 = OpLoad %float %25 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %16 = OpExtInst %float %34 rint %17 @@ -56,8 +60,10 @@ OpStore %30 %21 Aligned 4 %22 = OpLoad %ulong %5 %23 = OpLoad %float %7 - %27 = OpIAdd %ulong %22 %ulong_4_0 - %31 = OpConvertUToPtr %_ptr_Generic_float %27 - OpStore %31 %23 Aligned 4 + %31 = OpConvertUToPtr %_ptr_Generic_float %22 + %46 = OpBitcast %_ptr_Generic_uchar %31 + %47 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %46 %ulong_4_0 + %27 = OpBitcast %_ptr_Generic_float %47 + OpStore %27 %23 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_rzi.spvtxt b/ptx/src/test/spirv_run/cvt_rzi.spvtxt index 68c12c6..7dda454 100644 --- a/ptx/src/test/spirv_run/cvt_rzi.spvtxt +++ b/ptx/src/test/spirv_run/cvt_rzi.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_4_0 = OpConstant %ulong 4 %1 = OpFunction %void None %37 %8 = OpFunctionParameter %ulong @@ -40,9 +42,11 @@ %12 = OpLoad %float %28 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %25 = OpIAdd %ulong %15 %ulong_4 - %29 = OpConvertUToPtr %_ptr_Generic_float %25 - %14 = OpLoad %float %29 Aligned 4 + %29 = OpConvertUToPtr %_ptr_Generic_float %15 + %44 = OpBitcast %_ptr_Generic_uchar %29 + %45 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %44 %ulong_4 + %25 = OpBitcast %_ptr_Generic_float %45 + %14 = OpLoad %float %25 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %16 = OpExtInst %float %34 trunc %17 @@ -56,8 +60,10 @@ OpStore %30 %21 Aligned 4 %22 = OpLoad %ulong %5 %23 = OpLoad %float %7 - %27 = OpIAdd %ulong %22 %ulong_4_0 - %31 = OpConvertUToPtr %_ptr_Generic_float %27 - OpStore %31 %23 Aligned 4 + %31 = OpConvertUToPtr %_ptr_Generic_float %22 + %46 = OpBitcast %_ptr_Generic_uchar %31 + %47 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %46 %ulong_4_0 + %27 = OpBitcast %_ptr_Generic_float %47 + OpStore %27 %23 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt b/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt index d9ae053..c1229d4 100644 --- a/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt +++ b/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt @@ -21,8 +21,11 @@ %float = OpTypeFloat 32 %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint %ulong_4_0 = OpConstant %ulong 4 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar %1 = OpFunction %void None %45 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -45,10 +48,12 @@ %12 = OpBitcast %uint %28 OpStore %6 %12 %15 = OpLoad %ulong %4 - %25 = OpIAdd %ulong %15 %ulong_4 - %31 = OpConvertUToPtr %_ptr_Generic_float %25 - %30 = OpLoad %float %31 Aligned 4 - %14 = OpBitcast %uint %30 + %30 = OpConvertUToPtr %_ptr_Generic_float %15 + %53 = OpBitcast %_ptr_Generic_uchar %30 + %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_4 + %25 = OpBitcast %_ptr_Generic_float %54 + %31 = OpLoad %float %25 Aligned 4 + %14 = OpBitcast %uint %31 OpStore %7 %14 %17 = OpLoad %uint %6 %33 = OpBitcast %float %17 @@ -67,9 +72,11 @@ OpStore %36 %37 Aligned 4 %22 = OpLoad %ulong %5 %23 = OpLoad %uint %7 - %27 = OpIAdd %ulong %22 %ulong_4_0 - %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %27 + %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %22 + %57 = OpBitcast %_ptr_CrossWorkgroup_uchar %38 + %58 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %57 %ulong_4_0 + %27 = OpBitcast %_ptr_CrossWorkgroup_uint %58 %39 = OpCopyObject %uint %23 - OpStore %38 %39 Aligned 4 + OpStore %27 %39 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/div_approx.spvtxt b/ptx/src/test/spirv_run/div_approx.spvtxt index 274f73e..858ec8d 100644 --- a/ptx/src/test/spirv_run/div_approx.spvtxt +++ b/ptx/src/test/spirv_run/div_approx.spvtxt @@ -19,6 +19,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -40,9 +42,11 @@ %12 = OpLoad %float %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_float %22 - %14 = OpLoad %float %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_float %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_float %39 + %14 = OpLoad %float %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %18 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/extern_shared.spvtxt b/ptx/src/test/spirv_run/extern_shared.spvtxt index fb2987e..13587d5 100644 --- a/ptx/src/test/spirv_run/extern_shared.spvtxt +++ b/ptx/src/test/spirv_run/extern_shared.spvtxt @@ -7,37 +7,30 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %30 = OpExtInstImport "OpenCL.std" + %27 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %2 "extern_shared" %1 %void = OpTypeVoid %uint = OpTypeInt 32 0 %_ptr_Workgroup_uint = OpTypePointer Workgroup %uint -%_ptr_Workgroup__ptr_Workgroup_uint = OpTypePointer Workgroup %_ptr_Workgroup_uint - %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uint Workgroup + %1 = OpVariable %_ptr_Workgroup_uint Workgroup %ulong = OpTypeInt 64 0 %uchar = OpTypeInt 8 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %38 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar -%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar + %34 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong -%_ptr_Function__ptr_Workgroup_uint = OpTypePointer Function %_ptr_Workgroup_uint %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %2 = OpFunction %void None %38 + %2 = OpFunction %void None %34 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong - %26 = OpFunctionParameter %_ptr_Workgroup_uchar - %39 = OpLabel - %27 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %24 = OpFunctionParameter %_ptr_Workgroup_uchar + %22 = OpLabel %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function %5 = OpVariable %_ptr_Function_ulong Function %6 = OpVariable %_ptr_Function_ulong Function %7 = OpVariable %_ptr_Function_ulong Function - OpStore %27 %26 - OpBranch %24 - %24 = OpLabel OpStore %3 %8 OpStore %4 %9 %10 = OpLoad %ulong %3 Aligned 8 @@ -45,22 +38,20 @@ %11 = OpLoad %ulong %4 Aligned 8 OpStore %6 %11 %13 = OpLoad %ulong %5 - %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %13 - %12 = OpLoad %ulong %20 Aligned 8 + %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %13 + %12 = OpLoad %ulong %18 Aligned 8 OpStore %7 %12 - %28 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %27 - %14 = OpLoad %_ptr_Workgroup_uint %28 - %15 = OpLoad %ulong %7 - %21 = OpBitcast %_ptr_Workgroup_ulong %14 - OpStore %21 %15 Aligned 8 - %29 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %27 - %17 = OpLoad %_ptr_Workgroup_uint %29 - %22 = OpBitcast %_ptr_Workgroup_ulong %17 - %16 = OpLoad %ulong %22 Aligned 8 - OpStore %7 %16 - %18 = OpLoad %ulong %6 - %19 = OpLoad %ulong %7 - %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %18 - OpStore %23 %19 Aligned 8 + %14 = OpLoad %ulong %7 + %25 = OpBitcast %_ptr_Workgroup_uint %24 + %19 = OpBitcast %_ptr_Workgroup_ulong %25 + OpStore %19 %14 Aligned 8 + %26 = OpBitcast %_ptr_Workgroup_uint %24 + %20 = OpBitcast %_ptr_Workgroup_ulong %26 + %15 = OpLoad %ulong %20 Aligned 8 + OpStore %7 %15 + %16 = OpLoad %ulong %6 + %17 = OpLoad %ulong %7 + %21 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 + OpStore %21 %17 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/extern_shared_call.spvtxt b/ptx/src/test/spirv_run/extern_shared_call.spvtxt index 7043172..5af7168 100644 --- a/ptx/src/test/spirv_run/extern_shared_call.spvtxt +++ b/ptx/src/test/spirv_run/extern_shared_call.spvtxt @@ -7,87 +7,72 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %46 = OpExtInstImport "OpenCL.std" + %40 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %14 "extern_shared_call" %1 + OpEntryPoint Kernel %12 "extern_shared_call" %1 OpDecorate %1 Alignment 4 %void = OpTypeVoid %uint = OpTypeInt 32 0 %_ptr_Workgroup_uint = OpTypePointer Workgroup %uint -%_ptr_Workgroup__ptr_Workgroup_uint = OpTypePointer Workgroup %_ptr_Workgroup_uint - %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uint Workgroup + %1 = OpVariable %_ptr_Workgroup_uint Workgroup %uchar = OpTypeInt 8 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %53 = OpTypeFunction %void %_ptr_Workgroup_uchar -%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar + %46 = OpTypeFunction %void %_ptr_Workgroup_uchar %ulong = OpTypeInt 64 0 %_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function__ptr_Workgroup_uint = OpTypePointer Function %_ptr_Workgroup_uint %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong %ulong_2 = OpConstant %ulong 2 - %60 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar + %50 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %2 = OpFunction %void None %53 - %38 = OpFunctionParameter %_ptr_Workgroup_uchar - %54 = OpLabel - %39 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %2 = OpFunction %void None %46 + %34 = OpFunctionParameter %_ptr_Workgroup_uchar + %11 = OpLabel %3 = OpVariable %_ptr_Function_ulong Function - OpStore %39 %38 - OpBranch %13 - %13 = OpLabel - %40 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %39 - %5 = OpLoad %_ptr_Workgroup_uint %40 - %11 = OpBitcast %_ptr_Workgroup_ulong %5 - %4 = OpLoad %ulong %11 Aligned 8 + %35 = OpBitcast %_ptr_Workgroup_uint %34 + %9 = OpBitcast %_ptr_Workgroup_ulong %35 + %4 = OpLoad %ulong %9 Aligned 8 OpStore %3 %4 + %6 = OpLoad %ulong %3 + %5 = OpIAdd %ulong %6 %ulong_2 + OpStore %3 %5 %7 = OpLoad %ulong %3 - %6 = OpIAdd %ulong %7 %ulong_2 - OpStore %3 %6 - %41 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %39 - %8 = OpLoad %_ptr_Workgroup_uint %41 - %9 = OpLoad %ulong %3 - %12 = OpBitcast %_ptr_Workgroup_ulong %8 - OpStore %12 %9 Aligned 8 + %36 = OpBitcast %_ptr_Workgroup_uint %34 + %10 = OpBitcast %_ptr_Workgroup_ulong %36 + OpStore %10 %7 Aligned 8 OpReturn OpFunctionEnd - %14 = OpFunction %void None %60 - %20 = OpFunctionParameter %ulong - %21 = OpFunctionParameter %ulong - %42 = OpFunctionParameter %_ptr_Workgroup_uchar - %61 = OpLabel - %43 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %12 = OpFunction %void None %50 + %18 = OpFunctionParameter %ulong + %19 = OpFunctionParameter %ulong + %37 = OpFunctionParameter %_ptr_Workgroup_uchar + %32 = OpLabel + %13 = OpVariable %_ptr_Function_ulong Function + %14 = OpVariable %_ptr_Function_ulong Function %15 = OpVariable %_ptr_Function_ulong Function %16 = OpVariable %_ptr_Function_ulong Function %17 = OpVariable %_ptr_Function_ulong Function - %18 = OpVariable %_ptr_Function_ulong Function - %19 = OpVariable %_ptr_Function_ulong Function - OpStore %43 %42 - OpBranch %36 - %36 = OpLabel + OpStore %13 %18 + OpStore %14 %19 + %20 = OpLoad %ulong %13 Aligned 8 OpStore %15 %20 + %21 = OpLoad %ulong %14 Aligned 8 OpStore %16 %21 - %22 = OpLoad %ulong %15 Aligned 8 + %23 = OpLoad %ulong %15 + %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %23 + %22 = OpLoad %ulong %28 Aligned 8 OpStore %17 %22 - %23 = OpLoad %ulong %16 Aligned 8 - OpStore %18 %23 - %25 = OpLoad %ulong %17 - %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %25 - %24 = OpLoad %ulong %32 Aligned 8 - OpStore %19 %24 - %44 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %43 - %26 = OpLoad %_ptr_Workgroup_uint %44 - %27 = OpLoad %ulong %19 - %33 = OpBitcast %_ptr_Workgroup_ulong %26 - OpStore %33 %27 Aligned 8 - %63 = OpFunctionCall %void %2 %42 - %45 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %43 - %29 = OpLoad %_ptr_Workgroup_uint %45 - %34 = OpBitcast %_ptr_Workgroup_ulong %29 - %28 = OpLoad %ulong %34 Aligned 8 - OpStore %19 %28 - %30 = OpLoad %ulong %18 - %31 = OpLoad %ulong %19 - %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %30 - OpStore %35 %31 Aligned 8 + %24 = OpLoad %ulong %17 + %38 = OpBitcast %_ptr_Workgroup_uint %37 + %29 = OpBitcast %_ptr_Workgroup_ulong %38 + OpStore %29 %24 Aligned 8 + %52 = OpFunctionCall %void %2 %37 + %39 = OpBitcast %_ptr_Workgroup_uint %37 + %30 = OpBitcast %_ptr_Workgroup_ulong %39 + %25 = OpLoad %ulong %30 Aligned 8 + OpStore %17 %25 + %26 = OpLoad %ulong %16 + %27 = OpLoad %ulong %17 + %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %26 + OpStore %31 %27 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/fma.spvtxt b/ptx/src/test/spirv_run/fma.spvtxt index 300a328..8cc0e16 100644 --- a/ptx/src/test/spirv_run/fma.spvtxt +++ b/ptx/src/test/spirv_run/fma.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %1 = OpFunction %void None %38 %9 = OpFunctionParameter %ulong @@ -41,14 +43,18 @@ %13 = OpLoad %float %29 Aligned 4 OpStore %6 %13 %16 = OpLoad %ulong %4 - %26 = OpIAdd %ulong %16 %ulong_4 - %30 = OpConvertUToPtr %_ptr_Generic_float %26 - %15 = OpLoad %float %30 Aligned 4 + %30 = OpConvertUToPtr %_ptr_Generic_float %16 + %45 = OpBitcast %_ptr_Generic_uchar %30 + %46 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %45 %ulong_4 + %26 = OpBitcast %_ptr_Generic_float %46 + %15 = OpLoad %float %26 Aligned 4 OpStore %7 %15 %18 = OpLoad %ulong %4 - %28 = OpIAdd %ulong %18 %ulong_8 - %31 = OpConvertUToPtr %_ptr_Generic_float %28 - %17 = OpLoad %float %31 Aligned 4 + %31 = OpConvertUToPtr %_ptr_Generic_float %18 + %47 = OpBitcast %_ptr_Generic_uchar %31 + %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_8 + %28 = OpBitcast %_ptr_Generic_float %48 + %17 = OpLoad %float %28 Aligned 4 OpStore %8 %17 %20 = OpLoad %float %6 %21 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/ld_st_offset.spvtxt b/ptx/src/test/spirv_run/ld_st_offset.spvtxt index 5e314a0..ea97222 100644 --- a/ptx/src/test/spirv_run/ld_st_offset.spvtxt +++ b/ptx/src/test/spirv_run/ld_st_offset.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_4_0 = OpConstant %ulong 4 %1 = OpFunction %void None %33 %8 = OpFunctionParameter %ulong @@ -40,9 +42,11 @@ %12 = OpLoad %uint %24 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %21 = OpIAdd %ulong %15 %ulong_4 - %25 = OpConvertUToPtr %_ptr_Generic_uint %21 - %14 = OpLoad %uint %25 Aligned 4 + %25 = OpConvertUToPtr %_ptr_Generic_uint %15 + %40 = OpBitcast %_ptr_Generic_uchar %25 + %41 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %40 %ulong_4 + %21 = OpBitcast %_ptr_Generic_uint %41 + %14 = OpLoad %uint %21 Aligned 4 OpStore %7 %14 %16 = OpLoad %ulong %5 %17 = OpLoad %uint %7 @@ -50,8 +54,10 @@ OpStore %26 %17 Aligned 4 %18 = OpLoad %ulong %5 %19 = OpLoad %uint %6 - %23 = OpIAdd %ulong %18 %ulong_4_0 - %27 = OpConvertUToPtr %_ptr_Generic_uint %23 - OpStore %27 %19 Aligned 4 + %27 = OpConvertUToPtr %_ptr_Generic_uint %18 + %42 = OpBitcast %_ptr_Generic_uchar %27 + %43 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %42 %ulong_4_0 + %23 = OpBitcast %_ptr_Generic_uint %43 + OpStore %23 %19 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mad_s32.spvtxt b/ptx/src/test/spirv_run/mad_s32.spvtxt index bb44af0..0ee3ca7 100644 --- a/ptx/src/test/spirv_run/mad_s32.spvtxt +++ b/ptx/src/test/spirv_run/mad_s32.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %ulong_4_0 = OpConstant %ulong 4 %ulong_8_0 = OpConstant %ulong 8 @@ -44,20 +46,24 @@ %14 = OpLoad %uint %38 Aligned 4 OpStore %7 %14 %17 = OpLoad %ulong %4 - %31 = OpIAdd %ulong %17 %ulong_4 - %39 = OpConvertUToPtr %_ptr_Generic_uint %31 - %16 = OpLoad %uint %39 Aligned 4 + %39 = OpConvertUToPtr %_ptr_Generic_uint %17 + %56 = OpBitcast %_ptr_Generic_uchar %39 + %57 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %56 %ulong_4 + %31 = OpBitcast %_ptr_Generic_uint %57 + %16 = OpLoad %uint %31 Aligned 4 OpStore %8 %16 %19 = OpLoad %ulong %4 - %33 = OpIAdd %ulong %19 %ulong_8 - %40 = OpConvertUToPtr %_ptr_Generic_uint %33 - %18 = OpLoad %uint %40 Aligned 4 + %40 = OpConvertUToPtr %_ptr_Generic_uint %19 + %58 = OpBitcast %_ptr_Generic_uchar %40 + %59 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %58 %ulong_8 + %33 = OpBitcast %_ptr_Generic_uint %59 + %18 = OpLoad %uint %33 Aligned 4 OpStore %9 %18 %21 = OpLoad %uint %7 %22 = OpLoad %uint %8 %23 = OpLoad %uint %9 - %54 = OpIMul %uint %21 %22 - %20 = OpIAdd %uint %23 %54 + %60 = OpIMul %uint %21 %22 + %20 = OpIAdd %uint %23 %60 OpStore %6 %20 %24 = OpLoad %ulong %5 %25 = OpLoad %uint %6 @@ -65,13 +71,17 @@ OpStore %41 %25 Aligned 4 %26 = OpLoad %ulong %5 %27 = OpLoad %uint %6 - %35 = OpIAdd %ulong %26 %ulong_4_0 - %42 = OpConvertUToPtr %_ptr_Generic_uint %35 - OpStore %42 %27 Aligned 4 + %42 = OpConvertUToPtr %_ptr_Generic_uint %26 + %61 = OpBitcast %_ptr_Generic_uchar %42 + %62 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %61 %ulong_4_0 + %35 = OpBitcast %_ptr_Generic_uint %62 + OpStore %35 %27 Aligned 4 %28 = OpLoad %ulong %5 %29 = OpLoad %uint %6 - %37 = OpIAdd %ulong %28 %ulong_8_0 - %43 = OpConvertUToPtr %_ptr_Generic_uint %37 - OpStore %43 %29 Aligned 4 + %43 = OpConvertUToPtr %_ptr_Generic_uint %28 + %63 = OpBitcast %_ptr_Generic_uchar %43 + %64 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %63 %ulong_8_0 + %37 = OpBitcast %_ptr_Generic_uint %64 + OpStore %37 %29 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/max.spvtxt b/ptx/src/test/spirv_run/max.spvtxt index d3ffa2f..86b732a 100644 --- a/ptx/src/test/spirv_run/max.spvtxt +++ b/ptx/src/test/spirv_run/max.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %39 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/min.spvtxt b/ptx/src/test/spirv_run/min.spvtxt index de2e35e..a187376 100644 --- a/ptx/src/test/spirv_run/min.spvtxt +++ b/ptx/src/test/spirv_run/min.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %39 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/mul_ftz.spvtxt b/ptx/src/test/spirv_run/mul_ftz.spvtxt index ed268fb..e7a4a56 100644 --- a/ptx/src/test/spirv_run/mul_ftz.spvtxt +++ b/ptx/src/test/spirv_run/mul_ftz.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %float %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_float %22 - %14 = OpLoad %float %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_float %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_float %39 + %14 = OpLoad %float %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %18 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/mul_non_ftz.spvtxt b/ptx/src/test/spirv_run/mul_non_ftz.spvtxt index 436aca1..5326baa 100644 --- a/ptx/src/test/spirv_run/mul_non_ftz.spvtxt +++ b/ptx/src/test/spirv_run/mul_non_ftz.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %float %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_float %22 - %14 = OpLoad %float %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_float %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_float %39 + %14 = OpLoad %float %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %18 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/mul_wide.spvtxt b/ptx/src/test/spirv_run/mul_wide.spvtxt index 7ac81cf..e96a964 100644 --- a/ptx/src/test/spirv_run/mul_wide.spvtxt +++ b/ptx/src/test/spirv_run/mul_wide.spvtxt @@ -18,7 +18,9 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint %ulong_4 = OpConstant %ulong 4 - %_struct_38 = OpTypeStruct %uint %uint + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %_struct_42 = OpTypeStruct %uint %uint %v2uint = OpTypeVector %uint 2 %_ptr_Generic_ulong = OpTypePointer Generic %ulong %1 = OpFunction %void None %33 @@ -43,17 +45,19 @@ %13 = OpLoad %uint %24 Aligned 4 OpStore %6 %13 %16 = OpLoad %ulong %4 - %23 = OpIAdd %ulong %16 %ulong_4 - %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %23 - %15 = OpLoad %uint %25 Aligned 4 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %16 + %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %25 + %41 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %40 %ulong_4 + %23 = OpBitcast %_ptr_CrossWorkgroup_uint %41 + %15 = OpLoad %uint %23 Aligned 4 OpStore %7 %15 %18 = OpLoad %uint %6 %19 = OpLoad %uint %7 - %39 = OpSMulExtended %_struct_38 %18 %19 - %40 = OpCompositeExtract %uint %39 0 - %41 = OpCompositeExtract %uint %39 1 - %43 = OpCompositeConstruct %v2uint %40 %41 - %17 = OpBitcast %ulong %43 + %43 = OpSMulExtended %_struct_42 %18 %19 + %44 = OpCompositeExtract %uint %43 0 + %45 = OpCompositeExtract %uint %43 1 + %47 = OpCompositeConstruct %v2uint %44 %45 + %17 = OpBitcast %ulong %47 OpStore %8 %17 %20 = OpLoad %ulong %5 %21 = OpLoad %ulong %8 diff --git a/ptx/src/test/spirv_run/or.spvtxt b/ptx/src/test/spirv_run/or.spvtxt index fef3f40..82db00c 100644 --- a/ptx/src/test/spirv_run/or.spvtxt +++ b/ptx/src/test/spirv_run/or.spvtxt @@ -16,6 +16,8 @@ %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Generic_ulong = OpTypePointer Generic %ulong %ulong_8 = OpConstant %ulong 8 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %34 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -37,9 +39,11 @@ %12 = OpLoad %ulong %23 Aligned 8 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_8 - %24 = OpConvertUToPtr %_ptr_Generic_ulong %22 - %14 = OpLoad %ulong %24 Aligned 8 + %24 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %39 = OpBitcast %_ptr_Generic_uchar %24 + %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_8 + %22 = OpBitcast %_ptr_Generic_ulong %40 + %14 = OpLoad %ulong %22 Aligned 8 OpStore %7 %14 %17 = OpLoad %ulong %6 %18 = OpLoad %ulong %7 diff --git a/ptx/src/test/spirv_run/pred_not.spvtxt b/ptx/src/test/spirv_run/pred_not.spvtxt index 18fde05..644731b 100644 --- a/ptx/src/test/spirv_run/pred_not.spvtxt +++ b/ptx/src/test/spirv_run/pred_not.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_ulong = OpTypePointer Generic %ulong %ulong_8 = OpConstant %ulong 8 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %true = OpConstantTrue %bool %false = OpConstantFalse %bool %ulong_1 = OpConstant %ulong 1 @@ -45,9 +47,11 @@ %18 = OpLoad %ulong %37 Aligned 8 OpStore %6 %18 %21 = OpLoad %ulong %4 - %34 = OpIAdd %ulong %21 %ulong_8 - %38 = OpConvertUToPtr %_ptr_Generic_ulong %34 - %20 = OpLoad %ulong %38 Aligned 8 + %38 = OpConvertUToPtr %_ptr_Generic_ulong %21 + %52 = OpBitcast %_ptr_Generic_uchar %38 + %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_8 + %34 = OpBitcast %_ptr_Generic_ulong %53 + %20 = OpLoad %ulong %34 Aligned 8 OpStore %7 %20 %23 = OpLoad %ulong %6 %24 = OpLoad %ulong %7 diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt index 7bb5bd9..a0b957a 100644 --- a/ptx/src/test/spirv_run/reg_local.spvtxt +++ b/ptx/src/test/spirv_run/reg_local.spvtxt @@ -26,6 +26,7 @@ %ulong_0 = OpConstant %ulong 0 %_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_0_0 = OpConstant %ulong 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar %1 = OpFunction %void None %37 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -48,10 +49,10 @@ %12 = OpCopyObject %ulong %24 OpStore %7 %12 %14 = OpLoad %ulong %7 - %26 = OpCopyObject %ulong %14 - %19 = OpIAdd %ulong %26 %ulong_1 - %27 = OpBitcast %_ptr_Generic_ulong %4 - OpStore %27 %19 Aligned 8 + %19 = OpIAdd %ulong %14 %ulong_1 + %26 = OpBitcast %_ptr_Generic_ulong %4 + %27 = OpCopyObject %ulong %19 + OpStore %26 %27 Aligned 8 %28 = OpBitcast %_ptr_Generic_ulong %4 %47 = OpBitcast %_ptr_Generic_uchar %28 %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_0 @@ -61,9 +62,11 @@ OpStore %7 %15 %16 = OpLoad %ulong %6 %17 = OpLoad %ulong %7 - %23 = OpIAdd %ulong %16 %ulong_0_0 - %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %23 + %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 + %50 = OpBitcast %_ptr_CrossWorkgroup_uchar %30 + %51 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %50 %ulong_0_0 + %23 = OpBitcast %_ptr_CrossWorkgroup_ulong %51 %31 = OpCopyObject %ulong %17 - OpStore %30 %31 Aligned 8 + OpStore %23 %31 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/rem.spvtxt b/ptx/src/test/spirv_run/rem.spvtxt index ce1d3e6..2184523 100644 --- a/ptx/src/test/spirv_run/rem.spvtxt +++ b/ptx/src/test/spirv_run/rem.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %39 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/selp.spvtxt b/ptx/src/test/spirv_run/selp.spvtxt index 9798758..40c0bce 100644 --- a/ptx/src/test/spirv_run/selp.spvtxt +++ b/ptx/src/test/spirv_run/selp.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_ushort = OpTypePointer Function %ushort %_ptr_Generic_ushort = OpTypePointer Generic %ushort %ulong_2 = OpConstant %ulong 2 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %bool = OpTypeBool %false = OpConstantFalse %bool %1 = OpFunction %void None %32 @@ -41,9 +43,11 @@ %12 = OpLoad %ushort %24 Aligned 2 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_2 - %25 = OpConvertUToPtr %_ptr_Generic_ushort %22 - %14 = OpLoad %ushort %25 Aligned 2 + %25 = OpConvertUToPtr %_ptr_Generic_ushort %15 + %39 = OpBitcast %_ptr_Generic_uchar %25 + %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_2 + %22 = OpBitcast %_ptr_Generic_ushort %40 + %14 = OpLoad %ushort %22 Aligned 2 OpStore %7 %14 %17 = OpLoad %ushort %6 %18 = OpLoad %ushort %7 diff --git a/ptx/src/test/spirv_run/selp_true.spvtxt b/ptx/src/test/spirv_run/selp_true.spvtxt index f7038e0..81b3b5f 100644 --- a/ptx/src/test/spirv_run/selp_true.spvtxt +++ b/ptx/src/test/spirv_run/selp_true.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_ushort = OpTypePointer Function %ushort %_ptr_Generic_ushort = OpTypePointer Generic %ushort %ulong_2 = OpConstant %ulong 2 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %bool = OpTypeBool %true = OpConstantTrue %bool %1 = OpFunction %void None %32 @@ -41,9 +43,11 @@ %12 = OpLoad %ushort %24 Aligned 2 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_2 - %25 = OpConvertUToPtr %_ptr_Generic_ushort %22 - %14 = OpLoad %ushort %25 Aligned 2 + %25 = OpConvertUToPtr %_ptr_Generic_ushort %15 + %39 = OpBitcast %_ptr_Generic_uchar %25 + %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_2 + %22 = OpBitcast %_ptr_Generic_ushort %40 + %14 = OpLoad %ushort %22 Aligned 2 OpStore %7 %14 %17 = OpLoad %ushort %6 %18 = OpLoad %ushort %7 diff --git a/ptx/src/test/spirv_run/setp.spvtxt b/ptx/src/test/spirv_run/setp.spvtxt index c3129e3..5868881 100644 --- a/ptx/src/test/spirv_run/setp.spvtxt +++ b/ptx/src/test/spirv_run/setp.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_ulong = OpTypePointer Generic %ulong %ulong_8 = OpConstant %ulong 8 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_1 = OpConstant %ulong 1 %ulong_2 = OpConstant %ulong 2 %1 = OpFunction %void None %43 @@ -43,9 +45,11 @@ %18 = OpLoad %ulong %35 Aligned 8 OpStore %6 %18 %21 = OpLoad %ulong %4 - %32 = OpIAdd %ulong %21 %ulong_8 - %36 = OpConvertUToPtr %_ptr_Generic_ulong %32 - %20 = OpLoad %ulong %36 Aligned 8 + %36 = OpConvertUToPtr %_ptr_Generic_ulong %21 + %50 = OpBitcast %_ptr_Generic_uchar %36 + %51 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %50 %ulong_8 + %32 = OpBitcast %_ptr_Generic_ulong %51 + %20 = OpLoad %ulong %32 Aligned 8 OpStore %7 %20 %23 = OpLoad %ulong %6 %24 = OpLoad %ulong %7 diff --git a/ptx/src/test/spirv_run/setp_gt.spvtxt b/ptx/src/test/spirv_run/setp_gt.spvtxt index 77f6546..e9783f5 100644 --- a/ptx/src/test/spirv_run/setp_gt.spvtxt +++ b/ptx/src/test/spirv_run/setp_gt.spvtxt @@ -20,6 +20,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %43 %14 = OpFunctionParameter %ulong %15 = OpFunctionParameter %ulong @@ -43,9 +45,11 @@ %18 = OpLoad %float %35 Aligned 4 OpStore %6 %18 %21 = OpLoad %ulong %4 - %34 = OpIAdd %ulong %21 %ulong_4 - %36 = OpConvertUToPtr %_ptr_Generic_float %34 - %20 = OpLoad %float %36 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_float %21 + %52 = OpBitcast %_ptr_Generic_uchar %36 + %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_4 + %34 = OpBitcast %_ptr_Generic_float %53 + %20 = OpLoad %float %34 Aligned 4 OpStore %7 %20 %23 = OpLoad %float %6 %24 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/setp_leu.spvtxt b/ptx/src/test/spirv_run/setp_leu.spvtxt index f80880a..1d2d781 100644 --- a/ptx/src/test/spirv_run/setp_leu.spvtxt +++ b/ptx/src/test/spirv_run/setp_leu.spvtxt @@ -20,6 +20,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %43 %14 = OpFunctionParameter %ulong %15 = OpFunctionParameter %ulong @@ -43,9 +45,11 @@ %18 = OpLoad %float %35 Aligned 4 OpStore %6 %18 %21 = OpLoad %ulong %4 - %34 = OpIAdd %ulong %21 %ulong_4 - %36 = OpConvertUToPtr %_ptr_Generic_float %34 - %20 = OpLoad %float %36 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_float %21 + %52 = OpBitcast %_ptr_Generic_uchar %36 + %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_4 + %34 = OpBitcast %_ptr_Generic_float %53 + %20 = OpLoad %float %34 Aligned 4 OpStore %7 %20 %23 = OpLoad %float %6 %24 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/setp_nan.spvtxt b/ptx/src/test/spirv_run/setp_nan.spvtxt index 4a9fe11..2ee333a 100644 --- a/ptx/src/test/spirv_run/setp_nan.spvtxt +++ b/ptx/src/test/spirv_run/setp_nan.spvtxt @@ -22,6 +22,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %ulong_12 = OpConstant %ulong 12 %ulong_16 = OpConstant %ulong 16 @@ -69,45 +71,59 @@ %36 = OpLoad %float %116 Aligned 4 OpStore %6 %36 %39 = OpLoad %ulong %4 - %89 = OpIAdd %ulong %39 %ulong_4 - %117 = OpConvertUToPtr %_ptr_Generic_float %89 - %38 = OpLoad %float %117 Aligned 4 + %117 = OpConvertUToPtr %_ptr_Generic_float %39 + %144 = OpBitcast %_ptr_Generic_uchar %117 + %145 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %144 %ulong_4 + %89 = OpBitcast %_ptr_Generic_float %145 + %38 = OpLoad %float %89 Aligned 4 OpStore %7 %38 %41 = OpLoad %ulong %4 - %91 = OpIAdd %ulong %41 %ulong_8 - %118 = OpConvertUToPtr %_ptr_Generic_float %91 - %40 = OpLoad %float %118 Aligned 4 + %118 = OpConvertUToPtr %_ptr_Generic_float %41 + %146 = OpBitcast %_ptr_Generic_uchar %118 + %147 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %146 %ulong_8 + %91 = OpBitcast %_ptr_Generic_float %147 + %40 = OpLoad %float %91 Aligned 4 OpStore %8 %40 %43 = OpLoad %ulong %4 - %93 = OpIAdd %ulong %43 %ulong_12 - %119 = OpConvertUToPtr %_ptr_Generic_float %93 - %42 = OpLoad %float %119 Aligned 4 + %119 = OpConvertUToPtr %_ptr_Generic_float %43 + %148 = OpBitcast %_ptr_Generic_uchar %119 + %149 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %148 %ulong_12 + %93 = OpBitcast %_ptr_Generic_float %149 + %42 = OpLoad %float %93 Aligned 4 OpStore %9 %42 %45 = OpLoad %ulong %4 - %95 = OpIAdd %ulong %45 %ulong_16 - %120 = OpConvertUToPtr %_ptr_Generic_float %95 - %44 = OpLoad %float %120 Aligned 4 + %120 = OpConvertUToPtr %_ptr_Generic_float %45 + %150 = OpBitcast %_ptr_Generic_uchar %120 + %151 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %150 %ulong_16 + %95 = OpBitcast %_ptr_Generic_float %151 + %44 = OpLoad %float %95 Aligned 4 OpStore %10 %44 %47 = OpLoad %ulong %4 - %97 = OpIAdd %ulong %47 %ulong_20 - %121 = OpConvertUToPtr %_ptr_Generic_float %97 - %46 = OpLoad %float %121 Aligned 4 + %121 = OpConvertUToPtr %_ptr_Generic_float %47 + %152 = OpBitcast %_ptr_Generic_uchar %121 + %153 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %152 %ulong_20 + %97 = OpBitcast %_ptr_Generic_float %153 + %46 = OpLoad %float %97 Aligned 4 OpStore %11 %46 %49 = OpLoad %ulong %4 - %99 = OpIAdd %ulong %49 %ulong_24 - %122 = OpConvertUToPtr %_ptr_Generic_float %99 - %48 = OpLoad %float %122 Aligned 4 + %122 = OpConvertUToPtr %_ptr_Generic_float %49 + %154 = OpBitcast %_ptr_Generic_uchar %122 + %155 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %154 %ulong_24 + %99 = OpBitcast %_ptr_Generic_float %155 + %48 = OpLoad %float %99 Aligned 4 OpStore %12 %48 %51 = OpLoad %ulong %4 - %101 = OpIAdd %ulong %51 %ulong_28 - %123 = OpConvertUToPtr %_ptr_Generic_float %101 - %50 = OpLoad %float %123 Aligned 4 + %123 = OpConvertUToPtr %_ptr_Generic_float %51 + %156 = OpBitcast %_ptr_Generic_uchar %123 + %157 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %156 %ulong_28 + %101 = OpBitcast %_ptr_Generic_float %157 + %50 = OpLoad %float %101 Aligned 4 OpStore %13 %50 %53 = OpLoad %float %6 %54 = OpLoad %float %7 - %142 = OpIsNan %bool %53 - %143 = OpIsNan %bool %54 - %52 = OpLogicalOr %bool %142 %143 + %158 = OpIsNan %bool %53 + %159 = OpIsNan %bool %54 + %52 = OpLogicalOr %bool %158 %159 OpStore %15 %52 %55 = OpLoad %bool %15 OpBranchConditional %55 %16 %17 @@ -129,9 +145,9 @@ OpStore %124 %60 Aligned 4 %62 = OpLoad %float %8 %63 = OpLoad %float %9 - %145 = OpIsNan %bool %62 - %146 = OpIsNan %bool %63 - %61 = OpLogicalOr %bool %145 %146 + %161 = OpIsNan %bool %62 + %162 = OpIsNan %bool %63 + %61 = OpLogicalOr %bool %161 %162 OpStore %15 %61 %64 = OpLoad %bool %15 OpBranchConditional %64 %20 %21 @@ -149,14 +165,16 @@ %23 = OpLabel %68 = OpLoad %ulong %5 %69 = OpLoad %uint %14 - %107 = OpIAdd %ulong %68 %ulong_4_0 - %125 = OpConvertUToPtr %_ptr_Generic_uint %107 - OpStore %125 %69 Aligned 4 + %125 = OpConvertUToPtr %_ptr_Generic_uint %68 + %163 = OpBitcast %_ptr_Generic_uchar %125 + %164 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %163 %ulong_4_0 + %107 = OpBitcast %_ptr_Generic_uint %164 + OpStore %107 %69 Aligned 4 %71 = OpLoad %float %10 %72 = OpLoad %float %11 - %147 = OpIsNan %bool %71 - %148 = OpIsNan %bool %72 - %70 = OpLogicalOr %bool %147 %148 + %165 = OpIsNan %bool %71 + %166 = OpIsNan %bool %72 + %70 = OpLogicalOr %bool %165 %166 OpStore %15 %70 %73 = OpLoad %bool %15 OpBranchConditional %73 %24 %25 @@ -174,14 +192,16 @@ %27 = OpLabel %77 = OpLoad %ulong %5 %78 = OpLoad %uint %14 - %111 = OpIAdd %ulong %77 %ulong_8_0 - %126 = OpConvertUToPtr %_ptr_Generic_uint %111 - OpStore %126 %78 Aligned 4 + %126 = OpConvertUToPtr %_ptr_Generic_uint %77 + %167 = OpBitcast %_ptr_Generic_uchar %126 + %168 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %167 %ulong_8_0 + %111 = OpBitcast %_ptr_Generic_uint %168 + OpStore %111 %78 Aligned 4 %80 = OpLoad %float %12 %81 = OpLoad %float %13 - %149 = OpIsNan %bool %80 - %150 = OpIsNan %bool %81 - %79 = OpLogicalOr %bool %149 %150 + %169 = OpIsNan %bool %80 + %170 = OpIsNan %bool %81 + %79 = OpLogicalOr %bool %169 %170 OpStore %15 %79 %82 = OpLoad %bool %15 OpBranchConditional %82 %28 %29 @@ -199,8 +219,10 @@ %31 = OpLabel %86 = OpLoad %ulong %5 %87 = OpLoad %uint %14 - %115 = OpIAdd %ulong %86 %ulong_12_0 - %127 = OpConvertUToPtr %_ptr_Generic_uint %115 - OpStore %127 %87 Aligned 4 + %127 = OpConvertUToPtr %_ptr_Generic_uint %86 + %171 = OpBitcast %_ptr_Generic_uchar %127 + %172 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %171 %ulong_12_0 + %115 = OpBitcast %_ptr_Generic_uint %172 + OpStore %115 %87 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/setp_num.spvtxt b/ptx/src/test/spirv_run/setp_num.spvtxt index 3ac6eab..c576a50 100644 --- a/ptx/src/test/spirv_run/setp_num.spvtxt +++ b/ptx/src/test/spirv_run/setp_num.spvtxt @@ -22,6 +22,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %ulong_12 = OpConstant %ulong 12 %ulong_16 = OpConstant %ulong 16 @@ -77,46 +79,60 @@ %36 = OpLoad %float %116 Aligned 4 OpStore %6 %36 %39 = OpLoad %ulong %4 - %89 = OpIAdd %ulong %39 %ulong_4 - %117 = OpConvertUToPtr %_ptr_Generic_float %89 - %38 = OpLoad %float %117 Aligned 4 + %117 = OpConvertUToPtr %_ptr_Generic_float %39 + %144 = OpBitcast %_ptr_Generic_uchar %117 + %145 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %144 %ulong_4 + %89 = OpBitcast %_ptr_Generic_float %145 + %38 = OpLoad %float %89 Aligned 4 OpStore %7 %38 %41 = OpLoad %ulong %4 - %91 = OpIAdd %ulong %41 %ulong_8 - %118 = OpConvertUToPtr %_ptr_Generic_float %91 - %40 = OpLoad %float %118 Aligned 4 + %118 = OpConvertUToPtr %_ptr_Generic_float %41 + %146 = OpBitcast %_ptr_Generic_uchar %118 + %147 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %146 %ulong_8 + %91 = OpBitcast %_ptr_Generic_float %147 + %40 = OpLoad %float %91 Aligned 4 OpStore %8 %40 %43 = OpLoad %ulong %4 - %93 = OpIAdd %ulong %43 %ulong_12 - %119 = OpConvertUToPtr %_ptr_Generic_float %93 - %42 = OpLoad %float %119 Aligned 4 + %119 = OpConvertUToPtr %_ptr_Generic_float %43 + %148 = OpBitcast %_ptr_Generic_uchar %119 + %149 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %148 %ulong_12 + %93 = OpBitcast %_ptr_Generic_float %149 + %42 = OpLoad %float %93 Aligned 4 OpStore %9 %42 %45 = OpLoad %ulong %4 - %95 = OpIAdd %ulong %45 %ulong_16 - %120 = OpConvertUToPtr %_ptr_Generic_float %95 - %44 = OpLoad %float %120 Aligned 4 + %120 = OpConvertUToPtr %_ptr_Generic_float %45 + %150 = OpBitcast %_ptr_Generic_uchar %120 + %151 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %150 %ulong_16 + %95 = OpBitcast %_ptr_Generic_float %151 + %44 = OpLoad %float %95 Aligned 4 OpStore %10 %44 %47 = OpLoad %ulong %4 - %97 = OpIAdd %ulong %47 %ulong_20 - %121 = OpConvertUToPtr %_ptr_Generic_float %97 - %46 = OpLoad %float %121 Aligned 4 + %121 = OpConvertUToPtr %_ptr_Generic_float %47 + %152 = OpBitcast %_ptr_Generic_uchar %121 + %153 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %152 %ulong_20 + %97 = OpBitcast %_ptr_Generic_float %153 + %46 = OpLoad %float %97 Aligned 4 OpStore %11 %46 %49 = OpLoad %ulong %4 - %99 = OpIAdd %ulong %49 %ulong_24 - %122 = OpConvertUToPtr %_ptr_Generic_float %99 - %48 = OpLoad %float %122 Aligned 4 + %122 = OpConvertUToPtr %_ptr_Generic_float %49 + %154 = OpBitcast %_ptr_Generic_uchar %122 + %155 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %154 %ulong_24 + %99 = OpBitcast %_ptr_Generic_float %155 + %48 = OpLoad %float %99 Aligned 4 OpStore %12 %48 %51 = OpLoad %ulong %4 - %101 = OpIAdd %ulong %51 %ulong_28 - %123 = OpConvertUToPtr %_ptr_Generic_float %101 - %50 = OpLoad %float %123 Aligned 4 + %123 = OpConvertUToPtr %_ptr_Generic_float %51 + %156 = OpBitcast %_ptr_Generic_uchar %123 + %157 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %156 %ulong_28 + %101 = OpBitcast %_ptr_Generic_float %157 + %50 = OpLoad %float %101 Aligned 4 OpStore %13 %50 %53 = OpLoad %float %6 %54 = OpLoad %float %7 - %142 = OpIsNan %bool %53 - %143 = OpIsNan %bool %54 - %144 = OpLogicalOr %bool %142 %143 - %52 = OpSelect %bool %144 %false %true + %158 = OpIsNan %bool %53 + %159 = OpIsNan %bool %54 + %160 = OpLogicalOr %bool %158 %159 + %52 = OpSelect %bool %160 %false %true OpStore %15 %52 %55 = OpLoad %bool %15 OpBranchConditional %55 %16 %17 @@ -138,10 +154,10 @@ OpStore %124 %60 Aligned 4 %62 = OpLoad %float %8 %63 = OpLoad %float %9 - %148 = OpIsNan %bool %62 - %149 = OpIsNan %bool %63 - %150 = OpLogicalOr %bool %148 %149 - %61 = OpSelect %bool %150 %false_0 %true_0 + %164 = OpIsNan %bool %62 + %165 = OpIsNan %bool %63 + %166 = OpLogicalOr %bool %164 %165 + %61 = OpSelect %bool %166 %false_0 %true_0 OpStore %15 %61 %64 = OpLoad %bool %15 OpBranchConditional %64 %20 %21 @@ -159,15 +175,17 @@ %23 = OpLabel %68 = OpLoad %ulong %5 %69 = OpLoad %uint %14 - %107 = OpIAdd %ulong %68 %ulong_4_0 - %125 = OpConvertUToPtr %_ptr_Generic_uint %107 - OpStore %125 %69 Aligned 4 + %125 = OpConvertUToPtr %_ptr_Generic_uint %68 + %169 = OpBitcast %_ptr_Generic_uchar %125 + %170 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %169 %ulong_4_0 + %107 = OpBitcast %_ptr_Generic_uint %170 + OpStore %107 %69 Aligned 4 %71 = OpLoad %float %10 %72 = OpLoad %float %11 - %153 = OpIsNan %bool %71 - %154 = OpIsNan %bool %72 - %155 = OpLogicalOr %bool %153 %154 - %70 = OpSelect %bool %155 %false_1 %true_1 + %171 = OpIsNan %bool %71 + %172 = OpIsNan %bool %72 + %173 = OpLogicalOr %bool %171 %172 + %70 = OpSelect %bool %173 %false_1 %true_1 OpStore %15 %70 %73 = OpLoad %bool %15 OpBranchConditional %73 %24 %25 @@ -185,15 +203,17 @@ %27 = OpLabel %77 = OpLoad %ulong %5 %78 = OpLoad %uint %14 - %111 = OpIAdd %ulong %77 %ulong_8_0 - %126 = OpConvertUToPtr %_ptr_Generic_uint %111 - OpStore %126 %78 Aligned 4 + %126 = OpConvertUToPtr %_ptr_Generic_uint %77 + %176 = OpBitcast %_ptr_Generic_uchar %126 + %177 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %176 %ulong_8_0 + %111 = OpBitcast %_ptr_Generic_uint %177 + OpStore %111 %78 Aligned 4 %80 = OpLoad %float %12 %81 = OpLoad %float %13 - %158 = OpIsNan %bool %80 - %159 = OpIsNan %bool %81 - %160 = OpLogicalOr %bool %158 %159 - %79 = OpSelect %bool %160 %false_2 %true_2 + %178 = OpIsNan %bool %80 + %179 = OpIsNan %bool %81 + %180 = OpLogicalOr %bool %178 %179 + %79 = OpSelect %bool %180 %false_2 %true_2 OpStore %15 %79 %82 = OpLoad %bool %15 OpBranchConditional %82 %28 %29 @@ -211,8 +231,10 @@ %31 = OpLabel %86 = OpLoad %ulong %5 %87 = OpLoad %uint %14 - %115 = OpIAdd %ulong %86 %ulong_12_0 - %127 = OpConvertUToPtr %_ptr_Generic_uint %115 - OpStore %127 %87 Aligned 4 + %127 = OpConvertUToPtr %_ptr_Generic_uint %86 + %183 = OpBitcast %_ptr_Generic_uchar %127 + %184 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %183 %ulong_12_0 + %115 = OpBitcast %_ptr_Generic_uint %184 + OpStore %115 %87 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt index 2ea964c..1b2e3dd 100644 --- a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt +++ b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt @@ -24,7 +24,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %uint_0 = OpConstant %uint 0 + %ulong_0 = OpConstant %ulong 0 +%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar %1 = OpFunction %void None %40 %10 = OpFunctionParameter %ulong %11 = OpFunctionParameter %ulong @@ -54,9 +55,11 @@ %27 = OpConvertUToPtr %_ptr_Workgroup_ulong %17 OpStore %27 %18 Aligned 8 %20 = OpLoad %uint %7 - %24 = OpIAdd %uint %20 %uint_0 - %28 = OpConvertUToPtr %_ptr_Workgroup_ulong %24 - %19 = OpLoad %ulong %28 Aligned 8 + %28 = OpConvertUToPtr %_ptr_Workgroup_ulong %20 + %46 = OpBitcast %_ptr_Workgroup_uchar %28 + %47 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uchar %46 %ulong_0 + %24 = OpBitcast %_ptr_Workgroup_ulong %47 + %19 = OpLoad %ulong %24 Aligned 8 OpStore %9 %19 %21 = OpLoad %ulong %6 %22 = OpLoad %ulong %9 diff --git a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt index 19d5a5a..fd4f893 100644 --- a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt +++ b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt @@ -7,27 +7,24 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %33 = OpExtInstImport "OpenCL.std" + %31 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %2 "shared_ptr_take_address" %1 OpDecorate %1 Alignment 4 %void = OpTypeVoid %uchar = OpTypeInt 8 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar -%_ptr_Workgroup__ptr_Workgroup_uchar = OpTypePointer Workgroup %_ptr_Workgroup_uchar - %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uchar Workgroup + %1 = OpVariable %_ptr_Workgroup_uchar Workgroup %ulong = OpTypeInt 64 0 - %39 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar -%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar + %36 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %2 = OpFunction %void None %39 + %2 = OpFunction %void None %36 %10 = OpFunctionParameter %ulong %11 = OpFunctionParameter %ulong - %31 = OpFunctionParameter %_ptr_Workgroup_uchar - %40 = OpLabel - %32 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %30 = OpFunctionParameter %_ptr_Workgroup_uchar + %28 = OpLabel %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function %5 = OpVariable %_ptr_Function_ulong Function @@ -35,34 +32,30 @@ %7 = OpVariable %_ptr_Function_ulong Function %8 = OpVariable %_ptr_Function_ulong Function %9 = OpVariable %_ptr_Function_ulong Function - OpStore %32 %31 - OpBranch %29 - %29 = OpLabel OpStore %3 %10 OpStore %4 %11 %12 = OpLoad %ulong %3 Aligned 8 OpStore %5 %12 %13 = OpLoad %ulong %4 Aligned 8 OpStore %6 %13 - %15 = OpLoad %_ptr_Workgroup_uchar %32 - %24 = OpConvertPtrToU %ulong %15 - %14 = OpCopyObject %ulong %24 + %23 = OpConvertPtrToU %ulong %30 + %14 = OpCopyObject %ulong %23 OpStore %7 %14 - %17 = OpLoad %ulong %5 - %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %17 - %16 = OpLoad %ulong %25 Aligned 8 - OpStore %8 %16 - %18 = OpLoad %ulong %7 - %19 = OpLoad %ulong %8 - %26 = OpConvertUToPtr %_ptr_Workgroup_ulong %18 - OpStore %26 %19 Aligned 8 - %21 = OpLoad %ulong %7 - %27 = OpConvertUToPtr %_ptr_Workgroup_ulong %21 - %20 = OpLoad %ulong %27 Aligned 8 - OpStore %9 %20 - %22 = OpLoad %ulong %6 - %23 = OpLoad %ulong %9 - %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %22 - OpStore %28 %23 Aligned 8 + %16 = OpLoad %ulong %5 + %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 + %15 = OpLoad %ulong %24 Aligned 8 + OpStore %8 %15 + %17 = OpLoad %ulong %7 + %18 = OpLoad %ulong %8 + %25 = OpConvertUToPtr %_ptr_Workgroup_ulong %17 + OpStore %25 %18 Aligned 8 + %20 = OpLoad %ulong %7 + %26 = OpConvertUToPtr %_ptr_Workgroup_ulong %20 + %19 = OpLoad %ulong %26 Aligned 8 + OpStore %9 %19 + %21 = OpLoad %ulong %6 + %22 = OpLoad %ulong %9 + %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21 + OpStore %27 %22 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt index 33812f6..cf0d86e 100644 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt @@ -7,7 +7,7 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %50 = OpExtInstImport "OpenCL.std" + %54 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "stateful_ld_st_ntid" %gl_LocalInvocationID OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId @@ -18,34 +18,34 @@ %gl_LocalInvocationID = OpVariable %_ptr_Input_v3ulong Input %uchar = OpTypeInt 8 0 %_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %57 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar + %61 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar %_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %1 = OpFunction %void None %57 + %1 = OpFunction %void None %61 %20 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar %21 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %48 = OpLabel - %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %52 = OpLabel + %12 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %13 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %6 = OpVariable %_ptr_Function_uint Function %7 = OpVariable %_ptr_Function_ulong Function %8 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %20 - OpStore %3 %21 - %13 = OpBitcast %_ptr_Function_ulong %2 - %44 = OpLoad %ulong %13 Aligned 8 - %12 = OpCopyObject %ulong %44 - %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %12 + OpStore %12 %20 + OpStore %13 %21 + %45 = OpBitcast %_ptr_Function_ulong %12 + %44 = OpLoad %ulong %45 Aligned 8 + %14 = OpCopyObject %ulong %44 + %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14 OpStore %10 %22 - %15 = OpBitcast %_ptr_Function_ulong %3 - %45 = OpLoad %ulong %15 Aligned 8 - %14 = OpCopyObject %ulong %45 - %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14 + %47 = OpBitcast %_ptr_Function_ulong %13 + %46 = OpLoad %ulong %47 Aligned 8 + %15 = OpCopyObject %ulong %46 + %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15 OpStore %11 %23 %24 = OpLoad %_ptr_CrossWorkgroup_uchar %10 %17 = OpConvertPtrToU %ulong %24 @@ -57,35 +57,37 @@ %18 = OpCopyObject %ulong %19 %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %18 OpStore %11 %27 - %62 = OpLoad %v3ulong %gl_LocalInvocationID - %43 = OpCompositeExtract %ulong %62 0 - %63 = OpBitcast %ulong %43 - %29 = OpUConvert %uint %63 + %66 = OpLoad %v3ulong %gl_LocalInvocationID + %43 = OpCompositeExtract %ulong %66 0 + %67 = OpBitcast %ulong %43 + %29 = OpUConvert %uint %67 %28 = OpCopyObject %uint %29 OpStore %6 %28 %31 = OpLoad %uint %6 - %64 = OpBitcast %uint %31 - %30 = OpUConvert %ulong %64 + %68 = OpBitcast %uint %31 + %30 = OpUConvert %ulong %68 OpStore %7 %30 %33 = OpLoad %_ptr_CrossWorkgroup_uchar %10 %34 = OpLoad %ulong %7 - %65 = OpBitcast %_ptr_CrossWorkgroup_uchar %33 - %66 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %65 %34 - %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %66 + %48 = OpCopyObject %ulong %34 + %69 = OpBitcast %_ptr_CrossWorkgroup_uchar %33 + %70 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %69 %48 + %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %70 OpStore %10 %32 %36 = OpLoad %_ptr_CrossWorkgroup_uchar %11 %37 = OpLoad %ulong %7 - %67 = OpBitcast %_ptr_CrossWorkgroup_uchar %36 - %68 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %67 %37 - %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %68 + %49 = OpCopyObject %ulong %37 + %71 = OpBitcast %_ptr_CrossWorkgroup_uchar %36 + %72 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %71 %49 + %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %72 OpStore %11 %35 %39 = OpLoad %_ptr_CrossWorkgroup_uchar %10 - %46 = OpBitcast %_ptr_CrossWorkgroup_ulong %39 - %38 = OpLoad %ulong %46 Aligned 8 + %50 = OpBitcast %_ptr_CrossWorkgroup_ulong %39 + %38 = OpLoad %ulong %50 Aligned 8 OpStore %8 %38 %40 = OpLoad %_ptr_CrossWorkgroup_uchar %11 %41 = OpLoad %ulong %8 - %47 = OpBitcast %_ptr_CrossWorkgroup_ulong %40 - OpStore %47 %41 Aligned 8 + %51 = OpBitcast %_ptr_CrossWorkgroup_ulong %40 + OpStore %51 %41 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt index cb77d14..97bf000 100644 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt @@ -7,7 +7,7 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %58 = OpExtInstImport "OpenCL.std" + %62 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "stateful_ld_st_ntid_chain" %gl_LocalInvocationID OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId @@ -18,18 +18,18 @@ %gl_LocalInvocationID = OpVariable %_ptr_Input_v3ulong Input %uchar = OpTypeInt 8 0 %_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %65 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar + %69 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar %_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %1 = OpFunction %void None %65 + %1 = OpFunction %void None %69 %28 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar %29 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %56 = OpLabel - %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %60 = OpLabel + %20 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %21 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function @@ -39,17 +39,17 @@ %10 = OpVariable %_ptr_Function_uint Function %11 = OpVariable %_ptr_Function_ulong Function %12 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %28 - OpStore %3 %29 - %21 = OpBitcast %_ptr_Function_ulong %2 - %52 = OpLoad %ulong %21 Aligned 8 - %20 = OpCopyObject %ulong %52 - %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %20 + OpStore %20 %28 + OpStore %21 %29 + %53 = OpBitcast %_ptr_Function_ulong %20 + %52 = OpLoad %ulong %53 Aligned 8 + %22 = OpCopyObject %ulong %52 + %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22 OpStore %14 %30 - %23 = OpBitcast %_ptr_Function_ulong %3 - %53 = OpLoad %ulong %23 Aligned 8 - %22 = OpCopyObject %ulong %53 - %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22 + %55 = OpBitcast %_ptr_Function_ulong %21 + %54 = OpLoad %ulong %55 Aligned 8 + %23 = OpCopyObject %ulong %54 + %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %23 OpStore %17 %31 %32 = OpLoad %_ptr_CrossWorkgroup_uchar %14 %25 = OpConvertPtrToU %ulong %32 @@ -61,35 +61,37 @@ %26 = OpCopyObject %ulong %27 %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26 OpStore %18 %35 - %70 = OpLoad %v3ulong %gl_LocalInvocationID - %51 = OpCompositeExtract %ulong %70 0 - %71 = OpBitcast %ulong %51 - %37 = OpUConvert %uint %71 + %74 = OpLoad %v3ulong %gl_LocalInvocationID + %51 = OpCompositeExtract %ulong %74 0 + %75 = OpBitcast %ulong %51 + %37 = OpUConvert %uint %75 %36 = OpCopyObject %uint %37 OpStore %10 %36 %39 = OpLoad %uint %10 - %72 = OpBitcast %uint %39 - %38 = OpUConvert %ulong %72 + %76 = OpBitcast %uint %39 + %38 = OpUConvert %ulong %76 OpStore %11 %38 %41 = OpLoad %_ptr_CrossWorkgroup_uchar %15 %42 = OpLoad %ulong %11 - %73 = OpBitcast %_ptr_CrossWorkgroup_uchar %41 - %74 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %73 %42 - %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %74 + %56 = OpCopyObject %ulong %42 + %77 = OpBitcast %_ptr_CrossWorkgroup_uchar %41 + %78 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %77 %56 + %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %78 OpStore %16 %40 %44 = OpLoad %_ptr_CrossWorkgroup_uchar %18 %45 = OpLoad %ulong %11 - %75 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 - %76 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %75 %45 - %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %76 + %57 = OpCopyObject %ulong %45 + %79 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 + %80 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %79 %57 + %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %80 OpStore %19 %43 %47 = OpLoad %_ptr_CrossWorkgroup_uchar %16 - %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %47 - %46 = OpLoad %ulong %54 Aligned 8 + %58 = OpBitcast %_ptr_CrossWorkgroup_ulong %47 + %46 = OpLoad %ulong %58 Aligned 8 OpStore %12 %46 %48 = OpLoad %_ptr_CrossWorkgroup_uchar %19 %49 = OpLoad %ulong %12 - %55 = OpBitcast %_ptr_CrossWorkgroup_ulong %48 - OpStore %55 %49 Aligned 8 + %59 = OpBitcast %_ptr_CrossWorkgroup_ulong %48 + OpStore %59 %49 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt index ecf2858..8253bf9 100644 --- a/ptx/src/test/spirv_run/vector.spvtxt +++ b/ptx/src/test/spirv_run/vector.spvtxt @@ -25,8 +25,8 @@ %1 = OpFunction %v2uint None %55 %7 = OpFunctionParameter %v2uint %24 = OpLabel - %2 = OpVariable %_ptr_Function_v2uint Function %3 = OpVariable %_ptr_Function_v2uint Function + %2 = OpVariable %_ptr_Function_v2uint Function %4 = OpVariable %_ptr_Function_v2uint Function %5 = OpVariable %_ptr_Function_uint Function %6 = OpVariable %_ptr_Function_uint Function diff --git a/ptx/src/test/spirv_run/xor.spvtxt b/ptx/src/test/spirv_run/xor.spvtxt index 4cc8968..c3a1f6f 100644 --- a/ptx/src/test/spirv_run/xor.spvtxt +++ b/ptx/src/test/spirv_run/xor.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %39 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 From 2e8716bf0debf5edfecd616204d0fd2864dc2f4c Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 25 Jun 2021 01:20:16 +0200 Subject: [PATCH 25/25] Clean up warnings --- ptx/src/translate.rs | 108 +++++++------------------------------------ 1 file changed, 17 insertions(+), 91 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 277db5c..3daf937 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,13 +1,9 @@ use crate::ast; -use core::borrow; use half::f16; use rspirv::dr; -use std::{borrow::Borrow, cell::RefCell}; +use std::cell::RefCell; +use std::collections::{hash_map, HashMap, HashSet}; use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem, rc::Rc}; -use std::{ - collections::{hash_map, HashMap, HashSet}, - convert::TryInto, -}; use rspirv::binary::Assemble; @@ -433,7 +429,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result, _>>()?; let must_link_ptx_impl = ptx_impl_imports.len() > 0; - let mut directives = ptx_impl_imports + let directives = ptx_impl_imports .into_iter() .map(|(_, v)| v) .chain(directives.into_iter()) @@ -1068,8 +1064,10 @@ fn emit_function_header<'a>( }) => { match (**func_decl).borrow().name { ast::MethodName::Func(name) => { - for var in globals { - interface.push(var.name); + if child_fns.contains(&name) { + for var in globals { + interface.push(var.name); + } } } ast::MethodName::Kernel(_) => {} @@ -1264,30 +1262,6 @@ fn to_ssa<'input, 'b>( }) } -fn deparamize_function_decl( - func_decl_rc: &Rc>>, -) -> Result<(), TranslateError> { - let mut func_decl = func_decl_rc.borrow_mut(); - match func_decl.name { - ast::MethodName::Func(..) => { - for decl in func_decl.input_arguments.iter_mut() { - if decl.state_space == ast::StateSpace::Param { - decl.state_space = ast::StateSpace::Reg; - let baseline_type = match decl.v_type { - ast::Type::Scalar(t) => t, - ast::Type::Vector(t, _) => t, // TODO: write a test for this - ast::Type::Array(t, _) => t, // TODO: write a test for this - ast::Type::Pointer(_, _) => return Err(error_unreachable()), - }; - decl.v_type = ast::Type::Pointer(baseline_type, ast::StateSpace::Param); - } - } - } - ast::MethodName::Kernel(..) => {} - }; - Ok(()) -} - fn fix_special_registers( typed_statements: Vec, numeric_id_defs: &mut NumericIdResolver, @@ -1905,17 +1879,6 @@ fn to_ptx_impl_bfi_call( }) } -fn to_resolved_fn_args( - params: Vec, - params_decl: &[ast::Variable], -) -> Vec<(T, ast::Type, ast::StateSpace)> { - params - .into_iter() - .zip(params_decl.iter()) - .map(|(id, var)| (id, var.v_type.clone(), var.state_space)) - .collect::>() -} - fn normalize_labels( func: Vec, id_def: &mut NumericIdResolver, @@ -2644,10 +2607,15 @@ fn emit_function_body_ops( Statement::Label(_) => (), Statement::Call(call) => { let (result_type, result_id) = match &*call.return_arguments { - [(id, typ, space)] => ( - map.get_or_add(builder, SpirvType::new(typ.clone())), - Some(*id), - ), + [(id, typ, space)] => { + if *space != ast::StateSpace::Reg { + return Err(error_unreachable()); + } + ( + map.get_or_add(builder, SpirvType::new(typ.clone())), + Some(*id), + ) + } [] => (map.void(), None), _ => todo!(), }; @@ -4679,7 +4647,7 @@ fn convert_to_stateful_memory_access_postprocess( ) -> Result { Ok(match remapped_ids.get(&arg_desc.op) { Some(new_id) => { - let (new_operand_type, new_operand_space, is_variable) = id_defs.get_typed(*new_id)?; + let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?; if let Some((expected_type, expected_space)) = expected_type { let implicit_conversion = arg_desc .non_default_implicit_conversion @@ -4694,7 +4662,6 @@ fn convert_to_stateful_memory_access_postprocess( } } let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?; - let new_operand_type_clone = new_operand_type.clone(); let converting_id = id_defs.register_intermediate(Some((old_operand_type.clone(), old_operand_space))); let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) { @@ -5745,20 +5712,6 @@ pub struct PtrAccess { offset_src: P::Operand, } -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -pub enum ArgumentSemantics { - // normal register access - Default, - // normal register access with relaxed conversion rules (ld/st) - DefaultRelaxed, - // st/ld global - PhysicalPointer, - // st/ld .param, .local - RegisterPointer, - // mov of .local/.global variables - Address, -} - impl ArgumentDescriptor { fn new_op(&self, u: U) -> ArgumentDescriptor { ArgumentDescriptor { @@ -7315,17 +7268,6 @@ impl ast::AtomSemantics { } } -impl ast::StateSpace { - fn semantics(self) -> ArgumentSemantics { - match self { - ast::StateSpace::Reg => ArgumentSemantics::Default, - ast::StateSpace::Param => ArgumentSemantics::RegisterPointer, - ast::StateSpace::Shared => ArgumentSemantics::PhysicalPointer, - _ => todo!(), - } - } -} - fn default_implicit_conversion( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), @@ -7475,22 +7417,6 @@ fn implicit_conversion_mov( ) } -fn should_bitcast_wrapper( - operand: &ast::Type, - _: ast::StateSpace, - instr: &ast::Type, - _: ast::StateSpace, -) -> Result, TranslateError> { - if instr == operand { - return Ok(None); - } - if should_bitcast(instr, operand) { - Ok(Some(ConversionKind::Default)) - } else { - Err(TranslateError::MismatchedType) - } -} - fn should_convert_relaxed_src_wrapper( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),