diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 2ac1f68..d1127af 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -2,13 +2,12 @@ name = "ptx" version = "0.0.0" authors = ["Andrzej Janik "] -edition = "2018" +edition = "2021" [lib] [dependencies] -lalrpop-util = "0.19" -regex = "1" +ptx_parser = { path = "../ptx_parser" } rspirv = "0.7" spirv_headers = "1.5" quick-error = "1.2" @@ -17,10 +16,6 @@ bit-vec = "0.6" half ="1.6" bitflags = "1.2" -[build-dependencies.lalrpop] -version = "0.19" -features = ["lexer"] - [dev-dependencies] hip_runtime-sys = { path = "../hip_runtime-sys" } tempfile = "3" diff --git a/ptx/build.rs b/ptx/build.rs deleted file mode 100644 index 42c5d59..0000000 --- a/ptx/build.rs +++ /dev/null @@ -1,5 +0,0 @@ -extern crate lalrpop; - -fn main() { - lalrpop::process_root().unwrap(); -} \ No newline at end of file diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs deleted file mode 100644 index f1323be..0000000 --- a/ptx/src/ast.rs +++ /dev/null @@ -1,1072 +0,0 @@ -use half::f16; -use lalrpop_util::{lexer::Token, ParseError}; -use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; -use std::{marker::PhantomData, num::ParseIntError}; - -#[derive(Debug, thiserror::Error)] -pub enum PtxError { - #[error("{source}")] - ParseInt { - #[from] - source: ParseIntError, - }, - #[error("{source}")] - ParseFloat { - #[from] - source: ParseFloatError, - }, - #[error("")] - Unsupported32Bit, - #[error("")] - SyntaxError, - #[error("")] - NonF32Ftz, - #[error("")] - WrongArrayType, - #[error("")] - WrongVectorElement, - #[error("")] - MultiArrayVariable, - #[error("")] - ZeroDimensionArray, - #[error("")] - ArrayInitalizer, - #[error("")] - NonExternPointer, - #[error("{start}:{end}")] - UnrecognizedStatement { - start: usize, - end: usize, - }, - #[error("{start}:{end}")] - UnrecognizedDirective { - start: usize, - end: usize, - }, -} - -// For some weird reson this is illegal: -// .param .f16x2 foobar; -// but this is legal: -// .param .f16x2 foobar[1]; -// even more interestingly this is legal, but only in .func (not in .entry): -// .param .b32 foobar[] - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum BarDetails { - SyncAligned, -} - -pub trait UnwrapWithVec { - fn unwrap_with(self, errs: &mut Vec) -> To; -} - -impl, EInto> UnwrapWithVec - for Result -{ - fn unwrap_with(self, errs: &mut Vec) -> R { - self.unwrap_or_else(|e| { - errs.push(e.into()); - R::default() - }) - } -} - -impl< - R1: Default, - EFrom1: std::convert::Into, - R2: Default, - EFrom2: std::convert::Into, - EInto, - > UnwrapWithVec for (Result, Result) -{ - fn unwrap_with(self, errs: &mut Vec) -> (R1, R2) { - let (x, y) = self; - let r1 = x.unwrap_with(errs); - let r2 = y.unwrap_with(errs); - (r1, r2) - } -} - -pub struct Module<'a> { - pub version: (u8, u8), - pub directives: Vec>>, -} - -pub enum Directive<'a, P: ArgParams> { - Variable(LinkingDirective, Variable), - Method(LinkingDirective, Function<'a, &'a str, Statement

