diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 367f060..aba6bda 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -557,7 +557,7 @@ pub enum Instruction { Mul(MulDetails, Arg3

), Add(ArithDetails, Arg3

), Setp(SetpData, Arg4Setp

), - SetpBool(SetpBoolData, Arg5

), + SetpBool(SetpBoolData, Arg5Setp

), Not(BooleanType, Arg2

), Bra(BraData, Arg1

), Cvt(CvtDetails, Arg2

), @@ -614,16 +614,12 @@ pub struct CallInst { pub uniform: bool, pub ret_params: Vec, pub func: P::Id, - pub param_list: Vec, + pub param_list: Vec, } pub trait ArgParams { type Id; type Operand; - type IdOrVector; - type OperandOrVector; - type CallOperand; - type SrcMemberOperand; } pub struct ParsedArgParams<'a> { @@ -633,10 +629,6 @@ pub struct ParsedArgParams<'a> { impl<'a> ArgParams for ParsedArgParams<'a> { type Id = &'a str; type Operand = Operand<&'a str>; - type CallOperand = CallOperand<&'a str>; - type IdOrVector = IdOrVector<&'a str>; - type OperandOrVector = OperandOrVector<&'a str>; - type SrcMemberOperand = (&'a str, u8); } pub struct Arg1 { @@ -648,45 +640,32 @@ pub struct Arg1Bar { } pub struct Arg2 { - pub dst: P::Id, + pub dst: P::Operand, pub src: P::Operand, } pub struct Arg2Ld { - pub dst: P::IdOrVector, + pub dst: P::Operand, pub src: P::Operand, } pub struct Arg2St { pub src1: P::Operand, - pub src2: P::OperandOrVector, + pub src2: P::Operand, } -pub enum Arg2Mov { - Normal(Arg2MovNormal

), - Member(Arg2MovMember

), -} - -pub struct Arg2MovNormal { - pub dst: P::IdOrVector, - pub src: P::OperandOrVector, -} - -// We duplicate dst here because during further compilation -// composite dst and composite src will receive different ids -pub enum Arg2MovMember { - Dst((P::Id, u8), P::Id, P::Id), - Src(P::Id, P::SrcMemberOperand), - Both((P::Id, u8), P::Id, P::SrcMemberOperand), +pub struct Arg2Mov { + pub dst: P::Operand, + pub src: P::Operand, } pub struct Arg3 { - pub dst: P::Id, + pub dst: P::Operand, pub src1: P::Operand, pub src2: P::Operand, } pub struct Arg4 { - pub dst: P::Id, + pub dst: P::Operand, pub src1: P::Operand, pub src2: P::Operand, pub src3: P::Operand, @@ -699,7 +678,7 @@ pub struct Arg4Setp { pub src2: P::Operand, } -pub struct Arg5 { +pub struct Arg5Setp { pub dst1: P::Id, pub dst2: Option, pub src1: P::Operand, @@ -715,39 +694,13 @@ pub enum ImmediateValue { F64(f64), } -#[derive(Copy, Clone)] -pub enum Operand { - Reg(ID), - RegOffset(ID, i32), +#[derive(Clone)] +pub enum Operand { + Reg(Id), + RegOffset(Id, i32), Imm(ImmediateValue), -} - -#[derive(Copy, Clone)] -pub enum CallOperand { - Reg(ID), - Imm(ImmediateValue), -} - -pub enum IdOrVector { - Reg(ID), - Vec(Vec), -} - -pub enum OperandOrVector { - Reg(ID), - RegOffset(ID, i32), - Imm(ImmediateValue), - Vec(Vec), -} - -impl From> for OperandOrVector { - fn from(this: Operand) -> Self { - match this { - Operand::Reg(r) => OperandOrVector::Reg(r), - Operand::RegOffset(r, imm) => OperandOrVector::RegOffset(r, imm), - Operand::Imm(imm) => OperandOrVector::Imm(imm), - } - } + VecMember(Id, u8), + VecPack(Vec), } pub enum VectorPrefix { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index d2c235a..fd2a3f1 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -721,7 +721,7 @@ Instruction: ast::Instruction> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld InstLd: ast::Instruction> = { - "ld" "," => { + "ld" "," => { ast::Instruction::Ld( ast::LdDetails { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), @@ -734,16 +734,6 @@ InstLd: ast::Instruction> = { } }; -IdOrVector: ast::IdOrVector<&'input str> = { - => ast::IdOrVector::Reg(dst), - => ast::IdOrVector::Vec(dst) -} - -OperandOrVector: ast::OperandOrVector<&'input str> = { - => ast::OperandOrVector::from(op), - => ast::OperandOrVector::Vec(dst) -} - LdStType: ast::LdStType = { => ast::LdStType::Vector(t, v), => ast::LdStType::Scalar(t), @@ -780,27 +770,17 @@ LdCacheOperator: ast::LdCacheOperator = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov InstMov: ast::Instruction> = { - => ast::Instruction::Mov(m.0, m.1), - => ast::Instruction::Mov(m.0, m.1), -}; - - -MovNormal: (ast::MovDetails, ast::Arg2Mov>) = { - "mov" "," => {( - ast::MovDetails::new(ast::Type::Scalar(t)), - ast::Arg2Mov::Normal(ast::Arg2MovNormal{ dst: ast::IdOrVector::Reg(dst), src: src.into() }) - )}, - "mov" "," => {( - ast::MovDetails::new(ast::Type::Vector(t, pref)), - ast::Arg2Mov::Normal(ast::Arg2MovNormal{ dst: dst, src: src }) - )} -} - -MovVector: (ast::MovDetails, ast::Arg2Mov>) = { - "mov" => {( - ast::MovDetails::new(ast::Type::Scalar(t.into())), - ast::Arg2Mov::Member(a) - )}, + "mov" "," => { + let mov_type = match pref { + Some(vec_width) => ast::Type::Vector(t, vec_width), + None => ast::Type::Scalar(t) + }; + let details = ast::MovDetails::new(mov_type); + ast::Instruction::Mov( + details, + ast::Arg2Mov { dst, src } + ) + } } #[inline] @@ -819,21 +799,6 @@ MovScalarType: ast::ScalarType = { ".pred" => ast::ScalarType::Pred }; -#[inline] -MovVectorType: ast::ScalarType = { - ".b16" => ast::ScalarType::B16, - ".b32" => ast::ScalarType::B32, - ".b64" => ast::ScalarType::B64, - ".u16" => ast::ScalarType::U16, - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, - ".f32" => ast::ScalarType::F32, - ".f64" => ast::ScalarType::F64, -}; - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul @@ -921,7 +886,7 @@ InstAdd: ast::Instruction> = { // TODO: support f16 setp InstSetp: ast::Instruction> = { "setp" => ast::Instruction::Setp(d, a), - "setp" => ast::Instruction::SetpBool(d, a), + "setp" => ast::Instruction::SetpBool(d, a), }; SetpMode: ast::SetpData = { @@ -1198,7 +1163,7 @@ ShrType: ast::ShrType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // Warning: NVIDIA documentation is incorrect, you can specify scope only once InstSt: ast::Instruction> = { - "st" "," => { + "st" "," => { ast::Instruction::St( ast::StData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), @@ -1775,9 +1740,9 @@ Operand: ast::Operand<&'input str> = { => ast::Operand::Imm(x) }; -CallOperand: ast::CallOperand<&'input str> = { - => ast::CallOperand::Reg(r), - => ast::CallOperand::Imm(x) +CallOperand: ast::Operand<&'input str> = { + => ast::Operand::Reg(r), + => ast::Operand::Imm(x) }; // TODO: start parsing whole constants sub-language: @@ -1825,13 +1790,7 @@ Arg1Bar: ast::Arg1Bar> = { }; Arg2: ast::Arg2> = { - "," => ast::Arg2{<>} -}; - -Arg2MovMember: ast::Arg2MovMember> = { - "," => ast::Arg2MovMember::Dst(dst, dst.0, src), - "," => ast::Arg2MovMember::Src(dst, src), - "," => ast::Arg2MovMember::Both(dst, dst.0, src), + "," => ast::Arg2{<>} }; MemberOperand: (&'input str, u8) = { @@ -1855,19 +1814,19 @@ VectorExtract: Vec<&'input str> = { }; Arg3: ast::Arg3> = { - "," "," => ast::Arg3{<>} + "," "," => ast::Arg3{<>} }; Arg3Atom: ast::Arg3> = { - "," "[" "]" "," => ast::Arg3{<>} + "," "[" "]" "," => ast::Arg3{<>} }; Arg4: ast::Arg4> = { - "," "," "," => ast::Arg4{<>} + "," "," "," => ast::Arg4{<>} }; Arg4Atom: ast::Arg4> = { - "," "[" "]" "," "," => ast::Arg4{<>} + "," "[" "]" "," "," => ast::Arg4{<>} }; Arg4Setp: ast::Arg4Setp> = { @@ -1875,22 +1834,50 @@ Arg4Setp: ast::Arg4Setp> = { }; // TODO: pass src3 negation somewhere -Arg5: ast::Arg5> = { - "," "," "," "!"? => ast::Arg5{<>} +Arg5Setp: ast::Arg5Setp> = { + "," "," "," "!"? => ast::Arg5Setp{<>} }; -ArgCall: (Vec<&'input str>, &'input str, Vec>) = { +ArgCall: (Vec<&'input str>, &'input str, Vec>) = { "(" > ")" "," "," "(" > ")" => { (ret_params, func, param_list) }, "," "(" > ")" => (Vec::new(), func, param_list), - => (Vec::new(), func, Vec::>::new()), + => (Vec::new(), func, Vec::>::new()), }; OptionalDst: &'input str = { "|" => dst2 } +SrcOperand: ast::Operand<&'input str> = { + => ast::Operand::Reg(r), + "+" => ast::Operand::RegOffset(r, offset), + => ast::Operand::Imm(x), + => { + let (reg, idx) = mem_op; + ast::Operand::VecMember(reg, idx) + } +} + +SrcOperandVec: ast::Operand<&'input str> = { + => normal, + => ast::Operand::VecPack(vec), +} + +DstOperand: ast::Operand<&'input str> = { + => ast::Operand::Reg(r), + => { + let (reg, idx) = mem_op; + ast::Operand::VecMember(reg, idx) + } +} + +DstOperandVec: ast::Operand<&'input str> = { + => normal, + => ast::Operand::VecPack(vec), +} + VectorPrefix: u8 = { ".v2" => 2, ".v4" => 4 diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt index 535e480..a77ab7d 100644 --- a/ptx/src/test/spirv_run/vector.spvtxt +++ b/ptx/src/test/spirv_run/vector.spvtxt @@ -7,91 +7,93 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %57 = OpExtInstImport "OpenCL.std" + %51 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %31 "vector" + OpEntryPoint Kernel %25 "vector" %void = OpTypeVoid %uint = OpTypeInt 32 0 %v2uint = OpTypeVector %uint 2 - %61 = OpTypeFunction %v2uint %v2uint + %55 = OpTypeFunction %v2uint %v2uint %_ptr_Function_v2uint = OpTypePointer Function %v2uint %_ptr_Function_uint = OpTypePointer Function %uint + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 %ulong = OpTypeInt 64 0 - %65 = OpTypeFunction %void %ulong %ulong + %67 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Generic_v2uint = OpTypePointer Generic %v2uint - %1 = OpFunction %v2uint None %61 + %1 = OpFunction %v2uint None %55 %7 = OpFunctionParameter %v2uint - %30 = OpLabel + %24 = OpLabel %2 = OpVariable %_ptr_Function_v2uint Function %3 = OpVariable %_ptr_Function_v2uint Function %4 = OpVariable %_ptr_Function_v2uint Function %5 = OpVariable %_ptr_Function_uint Function %6 = OpVariable %_ptr_Function_uint Function OpStore %3 %7 - %9 = OpLoad %v2uint %3 - %27 = OpCompositeExtract %uint %9 0 - %8 = OpCopyObject %uint %27 + %59 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_0 + %9 = OpLoad %uint %59 + %8 = OpCopyObject %uint %9 OpStore %5 %8 - %11 = OpLoad %v2uint %3 - %28 = OpCompositeExtract %uint %11 1 - %10 = OpCopyObject %uint %28 + %61 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_1 + %11 = OpLoad %uint %61 + %10 = OpCopyObject %uint %11 OpStore %6 %10 %13 = OpLoad %uint %5 %14 = OpLoad %uint %6 %12 = OpIAdd %uint %13 %14 OpStore %6 %12 - %16 = OpLoad %v2uint %4 - %17 = OpLoad %uint %6 - %15 = OpCompositeInsert %v2uint %17 %16 0 - OpStore %4 %15 - %19 = OpLoad %v2uint %4 - %20 = OpLoad %uint %6 - %18 = OpCompositeInsert %v2uint %20 %19 1 - OpStore %4 %18 + %16 = OpLoad %uint %6 + %15 = OpCopyObject %uint %16 + %62 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0 + OpStore %62 %15 + %18 = OpLoad %uint %6 + %17 = OpCopyObject %uint %18 + %63 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1 + OpStore %63 %17 + %64 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1 + %20 = OpLoad %uint %64 + %19 = OpCopyObject %uint %20 + %65 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0 + OpStore %65 %19 %22 = OpLoad %v2uint %4 - %23 = OpLoad %v2uint %4 - %29 = OpCompositeExtract %uint %23 1 - %21 = OpCompositeInsert %v2uint %29 %22 0 - OpStore %4 %21 - %25 = OpLoad %v2uint %4 - %24 = OpCopyObject %v2uint %25 - OpStore %2 %24 - %26 = OpLoad %v2uint %2 - OpReturnValue %26 + %21 = OpCopyObject %v2uint %22 + OpStore %2 %21 + %23 = OpLoad %v2uint %2 + OpReturnValue %23 OpFunctionEnd - %31 = OpFunction %void None %65 - %40 = OpFunctionParameter %ulong - %41 = OpFunctionParameter %ulong - %55 = OpLabel - %32 = OpVariable %_ptr_Function_ulong Function + %25 = OpFunction %void None %67 + %34 = OpFunctionParameter %ulong + %35 = OpFunctionParameter %ulong + %49 = OpLabel + %26 = OpVariable %_ptr_Function_ulong Function + %27 = OpVariable %_ptr_Function_ulong Function + %28 = OpVariable %_ptr_Function_ulong Function + %29 = OpVariable %_ptr_Function_ulong Function + %30 = OpVariable %_ptr_Function_v2uint Function + %31 = OpVariable %_ptr_Function_uint Function + %32 = OpVariable %_ptr_Function_uint Function %33 = OpVariable %_ptr_Function_ulong Function - %34 = OpVariable %_ptr_Function_ulong Function - %35 = OpVariable %_ptr_Function_ulong Function - %36 = OpVariable %_ptr_Function_v2uint Function - %37 = OpVariable %_ptr_Function_uint Function - %38 = OpVariable %_ptr_Function_uint Function - %39 = OpVariable %_ptr_Function_ulong Function - OpStore %32 %40 - OpStore %33 %41 - %42 = OpLoad %ulong %32 - OpStore %34 %42 - %43 = OpLoad %ulong %33 - OpStore %35 %43 - %45 = OpLoad %ulong %34 - %52 = OpConvertUToPtr %_ptr_Generic_v2uint %45 - %44 = OpLoad %v2uint %52 - OpStore %36 %44 - %47 = OpLoad %v2uint %36 - %46 = OpFunctionCall %v2uint %1 %47 - OpStore %36 %46 - %49 = OpLoad %v2uint %36 - %53 = OpBitcast %ulong %49 - %48 = OpCopyObject %ulong %53 - OpStore %39 %48 - %50 = OpLoad %ulong %35 - %51 = OpLoad %v2uint %36 - %54 = OpConvertUToPtr %_ptr_Generic_v2uint %50 - OpStore %54 %51 + OpStore %26 %34 + OpStore %27 %35 + %36 = OpLoad %ulong %26 + OpStore %28 %36 + %37 = OpLoad %ulong %27 + OpStore %29 %37 + %39 = OpLoad %ulong %28 + %46 = OpConvertUToPtr %_ptr_Generic_v2uint %39 + %38 = OpLoad %v2uint %46 + OpStore %30 %38 + %41 = OpLoad %v2uint %30 + %40 = OpFunctionCall %v2uint %1 %41 + OpStore %30 %40 + %43 = OpLoad %v2uint %30 + %47 = OpBitcast %ulong %43 + %42 = OpCopyObject %ulong %47 + OpStore %33 %42 + %44 = OpLoad %ulong %29 + %45 = OpLoad %v2uint %30 + %48 = OpConvertUToPtr %_ptr_Generic_v2uint %44 + OpStore %48 %45 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector_extract.spvtxt b/ptx/src/test/spirv_run/vector_extract.spvtxt index 4943189..2037dec 100644 --- a/ptx/src/test/spirv_run/vector_extract.spvtxt +++ b/ptx/src/test/spirv_run/vector_extract.spvtxt @@ -7,12 +7,12 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %73 = OpExtInstImport "OpenCL.std" + %61 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "vector_extract" %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %76 = OpTypeFunction %void %ulong %ulong + %64 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %ushort = OpTypeInt 16 0 %_ptr_Function_ushort = OpTypePointer Function %ushort @@ -21,10 +21,10 @@ %uchar = OpTypeInt 8 0 %v4uchar = OpTypeVector %uchar 4 %_ptr_CrossWorkgroup_v4uchar = OpTypePointer CrossWorkgroup %v4uchar - %1 = OpFunction %void None %76 - %11 = OpFunctionParameter %ulong - %12 = OpFunctionParameter %ulong - %71 = OpLabel + %1 = OpFunction %void None %64 + %17 = OpFunctionParameter %ulong + %18 = OpFunctionParameter %ulong + %59 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -34,89 +34,92 @@ %8 = OpVariable %_ptr_Function_ushort Function %9 = OpVariable %_ptr_Function_ushort Function %10 = OpVariable %_ptr_Function_v4ushort Function - OpStore %2 %11 - OpStore %3 %12 - %13 = OpLoad %ulong %2 - OpStore %4 %13 - %14 = OpLoad %ulong %3 - OpStore %5 %14 - %19 = OpLoad %ulong %4 - %61 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %19 - %43 = OpLoad %v4uchar %61 - %62 = OpCompositeExtract %uchar %43 0 - %85 = OpBitcast %uchar %62 - %15 = OpUConvert %ushort %85 - %63 = OpCompositeExtract %uchar %43 1 - %86 = OpBitcast %uchar %63 - %16 = OpUConvert %ushort %86 - %64 = OpCompositeExtract %uchar %43 2 - %87 = OpBitcast %uchar %64 - %17 = OpUConvert %ushort %87 - %65 = OpCompositeExtract %uchar %43 3 - %88 = OpBitcast %uchar %65 - %18 = OpUConvert %ushort %88 - OpStore %6 %15 - OpStore %7 %16 - OpStore %8 %17 - OpStore %9 %18 - %21 = OpLoad %ushort %7 - %22 = OpLoad %ushort %8 - %23 = OpLoad %ushort %9 - %24 = OpLoad %ushort %6 - %44 = OpUndef %v4ushort - %45 = OpCompositeInsert %v4ushort %21 %44 0 - %46 = OpCompositeInsert %v4ushort %22 %45 1 - %47 = OpCompositeInsert %v4ushort %23 %46 2 - %48 = OpCompositeInsert %v4ushort %24 %47 3 - %20 = OpCopyObject %v4ushort %48 - OpStore %10 %20 - %29 = OpLoad %v4ushort %10 - %49 = OpCopyObject %v4ushort %29 - %25 = OpCompositeExtract %ushort %49 0 - %26 = OpCompositeExtract %ushort %49 1 - %27 = OpCompositeExtract %ushort %49 2 - %28 = OpCompositeExtract %ushort %49 3 - OpStore %8 %25 - OpStore %9 %26 - OpStore %6 %27 - OpStore %7 %28 - %34 = OpLoad %ushort %8 - %35 = OpLoad %ushort %9 - %36 = OpLoad %ushort %6 - %37 = OpLoad %ushort %7 - %51 = OpUndef %v4ushort - %52 = OpCompositeInsert %v4ushort %34 %51 0 - %53 = OpCompositeInsert %v4ushort %35 %52 1 - %54 = OpCompositeInsert %v4ushort %36 %53 2 - %55 = OpCompositeInsert %v4ushort %37 %54 3 - %50 = OpCopyObject %v4ushort %55 - %30 = OpCompositeExtract %ushort %50 0 - %31 = OpCompositeExtract %ushort %50 1 - %32 = OpCompositeExtract %ushort %50 2 - %33 = OpCompositeExtract %ushort %50 3 - OpStore %9 %30 - OpStore %6 %31 - OpStore %7 %32 - OpStore %8 %33 - %38 = OpLoad %ulong %5 - %39 = OpLoad %ushort %6 - %40 = OpLoad %ushort %7 - %41 = OpLoad %ushort %8 - %42 = OpLoad %ushort %9 - %56 = OpUndef %v4uchar - %89 = OpBitcast %ushort %39 - %66 = OpUConvert %uchar %89 - %57 = OpCompositeInsert %v4uchar %66 %56 0 - %90 = OpBitcast %ushort %40 - %67 = OpUConvert %uchar %90 - %58 = OpCompositeInsert %v4uchar %67 %57 1 - %91 = OpBitcast %ushort %41 - %68 = OpUConvert %uchar %91 - %59 = OpCompositeInsert %v4uchar %68 %58 2 - %92 = OpBitcast %ushort %42 - %69 = OpUConvert %uchar %92 - %60 = OpCompositeInsert %v4uchar %69 %59 3 - %70 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %38 - OpStore %70 %60 + OpStore %2 %17 + OpStore %3 %18 + %19 = OpLoad %ulong %2 + OpStore %4 %19 + %20 = OpLoad %ulong %3 + OpStore %5 %20 + %21 = OpLoad %ulong %4 + %49 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %21 + %11 = OpLoad %v4uchar %49 + %50 = OpCompositeExtract %uchar %11 0 + %51 = OpCompositeExtract %uchar %11 1 + %52 = OpCompositeExtract %uchar %11 2 + %53 = OpCompositeExtract %uchar %11 3 + %73 = OpBitcast %uchar %50 + %22 = OpUConvert %ushort %73 + %74 = OpBitcast %uchar %51 + %23 = OpUConvert %ushort %74 + %75 = OpBitcast %uchar %52 + %24 = OpUConvert %ushort %75 + %76 = OpBitcast %uchar %53 + %25 = OpUConvert %ushort %76 + OpStore %6 %22 + OpStore %7 %23 + OpStore %8 %24 + OpStore %9 %25 + %26 = OpLoad %ushort %7 + %27 = OpLoad %ushort %8 + %28 = OpLoad %ushort %9 + %29 = OpLoad %ushort %6 + %77 = OpUndef %v4ushort + %78 = OpCompositeInsert %v4ushort %26 %77 0 + %79 = OpCompositeInsert %v4ushort %27 %78 1 + %80 = OpCompositeInsert %v4ushort %28 %79 2 + %81 = OpCompositeInsert %v4ushort %29 %80 3 + %12 = OpCopyObject %v4ushort %81 + %30 = OpCopyObject %v4ushort %12 + OpStore %10 %30 + %31 = OpLoad %v4ushort %10 + %13 = OpCopyObject %v4ushort %31 + %32 = OpCompositeExtract %ushort %13 0 + %33 = OpCompositeExtract %ushort %13 1 + %34 = OpCompositeExtract %ushort %13 2 + %35 = OpCompositeExtract %ushort %13 3 + OpStore %8 %32 + OpStore %9 %33 + OpStore %6 %34 + OpStore %7 %35 + %36 = OpLoad %ushort %8 + %37 = OpLoad %ushort %9 + %38 = OpLoad %ushort %6 + %39 = OpLoad %ushort %7 + %82 = OpUndef %v4ushort + %83 = OpCompositeInsert %v4ushort %36 %82 0 + %84 = OpCompositeInsert %v4ushort %37 %83 1 + %85 = OpCompositeInsert %v4ushort %38 %84 2 + %86 = OpCompositeInsert %v4ushort %39 %85 3 + %15 = OpCopyObject %v4ushort %86 + %14 = OpCopyObject %v4ushort %15 + %40 = OpCompositeExtract %ushort %14 0 + %41 = OpCompositeExtract %ushort %14 1 + %42 = OpCompositeExtract %ushort %14 2 + %43 = OpCompositeExtract %ushort %14 3 + OpStore %9 %40 + OpStore %6 %41 + OpStore %7 %42 + OpStore %8 %43 + %44 = OpLoad %ushort %6 + %45 = OpLoad %ushort %7 + %46 = OpLoad %ushort %8 + %47 = OpLoad %ushort %9 + %87 = OpBitcast %ushort %44 + %54 = OpUConvert %uchar %87 + %88 = OpBitcast %ushort %45 + %55 = OpUConvert %uchar %88 + %89 = OpBitcast %ushort %46 + %56 = OpUConvert %uchar %89 + %90 = OpBitcast %ushort %47 + %57 = OpUConvert %uchar %90 + %91 = OpUndef %v4uchar + %92 = OpCompositeInsert %v4uchar %54 %91 0 + %93 = OpCompositeInsert %v4uchar %55 %92 1 + %94 = OpCompositeInsert %v4uchar %56 %93 2 + %95 = OpCompositeInsert %v4uchar %57 %94 3 + %16 = OpCopyObject %v4uchar %95 + %48 = OpLoad %ulong %5 + %58 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %48 + OpStore %58 %16 OpReturn OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 15211ab..20578eb 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -27,6 +27,16 @@ quick_error! { } } +#[cfg(debug_assertions)] +fn error_unreachable() -> TranslateError { + unreachable!() +} + +#[cfg(not(debug_assertions))] +fn error_unreachable() -> TranslateError { + TranslateError::Unreachable +} + #[derive(PartialEq, Eq, Hash, Clone)] enum SpirvType { Base(SpirvScalarKey), @@ -82,7 +92,7 @@ impl ast::Type { ast::Type::Pointer(ast::PointerType::Scalar(t), space) => { ast::Type::Pointer(ast::PointerType::Pointer(t, space), space) } - ast::Type::Pointer(_, _) => return Err(TranslateError::Unreachable), + ast::Type::Pointer(_, _) => return Err(error_unreachable()), }) } } @@ -364,7 +374,7 @@ impl TypeWordMap { b.constant_composite(result_type, None, &components) } ast::Type::Array(typ, dims) => match dims.as_slice() { - [] => return Err(TranslateError::Unreachable), + [] => return Err(error_unreachable()), [dim] => { let result_type = self .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim])); @@ -791,13 +801,14 @@ fn convert_dynamic_shared_memory_usage<'input>( ast::PointerStateSpace::Shared, )), }); - let shared_var_st = ExpandedStatement::StoreVar( - ast::Arg2St { + let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails { + arg: ast::Arg2St { src1: shared_var_id, src2: shared_id_param, }, - ast::Type::Scalar(ast::ScalarType::B8), - ); + typ: ast::Type::Scalar(ast::ScalarType::B8), + member_index: None, + }); let mut new_statements = vec![shared_var, shared_var_st]; replace_uses_of_shared_memory( &mut new_statements, @@ -963,18 +974,17 @@ fn compute_denorm_information<'input>( denorm_count_map_update(&mut flush_counter, width, flush); } } - Statement::LoadVar(_, _) => {} - Statement::StoreVar(_, _) => {} + Statement::LoadVar(..) => {} + Statement::StoreVar(..) => {} Statement::Call(_) => {} - Statement::Composite(_) => {} Statement::Conditional(_) => {} Statement::Conversion(_) => {} Statement::Constant(_) => {} Statement::RetValue(_, _) => {} - Statement::Undef(_, _) => {} Statement::Label(_) => {} Statement::Variable(_) => {} Statement::PtrAccess { .. } => {} + Statement::RepackVector(_) => {} } } denorm_methods.insert(method_key, flush_counter); @@ -1307,7 +1317,7 @@ fn to_ssa<'input, 'b>( let mut numeric_id_defs = id_defs.finish(); let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = - convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?; + convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; let typed_statements = convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; let ssa_statements = insert_mem_ssa_statements( @@ -1431,7 +1441,7 @@ fn normalize_variable_decls(directives: &mut Vec) { fn convert_to_typed_statements( func: Vec, fn_defs: &GlobalFnDeclResolver, - id_defs: &NumericIdResolver, + id_defs: &mut NumericIdResolver, ) -> Result, TranslateError> { let mut result = Vec::::with_capacity(func.len()); for s in func { @@ -1447,7 +1457,7 @@ fn convert_to_typed_statements( .partition(|(_, arg_type)| arg_type.is_param()); let normalized_input_args = out_params .into_iter() - .map(|(id, typ)| (ast::CallOperand::Reg(id), typ)) + .map(|(id, typ)| (ast::Operand::Reg(id), typ)) .chain(in_args.into_iter()) .collect(); let resolved_call = ResolvedCall { @@ -1456,205 +1466,117 @@ fn convert_to_typed_statements( func: call.func, param_list: normalized_input_args, }; - result.push(Statement::Call(resolved_call)); + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let reresolved_call = resolved_call.visit(&mut visitor)?; + visitor.func.push(reresolved_call); + visitor.func.extend(visitor.post_stmts); } - ast::Instruction::Ld(d, arg) => { - result.push(Statement::Instruction(ast::Instruction::Ld(d, arg.cast()))); - } - ast::Instruction::St(d, arg) => { - result.push(Statement::Instruction(ast::Instruction::St(d, arg.cast()))); - } - ast::Instruction::Mov(mut d, args) => match args { - ast::Arg2Mov::Normal(arg) => { - if let Some(src_id) = arg.src.single_underlying() { - let (typ, _) = id_defs.get_typed(*src_id)?; - let take_address = match typ { - ast::Type::Scalar(_) => false, - ast::Type::Vector(_, _) => false, - ast::Type::Array(_, _) => true, - ast::Type::Pointer(_, _) => true, - }; - d.src_is_address = take_address; - } - result.push(Statement::Instruction(ast::Instruction::Mov( - d, - ast::Arg2Mov::Normal(arg.cast()), - ))); - } - ast::Arg2Mov::Member(args) => { - if let Some(dst_typ) = args.vector_dst() { - match id_defs.get_typed(*dst_typ)? { - (ast::Type::Vector(_, len), _) => { - d.dst_width = len; - } - _ => return Err(TranslateError::MismatchedType), - } + ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => { + if let Some(src_id) = src.underlying() { + let (typ, _) = id_defs.get_typed(*src_id)?; + let take_address = match typ { + ast::Type::Scalar(_) => false, + ast::Type::Vector(_, _) => false, + ast::Type::Array(_, _) => true, + ast::Type::Pointer(_, _) => true, }; - if let Some((src_typ, _)) = args.vector_src() { - match id_defs.get_typed(*src_typ)? { - (ast::Type::Vector(_, len), _) => { - d.src_width = len; - } - _ => return Err(TranslateError::MismatchedType), - } - }; - result.push(Statement::Instruction(ast::Instruction::Mov( - d, - ast::Arg2Mov::Member(args.cast()), - ))); + d.src_is_address = take_address; } - }, - ast::Instruction::Mul(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Mul(d, a.cast()))) + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let instruction = Statement::Instruction( + ast::Instruction::Mov(d, ast::Arg2Mov { dst, src }).map(&mut visitor)?, + ); + visitor.func.push(instruction); + visitor.func.extend(visitor.post_stmts); } - ast::Instruction::Add(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Add(d, a.cast()))) - } - ast::Instruction::Setp(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Setp(d, a.cast()))) - } - ast::Instruction::SetpBool(d, a) => result.push(Statement::Instruction( - ast::Instruction::SetpBool(d, a.cast()), - )), - ast::Instruction::Not(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Not(d, a.cast()))) - } - ast::Instruction::Bra(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Bra(d, a.cast()))) - } - ast::Instruction::Cvt(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Cvt(d, a.cast()))) - } - ast::Instruction::Cvta(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Cvta(d, a.cast()))) - } - ast::Instruction::Shl(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Shl(d, a.cast()))) - } - ast::Instruction::Ret(d) => { - result.push(Statement::Instruction(ast::Instruction::Ret(d))) - } - ast::Instruction::Abs(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Abs(d, a.cast()))) - } - ast::Instruction::Mad(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Mad(d, a.cast()))) - } - ast::Instruction::Shr(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Shr(d, a.cast()))) - } - ast::Instruction::Or(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast()))) - } - ast::Instruction::Sub(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Sub(d, a.cast()))) - } - ast::Instruction::Min(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Min(d, a.cast()))) - } - ast::Instruction::Max(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast()))) - } - ast::Instruction::Rcp(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Rcp(d, a.cast()))) - } - ast::Instruction::And(d, a) => { - result.push(Statement::Instruction(ast::Instruction::And(d, a.cast()))) - } - ast::Instruction::Selp(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Selp(d, a.cast()))) - } - ast::Instruction::Bar(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Bar(d, a.cast()))) - } - ast::Instruction::Atom(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Atom(d, a.cast()))) - } - ast::Instruction::AtomCas(d, a) => result.push(Statement::Instruction( - ast::Instruction::AtomCas(d, a.cast()), - )), - ast::Instruction::Div(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Div(d, a.cast()))) - } - ast::Instruction::Sqrt(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Sqrt(d, a.cast()))) - } - ast::Instruction::Rsqrt(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast()))) - } - ast::Instruction::Neg(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Neg(d, a.cast()))) - } - ast::Instruction::Sin { flush_to_zero, arg } => { - result.push(Statement::Instruction(ast::Instruction::Sin { - flush_to_zero, - arg: arg.cast(), - })) - } - ast::Instruction::Cos { flush_to_zero, arg } => { - result.push(Statement::Instruction(ast::Instruction::Cos { - flush_to_zero, - arg: arg.cast(), - })) - } - ast::Instruction::Lg2 { flush_to_zero, arg } => { - result.push(Statement::Instruction(ast::Instruction::Lg2 { - flush_to_zero, - arg: arg.cast(), - })) - } - ast::Instruction::Ex2 { flush_to_zero, arg } => { - result.push(Statement::Instruction(ast::Instruction::Ex2 { - flush_to_zero, - arg: arg.cast(), - })) - } - ast::Instruction::Clz { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Clz { - typ, - arg: arg.cast(), - })) - } - ast::Instruction::Brev { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Brev { - typ, - arg: arg.cast(), - })) - } - ast::Instruction::Popc { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Popc { - typ, - arg: arg.cast(), - })) - } - ast::Instruction::Xor { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Xor { - typ, - arg: arg.cast(), - })) - } - ast::Instruction::Bfe { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Bfe { - typ, - arg: arg.cast(), - })) - } - ast::Instruction::Rem { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Rem { - typ, - arg: arg.cast(), - })) + inst => { + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let instruction = Statement::Instruction(inst.map(&mut visitor)?); + visitor.func.push(instruction); + visitor.func.extend(visitor.post_stmts); } }, Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), Statement::Conditional(c) => result.push(Statement::Conditional(c)), - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), } } Ok(result) } +struct VectorRepackVisitor<'a, 'b> { + func: &'b mut Vec, + id_def: &'b mut NumericIdResolver<'a>, + post_stmts: Option, +} + +impl<'a, 'b> VectorRepackVisitor<'a, 'b> { + fn new(func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>) -> Self { + VectorRepackVisitor { + func, + id_def, + post_stmts: None, + } + } + + fn convert_vector( + &mut self, + is_dst: bool, + vector_sema: ArgumentSemantics, + typ: &ast::Type, + idx: Vec, + ) -> Result { + // mov.u32 foobar, {a,b}; + let scalar_t = match typ { + ast::Type::Vector(scalar_t, _) => *scalar_t, + _ => return Err(TranslateError::MismatchedType), + }; + let temp_vec = self.id_def.new_non_variable(Some(typ.clone())); + let statement = Statement::RepackVector(RepackVectorDetails { + is_extract: is_dst, + typ: scalar_t, + packed: temp_vec, + unpacked: idx, + vector_sema, + }); + if is_dst { + self.post_stmts = Some(statement); + } else { + self.func.push(statement); + } + Ok(temp_vec) + } +} + +impl<'a, 'b> ArgumentMapVisitor + for VectorRepackVisitor<'a, 'b> +{ + fn id( + &mut self, + desc: ArgumentDescriptor, + _: Option<&ast::Type>, + ) -> Result { + Ok(desc.op) + } + + fn operand( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result { + Ok(match desc.op { + ast::Operand::Reg(reg) => TypedOperand::Reg(reg), + ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), + ast::Operand::Imm(x) => TypedOperand::Imm(x), + ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), + ast::Operand::VecPack(vec) => { + TypedOperand::Reg(self.convert_vector(desc.is_dst, desc.sema, typ, vec)?) + } + }) + } +} + //TODO: share common code between this and to_ptx_impl_bfe_call fn to_ptx_impl_atomic_call( id_defs: &mut NumericIdResolver, @@ -1872,17 +1794,16 @@ fn normalize_labels( labels_in_use.insert(cond.if_true); labels_in_use.insert(cond.if_false); } - Statement::Composite(_) - | Statement::Call(_) - | Statement::Variable(_) - | Statement::LoadVar(_, _) - | Statement::StoreVar(_, _) - | Statement::RetValue(_, _) - | Statement::Conversion(_) - | Statement::Constant(_) - | Statement::Label(_) - | Statement::Undef(_, _) - | Statement::PtrAccess { .. } => {} + Statement::Call(..) + | Statement::Variable(..) + | Statement::LoadVar(..) + | Statement::StoreVar(..) + | Statement::RetValue(..) + | Statement::Conversion(..) + | Statement::Constant(..) + | Statement::Label(..) + | Statement::PtrAccess { .. } + | Statement::RepackVector(..) => {} } } iter::once(Statement::Label(id_def.new_non_variable(None))) @@ -1929,7 +1850,7 @@ fn normalize_predicates( } Statement::Variable(var) => result.push(Statement::Variable(var)), // Blocks are flattened when resolving ids - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), } } Ok(result) @@ -1956,7 +1877,7 @@ fn insert_mem_ssa_statements<'a, 'b>( array_init: arg.array_init.clone(), })); } - None => return Err(TranslateError::Unreachable), + None => return Err(error_unreachable()), } } for spirv_arg in fn_decl.input.iter_mut() { @@ -1970,13 +1891,14 @@ fn insert_mem_ssa_statements<'a, 'b>( name: spirv_arg.name, array_init: spirv_arg.array_init.clone(), })); - result.push(Statement::StoreVar( - ast::Arg2St { + result.push(Statement::StoreVar(StoreVarDetails { + arg: ast::Arg2St { src1: spirv_arg.name, src2: new_id, }, typ, - )); + member_index: None, + })); spirv_arg.name = new_id; } None => {} @@ -1993,13 +1915,14 @@ fn insert_mem_ssa_statements<'a, 'b>( if let &[out_param] = &fn_decl.output.as_slice() { let (typ, _) = id_def.get_typed(out_param.name)?; let new_id = id_def.new_non_variable(Some(typ.clone())); - result.push(Statement::LoadVar( - ast::Arg2 { + result.push(Statement::LoadVar(LoadVarDetails { + arg: ast::Arg2 { dst: new_id, src: out_param.name, }, - typ.clone(), - )); + typ: typ.clone(), + member_index: None, + })); result.push(Statement::RetValue(d, new_id)); } else { result.push(Statement::Instruction(ast::Instruction::Ret(d))) @@ -2010,13 +1933,14 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::Conditional(mut bra) => { let generated_id = id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred))); - result.push(Statement::LoadVar( - Arg2 { + result.push(Statement::LoadVar(LoadVarDetails { + arg: Arg2 { dst: generated_id, src: bra.predicate, }, - ast::Type::Scalar(ast::ScalarType::Pred), - )); + typ: ast::Type::Scalar(ast::ScalarType::Pred), + member_index: None, + })); bra.predicate = generated_id; result.push(Statement::Conditional(bra)); } @@ -2026,8 +1950,11 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::PtrAccess(ptr_access) => { insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)? } + Statement::RepackVector(repack) => { + insert_mem_ssa_statement_default(id_def, &mut result, repack)? + } s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s), - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), } } Ok(result) @@ -2059,101 +1986,156 @@ fn type_to_variable_type( scalar_type .clone() .try_into() - .map_err(|_| TranslateError::Unreachable)?, - (*space) - .try_into() - .map_err(|_| TranslateError::Unreachable)?, + .map_err(|_| error_unreachable())?, + (*space).try_into().map_err(|_| error_unreachable())?, ))) } ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None, - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), }) } -trait VisitVariable: Sized { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +trait Visitable: Sized { + fn visit( self, - f: &mut F, - ) -> Result; -} -trait VisitVariableExpanded { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result; + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, To>, TranslateError>; } -struct VisitArgumentDescriptor<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> { +struct VisitArgumentDescriptor< + 'a, + Ctor: FnOnce(spirv::Word) -> Statement, U>, + U: ArgParamsEx, +> { desc: ArgumentDescriptor, typ: &'a ast::Type, stmt_ctor: Ctor, } -impl<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> VisitVariableExpanded - for VisitArgumentDescriptor<'a, Ctor> +impl< + 'a, + Ctor: FnOnce(spirv::Word) -> Statement, U>, + T: ArgParamsEx, + U: ArgParamsEx, + > Visitable for VisitArgumentDescriptor<'a, Ctor, U> { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( + fn visit( self, - f: &mut F, - ) -> Result { - f(self.desc, Some(self.typ)).map(self.stmt_ctor) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok((self.stmt_ctor)(visitor.id(self.desc, Some(self.typ))?)) } } -fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( - id_def: &mut NumericIdResolver, - result: &mut Vec, - stmt: F, -) -> Result<(), TranslateError> { - let mut post_statements = Vec::new(); - let new_statement = stmt.visit_variable( - &mut |desc: ArgumentDescriptor, expected_type| { - if expected_type.is_none() { - return Ok(desc.op); - }; - let (var_type, is_variable) = id_def.get_typed(desc.op)?; - if !is_variable { - return Ok(desc.op); - } - let generated_id = id_def.new_non_variable(Some(var_type.clone())); - if !desc.is_dst { - result.push(Statement::LoadVar( - Arg2 { - dst: generated_id, - src: desc.op, +struct InsertMemSSAVisitor<'a, 'input> { + id_def: &'a mut NumericIdResolver<'input>, + func: &'a mut Vec, + post_statements: Vec, +} + +impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { + fn symbol( + &mut self, + desc: ArgumentDescriptor<(spirv::Word, Option)>, + expected_type: Option<&ast::Type>, + ) -> Result { + let symbol = desc.op.0; + if expected_type.is_none() { + return Ok(symbol); + }; + let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?; + if !is_variable { + return Ok(symbol); + }; + let member_index = match desc.op.1 { + Some(idx) => { + let vector_width = match var_type { + ast::Type::Vector(scalar_t, width) => { + var_type = ast::Type::Scalar(scalar_t); + width + } + _ => return Err(TranslateError::MismatchedType), + }; + Some(( + idx, + if self.id_def.special_registers.contains_key(&symbol) { + Some(vector_width) + } else { + None }, - var_type, - )); - } else { - post_statements.push(Statement::StoreVar( - Arg2St { - src1: desc.op, + )) + } + None => None, + }; + let generated_id = self.id_def.new_non_variable(Some(var_type.clone())); + if !desc.is_dst { + self.func.push(Statement::LoadVar(LoadVarDetails { + arg: Arg2 { + dst: generated_id, + src: symbol, + }, + typ: var_type, + member_index, + })); + } else { + self.post_statements + .push(Statement::StoreVar(StoreVarDetails { + arg: Arg2St { + src1: symbol, src2: generated_id, }, - var_type, - )); + typ: var_type, + member_index: member_index.map(|(idx, _)| idx), + })); + } + Ok(generated_id) + } +} + +impl<'a, 'input> ArgumentMapVisitor + for InsertMemSSAVisitor<'a, 'input> +{ + fn id( + &mut self, + desc: ArgumentDescriptor, + typ: Option<&ast::Type>, + ) -> Result { + self.symbol(desc.new_op((desc.op, None)), typ) + } + + fn operand( + &mut self, + desc: ArgumentDescriptor, + typ: &ast::Type, + ) -> Result { + Ok(match desc.op { + TypedOperand::Reg(reg) => { + TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) } - Ok(generated_id) - }, - )?; - result.push(new_statement); - result.append(&mut post_statements); + TypedOperand::RegOffset(reg, offset) => { + TypedOperand::RegOffset(self.symbol(desc.new_op((reg, None)), Some(typ))?, offset) + } + op @ TypedOperand::Imm(..) => op, + TypedOperand::VecMember(symbol, index) => { + TypedOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) + } + }) + } +} + +fn insert_mem_ssa_statement_default<'a, 'input, S: Visitable>( + id_def: &'a mut NumericIdResolver<'input>, + func: &'a mut Vec, + stmt: S, +) -> Result<(), TranslateError> { + let mut visitor = InsertMemSSAVisitor { + id_def, + func, + post_statements: Vec::new(), + }; + let new_stmt = stmt.visit(&mut visitor)?; + visitor.func.push(new_stmt); + visitor.func.extend(visitor.post_statements); Ok(()) } @@ -2193,15 +2175,19 @@ fn expand_arguments<'a, 'b>( result.push(Statement::PtrAccess(new_inst)); result.extend(post_stmts); } + Statement::RepackVector(repack) => { + let mut visitor = FlattenArguments::new(&mut result, id_def); + let (new_inst, post_stmts) = (repack.map(&mut visitor)?, visitor.post_stmts); + result.push(Statement::RepackVector(new_inst)); + result.extend(post_stmts); + } Statement::Label(id) => result.push(Statement::Label(id)), Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), - Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), - Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)), + Statement::LoadVar(details) => result.push(Statement::LoadVar(details)), + Statement::StoreVar(details) => result.push(Statement::StoreVar(details)), Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), - Statement::Composite(_) | Statement::Constant(_) | Statement::Undef(_, _) => { - return Err(TranslateError::Unreachable) - } + Statement::Constant(_) => return Err(error_unreachable()), } } Ok(result) @@ -2225,27 +2211,6 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { } } - fn insert_composite_read( - func: &mut Vec, - id_def: &mut MutableNumericIdResolver<'a>, - typ: (ast::ScalarType, u8), - scalar_dst: Option, - scalar_sema_override: Option, - composite_src: (spirv::Word, u8), - ) -> spirv::Word { - let new_id = - scalar_dst.unwrap_or_else(|| id_def.new_non_variable(ast::Type::Scalar(typ.0))); - func.push(Statement::Composite(CompositeRead { - typ: typ.0, - dst: new_id, - dst_semantics_override: scalar_sema_override, - src_composite: composite_src.0, - src_index: composite_src.1 as u32, - src_len: typ.1 as u32, - })); - new_id - } - fn reg( &mut self, desc: ArgumentDescriptor, @@ -2367,69 +2332,6 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { })); Ok(id) } - - fn member_src( - &mut self, - desc: ArgumentDescriptor<(spirv::Word, u8)>, - typ: (ast::ScalarType, u8), - ) -> Result { - if desc.is_dst { - return Err(TranslateError::Unreachable); - } - let new_id = Self::insert_composite_read( - self.func, - self.id_def, - typ, - None, - Some(desc.sema), - desc.op, - ); - Ok(new_id) - } - - fn vector( - &mut self, - desc: ArgumentDescriptor<&Vec>, - typ: &ast::Type, - ) -> Result { - let (scalar_type, vec_len) = typ.get_vector()?; - if !desc.is_dst { - let mut new_id = self.id_def.new_non_variable(typ.clone()); - self.func.push(Statement::Undef(typ.clone(), new_id)); - for (idx, id) in desc.op.iter().enumerate() { - let newer_id = self.id_def.new_non_variable(typ.clone()); - self.func.push(Statement::Instruction(ast::Instruction::Mov( - ast::MovDetails { - typ: ast::Type::Scalar(scalar_type), - src_is_address: false, - dst_width: vec_len, - src_width: 0, - relaxed_src2_conv: desc.sema == ArgumentSemantics::DefaultRelaxed, - }, - ast::Arg2Mov::Member(ast::Arg2MovMember::Dst( - (newer_id, idx as u8), - new_id, - *id, - )), - ))); - new_id = newer_id; - } - Ok(new_id) - } else { - let new_id = self.id_def.new_non_variable(typ.clone()); - for (idx, id) in desc.op.iter().enumerate() { - Self::insert_composite_read( - &mut self.post_stmts, - self.id_def, - (scalar_type, vec_len), - Some(*id), - Some(desc.sema), - (new_id, idx as u8), - ); - } - Ok(new_id) - } - } } impl<'a, 'b> ArgumentMapVisitor for FlattenArguments<'a, 'b> { @@ -2443,58 +2345,16 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr fn operand( &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor, typ: &ast::Type, ) -> Result { match desc.op { - ast::Operand::Reg(r) => self.reg(desc.new_op(r), Some(typ)), - ast::Operand::Imm(x) => self.immediate(desc.new_op(x), typ), - ast::Operand::RegOffset(reg, offset) => { + TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)), + TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ), + TypedOperand::RegOffset(reg, offset) => { self.reg_offset(desc.new_op((reg, offset)), typ) } - } - } - - fn src_call_operand( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result { - match desc.op { - ast::CallOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)), - ast::CallOperand::Imm(x) => self.immediate(desc.new_op(x), typ), - } - } - - fn src_member_operand( - &mut self, - desc: ArgumentDescriptor<(spirv::Word, u8)>, - typ: (ast::ScalarType, u8), - ) -> Result { - self.member_src(desc, typ) - } - - fn id_or_vector( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result { - match desc.op { - ast::IdOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)), - ast::IdOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ), - } - } - - fn operand_or_vector( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result { - match desc.op { - ast::OperandOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)), - ast::OperandOrVector::RegOffset(r, imm) => self.reg_offset(desc.new_op((r, imm)), typ), - ast::OperandOrVector::Imm(imm) => self.immediate(desc.new_op(imm), typ), - ast::OperandOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ), + TypedOperand::VecMember(..) => Err(error_unreachable()), } } } @@ -2543,7 +2403,7 @@ fn insert_implicit_conversions( if let ast::Instruction::AtomCas(d, _) = &inst { state_space = Some(d.space.to_ld_ss()); } - if let ast::Instruction::Mov(_, ast::Arg2Mov::Normal(_)) = &inst { + if let ast::Instruction::Mov(..) = &inst { default_conversion_fn = should_bitcast_packed; } insert_implicit_conversions_impl( @@ -2554,13 +2414,6 @@ fn insert_implicit_conversions( state_space, )?; } - Statement::Composite(composite) => insert_implicit_conversions_impl( - &mut result, - id_def, - composite, - should_bitcast_wrapper, - None, - )?, Statement::PtrAccess(PtrAccess { underlying_type, state_space, @@ -2593,14 +2446,20 @@ fn insert_implicit_conversions( Some(state_space), )?; } + Statement::RepackVector(repack) => insert_implicit_conversions_impl( + &mut result, + id_def, + repack, + should_bitcast_wrapper, + None, + )?, s @ Statement::Conditional(_) | s @ Statement::Conversion(_) | s @ Statement::Label(_) | s @ Statement::Constant(_) | s @ Statement::Variable(_) - | s @ Statement::LoadVar(_, _) - | s @ Statement::StoreVar(_, _) - | s @ Statement::Undef(_, _) + | s @ Statement::LoadVar(..) + | s @ Statement::StoreVar(..) | s @ Statement::RetValue(_, _) => result.push(s), } } @@ -2610,7 +2469,7 @@ fn insert_implicit_conversions( fn insert_implicit_conversions_impl( func: &mut Vec, id_def: &mut MutableNumericIdResolver, - stmt: impl VisitVariableExpanded, + stmt: impl Visitable, default_conversion_fn: for<'a> fn( &'a ast::Type, &'a ast::Type, @@ -2619,62 +2478,64 @@ fn insert_implicit_conversions_impl( state_space: Option, ) -> Result<(), TranslateError> { let mut post_conv = Vec::new(); - let statement = stmt.visit_variable_extended(&mut |desc, typ| { - let instr_type = match typ { - None => return Ok(desc.op), - Some(t) => t, - }; - let operand_type = id_def.get_typed(desc.op)?; - let mut conversion_fn = default_conversion_fn; - match desc.sema { - ArgumentSemantics::Default => {} - ArgumentSemantics::DefaultRelaxed => { - if desc.is_dst { - conversion_fn = should_convert_relaxed_dst_wrapper; - } else { - conversion_fn = should_convert_relaxed_src_wrapper; + let statement = stmt.visit( + &mut |desc: ArgumentDescriptor, typ: Option<&ast::Type>| { + let instr_type = match typ { + None => return Ok(desc.op), + Some(t) => t, + }; + let operand_type = id_def.get_typed(desc.op)?; + let mut conversion_fn = default_conversion_fn; + match desc.sema { + ArgumentSemantics::Default => {} + ArgumentSemantics::DefaultRelaxed => { + if desc.is_dst { + conversion_fn = should_convert_relaxed_dst_wrapper; + } else { + conversion_fn = should_convert_relaxed_src_wrapper; + } } - } - ArgumentSemantics::PhysicalPointer => { - conversion_fn = bitcast_physical_pointer; - } - ArgumentSemantics::RegisterPointer => { - conversion_fn = bitcast_register_pointer; - } - ArgumentSemantics::Address => { - conversion_fn = force_bitcast_ptr_to_bit; - } - }; - match conversion_fn(&operand_type, instr_type, state_space)? { - Some(conv_kind) => { - let conv_output = if desc.is_dst { - &mut post_conv - } else { - &mut *func - }; - let mut from = instr_type.clone(); - let mut to = operand_type; - let mut src = id_def.new_non_variable(instr_type.clone()); - let mut dst = desc.op; - let result = Ok(src); - if !desc.is_dst { - mem::swap(&mut src, &mut dst); - mem::swap(&mut from, &mut to); + ArgumentSemantics::PhysicalPointer => { + conversion_fn = bitcast_physical_pointer; } - conv_output.push(Statement::Conversion(ImplicitConversion { - src, - dst, - from, - to, - kind: conv_kind, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, - })); - result + ArgumentSemantics::RegisterPointer => { + conversion_fn = bitcast_register_pointer; + } + ArgumentSemantics::Address => { + conversion_fn = force_bitcast_ptr_to_bit; + } + }; + match conversion_fn(&operand_type, instr_type, state_space)? { + Some(conv_kind) => { + let conv_output = if desc.is_dst { + &mut post_conv + } else { + &mut *func + }; + let mut from = instr_type.clone(); + let mut to = operand_type; + let mut src = id_def.new_non_variable(instr_type.clone()); + let mut dst = desc.op; + let result = Ok(src); + if !desc.is_dst { + mem::swap(&mut src, &mut dst); + mem::swap(&mut from, &mut to); + } + conv_output.push(Statement::Conversion(ImplicitConversion { + src, + dst, + from, + to, + kind: conv_kind, + src_sema: ArgumentSemantics::Default, + dst_sema: ArgumentSemantics::Default, + })); + result + } + None => Ok(desc.op), } - None => Ok(desc.op), - } - })?; + }, + )?; func.push(statement); func.append(&mut post_conv); Ok(()) @@ -2861,38 +2722,11 @@ fn emit_function_body_ops( } // SPIR-V does not support ret as guaranteed-converged ast::Instruction::Ret(_) => builder.ret()?, - ast::Instruction::Mov(d, arg) => match arg { - ast::Arg2Mov::Normal(ast::Arg2MovNormal { dst, src }) - | ast::Arg2Mov::Member(ast::Arg2MovMember::Src(dst, src)) => { - let result_type = map - .get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone()))); - builder.copy_object(result_type, Some(*dst), *src)?; - } - ast::Arg2Mov::Member(ast::Arg2MovMember::Dst( - dst, - composite_src, - scalar_src, - )) - | ast::Arg2Mov::Member(ast::Arg2MovMember::Both( - dst, - composite_src, - scalar_src, - )) => { - let scalar_type = d.typ.get_scalar()?; - let result_type = map.get_or_add( - builder, - SpirvType::from(ast::Type::Vector(scalar_type, d.dst_width)), - ); - let result_id = Some(dst.0); - builder.composite_insert( - result_type, - result_id, - *scalar_src, - *composite_src, - [dst.1 as u32], - )?; - } - }, + ast::Instruction::Mov(d, arg) => { + let result_type = + map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone()))); + builder.copy_object(result_type, Some(arg.dst), arg.src)?; + } ast::Instruction::Mul(mul, arg) => match mul { ast::MulDetails::Signed(ref ctr) => { emit_mul_sint(builder, map, opencl, ctr, arg)? @@ -3202,30 +3036,38 @@ fn emit_function_body_ops( builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; } }, - Statement::LoadVar(arg, typ) => { - let type_id = map.get_or_add(builder, SpirvType::from(typ.clone())); - builder.load(type_id, Some(arg.dst), arg.src, None, [])?; + Statement::LoadVar(details) => { + emit_load_var(builder, map, details)?; } - Statement::StoreVar(arg, _) => { - builder.store(arg.src1, arg.src2, None, [])?; + Statement::StoreVar(details) => { + let dst_ptr = match details.member_index { + Some(index) => { + let result_ptr_type = map.get_or_add( + builder, + SpirvType::new_pointer( + details.typ.clone(), + spirv::StorageClass::Function, + ), + ); + let index_spirv = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(index as u32), + )?; + builder.in_bounds_access_chain( + result_ptr_type, + None, + details.arg.src1, + &[index_spirv], + )? + } + None => details.arg.src1, + }; + builder.store(dst_ptr, details.arg.src2, None, [])?; } Statement::RetValue(_, id) => { builder.ret_value(*id)?; } - Statement::Composite(c) => { - let result_type = map.get_or_add_scalar(builder, c.typ.into()); - let result_id = Some(c.dst); - builder.composite_extract( - result_type, - result_id, - c.src_composite, - [c.src_index], - )?; - } - Statement::Undef(t, id) => { - let result_type = map.get_or_add(builder, SpirvType::from(t.clone())); - builder.undef(result_type, Some(*id)); - } Statement::PtrAccess(PtrAccess { underlying_type, state_space, @@ -3254,6 +3096,38 @@ fn emit_function_body_ops( )?; builder.bitcast(result_type, Some(*dst), temp)?; } + Statement::RepackVector(repack) => { + if repack.is_extract { + let scalar_type = map.get_or_add_scalar(builder, repack.typ); + for (index, dst_id) in repack.unpacked.iter().enumerate() { + builder.composite_extract( + scalar_type, + Some(*dst_id), + repack.packed, + &[index as u32], + )?; + } + } else { + let vector_type = map.get_or_add( + builder, + SpirvType::Vector( + SpirvScalarKey::from(repack.typ), + repack.unpacked.len() as u8, + ), + ); + let mut temp_vec = builder.undef(vector_type, None); + for (index, src_id) in repack.unpacked.iter().enumerate() { + temp_vec = builder.composite_insert( + vector_type, + None, + *src_id, + temp_vec, + &[index as u32], + )?; + } + builder.copy_object(vector_type, Some(repack.packed), temp_vec)?; + } + } } } Ok(()) @@ -3271,7 +3145,7 @@ fn insert_shift_hack( 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16), 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64), 4 => return Ok(offset_var), - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), }; Ok(builder.u_convert(result_type, None, offset_var)?) } @@ -3351,7 +3225,7 @@ fn emit_atom( let spirv_op = match op { ast::AtomUIntOp::Add => dr::Builder::atomic_i_add, ast::AtomUIntOp::Inc | ast::AtomUIntOp::Dec => { - return Err(TranslateError::Unreachable); + return Err(error_unreachable()); } ast::AtomUIntOp::Min => dr::Builder::atomic_u_min, ast::AtomUIntOp::Max => dr::Builder::atomic_u_max, @@ -4165,6 +4039,58 @@ fn emit_implicit_conversion( Ok(()) } +fn emit_load_var( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + details: &LoadVarDetails, +) -> Result<(), TranslateError> { + let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone())); + match details.member_index { + Some((index, Some(width))) => { + let vector_type = match details.typ { + ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), + _ => return Err(TranslateError::MismatchedType), + }; + let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type)); + let vector_temp = builder.load(vector_type_spirv, None, details.arg.src, None, [])?; + builder.composite_extract( + result_type, + Some(details.arg.dst), + vector_temp, + &[index as u32], + )?; + } + Some((index, None)) => { + let result_ptr_type = map.get_or_add( + builder, + SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function), + ); + let index_spirv = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(index as u32), + )?; + let src = builder.in_bounds_access_chain( + result_ptr_type, + None, + details.arg.src, + &[index_spirv], + )?; + builder.load(result_type, Some(details.arg.dst), src, None, [])?; + } + None => { + builder.load( + result_type, + Some(details.arg.dst), + details.arg.src, + None, + [], + )?; + } + }; + Ok(()) +} + fn normalize_identifiers<'a, 'b>( id_defs: &mut FnStringIdResolver<'a, 'b>, fn_defs: &GlobalFnDeclResolver<'a, 'b>, @@ -4290,9 +4216,11 @@ fn convert_to_stateful_memory_access<'a>( }, arg, )) => { - if let Some(src) = arg.src.underlying() { - if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, arg.dst) { - stateful_markers.push((arg.dst, *src)); + if let (TypedOperand::Reg(dst), Some(src)) = + (arg.dst, arg.src.upcast().underlying()) + { + if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, dst) { + stateful_markers.push((dst, *src)); } } } @@ -4320,7 +4248,9 @@ fn convert_to_stateful_memory_access<'a>( }, arg, )) => { - if let (ast::IdOrVector::Reg(dst), Some(src)) = (&arg.dst, arg.src.underlying()) { + if let (TypedOperand::Reg(dst), Some(src)) = + (&arg.dst, arg.src.upcast().underlying()) + { if func_args_64bit.contains(src) { multi_hash_map_append(&mut stateful_init_reg, *dst, *src); } @@ -4369,13 +4299,17 @@ fn convert_to_stateful_memory_access<'a>( }), arg, )) => { - if let Some(src1) = arg.src1.underlying() { + if let (TypedOperand::Reg(dst), Some(src1)) = + (arg.dst, arg.src1.upcast().underlying()) + { if regs_ptr_current.contains(src1) && !regs_ptr_seen.contains(src1) { - regs_ptr_new.insert(arg.dst); + regs_ptr_new.insert(dst); } - } else if let Some(src2) = arg.src2.underlying() { + } else if let (TypedOperand::Reg(dst), Some(src2)) = + (arg.dst, arg.src2.upcast().underlying()) + { if regs_ptr_current.contains(src2) && !regs_ptr_seen.contains(src2) { - regs_ptr_new.insert(arg.dst); + regs_ptr_new.insert(dst); } } } @@ -4426,19 +4360,20 @@ fn convert_to_stateful_memory_access<'a>( }), arg, )) if is_add_ptr_direct(&remapped_ids, &arg) => { - let (ptr, offset) = match arg.src1.underlying() { + let (ptr, offset) = match arg.src1.upcast().underlying() { Some(src1) if remapped_ids.contains_key(src1) => { (remapped_ids.get(src1).unwrap(), arg.src2) } Some(src2) if remapped_ids.contains_key(src2) => { (remapped_ids.get(src2).unwrap(), arg.src1) } - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), }; + let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), state_space: ast::LdStateSpace::Global, - dst: *remapped_ids.get(&arg.dst).unwrap(), + dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, offset_src: offset, })) @@ -4454,14 +4389,14 @@ fn convert_to_stateful_memory_access<'a>( }), arg, )) if is_add_ptr_direct(&remapped_ids, &arg) => { - let (ptr, offset) = match arg.src1.underlying() { + let (ptr, offset) = match arg.src1.upcast().underlying() { Some(src1) if remapped_ids.contains_key(src1) => { (remapped_ids.get(src1).unwrap(), arg.src2) } Some(src2) if remapped_ids.contains_key(src2) => { (remapped_ids.get(src2).unwrap(), arg.src1) } - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), }; let offset_neg = id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64))); @@ -4472,21 +4407,23 @@ fn convert_to_stateful_memory_access<'a>( }, ast::Arg2 { src: offset, - dst: offset_neg, + dst: TypedOperand::Reg(offset_neg), }, ))); + let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), state_space: ast::LdStateSpace::Global, - dst: *remapped_ids.get(&arg.dst).unwrap(), + dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, - offset_src: ast::Operand::Reg(offset_neg), + offset_src: TypedOperand::Reg(offset_neg), })) } Statement::Instruction(inst) => { let mut post_statements = Vec::new(); - let new_statement = inst.visit_variable( - &mut |arg_desc: ArgumentDescriptor, expected_type| { + let new_statement = inst.visit( + &mut |arg_desc: ArgumentDescriptor, + expected_type: Option<&ast::Type>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, @@ -4499,14 +4436,13 @@ fn convert_to_stateful_memory_access<'a>( }, )?; result.push(new_statement); - for s in post_statements { - result.push(s); - } + result.extend(post_statements); } Statement::Call(call) => { let mut post_statements = Vec::new(); - let new_statement = call.visit_variable( - &mut |arg_desc: ArgumentDescriptor, expected_type| { + let new_statement = call.visit( + &mut |arg_desc: ArgumentDescriptor, + expected_type: Option<&ast::Type>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, @@ -4519,11 +4455,28 @@ fn convert_to_stateful_memory_access<'a>( }, )?; result.push(new_statement); - for s in post_statements { - result.push(s); - } + result.extend(post_statements); } - _ => return Err(TranslateError::Unreachable), + Statement::RepackVector(pack) => { + let mut post_statements = Vec::new(); + let new_statement = pack.visit( + &mut |arg_desc: ArgumentDescriptor, + expected_type: Option<&ast::Type>| { + convert_to_stateful_memory_access_postprocess( + id_defs, + &remapped_ids, + &func_args_ptr, + &mut result, + &mut post_statements, + arg_desc, + expected_type, + ) + }, + )?; + result.push(new_statement); + result.extend(post_statements); + } + _ => return Err(error_unreachable()), } } for arg in func_args.input.iter_mut() { @@ -4588,7 +4541,7 @@ fn convert_to_stateful_memory_access_postprocess( None => match func_args_ptr.get(&arg_desc.op) { Some(new_id) => { if arg_desc.is_dst { - return Err(TranslateError::Unreachable); + return Err(error_unreachable()); } // We skip conversion here to trigger PtrAcces in a later pass let old_type = match expected_type { @@ -4617,13 +4570,20 @@ fn convert_to_stateful_memory_access_postprocess( } fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { - if !remapped_ids.contains_key(&arg.dst) { - return false; - } - match arg.src1.underlying() { - Some(src1) if remapped_ids.contains_key(src1) => true, - Some(src2) if remapped_ids.contains_key(src2) => true, - _ => false, + match arg.dst { + TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { + return false + } + TypedOperand::Reg(dst) => { + if !remapped_ids.contains_key(&dst) { + return false; + } + match arg.src1.upcast().underlying() { + Some(src1) if remapped_ids.contains_key(src1) => true, + Some(src2) if remapped_ids.contains_key(src2) => true, + _ => false, + } + } } } @@ -4962,14 +4922,13 @@ enum Statement { // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), Call(ResolvedCall

), - LoadVar(ast::Arg2, ast::Type), - StoreVar(ast::Arg2St, ast::Type), - Composite(CompositeRead), + LoadVar(LoadVarDetails), + StoreVar(StoreVarDetails), Conversion(ImplicitConversion), Constant(ConstantDefinition), RetValue(ast::RetData, spirv::Word), - Undef(ast::Type, spirv::Word), PtrAccess(PtrAccess

), + RepackVector(RepackVectorDetails), } impl ExpandedStatement { @@ -4981,19 +4940,19 @@ impl ExpandedStatement { Statement::Variable(var) } Statement::Instruction(inst) => inst - .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| { + .visit(&mut |arg: ArgumentDescriptor<_>, _: Option<&ast::Type>| { Ok(f(arg.op, arg.is_dst)) }) .unwrap(), - Statement::LoadVar(mut arg, typ) => { - arg.dst = f(arg.dst, true); - arg.src = f(arg.src, false); - Statement::LoadVar(arg, typ) + Statement::LoadVar(mut details) => { + details.arg.dst = f(details.arg.dst, true); + details.arg.src = f(details.arg.src, false); + Statement::LoadVar(details) } - Statement::StoreVar(mut arg, typ) => { - arg.src1 = f(arg.src1, false); - arg.src2 = f(arg.src2, false); - Statement::StoreVar(arg, typ) + Statement::StoreVar(mut details) => { + details.arg.src1 = f(details.arg.src1, false); + details.arg.src2 = f(details.arg.src2, false); + Statement::StoreVar(details) } Statement::Call(mut call) => { for (id, typ) in call.ret_params.iter_mut() { @@ -5010,11 +4969,6 @@ impl ExpandedStatement { } Statement::Call(call) } - Statement::Composite(mut composite) => { - composite.dst = f(composite.dst, true); - composite.src_composite = f(composite.src_composite, false); - Statement::Composite(composite) - } Statement::Conditional(mut conditional) => { conditional.predicate = f(conditional.predicate, false); conditional.if_true = f(conditional.if_true, false); @@ -5034,10 +4988,6 @@ impl ExpandedStatement { let id = f(id, false); Statement::RetValue(data, id) } - Statement::Undef(typ, id) => { - let id = f(id, true); - Statement::Undef(typ, id) - } Statement::PtrAccess(PtrAccess { underlying_type, state_space, @@ -5056,19 +5006,100 @@ impl ExpandedStatement { offset_src: constant_src, }) } + Statement::RepackVector(_) => todo!(), } } } +struct LoadVarDetails { + arg: ast::Arg2, + typ: ast::Type, + // (index, vector_width) + // HACK ALERT + // For some reason IGC explodes when you try to load from builtin vectors + // using OpInBoundsAccessChain, the one true way to do it is to + // OpLoad+OpCompositeExtract + member_index: Option<(u8, Option)>, +} + +struct StoreVarDetails { + arg: ast::Arg2St, + typ: ast::Type, + member_index: Option, +} + +struct RepackVectorDetails { + is_extract: bool, + typ: ast::ScalarType, + packed: spirv::Word, + unpacked: Vec, + vector_sema: ArgumentSemantics, +} + +impl RepackVectorDetails { + fn map< + From: ArgParamsEx, + To: ArgParamsEx, + V: ArgumentMapVisitor, + >( + self, + visitor: &mut V, + ) -> Result { + let scalar = visitor.id( + ArgumentDescriptor { + op: self.packed, + is_dst: !self.is_extract, + sema: ArgumentSemantics::Default, + }, + Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)), + )?; + let scalar_type = self.typ; + let is_extract = self.is_extract; + let vector_sema = self.vector_sema; + let vector = self + .unpacked + .into_iter() + .map(|id| { + visitor.id( + ArgumentDescriptor { + op: id, + is_dst: is_extract, + sema: vector_sema, + }, + Some(&ast::Type::Scalar(scalar_type)), + ) + }) + .collect::>()?; + Ok(RepackVectorDetails { + is_extract, + typ: self.typ, + packed: scalar, + unpacked: vector, + vector_sema, + }) + } +} + +impl, U: ArgParamsEx> Visitable + for RepackVectorDetails +{ + fn visit( + self, + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok(Statement::RepackVector(self.map::<_, _, _>(visitor)?)) + } +} + struct ResolvedCall { pub uniform: bool, - pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>, - pub func: spirv::Word, - pub param_list: Vec<(P::CallOperand, ast::FnArgumentType)>, + pub ret_params: Vec<(P::Id, ast::FnArgumentType)>, + pub func: P::Id, + pub param_list: Vec<(P::Operand, ast::FnArgumentType)>, } impl ResolvedCall { - fn cast>(self) -> ResolvedCall { + fn cast>(self) -> ResolvedCall { ResolvedCall { uniform: self.uniform, ret_params: self.ret_params, @@ -5110,7 +5141,7 @@ impl> ResolvedCall { .param_list .into_iter() .map::, _>(|(id, typ)| { - let new_id = visitor.src_call_operand( + let new_id = visitor.operand( ArgumentDescriptor { op: id, is_dst: false, @@ -5130,32 +5161,14 @@ impl> ResolvedCall { } } -impl VisitVariable for ResolvedCall { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +impl, U: ArgParamsEx> Visitable + for ResolvedCall +{ + fn visit( self, - f: &mut F, - ) -> Result { - Ok(Statement::Call(self.map(f)?)) - } -} - -impl VisitVariableExpanded for ResolvedCall { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result { - Ok(Statement::Call(self.map(f)?)) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok(Statement::Call(self.map(visitor)?)) } } @@ -5208,18 +5221,14 @@ impl> PtrAccess

{ } } -impl VisitVariable for PtrAccess { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +impl, U: ArgParamsEx> Visitable + for PtrAccess +{ + fn visit( self, - f: &mut F, - ) -> Result { - Ok(Statement::PtrAccess(self.map(f)?)) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok(Statement::PtrAccess(self.map(visitor)?)) } } @@ -5244,10 +5253,6 @@ enum NormalizedArgParams {} impl ast::ArgParams for NormalizedArgParams { type Id = spirv::Word; type Operand = ast::Operand; - type CallOperand = ast::CallOperand; - type IdOrVector = ast::IdOrVector; - type OperandOrVector = ast::OperandOrVector; - type SrcMemberOperand = (spirv::Word, u8); } impl ArgParamsEx for NormalizedArgParams { @@ -5273,11 +5278,7 @@ enum TypedArgParams {} impl ast::ArgParams for TypedArgParams { type Id = spirv::Word; - type Operand = ast::Operand; - type CallOperand = ast::CallOperand; - type IdOrVector = ast::IdOrVector; - type OperandOrVector = ast::OperandOrVector; - type SrcMemberOperand = (spirv::Word, u8); + type Operand = TypedOperand; } impl ArgParamsEx for TypedArgParams { @@ -5289,6 +5290,25 @@ impl ArgParamsEx for TypedArgParams { } } +#[derive(Copy, Clone)] +enum TypedOperand { + Reg(spirv::Word), + RegOffset(spirv::Word, i32), + Imm(ast::ImmediateValue), + VecMember(spirv::Word, u8), +} + +impl TypedOperand { + fn upcast(self) -> ast::Operand { + match self { + TypedOperand::Reg(reg) => ast::Operand::Reg(reg), + TypedOperand::RegOffset(reg, idx) => ast::Operand::RegOffset(reg, idx), + TypedOperand::Imm(x) => ast::Operand::Imm(x), + TypedOperand::VecMember(vec, idx) => ast::Operand::VecMember(vec, idx), + } + } +} + type TypedStatement = Statement, TypedArgParams>; enum ExpandedArgParams {} @@ -5297,10 +5317,6 @@ type ExpandedStatement = Statement, Expanded impl ast::ArgParams for ExpandedArgParams { type Id = spirv::Word; type Operand = spirv::Word; - type CallOperand = spirv::Word; - type IdOrVector = spirv::Word; - type OperandOrVector = spirv::Word; - type SrcMemberOperand = spirv::Word; } impl ArgParamsEx for ExpandedArgParams { @@ -5312,29 +5328,6 @@ impl ArgParamsEx for ExpandedArgParams { } } -#[derive(Copy, Clone)] -pub enum StateSpace { - Reg, - Const, - Global, - Local, - Shared, - Param, -} - -impl From for StateSpace { - fn from(ss: ast::StateSpace) -> Self { - match ss { - ast::StateSpace::Reg => StateSpace::Reg, - ast::StateSpace::Const => StateSpace::Const, - ast::StateSpace::Global => StateSpace::Global, - ast::StateSpace::Local => StateSpace::Local, - ast::StateSpace::Shared => StateSpace::Shared, - ast::StateSpace::Param => StateSpace::Param, - } - } -} - enum Directive<'input> { Variable(ast::Variable), Method(Function<'input>), @@ -5359,26 +5352,6 @@ pub trait ArgumentMapVisitor { desc: ArgumentDescriptor, typ: &ast::Type, ) -> Result; - fn id_or_vector( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - ) -> Result; - fn operand_or_vector( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - ) -> Result; - fn src_call_operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - ) -> Result; - fn src_member_operand( - &mut self, - desc: ArgumentDescriptor, - typ: (ast::ScalarType, u8), - ) -> Result; } impl ArgumentMapVisitor for T @@ -5397,44 +5370,12 @@ where } fn operand( - &mut self, - desc: ArgumentDescriptor, - t: &ast::Type, - ) -> Result { - self(desc, Some(t)) - } - - fn id_or_vector( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, ) -> Result { self(desc, Some(typ)) } - - fn operand_or_vector( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - ) -> Result { - self(desc, Some(typ)) - } - - fn src_call_operand( - &mut self, - desc: ArgumentDescriptor, - t: &ast::Type, - ) -> Result { - self(desc, Some(t)) - } - - fn src_member_operand( - &mut self, - desc: ArgumentDescriptor, - (scalar_type, _): (ast::ScalarType, u8), - ) -> Result { - self(desc.new_op(desc.op), Some(&ast::Type::Scalar(scalar_type))) - } } impl<'a, T> ArgumentMapVisitor, NormalizedArgParams> for T @@ -5452,62 +5393,19 @@ where fn operand( &mut self, desc: ArgumentDescriptor>, - _: &ast::Type, + typ: &ast::Type, ) -> Result, TranslateError> { - match desc.op { - ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(id)?)), - ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(self(id)?, imm)), - ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)), - } - } - - fn id_or_vector( - &mut self, - desc: ArgumentDescriptor>, - _: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(id)?)), - ast::IdOrVector::Vec(ids) => Ok(ast::IdOrVector::Vec( - ids.into_iter().map(self).collect::>()?, - )), - } - } - - fn operand_or_vector( - &mut self, - desc: ArgumentDescriptor>, - _: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::OperandOrVector::Reg(id) => Ok(ast::OperandOrVector::Reg(self(id)?)), - ast::OperandOrVector::RegOffset(id, imm) => { - Ok(ast::OperandOrVector::RegOffset(self(id)?, imm)) - } - ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)), - ast::OperandOrVector::Vec(ids) => Ok(ast::OperandOrVector::Vec( - ids.into_iter().map(self).collect::>()?, - )), - } - } - - fn src_call_operand( - &mut self, - desc: ArgumentDescriptor>, - _: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(id)?)), - ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)), - } - } - - fn src_member_operand( - &mut self, - desc: ArgumentDescriptor<(&str, u8)>, - _: (ast::ScalarType, u8), - ) -> Result<(spirv::Word, u8), TranslateError> { - Ok((self(desc.op.0)?, desc.op.1)) + Ok(match desc.op { + ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?), + ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id)?, imm), + ast::Operand::Imm(imm) => ast::Operand::Imm(imm), + ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member), + ast::Operand::VecPack(ref ids) => ast::Operand::VecPack( + ids.into_iter() + .map(|id| self.id(desc.new_op(id), Some(typ))) + .collect::, _>>()?, + ), + }) } } @@ -5559,7 +5457,7 @@ impl ast::Instruction { ast::Instruction::Abs(d, arg.map(visitor, &ast::Type::Scalar(d.typ))?) } // Call instruction is converted to a call statement early on - ast::Instruction::Call(_) => return Err(TranslateError::Unreachable), + ast::Instruction::Call(_) => return Err(error_unreachable()), ast::Instruction::Ld(d, a) => { let new_args = a.map(visitor, &d)?; ast::Instruction::Ld(d, new_args) @@ -5752,18 +5650,12 @@ impl ast::Instruction { } } -impl VisitVariable for ast::Instruction { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +impl Visitable for ast::Instruction { + fn visit( self, - f: &mut F, - ) -> Result { - Ok(Statement::Instruction(self.map(f)?)) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok(Statement::Instruction(self.map(visitor)?)) } } @@ -5802,32 +5694,14 @@ impl ImplicitConversion { } } -impl VisitVariable for ImplicitConversion { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +impl, To: ArgParamsEx> Visitable + for ImplicitConversion +{ + fn visit( self, - f: &mut F, - ) -> Result { - self.map(f) - } -} - -impl VisitVariableExpanded for ImplicitConversion { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result { - self.map(f) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, To>, TranslateError> { + Ok(self.map(visitor)?) } } @@ -5848,79 +5722,24 @@ where fn operand( &mut self, - desc: ArgumentDescriptor>, - t: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(desc.new_op(id), Some(t))?)), - ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)), - ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset( - self(desc.new_op(id), Some(t))?, - imm, - )), - } - } - - fn src_call_operand( - &mut self, - desc: ArgumentDescriptor>, - t: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(desc.new_op(id), Some(t))?)), - ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)), - } - } - - fn id_or_vector( - &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor, typ: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(desc.new_op(id), Some(typ))?)), - ast::IdOrVector::Vec(ref ids) => Ok(ast::IdOrVector::Vec( - ids.iter() - .map(|id| self(desc.new_op(*id), Some(typ))) - .collect::>()?, - )), - } - } - - fn operand_or_vector( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::OperandOrVector::Reg(id) => { - Ok(ast::OperandOrVector::Reg(self(desc.new_op(id), Some(typ))?)) + ) -> Result { + Ok(match desc.op { + TypedOperand::Reg(id) => TypedOperand::Reg(self(desc.new_op(id), Some(typ))?), + TypedOperand::Imm(imm) => TypedOperand::Imm(imm), + TypedOperand::RegOffset(id, imm) => { + TypedOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm) } - ast::OperandOrVector::RegOffset(id, imm) => Ok(ast::OperandOrVector::RegOffset( - self(desc.new_op(id), Some(typ))?, - imm, - )), - ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)), - ast::OperandOrVector::Vec(ref ids) => Ok(ast::OperandOrVector::Vec( - ids.iter() - .map(|id| self(desc.new_op(*id), Some(typ))) - .collect::>()?, - )), - } - } - - fn src_member_operand( - &mut self, - desc: ArgumentDescriptor<(spirv::Word, u8)>, - (scalar_type, vector_len): (ast::ScalarType, u8), - ) -> Result<(spirv::Word, u8), TranslateError> { - Ok(( - self( - desc.new_op(desc.op.0), - Some(&ast::Type::Vector(scalar_type.into(), vector_len)), - )?, - desc.op.1, - )) + TypedOperand::VecMember(reg, index) => { + let scalar_type = match typ { + ast::Type::Scalar(scalar_t) => *scalar_t, + _ => return Err(error_unreachable()), + }; + let vec_type = ast::Type::Vector(scalar_type, index + 1); + TypedOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) + } + }) } } @@ -5942,7 +5761,7 @@ impl ast::Type { kind, ))) } - _ => Err(TranslateError::Unreachable), + _ => Err(error_unreachable()), } } @@ -6182,67 +6001,9 @@ impl ast::Instruction { } } -impl VisitVariableExpanded for ast::Instruction { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result { - Ok(Statement::Instruction(self.map(f)?)) - } -} - type Arg2 = ast::Arg2; type Arg2St = ast::Arg2St; -struct CompositeRead { - pub typ: ast::ScalarType, - pub dst: spirv::Word, - pub dst_semantics_override: Option, - pub src_composite: spirv::Word, - pub src_index: u32, - pub src_len: u32, -} - -impl VisitVariableExpanded for CompositeRead { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result { - let dst_sema = self - .dst_semantics_override - .unwrap_or(ArgumentSemantics::Default); - Ok(Statement::Composite(CompositeRead { - dst: f( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - sema: dst_sema, - }, - Some(&ast::Type::Scalar(self.typ)), - )?, - src_composite: f( - ArgumentDescriptor { - op: self.src_composite, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(self.typ, self.src_len as u8)), - )?, - ..self - })) - } -} - struct ConstantDefinition { pub dst: spirv::Word, pub typ: ast::ScalarType, @@ -6330,10 +6091,6 @@ impl From for ast::Type { } impl ast::Arg1 { - fn cast>(self) -> ast::Arg1 { - ast::Arg1 { src: self.src } - } - fn map>( self, visitor: &mut V, @@ -6352,10 +6109,6 @@ impl ast::Arg1 { } impl ast::Arg1Bar { - fn cast>(self) -> ast::Arg1Bar { - ast::Arg1Bar { src: self.src } - } - fn map>( self, visitor: &mut V, @@ -6373,25 +6126,18 @@ impl ast::Arg1Bar { } impl ast::Arg2 { - fn cast>(self) -> ast::Arg2 { - ast::Arg2 { - src: self.src, - dst: self.dst, - } - } - fn map>( self, visitor: &mut V, t: &ast::Type, ) -> Result, TranslateError> { - let new_dst = visitor.id( + let new_dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(t), + t, )?; let new_src = visitor.operand( ArgumentDescriptor { @@ -6413,13 +6159,13 @@ impl ast::Arg2 { dst_t: &ast::Type, src_t: &ast::Type, ) -> Result, TranslateError> { - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(dst_t), + dst_t, )?; let src = visitor.operand( ArgumentDescriptor { @@ -6434,21 +6180,12 @@ impl ast::Arg2 { } impl ast::Arg2Ld { - fn cast>( - self, - ) -> ast::Arg2Ld { - ast::Arg2Ld { - dst: self.dst, - src: self.src, - } - } - fn map>( self, visitor: &mut V, details: &ast::LdDetails, ) -> Result, TranslateError> { - let dst = visitor.id_or_vector( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6478,15 +6215,6 @@ impl ast::Arg2Ld { } impl ast::Arg2St { - fn cast>( - self, - ) -> ast::Arg2St { - ast::Arg2St { - src1: self.src1, - src2: self.src2, - } - } - fn map>( self, visitor: &mut V, @@ -6509,7 +6237,7 @@ impl ast::Arg2St { details.state_space.to_ld_ss(), ), )?; - let src2 = visitor.operand_or_vector( + let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6527,29 +6255,7 @@ impl ast::Arg2Mov { visitor: &mut V, details: &ast::MovDetails, ) -> Result, TranslateError> { - Ok(match self { - ast::Arg2Mov::Normal(arg) => ast::Arg2Mov::Normal(arg.map(visitor, details)?), - ast::Arg2Mov::Member(arg) => ast::Arg2Mov::Member(arg.map(visitor, details)?), - }) - } -} - -impl ast::Arg2MovNormal

{ - fn cast>( - self, - ) -> ast::Arg2MovNormal { - ast::Arg2MovNormal { - dst: self.dst, - src: self.src, - } - } - - fn map>( - self, - visitor: &mut V, - details: &ast::MovDetails, - ) -> Result, TranslateError> { - let dst = visitor.id_or_vector( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6557,7 +6263,7 @@ impl ast::Arg2MovNormal

{ }, &details.typ.clone().into(), )?; - let src = visitor.operand_or_vector( + let src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, @@ -6569,144 +6275,11 @@ impl ast::Arg2MovNormal

{ }, &details.typ.clone().into(), )?; - Ok(ast::Arg2MovNormal { dst, src }) - } -} - -impl ast::Arg2MovMember { - fn cast>( - self, - ) -> ast::Arg2MovMember { - match self { - ast::Arg2MovMember::Dst(dst, src1, src2) => ast::Arg2MovMember::Dst(dst, src1, src2), - ast::Arg2MovMember::Src(dst, src) => ast::Arg2MovMember::Src(dst, src), - ast::Arg2MovMember::Both(dst, src1, src2) => ast::Arg2MovMember::Both(dst, src1, src2), - } - } - - fn vector_dst(&self) -> Option<&T::Id> { - match self { - ast::Arg2MovMember::Src(_, _) => None, - ast::Arg2MovMember::Dst((d, _), _, _) | ast::Arg2MovMember::Both((d, _), _, _) => { - Some(d) - } - } - } - - fn vector_src(&self) -> Option<&T::SrcMemberOperand> { - match self { - ast::Arg2MovMember::Src(_, d) | ast::Arg2MovMember::Both(_, _, d) => Some(d), - ast::Arg2MovMember::Dst(_, _, _) => None, - } - } -} - -impl ast::Arg2MovMember { - fn map>( - self, - visitor: &mut V, - details: &ast::MovDetails, - ) -> Result, TranslateError> { - match self { - ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => { - let scalar_type = details.typ.get_scalar()?; - let dst = visitor.id( - ArgumentDescriptor { - op: dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(scalar_type, details.dst_width)), - )?; - let src1 = visitor.id( - ArgumentDescriptor { - op: composite_src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(scalar_type, details.dst_width)), - )?; - let src2 = visitor.id( - ArgumentDescriptor { - op: scalar_src, - is_dst: false, - sema: if details.src_is_address { - ArgumentSemantics::Address - } else if details.relaxed_src2_conv { - ArgumentSemantics::DefaultRelaxed - } else { - ArgumentSemantics::Default - }, - }, - Some(&details.typ.clone().into()), - )?; - Ok(ast::Arg2MovMember::Dst((dst, len), src1, src2)) - } - ast::Arg2MovMember::Src(dst, src) => { - let dst = visitor.id( - ArgumentDescriptor { - op: dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(&details.typ.clone().into()), - )?; - let scalar_typ = details.typ.get_scalar()?; - let src = visitor.src_member_operand( - ArgumentDescriptor { - op: src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - (scalar_typ.into(), details.src_width), - )?; - Ok(ast::Arg2MovMember::Src(dst, src)) - } - ast::Arg2MovMember::Both((dst, len), composite_src, src) => { - let scalar_type = details.typ.get_scalar()?; - let dst = visitor.id( - ArgumentDescriptor { - op: dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(scalar_type, details.dst_width)), - )?; - let composite_src = visitor.id( - ArgumentDescriptor { - op: composite_src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(scalar_type, details.dst_width)), - )?; - let src = visitor.src_member_operand( - ArgumentDescriptor { - op: src, - is_dst: false, - sema: if details.relaxed_src2_conv { - ArgumentSemantics::DefaultRelaxed - } else { - ArgumentSemantics::Default - }, - }, - (scalar_type.into(), details.src_width), - )?; - Ok(ast::Arg2MovMember::Both((dst, len), composite_src, src)) - } - } + Ok(ast::Arg2Mov { dst, src }) } } impl ast::Arg3 { - fn cast>(self) -> ast::Arg3 { - ast::Arg3 { - dst: self.dst, - src1: self.src1, - src2: self.src2, - } - } - fn map_non_shift>( self, visitor: &mut V, @@ -6718,13 +6291,13 @@ impl ast::Arg3 { } else { None }; - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(wide_type.as_ref().unwrap_or(typ)), + wide_type.as_ref().unwrap_or(typ), )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6750,13 +6323,13 @@ impl ast::Arg3 { visitor: &mut V, t: &ast::Type, ) -> Result, TranslateError> { - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(t), + t, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6784,13 +6357,13 @@ impl ast::Arg3 { state_space: ast::AtomSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(scalar_type)), + &ast::Type::Scalar(scalar_type), )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6816,15 +6389,6 @@ impl ast::Arg3 { } impl ast::Arg4 { - fn cast>(self) -> ast::Arg4 { - ast::Arg4 { - dst: self.dst, - src1: self.src1, - src2: self.src2, - src3: self.src3, - } - } - fn map>( self, visitor: &mut V, @@ -6836,13 +6400,13 @@ impl ast::Arg4 { } else { None }; - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(wide_type.as_ref().unwrap_or(t)), + wide_type.as_ref().unwrap_or(t), )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6881,13 +6445,13 @@ impl ast::Arg4 { visitor: &mut V, t: ast::SelpType, ) -> Result, TranslateError> { - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(t.into())), + &ast::Type::Scalar(t.into()), )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6928,13 +6492,13 @@ impl ast::Arg4 { state_space: ast::AtomSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(scalar_type)), + &ast::Type::Scalar(scalar_type), )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6976,13 +6540,13 @@ impl ast::Arg4 { visitor: &mut V, typ: &ast::Type, ) -> Result, TranslateError> { - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(typ), + typ, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -7019,15 +6583,6 @@ impl ast::Arg4 { } impl ast::Arg4Setp { - fn cast>(self) -> ast::Arg4Setp { - ast::Arg4Setp { - dst1: self.dst1, - dst2: self.dst2, - src1: self.src1, - src2: self.src2, - } - } - fn map>( self, visitor: &mut V, @@ -7079,22 +6634,12 @@ impl ast::Arg4Setp { } } -impl ast::Arg5 { - fn cast>(self) -> ast::Arg5 { - ast::Arg5 { - dst1: self.dst1, - dst2: self.dst2, - src1: self.src1, - src2: self.src2, - src3: self.src3, - } - } - +impl ast::Arg5Setp { fn map>( self, visitor: &mut V, t: &ast::Type, - ) -> Result, TranslateError> { + ) -> Result, TranslateError> { let dst1 = visitor.id( ArgumentDescriptor { op: self.dst1, @@ -7140,7 +6685,7 @@ impl ast::Arg5 { }, &ast::Type::Scalar(ast::ScalarType::Pred), )?; - Ok(ast::Arg5 { + Ok(ast::Arg5Setp { dst1, dst2, src1, @@ -7150,30 +6695,28 @@ impl ast::Arg5 { } } -impl ast::Type { - fn get_vector(&self) -> Result<(ast::ScalarType, u8), TranslateError> { - match self { - ast::Type::Vector(t, len) => Ok((*t, *len)), - _ => Err(TranslateError::MismatchedType), - } - } - - fn get_scalar(&self) -> Result { - match self { - ast::Type::Scalar(t) => Ok(*t), - _ => Err(TranslateError::MismatchedType), - } - } -} - -impl ast::CallOperand { +impl ast::Operand { fn map_variable Result>( self, f: &mut F, - ) -> Result, TranslateError> { + ) -> Result, TranslateError> { + Ok(match self { + ast::Operand::Reg(reg) => ast::Operand::Reg(f(reg)?), + ast::Operand::RegOffset(reg, offset) => ast::Operand::RegOffset(f(reg)?, offset), + ast::Operand::Imm(x) => ast::Operand::Imm(x), + ast::Operand::VecMember(reg, idx) => ast::Operand::VecMember(f(reg)?, idx), + ast::Operand::VecPack(vec) => { + ast::Operand::VecPack(vec.into_iter().map(f).collect::>()?) + } + }) + } +} + +impl ast::Operand { + fn unwrap_reg(&self) -> Result { match self { - ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(f(id)?)), - ast::CallOperand::Imm(x) => Ok(ast::CallOperand::Imm(x)), + ast::Operand::Reg(reg) => Ok(*reg), + _ => Err(error_unreachable()), } } } @@ -7394,15 +6937,8 @@ impl ast::Operand { match self { ast::Operand::Reg(r) | ast::Operand::RegOffset(r, _) => Some(r), ast::Operand::Imm(_) => None, - } - } -} - -impl ast::OperandOrVector { - fn single_underlying(&self) -> Option<&T> { - match self { - ast::OperandOrVector::Reg(r) | ast::OperandOrVector::RegOffset(r, _) => Some(r), - ast::OperandOrVector::Imm(_) | ast::OperandOrVector::Vec(_) => None, + ast::Operand::VecMember(reg, _) => Some(reg), + ast::Operand::VecPack(..) => None, } } } @@ -7500,7 +7036,7 @@ fn bitcast_physical_pointer( if let Some(space) = ss { Ok(Some(ConversionKind::BitToPtr(space))) } else { - Err(TranslateError::Unreachable) + Err(error_unreachable()) } } ast::Type::Scalar(ast::ScalarType::B32)