diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 1c6d2fb..5432207 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,7 +1,6 @@ use half::f16; use lalrpop_util::{lexer::Token, ParseError}; -use std::convert::TryInto; -use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr}; +use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; #[derive(Debug, thiserror::Error)] @@ -34,195 +33,12 @@ pub enum PtxError { NonExternPointer, } -macro_rules! sub_enum { - ($name:ident { $($variant:ident),+ $(,)? }) => { - sub_enum!{ $name : ScalarType { $($variant),+ } } - }; - ($name:ident : $base_type:ident { $($variant:ident),+ $(,)? }) => { - #[derive(PartialEq, Eq, Clone, Copy)] - pub enum $name { - $( - $variant, - )+ - } - - impl From<$name> for $base_type { - fn from(t: $name) -> $base_type { - match t { - $( - $name::$variant => $base_type::$variant, - )+ - } - } - } - - impl std::convert::TryFrom<$base_type> for $name { - type Error = (); - - fn try_from(t: $base_type) -> Result { - match t { - $( - $base_type::$variant => Ok($name::$variant), - )+ - _ => Err(()), - } - } - } - }; -} - -macro_rules! sub_type { - ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { - sub_type! { $type_name : Type { - $( - $variant ($($field_type),+), - )+ - }} - }; - ($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { - #[derive(PartialEq, Eq, Clone)] - pub enum $type_name { - $( - $variant ($($field_type),+), - )+ - } - - impl From<$type_name> for $base_type { - #[allow(non_snake_case)] - fn from(t: $type_name) -> $base_type { - match t { - $( - $type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+), - )+ - } - } - } - - impl std::convert::TryFrom<$base_type> for $type_name { - type Error = (); - - #[allow(non_snake_case)] - #[allow(unreachable_patterns)] - fn try_from(t: $base_type) -> Result { - match t { - $( - $base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )), - )+ - _ => Err(()), - } - } - } - }; -} - -sub_type! { - VariableRegType { - Scalar(ScalarType), - Vector(SizedScalarType, u8), - // Array type is used when emiting SSA statements at the start of a method - Array(ScalarType, VecU32), - // Pointer variant is used when passing around SLM pointer between - // function calls for dynamic SLM - Pointer(SizedScalarType, PointerStateSpace) - } -} - -type VecU32 = Vec; - -sub_type! { - VariableLocalType { - Scalar(SizedScalarType), - Vector(SizedScalarType, u8), - Array(SizedScalarType, VecU32), - } -} - -impl TryFrom for VariableLocalType { - type Error = PtxError; - - fn try_from(value: VariableGlobalType) -> Result { - match value { - VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)), - VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)), - VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)), - VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray), - } - } -} - -sub_type! { - VariableGlobalType { - Scalar(SizedScalarType), - Vector(SizedScalarType, u8), - Array(SizedScalarType, VecU32), - Pointer(SizedScalarType, PointerStateSpace), - } -} - // For some weird reson this is illegal: // .param .f16x2 foobar; // but this is legal: // .param .f16x2 foobar[1]; // even more interestingly this is legal, but only in .func (not in .entry): // .param .b32 foobar[] -sub_type! { - VariableParamType { - Scalar(LdStScalarType), - Array(SizedScalarType, VecU32), - Pointer(SizedScalarType, PointerStateSpace), - } -} - -sub_enum!(SizedScalarType { - B8, - B16, - B32, - B64, - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64, - F16, - F16x2, - F32, - F64, -}); - -sub_enum!(LdStScalarType { - B8, - B16, - B32, - B64, - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64, - F16, - F32, - F64, -}); - -sub_enum!(SelpType { - B16, - B32, - B64, - U16, - U32, - U64, - S16, - S32, - S64, - F32, - F64, -}); #[derive(Copy, Clone, Eq, PartialEq)] pub enum BarDetails { @@ -266,142 +82,76 @@ pub struct Module<'a> { } pub enum Directive<'a, P: ArgParams> { - Variable(Variable), - Method(Function<'a, &'a str, Statement