>), -} - -#[derive(Hash, PartialEq, Eq, Copy, Clone)] -pub enum MethodName<'input, ID> { - Kernel(&'input str), - Func(ID), -} - -pub struct MethodDeclaration<'input, ID> { - pub return_arguments: Vec>, - pub name: MethodName<'input, ID>, - pub input_arguments: Vec>, - pub shared_mem: Option, -} - -pub struct Function<'a, ID, S> { - pub func_directive: MethodDeclaration<'a, ID>, - pub tuning: Vec, - pub body: Option>, -} - -pub type ParsedFunction<'a> = Function<'a, &'a str, Statement>>; - -#[derive(PartialEq, Eq, Clone)] -pub enum Type { - // .param.b32 foo; - // -> OpTypeInt - Scalar(ScalarType), - // .param.v2.b32 foo; - // -> OpTypeVector - Vector(ScalarType, u8), - // .param.b32 foo[4]; - // -> OpTypeArray - Array(ScalarType, Vec), - /* - Variables of this type almost never exist in the original .ptx and are - usually artificially created. Some examples below: - - extern pointers to the .shared memory in the form: - .extern .shared .b32 shared_mem[]; - which we first parse as - .extern .shared .b32 shared_mem; - and then convert to an additional function parameter: - .param .ptr<.b32.shared> shared_mem; - and do a load at the start of the function (and renames inside fn): - .reg .ptr<.b32.shared> temp; - ld.param.ptr<.b32.shared> temp, [shared_mem]; - note, we don't support non-.shared extern pointers, because there's - zero use for them in the ptxas - - artifical pointers created by stateful conversion, which work - similiarly to the above - - function parameters: - foobar(.param .align 4 .b8 numbers[]) - which get parsed to - foobar(.param .align 4 .b8 numbers) - and then converted to - foobar(.reg .align 4 .ptr<.b8.param> numbers) - - ld/st with offset: - .reg.b32 x; - .param.b64 arg0; - st.param.b32 [arg0+4], x; - Yes, this code is legal and actually emitted by the NV compiler! - We convert the st to: - .reg ptr<.b64.param> temp = ptr_offset(arg0, 4); - st.param.b32 [temp], x; - */ - // .reg ptr<.b64.param> - // -> OpTypePointer Function - Pointer(ScalarType, StateSpace), -} - -#[derive(PartialEq, Eq, Hash, Clone, Copy)] -pub enum ScalarType { - B8, - B16, - B32, - B64, - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64, - F16, - F32, - F64, - F16x2, - Pred, -} - -impl ScalarType { - pub fn size_of(self) -> u8 { - match self { - ScalarType::U8 => 1, - ScalarType::S8 => 1, - ScalarType::B8 => 1, - ScalarType::U16 => 2, - ScalarType::S16 => 2, - ScalarType::B16 => 2, - ScalarType::F16 => 2, - ScalarType::U32 => 4, - ScalarType::S32 => 4, - ScalarType::B32 => 4, - ScalarType::F32 => 4, - ScalarType::U64 => 8, - ScalarType::S64 => 8, - ScalarType::B64 => 8, - ScalarType::F64 => 8, - ScalarType::F16x2 => 4, - ScalarType::Pred => 1, - } - } -} - -impl Default for ScalarType { - fn default() -> Self { - ScalarType::B8 - } -} - -pub enum Statement { - Label(P::Id), - Variable(MultiVariable), - Instruction(Option>, Instruction

), - Block(Vec>), -} - -pub struct MultiVariable { - pub var: Variable, - pub count: Option, -} - -#[derive(Clone)] -pub struct Variable { - pub align: Option, - pub v_type: Type, - pub state_space: StateSpace, - pub name: ID, - pub array_init: Vec, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum StateSpace { - Reg, - Const, - Global, - Local, - Shared, - Param, - Generic, - Sreg, -} - -pub struct PredAt { - pub not: bool, - pub label: ID, -} - -pub enum Instruction { - Ld(LdDetails, Arg2Ld

), - Mov(MovDetails, Arg2Mov

), - Mul(MulDetails, Arg3

), - Add(ArithDetails, Arg3

), - Setp(SetpData, Arg4Setp

), - SetpBool(SetpBoolData, Arg5Setp

), - Not(ScalarType, Arg2

), - Bra(BraData, Arg1

), - Cvt(CvtDetails, Arg2

), - Cvta(CvtaDetails, Arg2

), - Shl(ScalarType, Arg3

), - Shr(ScalarType, Arg3

), - St(StData, Arg2St

), - Ret(RetData), - Call(CallInst

), - Abs(AbsDetails, Arg2

), - Mad(MulDetails, Arg4

), - Fma(ArithFloat, Arg4

), - Or(ScalarType, Arg3

), - Sub(ArithDetails, Arg3

), - Min(MinMaxDetails, Arg3

), - Max(MinMaxDetails, Arg3

), - Rcp(RcpDetails, Arg2

), - And(ScalarType, Arg3

), - Selp(ScalarType, Arg4

), - Bar(BarDetails, Arg1Bar

), - Atom(AtomDetails, Arg3

), - AtomCas(AtomCasDetails, Arg4

), - Div(DivDetails, Arg3

), - Sqrt(SqrtDetails, Arg2

), - Rsqrt(RsqrtDetails, Arg2

), - Neg(NegDetails, Arg2

), - Sin { flush_to_zero: bool, arg: Arg2

}, - Cos { flush_to_zero: bool, arg: Arg2

}, - Lg2 { flush_to_zero: bool, arg: Arg2

}, - Ex2 { flush_to_zero: bool, arg: Arg2

}, - Clz { typ: ScalarType, arg: Arg2

}, - Brev { typ: ScalarType, arg: Arg2

}, - Popc { typ: ScalarType, arg: Arg2

}, - Xor { typ: ScalarType, arg: Arg3

}, - Bfe { typ: ScalarType, arg: Arg4

}, - Bfi { typ: ScalarType, arg: Arg5

}, - Rem { typ: ScalarType, arg: Arg3

}, - Prmt { control: u16, arg: Arg3

}, - Activemask { arg: Arg1

}, - Membar { level: MemScope }, -} - -#[derive(Copy, Clone)] -pub struct MadFloatDesc {} - -#[derive(Copy, Clone)] -pub struct AbsDetails { - pub flush_to_zero: Option, - pub typ: ScalarType, -} -#[derive(Copy, Clone)] -pub struct RcpDetails { - pub rounding: Option, - pub flush_to_zero: Option, - pub is_f64: bool, -} - -pub struct CallInst { - pub uniform: bool, - pub ret_params: Vec, - pub func: P::Id, - pub param_list: Vec, -} - -pub trait ArgParams { - type Id; - type Operand; -} - -pub struct ParsedArgParams<'a> { - _marker: PhantomData<&'a ()>, -} - -impl<'a> ArgParams for ParsedArgParams<'a> { - type Id = &'a str; - type Operand = Operand<&'a str>; -} - -pub struct Arg1 { - pub src: P::Id, // it is a jump destination, but in terms of operands it is a source operand -} - -pub struct Arg1Bar { - pub src: P::Operand, -} - -pub struct Arg2 { - pub dst: P::Operand, - pub src: P::Operand, -} -pub struct Arg2Ld { - pub dst: P::Operand, - pub src: P::Operand, -} - -pub struct Arg2St { - pub src1: P::Operand, - pub src2: P::Operand, -} - -pub struct Arg2Mov { - pub dst: P::Operand, - pub src: P::Operand, -} - -pub struct Arg3 { - pub dst: P::Operand, - pub src1: P::Operand, - pub src2: P::Operand, -} - -pub struct Arg4 { - pub dst: P::Operand, - pub src1: P::Operand, - pub src2: P::Operand, - pub src3: P::Operand, -} - -pub struct Arg4Setp { - pub dst1: P::Id, - pub dst2: Option, - pub src1: P::Operand, - pub src2: P::Operand, -} - -pub struct Arg5 { - pub dst: P::Operand, - pub src1: P::Operand, - pub src2: P::Operand, - pub src3: P::Operand, - pub src4: P::Operand, -} - -pub struct Arg5Setp { - pub dst1: P::Id, - pub dst2: Option, - pub src1: P::Operand, - pub src2: P::Operand, - pub src3: P::Operand, -} - -#[derive(Copy, Clone)] -pub enum ImmediateValue { - U64(u64), - S64(i64), - F32(f32), - F64(f64), -} - -#[derive(Clone)] -pub enum Operand { - Reg(Id), - RegOffset(Id, i32), - Imm(ImmediateValue), - VecMember(Id, u8), - VecPack(Vec), -} - -pub enum VectorPrefix { - V2, - V4, -} - -pub struct LdDetails { - pub qualifier: LdStQualifier, - pub state_space: StateSpace, - pub caching: LdCacheOperator, - pub typ: Type, - pub non_coherent: bool, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdStQualifier { - Weak, - Volatile, - Relaxed(MemScope), - Acquire(MemScope), -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum MemScope { - Cta, - Gpu, - Sys, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdCacheOperator { - Cached, - L2Only, - Streaming, - LastUse, - Uncached, -} - -#[derive(Clone)] -pub struct MovDetails { - pub typ: Type, - pub src_is_address: bool, - // two fields below are in use by member moves - pub dst_width: u8, - pub src_width: u8, - // This is in use by auto-generated movs - pub relaxed_src2_conv: bool, -} - -impl MovDetails { - pub fn new(typ: Type) -> Self { - MovDetails { - typ, - src_is_address: false, - dst_width: 0, - src_width: 0, - relaxed_src2_conv: false, - } - } -} - -#[derive(Copy, Clone)] -pub struct MulIntDesc { - pub typ: ScalarType, - pub control: MulIntControl, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum MulIntControl { - Low, - High, - Wide, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum RoundingMode { - NearestEven, - Zero, - NegativeInf, - PositiveInf, -} - -pub struct AddIntDesc { - pub typ: ScalarType, - pub saturate: bool, -} - -pub struct SetpData { - pub typ: ScalarType, - pub flush_to_zero: Option, - pub cmp_op: SetpCompareOp, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum SetpCompareOp { - Eq, - NotEq, - Less, - LessOrEq, - Greater, - GreaterOrEq, - NanEq, - NanNotEq, - NanLess, - NanLessOrEq, - NanGreater, - NanGreaterOrEq, - IsNotNan, - IsAnyNan, -} - -pub enum SetpBoolPostOp { - And, - Or, - Xor, -} - -pub struct SetpBoolData { - pub typ: ScalarType, - pub flush_to_zero: Option, - pub cmp_op: SetpCompareOp, - pub bool_op: SetpBoolPostOp, -} - -pub struct BraData { - pub uniform: bool, -} - -pub enum CvtDetails { - IntFromInt(CvtIntToIntDesc), - FloatFromFloat(CvtDesc), - IntFromFloat(CvtDesc), - FloatFromInt(CvtDesc), -} - -pub struct CvtIntToIntDesc { - pub dst: ScalarType, - pub src: ScalarType, - pub saturate: bool, -} - -pub struct CvtDesc { - pub rounding: Option, - pub flush_to_zero: Option, - pub saturate: bool, - pub dst: ScalarType, - pub src: ScalarType, -} - -impl CvtDetails { - pub fn new_int_from_int_checked<'err, 'input>( - saturate: bool, - dst: ScalarType, - src: ScalarType, - err: &'err mut Vec, PtxError>>, - ) -> Self { - if saturate { - if src.kind() == ScalarKind::Signed { - if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() { - err.push(ParseError::from(PtxError::SyntaxError)); - } - } else { - if dst == src || dst.size_of() >= src.size_of() { - err.push(ParseError::from(PtxError::SyntaxError)); - } - } - } - CvtDetails::IntFromInt(CvtIntToIntDesc { dst, src, saturate }) - } - - pub fn new_float_from_int_checked<'err, 'input>( - rounding: RoundingMode, - flush_to_zero: bool, - saturate: bool, - dst: ScalarType, - src: ScalarType, - err: &'err mut Vec, PtxError>>, - ) -> Self { - if flush_to_zero && dst != ScalarType::F32 { - err.push(ParseError::from(PtxError::NonF32Ftz)); - } - CvtDetails::FloatFromInt(CvtDesc { - dst, - src, - saturate, - flush_to_zero: Some(flush_to_zero), - rounding: Some(rounding), - }) - } - - pub fn new_int_from_float_checked<'err, 'input>( - rounding: RoundingMode, - flush_to_zero: bool, - saturate: bool, - dst: ScalarType, - src: ScalarType, - err: &'err mut Vec, PtxError>>, - ) -> Self { - if flush_to_zero && src != ScalarType::F32 { - err.push(ParseError::from(PtxError::NonF32Ftz)); - } - CvtDetails::IntFromFloat(CvtDesc { - dst, - src, - saturate, - flush_to_zero: Some(flush_to_zero), - rounding: Some(rounding), - }) - } -} - -pub struct CvtaDetails { - pub to: StateSpace, - pub from: StateSpace, - pub size: CvtaSize, -} - -pub enum CvtaSize { - U32, - U64, -} - -pub struct StData { - pub qualifier: LdStQualifier, - pub state_space: StateSpace, - pub caching: StCacheOperator, - pub typ: Type, -} - -#[derive(PartialEq, Eq)] -pub enum StCacheOperator { - Writeback, - L2Only, - Streaming, - Writethrough, -} - -pub struct RetData { - pub uniform: bool, -} - -#[derive(Copy, Clone)] -pub enum MulDetails { - Unsigned(MulUInt), - Signed(MulSInt), - Float(ArithFloat), -} - -#[derive(Copy, Clone)] -pub struct MulUInt { - pub typ: ScalarType, - pub control: MulIntControl, -} - -#[derive(Copy, Clone)] -pub struct MulSInt { - pub typ: ScalarType, - pub control: MulIntControl, -} - -#[derive(Copy, Clone)] -pub enum ArithDetails { - Unsigned(ScalarType), - Signed(ArithSInt), - Float(ArithFloat), -} - -#[derive(Copy, Clone)] -pub struct ArithSInt { - pub typ: ScalarType, - pub saturate: bool, -} - -#[derive(Copy, Clone)] -pub struct ArithFloat { - pub typ: ScalarType, - pub rounding: Option, - pub flush_to_zero: Option, - pub saturate: bool, -} - -#[derive(Copy, Clone)] -pub enum MinMaxDetails { - Signed(ScalarType), - Unsigned(ScalarType), - Float(MinMaxFloat), -} - -#[derive(Copy, Clone)] -pub struct MinMaxFloat { - pub flush_to_zero: Option, - pub nan: bool, - pub typ: ScalarType, -} - -#[derive(Copy, Clone)] -pub struct AtomDetails { - pub semantics: AtomSemantics, - pub scope: MemScope, - pub space: StateSpace, - pub inner: AtomInnerDetails, -} - -#[derive(Copy, Clone)] -pub enum AtomSemantics { - Relaxed, - Acquire, - Release, - AcquireRelease, -} - -#[derive(Copy, Clone)] -pub enum AtomInnerDetails { - Bit { op: AtomBitOp, typ: ScalarType }, - Unsigned { op: AtomUIntOp, typ: ScalarType }, - Signed { op: AtomSIntOp, typ: ScalarType }, - Float { op: AtomFloatOp, typ: ScalarType }, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomBitOp { - And, - Or, - Xor, - Exchange, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomUIntOp { - Add, - Inc, - Dec, - Min, - Max, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomSIntOp { - Add, - Min, - Max, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomFloatOp { - Add, -} - -#[derive(Copy, Clone)] -pub struct AtomCasDetails { - pub semantics: AtomSemantics, - pub scope: MemScope, - pub space: StateSpace, - pub typ: ScalarType, -} - -#[derive(Copy, Clone)] -pub enum DivDetails { - Unsigned(ScalarType), - Signed(ScalarType), - Float(DivFloatDetails), -} - -#[derive(Copy, Clone)] -pub struct DivFloatDetails { - pub typ: ScalarType, - pub flush_to_zero: Option, - pub kind: DivFloatKind, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum DivFloatKind { - Approx, - Full, - Rounding(RoundingMode), -} - -pub enum NumsOrArrays<'a> { - Nums(Vec<(&'a str, u32)>), - Arrays(Vec>), -} - -#[derive(Copy, Clone)] -pub struct SqrtDetails { - pub typ: ScalarType, - pub flush_to_zero: Option, - pub kind: SqrtKind, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum SqrtKind { - Approx, - Rounding(RoundingMode), -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub struct RsqrtDetails { - pub typ: ScalarType, - pub flush_to_zero: bool, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub struct NegDetails { - pub typ: ScalarType, - pub flush_to_zero: Option, -} - -impl<'a> NumsOrArrays<'a> { - pub fn to_vec(self, typ: ScalarType, dimensions: &mut [u32]) -> Result, PtxError> { - self.normalize_dimensions(dimensions)?; - let sizeof_t = ScalarType::from(typ).size_of() as usize; - let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize)); - let mut result = vec![0; result_size]; - self.parse_and_copy(typ, sizeof_t, dimensions, &mut result)?; - Ok(result) - } - - fn normalize_dimensions(&self, dimensions: &mut [u32]) -> Result<(), PtxError> { - match dimensions.first_mut() { - Some(first) => { - if *first == 0 { - *first = match self { - NumsOrArrays::Nums(v) => v.len() as u32, - NumsOrArrays::Arrays(v) => v.len() as u32, - }; - } - } - None => return Err(PtxError::ZeroDimensionArray), - } - for dim in dimensions { - if *dim == 0 { - return Err(PtxError::ZeroDimensionArray); - } - } - Ok(()) - } - - fn parse_and_copy( - &self, - t: ScalarType, - size_of_t: usize, - dimensions: &[u32], - result: &mut [u8], - ) -> Result<(), PtxError> { - match dimensions { - [] => unreachable!(), - [dim] => match self { - NumsOrArrays::Nums(vec) => { - if vec.len() > *dim as usize { - return Err(PtxError::ZeroDimensionArray); - } - for (idx, (val, radix)) in vec.iter().enumerate() { - Self::parse_and_copy_single(t, idx, val, *radix, result)?; - } - } - NumsOrArrays::Arrays(_) => return Err(PtxError::ZeroDimensionArray), - }, - [first_dim, rest @ ..] => match self { - NumsOrArrays::Arrays(vec) => { - if vec.len() > *first_dim as usize { - return Err(PtxError::ZeroDimensionArray); - } - let size_of_element = rest.iter().fold(size_of_t, |x, y| x * (*y as usize)); - for (idx, this) in vec.iter().enumerate() { - this.parse_and_copy( - t, - size_of_t, - rest, - &mut result[(size_of_element * idx)..], - )?; - } - } - NumsOrArrays::Nums(_) => return Err(PtxError::ZeroDimensionArray), - }, - } - Ok(()) - } - - fn parse_and_copy_single( - t: ScalarType, - idx: usize, - str_val: &str, - radix: u32, - output: &mut [u8], - ) -> Result<(), PtxError> { - match t { - ScalarType::B8 | ScalarType::U8 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::B16 | ScalarType::U16 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::B32 | ScalarType::U32 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::B64 | ScalarType::U64 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::S8 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::S16 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::S32 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::S64 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::F16 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::F16x2 => todo!(), - ScalarType::F32 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::F64 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::Pred => todo!(), - } - Ok(()) - } - - fn parse_and_copy_single_t( - idx: usize, - str_val: &str, - _radix: u32, // TODO: use this to properly support hex literals - output: &mut [u8], - ) -> Result<(), PtxError> - where - T::Err: Into, - { - let typed_output = unsafe { - std::slice::from_raw_parts_mut::( - output.as_mut_ptr() as *mut _, - output.len() / mem::size_of::(), - ) - }; - typed_output[idx] = str_val.parse::().map_err(|e| e.into())?; - Ok(()) - } -} - -pub enum ArrayOrPointer { - Array { dimensions: Vec, init: Vec }, - Pointer, -} - -bitflags! { - pub struct LinkingDirective: u8 { - const NONE = 0b000; - const EXTERN = 0b001; - const VISIBLE = 0b10; - const WEAK = 0b100; - } -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum TuningDirective { - MaxNReg(u32), - MaxNtid(u32, u32, u32), - ReqNtid(u32, u32, u32), - MinNCtaPerSm(u32), -} - -#[derive(Clone, Copy, PartialEq, Eq)] -pub enum ScalarKind { - Bit, - Unsigned, - Signed, - Float, - Float2, - Pred, -} - -impl ScalarType { - pub fn kind(self) -> ScalarKind { - match self { - ScalarType::U8 => ScalarKind::Unsigned, - ScalarType::U16 => ScalarKind::Unsigned, - ScalarType::U32 => ScalarKind::Unsigned, - ScalarType::U64 => ScalarKind::Unsigned, - ScalarType::S8 => ScalarKind::Signed, - ScalarType::S16 => ScalarKind::Signed, - ScalarType::S32 => ScalarKind::Signed, - ScalarType::S64 => ScalarKind::Signed, - ScalarType::B8 => ScalarKind::Bit, - ScalarType::B16 => ScalarKind::Bit, - ScalarType::B32 => ScalarKind::Bit, - ScalarType::B64 => ScalarKind::Bit, - ScalarType::F16 => ScalarKind::Float, - ScalarType::F32 => ScalarKind::Float, - ScalarType::F64 => ScalarKind::Float, - ScalarType::F16x2 => ScalarKind::Float2, - ScalarType::Pred => ScalarKind::Pred, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn array_fails_multiple_0_dmiensions() { - let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(ScalarType::B8, &mut vec![0, 0]).is_err()); - } - - #[test] - fn array_fails_on_empty() { - let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(ScalarType::B8, &mut vec![0]).is_err()); - } - - #[test] - fn array_auto_sizes_0_dimension() { - let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]), - NumsOrArrays::Nums(vec![("3", 10), ("4", 10)]), - ]); - let mut dimensions = vec![0u32, 2]; - assert_eq!( - vec![1u8, 2, 3, 4], - inp.to_vec(ScalarType::B8, &mut dimensions).unwrap() - ); - assert_eq!(dimensions, vec![2u32, 2]); - } - - #[test] - fn array_fails_wrong_structure() { - let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]), - NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]), - ]); - let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err()); - } - - #[test] - fn array_fails_too_long_component() { - let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec![("1", 10), ("2", 10), ("3", 10)]), - NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]), - ]); - let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err()); - } -} diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 1cb9630..3798fdd 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -1,10 +1,7 @@ #[cfg(test)] extern crate paste; #[macro_use] -extern crate lalrpop_util; -#[macro_use] extern crate quick_error; - extern crate bit_vec; extern crate half; #[cfg(test)] @@ -18,168 +15,12 @@ extern crate spirv_tools_sys as spirv_tools; #[macro_use] extern crate bitflags; -lalrpop_mod!( - #[allow(warnings)] - ptx -); - -pub mod ast; #[cfg(test)] mod test; mod translate; -use std::fmt; - -pub use crate::ptx::ModuleParser; -use ast::PtxError; -pub use lalrpop_util::lexer::Token; -pub use lalrpop_util::ParseError; pub use rspirv::dr::Error as SpirvError; pub use translate::to_spirv_module; pub use translate::KernelInfo; pub use translate::TranslateError; - -pub trait ModuleParserExt { - fn parse_checked<'input>( - txt: &'input str, - ) -> Result, Vec, ast::PtxError>>>; - - // Returned AST might be malformed. Some users, like logger, want to look at - // malformed AST to record information - list of kernels or such - fn parse_unchecked<'input>( - txt: &'input str, - ) -> ( - ast::Module<'input>, - Vec, ast::PtxError>>, - ); -} - -impl ModuleParserExt for ModuleParser { - fn parse_checked<'input>( - txt: &'input str, - ) -> Result, Vec, ast::PtxError>>> { - let mut errors = Vec::new(); - let maybe_ast = ptx::ModuleParser::new().parse(&mut errors, txt); - match (&*errors, maybe_ast) { - (&[], Ok(ast)) => Ok(ast), - (_, Err(unrecoverable)) => { - errors.push(unrecoverable); - Err(errors) - } - (_, Ok(_)) => Err(errors), - } - } - - fn parse_unchecked<'input>( - txt: &'input str, - ) -> ( - ast::Module<'input>, - Vec, ast::PtxError>>, - ) { - let mut errors = Vec::new(); - let maybe_ast = ptx::ModuleParser::new().parse(&mut errors, txt); - let ast = match maybe_ast { - Ok(ast) => ast, - Err(unrecoverable_err) => { - errors.push(unrecoverable_err); - ast::Module { - version: (0, 0), - directives: Vec::new(), - } - } - }; - (ast, errors) - } -} - -pub struct DisplayParseError<'a, Loc, Tok, Err>(&'a str, &'a ParseError); - -impl<'a, Loc: fmt::Display + Into + Copy, Tok, Err> DisplayParseError<'a, Loc, Tok, Err> { - // unsafe because there's no guarantee that the input str is the one that this error was created from - pub unsafe fn new(error: &'a ParseError, text: &'a str) -> Self { - Self(text, error) - } -} - -impl<'a, Loc, Tok> fmt::Display for DisplayParseError<'a, Loc, Tok, PtxError> -where - Loc: fmt::Display, - Tok: fmt::Display, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.1 { - ParseError::User { - error: PtxError::UnrecognizedStatement { start, end }, - } => self.fmt_unrecognized(f, *start, *end, "statement"), - ParseError::User { - error: PtxError::UnrecognizedDirective { start, end }, - } => self.fmt_unrecognized(f, *start, *end, "directive"), - _ => self.1.fmt(f), - } - } -} - -impl<'a, Loc, Tok, Err> DisplayParseError<'a, Loc, Tok, Err> { - fn fmt_unrecognized( - &self, - f: &mut fmt::Formatter, - start: usize, - end: usize, - kind: &'static str, - ) -> fmt::Result { - let full_substring = unsafe { self.0.get_unchecked(start..end) }; - write!( - f, - "Unrecognized {} `{}` found at {}:{}", - kind, full_substring, start, end - ) - } -} - -pub(crate) fn without_none(x: Vec>) -> Vec { - x.into_iter().filter_map(|x| x).collect() -} - -pub(crate) fn vector_index<'input>( - inp: &'input str, -) -> Result, ast::PtxError>> { - match inp { - "x" | "r" => Ok(0), - "y" | "g" => Ok(1), - "z" | "b" => Ok(2), - "w" | "a" => Ok(3), - _ => Err(ParseError::User { - error: ast::PtxError::WrongVectorElement, - }), - } -} - -#[cfg(test)] -mod tests { - use crate::{DisplayParseError, ModuleParser, ModuleParserExt}; - - #[test] - fn error_report_unknown_instructions() { - let module = r#" - .version 6.5 - .target sm_30 - .address_size 64 - - .visible .entry add( - .param .u64 input, - ) - { - .reg .u64 x; - does_not_exist.u64 x, x; - ret; - }"#; - let errors = match ModuleParser::parse_checked(module) { - Err(e) => e, - Ok(_) => panic!(), - }; - assert_eq!(errors.len(), 1); - let reporter = DisplayParseError(module, &errors[0]); - let build_log_string = format!("{}", reporter); - assert!(build_log_string.contains("does_not_exist")); - } -} +use ptx_parser as ast; diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index 0785f3e..32b968a 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -1,4 +1,3 @@ -use super::ptx; use super::TranslateError; mod spirv_run; diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index db1063b..0797919 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,5 +1,5 @@ -use crate::ast; use half::f16; +use ptx_parser as ast; use rspirv::dr; use std::cell::RefCell; use std::collections::{hash_map, BTreeMap, HashMap, HashSet}; @@ -57,10 +57,6 @@ impl SpirvType { ast::Type::Scalar(t) => SpirvType::Base(t.into()), ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), - ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer( - Box::new(SpirvType::Base(pointer_t.into())), - space.to_spirv(), - ), } } @@ -77,9 +73,9 @@ impl From for SpirvType { } struct TypeWordMap { - void: spirv::Word, - complex: HashMap, - constants: HashMap<(SpirvType, u64), spirv::Word>, + void: SpirvWord, + complex: HashMap, + constants: HashMap<(SpirvType, u64), SpirvWord>, } // SPIR-V integer type definitions are signless, more below: @@ -116,6 +112,12 @@ impl From for SpirvScalarKey { ast::ScalarType::F64 => SpirvScalarKey::F64, ast::ScalarType::F16x2 => SpirvScalarKey::F16x2, ast::ScalarType::Pred => SpirvScalarKey::Pred, + ptx_parser::ScalarType::S16x2 => todo!(), + ptx_parser::ScalarType::B128 => todo!(), + ptx_parser::ScalarType::U16x2 => todo!(), + ptx_parser::ScalarType::F16x2 => todo!(), + ptx_parser::ScalarType::BF16 => todo!(), + ptx_parser::ScalarType::BF16x2 => todo!(), } } } @@ -124,26 +126,24 @@ impl TypeWordMap { fn new(b: &mut dr::Builder) -> TypeWordMap { let void = b.type_void(None); TypeWordMap { - void: void, - complex: HashMap::::new(), + void: SpirvWord(void), + complex: HashMap::::new(), constants: HashMap::new(), } } - fn void(&self) -> spirv::Word { + fn void(&self) -> SpirvWord { self.void } - fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word { + fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> SpirvWord { let key: SpirvScalarKey = t.into(); self.get_or_add_spirv_scalar(b, key) } - fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> spirv::Word { - *self - .complex - .entry(SpirvType::Base(key)) - .or_insert_with(|| match key { + fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> SpirvWord { + *self.complex.entry(SpirvType::Base(key)).or_insert_with(|| { + SpirvWord(match key { SpirvScalarKey::B8 => b.type_int(None, 8, 0), SpirvScalarKey::B16 => b.type_int(None, 16, 0), SpirvScalarKey::B32 => b.type_int(None, 32, 0), @@ -154,9 +154,10 @@ impl TypeWordMap { SpirvScalarKey::Pred => b.type_bool(None), SpirvScalarKey::F16x2 => todo!(), }) + }) } - fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word { + fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> SpirvWord { match t { SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key), SpirvType::Pointer(ref typ, storage) => { @@ -229,7 +230,7 @@ impl TypeWordMap { b: &mut dr::Builder, in_params: impl Iterator, mut out_params: impl ExactSizeIterator, - ) -> (spirv::Word, spirv::Word) { + ) -> (SpirvWord, SpirvWord) { let (out_args, out_spirv_type) = if out_params.len() == 0 { (None, self.void()) } else if out_params.len() == 1 { @@ -253,7 +254,7 @@ impl TypeWordMap { b: &mut dr::Builder, typ: &ast::Type, init: &[u8], - ) -> Result { + ) -> Result { Ok(match typ { ast::Type::Scalar(t) => match t { ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self @@ -323,6 +324,12 @@ impl TypeWordMap { } }, ), + ptx_parser::ScalarType::S16x2 => todo!(), + ptx_parser::ScalarType::B128 => todo!(), + ptx_parser::ScalarType::U16x2 => todo!(), + ptx_parser::ScalarType::F16x2 => todo!(), + ptx_parser::ScalarType::BF16 => todo!(), + ptx_parser::ScalarType::BF16x2 => todo!(), }, ast::Type::Vector(typ, len) => { let result_type = @@ -383,7 +390,7 @@ impl TypeWordMap { fn get_or_add_constant_single< T: Copy, CastAsU64: FnOnce(T) -> u64, - InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word, + InsertConstant: FnOnce(&mut dr::Builder, SpirvWord, T) -> SpirvWord, >( &mut self, b: &mut dr::Builder, @@ -391,7 +398,7 @@ impl TypeWordMap { init: &[u8], cast: CastAsU64, f: InsertConstant, - ) -> spirv::Word { + ) -> SpirvWord { let value = unsafe { *(init.as_ptr() as *const T) }; let value_64 = cast(value); let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64); @@ -488,7 +495,7 @@ fn get_globals_use_map<'input>( directives: Vec>, ) -> ( Vec>, - HashMap, HashSet>, + HashMap, HashSet>, ) { let mut known_globals = HashSet::new(); for directive in directives.iter() { @@ -561,7 +568,7 @@ fn hoist_function_globals(directives: Vec) -> Vec { fn emit_denorm_build_string<'input>( call_map: &MethodsCallMap, denorm_information: &HashMap< - ast::MethodName<'input, spirv::Word>, + ast::MethodName<'input, SpirvWord>, HashMap, >, ) -> (CString, bool) { @@ -604,10 +611,10 @@ fn emit_directives<'input>( builder: &mut dr::Builder, map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver<'input>, - opencl_id: spirv::Word, + opencl_id: SpirvWord, should_flush_denorms: bool, call_map: &MethodsCallMap<'input>, - globals_use_map: HashMap, HashSet>, + globals_use_map: HashMap, HashSet>, directives: Vec>, kernel_info: &mut HashMap, ) -> Result<(), TranslateError> { @@ -697,7 +704,7 @@ fn emit_function_linkage<'input>( builder: &mut dr::Builder, id_defs: &GlobalStringIdResolver<'input>, f: &Function, - fn_name: spirv::Word, + fn_name: SpirvWord, ) -> Result<(), TranslateError> { if f.linkage == ast::LinkingDirective::NONE { return Ok(()); @@ -718,7 +725,7 @@ fn emit_function_linkage<'input>( } struct MethodsCallMap<'input> { - map: HashMap, HashSet>, + map: HashMap, HashSet>, } impl<'input> MethodsCallMap<'input> { @@ -759,9 +766,9 @@ impl<'input> MethodsCallMap<'input> { } fn add_call_map_single( - directly_called_by: &HashMap, Vec>, - visited: &mut HashSet, - current: spirv::Word, + directly_called_by: &HashMap, Vec>, + visited: &mut HashSet, + current: SpirvWord, ) { if !visited.insert(current) { return; @@ -773,14 +780,14 @@ impl<'input> MethodsCallMap<'input> { } } - fn get_kernel_children(&self, name: &'input str) -> impl Iterator { + fn get_kernel_children(&self, name: &'input str) -> impl Iterator { self.map .get(&ast::MethodName::Kernel(name)) .into_iter() .flatten() } - fn kernels(&self) -> impl Iterator)> { + fn kernels(&self) -> impl Iterator)> { self.map .iter() .filter_map(|(method, children)| match method { @@ -791,17 +798,13 @@ impl<'input> MethodsCallMap<'input> { fn methods( &self, - ) -> impl Iterator, &HashSet)> { + ) -> impl Iterator, &HashSet)> { self.map .iter() .map(|(method, children)| (*method, children)) } - fn visit_callees( - &self, - method: ast::MethodName<'input, spirv::Word>, - f: impl FnMut(spirv::Word), - ) { + fn visit_callees(&self, method: ast::MethodName<'input, SpirvWord>, f: impl FnMut(SpirvWord)) { self.map .get(&method) .into_iter() @@ -857,7 +860,7 @@ fn multi_hash_map_append< fn convert_dynamic_shared_memory_usage<'input>( module: Vec>, kernels_methods_call_map: &MethodsCallMap<'input>, - new_id: &mut impl FnMut() -> spirv::Word, + new_id: &mut impl FnMut() -> SpirvWord, ) -> Vec> { let mut globals_shared = HashMap::new(); for dir in module.iter() { @@ -879,7 +882,7 @@ fn convert_dynamic_shared_memory_usage<'input>( if globals_shared.len() == 0 { return module; } - let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet>::new(); + let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet>::new(); let module = module .into_iter() .map(|directive| match directive { @@ -970,14 +973,14 @@ fn insert_arguments_remap_statements<'input>( kernels_methods_call_map: &MethodsCallMap<'input>, globals_shared: &HashMap, methods_to_indirectly_used_shared_globals: &HashMap< - ast::MethodName<'input, spirv::Word>, - BTreeSet, + ast::MethodName<'input, SpirvWord>, + BTreeSet, >, method_name: ast::MethodName, result: &mut Vec, func_decl_ref: &mut std::cell::RefMut>, - statements: Vec, ExpandedArgParams>>, -) -> Vec, ExpandedArgParams>> { + statements: Vec, +) -> Vec { let remapped_globals_in_method = if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) { match method_name { @@ -1053,13 +1056,13 @@ impl GlobalSharedSize { } fn replace_uses_of_shared_memory<'input>( - new_id: &mut impl FnMut() -> spirv::Word, + new_id: &mut impl FnMut() -> SpirvWord, methods_to_indirectly_used_shared_globals: &HashMap< - ast::MethodName<'input, spirv::Word>, - BTreeSet, + ast::MethodName<'input, SpirvWord>, + BTreeSet, >, statements: Vec, - remapped_globals_in_method: BTreeMap, + remapped_globals_in_method: BTreeMap, ) -> Vec { let mut result = Vec::with_capacity(statements.len()); for statement in statements { @@ -1103,12 +1106,9 @@ fn replace_uses_of_shared_memory<'input>( // * If it's a kernel -> size of .shared globals in use (direct or indirect) // * If it's a function -> does it use .shared global (directly or indirectly) fn resolve_indirect_uses_of_globals_shared<'input>( - methods_use_of_globals_shared: HashMap< - ast::MethodName<'input, spirv::Word>, - HashSet, - >, + methods_use_of_globals_shared: HashMap, HashSet>, kernels_methods_call_map: &MethodsCallMap<'input>, -) -> HashMap, BTreeSet> { +) -> HashMap, BTreeSet> { let mut result = HashMap::new(); for (method, callees) in kernels_methods_call_map.methods() { let mut indirect_globals = methods_use_of_globals_shared @@ -1161,7 +1161,7 @@ fn denorm_count_map_update_impl( // and emit suitable execution mode fn compute_denorm_information<'input>( module: &[Directive<'input>], -) -> HashMap, HashMap> { +) -> HashMap, HashMap> { let mut denorm_methods = HashMap::new(); for directive in module { match directive { @@ -1221,11 +1221,11 @@ fn emit_function_header<'input>( builder: &mut dr::Builder, map: &mut TypeWordMap, defined_globals: &GlobalStringIdResolver<'input>, - func_decl: &ast::MethodDeclaration<'input, spirv::Word>, + func_decl: &ast::MethodDeclaration<'input, SpirvWord>, call_map: &MethodsCallMap<'input>, - globals_use_map: &HashMap, HashSet>, + globals_use_map: &HashMap, HashSet>, kernel_info: &mut HashMap, -) -> Result { +) -> Result { if let ast::MethodName::Kernel(name) = func_decl.name { let args_lens = func_decl .input_arguments @@ -1272,7 +1272,7 @@ fn emit_function_header<'input>( }) .into_iter() }) - .collect::>(); + .collect::>(); builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, interface); fn_id } @@ -1314,7 +1314,7 @@ fn emit_extensions(builder: &mut dr::Builder) { builder.extension("SPV_KHR_no_integer_wrap_decoration"); } -fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word { +fn emit_opencl_import(builder: &mut dr::Builder) -> SpirvWord { builder.ext_inst_import("OpenCL.std") } @@ -1328,7 +1328,7 @@ fn emit_memory_model(builder: &mut dr::Builder) { fn translate_directive<'input, 'a>( id_defs: &'a mut GlobalStringIdResolver<'input>, ptx_impl_imports: &'a mut HashMap>, - d: ast::Directive<'input, ast::ParsedArgParams<'input>>, + d: ast::Directive<'input, ast::ParsedOperand<&'input str>>, ) -> Result>, TranslateError> { Ok(match d { ast::Directive::Variable(linking, var) => Some(Directive::Variable( @@ -1347,11 +1347,13 @@ fn translate_directive<'input, 'a>( }) } +type ParsedFunction<'a> = ast::Function<'a, &'a str, ast::Statement>>; + fn translate_function<'input, 'a>( id_defs: &'a mut GlobalStringIdResolver<'input>, ptx_impl_imports: &'a mut HashMap>, linkage: ast::LinkingDirective, - f: ast::ParsedFunction<'input>, + f: ParsedFunction<'input>, ) -> Result>, TranslateError> { let import_as = match &f.func_directive { ast::MethodDeclaration { @@ -1387,7 +1389,7 @@ fn translate_function<'input, 'a>( fn rename_fn_params<'a, 'b>( fn_resolver: &mut FnStringIdResolver<'a, 'b>, args: &'b [ast::Variable<&'a str>], -) -> Vec> { +) -> Vec> { args.iter() .map(|a| ast::Variable { name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true), @@ -1403,8 +1405,8 @@ fn to_ssa<'input, 'b>( ptx_impl_imports: &'b mut HashMap>, mut id_defs: FnStringIdResolver<'input, 'b>, fn_defs: GlobalFnDeclResolver<'input, 'b>, - func_decl: Rc>>, - f_body: Option>>>, + func_decl: Rc>>, + f_body: Option>>>, tuning: Vec, linkage: ast::LinkingDirective, ) -> Result, TranslateError> { @@ -1423,6 +1425,8 @@ fn to_ssa<'input, 'b>( } }; let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?; + todo!() + /* let mut numeric_id_defs = id_defs.finish(); let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = @@ -1452,6 +1456,7 @@ fn to_ssa<'input, 'b>( tuning, linkage, }) + */ } fn fix_special_registers2<'a, 'b, 'input>( @@ -1509,9 +1514,9 @@ struct SpecialRegisterResolver<'a, 'b, 'input> { impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { fn replace_sreg( &mut self, - desc: ArgumentDescriptor, + desc: ArgumentDescriptor, vector_index: Option, - ) -> Result { + ) -> Result { if let Some(sreg) = self.numeric_id_defs.special_registers.get(desc.op) { if desc.is_dst { return Err(TranslateError::MismatchedType); @@ -1557,7 +1562,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { return_arguments.iter().map(|(_, typ, space)| (typ, *space)), input_arguments.iter().map(|(_, typ, space)| (typ, *space)), )?; - self.result.push(Statement::Call(ResolvedCall { + self.result.push(Statement::Instruction(ast::Instruction::Call { uniform: false, return_arguments, name: fn_call, @@ -1570,104 +1575,66 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { } } -impl<'a, 'b, 'input> ArgumentMapVisitor - for SpecialRegisterResolver<'a, 'b, 'input> -{ - fn id( - &mut self, - desc: ArgumentDescriptor, - _: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - self.replace_sreg(desc, None) - } - - fn operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result { - Ok(match desc.op { - TypedOperand::Reg(reg) => TypedOperand::Reg(self.replace_sreg(desc.new_op(reg), None)?), - op @ TypedOperand::RegOffset(_, _) => op, - op @ TypedOperand::Imm(_) => op, - TypedOperand::VecMember(reg, idx) => { - TypedOperand::VecMember(self.replace_sreg(desc.new_op(reg), Some(idx))?, idx) - } - }) - } -} - fn extract_globals<'input, 'b>( sorted_statements: Vec, ptx_impl_imports: &mut HashMap, id_def: &mut NumericIdResolver, -) -> Result<(Vec, Vec>), TranslateError> { +) -> Result<(Vec, Vec>), TranslateError> { let mut local = Vec::with_capacity(sorted_statements.len()); let mut global = Vec::new(); for statement in sorted_statements { match statement { Statement::Variable( - var - @ - ast::Variable { + var @ ast::Variable { state_space: ast::StateSpace::Shared, .. }, ) | Statement::Variable( - var - @ - ast::Variable { + var @ ast::Variable { state_space: ast::StateSpace::Global, .. }, ) => global.push(var), - Statement::Instruction(ast::Instruction::Bfe { typ, arg }) => { - let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", typ.to_ptx_name()].concat(); + Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", data.to_ptx_name()].concat(); local.push(instruction_to_fn_call( id_def, ptx_impl_imports, - ast::Instruction::Bfe { typ, arg }, + ast::Instruction::Bfe { data, arguments }, fn_name, )?); } - Statement::Instruction(ast::Instruction::Bfi { typ, arg }) => { - let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", typ.to_ptx_name()].concat(); + Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", data.to_ptx_name()].concat(); local.push(instruction_to_fn_call( id_def, ptx_impl_imports, - ast::Instruction::Bfi { typ, arg }, + ast::Instruction::Bfi { data, arguments }, fn_name, )?); } - Statement::Instruction(ast::Instruction::Brev { typ, arg }) => { - let fn_name = [ZLUDA_PTX_PREFIX, "brev_", typ.to_ptx_name()].concat(); + Statement::Instruction(ast::Instruction::Brev { data, arguments }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "brev_", data.to_ptx_name()].concat(); local.push(instruction_to_fn_call( id_def, ptx_impl_imports, - ast::Instruction::Brev { typ, arg }, + ast::Instruction::Brev { data, arguments }, fn_name, )?); } - Statement::Instruction(ast::Instruction::Activemask { arg }) => { + Statement::Instruction(ast::Instruction::Activemask { arguments }) => { let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat(); local.push(instruction_to_fn_call( id_def, ptx_impl_imports, - ast::Instruction::Activemask { arg }, + ast::Instruction::Activemask { arguments }, fn_name, )?); } Statement::Instruction(ast::Instruction::Atom( - details - @ - ast::AtomDetails { - inner: - ast::AtomInnerDetails::Unsigned { - op: ast::AtomUIntOp::Inc, - .. - }, + details @ ast::AtomDetails { + op: ast::AtomicOp::IncrementWrap, .. }, args, @@ -1691,14 +1658,8 @@ fn extract_globals<'input, 'b>( )?); } Statement::Instruction(ast::Instruction::Atom( - details - @ - ast::AtomDetails { - inner: - ast::AtomInnerDetails::Unsigned { - op: ast::AtomUIntOp::Dec, - .. - }, + details @ ast::AtomDetails { + op: ast::AtomicOp::DecrementWrap, .. }, args, @@ -1722,14 +1683,8 @@ fn extract_globals<'input, 'b>( )?); } Statement::Instruction(ast::Instruction::Atom( - details - @ - ast::AtomDetails { - inner: - ast::AtomInnerDetails::Float { - op: ast::AtomFloatOp::Add, - .. - }, + details @ ast::AtomDetails { + op: ast::AtomicOp::FloatAdd, .. }, args, @@ -1759,63 +1714,62 @@ fn extract_globals<'input, 'b>( Ok((local, global)) } -impl ast::ScalarType { - fn to_ptx_name(self) -> &'static str { - match self { - ast::ScalarType::B8 => "b8", - ast::ScalarType::B16 => "b16", - ast::ScalarType::B32 => "b32", - ast::ScalarType::B64 => "b64", - ast::ScalarType::U8 => "u8", - ast::ScalarType::U16 => "u16", - ast::ScalarType::U32 => "u32", - ast::ScalarType::U64 => "u64", - ast::ScalarType::S8 => "s8", - ast::ScalarType::S16 => "s16", - ast::ScalarType::S32 => "s32", - ast::ScalarType::S64 => "s64", - ast::ScalarType::F16 => "f16", - ast::ScalarType::F32 => "f32", - ast::ScalarType::F64 => "f64", - ast::ScalarType::F16x2 => "f16x2", - ast::ScalarType::Pred => "pred", - } +fn type_to_ptx_name(this: ast::ScalarType) -> &'static str { + match this { + ast::ScalarType::B8 => "b8", + ast::ScalarType::B16 => "b16", + ast::ScalarType::B32 => "b32", + ast::ScalarType::B64 => "b64", + ast::ScalarType::U8 => "u8", + ast::ScalarType::U16 => "u16", + ast::ScalarType::U32 => "u32", + ast::ScalarType::U64 => "u64", + ast::ScalarType::S8 => "s8", + ast::ScalarType::S16 => "s16", + ast::ScalarType::S32 => "s32", + ast::ScalarType::S64 => "s64", + ast::ScalarType::F16 => "f16", + ast::ScalarType::F32 => "f32", + ast::ScalarType::F64 => "f64", + ast::ScalarType::F16x2 => "f16x2", + ast::ScalarType::Pred => "pred", + ptx_parser::ScalarType::S16x2 => todo!(), + ptx_parser::ScalarType::B128 => todo!(), + ptx_parser::ScalarType::U16x2 => todo!(), + ptx_parser::ScalarType::F16x2 => todo!(), + ptx_parser::ScalarType::BF16 => todo!(), + ptx_parser::ScalarType::BF16x2 => todo!(), } } -impl ast::AtomSemantics { - fn to_ptx_name(self) -> &'static str { - match self { - ast::AtomSemantics::Relaxed => "relaxed", - ast::AtomSemantics::Acquire => "acquire", - ast::AtomSemantics::Release => "release", - ast::AtomSemantics::AcquireRelease => "acq_rel", - } +fn sema_to_ptx_name(this: ast::AtomSemantics) -> &'static str { + match this { + ast::AtomSemantics::Relaxed => "relaxed", + ast::AtomSemantics::Acquire => "acquire", + ast::AtomSemantics::Release => "release", + ast::AtomSemantics::AcqRel => "acq_rel", } } -impl ast::MemScope { - fn to_ptx_name(self) -> &'static str { - match self { - ast::MemScope::Cta => "cta", - ast::MemScope::Gpu => "gpu", - ast::MemScope::Sys => "sys", - } +fn scope_to_ptx_name(this: ast::MemScope) -> &'static str { + match this { + ast::MemScope::Cta => "cta", + ast::MemScope::Gpu => "gpu", + ast::MemScope::Sys => "sys", + ptx_parser::MemScope::Cluster => "cluster", } } -impl ast::StateSpace { - fn to_ptx_name(self) -> &'static str { - match self { - ast::StateSpace::Generic => "generic", - ast::StateSpace::Global => "global", - ast::StateSpace::Shared => "shared", - ast::StateSpace::Reg => "reg", - ast::StateSpace::Const => "const", - ast::StateSpace::Local => "local", - ast::StateSpace::Param => "param", - ast::StateSpace::Sreg => "sreg", - } +fn space_to_ptx_name(this: ast::StateSpace) -> &'static str { + match this { + ast::StateSpace::Generic => "generic", + ast::StateSpace::Global => "global", + ast::StateSpace::Shared => "shared", + ast::StateSpace::Reg => "reg", + ast::StateSpace::Const => "const", + ast::StateSpace::Local => "local", + ast::StateSpace::Param => "param", + ast::StateSpace::Sreg => "sreg", } } @@ -1846,7 +1800,7 @@ fn convert_to_typed_statements( Statement::Instruction(inst) => match inst { ast::Instruction::Mov( mov, - ast::Arg2Mov { + ast::MovArgs { dst: ast::Operand::Reg(dst_reg), src: ast::Operand::Reg(src_reg), }, @@ -1859,14 +1813,6 @@ fn convert_to_typed_statements( src: src_reg, })); } - ast::Instruction::Call(call) => { - let resolver = fn_defs.get_fn_sig_resolver(call.func)?; - let resolved_call = resolver.resolve_in_spirv_repr(call)?; - let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); - let reresolved_call = resolved_call.visit(&mut visitor)?; - visitor.func.push(reresolved_call); - visitor.func.extend(visitor.post_stmts); - } inst => { let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let instruction = Statement::Instruction(inst.map(&mut visitor)?); @@ -1909,8 +1855,8 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { >, typ: &ast::Type, state_space: ast::StateSpace, - idx: Vec, - ) -> Result { + idx: Vec, + ) -> Result { // mov.u32 foobar, {a,b}; let scalar_t = match typ { ast::Type::Vector(scalar_t, _) => *scalar_t, @@ -1935,47 +1881,14 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { } } -impl<'a, 'b> ArgumentMapVisitor - for VectorRepackVisitor<'a, 'b> -{ - fn id( - &mut self, - desc: ArgumentDescriptor, - _: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - Ok(desc.op) - } - - fn operand( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result { - Ok(match desc.op { - ast::Operand::Reg(reg) => TypedOperand::Reg(reg), - ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), - ast::Operand::Imm(x) => TypedOperand::Imm(x), - ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), - ast::Operand::VecPack(vec) => TypedOperand::Reg(self.convert_vector( - desc.is_dst, - desc.non_default_implicit_conversion, - typ, - state_space, - vec, - )?), - }) - } -} - fn instruction_to_fn_call( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, - inst: ast::Instruction, + inst: ast::Instruction, fn_name: String, ) -> Result { let mut arguments = Vec::new(); - inst.visit(&mut |desc: ArgumentDescriptor, + inst.visit(&mut |desc: ArgumentDescriptor, typ: Option<(&ast::Type, ast::StateSpace)>| { let (typ, space) = match typ { Some((typ, space)) => (typ.clone(), space), @@ -2010,13 +1923,13 @@ fn register_external_fn_call<'a>( name: String, return_arguments: impl Iterator, input_arguments: impl Iterator, -) -> Result { +) -> Result { match ptx_impl_imports.entry(name) { hash_map::Entry::Vacant(entry) => { let fn_id = id_defs.register_intermediate(None); let return_arguments = fn_arguments_to_variables(id_defs, return_arguments); let input_arguments = fn_arguments_to_variables(id_defs, input_arguments); - let func_decl = ast::MethodDeclaration:: { + let func_decl = ast::MethodDeclaration:: { return_arguments, name: ast::MethodName::Func(fn_id), input_arguments, @@ -2046,7 +1959,7 @@ fn register_external_fn_call<'a>( fn fn_arguments_to_variables<'a>( id_defs: &mut NumericIdResolver, args: impl Iterator, -) -> Vec> { +) -> Vec> { args.map(|(typ, space)| ast::Variable { align: None, v_type: typ.clone(), @@ -2058,8 +1971,8 @@ fn fn_arguments_to_variables<'a>( } fn arguments_to_resolved_arguments( - args: &[(ArgumentDescriptor, ast::Type, ast::StateSpace)], -) -> Vec<(spirv::Word, ast::Type, ast::StateSpace)> { + args: &[(ArgumentDescriptor, ast::Type, ast::StateSpace)], +) -> Vec<(SpirvWord, ast::Type, ast::StateSpace)> { args.iter() .map(|(desc, typ, space)| (desc.op, typ.clone(), *space)) .collect::>() @@ -2144,113 +2057,10 @@ fn normalize_predicates( Ok(result) } -/* - How do we handle arguments: - - input .params in kernels - .param .b64 in_arg - get turned into this SPIR-V: - %1 = OpFunctionParameter %ulong - %2 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %1 - We do this for two reasons. One, common treatment for argument-declared - .param variables and .param variables inside function (we assume that - at SPIR-V level every .param is a pointer in Function storage class) - - input .params in functions - .param .b64 in_arg - get turned into this SPIR-V: - %1 = OpFunctionParameter %_ptr_Function_ulong - - input .regs - .reg .b64 in_arg - get turned into the same SPIR-V as kernel .params: - %1 = OpFunctionParameter %ulong - %2 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %1 - - output .regs - .reg .b64 out_arg - get just a variable declaration: - %2 = OpVariable %%_ptr_Function_ulong Function - - output .params don't exist, they have been moved to input positions - by an earlier pass - Distinguishing betweem kernel .params and function .params is not the - cleanest solution. Alternatively, we could "deparamize" all kernel .param - arguments by turning them into .reg arguments like this: - .param .b64 arg -> .reg ptr<.b64,.param> arg - This has the massive downside that this transformation would have to run - very early and would muddy up already difficult code. It's simpler to just - have an if here -*/ -fn insert_mem_ssa_statements<'a, 'b>( - func: Vec, - id_def: &mut NumericIdResolver, - fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>, -) -> Result, TranslateError> { - let mut result = Vec::with_capacity(func.len()); - for arg in fn_decl.input_arguments.iter_mut() { - insert_mem_ssa_argument(id_def, &mut result, arg, fn_decl.name.is_kernel()); - } - for arg in fn_decl.return_arguments.iter() { - insert_mem_ssa_argument_reg_return(&mut result, arg); - } - for s in func { - match s { - Statement::Call(call) => { - insert_mem_ssa_statement_default(id_def, &mut result, call.cast())? - } - Statement::Instruction(inst) => match inst { - ast::Instruction::Ret(d) => { - // TODO: handle multiple output args - match &fn_decl.return_arguments[..] { - [return_reg] => { - let new_id = id_def.register_intermediate(Some(( - return_reg.v_type.clone(), - ast::StateSpace::Reg, - ))); - result.push(Statement::LoadVar(LoadVarDetails { - arg: ast::Arg2 { - dst: new_id, - src: return_reg.name, - }, - // TODO: ret with stateful conversion - state_space: ast::StateSpace::Reg, - typ: return_reg.v_type.clone(), - member_index: None, - })); - result.push(Statement::RetValue(d, new_id)); - } - [] => result.push(Statement::Instruction(ast::Instruction::Ret(d))), - _ => unimplemented!(), - } - } - inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?, - }, - Statement::Conditional(bra) => { - insert_mem_ssa_statement_default(id_def, &mut result, bra)? - } - Statement::Conversion(conv) => { - insert_mem_ssa_statement_default(id_def, &mut result, conv)? - } - Statement::PtrAccess(ptr_access) => { - insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)? - } - Statement::RepackVector(repack) => { - insert_mem_ssa_statement_default(id_def, &mut result, repack)? - } - Statement::FunctionPointer(func_ptr) => { - insert_mem_ssa_statement_default(id_def, &mut result, func_ptr)? - } - s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => { - result.push(s) - } - _ => return Err(error_unreachable()), - } - } - Ok(result) -} - fn insert_mem_ssa_argument( id_def: &mut NumericIdResolver, func: &mut Vec, - arg: &mut ast::Variable, + arg: &mut ast::Variable, is_kernel: bool, ) { if !is_kernel && arg.state_space == ast::StateSpace::Param { @@ -2265,7 +2075,7 @@ fn insert_mem_ssa_argument( array_init: Vec::new(), })); func.push(Statement::StoreVar(StoreVarDetails { - arg: ast::Arg2St { + arg: ast::StArgs { src1: arg.name, src2: new_id, }, @@ -2277,7 +2087,7 @@ fn insert_mem_ssa_argument( fn insert_mem_ssa_argument_reg_return( func: &mut Vec, - arg: &ast::Variable, + arg: &ast::Variable, ) { func.push(Statement::Variable(ast::Variable { align: arg.align, @@ -2288,41 +2098,6 @@ fn insert_mem_ssa_argument_reg_return( })); } -trait Visitable: Sized { - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, To>, TranslateError>; -} - -struct VisitArgumentDescriptor< - 'a, - Ctor: FnOnce(spirv::Word) -> Statement, U>, - U: ArgParamsEx, -> { - desc: ArgumentDescriptor, - typ: &'a ast::Type, - state_space: ast::StateSpace, - stmt_ctor: Ctor, -} - -impl< - 'a, - Ctor: FnOnce(spirv::Word) -> Statement, U>, - T: ArgParamsEx, - U: ArgParamsEx, - > Visitable for VisitArgumentDescriptor<'a, Ctor, U> -{ - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, U>, TranslateError> { - Ok((self.stmt_ctor)( - visitor.id(self.desc, Some((self.typ, self.state_space)))?, - )) - } -} - struct InsertMemSSAVisitor<'a, 'input> { id_def: &'a mut NumericIdResolver<'input>, func: &'a mut Vec, @@ -2332,9 +2107,9 @@ struct InsertMemSSAVisitor<'a, 'input> { impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { fn symbol( &mut self, - desc: ArgumentDescriptor<(spirv::Word, Option)>, + desc: ArgumentDescriptor<(SpirvWord, Option)>, expected: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { + ) -> Result { let symbol = desc.op.0; if expected.is_none() { return Ok(symbol); @@ -2368,7 +2143,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg))); if !desc.is_dst { self.func.push(Statement::LoadVar(LoadVarDetails { - arg: Arg2 { + arg: ast::MovArgs { dst: generated_id, src: symbol, }, @@ -2379,7 +2154,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { } else { self.post_statements .push(Statement::StoreVar(StoreVarDetails { - arg: Arg2St { + arg: ast::StArgs { src1: symbol, src2: generated_id, }, @@ -2391,55 +2166,6 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { } } -impl<'a, 'input> ArgumentMapVisitor - for InsertMemSSAVisitor<'a, 'input> -{ - fn id( - &mut self, - desc: ArgumentDescriptor, - typ: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - self.symbol(desc.new_op((desc.op, None)), typ) - } - - fn operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result { - Ok(match desc.op { - TypedOperand::Reg(reg) => { - TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?) - } - TypedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset( - self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?, - offset, - ), - op @ TypedOperand::Imm(..) => op, - TypedOperand::VecMember(symbol, index) => TypedOperand::Reg( - self.symbol(desc.new_op((symbol, Some(index))), Some((typ, state_space)))?, - ), - }) - } -} - -fn insert_mem_ssa_statement_default<'a, 'input, S: Visitable>( - id_def: &'a mut NumericIdResolver<'input>, - func: &'a mut Vec, - stmt: S, -) -> Result<(), TranslateError> { - let mut visitor = InsertMemSSAVisitor { - id_def, - func, - post_statements: Vec::new(), - }; - let new_stmt = stmt.visit(&mut visitor)?; - visitor.func.push(new_stmt); - visitor.func.extend(visitor.post_statements); - Ok(()) -} - fn expand_arguments<'a, 'b>( func: Vec, id_def: &'b mut MutableNumericIdResolver<'a>, @@ -2517,18 +2243,18 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn reg( &mut self, - desc: ArgumentDescriptor, + desc: ArgumentDescriptor, _: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { + ) -> Result { Ok(desc.op) } fn reg_offset( &mut self, - desc: ArgumentDescriptor<(spirv::Word, i32)>, + desc: ArgumentDescriptor<(SpirvWord, i32)>, typ: &ast::Type, state_space: ast::StateSpace, - ) -> Result { + ) -> Result { let (reg, offset) = desc.op; if !desc.is_memory_access { let (reg_type, reg_space) = self.id_def.get_typed(reg)?; @@ -2548,8 +2274,8 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { value: ast::ImmediateValue::S64(offset as i64), })); let arith_details = match reg_scalar_type.kind() { - ast::ScalarKind::Signed => ast::ArithDetails::Signed(ast::ArithSInt { - typ: reg_scalar_type, + ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger { + type_: reg_scalar_type, saturate: false, }), ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => { @@ -2560,7 +2286,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { let id_add_result = self.id_def.register_intermediate(reg_type, state_space); self.func.push(Statement::Instruction(ast::Instruction::Add( arith_details, - ast::Arg3 { + ast::AddArgs { dst: id_add_result, src1: reg, src2: id_constant_stmt, @@ -2594,7 +2320,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { desc: ArgumentDescriptor, typ: &ast::Type, state_space: ast::StateSpace, - ) -> Result { + ) -> Result { let scalar_t = if let ast::Type::Scalar(scalar) = typ { *scalar } else { @@ -2612,144 +2338,12 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { } } -impl<'a, 'b> ArgumentMapVisitor for FlattenArguments<'a, 'b> { - fn id( - &mut self, - desc: ArgumentDescriptor, - t: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - self.reg(desc, t) - } - - fn operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result { - match desc.op { - TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some((typ, state_space))), - TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ, state_space), - TypedOperand::RegOffset(reg, offset) => { - self.reg_offset(desc.new_op((reg, offset)), typ, state_space) - } - TypedOperand::VecMember(..) => Err(error_unreachable()), - } - } -} - -/* - There are several kinds of implicit conversions in PTX: - * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands - * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size - - ld.param: not documented, but for instruction `ld.param. x, [y]`, - semantics are to first zext/chop/bitcast `y` as needed and then do - documented special ld/st/cvt conversion rules for destination operands - - st.param [x] y (used as function return arguments) same rule as above applies - - generic/global ld: for instruction `ld x, [y]`, y must be of type - b64/u64/s64, which is bitcast to a pointer, dereferenced and then - documented special ld/st/cvt conversion rules are applied to dst - - generic/global st: for instruction `st [x], y`, x must be of type - b64/u64/s64, which is bitcast to a pointer -*/ -fn insert_implicit_conversions( - func: Vec, - id_def: &mut MutableNumericIdResolver, -) -> Result, TranslateError> { - let mut result = Vec::with_capacity(func.len()); - for s in func.into_iter() { - match s { - Statement::Call(call) => { - insert_implicit_conversions_impl(&mut result, id_def, call)?; - } - Statement::Instruction(inst) => { - insert_implicit_conversions_impl(&mut result, id_def, inst)?; - } - Statement::PtrAccess(access) => { - insert_implicit_conversions_impl(&mut result, id_def, access)?; - } - Statement::RepackVector(repack) => { - insert_implicit_conversions_impl(&mut result, id_def, repack)?; - } - s @ Statement::Conditional(_) - | s @ Statement::Conversion(_) - | s @ Statement::Label(_) - | s @ Statement::Constant(_) - | s @ Statement::Variable(_) - | s @ Statement::LoadVar(..) - | s @ Statement::StoreVar(..) - | s @ Statement::RetValue(..) - | s @ Statement::FunctionPointer(..) => result.push(s), - } - } - Ok(result) -} - -fn insert_implicit_conversions_impl( - func: &mut Vec, - id_def: &mut MutableNumericIdResolver, - stmt: impl Visitable, -) -> Result<(), TranslateError> { - let mut post_conv = Vec::new(); - let statement = - stmt.visit(&mut |desc: ArgumentDescriptor, - typ: Option<(&ast::Type, ast::StateSpace)>| { - let (instr_type, instruction_space) = match typ { - None => return Ok(desc.op), - Some(t) => t, - }; - let (operand_type, operand_space) = id_def.get_typed(desc.op)?; - let conversion_fn = desc - .non_default_implicit_conversion - .unwrap_or(default_implicit_conversion); - match conversion_fn( - (operand_space, &operand_type), - (instruction_space, instr_type), - )? { - Some(conv_kind) => { - let conv_output = if desc.is_dst { - &mut post_conv - } else { - &mut *func - }; - let mut from_type = instr_type.clone(); - let mut from_space = instruction_space; - let mut to_type = operand_type; - let mut to_space = operand_space; - let mut src = - id_def.register_intermediate(instr_type.clone(), instruction_space); - let mut dst = desc.op; - let result = Ok(src); - if !desc.is_dst { - mem::swap(&mut src, &mut dst); - mem::swap(&mut from_type, &mut to_type); - mem::swap(&mut from_space, &mut to_space); - } - conv_output.push(Statement::Conversion(ImplicitConversion { - src, - dst, - from_type, - from_space, - to_type, - to_space, - kind: conv_kind, - })); - result - } - None => Ok(desc.op), - } - })?; - func.push(statement); - func.append(&mut post_conv); - Ok(()) -} - fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, spirv_input: impl Iterator, - spirv_output: &[ast::Variable], -) -> (spirv::Word, spirv::Word) { + spirv_output: &[ast::Variable], +) -> (SpirvWord, SpirvWord) { map.get_or_add_fn( builder, spirv_input, @@ -2763,7 +2357,7 @@ fn emit_function_body_ops<'input>( builder: &mut dr::Builder, map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver<'input>, - opencl: spirv::Word, + opencl: SpirvWord, func: &[ExpandedStatement], ) -> Result<(), TranslateError> { for s in func { @@ -3107,17 +2701,15 @@ fn emit_function_body_ops<'input>( &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(spirv::Scope::Workgroup as u32), )?; - let barrier_semantics = match d { - ast::BarDetails::SyncAligned => map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr( - spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY - | spirv::MemorySemantics::WORKGROUP_MEMORY - | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, - ), - )?, - }; + let barrier_semantics = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr( + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + )?; builder.control_barrier(workgroup_scope, workgroup_scope, barrier_semantics)?; } ast::Instruction::Atom(details, arg) => { @@ -3170,7 +2762,7 @@ fn emit_function_body_ops<'input>( result_type, Some(a.dst), opencl, - spirv::CLOp::rsqrt as spirv::Word, + spirv::CLOp::rsqrt as SpirvWord, [dr::Operand::IdRef(a.src)].iter().cloned(), )?; } @@ -3183,71 +2775,77 @@ fn emit_function_body_ops<'input>( }; negate_func(builder, result_type, Some(arg.dst), arg.src)?; } - ast::Instruction::Sin { arg, .. } => { + ast::Instruction::Sin { arguments, .. } => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); builder.ext_inst( result_type, - Some(arg.dst), + Some(arguments.dst), opencl, spirv::CLOp::sin as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), + [dr::Operand::IdRef(arguments.src)].iter().cloned(), )?; } - ast::Instruction::Cos { arg, .. } => { + ast::Instruction::Cos { arguments, .. } => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); builder.ext_inst( result_type, - Some(arg.dst), + Some(arguments.dst), opencl, spirv::CLOp::cos as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), + [dr::Operand::IdRef(arguments.src)].iter().cloned(), )?; } - ast::Instruction::Lg2 { arg, .. } => { + ast::Instruction::Lg2 { arguments, .. } => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); builder.ext_inst( result_type, - Some(arg.dst), + Some(arguments.dst), opencl, spirv::CLOp::log2 as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), + [dr::Operand::IdRef(arguments.src)].iter().cloned(), )?; } - ast::Instruction::Ex2 { arg, .. } => { + ast::Instruction::Ex2 { arguments, .. } => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); builder.ext_inst( result_type, - Some(arg.dst), + Some(arguments.dst), opencl, spirv::CLOp::exp2 as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), + [dr::Operand::IdRef(arguments.src)].iter().cloned(), )?; } - ast::Instruction::Clz { typ, arg } => { - let result_type = map.get_or_add_scalar(builder, (*typ).into()); + ast::Instruction::Clz { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, (*data).into()); builder.ext_inst( result_type, - Some(arg.dst), + Some(arguments.dst), opencl, spirv::CLOp::clz as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), + [dr::Operand::IdRef(arguments.src)].iter().cloned(), )?; } - ast::Instruction::Brev { typ, arg } => { - let result_type = map.get_or_add_scalar(builder, (*typ).into()); - builder.bit_reverse(result_type, Some(arg.dst), arg.src)?; + ast::Instruction::Brev { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder.bit_reverse(result_type, Some(arguments.dst), arguments.src)?; } - ast::Instruction::Popc { typ, arg } => { - let result_type = map.get_or_add_scalar(builder, (*typ).into()); - builder.bit_count(result_type, Some(arg.dst), arg.src)?; + ast::Instruction::Popc { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder.bit_count(result_type, Some(arguments.dst), arguments.src)?; } - ast::Instruction::Xor { typ, arg } => { - let builder_fn = match typ { + ast::Instruction::Xor { data, arguments } => { + let builder_fn = match data { ast::ScalarType::Pred => emit_logical_xor_spirv, _ => dr::Builder::bitwise_xor, }; - let result_type = map.get_or_add_scalar(builder, (*typ).into()); - builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder_fn( + builder, + result_type, + Some(arguments.dst), + arguments.src1, + arguments.src2, + )?; } ast::Instruction::Bfe { .. } | ast::Instruction::Bfi { .. } @@ -3256,17 +2854,23 @@ fn emit_function_body_ops<'input>( return Err(error_unreachable()); } - ast::Instruction::Rem { typ, arg } => { - let builder_fn = if typ.kind() == ast::ScalarKind::Signed { + ast::Instruction::Rem { data, arguments } => { + let builder_fn = if data.kind() == ast::ScalarKind::Signed { dr::Builder::s_mod } else { dr::Builder::u_mod }; - let result_type = map.get_or_add_scalar(builder, (*typ).into()); - builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder_fn( + builder, + result_type, + Some(arguments.dst), + arguments.src1, + arguments.src2, + )?; } - ast::Instruction::Prmt { control, arg } => { - let control = *control as u32; + ast::Instruction::Prmt { data, arguments } => { + let control = *data as u32; let components = [ (control >> 0) & 0b1111, (control >> 4) & 0b1111, @@ -3279,8 +2883,8 @@ fn emit_function_body_ops<'input>( let vec4_b8_type = map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4)); let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32); - let src1_vector = builder.bitcast(vec4_b8_type, None, arg.src1)?; - let src2_vector = builder.bitcast(vec4_b8_type, None, arg.src2)?; + let src1_vector = builder.bitcast(vec4_b8_type, None, arguments.src1)?; + let src2_vector = builder.bitcast(vec4_b8_type, None, arguments.src2)?; let dst_vector = builder.vector_shuffle( vec4_b8_type, None, @@ -3288,10 +2892,10 @@ fn emit_function_body_ops<'input>( src2_vector, components, )?; - builder.bitcast(b32_type, Some(arg.dst), dst_vector)?; + builder.bitcast(b32_type, Some(arguments.dst), dst_vector)?; } - ast::Instruction::Membar { level } => { - let (scope, semantics) = match level { + ast::Instruction::Membar { data } => { + let (scope, semantics) = match data { ast::MemScope::Cta => ( spirv::Scope::Workgroup, spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY @@ -3310,6 +2914,7 @@ fn emit_function_body_ops<'input>( | spirv::MemorySemantics::WORKGROUP_MEMORY | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, ), + ast::MemScope::Cluster => todo!(), }; let spirv_scope = map.get_or_add_constant( builder, @@ -3423,9 +3028,9 @@ fn emit_function_body_ops<'input>( fn insert_shift_hack( builder: &mut dr::Builder, map: &mut TypeWordMap, - offset_var: spirv::Word, + offset_var: SpirvWord, size_of: usize, -) -> Result { +) -> Result { let result_type = match size_of { 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16), 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64), @@ -3438,11 +3043,11 @@ fn insert_shift_hack( // TODO: check what kind of assembly do we emit fn emit_logical_xor_spirv( builder: &mut dr::Builder, - result_type: spirv::Word, - result_id: Option, - op1: spirv::Word, - op2: spirv::Word, -) -> Result { + result_type: SpirvWord, + result_id: Option, + op1: SpirvWord, + op2: SpirvWord, +) -> Result { let temp_or = builder.logical_or(result_type, None, op1, op2)?; let temp_and = builder.logical_and(result_type, None, op1, op2)?; let temp_neg = logical_not(builder, result_type, None, temp_and)?; @@ -3452,27 +3057,27 @@ fn emit_logical_xor_spirv( fn emit_sqrt( builder: &mut dr::Builder, map: &mut TypeWordMap, - opencl: spirv::Word, - details: &ast::SqrtDetails, - a: &ast::Arg2, + opencl: SpirvWord, + details: &ast::RcpData, + a: &ast::SqrtArgs, ) -> Result<(), TranslateError> { let result_type = map.get_or_add_scalar(builder, details.typ.into()); let (ocl_op, rounding) = match details.kind { - ast::SqrtKind::Approx => (spirv::CLOp::sqrt, None), - ast::SqrtKind::Rounding(rnd) => (spirv::CLOp::sqrt, Some(rnd)), + ast::RcpKind::Approx => (spirv::CLOp::sqrt, None), + ast::RcpKind::Rounding(rnd) => (spirv::CLOp::sqrt, Some(rnd)), }; builder.ext_inst( result_type, Some(a.dst), opencl, - ocl_op as spirv::Word, + ocl_op as SpirvWord, [dr::Operand::IdRef(a.src)].iter().cloned(), )?; emit_rounding_decoration(builder, a.dst, rounding); Ok(()) } -fn emit_float_div_decoration(builder: &mut dr::Builder, dst: spirv::Word, kind: ast::DivFloatKind) { +fn emit_float_div_decoration(builder: &mut dr::Builder, dst: SpirvWord, kind: ast::DivFloatKind) { match kind { ast::DivFloatKind::Approx => { builder.decorate( @@ -3496,45 +3101,25 @@ fn emit_atom( builder: &mut dr::Builder, map: &mut TypeWordMap, details: &ast::AtomDetails, - arg: &ast::Arg3, + arg: &ast::AtomArgs, ) -> Result<(), TranslateError> { - let (spirv_op, typ) = match details.inner { - ast::AtomInnerDetails::Bit { op, typ } => { - let spirv_op = match op { - ast::AtomBitOp::And => dr::Builder::atomic_and, - ast::AtomBitOp::Or => dr::Builder::atomic_or, - ast::AtomBitOp::Xor => dr::Builder::atomic_xor, - ast::AtomBitOp::Exchange => dr::Builder::atomic_exchange, - }; - (spirv_op, ast::ScalarType::from(typ)) - } - ast::AtomInnerDetails::Unsigned { op, typ } => { - let spirv_op = match op { - ast::AtomUIntOp::Add => dr::Builder::atomic_i_add, - ast::AtomUIntOp::Inc | ast::AtomUIntOp::Dec => { - return Err(error_unreachable()); - } - ast::AtomUIntOp::Min => dr::Builder::atomic_u_min, - ast::AtomUIntOp::Max => dr::Builder::atomic_u_max, - }; - (spirv_op, typ.into()) - } - ast::AtomInnerDetails::Signed { op, typ } => { - let spirv_op = match op { - ast::AtomSIntOp::Add => dr::Builder::atomic_i_add, - ast::AtomSIntOp::Min => dr::Builder::atomic_s_min, - ast::AtomSIntOp::Max => dr::Builder::atomic_s_max, - }; - (spirv_op, typ.into()) - } - ast::AtomInnerDetails::Float { op, typ } => { - let spirv_op: fn(&mut dr::Builder, _, _, _, _, _, _) -> _ = match op { - ast::AtomFloatOp::Add => dr::Builder::atomic_f_add_ext, - }; - (spirv_op, typ.into()) - } + let spirv_op = match details.op { + ptx_parser::AtomicOp::And => dr::Builder::atomic_and, + ptx_parser::AtomicOp::Or => dr::Builder::atomic_or, + ptx_parser::AtomicOp::Xor => dr::Builder::atomic_xor, + ptx_parser::AtomicOp::Exchange => dr::Builder::atomic_exchange, + ptx_parser::AtomicOp::Add => dr::Builder::atomic_i_add, + ptx_parser::AtomicOp::IncrementWrap => return Err(error_unreachable()), + ptx_parser::AtomicOp::DecrementWrap => return Err(error_unreachable()), + ptx_parser::AtomicOp::SignedMin => dr::Builder::atomic_s_min, + ptx_parser::AtomicOp::UnsignedMin => dr::Builder::atomic_u_min, + ptx_parser::AtomicOp::SignedMax => dr::Builder::atomic_s_max, + ptx_parser::AtomicOp::UnsignedMax => dr::Builder::atomic_u_max, + ptx_parser::AtomicOp::FloatAdd => dr::Builder::atomic_f_add_ext, + ptx_parser::AtomicOp::FloatMin => dr::Builder::atomic_f_min_ext, + ptx_parser::AtomicOp::FloatMax => dr::Builder::atomic_f_max_ext, }; - let result_type = map.get_or_add_scalar(builder, typ); + let result_type = map.get_or_add_scalar(builder, details.type_); let memory_const = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), @@ -3568,7 +3153,7 @@ fn emit_mul_float( builder: &mut dr::Builder, map: &mut TypeWordMap, ctr: &ast::ArithFloat, - arg: &ast::Arg3, + arg: &ast::MulArgs, ) -> Result<(), dr::Error> { if ctr.saturate { todo!() @@ -3582,9 +3167,9 @@ fn emit_mul_float( fn emit_rcp( builder: &mut dr::Builder, map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::RcpDetails, - arg: &ast::Arg2, + opencl: SpirvWord, + desc: &ast::RcpData, + arg: &ast::RcpArgs, ) -> Result<(), TranslateError> { let (instr_type, constant) = if desc.is_f64 { (ast::ScalarType::F64, vec_repr(1.0f64)) @@ -3628,7 +3213,7 @@ fn emit_variable<'input>( map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver<'input>, linking: ast::LinkingDirective, - var: &ast::Variable, + var: &ast::Variable, ) -> Result<(), TranslateError> { let (must_init, st_class) = match var.state_space { ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { @@ -3673,7 +3258,7 @@ fn emit_linking_decoration<'input>( builder: &mut dr::Builder, id_defs: &GlobalStringIdResolver<'input>, name_override: Option<&str>, - name: spirv::Word, + name: SpirvWord, linking: ast::LinkingDirective, ) { if linking == ast::LinkingDirective::NONE { @@ -3712,9 +3297,9 @@ fn emit_linking_decoration<'input>( fn emit_mad_uint( builder: &mut dr::Builder, map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::MulUInt, - arg: &ast::Arg4, + opencl: SpirvWord, + desc: &ast::MulDetails, + arg: &ast::MulArgs, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); match desc.control { @@ -3727,7 +3312,7 @@ fn emit_mad_uint( inst_type, Some(arg.dst), opencl, - spirv::CLOp::u_mad_hi as spirv::Word, + spirv::CLOp::u_mad_hi as SpirvWord, [ dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2), @@ -3745,9 +3330,9 @@ fn emit_mad_uint( fn emit_mad_sint( builder: &mut dr::Builder, map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::MulSInt, - arg: &ast::Arg4, + opencl: SpirvWord, + desc: &ast::MulDetails, + arg: &ast::MulArgs, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); match desc.control { @@ -3760,7 +3345,7 @@ fn emit_mad_sint( inst_type, Some(arg.dst), opencl, - spirv::CLOp::s_mad_hi as spirv::Word, + spirv::CLOp::s_mad_hi as SpirvWord, [ dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2), @@ -3778,16 +3363,16 @@ fn emit_mad_sint( fn emit_fma_float( builder: &mut dr::Builder, map: &mut TypeWordMap, - opencl: spirv::Word, + opencl: SpirvWord, desc: &ast::ArithFloat, - arg: &ast::Arg4, + arg: &ast::FmaArgs, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); builder.ext_inst( inst_type, Some(arg.dst), opencl, - spirv::CLOp::fma as spirv::Word, + spirv::CLOp::fma as SpirvWord, [ dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2), @@ -3802,16 +3387,16 @@ fn emit_fma_float( fn emit_mad_float( builder: &mut dr::Builder, map: &mut TypeWordMap, - opencl: spirv::Word, + opencl: SpirvWord, desc: &ast::ArithFloat, - arg: &ast::Arg4, + arg: &ast::MadArgs, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); builder.ext_inst( inst_type, Some(arg.dst), opencl, - spirv::CLOp::mad as spirv::Word, + spirv::CLOp::mad as SpirvWord, [ dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2), @@ -3827,7 +3412,7 @@ fn emit_add_float( builder: &mut dr::Builder, map: &mut TypeWordMap, desc: &ast::ArithFloat, - arg: &ast::Arg3, + arg: &ast::AddArgs, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); builder.f_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?; @@ -3839,7 +3424,7 @@ fn emit_sub_float( builder: &mut dr::Builder, map: &mut TypeWordMap, desc: &ast::ArithFloat, - arg: &ast::Arg3, + arg: &ast::SubArgs, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?; @@ -3850,9 +3435,9 @@ fn emit_sub_float( fn emit_min( builder: &mut dr::Builder, map: &mut TypeWordMap, - opencl: spirv::Word, + opencl: SpirvWord, desc: &ast::MinMaxDetails, - arg: &ast::Arg3, + arg: &ast::MinArgs, ) -> Result<(), dr::Error> { let cl_op = match desc { ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min, @@ -3864,7 +3449,7 @@ fn emit_min( inst_type, Some(arg.dst), opencl, - cl_op as spirv::Word, + cl_op as SpirvWord, [dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2)] .iter() .cloned(), @@ -3875,9 +3460,9 @@ fn emit_min( fn emit_max( builder: &mut dr::Builder, map: &mut TypeWordMap, - opencl: spirv::Word, + opencl: SpirvWord, desc: &ast::MinMaxDetails, - arg: &ast::Arg3, + arg: &ast::MaxArgs, ) -> Result<(), dr::Error> { let cl_op = match desc { ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max, @@ -3889,7 +3474,7 @@ fn emit_max( inst_type, Some(arg.dst), opencl, - cl_op as spirv::Word, + cl_op as SpirvWord, [dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2)] .iter() .cloned(), @@ -3900,9 +3485,9 @@ fn emit_max( fn emit_cvt( builder: &mut dr::Builder, map: &mut TypeWordMap, - opencl: spirv::Word, + opencl: SpirvWord, dets: &ast::CvtDetails, - arg: &ast::Arg2, + arg: &ast::CvtArgs, ) -> Result<(), TranslateError> { match dets { ast::CvtDetails::FloatFromFloat(desc) => { @@ -4036,7 +3621,7 @@ fn emit_saturating_decoration(builder: &mut dr::Builder, dst: u32, saturate: boo fn emit_rounding_decoration( builder: &mut dr::Builder, - dst: spirv::Word, + dst: SpirvWord, rounding: Option, ) { if let Some(rounding) = rounding { @@ -4048,23 +3633,21 @@ fn emit_rounding_decoration( } } -impl ast::RoundingMode { - fn to_spirv(self) -> rspirv::dr::Operand { - let mode = match self { - ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE, - ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ, - ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP, - ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN, - }; - rspirv::dr::Operand::FPRoundingMode(mode) - } +fn rnd_to_spirv(this: ast::RoundingMode) -> rspirv::dr::Operand { + let mode = match this { + ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE, + ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ, + ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP, + ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN, + }; + rspirv::dr::Operand::FPRoundingMode(mode) } fn emit_setp( builder: &mut dr::Builder, map: &mut TypeWordMap, setp: &ast::SetpData, - arg: &ast::Arg4Setp, + arg: &ast::SetpArgs, ) -> Result<(), dr::Error> { let result_type = map.get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred)); let result_id = Some(arg.dst1); @@ -4169,10 +3752,10 @@ fn emit_setp( // https://github.com/intel/intel-graphics-compiler/issues/148 fn logical_not( builder: &mut dr::Builder, - result_type: spirv::Word, - result_id: Option, - operand: spirv::Word, -) -> Result { + result_type: SpirvWord, + result_id: Option, + operand: SpirvWord, +) -> Result { let const_true = builder.constant_true(result_type, None); let const_false = builder.constant_false(result_type, None); builder.select(result_type, result_id, operand, const_false, const_true) @@ -4181,9 +3764,9 @@ fn logical_not( fn emit_mul_sint( builder: &mut dr::Builder, map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::MulSInt, - arg: &ast::Arg3, + opencl: SpirvWord, + desc: &ast::MulDetails, + arg: &ast::MulArgs, ) -> Result<(), dr::Error> { let instruction_type = desc.typ; let inst_type = map.get_or_add(builder, SpirvType::from(desc.typ)); @@ -4196,7 +3779,7 @@ fn emit_mul_sint( inst_type, Some(arg.dst), opencl, - spirv::CLOp::s_mul_hi as spirv::Word, + spirv::CLOp::s_mul_hi as SpirvWord, [dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2)] .iter() .cloned(), @@ -4219,9 +3802,9 @@ fn emit_mul_sint( fn emit_mul_uint( builder: &mut dr::Builder, map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::MulUInt, - arg: &ast::Arg3, + opencl: SpirvWord, + desc: &ast::MulDetails, + arg: &ast::MulArgs, ) -> Result<(), dr::Error> { let instruction_type = ast::ScalarType::from(desc.typ); let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); @@ -4234,7 +3817,7 @@ fn emit_mul_uint( inst_type, Some(arg.dst), opencl, - spirv::CLOp::u_mul_hi as spirv::Word, + spirv::CLOp::u_mul_hi as SpirvWord, [dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2)] .iter() .cloned(), @@ -4259,10 +3842,10 @@ fn struct2_bitcast_to_wide( builder: &mut dr::Builder, map: &mut TypeWordMap, base_type_key: SpirvScalarKey, - instruction_type: spirv::Word, - dst: spirv::Word, - dst_type_id: spirv::Word, - src: spirv::Word, + instruction_type: SpirvWord, + dst: SpirvWord, + dst_type_id: SpirvWord, + src: SpirvWord, ) -> Result<(), dr::Error> { let low_bits = builder.composite_extract(instruction_type, None, src, [0].iter().copied())?; let high_bits = builder.composite_extract(instruction_type, None, src, [1].iter().copied())?; @@ -4276,9 +3859,9 @@ fn struct2_bitcast_to_wide( fn emit_abs( builder: &mut dr::Builder, map: &mut TypeWordMap, - opencl: spirv::Word, - d: &ast::AbsDetails, - arg: &ast::Arg2, + opencl: SpirvWord, + d: &ast::TypeFtz, + arg: &ast::AbsArgs, ) -> Result<(), dr::Error> { let scalar_t = ast::ScalarType::from(d.typ); let result_type = map.get_or_add(builder, SpirvType::from(scalar_t)); @@ -4291,7 +3874,7 @@ fn emit_abs( result_type, Some(arg.dst), opencl, - cl_abs as spirv::Word, + cl_abs as SpirvWord, [dr::Operand::IdRef(arg.src)].iter().cloned(), )?; Ok(()) @@ -4302,7 +3885,7 @@ fn emit_add_int( map: &mut TypeWordMap, typ: ast::ScalarType, saturate: bool, - arg: &ast::Arg3, + arg: &ast::AddArgs, ) -> Result<(), dr::Error> { if saturate { todo!() @@ -4317,7 +3900,7 @@ fn emit_sub_int( map: &mut TypeWordMap, typ: ast::ScalarType, saturate: bool, - arg: &ast::Arg3, + arg: &ast::SubArgs, ) -> Result<(), dr::Error> { if saturate { todo!() @@ -4530,7 +4113,7 @@ fn emit_load_var( fn normalize_identifiers<'input, 'b>( id_defs: &mut FnStringIdResolver<'input, 'b>, fn_defs: &GlobalFnDeclResolver<'input, 'b>, - func: Vec>>, + func: Vec>>, ) -> Result, TranslateError> { for s in func.iter() { match s { @@ -4551,7 +4134,7 @@ fn expand_map_variables<'a, 'b>( id_defs: &mut FnStringIdResolver<'a, 'b>, fn_defs: &GlobalFnDeclResolver<'a, 'b>, result: &mut Vec, - s: ast::Statement>, + s: ast::Statement>, ) -> Result<(), TranslateError> { match s { ast::Statement::Block(block) => { @@ -4636,12 +4219,12 @@ fn expand_map_variables<'a, 'b>( // argument expansion // TODO: propagate out of calls and into calls fn convert_to_stateful_memory_access<'a, 'input>( - func_args: Rc>>, + func_args: Rc>>, func_body: Vec, id_defs: &mut NumericIdResolver<'a>, ) -> Result< ( - Rc>>, + Rc>>, Vec, ), TranslateError, @@ -4668,16 +4251,16 @@ fn convert_to_stateful_memory_access<'a, 'input>( let mut stateful_init_reg = HashMap::<_, Vec<_>>::new(); for statement in func_body.iter() { match statement { - Statement::Instruction(ast::Instruction::Cvta( - ast::CvtaDetails { - to: ast::StateSpace::Global, - size: ast::CvtaSize::U64, - from: ast::StateSpace::Generic, - }, - arg, - )) => { + Statement::Instruction(ast::Instruction::Cvta { + data: + ast::CvtaDetails { + direction: ast::CvtaDirection::GenericToExplicit, + state_space: ast::StateSpace::Global, + }, + arguments, + }) => { if let (TypedOperand::Reg(dst), Some(src)) = - (arg.dst, arg.src.upcast().underlying_register()) + (arguments.dst, arguments.src.upcast().underlying_register()) { if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, dst) { stateful_markers.push((dst, *src)); @@ -4746,8 +4329,8 @@ fn convert_to_stateful_memory_access<'a, 'input>( arg, )) | Statement::Instruction(ast::Instruction::Add( - ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::ScalarType::S64, + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, saturate: false, }), arg, @@ -4757,8 +4340,8 @@ fn convert_to_stateful_memory_access<'a, 'input>( arg, )) | Statement::Instruction(ast::Instruction::Sub( - ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::ScalarType::S64, + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, saturate: false, }), arg, @@ -4832,8 +4415,8 @@ fn convert_to_stateful_memory_access<'a, 'input>( arg, )) | Statement::Instruction(ast::Instruction::Add( - ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::ScalarType::S64, + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, saturate: false, }), arg, @@ -4861,8 +4444,8 @@ fn convert_to_stateful_memory_access<'a, 'input>( arg, )) | Statement::Instruction(ast::Instruction::Sub( - ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::ScalarType::S64, + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, saturate: false, }), arg, @@ -4875,16 +4458,16 @@ fn convert_to_stateful_memory_access<'a, 'input>( ast::Type::Scalar(ast::ScalarType::S64), ast::StateSpace::Reg, ))); - result.push(Statement::Instruction(ast::Instruction::Neg( - ast::NegDetails { - typ: ast::ScalarType::S64, + result.push(Statement::Instruction(ast::Instruction::Neg { + data: ast::TypeFtz { + type_: ast::ScalarType::S64, flush_to_zero: None, }, - ast::Arg2 { + arguments: ast::NegArgs { src: offset, dst: TypedOperand::Reg(offset_neg), }, - ))); + })); let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { underlying_type: ast::Type::Scalar(ast::ScalarType::U8), @@ -4951,12 +4534,12 @@ fn convert_to_stateful_memory_access<'a, 'input>( fn convert_to_stateful_memory_access_postprocess( id_defs: &mut NumericIdResolver, - remapped_ids: &HashMap, + remapped_ids: &HashMap, result: &mut Vec, post_statements: &mut Vec, - arg_desc: ArgumentDescriptor, + arg_desc: ArgumentDescriptor, expected_type: Option<(&ast::Type, ast::StateSpace)>, -) -> Result { +) -> Result { Ok(match remapped_ids.get(&arg_desc.op) { Some(new_id) => { let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?; @@ -5009,7 +4592,7 @@ fn convert_to_stateful_memory_access_postprocess( }) } -fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { +fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::AddArgs) -> bool { match arg.dst { TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { return false @@ -5035,7 +4618,7 @@ fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3, arg: &ast::Arg3) -> bool { +fn is_sub_ptr_direct(remapped_ids: &HashMap, arg: &ast::SubArgs) -> bool { match arg.dst { TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { return false @@ -5062,7 +4645,7 @@ fn is_sub_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3 bool { +fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool { match id_defs.get_typed(id) { Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _)) | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _)) @@ -5138,8 +4721,8 @@ impl PtxSpecialRegister { } struct SpecialRegistersMap { - reg_to_id: HashMap, - id_to_reg: HashMap, + reg_to_id: HashMap, + id_to_reg: HashMap, } impl SpecialRegistersMap { @@ -5150,11 +4733,11 @@ impl SpecialRegistersMap { } } - fn get(&self, id: spirv::Word) -> Option { + fn get(&self, id: SpirvWord) -> Option { self.id_to_reg.get(&id).copied() } - fn get_or_add(&mut self, current_id: &mut spirv::Word, reg: PtxSpecialRegister) -> spirv::Word { + fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord { match self.reg_to_id.entry(reg) { hash_map::Entry::Occupied(e) => *e.get(), hash_map::Entry::Vacant(e) => { @@ -5172,11 +4755,11 @@ struct FnSigMapper<'input> { // true - stays as return argument // false - is moved to input argument return_param_args: Vec, - func_decl: Rc>>, + func_decl: Rc>>, } impl<'input> FnSigMapper<'input> { - fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, spirv::Word>) -> Self { + fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, SpirvWord>) -> Self { let return_param_args = method .return_arguments .iter() @@ -5196,69 +4779,19 @@ impl<'input> FnSigMapper<'input> { func_decl: Rc::new(RefCell::new(method)), } } - - fn resolve_in_spirv_repr( - &self, - call_inst: ast::CallInst, - ) -> Result, TranslateError> { - let func_decl = (*self.func_decl).borrow(); - let mut return_arguments = Vec::new(); - let mut input_arguments = call_inst - .param_list - .into_iter() - .zip(func_decl.input_arguments.iter()) - .map(|(id, var)| (id, var.v_type.clone(), var.state_space)) - .collect::>(); - let mut func_decl_return_iter = func_decl.return_arguments.iter(); - let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter(); - for (idx, id) in call_inst.ret_params.iter().enumerate() { - let stays_as_return = match self.return_param_args.get(idx) { - Some(x) => *x, - None => return Err(TranslateError::MismatchedType), - }; - if stays_as_return { - if let Some(var) = func_decl_return_iter.next() { - return_arguments.push((*id, var.v_type.clone(), var.state_space)); - } else { - return Err(TranslateError::MismatchedType); - } - } else { - if let Some(var) = func_decl_input_iter.next() { - input_arguments.push(( - ast::Operand::Reg(*id), - var.v_type.clone(), - var.state_space, - )); - } else { - return Err(TranslateError::MismatchedType); - } - } - } - if return_arguments.len() != func_decl.return_arguments.len() - || input_arguments.len() != func_decl.input_arguments.len() - { - return Err(TranslateError::MismatchedType); - } - Ok(ResolvedCall { - return_arguments, - input_arguments, - uniform: call_inst.uniform, - name: call_inst.func, - }) - } } struct GlobalStringIdResolver<'input> { - current_id: spirv::Word, - variables: HashMap, spirv::Word>, - reverse_variables: HashMap, + current_id: SpirvWord, + variables: HashMap, SpirvWord>, + reverse_variables: HashMap, variables_type_check: HashMap>, special_registers: SpecialRegistersMap, - fns: HashMap>, + fns: HashMap>, } impl<'input> GlobalStringIdResolver<'input> { - fn new(start_id: spirv::Word) -> Self { + fn new(start_id: SpirvWord) -> Self { Self { current_id: start_id, variables: HashMap::new(), @@ -5269,7 +4802,7 @@ impl<'input> GlobalStringIdResolver<'input> { } } - fn get_or_add_def(&mut self, id: &'input str) -> spirv::Word { + fn get_or_add_def(&mut self, id: &'input str) -> SpirvWord { self.get_or_add_impl(id, None) } @@ -5279,7 +4812,7 @@ impl<'input> GlobalStringIdResolver<'input> { typ: ast::Type, state_space: ast::StateSpace, is_variable: bool, - ) -> spirv::Word { + ) -> SpirvWord { self.get_or_add_impl(id, Some((typ, state_space, is_variable))) } @@ -5287,7 +4820,7 @@ impl<'input> GlobalStringIdResolver<'input> { &mut self, id: &'input str, typ: Option<(ast::Type, ast::StateSpace, bool)>, - ) -> spirv::Word { + ) -> SpirvWord { let id = match self.variables.entry(Cow::Borrowed(id)) { hash_map::Entry::Occupied(e) => *(e.get()), hash_map::Entry::Vacant(e) => { @@ -5302,14 +4835,14 @@ impl<'input> GlobalStringIdResolver<'input> { id } - fn get_id(&self, id: &str) -> Result { + fn get_id(&self, id: &str) -> Result { self.variables .get(id) .copied() .ok_or_else(error_unknown_symbol) } - fn current_id(&self) -> spirv::Word { + fn current_id(&self) -> SpirvWord { self.current_id } @@ -5320,7 +4853,7 @@ impl<'input> GlobalStringIdResolver<'input> { ( FnStringIdResolver<'input, 'b>, GlobalFnDeclResolver<'input, 'b>, - Rc>>, + Rc>>, ), TranslateError, > { @@ -5363,21 +4896,21 @@ impl<'input> GlobalStringIdResolver<'input> { } pub struct GlobalFnDeclResolver<'input, 'a> { - fns: &'a HashMap>, + fns: &'a HashMap>, } impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { - fn get_fn_sig_resolver(&self, id: spirv::Word) -> Result<&FnSigMapper<'input>, TranslateError> { + fn get_fn_sig_resolver(&self, id: SpirvWord) -> Result<&FnSigMapper<'input>, TranslateError> { self.fns.get(&id).ok_or_else(error_unknown_symbol) } } struct FnStringIdResolver<'input, 'b> { - current_id: &'b mut spirv::Word, - global_variables: &'b HashMap, spirv::Word>, + current_id: &'b mut SpirvWord, + global_variables: &'b HashMap, SpirvWord>, global_type_check: &'b HashMap>, special_registers: &'b mut SpecialRegistersMap, - variables: Vec, spirv::Word>>, + variables: Vec, SpirvWord>>, type_check: HashMap>, } @@ -5399,7 +4932,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { self.variables.pop(); } - fn get_id(&mut self, id: &str) -> Result { + fn get_id(&mut self, id: &str) -> Result { for scope in self.variables.iter().rev() { match scope.get(id) { Some(id) => return Ok(*id), @@ -5420,7 +4953,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { id: &'a str, typ: Option<(ast::Type, ast::StateSpace)>, is_variable: bool, - ) -> spirv::Word { + ) -> SpirvWord { let numeric_id = *self.current_id; self.variables .last_mut() @@ -5442,25 +4975,27 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { typ: ast::Type, state_space: ast::StateSpace, is_variable: bool, - ) -> impl Iterator { + ) -> impl Iterator { let numeric_id = *self.current_id; for i in 0..count { - self.variables - .last_mut() - .unwrap() - .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i); + self.variables.last_mut().unwrap().insert( + Cow::Owned(format!("{}{}", base_id, i)), + SpirvWord(numeric_id.0 + i), + ); self.type_check.insert( - numeric_id + i, + numeric_id.0 + i, Some((typ.clone(), state_space, is_variable)), ); } - *self.current_id += count; - (0..count).into_iter().map(move |i| i + numeric_id) + self.current_id.0 += count; + (0..count) + .into_iter() + .map(move |i| SpirvWord(i + numeric_id.0)) } } struct NumericIdResolver<'b> { - current_id: &'b mut spirv::Word, + current_id: &'b mut SpirvWord, global_type_check: &'b HashMap>, type_check: HashMap>, special_registers: &'b mut SpecialRegistersMap, @@ -5473,7 +5008,7 @@ impl<'b> NumericIdResolver<'b> { fn get_typed( &self, - id: spirv::Word, + id: SpirvWord, ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> { match self.type_check.get(&id) { Some(Some(x)) => Ok(x.clone()), @@ -5490,7 +5025,7 @@ impl<'b> NumericIdResolver<'b> { // This is for identifiers which will be emitted later as OpVariable // They are candidates for insertion of LoadVar/StoreVar - fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> spirv::Word { + fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord { let new_id = *self.current_id; self.type_check .insert(new_id, Some((typ, state_space, true))); @@ -5498,7 +5033,7 @@ impl<'b> NumericIdResolver<'b> { new_id } - fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> spirv::Word { + fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord { let new_id = *self.current_id; self.type_check .insert(new_id, typ.map(|(t, space)| (t, space, false))); @@ -5516,76 +5051,38 @@ impl<'b> MutableNumericIdResolver<'b> { self.base } - fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, ast::StateSpace), TranslateError> { + fn get_typed(&self, id: SpirvWord) -> Result<(ast::Type, ast::StateSpace), TranslateError> { self.base.get_typed(id).map(|(t, space, _)| (t, space)) } - fn register_intermediate( - &mut self, - typ: ast::Type, - state_space: ast::StateSpace, - ) -> spirv::Word { + fn register_intermediate(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord { self.base.register_intermediate(Some((typ, state_space))) } } struct FunctionPointerDetails { - dst: spirv::Word, - src: spirv::Word, + dst: SpirvWord, + src: SpirvWord, } -impl, U: ArgParamsEx> Visitable - for FunctionPointerDetails -{ - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, U>, TranslateError> { - Ok(Statement::FunctionPointer(FunctionPointerDetails { - dst: visitor.id( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some(( - &ast::Type::Scalar(ast::ScalarType::U64), - ast::StateSpace::Reg, - )), - )?, - src: visitor.id( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - None, - )?, - })) - } -} - -enum Statement { +enum Statement { Label(u32), - Variable(ast::Variable), + Variable(ast::Variable), Instruction(I), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), - Call(ResolvedCall

), LoadVar(LoadVarDetails), StoreVar(StoreVarDetails), Conversion(ImplicitConversion), Constant(ConstantDefinition), - RetValue(ast::RetData, spirv::Word), + RetValue(ast::RetData, SpirvWord), PtrAccess(PtrAccess

), RepackVector(RepackVectorDetails), FunctionPointer(FunctionPointerDetails), } impl ExpandedStatement { - fn map_id(self, f: &mut impl FnMut(spirv::Word, bool) -> spirv::Word) -> ExpandedStatement { + fn map_id(self, f: &mut impl FnMut(SpirvWord, bool) -> SpirvWord) -> ExpandedStatement { match self { Statement::Label(id) => Statement::Label(f(id, false)), Statement::Variable(mut var) => { @@ -5685,7 +5182,7 @@ impl ExpandedStatement { } struct LoadVarDetails { - arg: ast::Arg2, + arg: ast::MovArgs, typ: ast::Type, state_space: ast::StateSpace, // (index, vector_width) @@ -5697,7 +5194,7 @@ struct LoadVarDetails { } struct StoreVarDetails { - arg: ast::Arg2St, + arg: ast::StArgs, typ: ast::Type, member_index: Option, } @@ -5705,8 +5202,8 @@ struct StoreVarDetails { struct RepackVectorDetails { is_extract: bool, typ: ast::ScalarType, - packed: spirv::Word, - unpacked: Vec, + packed: SpirvWord, + unpacked: Vec, non_default_implicit_conversion: Option< fn( (ast::StateSpace, &ast::Type), @@ -5715,247 +5212,43 @@ struct RepackVectorDetails { >, } -impl RepackVectorDetails { - fn map< - From: ArgParamsEx, - To: ArgParamsEx, - V: ArgumentMapVisitor, - >( - self, - visitor: &mut V, - ) -> Result { - let scalar = visitor.id( - ArgumentDescriptor { - op: self.packed, - is_dst: !self.is_extract, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some(( - &ast::Type::Vector(self.typ, self.unpacked.len() as u8), - ast::StateSpace::Reg, - )), - )?; - let scalar_type = self.typ; - let is_extract = self.is_extract; - let non_default_implicit_conversion = self.non_default_implicit_conversion; - let vector = self - .unpacked - .into_iter() - .map(|id| { - visitor.id( - ArgumentDescriptor { - op: id, - is_dst: is_extract, - is_memory_access: false, - non_default_implicit_conversion, - }, - Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)), - ) - }) - .collect::>()?; - Ok(RepackVectorDetails { - is_extract, - typ: self.typ, - packed: scalar, - unpacked: vector, - non_default_implicit_conversion, - }) +impl ast::Operand for SpirvWord { + type Ident = SpirvWord; +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +struct SpirvWord(spirv::Word); + +impl From for SpirvWord { + fn from(value: spirv::Word) -> Self { + Self(value) + } +} +impl From for spirv::Word { + fn from(value: SpirvWord) -> Self { + value.0 } } -impl, U: ArgParamsEx> Visitable - for RepackVectorDetails -{ - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, U>, TranslateError> { - Ok(Statement::RepackVector(self.map::<_, _, _>(visitor)?)) - } -} +type NormalizedStatement = + Statement<(Option>, ast::Instruction), SpirvWord>; -struct ResolvedCall { - pub uniform: bool, - pub return_arguments: Vec<(P::Id, ast::Type, ast::StateSpace)>, - pub name: P::Id, - pub input_arguments: Vec<(P::Operand, ast::Type, ast::StateSpace)>, -} - -impl ResolvedCall { - fn cast>(self) -> ResolvedCall { - ResolvedCall { - uniform: self.uniform, - return_arguments: self.return_arguments, - name: self.name, - input_arguments: self.input_arguments, - } - } -} - -impl> ResolvedCall { - fn map, V: ArgumentMapVisitor>( - self, - visitor: &mut V, - ) -> Result, TranslateError> { - let return_arguments = self - .return_arguments - .into_iter() - .map::, _>(|(id, typ, space)| { - let new_id = visitor.id( - ArgumentDescriptor { - op: id, - is_dst: space != ast::StateSpace::Param, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some((&typ, space)), - )?; - Ok((new_id, typ, space)) - }) - .collect::, _>>()?; - let func = visitor.id( - ArgumentDescriptor { - op: self.name, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - None, - )?; - let input_arguments = self - .input_arguments - .into_iter() - .map::, _>(|(id, typ, space)| { - let new_id = visitor.operand( - ArgumentDescriptor { - op: id, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &typ, - space, - )?; - Ok((new_id, typ, space)) - }) - .collect::, _>>()?; - Ok(ResolvedCall { - uniform: self.uniform, - return_arguments, - name: func, - input_arguments, - }) - } -} - -impl, U: ArgParamsEx> Visitable - for ResolvedCall -{ - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, U>, TranslateError> { - Ok(Statement::Call(self.map(visitor)?)) - } -} - -impl> PtrAccess

{ - fn map, V: ArgumentMapVisitor>( - self, - visitor: &mut V, - ) -> Result, TranslateError> { - let new_dst = visitor.id( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some((&self.underlying_type, self.state_space)), - )?; - let new_ptr_src = visitor.id( - ArgumentDescriptor { - op: self.ptr_src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some((&self.underlying_type, self.state_space)), - )?; - let new_constant_src = visitor.operand( - ArgumentDescriptor { - op: self.offset_src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::S64), - ast::StateSpace::Reg, - )?; - Ok(PtrAccess { - underlying_type: self.underlying_type, - state_space: self.state_space, - dst: new_dst, - ptr_src: new_ptr_src, - offset_src: new_constant_src, - }) - } -} - -impl, U: ArgParamsEx> Visitable - for PtrAccess -{ - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, U>, TranslateError> { - Ok(Statement::PtrAccess(self.map(visitor)?)) - } -} - -pub trait ArgParamsEx: ast::ArgParams + Sized {} - -impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {} - -enum NormalizedArgParams {} - -impl ast::ArgParams for NormalizedArgParams { - type Id = spirv::Word; - type Operand = ast::Operand; -} - -impl ArgParamsEx for NormalizedArgParams {} - -type NormalizedStatement = Statement< - ( - Option>, - ast::Instruction, - ), - NormalizedArgParams, ->; - -type UnconditionalStatement = Statement, NormalizedArgParams>; - -enum TypedArgParams {} - -impl ast::ArgParams for TypedArgParams { - type Id = spirv::Word; - type Operand = TypedOperand; -} - -impl ArgParamsEx for TypedArgParams {} +type UnconditionalStatement = Statement, SpirvWord>; #[derive(Copy, Clone)] enum TypedOperand { - Reg(spirv::Word), - RegOffset(spirv::Word, i32), + Reg(SpirvWord), + RegOffset(SpirvWord, i32), Imm(ast::ImmediateValue), - VecMember(spirv::Word, u8), + VecMember(SpirvWord, u8), +} + +impl ast::Operand for TypedOperand { + type Ident = SpirvWord; } impl TypedOperand { - fn upcast(self) -> ast::Operand { + fn upcast(self) -> ast::Operand { match self { TypedOperand::Reg(reg) => ast::Operand::Reg(reg), TypedOperand::RegOffset(reg, idx) => ast::Operand::RegOffset(reg, idx), @@ -5965,103 +5258,23 @@ impl TypedOperand { } } -type TypedStatement = Statement, TypedArgParams>; - -enum ExpandedArgParams {} -type ExpandedStatement = Statement, ExpandedArgParams>; - -impl ast::ArgParams for ExpandedArgParams { - type Id = spirv::Word; - type Operand = spirv::Word; -} - -impl ArgParamsEx for ExpandedArgParams {} +type TypedStatement = Statement, TypedOperand>; +type ExpandedStatement = Statement, SpirvWord>; enum Directive<'input> { - Variable(ast::LinkingDirective, ast::Variable), + Variable(ast::LinkingDirective, ast::Variable), Method(Function<'input>), } struct Function<'input> { - pub func_decl: Rc>>, - pub globals: Vec>, + pub func_decl: Rc>>, + pub globals: Vec>, pub body: Option>, import_as: Option, tuning: Vec, linkage: ast::LinkingDirective, } -pub trait ArgumentMapVisitor { - fn id( - &mut self, - desc: ArgumentDescriptor, - typ: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result; - fn operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result; -} - -impl ArgumentMapVisitor for T -where - T: FnMut( - ArgumentDescriptor, - Option<(&ast::Type, ast::StateSpace)>, - ) -> Result, -{ - fn id( - &mut self, - desc: ArgumentDescriptor, - t: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - self(desc, t) - } - - fn operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result { - self(desc, Some((typ, state_space))) - } -} - -impl<'a, T> ArgumentMapVisitor, NormalizedArgParams> for T -where - T: FnMut(&str) -> Result, -{ - fn id( - &mut self, - desc: ArgumentDescriptor<&str>, - _: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - self(desc.op) - } - - fn operand( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result, TranslateError> { - Ok(match desc.op { - ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?), - ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id)?, imm), - ast::Operand::Imm(imm) => ast::Operand::Imm(imm), - ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member), - ast::Operand::VecPack(ref ids) => ast::Operand::VecPack( - ids.into_iter() - .map(|id| self.id(desc.new_op(id), Some((typ, state_space)))) - .collect::, _>>()?, - ), - }) - } -} - pub struct ArgumentDescriptor { op: Op, is_dst: bool, @@ -6074,12 +5287,12 @@ pub struct ArgumentDescriptor { >, } -pub struct PtrAccess { +pub struct PtrAccess { underlying_type: ast::Type, state_space: ast::StateSpace, - dst: spirv::Word, - ptr_src: spirv::Word, - offset_src: P::Operand, + dst: SpirvWord, + ptr_src: SpirvWord, + offset_src: P::Ident, } impl ArgumentDescriptor { @@ -6093,410 +5306,86 @@ impl ArgumentDescriptor { } } -impl ast::Instruction { - fn map>( - self, - visitor: &mut V, - ) -> Result, TranslateError> { - Ok(match self { - ast::Instruction::Abs(d, arg) => { - ast::Instruction::Abs(d, arg.map(visitor, &ast::Type::Scalar(d.typ))?) +fn type_widen(this: ast::Type) -> Result { + match this { + ast::Type::Scalar(scalar) => { + let kind = scalar.kind(); + let width = scalar.size_of(); + if (kind != ast::ScalarKind::Signed + && kind != ast::ScalarKind::Unsigned + && kind != ast::ScalarKind::Bit) + || (width == 8) + { + return Err(TranslateError::MismatchedType); } - // Call instruction is converted to a call statement early on - ast::Instruction::Call(_) => return Err(error_unreachable()), - ast::Instruction::Ld(d, a) => { - let new_args = a.map(visitor, &d)?; - ast::Instruction::Ld(d, new_args) - } - ast::Instruction::Mov(d, a) => { - let mapped = a.map(visitor, &d)?; - ast::Instruction::Mov(d, mapped) - } - ast::Instruction::Mul(d, a) => { - let inst_type = d.get_type(); - let is_wide = d.is_wide(); - ast::Instruction::Mul(d, a.map_non_shift(visitor, &inst_type, is_wide)?) - } - ast::Instruction::Add(d, a) => { - let inst_type = d.get_type(); - ast::Instruction::Add(d, a.map_non_shift(visitor, &inst_type, false)?) - } - ast::Instruction::Setp(d, a) => { - let inst_type = d.typ; - ast::Instruction::Setp(d, a.map(visitor, &ast::Type::Scalar(inst_type))?) - } - ast::Instruction::SetpBool(d, a) => { - let inst_type = d.typ; - ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?) - } - ast::Instruction::Not(t, a) => { - ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?) - } - ast::Instruction::Cvt(d, a) => { - let (dst_t, src_t, int_to_int) = match &d { - ast::CvtDetails::FloatFromFloat(desc) => ((desc.dst, desc.src, false)), - ast::CvtDetails::FloatFromInt(desc) => ((desc.dst, desc.src, false)), - ast::CvtDetails::IntFromFloat(desc) => ((desc.dst, desc.src, false)), - ast::CvtDetails::IntFromInt(desc) => ((desc.dst, desc.src, true)), - }; - ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t, int_to_int)?) - } - ast::Instruction::Shl(t, a) => { - ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?) - } - ast::Instruction::Shr(t, a) => { - ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?) - } - ast::Instruction::St(d, a) => { - let new_args = a.map(visitor, &d)?; - ast::Instruction::St(d, new_args) - } - ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, false, None)?), - ast::Instruction::Ret(d) => ast::Instruction::Ret(d), - ast::Instruction::Cvta(d, a) => { - let inst_type = ast::Type::Scalar(ast::ScalarType::B64); - ast::Instruction::Cvta(d, a.map(visitor, &inst_type)?) - } - ast::Instruction::Mad(d, a) => { - let inst_type = d.get_type(); - let is_wide = d.is_wide(); - ast::Instruction::Mad(d, a.map(visitor, &inst_type, is_wide)?) - } - ast::Instruction::Fma(d, a) => { - let inst_type = ast::Type::Scalar(d.typ); - ast::Instruction::Fma(d, a.map(visitor, &inst_type, false)?) - } - ast::Instruction::Or(t, a) => ast::Instruction::Or( - t, - a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?, - ), - ast::Instruction::Sub(d, a) => { - let typ = d.get_type(); - ast::Instruction::Sub(d, a.map_non_shift(visitor, &typ, false)?) - } - ast::Instruction::Min(d, a) => { - let typ = d.get_type(); - ast::Instruction::Min(d, a.map_non_shift(visitor, &typ, false)?) - } - ast::Instruction::Max(d, a) => { - let typ = d.get_type(); - ast::Instruction::Max(d, a.map_non_shift(visitor, &typ, false)?) - } - ast::Instruction::Rcp(d, a) => { - let typ = ast::Type::Scalar(if d.is_f64 { - ast::ScalarType::F64 - } else { - ast::ScalarType::F32 - }); - ast::Instruction::Rcp(d, a.map(visitor, &typ)?) - } - ast::Instruction::And(t, a) => ast::Instruction::And( - t, - a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?, - ), - ast::Instruction::Selp(t, a) => ast::Instruction::Selp(t, a.map_selp(visitor, t)?), - ast::Instruction::Bar(d, a) => ast::Instruction::Bar(d, a.map(visitor)?), - ast::Instruction::Atom(d, a) => { - ast::Instruction::Atom(d, a.map_atom(visitor, d.inner.get_type(), d.space)?) - } - ast::Instruction::AtomCas(d, a) => { - ast::Instruction::AtomCas(d, a.map_atom(visitor, d.typ, d.space)?) - } - ast::Instruction::Div(d, a) => { - ast::Instruction::Div(d, a.map_non_shift(visitor, &d.get_type(), false)?) - } - ast::Instruction::Sqrt(d, a) => { - ast::Instruction::Sqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?) - } - ast::Instruction::Rsqrt(d, a) => { - ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?) - } - ast::Instruction::Neg(d, a) => { - ast::Instruction::Neg(d, a.map(visitor, &ast::Type::Scalar(d.typ))?) - } - ast::Instruction::Sin { flush_to_zero, arg } => { - let typ = ast::Type::Scalar(ast::ScalarType::F32); - ast::Instruction::Sin { - flush_to_zero, - arg: arg.map(visitor, &typ)?, - } - } - ast::Instruction::Cos { flush_to_zero, arg } => { - let typ = ast::Type::Scalar(ast::ScalarType::F32); - ast::Instruction::Cos { - flush_to_zero, - arg: arg.map(visitor, &typ)?, - } - } - ast::Instruction::Lg2 { flush_to_zero, arg } => { - let typ = ast::Type::Scalar(ast::ScalarType::F32); - ast::Instruction::Lg2 { - flush_to_zero, - arg: arg.map(visitor, &typ)?, - } - } - ast::Instruction::Ex2 { flush_to_zero, arg } => { - let typ = ast::Type::Scalar(ast::ScalarType::F32); - ast::Instruction::Ex2 { - flush_to_zero, - arg: arg.map(visitor, &typ)?, - } - } - ast::Instruction::Clz { typ, arg } => { - let dst_type = ast::Type::Scalar(ast::ScalarType::B32); - let src_type = ast::Type::Scalar(typ.into()); - ast::Instruction::Clz { - typ, - arg: arg.map_different_types(visitor, &dst_type, &src_type)?, - } - } - ast::Instruction::Brev { typ, arg } => { - let full_type = ast::Type::Scalar(typ.into()); - ast::Instruction::Brev { - typ, - arg: arg.map(visitor, &full_type)?, - } - } - ast::Instruction::Popc { typ, arg } => { - let dst_type = ast::Type::Scalar(ast::ScalarType::B32); - let src_type = ast::Type::Scalar(typ.into()); - ast::Instruction::Popc { - typ, - arg: arg.map_different_types(visitor, &dst_type, &src_type)?, - } - } - ast::Instruction::Xor { typ, arg } => { - let full_type = ast::Type::Scalar(typ.into()); - ast::Instruction::Xor { - typ, - arg: arg.map_non_shift(visitor, &full_type, false)?, - } - } - ast::Instruction::Bfe { typ, arg } => { - let full_type = ast::Type::Scalar(typ.into()); - ast::Instruction::Bfe { - typ, - arg: arg.map_bfe(visitor, &full_type)?, - } - } - ast::Instruction::Bfi { typ, arg } => { - let full_type = ast::Type::Scalar(typ.into()); - ast::Instruction::Bfi { - typ, - arg: arg.map_bfi(visitor, &full_type)?, - } - } - ast::Instruction::Rem { typ, arg } => { - let full_type = ast::Type::Scalar(typ.into()); - ast::Instruction::Rem { - typ, - arg: arg.map_non_shift(visitor, &full_type, false)?, - } - } - ast::Instruction::Prmt { control, arg } => ast::Instruction::Prmt { - control, - arg: arg.map_prmt(visitor)?, - }, - ast::Instruction::Activemask { arg } => ast::Instruction::Activemask { - arg: arg.map( - visitor, - true, - Some(( - &ast::Type::Scalar(ast::ScalarType::B32), - ast::StateSpace::Reg, - )), - )?, - }, - ast::Instruction::Membar { level } => ast::Instruction::Membar { level }, - }) - } -} - -impl Visitable for ast::Instruction { - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, U>, TranslateError> { - Ok(Statement::Instruction(self.map(visitor)?)) - } -} - -impl ImplicitConversion { - fn map< - T: ArgParamsEx, - U: ArgParamsEx, - V: ArgumentMapVisitor, - >( - self, - visitor: &mut V, - ) -> Result, U>, TranslateError> { - let new_dst = visitor.id( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some((&self.to_type, self.to_space)), - )?; - let new_src = visitor.id( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some((&self.from_type, self.from_space)), - )?; - Ok(Statement::Conversion({ - ImplicitConversion { - src: new_src, - dst: new_dst, - ..self - } - })) - } -} - -impl, To: ArgParamsEx> Visitable - for ImplicitConversion -{ - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, To>, TranslateError> { - Ok(self.map(visitor)?) - } -} - -impl ArgumentMapVisitor for T -where - T: FnMut( - ArgumentDescriptor, - Option<(&ast::Type, ast::StateSpace)>, - ) -> Result, -{ - fn id( - &mut self, - desc: ArgumentDescriptor, - t: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - self(desc, t) - } - - fn operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result { - Ok(match desc.op { - TypedOperand::Reg(id) => { - TypedOperand::Reg(self(desc.new_op(id), Some((typ, state_space)))?) - } - TypedOperand::Imm(imm) => TypedOperand::Imm(imm), - TypedOperand::RegOffset(id, imm) => { - TypedOperand::RegOffset(self(desc.new_op(id), Some((typ, state_space)))?, imm) - } - TypedOperand::VecMember(reg, index) => { - let scalar_type = match typ { - ast::Type::Scalar(scalar_t) => *scalar_t, - _ => return Err(error_unreachable()), - }; - let vec_type = ast::Type::Vector(scalar_type, index + 1); - TypedOperand::VecMember( - self(desc.new_op(reg), Some((&vec_type, state_space)))?, - index, - ) - } - }) - } -} - -impl ast::Type { - fn widen(self) -> Result { - match self { - ast::Type::Scalar(scalar) => { - let kind = scalar.kind(); - let width = scalar.size_of(); - if (kind != ast::ScalarKind::Signed - && kind != ast::ScalarKind::Unsigned - && kind != ast::ScalarKind::Bit) - || (width == 8) - { - return Err(TranslateError::MismatchedType); - } - Ok(ast::Type::Scalar(ast::ScalarType::from_parts( - width * 2, - kind, - ))) - } - _ => Err(error_unreachable()), + Ok(ast::Type::Scalar(ast::ScalarType::from_parts( + width * 2, + kind, + ))) } + _ => Err(error_unreachable()), } +} - fn to_parts(&self) -> TypeParts { - match self { - ast::Type::Scalar(scalar) => TypeParts { - kind: TypeKind::Scalar, - state_space: ast::StateSpace::Reg, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: Vec::new(), - }, - ast::Type::Vector(scalar, components) => TypeParts { - kind: TypeKind::Vector, - state_space: ast::StateSpace::Reg, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: vec![*components as u32], - }, - ast::Type::Array(scalar, components) => TypeParts { - kind: TypeKind::Array, - state_space: ast::StateSpace::Reg, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: components.clone(), - }, - ast::Type::Pointer(scalar, space) => TypeParts { - kind: TypeKind::Pointer, - state_space: *space, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: Vec::new(), - }, - } +fn to_parts(this: &ast::Type) -> TypeParts { + match this { + ast::Type::Scalar(scalar) => TypeParts { + kind: TypeKind::Scalar, + state_space: ast::StateSpace::Reg, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: Vec::new(), + }, + ast::Type::Vector(scalar, components) => TypeParts { + kind: TypeKind::Vector, + state_space: ast::StateSpace::Reg, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: vec![*components as u32], + }, + ast::Type::Array(scalar, components) => TypeParts { + kind: TypeKind::Array, + state_space: ast::StateSpace::Reg, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: components.clone(), + }, + ast::Type::Pointer(scalar, space) => TypeParts { + kind: TypeKind::Pointer, + state_space: *space, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: Vec::new(), + }, } +} - fn from_parts(t: TypeParts) -> Self { - match t.kind { - TypeKind::Scalar => { - ast::Type::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)) - } - TypeKind::Vector => ast::Type::Vector( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - t.components[0] as u8, - ), - TypeKind::Array => ast::Type::Array( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - t.components, - ), - TypeKind::Pointer => ast::Type::Pointer( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - t.state_space, - ), - } +fn from_parts(t: TypeParts) -> ast::Type { + match t.kind { + TypeKind::Scalar => ast::Type::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)), + TypeKind::Vector => ast::Type::Vector( + ast::ScalarType::from_parts(t.width, t.scalar_kind), + t.components[0] as u8, + ), + TypeKind::Array => ast::Type::Array( + ast::ScalarType::from_parts(t.width, t.scalar_kind), + t.components, + ), + TypeKind::Pointer => ast::Type::Pointer( + ast::ScalarType::from_parts(t.width, t.scalar_kind), + t.state_space, + ), } +} - pub fn size_of(&self) -> usize { - match self { - ast::Type::Scalar(typ) => typ.size_of() as usize, - ast::Type::Vector(typ, len) => (typ.size_of() as usize) * (*len as usize), - ast::Type::Array(typ, len) => len - .iter() - .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), - ast::Type::Pointer(..) => mem::size_of::(), - } +pub fn size_of(this: &ast::Type) -> usize { + match this { + ast::Type::Scalar(typ) => typ.size_of() as usize, + ast::Type::Vector(typ, len) => (typ.size_of() as usize) * (*len as usize), + ast::Type::Array(typ, len) => len + .iter() + .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), + ast::Type::Pointer(..) => mem::size_of::(), } } @@ -6517,164 +5406,140 @@ enum TypeKind { Pointer, } -impl ast::Instruction { - fn jump_target(&self) -> Option { - match self { - ast::Instruction::Bra(_, a) => Some(a.src), - _ => None, - } - } - - // .wide instructions don't support ftz, so it's enough to just look at the - // type declared by the instruction - fn flush_to_zero(&self) -> Option<(bool, u8)> { - match self { - ast::Instruction::Ld(_, _) => None, - ast::Instruction::St(_, _) => None, - ast::Instruction::Mov(_, _) => None, - ast::Instruction::Not(_, _) => None, - ast::Instruction::Bra(_, _) => None, - ast::Instruction::Shl(_, _) => None, - ast::Instruction::Shr(_, _) => None, - ast::Instruction::Ret(_) => None, - ast::Instruction::Call(_) => None, - ast::Instruction::Or(_, _) => None, - ast::Instruction::And(_, _) => None, - ast::Instruction::Cvta(_, _) => None, - ast::Instruction::Selp(_, _) => None, - ast::Instruction::Bar(_, _) => None, - ast::Instruction::Atom(_, _) => None, - ast::Instruction::AtomCas(_, _) => None, - ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None, - ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None, - ast::Instruction::Add(ast::ArithDetails::Signed(_), _) => None, - ast::Instruction::Add(ast::ArithDetails::Unsigned(_), _) => None, - ast::Instruction::Mul(ast::MulDetails::Unsigned(_), _) => None, - ast::Instruction::Mul(ast::MulDetails::Signed(_), _) => None, - ast::Instruction::Mad(ast::MulDetails::Unsigned(_), _) => None, - ast::Instruction::Mad(ast::MulDetails::Signed(_), _) => None, - ast::Instruction::Min(ast::MinMaxDetails::Signed(_), _) => None, - ast::Instruction::Min(ast::MinMaxDetails::Unsigned(_), _) => None, - ast::Instruction::Max(ast::MinMaxDetails::Signed(_), _) => None, - ast::Instruction::Max(ast::MinMaxDetails::Unsigned(_), _) => None, - ast::Instruction::Cvt(ast::CvtDetails::IntFromInt(_), _) => None, - ast::Instruction::Cvt(ast::CvtDetails::FloatFromInt(_), _) => None, - ast::Instruction::Div(ast::DivDetails::Unsigned(_), _) => None, - ast::Instruction::Div(ast::DivDetails::Signed(_), _) => None, - ast::Instruction::Clz { .. } => None, - ast::Instruction::Brev { .. } => None, - ast::Instruction::Popc { .. } => None, - ast::Instruction::Xor { .. } => None, - ast::Instruction::Bfe { .. } => None, - ast::Instruction::Bfi { .. } => None, - ast::Instruction::Rem { .. } => None, - ast::Instruction::Prmt { .. } => None, - ast::Instruction::Activemask { .. } => None, - ast::Instruction::Membar { .. } => None, - ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _) - | ast::Instruction::Add(ast::ArithDetails::Float(float_control), _) - | ast::Instruction::Mul(ast::MulDetails::Float(float_control), _) - | ast::Instruction::Mad(ast::MulDetails::Float(float_control), _) => float_control - .flush_to_zero - .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())), - ast::Instruction::Fma(d, _) => d.flush_to_zero.map(|ftz| (ftz, d.typ.size_of())), - ast::Instruction::Setp(details, _) => details - .flush_to_zero - .map(|ftz| (ftz, details.typ.size_of())), - ast::Instruction::SetpBool(details, _) => details - .flush_to_zero - .map(|ftz| (ftz, details.typ.size_of())), - ast::Instruction::Abs(details, _) => details - .flush_to_zero - .map(|ftz| (ftz, details.typ.size_of())), - ast::Instruction::Min(ast::MinMaxDetails::Float(float_control), _) - | ast::Instruction::Max(ast::MinMaxDetails::Float(float_control), _) => float_control - .flush_to_zero - .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())), - ast::Instruction::Rcp(details, _) => details - .flush_to_zero - .map(|ftz| (ftz, if details.is_f64 { 8 } else { 4 })), - // Modifier .ftz can only be specified when either .dtype or .atype - // is .f32 and applies only to single precision (.f32) inputs and results. - ast::Instruction::Cvt( - ast::CvtDetails::FloatFromFloat(ast::CvtDesc { flush_to_zero, .. }), - _, - ) - | ast::Instruction::Cvt( - ast::CvtDetails::IntFromFloat(ast::CvtDesc { flush_to_zero, .. }), - _, - ) => flush_to_zero.map(|ftz| (ftz, 4)), - ast::Instruction::Div(ast::DivDetails::Float(details), _) => details - .flush_to_zero - .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())), - ast::Instruction::Sqrt(details, _) => details - .flush_to_zero - .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())), - ast::Instruction::Rsqrt(details, _) => Some(( - details.flush_to_zero, - ast::ScalarType::from(details.typ).size_of(), - )), - ast::Instruction::Neg(details, _) => details - .flush_to_zero - .map(|ftz| (ftz, details.typ.size_of())), - ast::Instruction::Sin { flush_to_zero, .. } - | ast::Instruction::Cos { flush_to_zero, .. } - | ast::Instruction::Lg2 { flush_to_zero, .. } - | ast::Instruction::Ex2 { flush_to_zero, .. } => { - Some((*flush_to_zero, mem::size_of::() as u8)) - } - } +fn inst_jump_target(this: &ast::Instruction) -> Option { + match this { + ast::Instruction::Bra(_, a) => Some(a.src), + _ => None, } } -type Arg2 = ast::Arg2; -type Arg2St = ast::Arg2St; +// .wide instructions don't support ftz, so it's enough to just look at the +// type declared by the instruction +fn inst_flush_to_zero(this: &ast::Instruction) -> Option<(bool, u8)> { + match this { + ast::Instruction::Ld { .. } => None, + ast::Instruction::St { .. } => None, + ast::Instruction::Mov { .. } => None, + ast::Instruction::Not { .. } => None, + ast::Instruction::Bra { .. } => None, + ast::Instruction::Shl { .. } => None, + ast::Instruction::Shr { .. } => None, + ast::Instruction::Ret { .. } => None, + ast::Instruction::Call { .. } => None, + ast::Instruction::Or { .. } => None, + ast::Instruction::And { .. } => None, + ast::Instruction::Cvta { .. } => None, + ast::Instruction::Selp { .. } => None, + ast::Instruction::Bar { .. } => None, + ast::Instruction::Atom { .. } => None, + ast::Instruction::AtomCas { .. } => None, + ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None, + ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None, + ast::Instruction::Add(ast::ArithDetails::Signed(_), _) => None, + ast::Instruction::Add(ast::ArithDetails::Unsigned(_), _) => None, + ast::Instruction::Mul(ast::MulDetails::Unsigned(_), _) => None, + ast::Instruction::Mul(ast::MulDetails::Signed(_), _) => None, + ast::Instruction::Mad(ast::MulDetails::Unsigned(_), _) => None, + ast::Instruction::Mad(ast::MulDetails::Signed(_), _) => None, + ast::Instruction::Min(ast::MinMaxDetails::Signed(_), _) => None, + ast::Instruction::Min(ast::MinMaxDetails::Unsigned(_), _) => None, + ast::Instruction::Max(ast::MinMaxDetails::Signed(_), _) => None, + ast::Instruction::Max(ast::MinMaxDetails::Unsigned(_), _) => None, + ast::Instruction::Cvt(ast::CvtDetails::IntFromInt(_), _) => None, + ast::Instruction::Cvt(ast::CvtDetails::FloatFromInt(_), _) => None, + ast::Instruction::Div(ast::DivDetails::Unsigned(_), _) => None, + ast::Instruction::Div(ast::DivDetails::Signed(_), _) => None, + ast::Instruction::Clz { .. } => None, + ast::Instruction::Brev { .. } => None, + ast::Instruction::Popc { .. } => None, + ast::Instruction::Xor { .. } => None, + ast::Instruction::Bfe { .. } => None, + ast::Instruction::Bfi { .. } => None, + ast::Instruction::Rem { .. } => None, + ast::Instruction::Prmt { .. } => None, + ast::Instruction::Activemask { .. } => None, + ast::Instruction::Membar { .. } => None, + ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _) + | ast::Instruction::Add(ast::ArithDetails::Float(float_control), _) + | ast::Instruction::Mul(ast::MulDetails::Float(float_control), _) + | ast::Instruction::Mad(ast::MulDetails::Float(float_control), _) => float_control + .flush_to_zero + .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())), + ast::Instruction::Fma(d, _) => d.flush_to_zero.map(|ftz| (ftz, d.typ.size_of())), + ast::Instruction::Setp(details, _) => details + .flush_to_zero + .map(|ftz| (ftz, details.typ.size_of())), + ast::Instruction::SetpBool(details, _) => details + .flush_to_zero + .map(|ftz| (ftz, details.typ.size_of())), + ast::Instruction::Abs(details, _) => details + .flush_to_zero + .map(|ftz| (ftz, details.typ.size_of())), + ast::Instruction::Min(ast::MinMaxDetails::Float(float_control), _) + | ast::Instruction::Max(ast::MinMaxDetails::Float(float_control), _) => float_control + .flush_to_zero + .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())), + ast::Instruction::Rcp(details, _) => details + .flush_to_zero + .map(|ftz| (ftz, if details.is_f64 { 8 } else { 4 })), + // Modifier .ftz can only be specified when either .dtype or .atype + // is .f32 and applies only to single precision (.f32) inputs and results. + ast::Instruction::Cvt( + ast::CvtDetails::FloatFromFloat(ast::CvtDesc { flush_to_zero, .. }), + _, + ) + | ast::Instruction::Cvt( + ast::CvtDetails::IntFromFloat(ast::CvtDesc { flush_to_zero, .. }), + _, + ) => flush_to_zero.map(|ftz| (ftz, 4)), + ast::Instruction::Div(ast::DivDetails::Float(details), _) => details + .flush_to_zero + .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())), + ast::Instruction::Sqrt(details, _) => details + .flush_to_zero + .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())), + ast::Instruction::Rsqrt(details, _) => Some(( + details.flush_to_zero, + ast::ScalarType::from(details.typ).size_of(), + )), + ast::Instruction::Neg(details, _) => details + .flush_to_zero + .map(|ftz| (ftz, details.typ.size_of())), + ast::Instruction::Sin { + data: ast::FlushToZero { flush_to_zero }, + .. + } + | ast::Instruction::Cos { + data: ast::FlushToZero { flush_to_zero }, + .. + } + | ast::Instruction::Lg2 { + data: ast::FlushToZero { flush_to_zero }, + .. + } + | ast::Instruction::Ex2 { + data: ast::FlushToZero { flush_to_zero }, + .. + } => Some((*flush_to_zero, mem::size_of::() as u8)), + } +} struct ConstantDefinition { - pub dst: spirv::Word, + pub dst: SpirvWord, pub typ: ast::ScalarType, pub value: ast::ImmediateValue, } struct BrachCondition { - predicate: spirv::Word, - if_true: spirv::Word, - if_false: spirv::Word, -} - -impl, To: ArgParamsEx> Visitable - for BrachCondition -{ - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, To>, TranslateError> { - let predicate = visitor.id( - ArgumentDescriptor { - op: self.predicate, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some(( - &ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - )), - )?; - let if_true = self.if_true; - let if_false = self.if_false; - Ok(Statement::Conditional(BrachCondition { - predicate, - if_true, - if_false, - })) - } + predicate: SpirvWord, + if_true: SpirvWord, + if_false: SpirvWord, } #[derive(Clone)] struct ImplicitConversion { - src: spirv::Word, - dst: spirv::Word, + src: SpirvWord, + dst: SpirvWord, from_type: ast::Type, to_type: ast::Type, from_space: ast::StateSpace, @@ -6692,1065 +5557,125 @@ enum ConversionKind { AddressOf, } -impl ast::PredAt { - fn map_variable Result>( - self, - f: &mut F, - ) -> Result, TranslateError> { - let new_label = f(self.label)?; - Ok(ast::PredAt { - not: self.not, - label: new_label, - }) +fn unwrap_reg(this: &ast::ParsedOperand) -> Result { + match this { + ast::Operand::Reg(reg) => Ok(*reg), + _ => Err(error_unreachable()), } } -impl<'a> ast::Instruction> { - fn map_variable Result>( - self, - f: &mut F, - ) -> Result, TranslateError> { - match self { - ast::Instruction::Call(call) => { - let call_inst = ast::CallInst { - uniform: call.uniform, - ret_params: call - .ret_params - .into_iter() - .map(|p| f(p)) - .collect::>()?, - func: f(call.func)?, - param_list: call - .param_list - .into_iter() - .map(|p| p.map_variable(f)) - .collect::>()?, - }; - Ok(ast::Instruction::Call(call_inst)) - } - i => i.map(f), +fn scalar_from_parts(width: u8, kind: ast::ScalarKind) -> ast::ScalarType { + match kind { + ast::ScalarKind::Float => match width { + 2 => ast::ScalarType::F16, + 4 => ast::ScalarType::F32, + 8 => ast::ScalarType::F64, + _ => unreachable!(), + }, + ast::ScalarKind::Bit => match width { + 1 => ast::ScalarType::B8, + 2 => ast::ScalarType::B16, + 4 => ast::ScalarType::B32, + 8 => ast::ScalarType::B64, + _ => unreachable!(), + }, + ast::ScalarKind::Signed => match width { + 1 => ast::ScalarType::S8, + 2 => ast::ScalarType::S16, + 4 => ast::ScalarType::S32, + 8 => ast::ScalarType::S64, + _ => unreachable!(), + }, + ast::ScalarKind::Unsigned => match width { + 1 => ast::ScalarType::U8, + 2 => ast::ScalarType::U16, + 4 => ast::ScalarType::U32, + 8 => ast::ScalarType::U64, + _ => unreachable!(), + }, + ast::ScalarKind::Float2 => match width { + 4 => ast::ScalarType::F16x2, + _ => unreachable!(), + }, + ast::ScalarKind::Pred => ast::ScalarType::Pred, + } +} + +fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass { + match this { + ast::StateSpace::Const => spirv::StorageClass::UniformConstant, + ast::StateSpace::Generic => spirv::StorageClass::Generic, + ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, + ast::StateSpace::Local => spirv::StorageClass::Function, + ast::StateSpace::Shared => spirv::StorageClass::Workgroup, + ast::StateSpace::Param => spirv::StorageClass::Function, + ast::StateSpace::Reg => spirv::StorageClass::Function, + ast::StateSpace::Sreg => spirv::StorageClass::Input, + } +} + +fn space_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool { + this == other + || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg + || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg +} + +fn space_coerces_to_generic(this: ast::StateSpace) -> bool { + match this { + ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::Local + | ast::StateSpace::Shared => true, + ast::StateSpace::Reg + | ast::StateSpace::Param + | ast::StateSpace::Generic + | ast::StateSpace::Sreg => false, + } +} + +fn space_is_addressable(this: ast::StateSpace) -> bool { + match this { + ast::StateSpace::Const + | ast::StateSpace::Generic + | ast::StateSpace::Global + | ast::StateSpace::Local + | ast::StateSpace::Shared => true, + ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false, + } +} + +fn operand_underlying_register(this: &impl ast::Operand) -> Option<&T> { + match this { + ast::Operand::Reg(r) | ast::Operand::RegOffset(r, _) | ast::Operand::VecMember(r, _) => { + Some(r) } + ast::Operand::Imm(_) | ast::Operand::VecPack(..) => None, } } -impl ast::Arg1 { - fn map>( - self, - visitor: &mut V, - is_dst: bool, - t: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result, TranslateError> { - let new_src = visitor.id( - ArgumentDescriptor { - op: self.src, - is_dst, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - )?; - Ok(ast::Arg1 { src: new_src }) +fn is_wide(this: &ast::MulDetails) -> bool { + match this { + ast::MulDetails::Unsigned(d) => d.control == ast::MulIntControl::Wide, + ast::MulDetails::Signed(d) => d.control == ast::MulIntControl::Wide, + ast::MulDetails::Float(_) => false, } } -impl ast::Arg1Bar { - fn map>( - self, - visitor: &mut V, - ) -> Result, TranslateError> { - let new_src = visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::U32), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg1Bar { src: new_src }) +fn scope_to_spirv(this: &ast::MemScope) -> spirv::Scope { + match this { + ast::MemScope::Cta => spirv::Scope::Workgroup, + ast::MemScope::Gpu => spirv::Scope::Device, + ast::MemScope::Sys => spirv::Scope::CrossDevice, + ptx_parser::MemScope::Cluster => todo!(), } } -impl ast::Arg2 { - fn map>( - self, - visitor: &mut V, - t: &ast::Type, - ) -> Result, TranslateError> { - let new_dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let new_src = visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - Ok(ast::Arg2 { - dst: new_dst, - src: new_src, - }) - } - - fn map_cvt>( - self, - visitor: &mut V, - dst_t: ast::ScalarType, - src_t: ast::ScalarType, - is_int_to_int: bool, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: if is_int_to_int { - Some(should_convert_relaxed_dst_wrapper) - } else { - None - }, - }, - &ast::Type::Scalar(dst_t), - ast::StateSpace::Reg, - )?; - let src = visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: if is_int_to_int { - Some(should_convert_relaxed_src_wrapper) - } else { - None - }, - }, - &ast::Type::Scalar(src_t), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg2 { dst, src }) - } - - fn map_different_types>( - self, - visitor: &mut V, - dst_t: &ast::Type, - src_t: &ast::Type, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - dst_t, - ast::StateSpace::Reg, - )?; - let src = visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - src_t, - ast::StateSpace::Reg, - )?; - Ok(ast::Arg2 { dst, src }) - } -} - -impl ast::Arg2Ld { - fn map>( - self, - visitor: &mut V, - details: &ast::LdDetails, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: Some(should_convert_relaxed_dst_wrapper), - }, - &ast::Type::from(details.typ.clone()), - ast::StateSpace::Reg, - )?; - let src = visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: true, - non_default_implicit_conversion: None, - }, - &details.typ, - details.state_space, - )?; - Ok(ast::Arg2Ld { dst, src }) - } -} - -impl ast::Arg2St { - fn map>( - self, - visitor: &mut V, - details: &ast::StData, - ) -> Result, TranslateError> { - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: true, - non_default_implicit_conversion: None, - }, - &details.typ, - details.state_space, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: Some(should_convert_relaxed_src_wrapper), - }, - &details.typ.clone().into(), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg2St { src1, src2 }) - } -} - -impl ast::Arg2Mov { - fn map>( - self, - visitor: &mut V, - details: &ast::MovDetails, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &details.typ.clone().into(), - ast::StateSpace::Reg, - )?; - let src = visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: Some(implicit_conversion_mov), - }, - &details.typ.clone().into(), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg2Mov { dst, src }) - } -} - -impl ast::Arg3 { - fn map_non_shift>( - self, - visitor: &mut V, - typ: &ast::Type, - is_wide: bool, - ) -> Result, TranslateError> { - let wide_type = if is_wide { - Some(typ.clone().widen()?) - } else { - None - }; - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - wide_type.as_ref().unwrap_or(typ), - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - typ, - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - typ, - ast::StateSpace::Reg, - )?; - Ok(ast::Arg3 { dst, src1, src2 }) - } - - fn map_shift>( - self, - visitor: &mut V, - t: &ast::Type, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::U32), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg3 { dst, src1, src2 }) - } - - fn map_atom>( - self, - visitor: &mut V, - t: ast::ScalarType, - state_space: ast::StateSpace, - ) -> Result, TranslateError> { - let scalar_type = ast::ScalarType::from(t); - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(scalar_type), - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: true, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(scalar_type), - state_space, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(scalar_type), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg3 { dst, src1, src2 }) - } - - fn map_prmt>( - self, - visitor: &mut V, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::B32), - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::B32), - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::B32), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg3 { dst, src1, src2 }) - } -} - -impl ast::Arg4 { - fn map>( - self, - visitor: &mut V, - t: &ast::Type, - is_wide: bool, - ) -> Result, TranslateError> { - let wide_type = if is_wide { - Some(t.clone().widen()?) - } else { - None - }; - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - wide_type.as_ref().unwrap_or(t), - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let src3 = visitor.operand( - ArgumentDescriptor { - op: self.src3, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - Ok(ast::Arg4 { - dst, - src1, - src2, - src3, - }) - } - - fn map_selp>( - self, - visitor: &mut V, - t: ast::ScalarType, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(t.into()), - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(t.into()), - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(t.into()), - ast::StateSpace::Reg, - )?; - let src3 = visitor.operand( - ArgumentDescriptor { - op: self.src3, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg4 { - dst, - src1, - src2, - src3, - }) - } - - fn map_atom>( - self, - visitor: &mut V, - t: ast::ScalarType, - state_space: ast::StateSpace, - ) -> Result, TranslateError> { - let scalar_type = ast::ScalarType::from(t); - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(scalar_type), - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: true, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(scalar_type), - state_space, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(scalar_type), - ast::StateSpace::Reg, - )?; - let src3 = visitor.operand( - ArgumentDescriptor { - op: self.src3, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(scalar_type), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg4 { - dst, - src1, - src2, - src3, - }) - } - - fn map_bfe>( - self, - visitor: &mut V, - typ: &ast::Type, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - typ, - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - typ, - ast::StateSpace::Reg, - )?; - let u32_type = ast::Type::Scalar(ast::ScalarType::U32); - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &u32_type, - ast::StateSpace::Reg, - )?; - let src3 = visitor.operand( - ArgumentDescriptor { - op: self.src3, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &u32_type, - ast::StateSpace::Reg, - )?; - Ok(ast::Arg4 { - dst, - src1, - src2, - src3, - }) - } -} - -impl ast::Arg4Setp { - fn map>( - self, - visitor: &mut V, - t: &ast::Type, - ) -> Result, TranslateError> { - let dst1 = visitor.id( - ArgumentDescriptor { - op: self.dst1, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some(( - &ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - )), - )?; - let dst2 = self - .dst2 - .map(|dst2| { - visitor.id( - ArgumentDescriptor { - op: dst2, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some(( - &ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - )), - ) - }) - .transpose()?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - Ok(ast::Arg4Setp { - dst1, - dst2, - src1, - src2, - }) - } -} - -impl ast::Arg5 { - fn map_bfi>( - self, - visitor: &mut V, - base_type: &ast::Type, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - base_type, - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - base_type, - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - base_type, - ast::StateSpace::Reg, - )?; - let src3 = visitor.operand( - ArgumentDescriptor { - op: self.src3, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::U32), - ast::StateSpace::Reg, - )?; - let src4 = visitor.operand( - ArgumentDescriptor { - op: self.src4, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::U32), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg5 { - dst, - src1, - src2, - src3, - src4, - }) - } -} - -impl ast::Arg5Setp { - fn map>( - self, - visitor: &mut V, - t: &ast::Type, - ) -> Result, TranslateError> { - let dst1 = visitor.id( - ArgumentDescriptor { - op: self.dst1, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some(( - &ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - )), - )?; - let dst2 = self - .dst2 - .map(|dst2| { - visitor.id( - ArgumentDescriptor { - op: dst2, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some(( - &ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - )), - ) - }) - .transpose()?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let src3 = visitor.operand( - ArgumentDescriptor { - op: self.src3, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg5Setp { - dst1, - dst2, - src1, - src2, - src3, - }) - } -} - -impl ast::Operand { - fn map_variable Result>( - self, - f: &mut F, - ) -> Result, TranslateError> { - Ok(match self { - ast::Operand::Reg(reg) => ast::Operand::Reg(f(reg)?), - ast::Operand::RegOffset(reg, offset) => ast::Operand::RegOffset(f(reg)?, offset), - ast::Operand::Imm(x) => ast::Operand::Imm(x), - ast::Operand::VecMember(reg, idx) => ast::Operand::VecMember(f(reg)?, idx), - ast::Operand::VecPack(vec) => { - ast::Operand::VecPack(vec.into_iter().map(f).collect::>()?) - } - }) - } -} - -impl ast::Operand { - fn unwrap_reg(&self) -> Result { - match self { - ast::Operand::Reg(reg) => Ok(*reg), - _ => Err(error_unreachable()), - } - } -} - -impl ast::ScalarType { - fn from_parts(width: u8, kind: ast::ScalarKind) -> Self { - match kind { - ast::ScalarKind::Float => match width { - 2 => ast::ScalarType::F16, - 4 => ast::ScalarType::F32, - 8 => ast::ScalarType::F64, - _ => unreachable!(), - }, - ast::ScalarKind::Bit => match width { - 1 => ast::ScalarType::B8, - 2 => ast::ScalarType::B16, - 4 => ast::ScalarType::B32, - 8 => ast::ScalarType::B64, - _ => unreachable!(), - }, - ast::ScalarKind::Signed => match width { - 1 => ast::ScalarType::S8, - 2 => ast::ScalarType::S16, - 4 => ast::ScalarType::S32, - 8 => ast::ScalarType::S64, - _ => unreachable!(), - }, - ast::ScalarKind::Unsigned => match width { - 1 => ast::ScalarType::U8, - 2 => ast::ScalarType::U16, - 4 => ast::ScalarType::U32, - 8 => ast::ScalarType::U64, - _ => unreachable!(), - }, - ast::ScalarKind::Float2 => match width { - 4 => ast::ScalarType::F16x2, - _ => unreachable!(), - }, - ast::ScalarKind::Pred => ast::ScalarType::Pred, - } - } -} - -impl ast::ArithDetails { - fn get_type(&self) -> ast::Type { - ast::Type::Scalar(match self { - ast::ArithDetails::Unsigned(t) => (*t).into(), - ast::ArithDetails::Signed(d) => d.typ.into(), - ast::ArithDetails::Float(d) => d.typ.into(), - }) - } -} - -impl ast::MulDetails { - fn get_type(&self) -> ast::Type { - ast::Type::Scalar(match self { - ast::MulDetails::Unsigned(d) => d.typ.into(), - ast::MulDetails::Signed(d) => d.typ.into(), - ast::MulDetails::Float(d) => d.typ.into(), - }) - } -} - -impl ast::MinMaxDetails { - fn get_type(&self) -> ast::Type { - ast::Type::Scalar(match self { - ast::MinMaxDetails::Signed(t) => (*t).into(), - ast::MinMaxDetails::Unsigned(t) => (*t).into(), - ast::MinMaxDetails::Float(d) => d.typ.into(), - }) - } -} - -impl ast::DivDetails { - fn get_type(&self) -> ast::Type { - ast::Type::Scalar(match self { - ast::DivDetails::Unsigned(t) => (*t).into(), - ast::DivDetails::Signed(t) => (*t).into(), - ast::DivDetails::Float(d) => d.typ.into(), - }) - } -} - -impl ast::AtomInnerDetails { - fn get_type(&self) -> ast::ScalarType { - match self { - ast::AtomInnerDetails::Bit { typ, .. } => (*typ).into(), - ast::AtomInnerDetails::Unsigned { typ, .. } => (*typ).into(), - ast::AtomInnerDetails::Signed { typ, .. } => (*typ).into(), - ast::AtomInnerDetails::Float { typ, .. } => (*typ).into(), - } - } -} - -impl ast::StateSpace { - fn to_spirv(self) -> spirv::StorageClass { - match self { - ast::StateSpace::Const => spirv::StorageClass::UniformConstant, - ast::StateSpace::Generic => spirv::StorageClass::Generic, - ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, - ast::StateSpace::Local => spirv::StorageClass::Function, - ast::StateSpace::Shared => spirv::StorageClass::Workgroup, - ast::StateSpace::Param => spirv::StorageClass::Function, - ast::StateSpace::Reg => spirv::StorageClass::Function, - ast::StateSpace::Sreg => spirv::StorageClass::Input, - } - } - - fn is_compatible(self, other: ast::StateSpace) -> bool { - self == other - || self == ast::StateSpace::Reg && other == ast::StateSpace::Sreg - || self == ast::StateSpace::Sreg && other == ast::StateSpace::Reg - } - - fn coerces_to_generic(self) -> bool { - match self { - ast::StateSpace::Global - | ast::StateSpace::Const - | ast::StateSpace::Local - | ast::StateSpace::Shared => true, - ast::StateSpace::Reg - | ast::StateSpace::Param - | ast::StateSpace::Generic - | ast::StateSpace::Sreg => false, - } - } - - fn is_addressable(self) -> bool { - match self { - ast::StateSpace::Const - | ast::StateSpace::Generic - | ast::StateSpace::Global - | ast::StateSpace::Local - | ast::StateSpace::Shared => true, - ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false, - } - } -} - -impl ast::Operand { - fn underlying_register(&self) -> Option<&T> { - match self { - ast::Operand::Reg(r) - | ast::Operand::RegOffset(r, _) - | ast::Operand::VecMember(r, _) => Some(r), - ast::Operand::Imm(_) | ast::Operand::VecPack(..) => None, - } - } -} - -impl ast::MulDetails { - fn is_wide(&self) -> bool { - match self { - ast::MulDetails::Unsigned(d) => d.control == ast::MulIntControl::Wide, - ast::MulDetails::Signed(d) => d.control == ast::MulIntControl::Wide, - ast::MulDetails::Float(_) => false, - } - } -} - -impl ast::MemScope { - fn to_spirv(self) -> spirv::Scope { - match self { - ast::MemScope::Cta => spirv::Scope::Workgroup, - ast::MemScope::Gpu => spirv::Scope::Device, - ast::MemScope::Sys => spirv::Scope::CrossDevice, - } - } -} - -impl ast::AtomSemantics { - fn to_spirv(self) -> spirv::MemorySemantics { - match self { - ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED, - ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE, - ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE, - ast::AtomSemantics::AcquireRelease => spirv::MemorySemantics::ACQUIRE_RELEASE, - } +fn to_spirv(this: ast::AtomSemantics) -> spirv::MemorySemantics { + match this { + ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED, + ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE, + ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE, + ast::AtomSemantics::AcqRel => spirv::MemorySemantics::ACQUIRE_RELEASE, } } @@ -8046,36 +5971,32 @@ fn should_convert_relaxed_dst( } } -impl<'a> ast::MethodDeclaration<'a, &'a str> { - fn name(&self) -> &'a str { - match self.name { - ast::MethodName::Kernel(name) => name, - ast::MethodName::Func(name) => name, - } +fn method_name<'a>(this: &ast::MethodDeclaration<'a, &'a str>) -> &'a str { + match this.name { + ast::MethodName::Kernel(name) => name, + ast::MethodName::Func(name) => name, } } -impl<'a> ast::MethodDeclaration<'a, spirv::Word> { - fn effective_input_arguments(&self) -> impl Iterator + '_ { - let is_kernel = self.name.is_kernel(); - self.input_arguments.iter().map(move |arg| { - if !is_kernel && arg.state_space != ast::StateSpace::Reg { - let spirv_type = - SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv()); - (arg.name, spirv_type) - } else { - (arg.name, SpirvType::new(arg.v_type.clone())) - } - }) - } +fn effective_input_arguments<'a, 'input>( + this: &'a ast::MethodDeclaration<'input, SpirvWord>, +) -> impl Iterator + 'a { + let is_kernel = method_is_kernel(this); + this.input_arguments.iter().map(move |arg| { + if !is_kernel && arg.state_space != ast::StateSpace::Reg { + let spirv_type = + SpirvType::pointer_to(arg.v_type.clone(), space_to_spirv(arg.state_space)); + (arg.name, spirv_type) + } else { + (arg.name, SpirvType::new(arg.v_type.clone())) + } + }) } -impl<'input, ID> ast::MethodName<'input, ID> { - fn is_kernel(&self) -> bool { - match self { - ast::MethodName::Kernel(..) => true, - ast::MethodName::Func(..) => false, - } +fn method_is_kernel<'a, T>(this: &ast::MethodDeclaration<'a, T>) -> bool { + match this { + ast::MethodName::Kernel(..) => true, + ast::MethodName::Func(..) => false, } } diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml index af3058b..a4df14f 100644 --- a/ptx_parser/Cargo.toml +++ b/ptx_parser/Cargo.toml @@ -4,6 +4,8 @@ version = "0.0.0" authors = ["Andrzej Janik "] edition = "2021" +[lib] + [dependencies] logos = "0.14" winnow = { version = "0.6.18" } @@ -11,3 +13,4 @@ ptx_parser_macros = { path = "../ptx_parser_macros" } thiserror = "1.0" bitflags = "1.2" rustc-hash = "2.0.0" +derive_more = { version = "1", features = ["display"] } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 6cf1264..9c3312a 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -717,7 +717,7 @@ impl Operand for ParsedOperand { type Ident = Ident; } -pub trait Operand { +pub trait Operand: Sized { type Ident: Copy; } diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/lib.rs similarity index 98% rename from ptx_parser/src/main.rs rename to ptx_parser/src/lib.rs index 5db94f2..0b2bf65 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/lib.rs @@ -1,8 +1,8 @@ +use derive_more::Display; use logos::Logos; use ptx_parser_macros::derive_parser; use rustc_hash::FxHashMap; use std::fmt::Debug; -use std::mem; use std::num::{ParseFloatError, ParseIntError}; use winnow::ascii::dec_uint; use winnow::combinator::*; @@ -81,16 +81,16 @@ impl VectorPrefix { } } -struct PtxParserState<'input> { - errors: Vec, +struct PtxParserState<'a, 'input> { + errors: &'a mut Vec, function_declarations: FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>, } -impl<'input> PtxParserState<'input> { - fn new() -> Self { +impl<'a, 'input> PtxParserState<'a, 'input> { + fn new(errors: &'a mut Vec) -> Self { Self { - errors: Vec::new(), + errors, function_declarations: FxHashMap::default(), } } @@ -115,7 +115,7 @@ impl<'input> PtxParserState<'input> { } } -impl<'input> Debug for PtxParserState<'input> { +impl<'a, 'input> Debug for PtxParserState<'a, 'input> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PtxParserState") .field("errors", &self.errors) /* .field("function_decl", &self.function_decl) */ @@ -123,7 +123,7 @@ impl<'input> Debug for PtxParserState<'input> { } } -type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'input>>; +type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'a, 'input>>; fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { any.verify_map(|t| { @@ -277,6 +277,18 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult(text: &'input str) -> Option> { + let lexer = Token::lexer(text); + let input = lexer.collect::, _>>().ok()?; + let mut errors = Vec::new(); + let state = PtxParserState::new(&mut errors); + let parser = PtxParser { + state, + input: &input[..], + }; + module.parse(parser).ok() +} + fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { ( version, @@ -818,6 +830,8 @@ pub enum PtxError { source: ParseFloatError, }, #[error("")] + Lexer(#[from] TokenError), + #[error("")] Todo, #[error("")] SyntaxError, @@ -1042,9 +1056,15 @@ fn empty_call<'input>( type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>; +#[derive(Clone, PartialEq, Default, Debug, Display)] +pub struct TokenError; + +impl std::error::Error for TokenError {} + derive_parser!( #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)] #[logos(skip r"\s+")] + #[logos(error = TokenError)] enum Token<'input> { #[token(",")] Comma, @@ -2825,57 +2845,6 @@ derive_parser!( ); -fn main() { - use winnow::Parser; - - let lexer = Token::lexer( - " - .version 6.5 - .target sm_30 - .address_size 64 - - .const .align 8 .b32 constparams; - - .visible .entry const( - .param .u64 input, - .param .u64 output - ) - { - .reg .u64 in_addr; - .reg .u64 out_addr; - .reg .b16 temp1; - .reg .b16 temp2; - .reg .b16 temp3; - .reg .b16 temp4; - - ld.param.u64 in_addr, [input]; - ld.param.u64 out_addr, [output]; - - ld.const.b16 temp1, [constparams]; - ld.const.b16 temp2, [constparams+2]; - ld.const.b16 temp3, [constparams+4]; - ld.const.b16 temp4, [constparams+6]; - st.u16 [out_addr], temp1; - st.u16 [out_addr+2], temp2; - st.u16 [out_addr+4], temp3; - st.u16 [out_addr+6], temp4; - ret; - } - - ", - ); - let tokens = lexer.clone().collect::>(); - println!("{:?}", &tokens); - let tokens = lexer.map(|t| t.unwrap()).collect::>(); - println!("{:?}", &tokens); - let stream = PtxParser { - input: &tokens[..], - state: PtxParserState::new(), - }; - let _module = module.parse(stream).unwrap(); - println!("{}", mem::size_of::()); -} - #[cfg(test)] mod tests { use super::target;