>), + Variable(LinkingDirective, Variable), + Method(LinkingDirective, Function<'a, &'a str, Statement

>), } -pub enum MethodDecl<'a, ID> { - Func(Vec>, ID, Vec>), - Kernel { - name: &'a str, - in_args: Vec>, - }, +#[derive(Hash, PartialEq, Eq, Copy, Clone)] +pub enum MethodName<'input, ID> { + Kernel(&'input str), + Func(ID), } -pub type FnArgument = Variable; -pub type KernelArgument = Variable; +pub struct MethodDeclaration<'input, ID> { + pub return_arguments: Vec>, + pub name: MethodName<'input, ID>, + pub input_arguments: Vec>, + pub shared_mem: Option, +} pub struct Function<'a, ID, S> { - pub func_directive: MethodDecl<'a, ID>, + pub func_directive: MethodDeclaration<'a, ID>, pub tuning: Vec, pub body: Option>, } pub type ParsedFunction<'a> = Function<'a, &'a str, Statement>>; -#[derive(PartialEq, Eq, Clone)] -pub enum FnArgumentType { - Reg(VariableRegType), - Param(VariableParamType), - Shared, -} -#[derive(PartialEq, Eq, Clone)] -pub enum KernelArgumentType { - Normal(VariableParamType), - Shared, -} - -impl From for Type { - fn from(this: KernelArgumentType) -> Self { - match this { - KernelArgumentType::Normal(typ) => typ.into(), - KernelArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } -} - -impl FnArgumentType { - pub fn to_type(&self, is_kernel: bool) -> Type { - if is_kernel { - self.to_kernel_type() - } else { - self.to_func_type() - } - } - - pub fn to_kernel_type(&self) -> Type { - match self { - FnArgumentType::Reg(x) => x.clone().into(), - FnArgumentType::Param(x) => x.clone().into(), - FnArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } - - pub fn to_func_type(&self) -> Type { - match self { - FnArgumentType::Reg(x) => x.clone().into(), - FnArgumentType::Param(VariableParamType::Scalar(t)) => { - Type::Pointer(PointerType::Scalar((*t).into()), LdStateSpace::Param) - } - FnArgumentType::Param(VariableParamType::Array(t, dims)) => Type::Pointer( - PointerType::Array((*t).into(), dims.clone()), - LdStateSpace::Param, - ), - FnArgumentType::Param(VariableParamType::Pointer(t, space)) => Type::Pointer( - PointerType::Pointer((*t).into(), (*space).into()), - LdStateSpace::Param, - ), - FnArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } - - pub fn is_param(&self) -> bool { - match self { - FnArgumentType::Param(_) => true, - _ => false, - } - } -} - -sub_enum!( - PointerStateSpace : LdStateSpace { - Generic, - Global, - Const, - Shared, - Param, - } -); - #[derive(PartialEq, Eq, Clone)] pub enum Type { + // .param.b32 foo; + // -> OpTypeInt Scalar(ScalarType), + // .param.v2.b32 foo; + // -> OpTypeVector Vector(ScalarType, u8), + // .param.b32 foo[4]; + // -> OpTypeArray Array(ScalarType, Vec), - Pointer(PointerType, LdStateSpace), -} - -#[derive(PartialEq, Eq, Clone)] -pub enum PointerType { - Scalar(ScalarType), - Vector(ScalarType, u8), - Array(ScalarType, VecU32), - Pointer(ScalarType, LdStateSpace), -} - -impl From for PointerType { - fn from(t: SizedScalarType) -> Self { - PointerType::Scalar(t.into()) - } -} - -impl TryFrom for SizedScalarType { - type Error = (); - - fn try_from(value: PointerType) -> Result { - match value { - PointerType::Scalar(t) => Ok(t.try_into()?), - PointerType::Vector(_, _) => Err(()), - PointerType::Array(_, _) => Err(()), - PointerType::Pointer(_, _) => Err(()), - } - } + /* + Variables of this type almost never exist in the original .ptx and are + usually artificially created. Some examples below: + - extern pointers to the .shared memory in the form: + .extern .shared .b32 shared_mem[]; + which we first parse as + .extern .shared .b32 shared_mem; + and then convert to an additional function parameter: + .param .ptr<.b32.shared> shared_mem; + and do a load at the start of the function (and renames inside fn): + .reg .ptr<.b32.shared> temp; + ld.param.ptr<.b32.shared> temp, [shared_mem]; + note, we don't support non-.shared extern pointers, because there's + zero use for them in the ptxas + - artifical pointers created by stateful conversion, which work + similiarly to the above + - function parameters: + foobar(.param .align 4 .b8 numbers[]) + which get parsed to + foobar(.param .align 4 .b8 numbers) + and then converted to + foobar(.reg .align 4 .ptr<.b8.param> numbers) + - ld/st with offset: + .reg.b32 x; + .param.b64 arg0; + st.param.b32 [arg0+4], x; + Yes, this code is legal and actually emitted by the NV compiler! + We convert the st to: + .reg ptr<.b64.param> temp = ptr_offset(arg0, 4); + st.param.b32 [temp], x; + */ + // .reg ptr<.b64.param> + // -> OpTypePointer Function + Pointer(ScalarType, StateSpace), } #[derive(PartialEq, Eq, Hash, Clone, Copy)] @@ -425,52 +175,6 @@ pub enum ScalarType { Pred, } -sub_enum!(IntType { - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64 -}); - -sub_enum!(BitType { B8, B16, B32, B64 }); - -sub_enum!(UIntType { U8, U16, U32, U64 }); - -sub_enum!(SIntType { S8, S16, S32, S64 }); - -impl IntType { - pub fn is_signed(self) -> bool { - match self { - IntType::U8 | IntType::U16 | IntType::U32 | IntType::U64 => false, - IntType::S8 | IntType::S16 | IntType::S32 | IntType::S64 => true, - } - } - - pub fn width(self) -> u8 { - match self { - IntType::U8 => 1, - IntType::U16 => 2, - IntType::U32 => 4, - IntType::U64 => 8, - IntType::S8 => 1, - IntType::S16 => 2, - IntType::S32 => 4, - IntType::S64 => 8, - } - } -} - -sub_enum!(FloatType { - F16, - F16x2, - F32, - F64 -}); - impl ScalarType { pub fn size_of(self) -> u8 { match self { @@ -509,51 +213,19 @@ pub enum Statement { } pub struct MultiVariable { - pub var: Variable, + pub var: Variable, pub count: Option, } #[derive(Clone)] -pub struct Variable { +pub struct Variable { pub align: Option, - pub v_type: T, + pub v_type: Type, + pub state_space: StateSpace, pub name: ID, pub array_init: Vec, } -#[derive(Eq, PartialEq, Clone)] -pub enum VariableType { - Reg(VariableRegType), - Local(VariableLocalType), - Param(VariableParamType), - Global(VariableGlobalType), - Shared(VariableGlobalType), -} - -impl VariableType { - pub fn to_type(&self) -> (StateSpace, Type) { - match self { - VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()), - VariableType::Local(t) => (StateSpace::Local, t.clone().into()), - VariableType::Param(t) => (StateSpace::Param, t.clone().into()), - VariableType::Global(t) => (StateSpace::Global, t.clone().into()), - VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()), - } - } -} - -impl From for Type { - fn from(t: VariableType) -> Self { - match t { - VariableType::Reg(t) => t.into(), - VariableType::Local(t) => t.into(), - VariableType::Param(t) => t.into(), - VariableType::Global(t) => t.into(), - VariableType::Shared(t) => t.into(), - } - } -} - #[derive(Copy, Clone, PartialEq, Eq)] pub enum StateSpace { Reg, @@ -562,6 +234,8 @@ pub enum StateSpace { Local, Shared, Param, + Generic, + Sreg, } pub struct PredAt { @@ -576,24 +250,24 @@ pub enum Instruction { Add(ArithDetails, Arg3

), Setp(SetpData, Arg4Setp

), SetpBool(SetpBoolData, Arg5Setp

), - Not(BooleanType, Arg2

), + Not(ScalarType, Arg2

), Bra(BraData, Arg1

), Cvt(CvtDetails, Arg2

), Cvta(CvtaDetails, Arg2

), - Shl(ShlType, Arg3

), - Shr(ShrType, Arg3

), + Shl(ScalarType, Arg3

), + Shr(ScalarType, Arg3

), St(StData, Arg2St

), Ret(RetData), Call(CallInst

), Abs(AbsDetails, Arg2

), Mad(MulDetails, Arg4

), - Or(BooleanType, Arg3

), + Or(ScalarType, Arg3

), Sub(ArithDetails, Arg3

), Min(MinMaxDetails, Arg3

), Max(MinMaxDetails, Arg3

), Rcp(RcpDetails, Arg2

), - And(BooleanType, Arg3

), - Selp(SelpType, Arg4

), + And(ScalarType, Arg3

), + Selp(ScalarType, Arg4

), Bar(BarDetails, Arg1Bar

), Atom(AtomDetails, Arg3

), AtomCas(AtomCasDetails, Arg4

), @@ -605,13 +279,13 @@ pub enum Instruction { Cos { flush_to_zero: bool, arg: Arg2

}, Lg2 { flush_to_zero: bool, arg: Arg2

}, Ex2 { flush_to_zero: bool, arg: Arg2

}, - Clz { typ: BitType, arg: Arg2

}, - Brev { typ: BitType, arg: Arg2

}, - Popc { typ: BitType, arg: Arg2

}, - Xor { typ: BooleanType, arg: Arg3

}, - Bfe { typ: IntType, arg: Arg4

}, - Bfi { typ: BitType, arg: Arg5

}, - Rem { typ: IntType, arg: Arg3

}, + Clz { typ: ScalarType, arg: Arg2

}, + Brev { typ: ScalarType, arg: Arg2

}, + Popc { typ: ScalarType, arg: Arg2

}, + Xor { typ: ScalarType, arg: Arg3

}, + Bfe { typ: ScalarType, arg: Arg4

}, + Bfi { typ: ScalarType, arg: Arg5

}, + Rem { typ: ScalarType, arg: Arg3

}, } #[derive(Copy, Clone)] @@ -737,34 +411,12 @@ pub enum VectorPrefix { pub struct LdDetails { pub qualifier: LdStQualifier, - pub state_space: LdStateSpace, + pub state_space: StateSpace, pub caching: LdCacheOperator, - pub typ: LdStType, + pub typ: Type, pub non_coherent: bool, } -sub_type! { - LdStType { - Scalar(LdStScalarType), - Vector(LdStScalarType, u8), - // Used in generated code - Pointer(PointerType, LdStateSpace), - } -} - -impl From for PointerType { - fn from(t: LdStType) -> Self { - match t { - LdStType::Scalar(t) => PointerType::Scalar(t.into()), - LdStType::Vector(t, len) => PointerType::Vector(t.into(), len), - LdStType::Pointer(PointerType::Scalar(scalar_type), space) => { - PointerType::Pointer(scalar_type, space) - } - LdStType::Pointer(..) => unreachable!(), - } - } -} - #[derive(Copy, Clone, PartialEq, Eq)] pub enum LdStQualifier { Weak, @@ -780,17 +432,6 @@ pub enum MemScope { Sys, } -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -#[repr(u8)] -pub enum LdStateSpace { - Generic, - Const, - Global, - Local, - Param, - Shared, -} - #[derive(Copy, Clone, PartialEq, Eq)] pub enum LdCacheOperator { Cached, @@ -825,7 +466,7 @@ impl MovDetails { #[derive(Copy, Clone)] pub struct MulIntDesc { - pub typ: IntType, + pub typ: ScalarType, pub control: MulIntControl, } @@ -845,7 +486,7 @@ pub enum RoundingMode { } pub struct AddIntDesc { - pub typ: IntType, + pub typ: ScalarType, pub saturate: bool, } @@ -892,39 +533,39 @@ pub struct BraData { pub enum CvtDetails { IntFromInt(CvtIntToIntDesc), - FloatFromFloat(CvtDesc), - IntFromFloat(CvtDesc), - FloatFromInt(CvtDesc), + FloatFromFloat(CvtDesc), + IntFromFloat(CvtDesc), + FloatFromInt(CvtDesc), } pub struct CvtIntToIntDesc { - pub dst: IntType, - pub src: IntType, + pub dst: ScalarType, + pub src: ScalarType, pub saturate: bool, } -pub struct CvtDesc { +pub struct CvtDesc { pub rounding: Option, pub flush_to_zero: Option, pub saturate: bool, - pub dst: Dst, - pub src: Src, + pub dst: ScalarType, + pub src: ScalarType, } impl CvtDetails { pub fn new_int_from_int_checked<'err, 'input>( saturate: bool, - dst: IntType, - src: IntType, + dst: ScalarType, + src: ScalarType, err: &'err mut Vec, PtxError>>, ) -> Self { if saturate { - if src.is_signed() { - if dst.is_signed() && dst.width() >= src.width() { + if src.kind() == ScalarKind::Signed { + if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() { err.push(ParseError::from(PtxError::SyntaxError)); } } else { - if dst == src || dst.width() >= src.width() { + if dst == src || dst.size_of() >= src.size_of() { err.push(ParseError::from(PtxError::SyntaxError)); } } @@ -936,11 +577,11 @@ impl CvtDetails { rounding: RoundingMode, flush_to_zero: bool, saturate: bool, - dst: FloatType, - src: IntType, + dst: ScalarType, + src: ScalarType, err: &'err mut Vec, PtxError>>, ) -> Self { - if flush_to_zero && dst != FloatType::F32 { + if flush_to_zero && dst != ScalarType::F32 { err.push(ParseError::from(PtxError::NonF32Ftz)); } CvtDetails::FloatFromInt(CvtDesc { @@ -956,11 +597,11 @@ impl CvtDetails { rounding: RoundingMode, flush_to_zero: bool, saturate: bool, - dst: IntType, - src: FloatType, + dst: ScalarType, + src: ScalarType, err: &'err mut Vec, PtxError>>, ) -> Self { - if flush_to_zero && src != FloatType::F32 { + if flush_to_zero && src != ScalarType::F32 { err.push(ParseError::from(PtxError::NonF32Ftz)); } CvtDetails::IntFromFloat(CvtDesc { @@ -974,58 +615,21 @@ impl CvtDetails { } pub struct CvtaDetails { - pub to: CvtaStateSpace, - pub from: CvtaStateSpace, + pub to: StateSpace, + pub from: StateSpace, pub size: CvtaSize, } -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum CvtaStateSpace { - Generic, - Const, - Global, - Local, - Shared, -} - pub enum CvtaSize { U32, U64, } -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum ShlType { - B16, - B32, - B64, -} - -sub_enum!(ShrType { - B16, - B32, - B64, - U16, - U32, - U64, - S16, - S32, - S64, -}); - pub struct StData { pub qualifier: LdStQualifier, - pub state_space: StStateSpace, + pub state_space: StateSpace, pub caching: StCacheOperator, - pub typ: LdStType, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum StStateSpace { - Generic, - Global, - Local, - Param, - Shared, + pub typ: Type, } #[derive(PartialEq, Eq)] @@ -1040,13 +644,6 @@ pub struct RetData { pub uniform: bool, } -sub_enum!(BooleanType { - Pred, - B16, - B32, - B64, -}); - #[derive(Copy, Clone)] pub enum MulDetails { Unsigned(MulUInt), @@ -1056,32 +653,32 @@ pub enum MulDetails { #[derive(Copy, Clone)] pub struct MulUInt { - pub typ: UIntType, + pub typ: ScalarType, pub control: MulIntControl, } #[derive(Copy, Clone)] pub struct MulSInt { - pub typ: SIntType, + pub typ: ScalarType, pub control: MulIntControl, } #[derive(Copy, Clone)] pub enum ArithDetails { - Unsigned(UIntType), + Unsigned(ScalarType), Signed(ArithSInt), Float(ArithFloat), } #[derive(Copy, Clone)] pub struct ArithSInt { - pub typ: SIntType, + pub typ: ScalarType, pub saturate: bool, } #[derive(Copy, Clone)] pub struct ArithFloat { - pub typ: FloatType, + pub typ: ScalarType, pub rounding: Option, pub flush_to_zero: Option, pub saturate: bool, @@ -1089,8 +686,8 @@ pub struct ArithFloat { #[derive(Copy, Clone)] pub enum MinMaxDetails { - Signed(SIntType), - Unsigned(UIntType), + Signed(ScalarType), + Unsigned(ScalarType), Float(MinMaxFloat), } @@ -1098,14 +695,14 @@ pub enum MinMaxDetails { pub struct MinMaxFloat { pub flush_to_zero: Option, pub nan: bool, - pub typ: FloatType, + pub typ: ScalarType, } #[derive(Copy, Clone)] pub struct AtomDetails { pub semantics: AtomSemantics, pub scope: MemScope, - pub space: AtomSpace, + pub space: StateSpace, pub inner: AtomInnerDetails, } @@ -1117,19 +714,12 @@ pub enum AtomSemantics { AcquireRelease, } -#[derive(Copy, Clone)] -pub enum AtomSpace { - Generic, - Global, - Shared, -} - #[derive(Copy, Clone)] pub enum AtomInnerDetails { - Bit { op: AtomBitOp, typ: BitType }, - Unsigned { op: AtomUIntOp, typ: UIntType }, - Signed { op: AtomSIntOp, typ: SIntType }, - Float { op: AtomFloatOp, typ: FloatType }, + Bit { op: AtomBitOp, typ: ScalarType }, + Unsigned { op: AtomUIntOp, typ: ScalarType }, + Signed { op: AtomSIntOp, typ: ScalarType }, + Float { op: AtomFloatOp, typ: ScalarType }, } #[derive(Copy, Clone, Eq, PartialEq)] @@ -1165,20 +755,20 @@ pub enum AtomFloatOp { pub struct AtomCasDetails { pub semantics: AtomSemantics, pub scope: MemScope, - pub space: AtomSpace, - pub typ: BitType, + pub space: StateSpace, + pub typ: ScalarType, } #[derive(Copy, Clone)] pub enum DivDetails { - Unsigned(UIntType), - Signed(SIntType), + Unsigned(ScalarType), + Signed(ScalarType), Float(DivFloatDetails), } #[derive(Copy, Clone)] pub struct DivFloatDetails { - pub typ: FloatType, + pub typ: ScalarType, pub flush_to_zero: Option, pub kind: DivFloatKind, } @@ -1197,7 +787,7 @@ pub enum NumsOrArrays<'a> { #[derive(Copy, Clone)] pub struct SqrtDetails { - pub typ: FloatType, + pub typ: ScalarType, pub flush_to_zero: Option, pub kind: SqrtKind, } @@ -1210,7 +800,7 @@ pub enum SqrtKind { #[derive(Copy, Clone, Eq, PartialEq)] pub struct RsqrtDetails { - pub typ: FloatType, + pub typ: ScalarType, pub flush_to_zero: bool, } @@ -1221,7 +811,7 @@ pub struct NegDetails { } impl<'a> NumsOrArrays<'a> { - pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result, PtxError> { + pub fn to_vec(self, typ: ScalarType, dimensions: &mut [u32]) -> Result, PtxError> { self.normalize_dimensions(dimensions)?; let sizeof_t = ScalarType::from(typ).size_of() as usize; let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize)); @@ -1252,7 +842,7 @@ impl<'a> NumsOrArrays<'a> { fn parse_and_copy( &self, - t: SizedScalarType, + t: ScalarType, size_of_t: usize, dimensions: &[u32], result: &mut [u8], @@ -1292,47 +882,48 @@ impl<'a> NumsOrArrays<'a> { } fn parse_and_copy_single( - t: SizedScalarType, + t: ScalarType, idx: usize, str_val: &str, radix: u32, output: &mut [u8], ) -> Result<(), PtxError> { match t { - SizedScalarType::B8 | SizedScalarType::U8 => { + ScalarType::B8 | ScalarType::U8 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::B16 | SizedScalarType::U16 => { + ScalarType::B16 | ScalarType::U16 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::B32 | SizedScalarType::U32 => { + ScalarType::B32 | ScalarType::U32 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::B64 | SizedScalarType::U64 => { + ScalarType::B64 | ScalarType::U64 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::S8 => { + ScalarType::S8 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::S16 => { + ScalarType::S16 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::S32 => { + ScalarType::S32 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::S64 => { + ScalarType::S64 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::F16 => { + ScalarType::F16 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::F16x2 => todo!(), - SizedScalarType::F32 => { + ScalarType::F16x2 => todo!(), + ScalarType::F32 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::F64 => { + ScalarType::F64 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } + ScalarType::Pred => todo!(), } Ok(()) } @@ -1379,6 +970,40 @@ pub enum TuningDirective { MinNCtaPerSm(u32), } +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum ScalarKind { + Bit, + Unsigned, + Signed, + Float, + Float2, + Pred, +} + +impl ScalarType { + pub fn kind(self) -> ScalarKind { + match self { + ScalarType::U8 => ScalarKind::Unsigned, + ScalarType::U16 => ScalarKind::Unsigned, + ScalarType::U32 => ScalarKind::Unsigned, + ScalarType::U64 => ScalarKind::Unsigned, + ScalarType::S8 => ScalarKind::Signed, + ScalarType::S16 => ScalarKind::Signed, + ScalarType::S32 => ScalarKind::Signed, + ScalarType::S64 => ScalarKind::Signed, + ScalarType::B8 => ScalarKind::Bit, + ScalarType::B16 => ScalarKind::Bit, + ScalarType::B32 => ScalarKind::Bit, + ScalarType::B64 => ScalarKind::Bit, + ScalarType::F16 => ScalarKind::Float, + ScalarType::F32 => ScalarKind::Float, + ScalarType::F64 => ScalarKind::Float, + ScalarType::F16x2 => ScalarKind::Float2, + ScalarType::Pred => ScalarKind::Pred, + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -1386,13 +1011,13 @@ mod tests { #[test] fn array_fails_multiple_0_dmiensions() { let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0, 0]).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut vec![0, 0]).is_err()); } #[test] fn array_fails_on_empty() { let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0]).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut vec![0]).is_err()); } #[test] @@ -1404,7 +1029,7 @@ mod tests { let mut dimensions = vec![0u32, 2]; assert_eq!( vec![1u8, 2, 3, 4], - inp.to_vec(SizedScalarType::B8, &mut dimensions).unwrap() + inp.to_vec(ScalarType::B8, &mut dimensions).unwrap() ); assert_eq!(dimensions, vec![2u32, 2]); } @@ -1416,7 +1041,7 @@ mod tests { NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]), ]); let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err()); } #[test] @@ -1426,6 +1051,6 @@ mod tests { NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]), ]); let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err()); } } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 423fd57..b697317 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -343,10 +343,16 @@ TargetSpecifier = { Directive: Option>> = { AddressSize => None, - => Some(ast::Directive::Method(f)), + => { + let (linking, func) = f; + Some(ast::Directive::Method(linking, func)) + }, File => None, Section => None, - ";" => Some(ast::Directive::Variable(v)), + ";" => { + let (linking, var) = v; + Some(ast::Directive::Variable(linking, var)) + }, ! => { let err = <>; errors.push(err.error); @@ -358,11 +364,13 @@ AddressSize = { ".address_size" U8Num }; -Function: ast::Function<'input, &'input str, ast::Statement>> = { - LinkingDirectives - +Function: (ast::LinkingDirective, ast::Function<'input, &'input str, ast::Statement>>) = { + + - => ast::Function{<>} + => { + (linking, ast::Function{func_directive, tuning, body}) + } }; LinkingDirective: ast::LinkingDirective = { @@ -388,44 +396,50 @@ LinkingDirectives: ast::LinkingDirective = { } } -MethodDecl: ast::MethodDecl<'input, &'input str> = { - ".entry" => - ast::MethodDecl::Kernel{ name, in_args }, - ".func" => { - ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params) +MethodDeclaration: ast::MethodDeclaration<'input, &'input str> = { + ".entry" => { + let return_arguments = Vec::new(); + let name = ast::MethodName::Kernel(name); + ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } + }, + ".func" => { + let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); + let name = ast::MethodName::Func(name); + ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } } }; -KernelArguments: Vec> = { +KernelArguments: Vec> = { "(" > ")" => args }; -FnArguments: Vec> = { +FnArguments: Vec> = { "(" > ")" => args }; -KernelInput: ast::Variable = { +KernelInput: ast::Variable<&'input str> = { => { let (align, v_type, name) = v; ast::Variable { align, - v_type: ast::KernelArgumentType::Normal(v_type), + v_type, + state_space: ast::StateSpace::Param, name, array_init: Vec::new() } } } -FnInput: ast::Variable = { +FnInput: ast::Variable<&'input str> = { => { let (align, v_type, name) = v; - let v_type = ast::FnArgumentType::Reg(v_type); - ast::Variable{ align, v_type, name, array_init: Vec::new() } + let state_space = ast::StateSpace::Reg; + ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() } }, => { let (align, v_type, name) = v; - let v_type = ast::FnArgumentType::Param(v_type); - ast::Variable{ align, v_type, name, array_init: Vec::new() } + let state_space = ast::StateSpace::Param; + ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() } } } @@ -508,141 +522,148 @@ VariableParam: u32 = { "<" ">" => n } -Variable: ast::Variable = { +Variable: ast::Variable<&'input str> = { => { let (align, v_type, name) = v; - let v_type = ast::VariableType::Reg(v_type); - ast::Variable {align, v_type, name, array_init: Vec::new()} + let state_space = ast::StateSpace::Reg; + ast::Variable {align, v_type, state_space, name, array_init: Vec::new()} }, LocalVariable, => { let (align, array_init, v_type, name) = v; - let v_type = ast::VariableType::Param(v_type); - ast::Variable {align, v_type, name, array_init} + let state_space = ast::StateSpace::Param; + ast::Variable {align, v_type, state_space, name, array_init} }, SharedVariable, }; -RegVariable: (Option, ast::VariableRegType, &'input str) = { +RegVariable: (Option, ast::Type, &'input str) = { ".reg" > => { let (align, t, name) = var; - let v_type = ast::VariableRegType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, v_type, name) }, ".reg" > => { let (align, v_len, t, name) = var; - let v_type = ast::VariableRegType::Vector(t, v_len); + let v_type = ast::Type::Vector(t, v_len); (align, v_type, name) } } -LocalVariable: ast::Variable = { +LocalVariable: ast::Variable<&'input str> = { ".local" > => { let (align, t, name) = var; - let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t)); - ast::Variable { align, v_type, name, array_init: Vec::new() } + let v_type = ast::Type::Scalar(t); + let state_space = ast::StateSpace::Local; + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".local" > => { let (align, v_len, t, name) = var; - let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len)); - ast::Variable { align, v_type, name, array_init: Vec::new() } + let v_type = ast::Type::Vector(t, v_len); + let state_space = ast::StateSpace::Local; + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".local" > =>? { let (align, t, name, arr_or_ptr) = var; + let state_space = ast::StateSpace::Local; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableLocalType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); } }; - Ok(ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init }) + Ok(ast::Variable { align, v_type, state_space, name, array_init }) } } -SharedVariable: ast::Variable = { +SharedVariable: ast::Variable<&'input str> = { ".shared" > => { let (align, t, name) = var; - let v_type = ast::VariableGlobalType::Scalar(t); - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + let v_type = ast::Type::Scalar(t); + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".shared" > => { let (align, v_len, t, name) = var; - let v_type = ast::VariableGlobalType::Vector(t, v_len); - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + let v_type = ast::Type::Vector(t, v_len); + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".shared" > =>? { let (align, t, name, arr_or_ptr) = var; + let state_space = ast::StateSpace::Shared; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableGlobalType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); } }; - Ok(ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init }) + Ok(ast::Variable { align, v_type, state_space, name, array_init }) } } - -ModuleVariable: ast::Variable = { - LinkingDirectives ".global" => { +ModuleVariable: (ast::LinkingDirective, ast::Variable<&'input str>) = { + ".global" => { let (align, v_type, name, array_init) = def; - ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init } + let state_space = ast::StateSpace::Global; + (linking, ast::Variable { align, v_type, state_space, name, array_init }) }, - LinkingDirectives ".shared" => { + ".shared" => { let (align, v_type, name, array_init) = def; - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + (linking, ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }) }, - > > =>? { + > > =>? { let (align, t, name, arr_or_ptr) = var; - let (v_type, array_init) = match arr_or_ptr { + let (v_type, state_space, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { if space == ".global" { - (ast::VariableType::Global(ast::VariableGlobalType::Array(t, dimensions)), init) + (ast::Type::Array(t, dimensions), ast::StateSpace::Global, init) } else { - (ast::VariableType::Shared(ast::VariableGlobalType::Array(t, dimensions)), init) + (ast::Type::Array(t, dimensions), ast::StateSpace::Shared, init) } } ast::ArrayOrPointer::Pointer => { - if !ldirs.contains(ast::LinkingDirective::EXTERN) { + if !linking.contains(ast::LinkingDirective::EXTERN) { return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); } if space == ".global" { - (ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Global)), Vec::new()) + (ast::Type::Array(t, Vec::new()), ast::StateSpace::Global, Vec::new()) } else { - (ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Shared)), Vec::new()) + (ast::Type::Array(t, Vec::new()), ast::StateSpace::Shared, Vec::new()) } } }; - Ok(ast::Variable{ align, array_init, v_type, name }) + Ok((linking, ast::Variable{ align, v_type, state_space, name, array_init })) } } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space -ParamVariable: (Option, Vec, ast::VariableParamType, &'input str) = { +ParamVariable: (Option, Vec, ast::Type, &'input str) = { ".param" > => { let (align, t, name) = var; - let v_type = ast::VariableParamType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, Vec::new(), v_type, name) }, ".param" > => { let (align, t, name, arr_or_ptr) = var; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableParamType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { - (ast::VariableParamType::Pointer(t, ast::PointerStateSpace::Param), Vec::new()) + (ast::Type::Scalar(t), Vec::new()) } }; (align, array_init, v_type, name) } } -ParamDeclaration: (Option, ast::VariableParamType, &'input str) = { +ParamDeclaration: (Option, ast::Type, &'input str) = { =>? { let (align, array_init, v_type, name) = var; if array_init.len() > 0 { @@ -653,56 +674,56 @@ ParamDeclaration: (Option, ast::VariableParamType, &'input str) = { } } -GlobalVariableDefinitionNoArray: (Option, ast::VariableGlobalType, &'input str, Vec) = { +GlobalVariableDefinitionNoArray: (Option, ast::Type, &'input str, Vec) = { > => { let (align, t, name) = scalar; - let v_type = ast::VariableGlobalType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, v_type, name, Vec::new()) }, > => { let (align, v_len, t, name) = var; - let v_type = ast::VariableGlobalType::Vector(t, v_len); + let v_type = ast::Type::Vector(t, v_len); (align, v_type, name, Vec::new()) }, } #[inline] -SizedScalarType: ast::SizedScalarType = { - ".b8" => ast::SizedScalarType::B8, - ".b16" => ast::SizedScalarType::B16, - ".b32" => ast::SizedScalarType::B32, - ".b64" => ast::SizedScalarType::B64, - ".u8" => ast::SizedScalarType::U8, - ".u16" => ast::SizedScalarType::U16, - ".u32" => ast::SizedScalarType::U32, - ".u64" => ast::SizedScalarType::U64, - ".s8" => ast::SizedScalarType::S8, - ".s16" => ast::SizedScalarType::S16, - ".s32" => ast::SizedScalarType::S32, - ".s64" => ast::SizedScalarType::S64, - ".f16" => ast::SizedScalarType::F16, - ".f16x2" => ast::SizedScalarType::F16x2, - ".f32" => ast::SizedScalarType::F32, - ".f64" => ast::SizedScalarType::F64, +SizedScalarType: ast::ScalarType = { + ".b8" => ast::ScalarType::B8, + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u8" => ast::ScalarType::U8, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s8" => ast::ScalarType::S8, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f16" => ast::ScalarType::F16, + ".f16x2" => ast::ScalarType::F16x2, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, } #[inline] -LdStScalarType: ast::LdStScalarType = { - ".b8" => ast::LdStScalarType::B8, - ".b16" => ast::LdStScalarType::B16, - ".b32" => ast::LdStScalarType::B32, - ".b64" => ast::LdStScalarType::B64, - ".u8" => ast::LdStScalarType::U8, - ".u16" => ast::LdStScalarType::U16, - ".u32" => ast::LdStScalarType::U32, - ".u64" => ast::LdStScalarType::U64, - ".s8" => ast::LdStScalarType::S8, - ".s16" => ast::LdStScalarType::S16, - ".s32" => ast::LdStScalarType::S32, - ".s64" => ast::LdStScalarType::S64, - ".f16" => ast::LdStScalarType::F16, - ".f32" => ast::LdStScalarType::F32, - ".f64" => ast::LdStScalarType::F64, +LdStScalarType: ast::ScalarType = { + ".b8" => ast::ScalarType::B8, + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u8" => ast::ScalarType::U8, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s8" => ast::ScalarType::S8, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f16" => ast::ScalarType::F16, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, } Instruction: ast::Instruction> = { @@ -755,7 +776,7 @@ InstLd: ast::Instruction> = { ast::Instruction::Ld( ast::LdDetails { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ss.unwrap_or(ast::LdStateSpace::Generic), + state_space: ss.unwrap_or(ast::StateSpace::Generic), caching: cop.unwrap_or(ast::LdCacheOperator::Cached), typ: t, non_coherent: false @@ -767,7 +788,7 @@ InstLd: ast::Instruction> = { ast::Instruction::Ld( ast::LdDetails { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ast::LdStateSpace::Global, + state_space: ast::StateSpace::Global, caching: cop.unwrap_or(ast::LdCacheOperator::Cached), typ: t, non_coherent: false @@ -779,7 +800,7 @@ InstLd: ast::Instruction> = { ast::Instruction::Ld( ast::LdDetails { qualifier: ast::LdStQualifier::Weak, - state_space: ast::LdStateSpace::Global, + state_space: ast::StateSpace::Global, caching: cop.unwrap_or(ast::LdCacheOperator::Cached), typ: t, non_coherent: true @@ -789,9 +810,9 @@ InstLd: ast::Instruction> = { } }; -LdStType: ast::LdStType = { - => ast::LdStType::Vector(t, v), - => ast::LdStType::Scalar(t), +LdStType: ast::Type = { + => ast::Type::Vector(t, v), + => ast::Type::Scalar(t), } LdStQualifier: ast::LdStQualifier = { @@ -807,11 +828,11 @@ MemScope: ast::MemScope = { ".sys" => ast::MemScope::Sys }; -LdNonGlobalStateSpace: ast::LdStateSpace = { - ".const" => ast::LdStateSpace::Const, - ".local" => ast::LdStateSpace::Local, - ".param" => ast::LdStateSpace::Param, - ".shared" => ast::LdStateSpace::Shared, +LdNonGlobalStateSpace: ast::StateSpace = { + ".const" => ast::StateSpace::Const, + ".local" => ast::StateSpace::Local, + ".param" => ast::StateSpace::Param, + ".shared" => ast::StateSpace::Shared, }; LdCacheOperator: ast::LdCacheOperator = { @@ -899,39 +920,39 @@ RoundingModeInt : ast::RoundingMode = { ".rpi" => ast::RoundingMode::PositiveInf, }; -IntType : ast::IntType = { - ".u16" => ast::IntType::U16, - ".u32" => ast::IntType::U32, - ".u64" => ast::IntType::U64, - ".s16" => ast::IntType::S16, - ".s32" => ast::IntType::S32, - ".s64" => ast::IntType::S64, +IntType : ast::ScalarType = { + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; -IntType3264: ast::IntType = { - ".u32" => ast::IntType::U32, - ".u64" => ast::IntType::U64, - ".s32" => ast::IntType::S32, - ".s64" => ast::IntType::S64, +IntType3264: ast::ScalarType = { + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, } -UIntType: ast::UIntType = { - ".u16" => ast::UIntType::U16, - ".u32" => ast::UIntType::U32, - ".u64" => ast::UIntType::U64, +UIntType: ast::ScalarType = { + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, }; -SIntType: ast::SIntType = { - ".s16" => ast::SIntType::S16, - ".s32" => ast::SIntType::S32, - ".s64" => ast::SIntType::S64, +SIntType: ast::ScalarType = { + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; -FloatType: ast::FloatType = { - ".f16" => ast::FloatType::F16, - ".f16x2" => ast::FloatType::F16x2, - ".f32" => ast::FloatType::F32, - ".f64" => ast::FloatType::F64, +FloatType: ast::ScalarType = { + ".f16" => ast::ScalarType::F16, + ".f16x2" => ast::ScalarType::F16x2, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add @@ -1023,11 +1044,11 @@ InstNot: ast::Instruction> = { "not" => ast::Instruction::Not(t, a) }; -BooleanType: ast::BooleanType = { - ".pred" => ast::BooleanType::Pred, - ".b16" => ast::BooleanType::B16, - ".b32" => ast::BooleanType::B32, - ".b64" => ast::BooleanType::B64, +BooleanType: ast::ScalarType = { + ".pred" => ast::ScalarType::Pred, + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at @@ -1080,8 +1101,8 @@ InstCvt: ast::Instruction> = { rounding: r, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F16, - src: ast::FloatType::F16 + dst: ast::ScalarType::F16, + src: ast::ScalarType::F16 } ), a) }, @@ -1091,8 +1112,8 @@ InstCvt: ast::Instruction> = { rounding: None, flush_to_zero: Some(f.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F32, - src: ast::FloatType::F16 + dst: ast::ScalarType::F32, + src: ast::ScalarType::F16 } ), a) }, @@ -1102,8 +1123,8 @@ InstCvt: ast::Instruction> = { rounding: None, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F64, - src: ast::FloatType::F16 + dst: ast::ScalarType::F64, + src: ast::ScalarType::F16 } ), a) }, @@ -1113,8 +1134,8 @@ InstCvt: ast::Instruction> = { rounding: Some(r), flush_to_zero: Some(f.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F16, - src: ast::FloatType::F32 + dst: ast::ScalarType::F16, + src: ast::ScalarType::F32 } ), a) }, @@ -1124,8 +1145,8 @@ InstCvt: ast::Instruction> = { rounding: r, flush_to_zero: Some(f.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F32, - src: ast::FloatType::F32 + dst: ast::ScalarType::F32, + src: ast::ScalarType::F32 } ), a) }, @@ -1135,8 +1156,8 @@ InstCvt: ast::Instruction> = { rounding: None, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F64, - src: ast::FloatType::F32 + dst: ast::ScalarType::F64, + src: ast::ScalarType::F32 } ), a) }, @@ -1146,8 +1167,8 @@ InstCvt: ast::Instruction> = { rounding: Some(r), flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F16, - src: ast::FloatType::F64 + dst: ast::ScalarType::F16, + src: ast::ScalarType::F64 } ), a) }, @@ -1157,8 +1178,8 @@ InstCvt: ast::Instruction> = { rounding: Some(r), flush_to_zero: Some(s.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F32, - src: ast::FloatType::F64 + dst: ast::ScalarType::F32, + src: ast::ScalarType::F64 } ), a) }, @@ -1168,28 +1189,28 @@ InstCvt: ast::Instruction> = { rounding: r, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F64, - src: ast::FloatType::F64 + dst: ast::ScalarType::F64, + src: ast::ScalarType::F64 } ), a) }, }; -CvtTypeInt: ast::IntType = { - ".u8" => ast::IntType::U8, - ".u16" => ast::IntType::U16, - ".u32" => ast::IntType::U32, - ".u64" => ast::IntType::U64, - ".s8" => ast::IntType::S8, - ".s16" => ast::IntType::S16, - ".s32" => ast::IntType::S32, - ".s64" => ast::IntType::S64, +CvtTypeInt: ast::ScalarType = { + ".u8" => ast::ScalarType::U8, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s8" => ast::ScalarType::S8, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; -CvtTypeFloat: ast::FloatType = { - ".f16" => ast::FloatType::F16, - ".f32" => ast::FloatType::F32, - ".f64" => ast::FloatType::F64, +CvtTypeFloat: ast::ScalarType = { + ".f16" => ast::ScalarType::F16, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl @@ -1197,10 +1218,10 @@ InstShl: ast::Instruction> = { "shl" => ast::Instruction::Shl(t, a) }; -ShlType: ast::ShlType = { - ".b16" => ast::ShlType::B16, - ".b32" => ast::ShlType::B32, - ".b64" => ast::ShlType::B64, +ShlType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr @@ -1208,16 +1229,16 @@ InstShr: ast::Instruction> = { "shr" => ast::Instruction::Shr(t, a) }; -ShrType: ast::ShrType = { - ".b16" => ast::ShrType::B16, - ".b32" => ast::ShrType::B32, - ".b64" => ast::ShrType::B64, - ".u16" => ast::ShrType::U16, - ".u32" => ast::ShrType::U32, - ".u64" => ast::ShrType::U64, - ".s16" => ast::ShrType::S16, - ".s32" => ast::ShrType::S32, - ".s64" => ast::ShrType::S64, +ShrType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st @@ -1227,7 +1248,7 @@ InstSt: ast::Instruction> = { ast::Instruction::St( ast::StData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ss.unwrap_or(ast::StStateSpace::Generic), + state_space: ss.unwrap_or(ast::StateSpace::Generic), caching: cop.unwrap_or(ast::StCacheOperator::Writeback), typ: t }, @@ -1241,11 +1262,11 @@ MemoryOperand: ast::Operand<&'input str> = { "[" "]" => o } -StStateSpace: ast::StStateSpace = { - ".global" => ast::StStateSpace::Global, - ".local" => ast::StStateSpace::Local, - ".param" => ast::StStateSpace::Param, - ".shared" => ast::StStateSpace::Shared, +StStateSpace: ast::StateSpace = { + ".global" => ast::StateSpace::Global, + ".local" => ast::StateSpace::Local, + ".param" => ast::StateSpace::Param, + ".shared" => ast::StateSpace::Shared, }; StCacheOperator: ast::StCacheOperator = { @@ -1264,7 +1285,7 @@ InstRet: ast::Instruction> = { InstCvta: ast::Instruction> = { "cvta" => { ast::Instruction::Cvta(ast::CvtaDetails { - to: ast::CvtaStateSpace::Generic, + to: ast::StateSpace::Generic, from, size: s }, @@ -1273,18 +1294,18 @@ InstCvta: ast::Instruction> = { "cvta" ".to" => { ast::Instruction::Cvta(ast::CvtaDetails { to, - from: ast::CvtaStateSpace::Generic, + from: ast::StateSpace::Generic, size: s }, a) } } -CvtaStateSpace: ast::CvtaStateSpace = { - ".const" => ast::CvtaStateSpace::Const, - ".global" => ast::CvtaStateSpace::Global, - ".local" => ast::CvtaStateSpace::Local, - ".shared" => ast::CvtaStateSpace::Shared, +CvtaStateSpace: ast::StateSpace = { + ".const" => ast::StateSpace::Const, + ".global" => ast::StateSpace::Global, + ".local" => ast::StateSpace::Local, + ".shared" => ast::StateSpace::Shared, } CvtaSize: ast::CvtaSize = { @@ -1393,16 +1414,16 @@ MinMaxDetails: ast::MinMaxDetails = { => ast::MinMaxDetails::Unsigned(t), => ast::MinMaxDetails::Signed(t), ".f32" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F32 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F32 } ), ".f64" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::FloatType::F64 } + ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::ScalarType::F64 } ), ".f16" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16 } ), ".f16x2" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16x2 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16x2 } ) } @@ -1411,18 +1432,18 @@ InstSelp: ast::Instruction> = { "selp" => ast::Instruction::Selp(t, a), }; -SelpType: ast::SelpType = { - ".b16" => ast::SelpType::B16, - ".b32" => ast::SelpType::B32, - ".b64" => ast::SelpType::B64, - ".u16" => ast::SelpType::U16, - ".u32" => ast::SelpType::U32, - ".u64" => ast::SelpType::U64, - ".s16" => ast::SelpType::S16, - ".s32" => ast::SelpType::S32, - ".s64" => ast::SelpType::S64, - ".f32" => ast::SelpType::F32, - ".f64" => ast::SelpType::F64, +SelpType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar @@ -1442,7 +1463,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Bit { op, typ } }; ast::Instruction::Atom(details,a) @@ -1451,10 +1472,10 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Inc, - typ: ast::UIntType::U32 + typ: ast::ScalarType::U32 } }; ast::Instruction::Atom(details,a) @@ -1463,10 +1484,10 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Dec, - typ: ast::UIntType::U32 + typ: ast::ScalarType::U32 } }; ast::Instruction::Atom(details,a) @@ -1476,7 +1497,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Float { op, typ } }; ast::Instruction::Atom(details,a) @@ -1485,7 +1506,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op, typ } }; ast::Instruction::Atom(details,a) @@ -1494,7 +1515,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Signed { op, typ } }; ast::Instruction::Atom(details,a) @@ -1506,7 +1527,7 @@ InstAtomCas: ast::Instruction> = { let details = ast::AtomCasDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), typ, }; ast::Instruction::AtomCas(details,a) @@ -1520,9 +1541,9 @@ AtomSemantics: ast::AtomSemantics = { ".acq_rel" => ast::AtomSemantics::AcquireRelease } -AtomSpace: ast::AtomSpace = { - ".global" => ast::AtomSpace::Global, - ".shared" => ast::AtomSpace::Shared +AtomSpace: ast::StateSpace = { + ".global" => ast::StateSpace::Global, + ".shared" => ast::StateSpace::Shared } AtomBitOp: ast::AtomBitOp = { @@ -1544,19 +1565,19 @@ AtomSIntOp: ast::AtomSIntOp = { ".max" => ast::AtomSIntOp::Max, } -BitType: ast::BitType = { - ".b32" => ast::BitType::B32, - ".b64" => ast::BitType::B64, +BitType: ast::ScalarType = { + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, } -UIntType3264: ast::UIntType = { - ".u32" => ast::UIntType::U32, - ".u64" => ast::UIntType::U64, +UIntType3264: ast::ScalarType = { + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, } -SIntType3264: ast::SIntType = { - ".s32" => ast::SIntType::S32, - ".s64" => ast::SIntType::S64, +SIntType3264: ast::ScalarType = { + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div @@ -1566,7 +1587,7 @@ InstDiv: ast::Instruction> = { "div" => ast::Instruction::Div(ast::DivDetails::Signed(t), a), "div" ".f32" => { let inner = ast::DivFloatDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: Some(ftz.is_some()), kind }; @@ -1574,7 +1595,7 @@ InstDiv: ast::Instruction> = { }, "div" ".f64" => { let inner = ast::DivFloatDetails { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, flush_to_zero: None, kind: ast::DivFloatKind::Rounding(rnd) }; @@ -1592,7 +1613,7 @@ DivFloatKind: ast::DivFloatKind = { InstSqrt: ast::Instruction> = { "sqrt" ".approx" ".f32" => { let details = ast::SqrtDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: Some(ftz.is_some()), kind: ast::SqrtKind::Approx, }; @@ -1600,7 +1621,7 @@ InstSqrt: ast::Instruction> = { }, "sqrt" ".f32" => { let details = ast::SqrtDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: Some(ftz.is_some()), kind: ast::SqrtKind::Rounding(rnd), }; @@ -1608,7 +1629,7 @@ InstSqrt: ast::Instruction> = { }, "sqrt" ".f64" => { let details = ast::SqrtDetails { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, flush_to_zero: None, kind: ast::SqrtKind::Rounding(rnd), }; @@ -1621,14 +1642,14 @@ InstSqrt: ast::Instruction> = { InstRsqrt: ast::Instruction> = { "rsqrt" ".approx" ".f32" => { let details = ast::RsqrtDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: ftz.is_some(), }; ast::Instruction::Rsqrt(details, a) }, "rsqrt" ".approx" ".f64" => { let details = ast::RsqrtDetails { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, flush_to_zero: ftz.is_some(), }; ast::Instruction::Rsqrt(details, a) @@ -1739,7 +1760,7 @@ ArithDetails: ast::ArithDetails = { saturate: false, }), ".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S32, + typ: ast::ScalarType::S32, saturate: true, }), => ast::ArithDetails::Float(f) @@ -1747,25 +1768,25 @@ ArithDetails: ast::ArithDetails = { ArithFloat: ast::ArithFloat = { ".f32" => ast::ArithFloat { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, rounding: rn, flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".f64" => ast::ArithFloat { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, rounding: rn, flush_to_zero: None, saturate: false, }, ".f16" => ast::ArithFloat { - typ: ast::FloatType::F16, + typ: ast::ScalarType::F16, rounding: rn.map(|_| ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".f16x2" => ast::ArithFloat { - typ: ast::FloatType::F16x2, + typ: ast::ScalarType::F16x2, rounding: rn.map(|_| ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), @@ -1774,25 +1795,25 @@ ArithFloat: ast::ArithFloat = { ArithFloatMustRound: ast::ArithFloat = { ".f32" => ast::ArithFloat { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, rounding: Some(rn), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".f64" => ast::ArithFloat { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, rounding: Some(rn), flush_to_zero: None, saturate: false, }, ".rn" ".f16" => ast::ArithFloat { - typ: ast::FloatType::F16, + typ: ast::ScalarType::F16, rounding: Some(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".rn" ".f16x2" => ast::ArithFloat { - typ: ast::FloatType::F16x2, + typ: ast::ScalarType::F16x2, rounding: Some(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), diff --git a/ptx/src/test/spirv_run/and.spvtxt b/ptx/src/test/spirv_run/and.spvtxt index a378602..f66639a 100644 --- a/ptx/src/test/spirv_run/and.spvtxt +++ b/ptx/src/test/spirv_run/and.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %34 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %41 = OpBitcast %_ptr_Generic_uchar %24 + %42 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %41 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %42 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/atom_add.spvtxt b/ptx/src/test/spirv_run/atom_add.spvtxt index 3966da6..b4de00a 100644 --- a/ptx/src/test/spirv_run/atom_add.spvtxt +++ b/ptx/src/test/spirv_run/atom_add.spvtxt @@ -24,6 +24,7 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %_ptr_Workgroup_uint = OpTypePointer Workgroup %uint %uint_1 = OpConstant %uint 1 %uint_0 = OpConstant %uint 0 @@ -49,9 +50,11 @@ %13 = OpLoad %uint %29 Aligned 4 OpStore %7 %13 %16 = OpLoad %ulong %5 - %26 = OpIAdd %ulong %16 %ulong_4 - %30 = OpConvertUToPtr %_ptr_Generic_uint %26 - %15 = OpLoad %uint %30 Aligned 4 + %30 = OpConvertUToPtr %_ptr_Generic_uint %16 + %51 = OpBitcast %_ptr_Generic_uchar %30 + %52 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %51 %ulong_4 + %26 = OpBitcast %_ptr_Generic_uint %52 + %15 = OpLoad %uint %26 Aligned 4 OpStore %8 %15 %17 = OpLoad %uint %7 %31 = OpBitcast %_ptr_Workgroup_uint %4 @@ -69,8 +72,10 @@ OpStore %34 %22 Aligned 4 %23 = OpLoad %ulong %6 %24 = OpLoad %uint %8 - %28 = OpIAdd %ulong %23 %ulong_4_0 - %35 = OpConvertUToPtr %_ptr_Generic_uint %28 - OpStore %35 %24 Aligned 4 + %35 = OpConvertUToPtr %_ptr_Generic_uint %23 + %56 = OpBitcast %_ptr_Generic_uchar %35 + %57 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %56 %ulong_4_0 + %28 = OpBitcast %_ptr_Generic_uint %57 + OpStore %28 %24 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_add_float.spvtxt b/ptx/src/test/spirv_run/atom_add_float.spvtxt index c2292f1..7d25632 100644 --- a/ptx/src/test/spirv_run/atom_add_float.spvtxt +++ b/ptx/src/test/spirv_run/atom_add_float.spvtxt @@ -28,6 +28,7 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_4_0 = OpConstant %ulong 4 %37 = OpFunction %float None %46 %39 = OpFunctionParameter %_ptr_Workgroup_float @@ -54,9 +55,11 @@ %13 = OpLoad %float %29 Aligned 4 OpStore %7 %13 %16 = OpLoad %ulong %5 - %26 = OpIAdd %ulong %16 %ulong_4 - %30 = OpConvertUToPtr %_ptr_Generic_float %26 - %15 = OpLoad %float %30 Aligned 4 + %30 = OpConvertUToPtr %_ptr_Generic_float %16 + %58 = OpBitcast %_ptr_Generic_uchar %30 + %59 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %58 %ulong_4 + %26 = OpBitcast %_ptr_Generic_float %59 + %15 = OpLoad %float %26 Aligned 4 OpStore %8 %15 %17 = OpLoad %float %7 %31 = OpBitcast %_ptr_Workgroup_float %4 @@ -74,8 +77,10 @@ OpStore %34 %22 Aligned 4 %23 = OpLoad %ulong %6 %24 = OpLoad %float %8 - %28 = OpIAdd %ulong %23 %ulong_4_0 - %35 = OpConvertUToPtr %_ptr_Generic_float %28 - OpStore %35 %24 Aligned 4 + %35 = OpConvertUToPtr %_ptr_Generic_float %23 + %60 = OpBitcast %_ptr_Generic_uchar %35 + %61 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %60 %ulong_4_0 + %28 = OpBitcast %_ptr_Generic_float %61 + OpStore %28 %24 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_cas.spvtxt b/ptx/src/test/spirv_run/atom_cas.spvtxt index e1feb0a..7c2f4fa 100644 --- a/ptx/src/test/spirv_run/atom_cas.spvtxt +++ b/ptx/src/test/spirv_run/atom_cas.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %uint_100 = OpConstant %uint 100 %uint_1 = OpConstant %uint 1 %uint_0 = OpConstant %uint 0 @@ -45,16 +47,20 @@ OpStore %6 %12 %15 = OpLoad %ulong %4 %16 = OpLoad %uint %6 - %24 = OpIAdd %ulong %15 %ulong_4 - %32 = OpConvertUToPtr %_ptr_Generic_uint %24 + %31 = OpConvertUToPtr %_ptr_Generic_uint %15 + %49 = OpBitcast %_ptr_Generic_uchar %31 + %50 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %49 %ulong_4 + %24 = OpBitcast %_ptr_Generic_uint %50 %33 = OpCopyObject %uint %16 - %31 = OpAtomicCompareExchange %uint %32 %uint_1 %uint_0 %uint_0 %uint_100 %33 - %14 = OpCopyObject %uint %31 + %32 = OpAtomicCompareExchange %uint %24 %uint_1 %uint_0 %uint_0 %uint_100 %33 + %14 = OpCopyObject %uint %32 OpStore %6 %14 %18 = OpLoad %ulong %4 - %27 = OpIAdd %ulong %18 %ulong_4_0 - %34 = OpConvertUToPtr %_ptr_Generic_uint %27 - %17 = OpLoad %uint %34 Aligned 4 + %34 = OpConvertUToPtr %_ptr_Generic_uint %18 + %53 = OpBitcast %_ptr_Generic_uchar %34 + %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_4_0 + %27 = OpBitcast %_ptr_Generic_uint %54 + %17 = OpLoad %uint %27 Aligned 4 OpStore %7 %17 %19 = OpLoad %ulong %5 %20 = OpLoad %uint %6 @@ -62,8 +68,10 @@ OpStore %35 %20 Aligned 4 %21 = OpLoad %ulong %5 %22 = OpLoad %uint %7 - %29 = OpIAdd %ulong %21 %ulong_4_1 - %36 = OpConvertUToPtr %_ptr_Generic_uint %29 - OpStore %36 %22 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_uint %21 + %55 = OpBitcast %_ptr_Generic_uchar %36 + %56 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %55 %ulong_4_1 + %29 = OpBitcast %_ptr_Generic_uint %56 + OpStore %29 %22 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_inc.spvtxt b/ptx/src/test/spirv_run/atom_inc.spvtxt index 11b4243..4855cd4 100644 --- a/ptx/src/test/spirv_run/atom_inc.spvtxt +++ b/ptx/src/test/spirv_run/atom_inc.spvtxt @@ -10,14 +10,14 @@ %47 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "atom_inc" - OpDecorate %42 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_global_inc" Import OpDecorate %38 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_generic_inc" Import + OpDecorate %42 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_global_inc" Import %void = OpTypeVoid %uint = OpTypeInt 32 0 -%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint - %51 = OpTypeFunction %uint %_ptr_CrossWorkgroup_uint %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %53 = OpTypeFunction %uint %_ptr_Generic_uint %uint + %51 = OpTypeFunction %uint %_ptr_Generic_uint %uint +%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint + %53 = OpTypeFunction %uint %_ptr_CrossWorkgroup_uint %uint %ulong = OpTypeInt 64 0 %55 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong @@ -25,15 +25,17 @@ %uint_101 = OpConstant %uint 101 %uint_101_0 = OpConstant %uint 101 %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 - %42 = OpFunction %uint None %51 - %44 = OpFunctionParameter %_ptr_CrossWorkgroup_uint - %45 = OpFunctionParameter %uint - OpFunctionEnd - %38 = OpFunction %uint None %53 + %38 = OpFunction %uint None %51 %40 = OpFunctionParameter %_ptr_Generic_uint %41 = OpFunctionParameter %uint OpFunctionEnd + %42 = OpFunction %uint None %53 + %44 = OpFunctionParameter %_ptr_CrossWorkgroup_uint + %45 = OpFunctionParameter %uint + OpFunctionEnd %1 = OpFunction %void None %55 %9 = OpFunctionParameter %ulong %10 = OpFunctionParameter %ulong @@ -69,13 +71,17 @@ OpStore %34 %20 Aligned 4 %21 = OpLoad %ulong %5 %22 = OpLoad %uint %7 - %28 = OpIAdd %ulong %21 %ulong_4 - %35 = OpConvertUToPtr %_ptr_Generic_uint %28 - OpStore %35 %22 Aligned 4 + %35 = OpConvertUToPtr %_ptr_Generic_uint %21 + %60 = OpBitcast %_ptr_Generic_uchar %35 + %61 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %60 %ulong_4 + %28 = OpBitcast %_ptr_Generic_uint %61 + OpStore %28 %22 Aligned 4 %23 = OpLoad %ulong %5 %24 = OpLoad %uint %8 - %30 = OpIAdd %ulong %23 %ulong_8 - %36 = OpConvertUToPtr %_ptr_Generic_uint %30 - OpStore %36 %24 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_uint %23 + %62 = OpBitcast %_ptr_Generic_uchar %36 + %63 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %62 %ulong_8 + %30 = OpBitcast %_ptr_Generic_uint %63 + OpStore %30 %24 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/bfe.spvtxt b/ptx/src/test/spirv_run/bfe.spvtxt index 535ede9..0001808 100644 --- a/ptx/src/test/spirv_run/bfe.spvtxt +++ b/ptx/src/test/spirv_run/bfe.spvtxt @@ -20,6 +20,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %34 = OpFunction %uint None %43 %36 = OpFunctionParameter %uint @@ -48,14 +50,18 @@ %13 = OpLoad %uint %29 Aligned 4 OpStore %6 %13 %16 = OpLoad %ulong %4 - %26 = OpIAdd %ulong %16 %ulong_4 - %30 = OpConvertUToPtr %_ptr_Generic_uint %26 - %15 = OpLoad %uint %30 Aligned 4 + %30 = OpConvertUToPtr %_ptr_Generic_uint %16 + %51 = OpBitcast %_ptr_Generic_uchar %30 + %52 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %51 %ulong_4 + %26 = OpBitcast %_ptr_Generic_uint %52 + %15 = OpLoad %uint %26 Aligned 4 OpStore %7 %15 %18 = OpLoad %ulong %4 - %28 = OpIAdd %ulong %18 %ulong_8 - %31 = OpConvertUToPtr %_ptr_Generic_uint %28 - %17 = OpLoad %uint %31 Aligned 4 + %31 = OpConvertUToPtr %_ptr_Generic_uint %18 + %53 = OpBitcast %_ptr_Generic_uchar %31 + %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_8 + %28 = OpBitcast %_ptr_Generic_uint %54 + %17 = OpLoad %uint %28 Aligned 4 OpStore %8 %17 %20 = OpLoad %uint %6 %21 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/bfi.spvtxt b/ptx/src/test/spirv_run/bfi.spvtxt index a226f78..1979939 100644 --- a/ptx/src/test/spirv_run/bfi.spvtxt +++ b/ptx/src/test/spirv_run/bfi.spvtxt @@ -20,6 +20,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %ulong_12 = OpConstant %ulong 12 %44 = OpFunction %uint None %54 @@ -51,19 +53,25 @@ %14 = OpLoad %uint %35 Aligned 4 OpStore %6 %14 %17 = OpLoad %ulong %4 - %30 = OpIAdd %ulong %17 %ulong_4 - %36 = OpConvertUToPtr %_ptr_Generic_uint %30 - %16 = OpLoad %uint %36 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_uint %17 + %62 = OpBitcast %_ptr_Generic_uchar %36 + %63 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %62 %ulong_4 + %30 = OpBitcast %_ptr_Generic_uint %63 + %16 = OpLoad %uint %30 Aligned 4 OpStore %7 %16 %19 = OpLoad %ulong %4 - %32 = OpIAdd %ulong %19 %ulong_8 - %37 = OpConvertUToPtr %_ptr_Generic_uint %32 - %18 = OpLoad %uint %37 Aligned 4 + %37 = OpConvertUToPtr %_ptr_Generic_uint %19 + %64 = OpBitcast %_ptr_Generic_uchar %37 + %65 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %64 %ulong_8 + %32 = OpBitcast %_ptr_Generic_uint %65 + %18 = OpLoad %uint %32 Aligned 4 OpStore %8 %18 %21 = OpLoad %ulong %4 - %34 = OpIAdd %ulong %21 %ulong_12 - %38 = OpConvertUToPtr %_ptr_Generic_uint %34 - %20 = OpLoad %uint %38 Aligned 4 + %38 = OpConvertUToPtr %_ptr_Generic_uint %21 + %66 = OpBitcast %_ptr_Generic_uchar %38 + %67 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %66 %ulong_12 + %34 = OpBitcast %_ptr_Generic_uint %67 + %20 = OpLoad %uint %34 Aligned 4 OpStore %9 %20 %23 = OpLoad %uint %6 %24 = OpLoad %uint %7 @@ -71,7 +79,7 @@ %26 = OpLoad %uint %9 %40 = OpCopyObject %uint %23 %41 = OpCopyObject %uint %24 - %39 = OpFunctionCall %uint %44 %41 %40 %25 %26 + %39 = OpFunctionCall %uint %44 %40 %41 %25 %26 %22 = OpCopyObject %uint %39 OpStore %6 %22 %27 = OpLoad %ulong %5 diff --git a/ptx/src/test/spirv_run/call.spvtxt b/ptx/src/test/spirv_run/call.spvtxt index 5473234..6929b1e 100644 --- a/ptx/src/test/spirv_run/call.spvtxt +++ b/ptx/src/test/spirv_run/call.spvtxt @@ -42,7 +42,7 @@ %23 = OpBitcast %_ptr_Function_ulong %10 %24 = OpCopyObject %ulong %18 OpStore %23 %24 Aligned 8 - %43 = OpFunctionCall %void %1 %11 %10 + %43 = OpFunctionCall %void %1 %10 %11 %19 = OpLoad %ulong %11 Aligned 8 OpStore %9 %19 %20 = OpLoad %ulong %8 @@ -52,8 +52,8 @@ OpReturn OpFunctionEnd %1 = OpFunction %void None %44 - %27 = OpFunctionParameter %_ptr_Function_ulong %28 = OpFunctionParameter %_ptr_Function_ulong + %27 = OpFunctionParameter %_ptr_Function_ulong %35 = OpLabel %29 = OpVariable %_ptr_Function_ulong Function %30 = OpLoad %ulong %28 Aligned 8 diff --git a/ptx/src/test/spirv_run/cvt_rni.spvtxt b/ptx/src/test/spirv_run/cvt_rni.spvtxt index 288a939..e10999c 100644 --- a/ptx/src/test/spirv_run/cvt_rni.spvtxt +++ b/ptx/src/test/spirv_run/cvt_rni.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_4_0 = OpConstant %ulong 4 %1 = OpFunction %void None %37 %8 = OpFunctionParameter %ulong @@ -40,9 +42,11 @@ %12 = OpLoad %float %28 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %25 = OpIAdd %ulong %15 %ulong_4 - %29 = OpConvertUToPtr %_ptr_Generic_float %25 - %14 = OpLoad %float %29 Aligned 4 + %29 = OpConvertUToPtr %_ptr_Generic_float %15 + %44 = OpBitcast %_ptr_Generic_uchar %29 + %45 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %44 %ulong_4 + %25 = OpBitcast %_ptr_Generic_float %45 + %14 = OpLoad %float %25 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %16 = OpExtInst %float %34 rint %17 @@ -56,8 +60,10 @@ OpStore %30 %21 Aligned 4 %22 = OpLoad %ulong %5 %23 = OpLoad %float %7 - %27 = OpIAdd %ulong %22 %ulong_4_0 - %31 = OpConvertUToPtr %_ptr_Generic_float %27 - OpStore %31 %23 Aligned 4 + %31 = OpConvertUToPtr %_ptr_Generic_float %22 + %46 = OpBitcast %_ptr_Generic_uchar %31 + %47 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %46 %ulong_4_0 + %27 = OpBitcast %_ptr_Generic_float %47 + OpStore %27 %23 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_rzi.spvtxt b/ptx/src/test/spirv_run/cvt_rzi.spvtxt index 68c12c6..7dda454 100644 --- a/ptx/src/test/spirv_run/cvt_rzi.spvtxt +++ b/ptx/src/test/spirv_run/cvt_rzi.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_4_0 = OpConstant %ulong 4 %1 = OpFunction %void None %37 %8 = OpFunctionParameter %ulong @@ -40,9 +42,11 @@ %12 = OpLoad %float %28 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %25 = OpIAdd %ulong %15 %ulong_4 - %29 = OpConvertUToPtr %_ptr_Generic_float %25 - %14 = OpLoad %float %29 Aligned 4 + %29 = OpConvertUToPtr %_ptr_Generic_float %15 + %44 = OpBitcast %_ptr_Generic_uchar %29 + %45 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %44 %ulong_4 + %25 = OpBitcast %_ptr_Generic_float %45 + %14 = OpLoad %float %25 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %16 = OpExtInst %float %34 trunc %17 @@ -56,8 +60,10 @@ OpStore %30 %21 Aligned 4 %22 = OpLoad %ulong %5 %23 = OpLoad %float %7 - %27 = OpIAdd %ulong %22 %ulong_4_0 - %31 = OpConvertUToPtr %_ptr_Generic_float %27 - OpStore %31 %23 Aligned 4 + %31 = OpConvertUToPtr %_ptr_Generic_float %22 + %46 = OpBitcast %_ptr_Generic_uchar %31 + %47 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %46 %ulong_4_0 + %27 = OpBitcast %_ptr_Generic_float %47 + OpStore %27 %23 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt b/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt index d9ae053..c1229d4 100644 --- a/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt +++ b/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt @@ -21,8 +21,11 @@ %float = OpTypeFloat 32 %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint %ulong_4_0 = OpConstant %ulong 4 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar %1 = OpFunction %void None %45 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -45,10 +48,12 @@ %12 = OpBitcast %uint %28 OpStore %6 %12 %15 = OpLoad %ulong %4 - %25 = OpIAdd %ulong %15 %ulong_4 - %31 = OpConvertUToPtr %_ptr_Generic_float %25 - %30 = OpLoad %float %31 Aligned 4 - %14 = OpBitcast %uint %30 + %30 = OpConvertUToPtr %_ptr_Generic_float %15 + %53 = OpBitcast %_ptr_Generic_uchar %30 + %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_4 + %25 = OpBitcast %_ptr_Generic_float %54 + %31 = OpLoad %float %25 Aligned 4 + %14 = OpBitcast %uint %31 OpStore %7 %14 %17 = OpLoad %uint %6 %33 = OpBitcast %float %17 @@ -67,9 +72,11 @@ OpStore %36 %37 Aligned 4 %22 = OpLoad %ulong %5 %23 = OpLoad %uint %7 - %27 = OpIAdd %ulong %22 %ulong_4_0 - %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %27 + %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %22 + %57 = OpBitcast %_ptr_CrossWorkgroup_uchar %38 + %58 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %57 %ulong_4_0 + %27 = OpBitcast %_ptr_CrossWorkgroup_uint %58 %39 = OpCopyObject %uint %23 - OpStore %38 %39 Aligned 4 + OpStore %27 %39 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/div_approx.spvtxt b/ptx/src/test/spirv_run/div_approx.spvtxt index 274f73e..858ec8d 100644 --- a/ptx/src/test/spirv_run/div_approx.spvtxt +++ b/ptx/src/test/spirv_run/div_approx.spvtxt @@ -19,6 +19,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -40,9 +42,11 @@ %12 = OpLoad %float %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_float %22 - %14 = OpLoad %float %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_float %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_float %39 + %14 = OpLoad %float %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %18 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/extern_shared.spvtxt b/ptx/src/test/spirv_run/extern_shared.spvtxt index fb2987e..13587d5 100644 --- a/ptx/src/test/spirv_run/extern_shared.spvtxt +++ b/ptx/src/test/spirv_run/extern_shared.spvtxt @@ -7,37 +7,30 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %30 = OpExtInstImport "OpenCL.std" + %27 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %2 "extern_shared" %1 %void = OpTypeVoid %uint = OpTypeInt 32 0 %_ptr_Workgroup_uint = OpTypePointer Workgroup %uint -%_ptr_Workgroup__ptr_Workgroup_uint = OpTypePointer Workgroup %_ptr_Workgroup_uint - %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uint Workgroup + %1 = OpVariable %_ptr_Workgroup_uint Workgroup %ulong = OpTypeInt 64 0 %uchar = OpTypeInt 8 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %38 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar -%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar + %34 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong -%_ptr_Function__ptr_Workgroup_uint = OpTypePointer Function %_ptr_Workgroup_uint %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %2 = OpFunction %void None %38 + %2 = OpFunction %void None %34 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong - %26 = OpFunctionParameter %_ptr_Workgroup_uchar - %39 = OpLabel - %27 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %24 = OpFunctionParameter %_ptr_Workgroup_uchar + %22 = OpLabel %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function %5 = OpVariable %_ptr_Function_ulong Function %6 = OpVariable %_ptr_Function_ulong Function %7 = OpVariable %_ptr_Function_ulong Function - OpStore %27 %26 - OpBranch %24 - %24 = OpLabel OpStore %3 %8 OpStore %4 %9 %10 = OpLoad %ulong %3 Aligned 8 @@ -45,22 +38,20 @@ %11 = OpLoad %ulong %4 Aligned 8 OpStore %6 %11 %13 = OpLoad %ulong %5 - %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %13 - %12 = OpLoad %ulong %20 Aligned 8 + %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %13 + %12 = OpLoad %ulong %18 Aligned 8 OpStore %7 %12 - %28 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %27 - %14 = OpLoad %_ptr_Workgroup_uint %28 - %15 = OpLoad %ulong %7 - %21 = OpBitcast %_ptr_Workgroup_ulong %14 - OpStore %21 %15 Aligned 8 - %29 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %27 - %17 = OpLoad %_ptr_Workgroup_uint %29 - %22 = OpBitcast %_ptr_Workgroup_ulong %17 - %16 = OpLoad %ulong %22 Aligned 8 - OpStore %7 %16 - %18 = OpLoad %ulong %6 - %19 = OpLoad %ulong %7 - %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %18 - OpStore %23 %19 Aligned 8 + %14 = OpLoad %ulong %7 + %25 = OpBitcast %_ptr_Workgroup_uint %24 + %19 = OpBitcast %_ptr_Workgroup_ulong %25 + OpStore %19 %14 Aligned 8 + %26 = OpBitcast %_ptr_Workgroup_uint %24 + %20 = OpBitcast %_ptr_Workgroup_ulong %26 + %15 = OpLoad %ulong %20 Aligned 8 + OpStore %7 %15 + %16 = OpLoad %ulong %6 + %17 = OpLoad %ulong %7 + %21 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 + OpStore %21 %17 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/extern_shared_call.spvtxt b/ptx/src/test/spirv_run/extern_shared_call.spvtxt index 7043172..5af7168 100644 --- a/ptx/src/test/spirv_run/extern_shared_call.spvtxt +++ b/ptx/src/test/spirv_run/extern_shared_call.spvtxt @@ -7,87 +7,72 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %46 = OpExtInstImport "OpenCL.std" + %40 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %14 "extern_shared_call" %1 + OpEntryPoint Kernel %12 "extern_shared_call" %1 OpDecorate %1 Alignment 4 %void = OpTypeVoid %uint = OpTypeInt 32 0 %_ptr_Workgroup_uint = OpTypePointer Workgroup %uint -%_ptr_Workgroup__ptr_Workgroup_uint = OpTypePointer Workgroup %_ptr_Workgroup_uint - %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uint Workgroup + %1 = OpVariable %_ptr_Workgroup_uint Workgroup %uchar = OpTypeInt 8 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %53 = OpTypeFunction %void %_ptr_Workgroup_uchar -%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar + %46 = OpTypeFunction %void %_ptr_Workgroup_uchar %ulong = OpTypeInt 64 0 %_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function__ptr_Workgroup_uint = OpTypePointer Function %_ptr_Workgroup_uint %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong %ulong_2 = OpConstant %ulong 2 - %60 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar + %50 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %2 = OpFunction %void None %53 - %38 = OpFunctionParameter %_ptr_Workgroup_uchar - %54 = OpLabel - %39 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %2 = OpFunction %void None %46 + %34 = OpFunctionParameter %_ptr_Workgroup_uchar + %11 = OpLabel %3 = OpVariable %_ptr_Function_ulong Function - OpStore %39 %38 - OpBranch %13 - %13 = OpLabel - %40 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %39 - %5 = OpLoad %_ptr_Workgroup_uint %40 - %11 = OpBitcast %_ptr_Workgroup_ulong %5 - %4 = OpLoad %ulong %11 Aligned 8 + %35 = OpBitcast %_ptr_Workgroup_uint %34 + %9 = OpBitcast %_ptr_Workgroup_ulong %35 + %4 = OpLoad %ulong %9 Aligned 8 OpStore %3 %4 + %6 = OpLoad %ulong %3 + %5 = OpIAdd %ulong %6 %ulong_2 + OpStore %3 %5 %7 = OpLoad %ulong %3 - %6 = OpIAdd %ulong %7 %ulong_2 - OpStore %3 %6 - %41 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %39 - %8 = OpLoad %_ptr_Workgroup_uint %41 - %9 = OpLoad %ulong %3 - %12 = OpBitcast %_ptr_Workgroup_ulong %8 - OpStore %12 %9 Aligned 8 + %36 = OpBitcast %_ptr_Workgroup_uint %34 + %10 = OpBitcast %_ptr_Workgroup_ulong %36 + OpStore %10 %7 Aligned 8 OpReturn OpFunctionEnd - %14 = OpFunction %void None %60 - %20 = OpFunctionParameter %ulong - %21 = OpFunctionParameter %ulong - %42 = OpFunctionParameter %_ptr_Workgroup_uchar - %61 = OpLabel - %43 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %12 = OpFunction %void None %50 + %18 = OpFunctionParameter %ulong + %19 = OpFunctionParameter %ulong + %37 = OpFunctionParameter %_ptr_Workgroup_uchar + %32 = OpLabel + %13 = OpVariable %_ptr_Function_ulong Function + %14 = OpVariable %_ptr_Function_ulong Function %15 = OpVariable %_ptr_Function_ulong Function %16 = OpVariable %_ptr_Function_ulong Function %17 = OpVariable %_ptr_Function_ulong Function - %18 = OpVariable %_ptr_Function_ulong Function - %19 = OpVariable %_ptr_Function_ulong Function - OpStore %43 %42 - OpBranch %36 - %36 = OpLabel + OpStore %13 %18 + OpStore %14 %19 + %20 = OpLoad %ulong %13 Aligned 8 OpStore %15 %20 + %21 = OpLoad %ulong %14 Aligned 8 OpStore %16 %21 - %22 = OpLoad %ulong %15 Aligned 8 + %23 = OpLoad %ulong %15 + %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %23 + %22 = OpLoad %ulong %28 Aligned 8 OpStore %17 %22 - %23 = OpLoad %ulong %16 Aligned 8 - OpStore %18 %23 - %25 = OpLoad %ulong %17 - %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %25 - %24 = OpLoad %ulong %32 Aligned 8 - OpStore %19 %24 - %44 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %43 - %26 = OpLoad %_ptr_Workgroup_uint %44 - %27 = OpLoad %ulong %19 - %33 = OpBitcast %_ptr_Workgroup_ulong %26 - OpStore %33 %27 Aligned 8 - %63 = OpFunctionCall %void %2 %42 - %45 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %43 - %29 = OpLoad %_ptr_Workgroup_uint %45 - %34 = OpBitcast %_ptr_Workgroup_ulong %29 - %28 = OpLoad %ulong %34 Aligned 8 - OpStore %19 %28 - %30 = OpLoad %ulong %18 - %31 = OpLoad %ulong %19 - %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %30 - OpStore %35 %31 Aligned 8 + %24 = OpLoad %ulong %17 + %38 = OpBitcast %_ptr_Workgroup_uint %37 + %29 = OpBitcast %_ptr_Workgroup_ulong %38 + OpStore %29 %24 Aligned 8 + %52 = OpFunctionCall %void %2 %37 + %39 = OpBitcast %_ptr_Workgroup_uint %37 + %30 = OpBitcast %_ptr_Workgroup_ulong %39 + %25 = OpLoad %ulong %30 Aligned 8 + OpStore %17 %25 + %26 = OpLoad %ulong %16 + %27 = OpLoad %ulong %17 + %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %26 + OpStore %31 %27 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/fma.spvtxt b/ptx/src/test/spirv_run/fma.spvtxt index 300a328..8cc0e16 100644 --- a/ptx/src/test/spirv_run/fma.spvtxt +++ b/ptx/src/test/spirv_run/fma.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %1 = OpFunction %void None %38 %9 = OpFunctionParameter %ulong @@ -41,14 +43,18 @@ %13 = OpLoad %float %29 Aligned 4 OpStore %6 %13 %16 = OpLoad %ulong %4 - %26 = OpIAdd %ulong %16 %ulong_4 - %30 = OpConvertUToPtr %_ptr_Generic_float %26 - %15 = OpLoad %float %30 Aligned 4 + %30 = OpConvertUToPtr %_ptr_Generic_float %16 + %45 = OpBitcast %_ptr_Generic_uchar %30 + %46 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %45 %ulong_4 + %26 = OpBitcast %_ptr_Generic_float %46 + %15 = OpLoad %float %26 Aligned 4 OpStore %7 %15 %18 = OpLoad %ulong %4 - %28 = OpIAdd %ulong %18 %ulong_8 - %31 = OpConvertUToPtr %_ptr_Generic_float %28 - %17 = OpLoad %float %31 Aligned 4 + %31 = OpConvertUToPtr %_ptr_Generic_float %18 + %47 = OpBitcast %_ptr_Generic_uchar %31 + %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_8 + %28 = OpBitcast %_ptr_Generic_float %48 + %17 = OpLoad %float %28 Aligned 4 OpStore %8 %17 %20 = OpLoad %float %6 %21 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/ld_st_offset.spvtxt b/ptx/src/test/spirv_run/ld_st_offset.spvtxt index 5e314a0..ea97222 100644 --- a/ptx/src/test/spirv_run/ld_st_offset.spvtxt +++ b/ptx/src/test/spirv_run/ld_st_offset.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_4_0 = OpConstant %ulong 4 %1 = OpFunction %void None %33 %8 = OpFunctionParameter %ulong @@ -40,9 +42,11 @@ %12 = OpLoad %uint %24 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %21 = OpIAdd %ulong %15 %ulong_4 - %25 = OpConvertUToPtr %_ptr_Generic_uint %21 - %14 = OpLoad %uint %25 Aligned 4 + %25 = OpConvertUToPtr %_ptr_Generic_uint %15 + %40 = OpBitcast %_ptr_Generic_uchar %25 + %41 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %40 %ulong_4 + %21 = OpBitcast %_ptr_Generic_uint %41 + %14 = OpLoad %uint %21 Aligned 4 OpStore %7 %14 %16 = OpLoad %ulong %5 %17 = OpLoad %uint %7 @@ -50,8 +54,10 @@ OpStore %26 %17 Aligned 4 %18 = OpLoad %ulong %5 %19 = OpLoad %uint %6 - %23 = OpIAdd %ulong %18 %ulong_4_0 - %27 = OpConvertUToPtr %_ptr_Generic_uint %23 - OpStore %27 %19 Aligned 4 + %27 = OpConvertUToPtr %_ptr_Generic_uint %18 + %42 = OpBitcast %_ptr_Generic_uchar %27 + %43 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %42 %ulong_4_0 + %23 = OpBitcast %_ptr_Generic_uint %43 + OpStore %23 %19 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mad_s32.spvtxt b/ptx/src/test/spirv_run/mad_s32.spvtxt index bb44af0..0ee3ca7 100644 --- a/ptx/src/test/spirv_run/mad_s32.spvtxt +++ b/ptx/src/test/spirv_run/mad_s32.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %ulong_4_0 = OpConstant %ulong 4 %ulong_8_0 = OpConstant %ulong 8 @@ -44,20 +46,24 @@ %14 = OpLoad %uint %38 Aligned 4 OpStore %7 %14 %17 = OpLoad %ulong %4 - %31 = OpIAdd %ulong %17 %ulong_4 - %39 = OpConvertUToPtr %_ptr_Generic_uint %31 - %16 = OpLoad %uint %39 Aligned 4 + %39 = OpConvertUToPtr %_ptr_Generic_uint %17 + %56 = OpBitcast %_ptr_Generic_uchar %39 + %57 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %56 %ulong_4 + %31 = OpBitcast %_ptr_Generic_uint %57 + %16 = OpLoad %uint %31 Aligned 4 OpStore %8 %16 %19 = OpLoad %ulong %4 - %33 = OpIAdd %ulong %19 %ulong_8 - %40 = OpConvertUToPtr %_ptr_Generic_uint %33 - %18 = OpLoad %uint %40 Aligned 4 + %40 = OpConvertUToPtr %_ptr_Generic_uint %19 + %58 = OpBitcast %_ptr_Generic_uchar %40 + %59 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %58 %ulong_8 + %33 = OpBitcast %_ptr_Generic_uint %59 + %18 = OpLoad %uint %33 Aligned 4 OpStore %9 %18 %21 = OpLoad %uint %7 %22 = OpLoad %uint %8 %23 = OpLoad %uint %9 - %54 = OpIMul %uint %21 %22 - %20 = OpIAdd %uint %23 %54 + %60 = OpIMul %uint %21 %22 + %20 = OpIAdd %uint %23 %60 OpStore %6 %20 %24 = OpLoad %ulong %5 %25 = OpLoad %uint %6 @@ -65,13 +71,17 @@ OpStore %41 %25 Aligned 4 %26 = OpLoad %ulong %5 %27 = OpLoad %uint %6 - %35 = OpIAdd %ulong %26 %ulong_4_0 - %42 = OpConvertUToPtr %_ptr_Generic_uint %35 - OpStore %42 %27 Aligned 4 + %42 = OpConvertUToPtr %_ptr_Generic_uint %26 + %61 = OpBitcast %_ptr_Generic_uchar %42 + %62 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %61 %ulong_4_0 + %35 = OpBitcast %_ptr_Generic_uint %62 + OpStore %35 %27 Aligned 4 %28 = OpLoad %ulong %5 %29 = OpLoad %uint %6 - %37 = OpIAdd %ulong %28 %ulong_8_0 - %43 = OpConvertUToPtr %_ptr_Generic_uint %37 - OpStore %43 %29 Aligned 4 + %43 = OpConvertUToPtr %_ptr_Generic_uint %28 + %63 = OpBitcast %_ptr_Generic_uchar %43 + %64 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %63 %ulong_8_0 + %37 = OpBitcast %_ptr_Generic_uint %64 + OpStore %37 %29 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/max.spvtxt b/ptx/src/test/spirv_run/max.spvtxt index d3ffa2f..86b732a 100644 --- a/ptx/src/test/spirv_run/max.spvtxt +++ b/ptx/src/test/spirv_run/max.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %39 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/min.spvtxt b/ptx/src/test/spirv_run/min.spvtxt index de2e35e..a187376 100644 --- a/ptx/src/test/spirv_run/min.spvtxt +++ b/ptx/src/test/spirv_run/min.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %39 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/mul_ftz.spvtxt b/ptx/src/test/spirv_run/mul_ftz.spvtxt index ed268fb..e7a4a56 100644 --- a/ptx/src/test/spirv_run/mul_ftz.spvtxt +++ b/ptx/src/test/spirv_run/mul_ftz.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %float %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_float %22 - %14 = OpLoad %float %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_float %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_float %39 + %14 = OpLoad %float %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %18 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/mul_non_ftz.spvtxt b/ptx/src/test/spirv_run/mul_non_ftz.spvtxt index 436aca1..5326baa 100644 --- a/ptx/src/test/spirv_run/mul_non_ftz.spvtxt +++ b/ptx/src/test/spirv_run/mul_non_ftz.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %float %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_float %22 - %14 = OpLoad %float %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_float %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_float %39 + %14 = OpLoad %float %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %18 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/mul_wide.spvtxt b/ptx/src/test/spirv_run/mul_wide.spvtxt index 7ac81cf..e96a964 100644 --- a/ptx/src/test/spirv_run/mul_wide.spvtxt +++ b/ptx/src/test/spirv_run/mul_wide.spvtxt @@ -18,7 +18,9 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint %ulong_4 = OpConstant %ulong 4 - %_struct_38 = OpTypeStruct %uint %uint + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %_struct_42 = OpTypeStruct %uint %uint %v2uint = OpTypeVector %uint 2 %_ptr_Generic_ulong = OpTypePointer Generic %ulong %1 = OpFunction %void None %33 @@ -43,17 +45,19 @@ %13 = OpLoad %uint %24 Aligned 4 OpStore %6 %13 %16 = OpLoad %ulong %4 - %23 = OpIAdd %ulong %16 %ulong_4 - %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %23 - %15 = OpLoad %uint %25 Aligned 4 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %16 + %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %25 + %41 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %40 %ulong_4 + %23 = OpBitcast %_ptr_CrossWorkgroup_uint %41 + %15 = OpLoad %uint %23 Aligned 4 OpStore %7 %15 %18 = OpLoad %uint %6 %19 = OpLoad %uint %7 - %39 = OpSMulExtended %_struct_38 %18 %19 - %40 = OpCompositeExtract %uint %39 0 - %41 = OpCompositeExtract %uint %39 1 - %43 = OpCompositeConstruct %v2uint %40 %41 - %17 = OpBitcast %ulong %43 + %43 = OpSMulExtended %_struct_42 %18 %19 + %44 = OpCompositeExtract %uint %43 0 + %45 = OpCompositeExtract %uint %43 1 + %47 = OpCompositeConstruct %v2uint %44 %45 + %17 = OpBitcast %ulong %47 OpStore %8 %17 %20 = OpLoad %ulong %5 %21 = OpLoad %ulong %8 diff --git a/ptx/src/test/spirv_run/or.spvtxt b/ptx/src/test/spirv_run/or.spvtxt index fef3f40..82db00c 100644 --- a/ptx/src/test/spirv_run/or.spvtxt +++ b/ptx/src/test/spirv_run/or.spvtxt @@ -16,6 +16,8 @@ %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Generic_ulong = OpTypePointer Generic %ulong %ulong_8 = OpConstant %ulong 8 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %34 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -37,9 +39,11 @@ %12 = OpLoad %ulong %23 Aligned 8 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_8 - %24 = OpConvertUToPtr %_ptr_Generic_ulong %22 - %14 = OpLoad %ulong %24 Aligned 8 + %24 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %39 = OpBitcast %_ptr_Generic_uchar %24 + %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_8 + %22 = OpBitcast %_ptr_Generic_ulong %40 + %14 = OpLoad %ulong %22 Aligned 8 OpStore %7 %14 %17 = OpLoad %ulong %6 %18 = OpLoad %ulong %7 diff --git a/ptx/src/test/spirv_run/pred_not.spvtxt b/ptx/src/test/spirv_run/pred_not.spvtxt index 18fde05..644731b 100644 --- a/ptx/src/test/spirv_run/pred_not.spvtxt +++ b/ptx/src/test/spirv_run/pred_not.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_ulong = OpTypePointer Generic %ulong %ulong_8 = OpConstant %ulong 8 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %true = OpConstantTrue %bool %false = OpConstantFalse %bool %ulong_1 = OpConstant %ulong 1 @@ -45,9 +47,11 @@ %18 = OpLoad %ulong %37 Aligned 8 OpStore %6 %18 %21 = OpLoad %ulong %4 - %34 = OpIAdd %ulong %21 %ulong_8 - %38 = OpConvertUToPtr %_ptr_Generic_ulong %34 - %20 = OpLoad %ulong %38 Aligned 8 + %38 = OpConvertUToPtr %_ptr_Generic_ulong %21 + %52 = OpBitcast %_ptr_Generic_uchar %38 + %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_8 + %34 = OpBitcast %_ptr_Generic_ulong %53 + %20 = OpLoad %ulong %34 Aligned 8 OpStore %7 %20 %23 = OpLoad %ulong %6 %24 = OpLoad %ulong %7 diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt index 7bb5bd9..a0b957a 100644 --- a/ptx/src/test/spirv_run/reg_local.spvtxt +++ b/ptx/src/test/spirv_run/reg_local.spvtxt @@ -26,6 +26,7 @@ %ulong_0 = OpConstant %ulong 0 %_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_0_0 = OpConstant %ulong 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar %1 = OpFunction %void None %37 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -48,10 +49,10 @@ %12 = OpCopyObject %ulong %24 OpStore %7 %12 %14 = OpLoad %ulong %7 - %26 = OpCopyObject %ulong %14 - %19 = OpIAdd %ulong %26 %ulong_1 - %27 = OpBitcast %_ptr_Generic_ulong %4 - OpStore %27 %19 Aligned 8 + %19 = OpIAdd %ulong %14 %ulong_1 + %26 = OpBitcast %_ptr_Generic_ulong %4 + %27 = OpCopyObject %ulong %19 + OpStore %26 %27 Aligned 8 %28 = OpBitcast %_ptr_Generic_ulong %4 %47 = OpBitcast %_ptr_Generic_uchar %28 %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_0 @@ -61,9 +62,11 @@ OpStore %7 %15 %16 = OpLoad %ulong %6 %17 = OpLoad %ulong %7 - %23 = OpIAdd %ulong %16 %ulong_0_0 - %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %23 + %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 + %50 = OpBitcast %_ptr_CrossWorkgroup_uchar %30 + %51 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %50 %ulong_0_0 + %23 = OpBitcast %_ptr_CrossWorkgroup_ulong %51 %31 = OpCopyObject %ulong %17 - OpStore %30 %31 Aligned 8 + OpStore %23 %31 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/rem.spvtxt b/ptx/src/test/spirv_run/rem.spvtxt index ce1d3e6..2184523 100644 --- a/ptx/src/test/spirv_run/rem.spvtxt +++ b/ptx/src/test/spirv_run/rem.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %39 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/selp.spvtxt b/ptx/src/test/spirv_run/selp.spvtxt index 9798758..40c0bce 100644 --- a/ptx/src/test/spirv_run/selp.spvtxt +++ b/ptx/src/test/spirv_run/selp.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_ushort = OpTypePointer Function %ushort %_ptr_Generic_ushort = OpTypePointer Generic %ushort %ulong_2 = OpConstant %ulong 2 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %bool = OpTypeBool %false = OpConstantFalse %bool %1 = OpFunction %void None %32 @@ -41,9 +43,11 @@ %12 = OpLoad %ushort %24 Aligned 2 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_2 - %25 = OpConvertUToPtr %_ptr_Generic_ushort %22 - %14 = OpLoad %ushort %25 Aligned 2 + %25 = OpConvertUToPtr %_ptr_Generic_ushort %15 + %39 = OpBitcast %_ptr_Generic_uchar %25 + %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_2 + %22 = OpBitcast %_ptr_Generic_ushort %40 + %14 = OpLoad %ushort %22 Aligned 2 OpStore %7 %14 %17 = OpLoad %ushort %6 %18 = OpLoad %ushort %7 diff --git a/ptx/src/test/spirv_run/selp_true.spvtxt b/ptx/src/test/spirv_run/selp_true.spvtxt index f7038e0..81b3b5f 100644 --- a/ptx/src/test/spirv_run/selp_true.spvtxt +++ b/ptx/src/test/spirv_run/selp_true.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_ushort = OpTypePointer Function %ushort %_ptr_Generic_ushort = OpTypePointer Generic %ushort %ulong_2 = OpConstant %ulong 2 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %bool = OpTypeBool %true = OpConstantTrue %bool %1 = OpFunction %void None %32 @@ -41,9 +43,11 @@ %12 = OpLoad %ushort %24 Aligned 2 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_2 - %25 = OpConvertUToPtr %_ptr_Generic_ushort %22 - %14 = OpLoad %ushort %25 Aligned 2 + %25 = OpConvertUToPtr %_ptr_Generic_ushort %15 + %39 = OpBitcast %_ptr_Generic_uchar %25 + %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_2 + %22 = OpBitcast %_ptr_Generic_ushort %40 + %14 = OpLoad %ushort %22 Aligned 2 OpStore %7 %14 %17 = OpLoad %ushort %6 %18 = OpLoad %ushort %7 diff --git a/ptx/src/test/spirv_run/setp.spvtxt b/ptx/src/test/spirv_run/setp.spvtxt index c3129e3..5868881 100644 --- a/ptx/src/test/spirv_run/setp.spvtxt +++ b/ptx/src/test/spirv_run/setp.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_ulong = OpTypePointer Generic %ulong %ulong_8 = OpConstant %ulong 8 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_1 = OpConstant %ulong 1 %ulong_2 = OpConstant %ulong 2 %1 = OpFunction %void None %43 @@ -43,9 +45,11 @@ %18 = OpLoad %ulong %35 Aligned 8 OpStore %6 %18 %21 = OpLoad %ulong %4 - %32 = OpIAdd %ulong %21 %ulong_8 - %36 = OpConvertUToPtr %_ptr_Generic_ulong %32 - %20 = OpLoad %ulong %36 Aligned 8 + %36 = OpConvertUToPtr %_ptr_Generic_ulong %21 + %50 = OpBitcast %_ptr_Generic_uchar %36 + %51 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %50 %ulong_8 + %32 = OpBitcast %_ptr_Generic_ulong %51 + %20 = OpLoad %ulong %32 Aligned 8 OpStore %7 %20 %23 = OpLoad %ulong %6 %24 = OpLoad %ulong %7 diff --git a/ptx/src/test/spirv_run/setp_gt.spvtxt b/ptx/src/test/spirv_run/setp_gt.spvtxt index 77f6546..e9783f5 100644 --- a/ptx/src/test/spirv_run/setp_gt.spvtxt +++ b/ptx/src/test/spirv_run/setp_gt.spvtxt @@ -20,6 +20,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %43 %14 = OpFunctionParameter %ulong %15 = OpFunctionParameter %ulong @@ -43,9 +45,11 @@ %18 = OpLoad %float %35 Aligned 4 OpStore %6 %18 %21 = OpLoad %ulong %4 - %34 = OpIAdd %ulong %21 %ulong_4 - %36 = OpConvertUToPtr %_ptr_Generic_float %34 - %20 = OpLoad %float %36 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_float %21 + %52 = OpBitcast %_ptr_Generic_uchar %36 + %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_4 + %34 = OpBitcast %_ptr_Generic_float %53 + %20 = OpLoad %float %34 Aligned 4 OpStore %7 %20 %23 = OpLoad %float %6 %24 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/setp_leu.spvtxt b/ptx/src/test/spirv_run/setp_leu.spvtxt index f80880a..1d2d781 100644 --- a/ptx/src/test/spirv_run/setp_leu.spvtxt +++ b/ptx/src/test/spirv_run/setp_leu.spvtxt @@ -20,6 +20,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %43 %14 = OpFunctionParameter %ulong %15 = OpFunctionParameter %ulong @@ -43,9 +45,11 @@ %18 = OpLoad %float %35 Aligned 4 OpStore %6 %18 %21 = OpLoad %ulong %4 - %34 = OpIAdd %ulong %21 %ulong_4 - %36 = OpConvertUToPtr %_ptr_Generic_float %34 - %20 = OpLoad %float %36 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_float %21 + %52 = OpBitcast %_ptr_Generic_uchar %36 + %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_4 + %34 = OpBitcast %_ptr_Generic_float %53 + %20 = OpLoad %float %34 Aligned 4 OpStore %7 %20 %23 = OpLoad %float %6 %24 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/setp_nan.spvtxt b/ptx/src/test/spirv_run/setp_nan.spvtxt index 4a9fe11..2ee333a 100644 --- a/ptx/src/test/spirv_run/setp_nan.spvtxt +++ b/ptx/src/test/spirv_run/setp_nan.spvtxt @@ -22,6 +22,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %ulong_12 = OpConstant %ulong 12 %ulong_16 = OpConstant %ulong 16 @@ -69,45 +71,59 @@ %36 = OpLoad %float %116 Aligned 4 OpStore %6 %36 %39 = OpLoad %ulong %4 - %89 = OpIAdd %ulong %39 %ulong_4 - %117 = OpConvertUToPtr %_ptr_Generic_float %89 - %38 = OpLoad %float %117 Aligned 4 + %117 = OpConvertUToPtr %_ptr_Generic_float %39 + %144 = OpBitcast %_ptr_Generic_uchar %117 + %145 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %144 %ulong_4 + %89 = OpBitcast %_ptr_Generic_float %145 + %38 = OpLoad %float %89 Aligned 4 OpStore %7 %38 %41 = OpLoad %ulong %4 - %91 = OpIAdd %ulong %41 %ulong_8 - %118 = OpConvertUToPtr %_ptr_Generic_float %91 - %40 = OpLoad %float %118 Aligned 4 + %118 = OpConvertUToPtr %_ptr_Generic_float %41 + %146 = OpBitcast %_ptr_Generic_uchar %118 + %147 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %146 %ulong_8 + %91 = OpBitcast %_ptr_Generic_float %147 + %40 = OpLoad %float %91 Aligned 4 OpStore %8 %40 %43 = OpLoad %ulong %4 - %93 = OpIAdd %ulong %43 %ulong_12 - %119 = OpConvertUToPtr %_ptr_Generic_float %93 - %42 = OpLoad %float %119 Aligned 4 + %119 = OpConvertUToPtr %_ptr_Generic_float %43 + %148 = OpBitcast %_ptr_Generic_uchar %119 + %149 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %148 %ulong_12 + %93 = OpBitcast %_ptr_Generic_float %149 + %42 = OpLoad %float %93 Aligned 4 OpStore %9 %42 %45 = OpLoad %ulong %4 - %95 = OpIAdd %ulong %45 %ulong_16 - %120 = OpConvertUToPtr %_ptr_Generic_float %95 - %44 = OpLoad %float %120 Aligned 4 + %120 = OpConvertUToPtr %_ptr_Generic_float %45 + %150 = OpBitcast %_ptr_Generic_uchar %120 + %151 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %150 %ulong_16 + %95 = OpBitcast %_ptr_Generic_float %151 + %44 = OpLoad %float %95 Aligned 4 OpStore %10 %44 %47 = OpLoad %ulong %4 - %97 = OpIAdd %ulong %47 %ulong_20 - %121 = OpConvertUToPtr %_ptr_Generic_float %97 - %46 = OpLoad %float %121 Aligned 4 + %121 = OpConvertUToPtr %_ptr_Generic_float %47 + %152 = OpBitcast %_ptr_Generic_uchar %121 + %153 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %152 %ulong_20 + %97 = OpBitcast %_ptr_Generic_float %153 + %46 = OpLoad %float %97 Aligned 4 OpStore %11 %46 %49 = OpLoad %ulong %4 - %99 = OpIAdd %ulong %49 %ulong_24 - %122 = OpConvertUToPtr %_ptr_Generic_float %99 - %48 = OpLoad %float %122 Aligned 4 + %122 = OpConvertUToPtr %_ptr_Generic_float %49 + %154 = OpBitcast %_ptr_Generic_uchar %122 + %155 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %154 %ulong_24 + %99 = OpBitcast %_ptr_Generic_float %155 + %48 = OpLoad %float %99 Aligned 4 OpStore %12 %48 %51 = OpLoad %ulong %4 - %101 = OpIAdd %ulong %51 %ulong_28 - %123 = OpConvertUToPtr %_ptr_Generic_float %101 - %50 = OpLoad %float %123 Aligned 4 + %123 = OpConvertUToPtr %_ptr_Generic_float %51 + %156 = OpBitcast %_ptr_Generic_uchar %123 + %157 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %156 %ulong_28 + %101 = OpBitcast %_ptr_Generic_float %157 + %50 = OpLoad %float %101 Aligned 4 OpStore %13 %50 %53 = OpLoad %float %6 %54 = OpLoad %float %7 - %142 = OpIsNan %bool %53 - %143 = OpIsNan %bool %54 - %52 = OpLogicalOr %bool %142 %143 + %158 = OpIsNan %bool %53 + %159 = OpIsNan %bool %54 + %52 = OpLogicalOr %bool %158 %159 OpStore %15 %52 %55 = OpLoad %bool %15 OpBranchConditional %55 %16 %17 @@ -129,9 +145,9 @@ OpStore %124 %60 Aligned 4 %62 = OpLoad %float %8 %63 = OpLoad %float %9 - %145 = OpIsNan %bool %62 - %146 = OpIsNan %bool %63 - %61 = OpLogicalOr %bool %145 %146 + %161 = OpIsNan %bool %62 + %162 = OpIsNan %bool %63 + %61 = OpLogicalOr %bool %161 %162 OpStore %15 %61 %64 = OpLoad %bool %15 OpBranchConditional %64 %20 %21 @@ -149,14 +165,16 @@ %23 = OpLabel %68 = OpLoad %ulong %5 %69 = OpLoad %uint %14 - %107 = OpIAdd %ulong %68 %ulong_4_0 - %125 = OpConvertUToPtr %_ptr_Generic_uint %107 - OpStore %125 %69 Aligned 4 + %125 = OpConvertUToPtr %_ptr_Generic_uint %68 + %163 = OpBitcast %_ptr_Generic_uchar %125 + %164 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %163 %ulong_4_0 + %107 = OpBitcast %_ptr_Generic_uint %164 + OpStore %107 %69 Aligned 4 %71 = OpLoad %float %10 %72 = OpLoad %float %11 - %147 = OpIsNan %bool %71 - %148 = OpIsNan %bool %72 - %70 = OpLogicalOr %bool %147 %148 + %165 = OpIsNan %bool %71 + %166 = OpIsNan %bool %72 + %70 = OpLogicalOr %bool %165 %166 OpStore %15 %70 %73 = OpLoad %bool %15 OpBranchConditional %73 %24 %25 @@ -174,14 +192,16 @@ %27 = OpLabel %77 = OpLoad %ulong %5 %78 = OpLoad %uint %14 - %111 = OpIAdd %ulong %77 %ulong_8_0 - %126 = OpConvertUToPtr %_ptr_Generic_uint %111 - OpStore %126 %78 Aligned 4 + %126 = OpConvertUToPtr %_ptr_Generic_uint %77 + %167 = OpBitcast %_ptr_Generic_uchar %126 + %168 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %167 %ulong_8_0 + %111 = OpBitcast %_ptr_Generic_uint %168 + OpStore %111 %78 Aligned 4 %80 = OpLoad %float %12 %81 = OpLoad %float %13 - %149 = OpIsNan %bool %80 - %150 = OpIsNan %bool %81 - %79 = OpLogicalOr %bool %149 %150 + %169 = OpIsNan %bool %80 + %170 = OpIsNan %bool %81 + %79 = OpLogicalOr %bool %169 %170 OpStore %15 %79 %82 = OpLoad %bool %15 OpBranchConditional %82 %28 %29 @@ -199,8 +219,10 @@ %31 = OpLabel %86 = OpLoad %ulong %5 %87 = OpLoad %uint %14 - %115 = OpIAdd %ulong %86 %ulong_12_0 - %127 = OpConvertUToPtr %_ptr_Generic_uint %115 - OpStore %127 %87 Aligned 4 + %127 = OpConvertUToPtr %_ptr_Generic_uint %86 + %171 = OpBitcast %_ptr_Generic_uchar %127 + %172 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %171 %ulong_12_0 + %115 = OpBitcast %_ptr_Generic_uint %172 + OpStore %115 %87 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/setp_num.spvtxt b/ptx/src/test/spirv_run/setp_num.spvtxt index 3ac6eab..c576a50 100644 --- a/ptx/src/test/spirv_run/setp_num.spvtxt +++ b/ptx/src/test/spirv_run/setp_num.spvtxt @@ -22,6 +22,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %ulong_12 = OpConstant %ulong 12 %ulong_16 = OpConstant %ulong 16 @@ -77,46 +79,60 @@ %36 = OpLoad %float %116 Aligned 4 OpStore %6 %36 %39 = OpLoad %ulong %4 - %89 = OpIAdd %ulong %39 %ulong_4 - %117 = OpConvertUToPtr %_ptr_Generic_float %89 - %38 = OpLoad %float %117 Aligned 4 + %117 = OpConvertUToPtr %_ptr_Generic_float %39 + %144 = OpBitcast %_ptr_Generic_uchar %117 + %145 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %144 %ulong_4 + %89 = OpBitcast %_ptr_Generic_float %145 + %38 = OpLoad %float %89 Aligned 4 OpStore %7 %38 %41 = OpLoad %ulong %4 - %91 = OpIAdd %ulong %41 %ulong_8 - %118 = OpConvertUToPtr %_ptr_Generic_float %91 - %40 = OpLoad %float %118 Aligned 4 + %118 = OpConvertUToPtr %_ptr_Generic_float %41 + %146 = OpBitcast %_ptr_Generic_uchar %118 + %147 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %146 %ulong_8 + %91 = OpBitcast %_ptr_Generic_float %147 + %40 = OpLoad %float %91 Aligned 4 OpStore %8 %40 %43 = OpLoad %ulong %4 - %93 = OpIAdd %ulong %43 %ulong_12 - %119 = OpConvertUToPtr %_ptr_Generic_float %93 - %42 = OpLoad %float %119 Aligned 4 + %119 = OpConvertUToPtr %_ptr_Generic_float %43 + %148 = OpBitcast %_ptr_Generic_uchar %119 + %149 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %148 %ulong_12 + %93 = OpBitcast %_ptr_Generic_float %149 + %42 = OpLoad %float %93 Aligned 4 OpStore %9 %42 %45 = OpLoad %ulong %4 - %95 = OpIAdd %ulong %45 %ulong_16 - %120 = OpConvertUToPtr %_ptr_Generic_float %95 - %44 = OpLoad %float %120 Aligned 4 + %120 = OpConvertUToPtr %_ptr_Generic_float %45 + %150 = OpBitcast %_ptr_Generic_uchar %120 + %151 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %150 %ulong_16 + %95 = OpBitcast %_ptr_Generic_float %151 + %44 = OpLoad %float %95 Aligned 4 OpStore %10 %44 %47 = OpLoad %ulong %4 - %97 = OpIAdd %ulong %47 %ulong_20 - %121 = OpConvertUToPtr %_ptr_Generic_float %97 - %46 = OpLoad %float %121 Aligned 4 + %121 = OpConvertUToPtr %_ptr_Generic_float %47 + %152 = OpBitcast %_ptr_Generic_uchar %121 + %153 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %152 %ulong_20 + %97 = OpBitcast %_ptr_Generic_float %153 + %46 = OpLoad %float %97 Aligned 4 OpStore %11 %46 %49 = OpLoad %ulong %4 - %99 = OpIAdd %ulong %49 %ulong_24 - %122 = OpConvertUToPtr %_ptr_Generic_float %99 - %48 = OpLoad %float %122 Aligned 4 + %122 = OpConvertUToPtr %_ptr_Generic_float %49 + %154 = OpBitcast %_ptr_Generic_uchar %122 + %155 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %154 %ulong_24 + %99 = OpBitcast %_ptr_Generic_float %155 + %48 = OpLoad %float %99 Aligned 4 OpStore %12 %48 %51 = OpLoad %ulong %4 - %101 = OpIAdd %ulong %51 %ulong_28 - %123 = OpConvertUToPtr %_ptr_Generic_float %101 - %50 = OpLoad %float %123 Aligned 4 + %123 = OpConvertUToPtr %_ptr_Generic_float %51 + %156 = OpBitcast %_ptr_Generic_uchar %123 + %157 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %156 %ulong_28 + %101 = OpBitcast %_ptr_Generic_float %157 + %50 = OpLoad %float %101 Aligned 4 OpStore %13 %50 %53 = OpLoad %float %6 %54 = OpLoad %float %7 - %142 = OpIsNan %bool %53 - %143 = OpIsNan %bool %54 - %144 = OpLogicalOr %bool %142 %143 - %52 = OpSelect %bool %144 %false %true + %158 = OpIsNan %bool %53 + %159 = OpIsNan %bool %54 + %160 = OpLogicalOr %bool %158 %159 + %52 = OpSelect %bool %160 %false %true OpStore %15 %52 %55 = OpLoad %bool %15 OpBranchConditional %55 %16 %17 @@ -138,10 +154,10 @@ OpStore %124 %60 Aligned 4 %62 = OpLoad %float %8 %63 = OpLoad %float %9 - %148 = OpIsNan %bool %62 - %149 = OpIsNan %bool %63 - %150 = OpLogicalOr %bool %148 %149 - %61 = OpSelect %bool %150 %false_0 %true_0 + %164 = OpIsNan %bool %62 + %165 = OpIsNan %bool %63 + %166 = OpLogicalOr %bool %164 %165 + %61 = OpSelect %bool %166 %false_0 %true_0 OpStore %15 %61 %64 = OpLoad %bool %15 OpBranchConditional %64 %20 %21 @@ -159,15 +175,17 @@ %23 = OpLabel %68 = OpLoad %ulong %5 %69 = OpLoad %uint %14 - %107 = OpIAdd %ulong %68 %ulong_4_0 - %125 = OpConvertUToPtr %_ptr_Generic_uint %107 - OpStore %125 %69 Aligned 4 + %125 = OpConvertUToPtr %_ptr_Generic_uint %68 + %169 = OpBitcast %_ptr_Generic_uchar %125 + %170 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %169 %ulong_4_0 + %107 = OpBitcast %_ptr_Generic_uint %170 + OpStore %107 %69 Aligned 4 %71 = OpLoad %float %10 %72 = OpLoad %float %11 - %153 = OpIsNan %bool %71 - %154 = OpIsNan %bool %72 - %155 = OpLogicalOr %bool %153 %154 - %70 = OpSelect %bool %155 %false_1 %true_1 + %171 = OpIsNan %bool %71 + %172 = OpIsNan %bool %72 + %173 = OpLogicalOr %bool %171 %172 + %70 = OpSelect %bool %173 %false_1 %true_1 OpStore %15 %70 %73 = OpLoad %bool %15 OpBranchConditional %73 %24 %25 @@ -185,15 +203,17 @@ %27 = OpLabel %77 = OpLoad %ulong %5 %78 = OpLoad %uint %14 - %111 = OpIAdd %ulong %77 %ulong_8_0 - %126 = OpConvertUToPtr %_ptr_Generic_uint %111 - OpStore %126 %78 Aligned 4 + %126 = OpConvertUToPtr %_ptr_Generic_uint %77 + %176 = OpBitcast %_ptr_Generic_uchar %126 + %177 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %176 %ulong_8_0 + %111 = OpBitcast %_ptr_Generic_uint %177 + OpStore %111 %78 Aligned 4 %80 = OpLoad %float %12 %81 = OpLoad %float %13 - %158 = OpIsNan %bool %80 - %159 = OpIsNan %bool %81 - %160 = OpLogicalOr %bool %158 %159 - %79 = OpSelect %bool %160 %false_2 %true_2 + %178 = OpIsNan %bool %80 + %179 = OpIsNan %bool %81 + %180 = OpLogicalOr %bool %178 %179 + %79 = OpSelect %bool %180 %false_2 %true_2 OpStore %15 %79 %82 = OpLoad %bool %15 OpBranchConditional %82 %28 %29 @@ -211,8 +231,10 @@ %31 = OpLabel %86 = OpLoad %ulong %5 %87 = OpLoad %uint %14 - %115 = OpIAdd %ulong %86 %ulong_12_0 - %127 = OpConvertUToPtr %_ptr_Generic_uint %115 - OpStore %127 %87 Aligned 4 + %127 = OpConvertUToPtr %_ptr_Generic_uint %86 + %183 = OpBitcast %_ptr_Generic_uchar %127 + %184 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %183 %ulong_12_0 + %115 = OpBitcast %_ptr_Generic_uint %184 + OpStore %115 %87 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt index 2ea964c..1b2e3dd 100644 --- a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt +++ b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt @@ -24,7 +24,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %uint_0 = OpConstant %uint 0 + %ulong_0 = OpConstant %ulong 0 +%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar %1 = OpFunction %void None %40 %10 = OpFunctionParameter %ulong %11 = OpFunctionParameter %ulong @@ -54,9 +55,11 @@ %27 = OpConvertUToPtr %_ptr_Workgroup_ulong %17 OpStore %27 %18 Aligned 8 %20 = OpLoad %uint %7 - %24 = OpIAdd %uint %20 %uint_0 - %28 = OpConvertUToPtr %_ptr_Workgroup_ulong %24 - %19 = OpLoad %ulong %28 Aligned 8 + %28 = OpConvertUToPtr %_ptr_Workgroup_ulong %20 + %46 = OpBitcast %_ptr_Workgroup_uchar %28 + %47 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uchar %46 %ulong_0 + %24 = OpBitcast %_ptr_Workgroup_ulong %47 + %19 = OpLoad %ulong %24 Aligned 8 OpStore %9 %19 %21 = OpLoad %ulong %6 %22 = OpLoad %ulong %9 diff --git a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt index 19d5a5a..fd4f893 100644 --- a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt +++ b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt @@ -7,27 +7,24 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %33 = OpExtInstImport "OpenCL.std" + %31 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %2 "shared_ptr_take_address" %1 OpDecorate %1 Alignment 4 %void = OpTypeVoid %uchar = OpTypeInt 8 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar -%_ptr_Workgroup__ptr_Workgroup_uchar = OpTypePointer Workgroup %_ptr_Workgroup_uchar - %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uchar Workgroup + %1 = OpVariable %_ptr_Workgroup_uchar Workgroup %ulong = OpTypeInt 64 0 - %39 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar -%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar + %36 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %2 = OpFunction %void None %39 + %2 = OpFunction %void None %36 %10 = OpFunctionParameter %ulong %11 = OpFunctionParameter %ulong - %31 = OpFunctionParameter %_ptr_Workgroup_uchar - %40 = OpLabel - %32 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %30 = OpFunctionParameter %_ptr_Workgroup_uchar + %28 = OpLabel %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function %5 = OpVariable %_ptr_Function_ulong Function @@ -35,34 +32,30 @@ %7 = OpVariable %_ptr_Function_ulong Function %8 = OpVariable %_ptr_Function_ulong Function %9 = OpVariable %_ptr_Function_ulong Function - OpStore %32 %31 - OpBranch %29 - %29 = OpLabel OpStore %3 %10 OpStore %4 %11 %12 = OpLoad %ulong %3 Aligned 8 OpStore %5 %12 %13 = OpLoad %ulong %4 Aligned 8 OpStore %6 %13 - %15 = OpLoad %_ptr_Workgroup_uchar %32 - %24 = OpConvertPtrToU %ulong %15 - %14 = OpCopyObject %ulong %24 + %23 = OpConvertPtrToU %ulong %30 + %14 = OpCopyObject %ulong %23 OpStore %7 %14 - %17 = OpLoad %ulong %5 - %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %17 - %16 = OpLoad %ulong %25 Aligned 8 - OpStore %8 %16 - %18 = OpLoad %ulong %7 - %19 = OpLoad %ulong %8 - %26 = OpConvertUToPtr %_ptr_Workgroup_ulong %18 - OpStore %26 %19 Aligned 8 - %21 = OpLoad %ulong %7 - %27 = OpConvertUToPtr %_ptr_Workgroup_ulong %21 - %20 = OpLoad %ulong %27 Aligned 8 - OpStore %9 %20 - %22 = OpLoad %ulong %6 - %23 = OpLoad %ulong %9 - %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %22 - OpStore %28 %23 Aligned 8 + %16 = OpLoad %ulong %5 + %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 + %15 = OpLoad %ulong %24 Aligned 8 + OpStore %8 %15 + %17 = OpLoad %ulong %7 + %18 = OpLoad %ulong %8 + %25 = OpConvertUToPtr %_ptr_Workgroup_ulong %17 + OpStore %25 %18 Aligned 8 + %20 = OpLoad %ulong %7 + %26 = OpConvertUToPtr %_ptr_Workgroup_ulong %20 + %19 = OpLoad %ulong %26 Aligned 8 + OpStore %9 %19 + %21 = OpLoad %ulong %6 + %22 = OpLoad %ulong %9 + %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21 + OpStore %27 %22 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt index 33812f6..cf0d86e 100644 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt @@ -7,7 +7,7 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %50 = OpExtInstImport "OpenCL.std" + %54 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "stateful_ld_st_ntid" %gl_LocalInvocationID OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId @@ -18,34 +18,34 @@ %gl_LocalInvocationID = OpVariable %_ptr_Input_v3ulong Input %uchar = OpTypeInt 8 0 %_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %57 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar + %61 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar %_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %1 = OpFunction %void None %57 + %1 = OpFunction %void None %61 %20 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar %21 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %48 = OpLabel - %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %52 = OpLabel + %12 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %13 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %6 = OpVariable %_ptr_Function_uint Function %7 = OpVariable %_ptr_Function_ulong Function %8 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %20 - OpStore %3 %21 - %13 = OpBitcast %_ptr_Function_ulong %2 - %44 = OpLoad %ulong %13 Aligned 8 - %12 = OpCopyObject %ulong %44 - %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %12 + OpStore %12 %20 + OpStore %13 %21 + %45 = OpBitcast %_ptr_Function_ulong %12 + %44 = OpLoad %ulong %45 Aligned 8 + %14 = OpCopyObject %ulong %44 + %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14 OpStore %10 %22 - %15 = OpBitcast %_ptr_Function_ulong %3 - %45 = OpLoad %ulong %15 Aligned 8 - %14 = OpCopyObject %ulong %45 - %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14 + %47 = OpBitcast %_ptr_Function_ulong %13 + %46 = OpLoad %ulong %47 Aligned 8 + %15 = OpCopyObject %ulong %46 + %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15 OpStore %11 %23 %24 = OpLoad %_ptr_CrossWorkgroup_uchar %10 %17 = OpConvertPtrToU %ulong %24 @@ -57,35 +57,37 @@ %18 = OpCopyObject %ulong %19 %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %18 OpStore %11 %27 - %62 = OpLoad %v3ulong %gl_LocalInvocationID - %43 = OpCompositeExtract %ulong %62 0 - %63 = OpBitcast %ulong %43 - %29 = OpUConvert %uint %63 + %66 = OpLoad %v3ulong %gl_LocalInvocationID + %43 = OpCompositeExtract %ulong %66 0 + %67 = OpBitcast %ulong %43 + %29 = OpUConvert %uint %67 %28 = OpCopyObject %uint %29 OpStore %6 %28 %31 = OpLoad %uint %6 - %64 = OpBitcast %uint %31 - %30 = OpUConvert %ulong %64 + %68 = OpBitcast %uint %31 + %30 = OpUConvert %ulong %68 OpStore %7 %30 %33 = OpLoad %_ptr_CrossWorkgroup_uchar %10 %34 = OpLoad %ulong %7 - %65 = OpBitcast %_ptr_CrossWorkgroup_uchar %33 - %66 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %65 %34 - %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %66 + %48 = OpCopyObject %ulong %34 + %69 = OpBitcast %_ptr_CrossWorkgroup_uchar %33 + %70 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %69 %48 + %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %70 OpStore %10 %32 %36 = OpLoad %_ptr_CrossWorkgroup_uchar %11 %37 = OpLoad %ulong %7 - %67 = OpBitcast %_ptr_CrossWorkgroup_uchar %36 - %68 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %67 %37 - %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %68 + %49 = OpCopyObject %ulong %37 + %71 = OpBitcast %_ptr_CrossWorkgroup_uchar %36 + %72 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %71 %49 + %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %72 OpStore %11 %35 %39 = OpLoad %_ptr_CrossWorkgroup_uchar %10 - %46 = OpBitcast %_ptr_CrossWorkgroup_ulong %39 - %38 = OpLoad %ulong %46 Aligned 8 + %50 = OpBitcast %_ptr_CrossWorkgroup_ulong %39 + %38 = OpLoad %ulong %50 Aligned 8 OpStore %8 %38 %40 = OpLoad %_ptr_CrossWorkgroup_uchar %11 %41 = OpLoad %ulong %8 - %47 = OpBitcast %_ptr_CrossWorkgroup_ulong %40 - OpStore %47 %41 Aligned 8 + %51 = OpBitcast %_ptr_CrossWorkgroup_ulong %40 + OpStore %51 %41 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt index cb77d14..97bf000 100644 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt @@ -7,7 +7,7 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %58 = OpExtInstImport "OpenCL.std" + %62 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "stateful_ld_st_ntid_chain" %gl_LocalInvocationID OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId @@ -18,18 +18,18 @@ %gl_LocalInvocationID = OpVariable %_ptr_Input_v3ulong Input %uchar = OpTypeInt 8 0 %_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %65 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar + %69 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar %_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %1 = OpFunction %void None %65 + %1 = OpFunction %void None %69 %28 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar %29 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %56 = OpLabel - %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %60 = OpLabel + %20 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %21 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function @@ -39,17 +39,17 @@ %10 = OpVariable %_ptr_Function_uint Function %11 = OpVariable %_ptr_Function_ulong Function %12 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %28 - OpStore %3 %29 - %21 = OpBitcast %_ptr_Function_ulong %2 - %52 = OpLoad %ulong %21 Aligned 8 - %20 = OpCopyObject %ulong %52 - %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %20 + OpStore %20 %28 + OpStore %21 %29 + %53 = OpBitcast %_ptr_Function_ulong %20 + %52 = OpLoad %ulong %53 Aligned 8 + %22 = OpCopyObject %ulong %52 + %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22 OpStore %14 %30 - %23 = OpBitcast %_ptr_Function_ulong %3 - %53 = OpLoad %ulong %23 Aligned 8 - %22 = OpCopyObject %ulong %53 - %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22 + %55 = OpBitcast %_ptr_Function_ulong %21 + %54 = OpLoad %ulong %55 Aligned 8 + %23 = OpCopyObject %ulong %54 + %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %23 OpStore %17 %31 %32 = OpLoad %_ptr_CrossWorkgroup_uchar %14 %25 = OpConvertPtrToU %ulong %32 @@ -61,35 +61,37 @@ %26 = OpCopyObject %ulong %27 %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26 OpStore %18 %35 - %70 = OpLoad %v3ulong %gl_LocalInvocationID - %51 = OpCompositeExtract %ulong %70 0 - %71 = OpBitcast %ulong %51 - %37 = OpUConvert %uint %71 + %74 = OpLoad %v3ulong %gl_LocalInvocationID + %51 = OpCompositeExtract %ulong %74 0 + %75 = OpBitcast %ulong %51 + %37 = OpUConvert %uint %75 %36 = OpCopyObject %uint %37 OpStore %10 %36 %39 = OpLoad %uint %10 - %72 = OpBitcast %uint %39 - %38 = OpUConvert %ulong %72 + %76 = OpBitcast %uint %39 + %38 = OpUConvert %ulong %76 OpStore %11 %38 %41 = OpLoad %_ptr_CrossWorkgroup_uchar %15 %42 = OpLoad %ulong %11 - %73 = OpBitcast %_ptr_CrossWorkgroup_uchar %41 - %74 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %73 %42 - %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %74 + %56 = OpCopyObject %ulong %42 + %77 = OpBitcast %_ptr_CrossWorkgroup_uchar %41 + %78 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %77 %56 + %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %78 OpStore %16 %40 %44 = OpLoad %_ptr_CrossWorkgroup_uchar %18 %45 = OpLoad %ulong %11 - %75 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 - %76 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %75 %45 - %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %76 + %57 = OpCopyObject %ulong %45 + %79 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 + %80 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %79 %57 + %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %80 OpStore %19 %43 %47 = OpLoad %_ptr_CrossWorkgroup_uchar %16 - %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %47 - %46 = OpLoad %ulong %54 Aligned 8 + %58 = OpBitcast %_ptr_CrossWorkgroup_ulong %47 + %46 = OpLoad %ulong %58 Aligned 8 OpStore %12 %46 %48 = OpLoad %_ptr_CrossWorkgroup_uchar %19 %49 = OpLoad %ulong %12 - %55 = OpBitcast %_ptr_CrossWorkgroup_ulong %48 - OpStore %55 %49 Aligned 8 + %59 = OpBitcast %_ptr_CrossWorkgroup_ulong %48 + OpStore %59 %49 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt index ecf2858..8253bf9 100644 --- a/ptx/src/test/spirv_run/vector.spvtxt +++ b/ptx/src/test/spirv_run/vector.spvtxt @@ -25,8 +25,8 @@ %1 = OpFunction %v2uint None %55 %7 = OpFunctionParameter %v2uint %24 = OpLabel - %2 = OpVariable %_ptr_Function_v2uint Function %3 = OpVariable %_ptr_Function_v2uint Function + %2 = OpVariable %_ptr_Function_v2uint Function %4 = OpVariable %_ptr_Function_v2uint Function %5 = OpVariable %_ptr_Function_uint Function %6 = OpVariable %_ptr_Function_uint Function diff --git a/ptx/src/test/spirv_run/verify.py b/ptx/src/test/spirv_run/verify.py new file mode 100644 index 0000000..dbfab00 --- /dev/null +++ b/ptx/src/test/spirv_run/verify.py @@ -0,0 +1,21 @@ +import os, sys, subprocess + +def main(path): + dirs = os.listdir(path) + for file in dirs: + if not file.endswith(".spvtxt"): + continue + full_file = os.path.join(path, file) + print(file) + spv_file = f"/tmp/{file}.spv" + # We nominally emit spv1.3, but use spv1.4 feature (OpEntryPoint interface changes in 1.4) + proc1 = subprocess.run(["spirv-as", "--target-env", "spv1.4", full_file, "-o", spv_file]) + proc2 = subprocess.run(["spirv-dis", spv_file, "-o", f"{spv_file}.dis.txt"]) + proc3 = subprocess.run(["spirv-val", spv_file ]) + if proc1.returncode != 0 or proc2.returncode != 0 or proc3.returncode != 0: + print(proc1.returncode) + print(proc2.returncode) + print(proc3.returncode) + +if __name__ == "__main__": + main(sys.argv[1]) diff --git a/ptx/src/test/spirv_run/xor.spvtxt b/ptx/src/test/spirv_run/xor.spvtxt index 4cc8968..c3a1f6f 100644 --- a/ptx/src/test/spirv_run/xor.spvtxt +++ b/ptx/src/test/spirv_run/xor.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %39 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7170950..c2562c3 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,11 +1,9 @@ use crate::ast; use half::f16; use rspirv::dr; -use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem}; -use std::{ - collections::{hash_map, HashMap, HashSet}, - convert::TryInto, -}; +use std::cell::RefCell; +use std::collections::{hash_map, HashMap, HashSet}; +use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem, rc::Rc}; use rspirv::binary::Assemble; @@ -48,64 +46,21 @@ enum SpirvType { } impl SpirvType { - fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self { - let key = t.into(); - SpirvType::Pointer(Box::new(key), sc) - } -} - -impl From for SpirvType { - fn from(t: ast::Type) -> Self { + fn new(t: ast::Type) -> Self { match t { ast::Type::Scalar(t) => SpirvType::Base(t.into()), ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), - ast::Type::Pointer(pointer_t, state_space) => SpirvType::Pointer( - Box::new(SpirvType::from(ast::Type::from(pointer_t))), - state_space.to_spirv(), + ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer( + Box::new(SpirvType::Base(pointer_t.into())), + space.to_spirv(), ), } } -} -impl From for ast::Type { - fn from(t: ast::PointerType) -> Self { - match t { - ast::PointerType::Scalar(t) => ast::Type::Scalar(t), - ast::PointerType::Vector(t, len) => ast::Type::Vector(t, len), - ast::PointerType::Array(t, dims) => ast::Type::Array(t, dims), - ast::PointerType::Pointer(t, space) => { - ast::Type::Pointer(ast::PointerType::Scalar(t), space) - } - } - } -} - -impl ast::Type { - fn param_pointer_to(self, space: ast::LdStateSpace) -> Result { - Ok(match self { - ast::Type::Scalar(t) => ast::Type::Pointer(ast::PointerType::Scalar(t), space), - ast::Type::Vector(t, len) => { - ast::Type::Pointer(ast::PointerType::Vector(t, len), space) - } - ast::Type::Array(t, _) => ast::Type::Pointer(ast::PointerType::Scalar(t), space), - ast::Type::Pointer(ast::PointerType::Scalar(t), space) => { - ast::Type::Pointer(ast::PointerType::Pointer(t, space), space) - } - ast::Type::Pointer(_, _) => return Err(error_unreachable()), - }) - } -} - -impl Into for ast::PointerStateSpace { - fn into(self) -> spirv::StorageClass { - match self { - ast::PointerStateSpace::Const => spirv::StorageClass::UniformConstant, - ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup, - ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup, - ast::PointerStateSpace::Param => spirv::StorageClass::Function, - ast::PointerStateSpace::Generic => spirv::StorageClass::Generic, - } + fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self { + let key = Self::new(t); + SpirvType::Pointer(Box::new(key), outer_space) } } @@ -213,14 +168,18 @@ impl TypeWordMap { .or_insert_with(|| b.type_vector(None, base, len as u32)) } SpirvType::Array(typ, array_dimensions) => { - let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); let (base_type, length) = match &*array_dimensions { + &[] => { + return self.get_or_add(b, SpirvType::Base(typ)); + } &[len] => { + let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); let base = self.get_or_add_spirv_scalar(b, typ); let len_const = b.constant_u32(u32_type, None, len); (base, len_const) } array_dimensions => { + let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); let base = self .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec())); let len_const = b.constant_u32(u32_type, None, array_dimensions[0]); @@ -262,7 +221,7 @@ impl TypeWordMap { fn get_or_add_fn( &mut self, b: &mut dr::Builder, - in_params: impl ExactSizeIterator, + in_params: impl Iterator, mut out_params: impl ExactSizeIterator, ) -> (spirv::Word, spirv::Word) { let (out_args, out_spirv_type) = if out_params.len() == 0 { @@ -274,6 +233,7 @@ impl TypeWordMap { self.get_or_add(b, arg_as_key), ) } else { + // TODO: support multiple return values todo!() }; ( @@ -410,18 +370,7 @@ impl TypeWordMap { b.constant_composite(result_type, None, components.into_iter()) } }, - ast::Type::Pointer(typ, state_space) => { - let base_t = typ.clone().into(); - let base = self.get_or_add_constant(b, &base_t, &[])?; - let result_type = self.get_or_add( - b, - SpirvType::Pointer( - Box::new(SpirvType::from(base_t)), - (*state_space).to_spirv(), - ), - ); - b.variable(result_type, None, (*state_space).to_spirv(), Some(base)) - } + ast::Type::Pointer(..) => return Err(error_unreachable()), }) } @@ -487,7 +436,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result>(); let mut builder = dr::Builder::new(); builder.reserve_ids(id_defs.current_id()); - let call_map = get_call_map(&directives); + let call_map = get_kernels_call_map(&directives); let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id()); normalize_variable_decls(&mut directives); let denorm_information = compute_denorm_information(&directives); @@ -525,9 +474,12 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result( call_map: &HashMap<&str, HashSet>, - denorm_information: &HashMap>, + denorm_information: &HashMap< + ast::MethodName<'input, spirv::Word>, + HashMap, + >, ) -> CString { let denorm_counts = denorm_information .iter() @@ -545,10 +497,12 @@ fn emit_denorm_build_string( .collect::>(); let mut flush_over_preserve = 0; for (kernel, children) in call_map { - flush_over_preserve += *denorm_counts.get(&MethodName::Kernel(kernel)).unwrap_or(&0); + flush_over_preserve += *denorm_counts + .get(&ast::MethodName::Kernel(kernel)) + .unwrap_or(&0); for child_fn in children { flush_over_preserve += *denorm_counts - .get(&MethodName::Func(*child_fn)) + .get(&ast::MethodName::Func(*child_fn)) .unwrap_or(&0); } } @@ -564,15 +518,18 @@ fn emit_directives<'input>( map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver<'input>, opencl_id: spirv::Word, - denorm_information: &HashMap, HashMap>, + denorm_information: &HashMap< + ast::MethodName<'input, spirv::Word>, + HashMap, + >, call_map: &HashMap<&'input str, HashSet>, - directives: Vec, + directives: Vec>, kernel_info: &mut HashMap, ) -> Result<(), TranslateError> { let empty_body = Vec::new(); for d in directives.iter() { match d { - Directive::Variable(var) => { + Directive::Variable(_, var) => { emit_variable(builder, map, &var)?; } Directive::Method(f) => { @@ -589,12 +546,13 @@ fn emit_directives<'input>( for var in f.globals.iter() { emit_variable(builder, map, var)?; } + let func_decl = (*f.func_decl).borrow(); let fn_id = emit_function_header( builder, map, &id_defs, &f.globals, - &f.spirv_decl, + &*func_decl, &denorm_information, call_map, &directives, @@ -623,8 +581,13 @@ fn emit_directives<'input>( } emit_function_body_ops(builder, map, opencl_id, &f_body)?; builder.end_function()?; - if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) = - (&f.func_decl, &f.import_as) + if let ( + ast::MethodDeclaration { + name: ast::MethodName::Func(fn_id), + .. + }, + Some(name), + ) = (&*func_decl, &f.import_as) { builder.decorate( *fn_id, @@ -643,7 +606,7 @@ fn emit_directives<'input>( Ok(()) } -fn get_call_map<'input>( +fn get_kernels_call_map<'input>( module: &[Directive<'input>], ) -> HashMap<&'input str, HashSet> { let mut directly_called_by = HashMap::new(); @@ -654,14 +617,14 @@ fn get_call_map<'input>( body: Some(statements), .. }) => { - let call_key = MethodName::new(&func_decl); + let call_key: ast::MethodName<_> = (**func_decl).borrow().name; if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) { entry.insert(Vec::new()); } for statement in statements { match statement { Statement::Call(call) => { - multi_hash_map_append(&mut directly_called_by, call_key, call.func); + multi_hash_map_append(&mut directly_called_by, call_key, call.name); } _ => {} } @@ -673,28 +636,28 @@ fn get_call_map<'input>( let mut result = HashMap::new(); for (method_key, children) in directly_called_by.iter() { match method_key { - MethodName::Kernel(name) => { + ast::MethodName::Kernel(name) => { let mut visited = HashSet::new(); for child in children { add_call_map_single(&directly_called_by, &mut visited, *child); } result.insert(*name, visited); } - MethodName::Func(_) => {} + ast::MethodName::Func(_) => {} } } result } fn add_call_map_single<'input>( - directly_called_by: &MultiHashMap, spirv::Word>, + directly_called_by: &MultiHashMap, spirv::Word>, visited: &mut HashSet, current: spirv::Word, ) { if !visited.insert(current) { return; } - if let Some(children) = directly_called_by.get(&MethodName::Func(current)) { + if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) { for child in children { add_call_map_single(directly_called_by, visited, *child); } @@ -714,11 +677,29 @@ fn multi_hash_map_append(m: &mut MultiHashMap, } } -// PTX represents dynamically allocated shared local memory as -// .extern .shared .align 4 .b8 shared_mem[]; -// In SPIRV/OpenCL world this is expressed as an additional argument -// This pass looks for all uses of .extern .shared and converts them to -// an additional method argument +/* + PTX represents dynamically allocated shared local memory as + .extern .shared .b32 shared_mem[]; + In SPIRV/OpenCL world this is expressed as an additional argument + This pass looks for all uses of .extern .shared and converts them to + an additional method argument + The question is how this artificial argument should be expressed. There are + several options: + * Straight conversion: + .shared .b32 shared_mem[] + * Introduce .param_shared statespace: + .param_shared .b32 shared_mem + or + .param_shared .b32 shared_mem[] + * Introduce .shared_ptr type: + .param .shared_ptr .b32 shared_mem + * Reuse .ptr hint: + .param .u64 .ptr shared_mem + This is the most tempting, but also the most nonsensical, .ptr is just a + hint, which has no semantical meaning (and the output of our + transformation has a semantical meaning - we emit additional + "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...") +*/ fn convert_dynamic_shared_memory_usage<'input>( module: Vec>, new_id: &mut impl FnMut() -> spirv::Word, @@ -726,12 +707,16 @@ fn convert_dynamic_shared_memory_usage<'input>( let mut extern_shared_decls = HashMap::new(); for dir in module.iter() { match dir { - Directive::Variable(var) => { - if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) = - var.v_type - { - extern_shared_decls.insert(var.name, p_type); - } + Directive::Variable( + linking, + ast::Variable { + v_type: ast::Type::Array(p_type, dims), + state_space: ast::StateSpace::Shared, + name, + .. + }, + ) if linking.contains(ast::LinkingDirective::EXTERN) && dims.len() == 0 => { + extern_shared_decls.insert(*name, *p_type); } _ => {} } @@ -749,15 +734,14 @@ fn convert_dynamic_shared_memory_usage<'input>( globals, body: Some(statements), import_as, - spirv_decl, tuning, }) => { - let call_key = MethodName::new(&func_decl); + let call_key = (*func_decl).borrow().name; let statements = statements .into_iter() .map(|statement| match statement { Statement::Call(call) => { - multi_hash_map_append(&mut directly_called_by, call.func, call_key); + multi_hash_map_append(&mut directly_called_by, call.name, call_key); Statement::Call(call) } statement => statement.map_id(&mut |id, _| { @@ -773,7 +757,6 @@ fn convert_dynamic_shared_memory_usage<'input>( globals, body: Some(statements), import_as, - spirv_decl, tuning, }) } @@ -792,66 +775,34 @@ fn convert_dynamic_shared_memory_usage<'input>( globals, body: Some(statements), import_as, - mut spirv_decl, tuning, }) => { - if !methods_using_extern_shared.contains(&spirv_decl.name) { + if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) { return Directive::Method(Function { func_decl, globals, body: Some(statements), import_as, - spirv_decl, tuning, }); } let shared_id_param = new_id(); - spirv_decl.input.push({ - ast::Variable { - align: None, - v_type: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Shared, - ), - array_init: Vec::new(), - name: shared_id_param, - } - }); - spirv_decl.uses_shared_mem = true; - let shared_var_id = new_id(); - let shared_var = ExpandedStatement::Variable(ast::Variable { - align: None, - name: shared_var_id, - array_init: Vec::new(), - v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( - ast::SizedScalarType::B8, - ast::PointerStateSpace::Shared, - )), - }); - let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails { - arg: ast::Arg2St { - src1: shared_var_id, - src2: shared_id_param, - }, - typ: ast::Type::Scalar(ast::ScalarType::B8), - member_index: None, - }); - let mut new_statements = vec![shared_var, shared_var_st]; - replace_uses_of_shared_memory( - &mut new_statements, + { + let mut func_decl = (*func_decl).borrow_mut(); + func_decl.shared_mem = Some(shared_id_param); + } + let statements = replace_uses_of_shared_memory( new_id, &extern_shared_decls, &mut methods_using_extern_shared, shared_id_param, - shared_var_id, statements, ); Directive::Method(Function { func_decl, globals, - body: Some(new_statements), + body: Some(statements), import_as, - spirv_decl, tuning, }) } @@ -861,47 +812,43 @@ fn convert_dynamic_shared_memory_usage<'input>( } fn replace_uses_of_shared_memory<'a>( - result: &mut Vec, new_id: &mut impl FnMut() -> spirv::Word, - extern_shared_decls: &HashMap, - methods_using_extern_shared: &mut HashSet>, + extern_shared_decls: &HashMap, + methods_using_extern_shared: &mut HashSet>, shared_id_param: spirv::Word, - shared_var_id: spirv::Word, statements: Vec, -) { +) -> Vec { + let mut result = Vec::with_capacity(statements.len()); for statement in statements { match statement { Statement::Call(mut call) => { // We can safely skip checking call arguments, // because there's simply no way to pass shared ptr // without converting it to .b64 first - if methods_using_extern_shared.contains(&MethodName::Func(call.func)) { - call.param_list - .push((shared_id_param, ast::FnArgumentType::Shared)); + if methods_using_extern_shared.contains(&ast::MethodName::Func(call.name)) { + call.input_arguments.push(( + shared_id_param, + ast::Type::Scalar(ast::ScalarType::B8), + ast::StateSpace::Shared, + )); } result.push(Statement::Call(call)) } statement => { let new_statement = statement.map_id(&mut |id, _| { - if let Some(typ) = extern_shared_decls.get(&id) { - if *typ == ast::SizedScalarType::B8 { - return shared_var_id; + if let Some(scalar_type) = extern_shared_decls.get(&id) { + if *scalar_type == ast::ScalarType::B8 { + return shared_id_param; } let replacement_id = new_id(); result.push(Statement::Conversion(ImplicitConversion { - src: shared_var_id, + src: shared_id_param, dst: replacement_id, - from: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::B8), - ast::LdStateSpace::Shared, - ), - to: ast::Type::Pointer( - ast::PointerType::Scalar((*typ).into()), - ast::LdStateSpace::Shared, - ), - kind: ConversionKind::PtrToPtr { spirv_ptr: true }, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, + from_type: ast::Type::Scalar(ast::ScalarType::B8), + from_space: ast::StateSpace::Shared, + to_type: ast::Type::Scalar(*scalar_type), + to_space: ast::StateSpace::Shared, + kind: ConversionKind::PtrToPtr, })); replacement_id } else { @@ -912,16 +859,17 @@ fn replace_uses_of_shared_memory<'a>( } } } + result } fn get_callers_of_extern_shared<'a>( - methods_using_extern_shared: &mut HashSet>, - directly_called_by: &MultiHashMap>, + methods_using_extern_shared: &mut HashSet>, + directly_called_by: &MultiHashMap>, ) { let direct_uses_of_extern_shared = methods_using_extern_shared .iter() .filter_map(|method| { - if let MethodName::Func(f_id) = method { + if let ast::MethodName::Func(f_id) = method { Some(*f_id) } else { None @@ -934,14 +882,14 @@ fn get_callers_of_extern_shared<'a>( } fn get_callers_of_extern_shared_single<'a>( - methods_using_extern_shared: &mut HashSet>, - directly_called_by: &MultiHashMap>, + methods_using_extern_shared: &mut HashSet>, + directly_called_by: &MultiHashMap>, fn_id: spirv::Word, ) { if let Some(callers) = directly_called_by.get(&fn_id) { for caller in callers { if methods_using_extern_shared.insert(*caller) { - if let MethodName::Func(caller_fn) = caller { + if let ast::MethodName::Func(caller_fn) = caller { get_callers_of_extern_shared_single( methods_using_extern_shared, directly_called_by, @@ -983,18 +931,18 @@ fn denorm_count_map_update_impl( // and emit suitable execution mode fn compute_denorm_information<'input>( module: &[Directive<'input>], -) -> HashMap, HashMap> { +) -> HashMap, HashMap> { let mut denorm_methods = HashMap::new(); for directive in module { match directive { - Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {} + Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {} Directive::Method(Function { func_decl, body: Some(statements), .. }) => { let mut flush_counter = DenormCountMap::new(); - let method_key = MethodName::new(func_decl); + let method_key = (**func_decl).borrow().name; for statement in statements { match statement { Statement::Instruction(inst) => { @@ -1038,21 +986,6 @@ fn compute_denorm_information<'input>( .collect() } -#[derive(Hash, PartialEq, Eq, Copy, Clone)] -enum MethodName<'input> { - Kernel(&'input str), - Func(spirv::Word), -} - -impl<'input> MethodName<'input> { - fn new(decl: &ast::MethodDecl<'input, spirv::Word>) -> Self { - match decl { - ast::MethodDecl::Kernel { name, .. } => MethodName::Kernel(name), - ast::MethodDecl::Func(_, id, _) => MethodName::Func(*id), - } - } -} - fn emit_builtins( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -1061,10 +994,7 @@ fn emit_builtins( for (reg, id) in id_defs.special_registers.builtins() { let result_type = map.get_or_add( builder, - SpirvType::Pointer( - Box::new(SpirvType::from(reg.get_type())), - spirv::StorageClass::Input, - ), + SpirvType::pointer_to(reg.get_type(), spirv::StorageClass::Input), ); builder.variable(result_type, Some(id), spirv::StorageClass::Input, None); builder.decorate( @@ -1079,18 +1009,21 @@ fn emit_function_header<'a>( builder: &mut dr::Builder, map: &mut TypeWordMap, defined_globals: &GlobalStringIdResolver<'a>, - synthetic_globals: &[ast::Variable], - func_decl: &SpirvMethodDecl<'a>, - _denorm_information: &HashMap, HashMap>, + synthetic_globals: &[ast::Variable], + func_decl: &ast::MethodDeclaration<'a, spirv::Word>, + _denorm_information: &HashMap< + ast::MethodName<'a, spirv::Word>, + HashMap, + >, call_map: &HashMap<&'a str, HashSet>, direcitves: &[Directive], kernel_info: &mut HashMap, ) -> Result { - if let MethodName::Kernel(name) = func_decl.name { - let input_args = if !func_decl.uses_shared_mem { - func_decl.input.as_slice() + if let ast::MethodName::Kernel(name) = func_decl.name { + let input_args = if func_decl.shared_mem.is_none() { + func_decl.input_arguments.as_slice() } else { - &func_decl.input[0..func_decl.input.len() - 1] + &func_decl.input_arguments[0..func_decl.input_arguments.len() - 1] }; let args_lens = input_args .iter() @@ -1100,14 +1033,18 @@ fn emit_function_header<'a>( name.to_string(), KernelInfo { arguments_sizes: args_lens, - uses_shared_mem: func_decl.uses_shared_mem, + uses_shared_mem: func_decl.shared_mem.is_some(), }, ); } - let (ret_type, func_type) = - get_function_type(builder, map, &func_decl.input, &func_decl.output); + let (ret_type, func_type) = get_function_type( + builder, + map, + func_decl.effective_input_arguments().map(|(_, typ)| typ), + &func_decl.return_arguments, + ); let fn_id = match func_decl.name { - MethodName::Kernel(name) => { + ast::MethodName::Kernel(name) => { let fn_id = defined_globals.get_id(name)?; let mut global_variables = defined_globals .variables_type_check @@ -1123,15 +1060,18 @@ fn emit_function_header<'a>( for directive in direcitves { match directive { Directive::Method(Function { - func_decl: ast::MethodDecl::Func(_, name, _), - globals, - .. + func_decl, globals, .. }) => { - if child_fns.contains(name) { - for var in globals { - interface.push(var.name); + match (**func_decl).borrow().name { + ast::MethodName::Func(name) => { + if child_fns.contains(&name) { + for var in globals { + interface.push(var.name); + } + } } - } + ast::MethodName::Kernel(_) => {} + }; } _ => {} } @@ -1140,7 +1080,7 @@ fn emit_function_header<'a>( builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables); fn_id } - MethodName::Func(name) => name, + ast::MethodName::Func(name) => name, }; builder.begin_function( ret_type, @@ -1163,9 +1103,9 @@ fn emit_function_header<'a>( } } */ - for input in &func_decl.input { - let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone())); - builder.function_parameter(Some(input.name), result_type)?; + for (name, typ) in func_decl.effective_input_arguments() { + let result_type = map.get_or_add(builder, typ); + builder.function_parameter(Some(name), result_type)?; } Ok(fn_id) } @@ -1207,55 +1147,32 @@ fn translate_directive<'input>( d: ast::Directive<'input, ast::ParsedArgParams<'input>>, ) -> Result>, TranslateError> { Ok(match d { - ast::Directive::Variable(v) => Some(Directive::Variable(translate_variable(id_defs, v)?)), - ast::Directive::Method(f) => { + ast::Directive::Variable(linking, var) => Some(Directive::Variable( + linking, + ast::Variable { + align: var.align, + v_type: var.v_type.clone(), + state_space: var.state_space, + name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true), + array_init: var.array_init, + }, + )), + ast::Directive::Method(_, f) => { translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method) } }) } -fn translate_variable<'a>( - id_defs: &mut GlobalStringIdResolver<'a>, - var: ast::Variable, -) -> Result, TranslateError> { - let (space, var_type) = var.v_type.to_type(); - let mut is_variable = false; - let var_type = match space { - ast::StateSpace::Reg => { - is_variable = true; - var_type - } - ast::StateSpace::Const => var_type.param_pointer_to(ast::LdStateSpace::Const)?, - ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?, - ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?, - ast::StateSpace::Shared => { - // If it's a pointer it will be translated to a method parameter later - if let ast::Type::Pointer(..) = var_type { - is_variable = true; - var_type - } else { - var_type.param_pointer_to(ast::LdStateSpace::Shared)? - } - } - ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?, - }; - Ok(ast::Variable { - align: var.align, - v_type: var.v_type, - name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable), - array_init: var.array_init, - }) -} - fn translate_function<'a>( id_defs: &mut GlobalStringIdResolver<'a>, ptx_impl_imports: &mut HashMap>, f: ast::ParsedFunction<'a>, ) -> Result>, TranslateError> { let import_as = match &f.func_directive { - ast::MethodDecl::Func(_, "__assertfail", _) => { - Some("__zluda_ptx_impl____assertfail".to_owned()) - } + ast::MethodDeclaration { + name: ast::MethodName::Func("__assertfail"), + .. + } => Some("__zluda_ptx_impl____assertfail".to_owned()), _ => None, }; let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?; @@ -1279,63 +1196,38 @@ fn translate_function<'a>( } } -fn expand_kernel_params<'a, 'b>( +fn rename_fn_params<'a, 'b>( fn_resolver: &mut FnStringIdResolver<'a, 'b>, - args: impl Iterator>, -) -> Result>, TranslateError> { - args.map(|a| { - Ok(ast::KernelArgument { - name: fn_resolver.add_def( - a.name, - Some(ast::Type::from(a.v_type.clone()).param_pointer_to(ast::LdStateSpace::Param)?), - false, - ), + args: &'b [ast::Variable<&'a str>], +) -> Vec> { + args.iter() + .map(|a| ast::Variable { + name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true), v_type: a.v_type.clone(), + state_space: a.state_space, align: a.align, - array_init: Vec::new(), + array_init: a.array_init.clone(), }) - }) - .collect::>() -} - -fn expand_fn_params<'a, 'b>( - fn_resolver: &mut FnStringIdResolver<'a, 'b>, - args: impl Iterator>, -) -> Result>, TranslateError> { - args.map(|a| { - let is_variable = match a.v_type { - ast::FnArgumentType::Reg(_) => true, - _ => false, - }; - let var_type = a.v_type.to_func_type(); - Ok(ast::FnArgument { - name: fn_resolver.add_def(a.name, Some(var_type), is_variable), - v_type: a.v_type.clone(), - align: a.align, - array_init: Vec::new(), - }) - }) - .collect() + .collect() } fn to_ssa<'input, 'b>( ptx_impl_imports: &mut HashMap, mut id_defs: FnStringIdResolver<'input, 'b>, fn_defs: GlobalFnDeclResolver<'input, 'b>, - f_args: ast::MethodDecl<'input, spirv::Word>, + func_decl: Rc>>, f_body: Option>>>, tuning: Vec, ) -> Result, TranslateError> { - let mut spirv_decl = SpirvMethodDecl::new(&f_args); + //deparamize_function_decl(&func_decl)?; let f_body = match f_body { Some(vec) => vec, None => { return Ok(Function { - func_decl: f_args, + func_decl: func_decl, body: None, globals: Vec::new(), import_as: None, - spirv_decl, tuning, }) } @@ -1345,15 +1237,14 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; - let typed_statements = - convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; + let (func_decl, typed_statements) = + convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?; let ssa_statements = insert_mem_ssa_statements( typed_statements, &mut numeric_id_defs, - &f_args, - &mut spirv_decl, + &mut (*func_decl).borrow_mut(), )?; - let ssa_statements = fix_builtins(ssa_statements, &mut numeric_id_defs)?; + let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.finish(); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = @@ -1363,16 +1254,15 @@ fn to_ssa<'input, 'b>( let (f_body, globals) = extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs); Ok(Function { - func_decl: f_args, + func_decl: func_decl, globals: globals, body: Some(f_body), import_as: None, - spirv_decl, tuning, }) } -fn fix_builtins( +fn fix_special_registers( typed_statements: Vec, numeric_id_defs: &mut NumericIdResolver, ) -> Result, TranslateError> { @@ -1408,7 +1298,8 @@ fn fix_builtins( continue; } }; - let temp_id = numeric_id_defs.new_non_variable(Some(details.typ.clone())); + let temp_id = numeric_id_defs + .register_intermediate(Some((details.typ.clone(), details.state_space))); let real_dst = details.arg.dst; details.arg.dst = temp_id; result.push(Statement::LoadVar(LoadVarDetails { @@ -1416,17 +1307,18 @@ fn fix_builtins( src: sreg_src, dst: temp_id, }, + state_space: ast::StateSpace::Sreg, typ: ast::Type::Scalar(scalar_typ), member_index: Some((index, Some(vector_width))), })); result.push(Statement::Conversion(ImplicitConversion { src: temp_id, dst: real_dst, - from: ast::Type::Scalar(scalar_typ), - to: ast::Type::Scalar(ast::ScalarType::U32), + from_type: ast::Type::Scalar(scalar_typ), + from_space: ast::StateSpace::Sreg, + to_type: ast::Type::Scalar(ast::ScalarType::U32), + to_space: ast::StateSpace::Sreg, kind: ConversionKind::Default, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, })); } } @@ -1456,10 +1348,7 @@ fn extract_globals<'input, 'b>( sorted_statements: Vec, ptx_impl_imports: &mut HashMap, id_def: &mut NumericIdResolver, -) -> ( - Vec, - Vec>, -) { +) -> (Vec, Vec>) { let mut local = Vec::with_capacity(sorted_statements.len()); let mut global = Vec::new(); for statement in sorted_statements { @@ -1468,7 +1357,7 @@ fn extract_globals<'input, 'b>( var @ ast::Variable { - v_type: ast::VariableType::Shared(_), + state_space: ast::StateSpace::Shared, .. }, ) @@ -1476,7 +1365,7 @@ fn extract_globals<'input, 'b>( var @ ast::Variable { - v_type: ast::VariableType::Global(_), + state_space: ast::StateSpace::Global, .. }, ) => global.push(var), @@ -1505,7 +1394,7 @@ fn extract_globals<'input, 'b>( d, a, "inc", - ast::SizedScalarType::U32, + ast::ScalarType::U32, )); } Statement::Instruction(ast::Instruction::Atom( @@ -1527,7 +1416,7 @@ fn extract_globals<'input, 'b>( d, a, "dec", - ast::SizedScalarType::U32, + ast::ScalarType::U32, )); } Statement::Instruction(ast::Instruction::Atom( @@ -1553,10 +1442,9 @@ fn extract_globals<'input, 'b>( space, }; let (op, typ) = match typ { - ast::FloatType::F32 => ("add_f32", ast::SizedScalarType::F32), - ast::FloatType::F64 => ("add_f64", ast::SizedScalarType::F64), - ast::FloatType::F16 => unreachable!(), - ast::FloatType::F16x2 => unreachable!(), + ast::ScalarType::F32 => ("add_f32", ast::ScalarType::F32), + ast::ScalarType::F64 => ("add_f64", ast::ScalarType::F64), + _ => unreachable!(), }; local.push(to_ptx_impl_atomic_call( id_def, @@ -1599,47 +1487,13 @@ fn convert_to_typed_statements( match s { Statement::Instruction(inst) => match inst { ast::Instruction::Call(call) => { - // TODO: error out if lengths don't match - let fn_def = fn_defs.get_fn_decl(call.func)?; - let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals); - let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params); - let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args - .into_iter() - .partition(|(_, arg_type)| arg_type.is_param()); - let normalized_input_args = out_params - .into_iter() - .map(|(id, typ)| (ast::Operand::Reg(id), typ)) - .chain(in_args.into_iter()) - .collect(); - let resolved_call = ResolvedCall { - uniform: call.uniform, - ret_params: out_non_params, - func: call.func, - param_list: normalized_input_args, - }; + let resolver = fn_defs.get_fn_sig_resolver(call.func)?; + let resolved_call = resolver.resolve_in_spirv_repr(call)?; let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let reresolved_call = resolved_call.visit(&mut visitor)?; visitor.func.push(reresolved_call); visitor.func.extend(visitor.post_stmts); } - ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => { - if let Some(src_id) = src.underlying() { - let (typ, _) = id_defs.get_typed(*src_id)?; - let take_address = match typ { - ast::Type::Scalar(_) => false, - ast::Type::Vector(_, _) => false, - ast::Type::Array(_, _) => true, - ast::Type::Pointer(_, _) => true, - }; - d.src_is_address = take_address; - } - let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); - let instruction = Statement::Instruction( - ast::Instruction::Mov(d, ast::Arg2Mov { dst, src }).map(&mut visitor)?, - ); - visitor.func.push(instruction); - visitor.func.extend(visitor.post_stmts); - } inst => { let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let instruction = Statement::Instruction(inst.map(&mut visitor)?); @@ -1674,8 +1528,14 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { fn convert_vector( &mut self, is_dst: bool, - vector_sema: ArgumentSemantics, + non_default_implicit_conversion: Option< + fn( + (ast::StateSpace, &ast::Type), + (ast::StateSpace, &ast::Type), + ) -> Result, TranslateError>, + >, typ: &ast::Type, + state_space: ast::StateSpace, idx: Vec, ) -> Result { // mov.u32 foobar, {a,b}; @@ -1683,13 +1543,15 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { ast::Type::Vector(scalar_t, _) => *scalar_t, _ => return Err(TranslateError::MismatchedType), }; - let temp_vec = self.id_def.new_non_variable(Some(typ.clone())); + let temp_vec = self + .id_def + .register_intermediate(Some((typ.clone(), state_space))); let statement = Statement::RepackVector(RepackVectorDetails { is_extract: is_dst, typ: scalar_t, packed: temp_vec, unpacked: idx, - vector_sema, + non_default_implicit_conversion, }); if is_dst { self.post_stmts = Some(statement); @@ -1706,7 +1568,7 @@ impl<'a, 'b> ArgumentMapVisitor fn id( &mut self, desc: ArgumentDescriptor, - _: Option<&ast::Type>, + _: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { Ok(desc.op) } @@ -1715,15 +1577,20 @@ impl<'a, 'b> ArgumentMapVisitor &mut self, desc: ArgumentDescriptor>, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { Ok(match desc.op { ast::Operand::Reg(reg) => TypedOperand::Reg(reg), ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), ast::Operand::Imm(x) => TypedOperand::Imm(x), ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), - ast::Operand::VecPack(vec) => { - TypedOperand::Reg(self.convert_vector(desc.is_dst, desc.sema, typ, vec)?) - } + ast::Operand::VecPack(vec) => TypedOperand::Reg(self.convert_vector( + desc.is_dst, + desc.non_default_implicit_conversion, + typ, + state_space, + vec, + )?), }) } } @@ -1735,7 +1602,7 @@ fn to_ptx_impl_atomic_call( details: ast::AtomDetails, arg: ast::Arg3, op: &'static str, - typ: ast::SizedScalarType, + typ: ast::ScalarType, ) -> ExpandedStatement { let semantics = ptx_semantics_name(details.semantics); let scope = ptx_scope_name(details.scope); @@ -1745,75 +1612,70 @@ fn to_ptx_impl_atomic_call( semantics, scope, space, op ); // TODO: extract to a function - let ptr_space = match details.space { - ast::AtomSpace::Generic => ast::PointerStateSpace::Generic, - ast::AtomSpace::Global => ast::PointerStateSpace::Global, - ast::AtomSpace::Shared => ast::PointerStateSpace::Shared, - }; + let ptr_space = details.space; let scalar_typ = ast::ScalarType::from(typ); let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.new_non_variable(None); - let func_decl = ast::MethodDecl::Func::( - vec![ast::FnArgument { + let fn_id = id_defs.register_intermediate(None); + let func_decl = ast::MethodDeclaration:: { + return_arguments: vec![ast::Variable { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), - name: id_defs.new_non_variable(None), + v_type: ast::Type::Scalar(scalar_typ), + state_space: ast::StateSpace::Reg, + name: id_defs.register_intermediate(None), array_init: Vec::new(), }], - fn_id, - vec![ - ast::FnArgument { + name: ast::MethodName::Func(fn_id), + input_arguments: vec![ + ast::Variable { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer( - typ, ptr_space, - )), - name: id_defs.new_non_variable(None), + v_type: ast::Type::Pointer(typ, ptr_space), + state_space: ast::StateSpace::Reg, + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), - name: id_defs.new_non_variable(None), + v_type: ast::Type::Scalar(scalar_typ), + state_space: ast::StateSpace::Reg, + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ], - ); - let spirv_decl = SpirvMethodDecl::new(&func_decl); + shared_mem: None, + }; let func = Function { - func_decl, + func_decl: Rc::new(RefCell::new(func_decl)), globals: Vec::new(), body: None, import_as: Some(entry.key().clone()), - spirv_decl, tuning: Vec::new(), }; entry.insert(Directive::Method(func)); fn_id } hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { - func_decl: ast::MethodDecl::Func(_, name, _), - .. - }) => *name, + Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { + ast::MethodName::Func(fn_id) => fn_id, + ast::MethodName::Kernel(_) => unreachable!(), + }, _ => unreachable!(), }, }; Statement::Call(ResolvedCall { uniform: false, - func: fn_id, - ret_params: vec![( - arg.dst, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), - )], - param_list: vec![ + name: fn_id, + return_arguments: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)], + input_arguments: vec![ ( arg.src1, - ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(typ, ptr_space)), + ast::Type::Pointer(typ, ptr_space), + ast::StateSpace::Reg, ), ( arg.src2, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), + ast::Type::Scalar(scalar_typ), + ast::StateSpace::Reg, ), ], }) @@ -1822,93 +1684,92 @@ fn to_ptx_impl_atomic_call( fn to_ptx_impl_bfe_call( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, - typ: ast::IntType, + typ: ast::ScalarType, arg: ast::Arg4, ) -> ExpandedStatement { let prefix = "__zluda_ptx_impl__"; let suffix = match typ { - ast::IntType::U32 => "bfe_u32", - ast::IntType::U64 => "bfe_u64", - ast::IntType::S32 => "bfe_s32", - ast::IntType::S64 => "bfe_s64", + ast::ScalarType::U32 => "bfe_u32", + ast::ScalarType::U64 => "bfe_u64", + ast::ScalarType::S32 => "bfe_s32", + ast::ScalarType::S64 => "bfe_s64", _ => unreachable!(), }; let fn_name = format!("{}{}", prefix, suffix); let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.new_non_variable(None); - let func_decl = ast::MethodDecl::Func::( - vec![ast::FnArgument { + let fn_id = id_defs.register_intermediate(None); + let func_decl = ast::MethodDeclaration:: { + return_arguments: vec![ast::Variable { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), - name: id_defs.new_non_variable(None), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, + name: id_defs.register_intermediate(None), array_init: Vec::new(), }], - fn_id, - vec![ - ast::FnArgument { + name: ast::MethodName::Func(fn_id), + input_arguments: vec![ + ast::Variable { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), - name: id_defs.new_non_variable(None), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( - ast::ScalarType::U32, - )), - name: id_defs.new_non_variable(None), + v_type: ast::Type::Scalar(ast::ScalarType::U32), + state_space: ast::StateSpace::Reg, + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( - ast::ScalarType::U32, - )), - name: id_defs.new_non_variable(None), + v_type: ast::Type::Scalar(ast::ScalarType::U32), + state_space: ast::StateSpace::Reg, + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ], - ); - let spirv_decl = SpirvMethodDecl::new(&func_decl); + shared_mem: None, + }; let func = Function { - func_decl, + func_decl: Rc::new(RefCell::new(func_decl)), globals: Vec::new(), body: None, import_as: Some(entry.key().clone()), - spirv_decl, tuning: Vec::new(), }; entry.insert(Directive::Method(func)); fn_id } hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { - func_decl: ast::MethodDecl::Func(_, name, _), - .. - }) => *name, + Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { + ast::MethodName::Func(fn_id) => fn_id, + ast::MethodName::Kernel(_) => unreachable!(), + }, _ => unreachable!(), }, }; Statement::Call(ResolvedCall { uniform: false, - func: fn_id, - ret_params: vec![( - arg.dst, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), - )], - param_list: vec![ + name: fn_id, + return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], + input_arguments: vec![ ( arg.src1, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + ast::Type::Scalar(typ.into()), + ast::StateSpace::Reg, ), ( arg.src2, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, ), ( arg.src3, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, ), ], }) @@ -1917,117 +1778,107 @@ fn to_ptx_impl_bfe_call( fn to_ptx_impl_bfi_call( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, - typ: ast::BitType, + typ: ast::ScalarType, arg: ast::Arg5, ) -> ExpandedStatement { let prefix = "__zluda_ptx_impl__"; let suffix = match typ { - ast::BitType::B32 => "bfi_b32", - ast::BitType::B64 => "bfi_b64", - ast::BitType::B8 | ast::BitType::B16 => unreachable!(), + ast::ScalarType::B32 => "bfi_b32", + ast::ScalarType::B64 => "bfi_b64", + _ => unreachable!(), }; let fn_name = format!("{}{}", prefix, suffix); let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.new_non_variable(None); - let func_decl = ast::MethodDecl::Func::( - vec![ast::FnArgument { + let fn_id = id_defs.register_intermediate(None); + let func_decl = ast::MethodDeclaration:: { + return_arguments: vec![ast::Variable { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), - name: id_defs.new_non_variable(None), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, + name: id_defs.register_intermediate(None), array_init: Vec::new(), }], - fn_id, - vec![ - ast::FnArgument { + name: ast::MethodName::Func(fn_id), + input_arguments: vec![ + ast::Variable { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), - name: id_defs.new_non_variable(None), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), - name: id_defs.new_non_variable(None), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( - ast::ScalarType::U32, - )), - name: id_defs.new_non_variable(None), + v_type: ast::Type::Scalar(ast::ScalarType::U32), + state_space: ast::StateSpace::Reg, + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, - ast::FnArgument { + ast::Variable { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( - ast::ScalarType::U32, - )), - name: id_defs.new_non_variable(None), + v_type: ast::Type::Scalar(ast::ScalarType::U32), + state_space: ast::StateSpace::Reg, + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ], - ); - let spirv_decl = SpirvMethodDecl::new(&func_decl); + shared_mem: None, + }; let func = Function { - func_decl, + func_decl: Rc::new(RefCell::new(func_decl)), globals: Vec::new(), body: None, import_as: Some(entry.key().clone()), - spirv_decl, tuning: Vec::new(), }; entry.insert(Directive::Method(func)); fn_id } hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { - func_decl: ast::MethodDecl::Func(_, name, _), - .. - }) => *name, + Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { + ast::MethodName::Func(fn_id) => fn_id, + ast::MethodName::Kernel(_) => unreachable!(), + }, _ => unreachable!(), }, }; Statement::Call(ResolvedCall { uniform: false, - func: fn_id, - ret_params: vec![( - arg.dst, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), - )], - param_list: vec![ + name: fn_id, + return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], + input_arguments: vec![ ( arg.src1, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + ast::Type::Scalar(typ.into()), + ast::StateSpace::Reg, ), ( arg.src2, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + ast::Type::Scalar(typ.into()), + ast::StateSpace::Reg, ), ( arg.src3, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, ), ( arg.src4, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, ), ], }) } -fn to_resolved_fn_args( - params: Vec, - params_decl: &[ast::FnArgumentType], -) -> Vec<(T, ast::FnArgumentType)> { - params - .into_iter() - .zip(params_decl.iter()) - .map(|(id, typ)| (id, typ.clone())) - .collect::>() -} - fn normalize_labels( func: Vec, id_def: &mut NumericIdResolver, @@ -2056,7 +1907,7 @@ fn normalize_labels( | Statement::RepackVector(..) => {} } } - iter::once(Statement::Label(id_def.new_non_variable(None))) + iter::once(Statement::Label(id_def.register_intermediate(None))) .chain(func.into_iter().filter(|s| match s { Statement::Label(i) => labels_in_use.contains(i), _ => true, @@ -2074,8 +1925,8 @@ fn normalize_predicates( Statement::Label(id) => result.push(Statement::Label(id)), Statement::Instruction((pred, inst)) => { if let Some(pred) = pred { - let if_true = id_def.new_non_variable(None); - let if_false = id_def.new_non_variable(None); + let if_true = id_def.register_intermediate(None); + let if_false = id_def.register_intermediate(None); let folded_bra = match &inst { ast::Instruction::Bra(_, arg) => Some(arg.src), _ => None, @@ -2106,53 +1957,52 @@ fn normalize_predicates( Ok(result) } +/* + How do we handle arguments: + - input .params in kernels + .param .b64 in_arg + get turned into this SPIR-V: + %1 = OpFunctionParameter %ulong + %2 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %1 + We do this for two reasons. One, common treatment for argument-declared + .param variables and .param variables inside function (we assume that + at SPIR-V level every .param is a pointer in Function storage class) + - input .params in functions + .param .b64 in_arg + get turned into this SPIR-V: + %1 = OpFunctionParameter %_ptr_Function_ulong + - input .regs + .reg .b64 in_arg + get turned into the same SPIR-V as kernel .params: + %1 = OpFunctionParameter %ulong + %2 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %1 + - output .regs + .reg .b64 out_arg + get just a variable declaration: + %2 = OpVariable %%_ptr_Function_ulong Function + - output .params don't exist, they have been moved to input positions + by an earlier pass + Distinguishing betweem kernel .params and function .params is not the + cleanest solution. Alternatively, we could "deparamize" all kernel .param + arguments by turning them into .reg arguments like this: + .param .b64 arg -> .reg ptr<.b64,.param> arg + This has the massive downside that this transformation would have to run + very early and would muddy up already difficult code. It's simpler to just + have an if here +*/ fn insert_mem_ssa_statements<'a, 'b>( func: Vec, id_def: &mut NumericIdResolver, - ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>, - fn_decl: &mut SpirvMethodDecl, + fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>, ) -> Result, TranslateError> { - let is_func = match ast_fn_decl { - ast::MethodDecl::Func(..) => true, - ast::MethodDecl::Kernel { .. } => false, - }; let mut result = Vec::with_capacity(func.len()); - for arg in fn_decl.output.iter() { - match type_to_variable_type(&arg.v_type, is_func)? { - Some(var_type) => { - result.push(Statement::Variable(ast::Variable { - align: arg.align, - v_type: var_type, - name: arg.name, - array_init: arg.array_init.clone(), - })); - } - None => return Err(error_unreachable()), - } + for arg in fn_decl.input_arguments.iter_mut() { + insert_mem_ssa_argument(id_def, &mut result, arg, fn_decl.name.is_kernel()); } - for spirv_arg in fn_decl.input.iter_mut() { - match type_to_variable_type(&spirv_arg.v_type, is_func)? { - Some(var_type) => { - let typ = spirv_arg.v_type.clone(); - let new_id = id_def.new_non_variable(Some(typ.clone())); - result.push(Statement::Variable(ast::Variable { - align: spirv_arg.align, - v_type: var_type, - name: spirv_arg.name, - array_init: spirv_arg.array_init.clone(), - })); - result.push(Statement::StoreVar(StoreVarDetails { - arg: ast::Arg2St { - src1: spirv_arg.name, - src2: new_id, - }, - typ, - member_index: None, - })); - spirv_arg.name = new_id; - } - None => {} - } + for arg in fn_decl.return_arguments.iter() { + insert_mem_ssa_argument_reg_return(&mut result, arg); } for s in func { match s { @@ -2162,32 +2012,41 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::Instruction(inst) => match inst { ast::Instruction::Ret(d) => { // TODO: handle multiple output args - if let &[out_param] = &fn_decl.output.as_slice() { - let (typ, _) = id_def.get_typed(out_param.name)?; - let new_id = id_def.new_non_variable(Some(typ.clone())); - result.push(Statement::LoadVar(LoadVarDetails { - arg: ast::Arg2 { - dst: new_id, - src: out_param.name, - }, - typ: typ.clone(), - member_index: None, - })); - result.push(Statement::RetValue(d, new_id)); - } else { - result.push(Statement::Instruction(ast::Instruction::Ret(d))) + match &fn_decl.return_arguments[..] { + [return_reg] => { + let new_id = id_def.register_intermediate(Some(( + return_reg.v_type.clone(), + ast::StateSpace::Reg, + ))); + result.push(Statement::LoadVar(LoadVarDetails { + arg: ast::Arg2 { + dst: new_id, + src: return_reg.name, + }, + // TODO: ret with stateful conversion + state_space: ast::StateSpace::Reg, + typ: return_reg.v_type.clone(), + member_index: None, + })); + result.push(Statement::RetValue(d, new_id)); + } + [] => result.push(Statement::Instruction(ast::Instruction::Ret(d))), + _ => unimplemented!(), } } inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?, }, Statement::Conditional(mut bra) => { - let generated_id = - id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred))); + let generated_id = id_def.register_intermediate(Some(( + ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + ))); result.push(Statement::LoadVar(LoadVarDetails { arg: Arg2 { dst: generated_id, src: bra.predicate, }, + state_space: ast::StateSpace::Reg, typ: ast::Type::Scalar(ast::ScalarType::Pred), member_index: None, })); @@ -2210,39 +2069,45 @@ fn insert_mem_ssa_statements<'a, 'b>( Ok(result) } -fn type_to_variable_type( - t: &ast::Type, - is_func: bool, -) -> Result, TranslateError> { - Ok(match t { - ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))), - ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector( - (*typ) - .try_into() - .map_err(|_| TranslateError::MismatchedType)?, - *len, - ))), - ast::Type::Array(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Array( - (*typ) - .try_into() - .map_err(|_| TranslateError::MismatchedType)?, - len.clone(), - ))), - ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => { - if is_func { - return Ok(None); - } - Some(ast::VariableType::Reg(ast::VariableRegType::Pointer( - scalar_type - .clone() - .try_into() - .map_err(|_| error_unreachable())?, - (*space).try_into().map_err(|_| error_unreachable())?, - ))) - } - ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None, - _ => return Err(error_unreachable()), - }) +fn insert_mem_ssa_argument( + id_def: &mut NumericIdResolver, + func: &mut Vec, + arg: &mut ast::Variable, + is_kernel: bool, +) { + if !is_kernel && arg.state_space == ast::StateSpace::Param { + return; + } + let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space))); + func.push(Statement::Variable(ast::Variable { + align: arg.align, + v_type: arg.v_type.clone(), + state_space: ast::StateSpace::Reg, + name: arg.name, + array_init: Vec::new(), + })); + func.push(Statement::StoreVar(StoreVarDetails { + arg: ast::Arg2St { + src1: arg.name, + src2: new_id, + }, + typ: arg.v_type.clone(), + member_index: None, + })); + arg.name = new_id; +} + +fn insert_mem_ssa_argument_reg_return( + func: &mut Vec, + arg: &ast::Variable, +) { + func.push(Statement::Variable(ast::Variable { + align: arg.align, + v_type: arg.v_type.clone(), + state_space: arg.state_space, + name: arg.name, + array_init: arg.array_init.clone(), + })); } trait Visitable: Sized { @@ -2259,6 +2124,7 @@ struct VisitArgumentDescriptor< > { desc: ArgumentDescriptor, typ: &'a ast::Type, + state_space: ast::StateSpace, stmt_ctor: Ctor, } @@ -2273,7 +2139,9 @@ impl< self, visitor: &mut impl ArgumentMapVisitor, ) -> Result, U>, TranslateError> { - Ok((self.stmt_ctor)(visitor.id(self.desc, Some(self.typ))?)) + Ok((self.stmt_ctor)( + visitor.id(self.desc, Some((self.typ, self.state_space)))?, + )) } } @@ -2287,14 +2155,14 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { fn symbol( &mut self, desc: ArgumentDescriptor<(spirv::Word, Option)>, - expected_type: Option<&ast::Type>, + expected: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { let symbol = desc.op.0; - if expected_type.is_none() { + if expected.is_none() { return Ok(symbol); }; - let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?; - if !is_variable { + let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?; + if !var_space.is_compatible(ast::StateSpace::Reg) || !is_variable { return Ok(symbol); }; let member_index = match desc.op.1 { @@ -2317,13 +2185,16 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { } None => None, }; - let generated_id = self.id_def.new_non_variable(Some(var_type.clone())); + let generated_id = self + .id_def + .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg))); if !desc.is_dst { self.func.push(Statement::LoadVar(LoadVarDetails { arg: Arg2 { dst: generated_id, src: symbol, }, + state_space: ast::StateSpace::Reg, typ: var_type, member_index, })); @@ -2348,7 +2219,7 @@ impl<'a, 'input> ArgumentMapVisitor fn id( &mut self, desc: ArgumentDescriptor, - typ: Option<&ast::Type>, + typ: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self.symbol(desc.new_op((desc.op, None)), typ) } @@ -2357,18 +2228,20 @@ impl<'a, 'input> ArgumentMapVisitor &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { Ok(match desc.op { TypedOperand::Reg(reg) => { - TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) - } - TypedOperand::RegOffset(reg, offset) => { - TypedOperand::RegOffset(self.symbol(desc.new_op((reg, None)), Some(typ))?, offset) + TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?) } + TypedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset( + self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?, + offset, + ), op @ TypedOperand::Imm(..) => op, - TypedOperand::VecMember(symbol, index) => { - TypedOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) - } + TypedOperand::VecMember(symbol, index) => TypedOperand::Reg( + self.symbol(desc.new_op((symbol, Some(index))), Some((typ, state_space)))?, + ), }) } } @@ -2411,11 +2284,13 @@ fn expand_arguments<'a, 'b>( Statement::Variable(ast::Variable { align, v_type, + state_space, name, array_init, }) => result.push(Statement::Variable(ast::Variable { align, v_type, + state_space, name, array_init, })), @@ -2464,7 +2339,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn reg( &mut self, desc: ArgumentDescriptor, - _: Option<&ast::Type>, + _: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { Ok(desc.op) } @@ -2473,108 +2348,86 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { &mut self, desc: ArgumentDescriptor<(spirv::Word, i32)>, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { let (reg, offset) = desc.op; - let add_type; - match typ { - ast::Type::Pointer(underlying_type, state_space) => { - let reg_typ = self.id_def.get_typed(reg)?; - if let ast::Type::Pointer(_, _) = reg_typ { - let id_constant_stmt = self.id_def.new_non_variable(typ.clone()); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: ast::ScalarType::S64, - value: ast::ImmediateValue::S64(offset as i64), - })); - let dst = self.id_def.new_non_variable(typ.clone()); - self.func.push(Statement::PtrAccess(PtrAccess { - underlying_type: underlying_type.clone(), - state_space: *state_space, - dst, - ptr_src: reg, - offset_src: id_constant_stmt, - })); - return Ok(dst); - } else { - add_type = self.id_def.get_typed(reg)?; - } + if !desc.is_memory_access { + let (reg_type, reg_space) = self.id_def.get_typed(reg)?; + if !reg_space.is_compatible(ast::StateSpace::Reg) { + return Err(TranslateError::MismatchedType); } - _ => { - add_type = typ.clone(); - } - }; - let (width, kind) = match add_type { - ast::Type::Scalar(scalar_t) => { - let kind = match scalar_t.kind() { - kind @ ScalarKind::Bit - | kind @ ScalarKind::Unsigned - | kind @ ScalarKind::Signed => kind, - ScalarKind::Float => return Err(TranslateError::MismatchedType), - ScalarKind::Float2 => return Err(TranslateError::MismatchedType), - ScalarKind::Pred => return Err(TranslateError::MismatchedType), - }; - (scalar_t.size_of(), kind) - } - _ => return Err(TranslateError::MismatchedType), - }; - let arith_detail = if kind == ScalarKind::Signed { - ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::from_size(width), - saturate: false, - }) - } else { - ast::ArithDetails::Unsigned(ast::UIntType::from_size(width)) - }; - let id_constant_stmt = self.id_def.new_non_variable(add_type.clone()); - let result_id = self.id_def.new_non_variable(add_type); - // TODO: check for edge cases around min value/max value/wrapping - if offset < 0 && kind != ScalarKind::Signed { + let reg_scalar_type = match reg_type { + ast::Type::Scalar(underlying_type) => underlying_type, + _ => return Err(TranslateError::MismatchedType), + }; + let id_constant_stmt = self + .id_def + .register_intermediate(reg_type.clone(), ast::StateSpace::Reg); self.func.push(Statement::Constant(ConstantDefinition { dst: id_constant_stmt, - typ: ast::ScalarType::from_parts(width, kind), - value: ast::ImmediateValue::U64(-(offset as i64) as u64), - })); - self.func.push(Statement::Instruction( - ast::Instruction::::Sub( - arith_detail, - ast::Arg3 { - dst: result_id, - src1: reg, - src2: id_constant_stmt, - }, - ), - )); - } else { - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: ast::ScalarType::from_parts(width, kind), + typ: reg_scalar_type, value: ast::ImmediateValue::S64(offset as i64), })); - self.func.push(Statement::Instruction( - ast::Instruction::::Add( - arith_detail, - ast::Arg3 { - dst: result_id, - src1: reg, - src2: id_constant_stmt, - }, - ), - )); + let arith_details = match reg_scalar_type.kind() { + ast::ScalarKind::Signed => ast::ArithDetails::Signed(ast::ArithSInt { + typ: reg_scalar_type, + saturate: false, + }), + ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => { + ast::ArithDetails::Unsigned(reg_scalar_type) + } + _ => return Err(error_unreachable()), + }; + let id_add_result = self.id_def.register_intermediate(reg_type, state_space); + self.func.push(Statement::Instruction(ast::Instruction::Add( + arith_details, + ast::Arg3 { + dst: id_add_result, + src1: reg, + src2: id_constant_stmt, + }, + ))); + Ok(id_add_result) + } else { + let scalar_type = match typ { + ast::Type::Scalar(underlying_type) => *underlying_type, + _ => return Err(error_unreachable()), + }; + let id_constant_stmt = self.id_def.register_intermediate( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: ast::ScalarType::S64, + value: ast::ImmediateValue::S64(offset as i64), + })); + let dst = self.id_def.register_intermediate(typ.clone(), state_space); + self.func.push(Statement::PtrAccess(PtrAccess { + underlying_type: scalar_type, + state_space: state_space, + dst, + ptr_src: reg, + offset_src: id_constant_stmt, + })); + Ok(dst) } - Ok(result_id) } fn immediate( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { let scalar_t = if let ast::Type::Scalar(scalar) = typ { *scalar } else { todo!() }; - let id = self.id_def.new_non_variable(ast::Type::Scalar(scalar_t)); + let id = self + .id_def + .register_intermediate(ast::Type::Scalar(scalar_t), state_space); self.func.push(Statement::Constant(ConstantDefinition { dst: id, typ: scalar_t, @@ -2588,7 +2441,7 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr fn id( &mut self, desc: ArgumentDescriptor, - t: Option<&ast::Type>, + t: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self.reg(desc, t) } @@ -2597,12 +2450,13 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { match desc.op { - TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)), - TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ), + TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some((typ, state_space))), + TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ, state_space), TypedOperand::RegOffset(reg, offset) => { - self.reg_offset(desc.new_op((reg, offset)), typ) + self.reg_offset(desc.new_op((reg, offset)), typ, state_space) } TypedOperand::VecMember(..) => Err(error_unreachable()), } @@ -2630,79 +2484,18 @@ fn insert_implicit_conversions( let mut result = Vec::with_capacity(func.len()); for s in func.into_iter() { match s { - Statement::Call(call) => insert_implicit_conversions_impl( - &mut result, - id_def, - call, - should_bitcast_wrapper, - None, - )?, + Statement::Call(call) => { + insert_implicit_conversions_impl(&mut result, id_def, call)?; + } Statement::Instruction(inst) => { - let mut default_conversion_fn = - should_bitcast_wrapper as for<'a> fn(&'a ast::Type, &'a ast::Type, _) -> _; - let mut state_space = None; - if let ast::Instruction::Ld(d, _) = &inst { - state_space = Some(d.state_space); - } - if let ast::Instruction::St(d, _) = &inst { - state_space = Some(d.state_space.to_ld_ss()); - } - if let ast::Instruction::Atom(d, _) = &inst { - state_space = Some(d.space.to_ld_ss()); - } - if let ast::Instruction::AtomCas(d, _) = &inst { - state_space = Some(d.space.to_ld_ss()); - } - if let ast::Instruction::Mov(..) = &inst { - default_conversion_fn = should_bitcast_packed; - } - insert_implicit_conversions_impl( - &mut result, - id_def, - inst, - default_conversion_fn, - state_space, - )?; + insert_implicit_conversions_impl(&mut result, id_def, inst)?; } - Statement::PtrAccess(PtrAccess { - underlying_type, - state_space, - dst, - ptr_src, - offset_src: constant_src, - }) => { - let visit_desc = VisitArgumentDescriptor { - desc: ArgumentDescriptor { - op: ptr_src, - is_dst: false, - sema: ArgumentSemantics::PhysicalPointer, - }, - typ: &ast::Type::Pointer(underlying_type.clone(), state_space), - stmt_ctor: |new_ptr_src| { - Statement::PtrAccess(PtrAccess { - underlying_type, - state_space, - dst, - ptr_src: new_ptr_src, - offset_src: constant_src, - }) - }, - }; - insert_implicit_conversions_impl( - &mut result, - id_def, - visit_desc, - bitcast_physical_pointer, - Some(state_space), - )?; + Statement::PtrAccess(access) => { + insert_implicit_conversions_impl(&mut result, id_def, access)?; + } + Statement::RepackVector(repack) => { + insert_implicit_conversions_impl(&mut result, id_def, repack)?; } - Statement::RepackVector(repack) => insert_implicit_conversions_impl( - &mut result, - id_def, - repack, - should_bitcast_wrapper, - None, - )?, s @ Statement::Conditional(_) | s @ Statement::Conversion(_) | s @ Statement::Label(_) @@ -2720,72 +2513,56 @@ fn insert_implicit_conversions_impl( func: &mut Vec, id_def: &mut MutableNumericIdResolver, stmt: impl Visitable, - default_conversion_fn: for<'a> fn( - &'a ast::Type, - &'a ast::Type, - Option, - ) -> Result, TranslateError>, - state_space: Option, ) -> Result<(), TranslateError> { let mut post_conv = Vec::new(); - let statement = stmt.visit( - &mut |desc: ArgumentDescriptor, typ: Option<&ast::Type>| { - let instr_type = match typ { + let statement = + stmt.visit(&mut |desc: ArgumentDescriptor, + typ: Option<(&ast::Type, ast::StateSpace)>| { + let (instr_type, instruction_space) = match typ { None => return Ok(desc.op), Some(t) => t, }; - let operand_type = id_def.get_typed(desc.op)?; - let mut conversion_fn = default_conversion_fn; - match desc.sema { - ArgumentSemantics::Default => {} - ArgumentSemantics::DefaultRelaxed => { - if desc.is_dst { - conversion_fn = should_convert_relaxed_dst_wrapper; - } else { - conversion_fn = should_convert_relaxed_src_wrapper; - } - } - ArgumentSemantics::PhysicalPointer => { - conversion_fn = bitcast_physical_pointer; - } - ArgumentSemantics::RegisterPointer => { - conversion_fn = bitcast_register_pointer; - } - ArgumentSemantics::Address => { - conversion_fn = force_bitcast_ptr_to_bit; - } - }; - match conversion_fn(&operand_type, instr_type, state_space)? { + let (operand_type, operand_space) = id_def.get_typed(desc.op)?; + let conversion_fn = desc + .non_default_implicit_conversion + .unwrap_or(default_implicit_conversion); + match conversion_fn( + (operand_space, &operand_type), + (instruction_space, instr_type), + )? { Some(conv_kind) => { let conv_output = if desc.is_dst { &mut post_conv } else { &mut *func }; - let mut from = instr_type.clone(); - let mut to = operand_type; - let mut src = id_def.new_non_variable(instr_type.clone()); + let mut from_type = instr_type.clone(); + let mut from_space = instruction_space; + let mut to_type = operand_type; + let mut to_space = operand_space; + let mut src = + id_def.register_intermediate(instr_type.clone(), instruction_space); let mut dst = desc.op; let result = Ok(src); if !desc.is_dst { mem::swap(&mut src, &mut dst); - mem::swap(&mut from, &mut to); + mem::swap(&mut from_type, &mut to_type); + mem::swap(&mut from_space, &mut to_space); } conv_output.push(Statement::Conversion(ImplicitConversion { src, dst, - from, - to, + from_type, + from_space, + to_type, + to_space, kind: conv_kind, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, })); result } None => Ok(desc.op), } - }, - )?; + })?; func.push(statement); func.append(&mut post_conv); Ok(()) @@ -2794,17 +2571,15 @@ fn insert_implicit_conversions_impl( fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, - spirv_input: &[ast::Variable], - spirv_output: &[ast::Variable], + spirv_input: impl Iterator, + spirv_output: &[ast::Variable], ) -> (spirv::Word, spirv::Word) { map.get_or_add_fn( builder, - spirv_input - .iter() - .map(|var| SpirvType::from(var.v_type.clone())), + spirv_input, spirv_output .iter() - .map(|var| SpirvType::from(var.v_type.clone())), + .map(|var| SpirvType::new(var.v_type.clone())), ) } @@ -2831,20 +2606,25 @@ fn emit_function_body_ops( match s { Statement::Label(_) => (), Statement::Call(call) => { - let (result_type, result_id) = match &*call.ret_params { - [(id, typ)] => ( - map.get_or_add(builder, SpirvType::from(typ.to_func_type())), - Some(*id), - ), + let (result_type, result_id) = match &*call.return_arguments { + [(id, typ, space)] => { + if *space != ast::StateSpace::Reg { + return Err(error_unreachable()); + } + ( + map.get_or_add(builder, SpirvType::new(typ.clone())), + Some(*id), + ) + } [] => (map.void(), None), _ => todo!(), }; let arg_list = call - .param_list + .input_arguments .iter() - .map(|(id, _)| *id) + .map(|(id, _, _)| *id) .collect::>(); - builder.function_call(result_type, result_id, call.func, arg_list)?; + builder.function_call(result_type, result_id, call.name, arg_list)?; } Statement::Variable(var) => { emit_variable(builder, map, var)?; @@ -2966,7 +2746,7 @@ fn emit_function_body_ops( todo!() } let result_type = - map.get_or_add(builder, SpirvType::from(ast::Type::from(data.typ.clone()))); + map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); builder.load( result_type, Some(arg.dst), @@ -2998,7 +2778,7 @@ fn emit_function_body_ops( ast::Instruction::Ret(_) => builder.ret()?, ast::Instruction::Mov(d, arg) => { let result_type = - map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone()))); + map.get_or_add(builder, SpirvType::new(ast::Type::from(d.typ.clone()))); builder.copy_object(result_type, Some(arg.dst), arg.src)?; } ast::Instruction::Mul(mul, arg) => match mul { @@ -3026,20 +2806,20 @@ fn emit_function_body_ops( emit_setp(builder, map, setp, arg)?; } ast::Instruction::Not(t, a) => { - let result_type = map.get_or_add(builder, SpirvType::from(t.to_type())); + let result_type = map.get_or_add(builder, SpirvType::from(*t)); let result_id = Some(a.dst); let operand = a.src; match t { - ast::BooleanType::Pred => { + ast::ScalarType::Pred => { logical_not(builder, result_type, result_id, operand) } _ => builder.not(result_type, result_id, operand), }?; } ast::Instruction::Shl(t, a) => { - let full_type = t.to_type(); + let full_type = ast::Type::Scalar(*t); let size_of = full_type.size_of(); - let result_type = map.get_or_add(builder, SpirvType::from(full_type)); + let result_type = map.get_or_add(builder, SpirvType::new(full_type)); let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?; builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?; } @@ -3048,7 +2828,7 @@ fn emit_function_body_ops( let size_of = full_type.size_of(); let result_type = map.get_or_add_scalar(builder, full_type); let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?; - if t.signed() { + if t.kind() == ast::ScalarKind::Signed { builder.shift_right_arithmetic( result_type, Some(a.dst), @@ -3088,7 +2868,7 @@ fn emit_function_body_ops( }, ast::Instruction::Or(t, a) => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); - if *t == ast::BooleanType::Pred { + if *t == ast::ScalarType::Pred { builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?; } else { builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?; @@ -3116,7 +2896,7 @@ fn emit_function_body_ops( } ast::Instruction::And(t, a) => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); - if *t == ast::BooleanType::Pred { + if *t == ast::ScalarType::Pred { builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?; } else { builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?; @@ -3202,7 +2982,7 @@ fn emit_function_body_ops( } ast::Instruction::Neg(details, arg) => { let result_type = map.get_or_add_scalar(builder, details.typ); - let negate_func = if details.typ.kind() == ScalarKind::Float { + let negate_func = if details.typ.kind() == ast::ScalarKind::Float { dr::Builder::f_negate } else { dr::Builder::s_negate @@ -3269,7 +3049,7 @@ fn emit_function_body_ops( } ast::Instruction::Xor { typ, arg } => { let builder_fn = match typ { - ast::BooleanType::Pred => emit_logical_xor_spirv, + ast::ScalarType::Pred => emit_logical_xor_spirv, _ => dr::Builder::bitwise_xor, }; let result_type = map.get_or_add_scalar(builder, (*typ).into()); @@ -3284,7 +3064,7 @@ fn emit_function_body_ops( return Err(error_unreachable()); } ast::Instruction::Rem { typ, arg } => { - let builder_fn = if typ.is_signed() { + let builder_fn = if typ.kind() == ast::ScalarKind::Signed { dr::Builder::s_mod } else { dr::Builder::u_mod @@ -3301,7 +3081,7 @@ fn emit_function_body_ops( Some(index) => { let result_ptr_type = map.get_or_add( builder, - SpirvType::new_pointer( + SpirvType::pointer_to( details.typ.clone(), spirv::StorageClass::Function, ), @@ -3334,14 +3114,11 @@ fn emit_function_body_ops( }) => { let u8_pointer = map.get_or_add( builder, - SpirvType::from(ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - *state_space, - )), + SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)), ); let result_type = map.get_or_add( builder, - SpirvType::from(ast::Type::Pointer(underlying_type.clone(), *state_space)), + SpirvType::new(ast::Type::Pointer(*underlying_type, *state_space)), ); let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?; let temp = builder.in_bounds_ptr_access_chain( @@ -3553,11 +3330,16 @@ fn ptx_scope_name(scope: ast::MemScope) -> &'static str { } } -fn ptx_space_name(space: ast::AtomSpace) -> &'static str { +fn ptx_space_name(space: ast::StateSpace) -> &'static str { match space { - ast::AtomSpace::Generic => "generic", - ast::AtomSpace::Global => "global", - ast::AtomSpace::Shared => "shared", + ast::StateSpace::Generic => "generic", + ast::StateSpace::Global => "global", + ast::StateSpace::Shared => "shared", + ast::StateSpace::Reg => "reg", + ast::StateSpace::Const => "const", + ast::StateSpace::Local => "local", + ast::StateSpace::Param => "param", + ast::StateSpace::Sreg => "sreg", } } @@ -3612,14 +3394,17 @@ fn vec_repr(t: T) -> Vec { fn emit_variable( builder: &mut dr::Builder, map: &mut TypeWordMap, - var: &ast::Variable, + var: &ast::Variable, ) -> Result<(), TranslateError> { - let (must_init, st_class) = match var.v_type { - ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => { + let (must_init, st_class) = match var.state_space { + ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { (false, spirv::StorageClass::Function) } - ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup), - ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup), + ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup), + ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), + ast::StateSpace::Const => todo!(), + ast::StateSpace::Generic => todo!(), + ast::StateSpace::Sreg => todo!(), }; let initalizer = if var.array_init.len() > 0 { Some(map.get_or_add_constant( @@ -3628,18 +3413,12 @@ fn emit_variable( &*var.array_init, )?) } else if must_init { - let type_id = map.get_or_add( - builder, - SpirvType::from(ast::Type::from(var.v_type.clone())), - ); + let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone())); Some(builder.constant_null(type_id, None)) } else { None }; - let ptr_type_id = map.get_or_add( - builder, - SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class), - ); + let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class)); builder.variable(ptr_type_id, Some(var.name), st_class, initalizer); if let Some(align) = var.align { builder.decorate( @@ -3777,7 +3556,7 @@ fn emit_min( ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, }; - let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type())); + let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type())); builder.ext_inst( inst_type, Some(arg.dst), @@ -3802,7 +3581,7 @@ fn emit_max( ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, }; - let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type())); + let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type())); builder.ext_inst( inst_type, Some(arg.dst), @@ -3882,7 +3661,7 @@ fn emit_cvt( } let dest_t: ast::ScalarType = desc.dst.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - if desc.src.is_signed() { + if desc.src.kind() == ast::ScalarKind::Signed { builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?; } else { builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?; @@ -3892,7 +3671,7 @@ fn emit_cvt( ast::CvtDetails::IntFromFloat(desc) => { let dest_t: ast::ScalarType = desc.dst.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - if desc.dst.is_signed() { + if desc.dst.kind() == ast::ScalarKind::Signed { builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?; } else { builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?; @@ -3904,7 +3683,7 @@ fn emit_cvt( let dest_t: ast::ScalarType = desc.dst.into(); let src_t: ast::ScalarType = desc.src.into(); // first do shortening/widening - let src = if desc.dst.width() != desc.src.width() { + let src = if desc.dst.size_of() != desc.src.size_of() { let new_dst = if dest_t.kind() == src_t.kind() { arg.dst } else { @@ -3913,14 +3692,14 @@ fn emit_cvt( let cv = ImplicitConversion { src: arg.src, dst: new_dst, - from: ast::Type::Scalar(src_t), - to: ast::Type::Scalar(ast::ScalarType::from_parts( + from_type: ast::Type::Scalar(src_t), + from_space: ast::StateSpace::Reg, + to_type: ast::Type::Scalar(ast::ScalarType::from_parts( dest_t.size_of(), src_t.kind(), )), + to_space: ast::StateSpace::Reg, kind: ConversionKind::Default, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, }; emit_implicit_conversion(builder, map, &cv)?; new_dst @@ -3933,7 +3712,7 @@ fn emit_cvt( // now do actual conversion let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); if desc.saturate { - if desc.dst.is_signed() { + if desc.dst.kind() == ast::ScalarKind::Signed { builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?; } else { builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?; @@ -3989,60 +3768,60 @@ fn emit_setp( let operand_1 = arg.src1; let operand_2 = arg.src2; match (setp.cmp_op, setp.typ.kind()) { - (ast::SetpCompareOp::Eq, ScalarKind::Signed) - | (ast::SetpCompareOp::Eq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::Eq, ScalarKind::Bit) => { + (ast::SetpCompareOp::Eq, ast::ScalarKind::Signed) + | (ast::SetpCompareOp::Eq, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::Eq, ast::ScalarKind::Bit) => { builder.i_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Eq, ScalarKind::Float) => { + (ast::SetpCompareOp::Eq, ast::ScalarKind::Float) => { builder.f_ord_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::NotEq, ScalarKind::Signed) - | (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => { + (ast::SetpCompareOp::NotEq, ast::ScalarKind::Signed) + | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Bit) => { builder.i_not_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::NotEq, ScalarKind::Float) => { + (ast::SetpCompareOp::NotEq, ast::ScalarKind::Float) => { builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Less, ScalarKind::Unsigned) - | (ast::SetpCompareOp::Less, ScalarKind::Bit) => { + (ast::SetpCompareOp::Less, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::Less, ast::ScalarKind::Bit) => { builder.u_less_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Less, ScalarKind::Signed) => { + (ast::SetpCompareOp::Less, ast::ScalarKind::Signed) => { builder.s_less_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Less, ScalarKind::Float) => { + (ast::SetpCompareOp::Less, ast::ScalarKind::Float) => { builder.f_ord_less_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => { + (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Bit) => { builder.u_less_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => { + (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Signed) => { builder.s_less_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => { + (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Float) => { builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Greater, ScalarKind::Unsigned) - | (ast::SetpCompareOp::Greater, ScalarKind::Bit) => { + (ast::SetpCompareOp::Greater, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::Greater, ast::ScalarKind::Bit) => { builder.u_greater_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Greater, ScalarKind::Signed) => { + (ast::SetpCompareOp::Greater, ast::ScalarKind::Signed) => { builder.s_greater_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::Greater, ScalarKind::Float) => { + (ast::SetpCompareOp::Greater, ast::ScalarKind::Float) => { builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => { + (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Unsigned) + | (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Bit) => { builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => { + (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Signed) => { builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2) } - (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => { + (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Float) => { builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::NanEq, _) => { @@ -4222,7 +4001,7 @@ fn emit_abs( ) -> Result<(), dr::Error> { let scalar_t = ast::ScalarType::from(d.typ); let result_type = map.get_or_add(builder, SpirvType::from(scalar_t)); - let cl_abs = if scalar_t.kind() == ScalarKind::Signed { + let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed { spirv::CLOp::s_abs } else { spirv::CLOp::fabs @@ -4272,22 +4051,21 @@ fn emit_implicit_conversion( map: &mut TypeWordMap, cv: &ImplicitConversion, ) -> Result<(), TranslateError> { - let from_parts = cv.from.to_parts(); - let to_parts = cv.to.to_parts(); - match (from_parts.kind, to_parts.kind, cv.kind) { - (_, _, ConversionKind::PtrToBit(typ)) => { - let dst_type = map.get_or_add_scalar(builder, typ.into()); - builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; - } - (_, _, ConversionKind::BitToPtr(_)) => { - let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); + let from_parts = cv.from_type.to_parts(); + let to_parts = cv.to_type.to_parts(); + match (from_parts.kind, to_parts.kind, &cv.kind) { + (_, _, &ConversionKind::BitToPtr) => { + let dst_type = map.get_or_add( + builder, + SpirvType::pointer_to(cv.to_type.clone(), cv.to_space.to_spirv()), + ); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } - (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { + (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => { if from_parts.width == to_parts.width { - let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); - if from_parts.scalar_kind != ScalarKind::Float - && to_parts.scalar_kind != ScalarKind::Float + let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + if from_parts.scalar_kind != ast::ScalarKind::Float + && to_parts.scalar_kind != ast::ScalarKind::Float { // It is noop, but another instruction expects result of this conversion builder.copy_object(dst_type, Some(cv.dst), cv.src)?; @@ -4295,28 +4073,28 @@ fn emit_implicit_conversion( builder.bitcast(dst_type, Some(cv.dst), cv.src)?; } } else { - // This block is safe because it's illegal to implictly convert between floating point instructions + // This block is safe because it's illegal to implictly convert between floating point values let same_width_bit_type = map.get_or_add( builder, - SpirvType::from(ast::Type::from_parts(TypeParts { - scalar_kind: ScalarKind::Bit, + SpirvType::new(ast::Type::from_parts(TypeParts { + scalar_kind: ast::ScalarKind::Bit, ..from_parts })), ); let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?; let wide_bit_type = ast::Type::from_parts(TypeParts { - scalar_kind: ScalarKind::Bit, + scalar_kind: ast::ScalarKind::Bit, ..to_parts }); let wide_bit_type_spirv = - map.get_or_add(builder, SpirvType::from(wide_bit_type.clone())); - if to_parts.scalar_kind == ScalarKind::Unsigned - || to_parts.scalar_kind == ScalarKind::Bit + map.get_or_add(builder, SpirvType::new(wide_bit_type.clone())); + if to_parts.scalar_kind == ast::ScalarKind::Unsigned + || to_parts.scalar_kind == ast::ScalarKind::Bit { builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?; } else { - let conversion_fn = if from_parts.scalar_kind == ScalarKind::Signed - && to_parts.scalar_kind == ScalarKind::Signed + let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed + && to_parts.scalar_kind == ast::ScalarKind::Signed { dr::Builder::s_convert } else { @@ -4330,40 +4108,48 @@ fn emit_implicit_conversion( &ImplicitConversion { src: wide_bit_value, dst: cv.dst, - from: wide_bit_type, - to: cv.to.clone(), + from_type: wide_bit_type, + from_space: cv.from_space, + to_type: cv.to_type.clone(), + to_space: cv.to_space, kind: ConversionKind::Default, - src_sema: cv.src_sema, - dst_sema: cv.dst_sema, }, )?; } } } - (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => { - let result_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); + (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => { + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.s_convert(result_type, Some(cv.dst), cv.src)?; } - (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default) - | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default) - | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => { - let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); + (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default) + | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default) + | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => { + let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.bitcast(into_type, Some(cv.dst), cv.src)?; } - (_, _, ConversionKind::PtrToPtr { spirv_ptr }) => { - let result_type = if spirv_ptr { - map.get_or_add( - builder, - SpirvType::Pointer( - Box::new(SpirvType::from(cv.to.clone())), - spirv::StorageClass::Function, - ), - ) - } else { - map.get_or_add(builder, SpirvType::from(cv.to.clone())) - }; + (_, _, &ConversionKind::PtrToPtr) => { + let result_type = map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::new(cv.to_type.clone())), + cv.to_space.to_spirv(), + ), + ); builder.bitcast(result_type, Some(cv.dst), cv.src)?; } + (_, _, &ConversionKind::AddressOf) => { + let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; + } + (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => { + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_ptr_to_u(result_type, Some(cv.dst), cv.src)?; + } + (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => { + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_u_to_ptr(result_type, Some(cv.dst), cv.src)?; + } _ => unreachable!(), } Ok(()) @@ -4374,14 +4160,14 @@ fn emit_load_var( map: &mut TypeWordMap, details: &LoadVarDetails, ) -> Result<(), TranslateError> { - let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone())); + let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone())); match details.member_index { Some((index, Some(width))) => { let vector_type = match details.typ { ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), _ => return Err(TranslateError::MismatchedType), }; - let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type)); + let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); let vector_temp = builder.load( vector_type_spirv, None, @@ -4399,7 +4185,7 @@ fn emit_load_var( Some((index, None)) => { let result_ptr_type = map.get_or_add( builder, - SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function), + SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function), ); let index_spirv = map.get_or_add_constant( builder, @@ -4427,10 +4213,10 @@ fn emit_load_var( Ok(()) } -fn normalize_identifiers<'a, 'b>( - id_defs: &mut FnStringIdResolver<'a, 'b>, - fn_defs: &GlobalFnDeclResolver<'a, 'b>, - func: Vec>>, +fn normalize_identifiers<'input, 'b>( + id_defs: &mut FnStringIdResolver<'input, 'b>, + fn_defs: &GlobalFnDeclResolver<'input, 'b>, + func: Vec>>, ) -> Result, TranslateError> { for s in func.iter() { match s { @@ -4468,48 +4254,28 @@ fn expand_map_variables<'a, 'b>( i.map_variable(&mut |id| id_defs.get_id(id))?, ))), ast::Statement::Variable(var) => { - let mut var_type = ast::Type::from(var.var.v_type.clone()); - let mut is_variable = false; - var_type = match var.var.v_type { - ast::VariableType::Reg(_) => { - is_variable = true; - var_type - } - ast::VariableType::Shared(_) => { - // If it's a pointer it will be translated to a method parameter later - if let ast::Type::Pointer(..) = var_type { - is_variable = true; - var_type - } else { - var_type.param_pointer_to(ast::LdStateSpace::Shared)? - } - } - ast::VariableType::Global(_) => { - var_type.param_pointer_to(ast::LdStateSpace::Global)? - } - ast::VariableType::Param(_) => { - var_type.param_pointer_to(ast::LdStateSpace::Param)? - } - ast::VariableType::Local(_) => { - var_type.param_pointer_to(ast::LdStateSpace::Local)? - } - }; + let var_type = var.var.v_type.clone(); match var.count { Some(count) => { - for new_id in id_defs.add_defs(var.var.name, count, var_type, is_variable) { + for new_id in + id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true) + { result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), + state_space: var.var.state_space, name: new_id, array_init: var.var.array_init.clone(), })) } } None => { - let new_id = id_defs.add_def(var.var.name, Some(var_type), is_variable); + let new_id = + id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true); result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), + state_space: var.var.state_space, name: new_id, array_init: var.var.array_init, })); @@ -4520,18 +4286,62 @@ fn expand_map_variables<'a, 'b>( Ok(()) } +/* + Our goal here is to transform + .visible .entry foobar(.param .u64 input) { + .reg .b64 in_addr; + .reg .b64 in_addr2; + ld.param.u64 in_addr, [input]; + cvta.to.global.u64 in_addr2, in_addr; + } + into: + .visible .entry foobar(.param .u8 input[]) { + .reg .u8 in_addr[]; + .reg .u8 in_addr2[]; + ld.param.u8[] in_addr, [input]; + mov.u8[] in_addr2, in_addr; + } + or: + .visible .entry foobar(.reg .u8 input[]) { + .reg .u8 in_addr[]; + .reg .u8 in_addr2[]; + mov.u8[] in_addr, input; + mov.u8[] in_addr2, in_addr; + } + or: + .visible .entry foobar(.param ptr input) { + .reg ptr in_addr; + .reg ptr in_addr2; + ld.param.ptr in_addr, [input]; + mov.ptr in_addr2, in_addr; + } +*/ // TODO: detect more patterns (mov, call via reg, call via param) // TODO: don't convert to ptr if the register is not ultimately used for ld/st // TODO: once insert_mem_ssa_statements is moved to later, move this pass after // argument expansion -// TODO: propagate through calls? -fn convert_to_stateful_memory_access<'a>( - func_args: &mut SpirvMethodDecl, +// TODO: propagate out of calls and into calls +fn convert_to_stateful_memory_access<'a, 'input>( + func_args: Rc>>, func_body: Vec, id_defs: &mut NumericIdResolver<'a>, -) -> Result, TranslateError> { - let func_args_64bit = func_args - .input +) -> Result< + ( + Rc>>, + Vec, + ), + TranslateError, +> { + let mut method_decl = func_args.borrow_mut(); + if !method_decl.name.is_kernel() { + drop(method_decl); + return Ok((func_args, func_body)); + } + if Rc::strong_count(&func_args) != 1 { + return Err(error_unreachable()); + } + let func_args_64bit = (*method_decl) + .input_arguments .iter() .filter_map(|arg| match arg.v_type { ast::Type::Scalar(ast::ScalarType::U64) @@ -4546,9 +4356,9 @@ fn convert_to_stateful_memory_access<'a>( match statement { Statement::Instruction(ast::Instruction::Cvta( ast::CvtaDetails { - to: ast::CvtaStateSpace::Global, + to: ast::StateSpace::Global, size: ast::CvtaSize::U64, - from: ast::CvtaStateSpace::Generic, + from: ast::StateSpace::Generic, }, arg, )) => { @@ -4562,24 +4372,24 @@ fn convert_to_stateful_memory_access<'a>( } Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { - state_space: ast::LdStateSpace::Param, - typ: ast::LdStType::Scalar(ast::LdStScalarType::U64), + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::U64), .. }, arg, )) | Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { - state_space: ast::LdStateSpace::Param, - typ: ast::LdStType::Scalar(ast::LdStScalarType::S64), + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::S64), .. }, arg, )) | Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { - state_space: ast::LdStateSpace::Param, - typ: ast::LdStType::Scalar(ast::LdStScalarType::B64), + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::B64), .. }, arg, @@ -4595,6 +4405,10 @@ fn convert_to_stateful_memory_access<'a>( _ => {} } } + if stateful_markers.len() == 0 { + drop(method_decl); + return Ok((func_args, func_body)); + } let mut func_args_ptr = HashSet::new(); let mut regs_ptr_current = HashSet::new(); for (dst, src) in stateful_markers { @@ -4614,23 +4428,23 @@ fn convert_to_stateful_memory_access<'a>( for statement in func_body.iter() { match statement { Statement::Instruction(ast::Instruction::Add( - ast::ArithDetails::Unsigned(ast::UIntType::U64), + ast::ArithDetails::Unsigned(ast::ScalarType::U64), arg, )) | Statement::Instruction(ast::Instruction::Add( ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S64, + typ: ast::ScalarType::S64, saturate: false, }), arg, )) | Statement::Instruction(ast::Instruction::Sub( - ast::ArithDetails::Unsigned(ast::UIntType::U64), + ast::ArithDetails::Unsigned(ast::ScalarType::U64), arg, )) | Statement::Instruction(ast::Instruction::Sub( ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S64, + typ: ast::ScalarType::S64, saturate: false, }), arg, @@ -4661,21 +4475,32 @@ fn convert_to_stateful_memory_access<'a>( let mut remapped_ids = HashMap::new(); let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len()); for reg in regs_ptr_seen { - let new_id = id_defs.new_variable(ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, - )); + let new_id = id_defs.register_variable( + ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + ast::StateSpace::Reg, + ); result.push(Statement::Variable(ast::Variable { align: None, name: new_id, array_init: Vec::new(), - v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( - ast::SizedScalarType::U8, - ast::PointerStateSpace::Global, - )), + v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + state_space: ast::StateSpace::Reg, })); remapped_ids.insert(reg, new_id); } + for arg in (*method_decl).input_arguments.iter_mut() { + if !func_args_ptr.contains(&arg.name) { + continue; + } + let new_id = id_defs.register_variable( + ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + ast::StateSpace::Param, + ); + let old_name = arg.name; + arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global); + arg.name = new_id; + remapped_ids.insert(old_name, new_id); + } for statement in func_body { match statement { l @ Statement::Label(_) => result.push(l), @@ -4686,12 +4511,12 @@ fn convert_to_stateful_memory_access<'a>( } } Statement::Instruction(ast::Instruction::Add( - ast::ArithDetails::Unsigned(ast::UIntType::U64), + ast::ArithDetails::Unsigned(ast::ScalarType::U64), arg, )) | Statement::Instruction(ast::Instruction::Add( ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S64, + typ: ast::ScalarType::S64, saturate: false, }), arg, @@ -4707,20 +4532,20 @@ fn convert_to_stateful_memory_access<'a>( }; let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { - underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), - state_space: ast::LdStateSpace::Global, + underlying_type: ast::ScalarType::U8, + state_space: ast::StateSpace::Global, dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, offset_src: offset, })) } Statement::Instruction(ast::Instruction::Sub( - ast::ArithDetails::Unsigned(ast::UIntType::U64), + ast::ArithDetails::Unsigned(ast::ScalarType::U64), arg, )) | Statement::Instruction(ast::Instruction::Sub( ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S64, + typ: ast::ScalarType::S64, saturate: false, }), arg, @@ -4734,8 +4559,10 @@ fn convert_to_stateful_memory_access<'a>( } _ => return Err(error_unreachable()), }; - let offset_neg = - id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64))); + let offset_neg = id_defs.register_intermediate(Some(( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ))); result.push(Statement::Instruction(ast::Instruction::Neg( ast::NegDetails { typ: ast::ScalarType::S64, @@ -4748,8 +4575,8 @@ fn convert_to_stateful_memory_access<'a>( ))); let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { - underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), - state_space: ast::LdStateSpace::Global, + underlying_type: ast::ScalarType::U8, + state_space: ast::StateSpace::Global, dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, offset_src: TypedOperand::Reg(offset_neg), @@ -4757,151 +4584,116 @@ fn convert_to_stateful_memory_access<'a>( } Statement::Instruction(inst) => { let mut post_statements = Vec::new(); - let new_statement = inst.visit( - &mut |arg_desc: ArgumentDescriptor, - expected_type: Option<&ast::Type>| { + let new_statement = + inst.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, - &func_args_ptr, &mut result, &mut post_statements, arg_desc, expected_type, ) - }, - )?; + })?; result.push(new_statement); result.extend(post_statements); } Statement::Call(call) => { let mut post_statements = Vec::new(); - let new_statement = call.visit( - &mut |arg_desc: ArgumentDescriptor, - expected_type: Option<&ast::Type>| { + let new_statement = + call.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, - &func_args_ptr, &mut result, &mut post_statements, arg_desc, expected_type, ) - }, - )?; + })?; result.push(new_statement); result.extend(post_statements); } Statement::RepackVector(pack) => { let mut post_statements = Vec::new(); - let new_statement = pack.visit( - &mut |arg_desc: ArgumentDescriptor, - expected_type: Option<&ast::Type>| { + let new_statement = + pack.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, - &func_args_ptr, &mut result, &mut post_statements, arg_desc, expected_type, ) - }, - )?; + })?; result.push(new_statement); result.extend(post_statements); } _ => return Err(error_unreachable()), } } - for arg in func_args.input.iter_mut() { - if func_args_ptr.contains(&arg.name) { - arg.v_type = ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, - ); - } - } - Ok(result) + drop(method_decl); + Ok((func_args, result)) } fn convert_to_stateful_memory_access_postprocess( id_defs: &mut NumericIdResolver, remapped_ids: &HashMap, - func_args_ptr: &HashSet, result: &mut Vec, post_statements: &mut Vec, arg_desc: ArgumentDescriptor, - expected_type: Option<&ast::Type>, + expected_type: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { Ok(match remapped_ids.get(&arg_desc.op) { Some(new_id) => { - // We skip conversion here to trigger PtrAcces in a later pass - let old_type = match expected_type { - Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id), - _ => id_defs.get_typed(arg_desc.op)?.0, + let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?; + if let Some((expected_type, expected_space)) = expected_type { + let implicit_conversion = arg_desc + .non_default_implicit_conversion + .unwrap_or(default_implicit_conversion); + if implicit_conversion( + (new_operand_space, &new_operand_type), + (expected_space, expected_type), + ) + .is_ok() + { + return Ok(*new_id); + } + } + let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?; + let converting_id = + id_defs.register_intermediate(Some((old_operand_type.clone(), old_operand_space))); + let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) { + ConversionKind::Default + } else { + ConversionKind::PtrToPtr }; - let old_type_clone = old_type.clone(); - let converting_id = id_defs.new_non_variable(Some(old_type_clone)); if arg_desc.is_dst { post_statements.push(Statement::Conversion(ImplicitConversion { src: converting_id, dst: *new_id, - from: old_type, - to: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, - ), - kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global), - src_sema: ArgumentSemantics::Default, - dst_sema: arg_desc.sema, + from_type: old_operand_type, + from_space: old_operand_space, + to_type: new_operand_type, + to_space: new_operand_space, + kind, })); converting_id } else { result.push(Statement::Conversion(ImplicitConversion { src: *new_id, dst: converting_id, - from: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, - ), - to: old_type, - kind: ConversionKind::PtrToBit(ast::UIntType::U64), - src_sema: arg_desc.sema, - dst_sema: ArgumentSemantics::Default, + from_type: new_operand_type, + from_space: new_operand_space, + to_type: old_operand_type, + to_space: old_operand_space, + kind, })); converting_id } } - None => match func_args_ptr.get(&arg_desc.op) { - Some(new_id) => { - if arg_desc.is_dst { - return Err(error_unreachable()); - } - // We skip conversion here to trigger PtrAcces in a later pass - let old_type = match expected_type { - Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id), - _ => id_defs.get_typed(arg_desc.op)?.0, - }; - let old_type_clone = old_type.clone(); - let converting_id = id_defs.new_non_variable(Some(old_type)); - result.push(Statement::Conversion(ImplicitConversion { - src: *new_id, - dst: converting_id, - from: ast::Type::Pointer( - ast::PointerType::Pointer(ast::ScalarType::U8, ast::LdStateSpace::Global), - ast::LdStateSpace::Param, - ), - to: old_type_clone, - kind: ConversionKind::PtrToPtr { spirv_ptr: false }, - src_sema: arg_desc.sema, - dst_sema: ArgumentSemantics::Default, - })); - converting_id - } - None => arg_desc.op, - }, + None => arg_desc.op, }) } @@ -4925,9 +4717,9 @@ fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3 bool { match id_defs.get_typed(id) { - Ok((ast::Type::Scalar(ast::ScalarType::U64), _)) - | Ok((ast::Type::Scalar(ast::ScalarType::S64), _)) - | Ok((ast::Type::Scalar(ast::ScalarType::B64), _)) => true, + Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _)) + | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _)) + | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true, _ => false, } } @@ -5055,20 +4847,95 @@ impl SpecialRegistersMap { } } +struct FnSigMapper<'input> { + // true - stays as return argument + // false - is moved to input argument + return_param_args: Vec, + func_decl: Rc>>, +} + +impl<'input> FnSigMapper<'input> { + fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, spirv::Word>) -> Self { + let return_param_args = method + .return_arguments + .iter() + .map(|a| a.state_space != ast::StateSpace::Param) + .collect::>(); + let mut new_return_arguments = Vec::new(); + for arg in method.return_arguments.into_iter() { + if arg.state_space == ast::StateSpace::Param { + method.input_arguments.push(arg); + } else { + new_return_arguments.push(arg); + } + } + method.return_arguments = new_return_arguments; + FnSigMapper { + return_param_args, + func_decl: Rc::new(RefCell::new(method)), + } + } + + fn resolve_in_spirv_repr( + &self, + call_inst: ast::CallInst, + ) -> Result, TranslateError> { + let func_decl = (*self.func_decl).borrow(); + let mut return_arguments = Vec::new(); + let mut input_arguments = call_inst + .param_list + .into_iter() + .zip(func_decl.input_arguments.iter()) + .map(|(id, var)| (id, var.v_type.clone(), var.state_space)) + .collect::>(); + let mut func_decl_return_iter = func_decl.return_arguments.iter(); + let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter(); + for (idx, id) in call_inst.ret_params.iter().enumerate() { + let stays_as_return = match self.return_param_args.get(idx) { + Some(x) => *x, + None => return Err(TranslateError::MismatchedType), + }; + if stays_as_return { + if let Some(var) = func_decl_return_iter.next() { + return_arguments.push((*id, var.v_type.clone(), var.state_space)); + } else { + return Err(TranslateError::MismatchedType); + } + } else { + if let Some(var) = func_decl_input_iter.next() { + input_arguments.push(( + ast::Operand::Reg(*id), + var.v_type.clone(), + var.state_space, + )); + } else { + return Err(TranslateError::MismatchedType); + } + } + } + if return_arguments.len() != func_decl.return_arguments.len() + || input_arguments.len() != func_decl.input_arguments.len() + { + return Err(TranslateError::MismatchedType); + } + Ok(ResolvedCall { + return_arguments, + input_arguments, + uniform: call_inst.uniform, + name: call_inst.func, + }) + } +} + struct GlobalStringIdResolver<'input> { current_id: spirv::Word, variables: HashMap, spirv::Word>, - variables_type_check: HashMap>, + variables_type_check: HashMap>, special_registers: SpecialRegistersMap, - fns: HashMap, + fns: HashMap>, } -pub struct FnDecl { - ret_vals: Vec, - params: Vec, -} - -impl<'a> GlobalStringIdResolver<'a> { +impl<'input> GlobalStringIdResolver<'input> { fn new(start_id: spirv::Word) -> Self { Self { current_id: start_id, @@ -5079,20 +4946,25 @@ impl<'a> GlobalStringIdResolver<'a> { } } - fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word { + fn get_or_add_def(&mut self, id: &'input str) -> spirv::Word { self.get_or_add_impl(id, None) } fn get_or_add_def_typed( &mut self, - id: &'a str, + id: &'input str, typ: ast::Type, + state_space: ast::StateSpace, is_variable: bool, ) -> spirv::Word { - self.get_or_add_impl(id, Some((typ, is_variable))) + self.get_or_add_impl(id, Some((typ, state_space, is_variable))) } - fn get_or_add_impl(&mut self, id: &'a str, typ: Option<(ast::Type, bool)>) -> spirv::Word { + fn get_or_add_impl( + &mut self, + id: &'input str, + typ: Option<(ast::Type, ast::StateSpace, bool)>, + ) -> spirv::Word { let id = match self.variables.entry(Cow::Borrowed(id)) { hash_map::Entry::Occupied(e) => *(e.get()), hash_map::Entry::Vacant(e) => { @@ -5119,12 +4991,12 @@ impl<'a> GlobalStringIdResolver<'a> { fn start_fn<'b>( &'b mut self, - header: &'b ast::MethodDecl<'a, &'a str>, + header: &'b ast::MethodDeclaration<'input, &'input str>, ) -> Result< ( - FnStringIdResolver<'a, 'b>, - GlobalFnDeclResolver<'a, 'b>, - ast::MethodDecl<'a, spirv::Word>, + FnStringIdResolver<'input, 'b>, + GlobalFnDeclResolver<'input, 'b>, + Rc>>, ), TranslateError, > { @@ -5138,60 +5010,51 @@ impl<'a> GlobalStringIdResolver<'a> { variables: vec![HashMap::new(); 1], type_check: HashMap::new(), }; - let new_fn_decl = match header { - ast::MethodDecl::Kernel { name, in_args } => ast::MethodDecl::Kernel { - name, - in_args: expand_kernel_params(&mut fn_resolver, in_args.iter())?, - }, - ast::MethodDecl::Func(ret_params, _, params) => { - let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter())?; - let params_ids = expand_fn_params(&mut fn_resolver, params.iter())?; - self.fns.insert( - name_id, - FnDecl { - ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(), - params: params_ids.iter().map(|p| p.v_type.clone()).collect(), - }, - ); - ast::MethodDecl::Func(ret_params_ids, name_id, params_ids) - } + let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments); + let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments); + let name = match header.name { + ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), + ast::MethodName::Func(_) => ast::MethodName::Func(name_id), + }; + let fn_decl = ast::MethodDeclaration { + return_arguments, + name, + input_arguments, + shared_mem: None, + }; + let new_fn_decl = if !fn_decl.name.is_kernel() { + let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl); + let new_fn_decl = resolver.func_decl.clone(); + self.fns.insert(name_id, resolver); + new_fn_decl + } else { + Rc::new(RefCell::new(fn_decl)) }; Ok(( fn_resolver, - GlobalFnDeclResolver { - variables: &self.variables, - fns: &self.fns, - }, + GlobalFnDeclResolver { fns: &self.fns }, new_fn_decl, )) } } pub struct GlobalFnDeclResolver<'input, 'a> { - variables: &'a HashMap, spirv::Word>, - fns: &'a HashMap, + fns: &'a HashMap>, } impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { - fn get_fn_decl(&self, id: spirv::Word) -> Result<&FnDecl, TranslateError> { + fn get_fn_sig_resolver(&self, id: spirv::Word) -> Result<&FnSigMapper<'input>, TranslateError> { self.fns.get(&id).ok_or(TranslateError::UnknownSymbol) } - - fn get_fn_decl_str(&self, id: &str) -> Result<&'a FnDecl, TranslateError> { - match self.variables.get(id).map(|var_id| self.fns.get(var_id)) { - Some(Some(fn_d)) => Ok(fn_d), - _ => Err(TranslateError::UnknownSymbol), - } - } } struct FnStringIdResolver<'input, 'b> { current_id: &'b mut spirv::Word, global_variables: &'b HashMap, spirv::Word>, - global_type_check: &'b HashMap>, + global_type_check: &'b HashMap>, special_registers: &'b mut SpecialRegistersMap, variables: Vec, spirv::Word>>, - type_check: HashMap>, + type_check: HashMap>, } impl<'a, 'b> FnStringIdResolver<'a, 'b> { @@ -5229,14 +5092,21 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { } } - fn add_def(&mut self, id: &'a str, typ: Option, is_variable: bool) -> spirv::Word { + fn add_def( + &mut self, + id: &'a str, + typ: Option<(ast::Type, ast::StateSpace)>, + is_variable: bool, + ) -> spirv::Word { let numeric_id = *self.current_id; self.variables .last_mut() .unwrap() .insert(Cow::Borrowed(id), numeric_id); - self.type_check - .insert(numeric_id, typ.map(|t| (t, is_variable))); + self.type_check.insert( + numeric_id, + typ.map(|(typ, space)| (typ, space, is_variable)), + ); *self.current_id += 1; numeric_id } @@ -5247,6 +5117,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { base_id: &'a str, count: u32, typ: ast::Type, + state_space: ast::StateSpace, is_variable: bool, ) -> impl Iterator { let numeric_id = *self.current_id; @@ -5255,8 +5126,10 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .last_mut() .unwrap() .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i); - self.type_check - .insert(numeric_id + i, Some((typ.clone(), is_variable))); + self.type_check.insert( + numeric_id + i, + Some((typ.clone(), state_space, is_variable)), + ); } *self.current_id += count; (0..count).into_iter().map(move |i| i + numeric_id) @@ -5265,8 +5138,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { struct NumericIdResolver<'b> { current_id: &'b mut spirv::Word, - global_type_check: &'b HashMap>, - type_check: HashMap>, + global_type_check: &'b HashMap>, + type_check: HashMap>, special_registers: &'b mut SpecialRegistersMap, } @@ -5275,12 +5148,15 @@ impl<'b> NumericIdResolver<'b> { MutableNumericIdResolver { base: self } } - fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, bool), TranslateError> { + fn get_typed( + &self, + id: spirv::Word, + ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> { match self.type_check.get(&id) { Some(Some(x)) => Ok(x.clone()), Some(None) => Err(TranslateError::UntypedSymbol), None => match self.special_registers.get(id) { - Some(x) => Ok((x.get_type(), true)), + Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)), None => match self.global_type_check.get(&id) { Some(Some(result)) => Ok(result.clone()), Some(None) | None => Err(TranslateError::UntypedSymbol), @@ -5291,16 +5167,18 @@ impl<'b> NumericIdResolver<'b> { // This is for identifiers which will be emitted later as OpVariable // They are candidates for insertion of LoadVar/StoreVar - fn new_variable(&mut self, typ: ast::Type) -> spirv::Word { + fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> spirv::Word { let new_id = *self.current_id; - self.type_check.insert(new_id, Some((typ, true))); + self.type_check + .insert(new_id, Some((typ, state_space, true))); *self.current_id += 1; new_id } - fn new_non_variable(&mut self, typ: Option) -> spirv::Word { + fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> spirv::Word { let new_id = *self.current_id; - self.type_check.insert(new_id, typ.map(|t| (t, false))); + self.type_check + .insert(new_id, typ.map(|(t, space)| (t, space, false))); *self.current_id += 1; new_id } @@ -5315,18 +5193,22 @@ impl<'b> MutableNumericIdResolver<'b> { self.base } - fn get_typed(&self, id: spirv::Word) -> Result { - self.base.get_typed(id).map(|(t, _)| t) + fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, ast::StateSpace), TranslateError> { + self.base.get_typed(id).map(|(t, space, _)| (t, space)) } - fn new_non_variable(&mut self, typ: ast::Type) -> spirv::Word { - self.base.new_non_variable(Some(typ)) + fn register_intermediate( + &mut self, + typ: ast::Type, + state_space: ast::StateSpace, + ) -> spirv::Word { + self.base.register_intermediate(Some((typ, state_space))) } } enum Statement { Label(u32), - Variable(ast::Variable), + Variable(ast::Variable), Instruction(I), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), @@ -5349,7 +5231,8 @@ impl ExpandedStatement { Statement::Variable(var) } Statement::Instruction(inst) => inst - .visit(&mut |arg: ArgumentDescriptor<_>, _: Option<&ast::Type>| { + .visit(&mut |arg: ArgumentDescriptor<_>, + _: Option<(&ast::Type, ast::StateSpace)>| { Ok(f(arg.op, arg.is_dst)) }) .unwrap(), @@ -5364,16 +5247,17 @@ impl ExpandedStatement { Statement::StoreVar(details) } Statement::Call(mut call) => { - for (id, typ) in call.ret_params.iter_mut() { - let is_dst = match typ { - ast::FnArgumentType::Reg(_) => true, - ast::FnArgumentType::Param(_) => false, - ast::FnArgumentType::Shared => false, + for (id, _, space) in call.return_arguments.iter_mut() { + let is_dst = match space { + ast::StateSpace::Reg => true, + ast::StateSpace::Param => false, + ast::StateSpace::Shared => false, + _ => todo!(), }; *id = f(*id, is_dst); } - call.func = f(call.func, false); - for (id, _) in call.param_list.iter_mut() { + call.name = f(call.name, false); + for (id, _, _) in call.input_arguments.iter_mut() { *id = f(*id, false); } Statement::Call(call) @@ -5435,6 +5319,7 @@ impl ExpandedStatement { struct LoadVarDetails { arg: ast::Arg2, typ: ast::Type, + state_space: ast::StateSpace, // (index, vector_width) // HACK ALERT // For some reason IGC explodes when you try to load from builtin vectors @@ -5454,7 +5339,12 @@ struct RepackVectorDetails { typ: ast::ScalarType, packed: spirv::Word, unpacked: Vec, - vector_sema: ArgumentSemantics, + non_default_implicit_conversion: Option< + fn( + (ast::StateSpace, &ast::Type), + (ast::StateSpace, &ast::Type), + ) -> Result, TranslateError>, + >, } impl RepackVectorDetails { @@ -5470,13 +5360,17 @@ impl RepackVectorDetails { ArgumentDescriptor { op: self.packed, is_dst: !self.is_extract, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, - Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)), + Some(( + &ast::Type::Vector(self.typ, self.unpacked.len() as u8), + ast::StateSpace::Reg, + )), )?; let scalar_type = self.typ; let is_extract = self.is_extract; - let vector_sema = self.vector_sema; + let non_default_implicit_conversion = self.non_default_implicit_conversion; let vector = self .unpacked .into_iter() @@ -5485,9 +5379,10 @@ impl RepackVectorDetails { ArgumentDescriptor { op: id, is_dst: is_extract, - sema: vector_sema, + is_memory_access: false, + non_default_implicit_conversion, }, - Some(&ast::Type::Scalar(scalar_type)), + Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)), ) }) .collect::>()?; @@ -5496,7 +5391,7 @@ impl RepackVectorDetails { typ: self.typ, packed: scalar, unpacked: vector, - vector_sema, + non_default_implicit_conversion, }) } } @@ -5514,18 +5409,18 @@ impl, U: ArgParamsEx> Visitab struct ResolvedCall { pub uniform: bool, - pub ret_params: Vec<(P::Id, ast::FnArgumentType)>, - pub func: P::Id, - pub param_list: Vec<(P::Operand, ast::FnArgumentType)>, + pub return_arguments: Vec<(P::Id, ast::Type, ast::StateSpace)>, + pub name: P::Id, + pub input_arguments: Vec<(P::Operand, ast::Type, ast::StateSpace)>, } impl ResolvedCall { fn cast>(self) -> ResolvedCall { ResolvedCall { uniform: self.uniform, - ret_params: self.ret_params, - func: self.func, - param_list: self.param_list, + return_arguments: self.return_arguments, + name: self.name, + input_arguments: self.input_arguments, } } } @@ -5535,49 +5430,53 @@ impl> ResolvedCall { self, visitor: &mut V, ) -> Result, TranslateError> { - let ret_params = self - .ret_params + let return_arguments = self + .return_arguments .into_iter() - .map::, _>(|(id, typ)| { + .map::, _>(|(id, typ, space)| { let new_id = visitor.id( ArgumentDescriptor { op: id, - is_dst: !typ.is_param(), - sema: typ.semantics(), + is_dst: space != ast::StateSpace::Param, + is_memory_access: false, + non_default_implicit_conversion: None, }, - Some(&typ.to_func_type()), + Some((&typ, space)), )?; - Ok((new_id, typ)) + Ok((new_id, typ, space)) }) .collect::, _>>()?; let func = visitor.id( ArgumentDescriptor { - op: self.func, + op: self.name, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, None, )?; - let param_list = self - .param_list + let input_arguments = self + .input_arguments .into_iter() - .map::, _>(|(id, typ)| { + .map::, _>(|(id, typ, space)| { let new_id = visitor.operand( ArgumentDescriptor { op: id, is_dst: false, - sema: typ.semantics(), + is_memory_access: false, + non_default_implicit_conversion: None, }, - &typ.to_func_type(), + &typ, + space, )?; - Ok((new_id, typ)) + Ok((new_id, typ, space)) }) .collect::, _>>()?; Ok(ResolvedCall { uniform: self.uniform, - ret_params, - func, - param_list, + return_arguments, + name: func, + input_arguments, }) } } @@ -5598,39 +5497,34 @@ impl> PtrAccess

{ self, visitor: &mut V, ) -> Result, TranslateError> { - let sema = match self.state_space { - ast::LdStateSpace::Const - | ast::LdStateSpace::Global - | ast::LdStateSpace::Shared - | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer, - ast::LdStateSpace::Local | ast::LdStateSpace::Param => { - ArgumentSemantics::RegisterPointer - } - }; - let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), self.state_space); + let ptr_type = ast::Type::Scalar(self.underlying_type.clone()); let new_dst = visitor.id( ArgumentDescriptor { op: self.dst, is_dst: true, - sema, + is_memory_access: false, + non_default_implicit_conversion: None, }, - Some(&ptr_type), + Some((&ptr_type, self.state_space)), )?; let new_ptr_src = visitor.id( ArgumentDescriptor { op: self.ptr_src, is_dst: false, - sema, + is_memory_access: false, + non_default_implicit_conversion: None, }, - Some(&ptr_type), + Some((&ptr_type, self.state_space)), )?; let new_constant_src = visitor.operand( ArgumentDescriptor { op: self.offset_src, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, )?; Ok(PtrAccess { underlying_type: self.underlying_type, @@ -5653,21 +5547,9 @@ impl, U: ArgParamsEx> Visitab } } -pub trait ArgParamsEx: ast::ArgParams + Sized { - fn get_fn_decl<'x, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'x, 'b>, - ) -> Result<&'b FnDecl, TranslateError>; -} +pub trait ArgParamsEx: ast::ArgParams + Sized {} -impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> { - fn get_fn_decl<'x, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'x, 'b>, - ) -> Result<&'b FnDecl, TranslateError> { - decl.get_fn_decl_str(id) - } -} +impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {} enum NormalizedArgParams {} @@ -5676,14 +5558,7 @@ impl ast::ArgParams for NormalizedArgParams { type Operand = ast::Operand; } -impl ArgParamsEx for NormalizedArgParams { - fn get_fn_decl<'a, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'a, 'b>, - ) -> Result<&'b FnDecl, TranslateError> { - decl.get_fn_decl(*id) - } -} +impl ArgParamsEx for NormalizedArgParams {} type NormalizedStatement = Statement< ( @@ -5702,14 +5577,7 @@ impl ast::ArgParams for TypedArgParams { type Operand = TypedOperand; } -impl ArgParamsEx for TypedArgParams { - fn get_fn_decl<'a, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'a, 'b>, - ) -> Result<&'b FnDecl, TranslateError> { - decl.get_fn_decl(*id) - } -} +impl ArgParamsEx for TypedArgParams {} #[derive(Copy, Clone)] enum TypedOperand { @@ -5740,24 +5608,16 @@ impl ast::ArgParams for ExpandedArgParams { type Operand = spirv::Word; } -impl ArgParamsEx for ExpandedArgParams { - fn get_fn_decl<'a, 'b>( - id: &Self::Id, - decl: &'b GlobalFnDeclResolver<'a, 'b>, - ) -> Result<&'b FnDecl, TranslateError> { - decl.get_fn_decl(*id) - } -} +impl ArgParamsEx for ExpandedArgParams {} enum Directive<'input> { - Variable(ast::Variable), + Variable(ast::LinkingDirective, ast::Variable), Method(Function<'input>), } struct Function<'input> { - pub func_decl: ast::MethodDecl<'input, spirv::Word>, - pub spirv_decl: SpirvMethodDecl<'input>, - pub globals: Vec>, + pub func_decl: Rc>>, + pub globals: Vec>, pub body: Option>, import_as: Option, tuning: Vec, @@ -5767,12 +5627,13 @@ pub trait ArgumentMapVisitor { fn id( &mut self, desc: ArgumentDescriptor, - typ: Option<&ast::Type>, + typ: Option<(&ast::Type, ast::StateSpace)>, ) -> Result; fn operand( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result; } @@ -5780,13 +5641,13 @@ impl ArgumentMapVisitor for T where T: FnMut( ArgumentDescriptor, - Option<&ast::Type>, + Option<(&ast::Type, ast::StateSpace)>, ) -> Result, { fn id( &mut self, desc: ArgumentDescriptor, - t: Option<&ast::Type>, + t: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self(desc, t) } @@ -5795,8 +5656,9 @@ where &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { - self(desc, Some(typ)) + self(desc, Some((typ, state_space))) } } @@ -5807,7 +5669,7 @@ where fn id( &mut self, desc: ArgumentDescriptor<&str>, - _: Option<&ast::Type>, + _: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self(desc.op) } @@ -5816,6 +5678,7 @@ where &mut self, desc: ArgumentDescriptor>, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result, TranslateError> { Ok(match desc.op { ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?), @@ -5824,7 +5687,7 @@ where ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member), ast::Operand::VecPack(ref ids) => ast::Operand::VecPack( ids.into_iter() - .map(|id| self.id(desc.new_op(id), Some(typ))) + .map(|id| self.id(desc.new_op(id), Some((typ, state_space)))) .collect::, _>>()?, ), }) @@ -5834,37 +5697,30 @@ where pub struct ArgumentDescriptor { op: Op, is_dst: bool, - sema: ArgumentSemantics, + is_memory_access: bool, + non_default_implicit_conversion: Option< + fn( + (ast::StateSpace, &ast::Type), + (ast::StateSpace, &ast::Type), + ) -> Result, TranslateError>, + >, } pub struct PtrAccess { - underlying_type: ast::PointerType, - state_space: ast::LdStateSpace, + underlying_type: ast::ScalarType, + state_space: ast::StateSpace, dst: spirv::Word, ptr_src: spirv::Word, offset_src: P::Operand, } -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -pub enum ArgumentSemantics { - // normal register access - Default, - // normal register access with relaxed conversion rules (ld/st) - DefaultRelaxed, - // st/ld global - PhysicalPointer, - // st/ld .param, .local - RegisterPointer, - // mov of .local/.global variables - Address, -} - impl ArgumentDescriptor { fn new_op(&self, u: U) -> ArgumentDescriptor { ArgumentDescriptor { op: u, is_dst: self.is_dst, - sema: self.sema, + is_memory_access: self.is_memory_access, + non_default_implicit_conversion: self.non_default_implicit_conversion, } } } @@ -5905,7 +5761,9 @@ impl ast::Instruction { let inst_type = d.typ; ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?) } - ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?), + ast::Instruction::Not(t, a) => { + ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?) + } ast::Instruction::Cvt(d, a) => { let (dst_t, src_t) = match &d { ast::CvtDetails::FloatFromFloat(desc) => ( @@ -5928,7 +5786,7 @@ impl ast::Instruction { ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?) } ast::Instruction::Shl(t, a) => { - ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?) + ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?) } ast::Instruction::Shr(t, a) => { ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?) @@ -6101,17 +5959,19 @@ impl ImplicitConversion { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: self.dst_sema, + is_memory_access: false, + non_default_implicit_conversion: None, }, - Some(&self.to), + Some((&self.to_type, self.to_space)), )?; let new_src = visitor.id( ArgumentDescriptor { op: self.src, is_dst: false, - sema: self.src_sema, + is_memory_access: false, + non_default_implicit_conversion: None, }, - Some(&self.from), + Some((&self.from_type, self.from_space)), )?; Ok(Statement::Conversion({ ImplicitConversion { @@ -6138,13 +5998,13 @@ impl ArgumentMapVisitor for T where T: FnMut( ArgumentDescriptor, - Option<&ast::Type>, + Option<(&ast::Type, ast::StateSpace)>, ) -> Result, { fn id( &mut self, desc: ArgumentDescriptor, - t: Option<&ast::Type>, + t: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self(desc, t) } @@ -6153,12 +6013,15 @@ where &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { Ok(match desc.op { - TypedOperand::Reg(id) => TypedOperand::Reg(self(desc.new_op(id), Some(typ))?), + TypedOperand::Reg(id) => { + TypedOperand::Reg(self(desc.new_op(id), Some((typ, state_space)))?) + } TypedOperand::Imm(imm) => TypedOperand::Imm(imm), TypedOperand::RegOffset(id, imm) => { - TypedOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm) + TypedOperand::RegOffset(self(desc.new_op(id), Some((typ, state_space)))?, imm) } TypedOperand::VecMember(reg, index) => { let scalar_type = match typ { @@ -6166,7 +6029,10 @@ where _ => return Err(error_unreachable()), }; let vec_type = ast::Type::Vector(scalar_type, index + 1); - TypedOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) + TypedOperand::VecMember( + self(desc.new_op(reg), Some((&vec_type, state_space)))?, + index, + ) } }) } @@ -6178,9 +6044,9 @@ impl ast::Type { ast::Type::Scalar(scalar) => { let kind = scalar.kind(); let width = scalar.size_of(); - if (kind != ScalarKind::Signed - && kind != ScalarKind::Unsigned - && kind != ScalarKind::Bit) + if (kind != ast::ScalarKind::Signed + && kind != ast::ScalarKind::Unsigned + && kind != ast::ScalarKind::Bit) || (width == 8) { return Err(TranslateError::MismatchedType); @@ -6198,57 +6064,32 @@ impl ast::Type { match self { ast::Type::Scalar(scalar) => TypeParts { kind: TypeKind::Scalar, + state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), - state_space: ast::LdStateSpace::Global, }, ast::Type::Vector(scalar, components) => TypeParts { kind: TypeKind::Vector, + state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), width: scalar.size_of(), components: vec![*components as u32], - state_space: ast::LdStateSpace::Global, }, ast::Type::Array(scalar, components) => TypeParts { kind: TypeKind::Array, + state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), width: scalar.size_of(), components: components.clone(), - state_space: ast::LdStateSpace::Global, }, - ast::Type::Pointer(ast::PointerType::Scalar(scalar), state_space) => TypeParts { - kind: TypeKind::PointerScalar, + ast::Type::Pointer(scalar, space) => TypeParts { + kind: TypeKind::Pointer, + state_space: *space, scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), - state_space: *state_space, }, - ast::Type::Pointer(ast::PointerType::Vector(scalar, len), state_space) => TypeParts { - kind: TypeKind::PointerVector, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: vec![*len as u32], - state_space: *state_space, - }, - ast::Type::Pointer(ast::PointerType::Array(scalar, components), state_space) => { - TypeParts { - kind: TypeKind::PointerArray, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: components.clone(), - state_space: *state_space, - } - } - ast::Type::Pointer(ast::PointerType::Pointer(scalar, inner_space), state_space) => { - TypeParts { - kind: TypeKind::PointerPointer, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: vec![*inner_space as u32], - state_space: *state_space, - } - } } } @@ -6265,29 +6106,8 @@ impl ast::Type { ast::ScalarType::from_parts(t.width, t.scalar_kind), t.components, ), - TypeKind::PointerScalar => ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)), - t.state_space, - ), - TypeKind::PointerVector => ast::Type::Pointer( - ast::PointerType::Vector( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - t.components[0] as u8, - ), - t.state_space, - ), - TypeKind::PointerArray => ast::Type::Pointer( - ast::PointerType::Array( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - t.components, - ), - t.state_space, - ), - TypeKind::PointerPointer => ast::Type::Pointer( - ast::PointerType::Pointer( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - unsafe { mem::transmute::<_, ast::LdStateSpace>(t.components[0] as u8) }, - ), + TypeKind::Pointer => ast::Type::Pointer( + ast::ScalarType::from_parts(t.width, t.scalar_kind), t.state_space, ), } @@ -6300,7 +6120,7 @@ impl ast::Type { ast::Type::Array(typ, len) => len .iter() .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), - ast::Type::Pointer(_, _) => mem::size_of::(), + ast::Type::Pointer(..) => mem::size_of::(), } } } @@ -6308,10 +6128,10 @@ impl ast::Type { #[derive(Eq, PartialEq, Clone)] struct TypeParts { kind: TypeKind, - scalar_kind: ScalarKind, + scalar_kind: ast::ScalarKind, width: u8, + state_space: ast::StateSpace, components: Vec, - state_space: ast::LdStateSpace, } #[derive(Eq, PartialEq, Copy, Clone)] @@ -6319,10 +6139,7 @@ enum TypeKind { Scalar, Vector, Array, - PointerScalar, - PointerVector, - PointerArray, - PointerPointer, + Pointer, } impl ast::Instruction { @@ -6450,21 +6267,21 @@ struct BrachCondition { struct ImplicitConversion { src: spirv::Word, dst: spirv::Word, - from: ast::Type, - to: ast::Type, + from_type: ast::Type, + to_type: ast::Type, + from_space: ast::StateSpace, + to_space: ast::StateSpace, kind: ConversionKind, - src_sema: ArgumentSemantics, - dst_sema: ArgumentSemantics, } -#[derive(PartialEq, Copy, Clone)] +#[derive(PartialEq, Clone)] enum ConversionKind { Default, // zero-extend/chop/bitcast depending on types SignExtend, - BitToPtr(ast::LdStateSpace), - PtrToBit(ast::UIntType), - PtrToPtr { spirv_ptr: bool }, + BitToPtr, + PtrToPtr, + AddressOf, } impl ast::PredAt { @@ -6512,13 +6329,14 @@ impl ast::Arg1 { fn map>( self, visitor: &mut V, - t: Option<&ast::Type>, + t: Option<(&ast::Type, ast::StateSpace)>, ) -> Result, TranslateError> { let new_src = visitor.id( ArgumentDescriptor { op: self.src, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, t, )?; @@ -6535,9 +6353,11 @@ impl ast::Arg1Bar { ArgumentDescriptor { op: self.src, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, )?; Ok(ast::Arg1Bar { src: new_src }) } @@ -6553,17 +6373,21 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, t, + ast::StateSpace::Reg, )?; let new_src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, t, + ast::StateSpace::Reg, )?; Ok(ast::Arg2 { dst: new_dst, @@ -6581,17 +6405,21 @@ impl ast::Arg2 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, dst_t, + ast::StateSpace::Reg, )?; let src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, src_t, + ast::StateSpace::Reg, )?; Ok(ast::Arg2 { dst, src }) } @@ -6607,26 +6435,21 @@ impl ast::Arg2Ld { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::DefaultRelaxed, + is_memory_access: false, + non_default_implicit_conversion: Some(should_convert_relaxed_dst_wrapper), }, &ast::Type::from(details.typ.clone()), + ast::StateSpace::Reg, )?; - let is_logical_ptr = details.state_space == ast::LdStateSpace::Param - || details.state_space == ast::LdStateSpace::Local; let src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, - sema: if is_logical_ptr { - ArgumentSemantics::RegisterPointer - } else { - ArgumentSemantics::PhysicalPointer - }, + is_memory_access: true, + non_default_implicit_conversion: None, }, - &ast::Type::Pointer( - ast::PointerType::from(details.typ.clone()), - details.state_space, - ), + &details.typ, + details.state_space, )?; Ok(ast::Arg2Ld { dst, src }) } @@ -6638,30 +6461,25 @@ impl ast::Arg2St { visitor: &mut V, details: &ast::StData, ) -> Result, TranslateError> { - let is_logical_ptr = details.state_space == ast::StStateSpace::Param - || details.state_space == ast::StStateSpace::Local; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, - sema: if is_logical_ptr { - ArgumentSemantics::RegisterPointer - } else { - ArgumentSemantics::PhysicalPointer - }, + is_memory_access: true, + non_default_implicit_conversion: None, }, - &ast::Type::Pointer( - ast::PointerType::from(details.typ.clone()), - details.state_space.to_ld_ss(), - ), + &details.typ, + details.state_space, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::DefaultRelaxed, + is_memory_access: false, + non_default_implicit_conversion: Some(should_convert_relaxed_src_wrapper), }, &details.typ.clone().into(), + ast::StateSpace::Reg, )?; Ok(ast::Arg2St { src1, src2 }) } @@ -6677,21 +6495,21 @@ impl ast::Arg2Mov { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &details.typ.clone().into(), + ast::StateSpace::Reg, )?; let src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, - sema: if details.src_is_address { - ArgumentSemantics::Address - } else { - ArgumentSemantics::Default - }, + is_memory_access: false, + non_default_implicit_conversion: Some(implicit_conversion_mov), }, &details.typ.clone().into(), + ast::StateSpace::Reg, )?; Ok(ast::Arg2Mov { dst, src }) } @@ -6713,25 +6531,31 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, wide_type.as_ref().unwrap_or(typ), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, typ, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, typ, + ast::StateSpace::Reg, )?; Ok(ast::Arg3 { dst, src1, src2 }) } @@ -6745,25 +6569,31 @@ impl ast::Arg3 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, t, + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, t, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, )?; Ok(ast::Arg3 { dst, src1, src2 }) } @@ -6772,35 +6602,38 @@ impl ast::Arg3 { self, visitor: &mut V, t: ast::ScalarType, - state_space: ast::AtomSpace, + state_space: ast::StateSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::PhysicalPointer, + is_memory_access: true, + non_default_implicit_conversion: None, }, - &ast::Type::Pointer( - ast::PointerType::Scalar(scalar_type), - state_space.to_ld_ss(), - ), + &ast::Type::Scalar(scalar_type), + state_space, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; Ok(ast::Arg3 { dst, src1, src2 }) } @@ -6822,33 +6655,41 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, wide_type.as_ref().unwrap_or(t), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, t, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, t, + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, t, + ast::StateSpace::Reg, )?; Ok(ast::Arg4 { dst, @@ -6861,39 +6702,47 @@ impl ast::Arg4 { fn map_selp>( self, visitor: &mut V, - t: ast::SelpType, + t: ast::ScalarType, ) -> Result, TranslateError> { let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(t.into()), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(t.into()), + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(t.into()), + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, )?; Ok(ast::Arg4 { dst, @@ -6906,44 +6755,49 @@ impl ast::Arg4 { fn map_atom>( self, visitor: &mut V, - t: ast::BitType, - state_space: ast::AtomSpace, + t: ast::ScalarType, + state_space: ast::StateSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::PhysicalPointer, + is_memory_access: true, + non_default_implicit_conversion: None, }, - &ast::Type::Pointer( - ast::PointerType::Scalar(scalar_type), - state_space.to_ld_ss(), - ), + &ast::Type::Scalar(scalar_type), + state_space, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; Ok(ast::Arg4 { dst, @@ -6962,34 +6816,42 @@ impl ast::Arg4 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, typ, + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, typ, + ast::StateSpace::Reg, )?; let u32_type = ast::Type::Scalar(ast::ScalarType::U32); let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &u32_type, + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &u32_type, + ast::StateSpace::Reg, )?; Ok(ast::Arg4 { dst, @@ -7010,9 +6872,13 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: self.dst1, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, - Some(&ast::Type::Scalar(ast::ScalarType::Pred)), + Some(( + &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + )), )?; let dst2 = self .dst2 @@ -7021,9 +6887,13 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: dst2, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, - Some(&ast::Type::Scalar(ast::ScalarType::Pred)), + Some(( + &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + )), ) }) .transpose()?; @@ -7031,17 +6901,21 @@ impl ast::Arg4Setp { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, t, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, t, + ast::StateSpace::Reg, )?; Ok(ast::Arg4Setp { dst1, @@ -7062,41 +6936,51 @@ impl ast::Arg5 { ArgumentDescriptor { op: self.dst, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, base_type, + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, base_type, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, base_type, + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, )?; let src4 = visitor.operand( ArgumentDescriptor { op: self.src4, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, )?; Ok(ast::Arg5 { dst, @@ -7118,9 +7002,13 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: self.dst1, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, - Some(&ast::Type::Scalar(ast::ScalarType::Pred)), + Some(( + &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + )), )?; let dst2 = self .dst2 @@ -7129,9 +7017,13 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: dst2, is_dst: true, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, - Some(&ast::Type::Scalar(ast::ScalarType::Pred)), + Some(( + &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + )), ) }) .transpose()?; @@ -7139,25 +7031,31 @@ impl ast::Arg5Setp { ArgumentDescriptor { op: self.src1, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, t, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, t, + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, - sema: ArgumentSemantics::Default, + is_memory_access: false, + non_default_implicit_conversion: None, }, &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, )?; Ok(ast::Arg5Setp { dst1, @@ -7195,115 +7093,41 @@ impl ast::Operand { } } -impl ast::StStateSpace { - fn to_ld_ss(self) -> ast::LdStateSpace { - match self { - ast::StStateSpace::Generic => ast::LdStateSpace::Generic, - ast::StStateSpace::Global => ast::LdStateSpace::Global, - ast::StStateSpace::Local => ast::LdStateSpace::Local, - ast::StStateSpace::Param => ast::LdStateSpace::Param, - ast::StStateSpace::Shared => ast::LdStateSpace::Shared, - } - } -} - -#[derive(Clone, Copy, PartialEq, Eq)] -enum ScalarKind { - Bit, - Unsigned, - Signed, - Float, - Float2, - Pred, -} - impl ast::ScalarType { - fn kind(self) -> ScalarKind { - match self { - ast::ScalarType::U8 => ScalarKind::Unsigned, - ast::ScalarType::U16 => ScalarKind::Unsigned, - ast::ScalarType::U32 => ScalarKind::Unsigned, - ast::ScalarType::U64 => ScalarKind::Unsigned, - ast::ScalarType::S8 => ScalarKind::Signed, - ast::ScalarType::S16 => ScalarKind::Signed, - ast::ScalarType::S32 => ScalarKind::Signed, - ast::ScalarType::S64 => ScalarKind::Signed, - ast::ScalarType::B8 => ScalarKind::Bit, - ast::ScalarType::B16 => ScalarKind::Bit, - ast::ScalarType::B32 => ScalarKind::Bit, - ast::ScalarType::B64 => ScalarKind::Bit, - ast::ScalarType::F16 => ScalarKind::Float, - ast::ScalarType::F32 => ScalarKind::Float, - ast::ScalarType::F64 => ScalarKind::Float, - ast::ScalarType::F16x2 => ScalarKind::Float2, - ast::ScalarType::Pred => ScalarKind::Pred, - } - } - - fn from_parts(width: u8, kind: ScalarKind) -> Self { + fn from_parts(width: u8, kind: ast::ScalarKind) -> Self { match kind { - ScalarKind::Float => match width { + ast::ScalarKind::Float => match width { 2 => ast::ScalarType::F16, 4 => ast::ScalarType::F32, 8 => ast::ScalarType::F64, _ => unreachable!(), }, - ScalarKind::Bit => match width { + ast::ScalarKind::Bit => match width { 1 => ast::ScalarType::B8, 2 => ast::ScalarType::B16, 4 => ast::ScalarType::B32, 8 => ast::ScalarType::B64, _ => unreachable!(), }, - ScalarKind::Signed => match width { + ast::ScalarKind::Signed => match width { 1 => ast::ScalarType::S8, 2 => ast::ScalarType::S16, 4 => ast::ScalarType::S32, 8 => ast::ScalarType::S64, _ => unreachable!(), }, - ScalarKind::Unsigned => match width { + ast::ScalarKind::Unsigned => match width { 1 => ast::ScalarType::U8, 2 => ast::ScalarType::U16, 4 => ast::ScalarType::U32, 8 => ast::ScalarType::U64, _ => unreachable!(), }, - ScalarKind::Float2 => match width { + ast::ScalarKind::Float2 => match width { 4 => ast::ScalarType::F16x2, _ => unreachable!(), }, - ScalarKind::Pred => ast::ScalarType::Pred, - } - } -} - -impl ast::BooleanType { - fn to_type(self) -> ast::Type { - match self { - ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred), - ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16), - ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32), - ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64), - } - } -} - -impl ast::ShlType { - fn to_type(self) -> ast::Type { - match self { - ast::ShlType::B16 => ast::Type::Scalar(ast::ScalarType::B16), - ast::ShlType::B32 => ast::Type::Scalar(ast::ScalarType::B32), - ast::ShlType::B64 => ast::Type::Scalar(ast::ScalarType::B64), - } - } -} - -impl ast::ShrType { - fn signed(&self) -> bool { - match self { - ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true, - _ => false, + ast::ScalarKind::Pred => ast::ScalarType::Pred, } } } @@ -7359,49 +7183,47 @@ impl ast::AtomInnerDetails { } } -impl ast::SIntType { - fn from_size(width: u8) -> Self { - match width { - 1 => ast::SIntType::S8, - 2 => ast::SIntType::S16, - 4 => ast::SIntType::S32, - 8 => ast::SIntType::S64, - _ => unreachable!(), - } - } -} - -impl ast::UIntType { - fn from_size(width: u8) -> Self { - match width { - 1 => ast::UIntType::U8, - 2 => ast::UIntType::U16, - 4 => ast::UIntType::U32, - 8 => ast::UIntType::U64, - _ => unreachable!(), - } - } -} - -impl ast::LdStateSpace { +impl ast::StateSpace { fn to_spirv(self) -> spirv::StorageClass { match self { - ast::LdStateSpace::Const => spirv::StorageClass::UniformConstant, - ast::LdStateSpace::Generic => spirv::StorageClass::Generic, - ast::LdStateSpace::Global => spirv::StorageClass::CrossWorkgroup, - ast::LdStateSpace::Local => spirv::StorageClass::Function, - ast::LdStateSpace::Shared => spirv::StorageClass::Workgroup, - ast::LdStateSpace::Param => spirv::StorageClass::Function, + ast::StateSpace::Const => spirv::StorageClass::UniformConstant, + ast::StateSpace::Generic => spirv::StorageClass::Generic, + ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, + ast::StateSpace::Local => spirv::StorageClass::Function, + ast::StateSpace::Shared => spirv::StorageClass::Workgroup, + ast::StateSpace::Param => spirv::StorageClass::Function, + ast::StateSpace::Reg => spirv::StorageClass::Function, + ast::StateSpace::Sreg => spirv::StorageClass::Input, } } -} -impl From for ast::VariableType { - fn from(t: ast::FnArgumentType) -> Self { - match t { - ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t), - ast::FnArgumentType::Param(t) => ast::VariableType::Param(t), - ast::FnArgumentType::Shared => todo!(), + fn is_compatible(self, other: ast::StateSpace) -> bool { + self == other + || self == ast::StateSpace::Reg && other == ast::StateSpace::Sreg + || self == ast::StateSpace::Sreg && other == ast::StateSpace::Reg + } + + fn coerces_to_generic(self) -> bool { + match self { + ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::Local + | ast::StateSpace::Shared => true, + ast::StateSpace::Reg + | ast::StateSpace::Param + | ast::StateSpace::Generic + | ast::StateSpace::Sreg => false, + } + } + + fn is_addressable(self) -> bool { + match self { + ast::StateSpace::Const + | ast::StateSpace::Generic + | ast::StateSpace::Global + | ast::StateSpace::Local + | ast::StateSpace::Shared => true, + ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false, } } } @@ -7427,16 +7249,6 @@ impl ast::MulDetails { } } -impl ast::AtomSpace { - fn to_ld_ss(self) -> ast::LdStateSpace { - match self { - ast::AtomSpace::Generic => ast::LdStateSpace::Generic, - ast::AtomSpace::Global => ast::LdStateSpace::Global, - ast::AtomSpace::Shared => ast::LdStateSpace::Shared, - } - } -} - impl ast::MemScope { fn to_spirv(self) -> spirv::Scope { match self { @@ -7458,109 +7270,96 @@ impl ast::AtomSemantics { } } -impl ast::FnArgumentType { - fn semantics(&self) -> ArgumentSemantics { - match self { - ast::FnArgumentType::Reg(_) => ArgumentSemantics::Default, - ast::FnArgumentType::Param(_) => ArgumentSemantics::RegisterPointer, - ast::FnArgumentType::Shared => ArgumentSemantics::PhysicalPointer, - } +fn default_implicit_conversion( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if !instruction_space.is_compatible(operand_space) { + default_implicit_conversion_space( + (operand_space, operand_type), + (instruction_space, instruction_type), + ) + } else if instruction_type != operand_type { + default_implicit_conversion_type(instruction_space, operand_type, instruction_type) + } else { + Ok(None) } } -fn bitcast_register_pointer( - operand_type: &ast::Type, - instr_type: &ast::Type, - ss: Option, +// Space is different +fn default_implicit_conversion_space( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - bitcast_physical_pointer(operand_type, instr_type, ss) -} - -fn bitcast_physical_pointer( - operand_type: &ast::Type, - instr_type: &ast::Type, - ss: Option, -) -> Result, TranslateError> { - match operand_type { - // array decays to a pointer - ast::Type::Array(op_scalar_t, _) => { - if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type { - if ss == Some(*instr_space) { - if ast::Type::Scalar(*op_scalar_t) == ast::Type::from(instr_scalar_t.clone()) { - Ok(None) - } else { - Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) - } + if (instruction_space == ast::StateSpace::Generic && operand_space.coerces_to_generic()) + || (operand_space == ast::StateSpace::Generic && instruction_space.coerces_to_generic()) + { + Ok(Some(ConversionKind::PtrToPtr)) + } else if operand_space.is_compatible(ast::StateSpace::Reg) { + match operand_type { + ast::Type::Pointer(operand_ptr_type, operand_ptr_space) + if *operand_ptr_space == instruction_space => + { + if instruction_type != &ast::Type::Scalar(*operand_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) } else { - if ss == Some(ast::LdStateSpace::Generic) - || *instr_space == ast::LdStateSpace::Generic - { - Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) - } else { - Err(TranslateError::MismatchedType) - } + Ok(None) } - } else { - Err(TranslateError::MismatchedType) } + // TODO: 32 bit + ast::Type::Scalar(ast::ScalarType::B64) + | ast::Type::Scalar(ast::ScalarType::U64) + | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space { + ast::StateSpace::Global + | ast::StateSpace::Generic + | ast::StateSpace::Const + | ast::StateSpace::Local + | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), + _ => Err(TranslateError::MismatchedType), + }, + ast::Type::Scalar(ast::ScalarType::B32) + | ast::Type::Scalar(ast::ScalarType::U32) + | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { + ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { + Ok(Some(ConversionKind::BitToPtr)) + } + _ => Err(TranslateError::MismatchedType), + }, + _ => Err(TranslateError::MismatchedType), } - ast::Type::Scalar(ast::ScalarType::B64) - | ast::Type::Scalar(ast::ScalarType::U64) - | ast::Type::Scalar(ast::ScalarType::S64) => { - if let Some(space) = ss { - Ok(Some(ConversionKind::BitToPtr(space))) - } else { - Err(error_unreachable()) - } - } - ast::Type::Scalar(ast::ScalarType::B32) - | ast::Type::Scalar(ast::ScalarType::U32) - | ast::Type::Scalar(ast::ScalarType::S32) => match ss { - Some(ast::LdStateSpace::Shared) - | Some(ast::LdStateSpace::Generic) - | Some(ast::LdStateSpace::Param) - | Some(ast::LdStateSpace::Local) => { - Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared))) + } else if instruction_space.is_compatible(ast::StateSpace::Reg) { + match instruction_type { + ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space) + if operand_space == *instruction_ptr_space => + { + if operand_type != &ast::Type::Scalar(*instruction_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } } _ => Err(TranslateError::MismatchedType), - }, - ast::Type::Pointer(op_scalar_t, op_space) => { - if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type { - if op_space == instr_space { - if op_scalar_t == instr_scalar_t { - Ok(None) - } else { - Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) - } - } else { - if *op_space == ast::LdStateSpace::Generic - || *instr_space == ast::LdStateSpace::Generic - { - Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) - } else { - Err(TranslateError::MismatchedType) - } - } - } else { - Err(TranslateError::MismatchedType) - } } - _ => Err(TranslateError::MismatchedType), + } else { + Err(TranslateError::MismatchedType) } } -fn force_bitcast_ptr_to_bit( - _: &ast::Type, - instr_type: &ast::Type, - _: Option, +// Space is same, but type is different +fn default_implicit_conversion_type( + space: ast::StateSpace, + operand_type: &ast::Type, + instruction_type: &ast::Type, ) -> Result, TranslateError> { - // TODO: verify this on f32, u16 and the like - if let ast::Type::Scalar(scalar_t) = instr_type { - if let Ok(int_type) = (*scalar_t).try_into() { - return Ok(Some(ConversionKind::PtrToBit(int_type))); + if space.is_compatible(ast::StateSpace::Reg) { + if should_bitcast(instruction_type, operand_type) { + Ok(Some(ConversionKind::Default)) + } else { + Err(TranslateError::MismatchedType) } + } else { + Ok(Some(ConversionKind::PtrToPtr)) } - Err(TranslateError::MismatchedType) } fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { @@ -7570,16 +7369,18 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { return false; } match inst.kind() { - ScalarKind::Bit => operand.kind() != ScalarKind::Bit, - ScalarKind::Float => operand.kind() == ScalarKind::Bit, - ScalarKind::Signed => { - operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned + ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit, + ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit, + ast::ScalarKind::Signed => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Unsigned } - ScalarKind::Unsigned => { - operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed + ast::ScalarKind::Unsigned => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Signed } - ScalarKind::Float2 => false, - ScalarKind::Pred => false, + ast::ScalarKind::Float2 => false, + ast::ScalarKind::Pred => false, } } (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _)) @@ -7590,47 +7391,45 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { } } -fn should_bitcast_packed( - operand: &ast::Type, - instr: &ast::Type, - ss: Option, +fn implicit_conversion_mov( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = - (operand, instr) - { - if scalar.kind() == ScalarKind::Bit - && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) + // instruction_space is always reg + if operand_space.is_compatible(ast::StateSpace::Reg) { + if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = + (operand_type, instruction_type) { - return Ok(Some(ConversionKind::Default)); + if scalar.kind() == ast::ScalarKind::Bit + && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) + { + return Ok(Some(ConversionKind::Default)); + } } + // TODO: verify .params addressability: + // * kernel arg + // * func arg + // * variable + } else if operand_space.is_addressable() { + return Ok(Some(ConversionKind::AddressOf)); } - should_bitcast_wrapper(operand, instr, ss) -} - -fn should_bitcast_wrapper( - operand: &ast::Type, - instr: &ast::Type, - _: Option, -) -> Result, TranslateError> { - if instr == operand { - return Ok(None); - } - if should_bitcast(instr, operand) { - Ok(Some(ConversionKind::Default)) - } else { - Err(TranslateError::MismatchedType) - } + default_implicit_conversion( + (operand_space, operand_type), + (instruction_space, instruction_type), + ) } fn should_convert_relaxed_src_wrapper( - src_type: &ast::Type, - instr_type: &ast::Type, - _: Option, + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if src_type == instr_type { + if !operand_space.is_compatible(instruction_space) { + return Err(TranslateError::MismatchedType); + } + if operand_type == instruction_type { return Ok(None); } - match should_convert_relaxed_src(src_type, instr_type) { + match should_convert_relaxed_src(operand_type, instruction_type) { conv @ Some(_) => Ok(conv), None => Err(TranslateError::MismatchedType), } @@ -7646,32 +7445,33 @@ fn should_convert_relaxed_src( } match (src_type, instr_type) { (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { - ScalarKind::Bit => { + ast::ScalarKind::Bit => { if instr_type.size_of() <= src_type.size_of() { Some(ConversionKind::Default) } else { None } } - ScalarKind::Signed | ScalarKind::Unsigned => { + ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => { if instr_type.size_of() <= src_type.size_of() - && src_type.kind() != ScalarKind::Float + && src_type.kind() != ast::ScalarKind::Float { Some(ConversionKind::Default) } else { None } } - ScalarKind::Float => { - if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit + ast::ScalarKind::Float => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() == ast::ScalarKind::Bit { Some(ConversionKind::Default) } else { None } } - ScalarKind::Float2 => todo!(), - ScalarKind::Pred => None, + ast::ScalarKind::Float2 => todo!(), + ast::ScalarKind::Pred => None, }, (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { @@ -7685,14 +7485,16 @@ fn should_convert_relaxed_src( } fn should_convert_relaxed_dst_wrapper( - dst_type: &ast::Type, - instr_type: &ast::Type, - _: Option, + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if dst_type == instr_type { + if !operand_space.is_compatible(instruction_space) { + return Err(TranslateError::MismatchedType); + } + if operand_type == instruction_type { return Ok(None); } - match should_convert_relaxed_dst(dst_type, instr_type) { + match should_convert_relaxed_dst(operand_type, instruction_type) { conv @ Some(_) => Ok(conv), None => Err(TranslateError::MismatchedType), } @@ -7708,15 +7510,15 @@ fn should_convert_relaxed_dst( } match (dst_type, instr_type) { (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { - ScalarKind::Bit => { + ast::ScalarKind::Bit => { if instr_type.size_of() <= dst_type.size_of() { Some(ConversionKind::Default) } else { None } } - ScalarKind::Signed => { - if dst_type.kind() != ScalarKind::Float { + ast::ScalarKind::Signed => { + if dst_type.kind() != ast::ScalarKind::Float { if instr_type.size_of() == dst_type.size_of() { Some(ConversionKind::Default) } else if instr_type.size_of() < dst_type.size_of() { @@ -7728,25 +7530,26 @@ fn should_convert_relaxed_dst( None } } - ScalarKind::Unsigned => { + ast::ScalarKind::Unsigned => { if instr_type.size_of() <= dst_type.size_of() - && dst_type.kind() != ScalarKind::Float + && dst_type.kind() != ast::ScalarKind::Float { Some(ConversionKind::Default) } else { None } } - ScalarKind::Float => { - if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit + ast::ScalarKind::Float => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() == ast::ScalarKind::Bit { Some(ConversionKind::Default) } else { None } } - ScalarKind::Float2 => todo!(), - ScalarKind::Pred => None, + ast::ScalarKind::Float2 => todo!(), + ast::ScalarKind::Pred => None, }, (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { @@ -7759,77 +7562,46 @@ fn should_convert_relaxed_dst( } } -impl<'a> ast::MethodDecl<'a, &'a str> { +impl<'a> ast::MethodDeclaration<'a, &'a str> { fn name(&self) -> &'a str { - match self { - ast::MethodDecl::Kernel { name, .. } => name, - ast::MethodDecl::Func(_, name, _) => name, + match self.name { + ast::MethodName::Kernel(name) => name, + ast::MethodName::Func(name) => name, } } } -struct SpirvMethodDecl<'input> { - input: Vec>, - output: Vec>, - name: MethodName<'input>, - uses_shared_mem: bool, +impl<'a> ast::MethodDeclaration<'a, spirv::Word> { + fn effective_input_arguments(&self) -> impl Iterator + '_ { + let is_kernel = self.name.is_kernel(); + self.input_arguments + .iter() + .map(move |arg| { + if !is_kernel && arg.state_space != ast::StateSpace::Reg { + let spirv_type = + SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv()); + (arg.name, spirv_type) + } else { + (arg.name, SpirvType::new(arg.v_type.clone())) + } + }) + .chain(self.shared_mem.iter().map(|id| { + ( + *id, + SpirvType::Pointer( + Box::new(SpirvType::Base(SpirvScalarKey::B8)), + spirv::StorageClass::Workgroup, + ), + ) + })) + } } -impl<'input> SpirvMethodDecl<'input> { - fn new(ast_decl: &ast::MethodDecl<'input, spirv::Word>) -> Self { - let (input, output) = match ast_decl { - ast::MethodDecl::Kernel { in_args, .. } => { - let spirv_input = in_args - .iter() - .map(|var| { - let v_type = match &var.v_type { - ast::KernelArgumentType::Normal(t) => { - ast::FnArgumentType::Param(t.clone()) - } - ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared, - }; - ast::Variable { - name: var.name, - align: var.align, - v_type: v_type.to_kernel_type(), - array_init: var.array_init.clone(), - } - }) - .collect(); - (spirv_input, Vec::new()) - } - ast::MethodDecl::Func(out_args, _, in_args) => { - let (param_output, non_param_output): (Vec<_>, Vec<_>) = - out_args.iter().partition(|var| var.v_type.is_param()); - let spirv_output = non_param_output - .into_iter() - .cloned() - .map(|var| ast::Variable { - name: var.name, - align: var.align, - v_type: var.v_type.to_func_type(), - array_init: var.array_init.clone(), - }) - .collect(); - let spirv_input = param_output - .into_iter() - .cloned() - .chain(in_args.iter().cloned()) - .map(|var| ast::Variable { - name: var.name, - align: var.align, - v_type: var.v_type.to_func_type(), - array_init: var.array_init.clone(), - }) - .collect(); - (spirv_input, spirv_output) - } - }; - SpirvMethodDecl { - input, - output, - name: MethodName::new(ast_decl), - uses_shared_mem: false, +impl<'input, ID> ast::MethodName<'input, ID> { + fn is_kernel(&self) -> bool { + match self { + ast::MethodName::Kernel(..) => true, + ast::MethodName::Func(..) => false, } } } diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs index 5b00844..b5a1e3a 100644 --- a/zluda_dump/src/lib.rs +++ b/zluda_dump/src/lib.rs @@ -219,11 +219,19 @@ unsafe fn to_str(image: *const T) -> Option<&'static str> { fn directive_to_kernel(dir: &ast::Directive) -> Option<(String, Vec)> { match dir { - ast::Directive::Method(ast::Function { - func_directive: ast::MethodDecl::Kernel { name, in_args }, - .. - }) => { - let arg_sizes = in_args + ast::Directive::Method( + _, + ast::Function { + func_directive: + ast::MethodDeclaration { + name: ast::MethodName::Kernel(name), + input_arguments, + .. + }, + .. + }, + ) => { + let arg_sizes = input_arguments .iter() .map(|arg| ast::Type::from(arg.v_type.clone()).size_of()) .collect();