diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 7921930..078cb31 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -349,6 +349,7 @@ pub trait ArgParams { type ID; type Operand; type CallOperand; + type VecOperand; } pub struct ParsedArgParams<'a> { @@ -359,6 +360,7 @@ impl<'a> ArgParams for ParsedArgParams<'a> { type ID = &'a str; type Operand = Operand<&'a str>; type CallOperand = CallOperand<&'a str>; + type VecOperand = (&'a str, u8); } pub struct Arg1 { @@ -376,9 +378,9 @@ pub struct Arg2St { } pub enum Arg2Vec { - Dst((P::ID, u8), P::ID), - Src(P::ID, (P::ID, u8)), - Both((P::ID, u8), (P::ID, u8)), + Dst(P::VecOperand, P::ID), + Src(P::ID, P::VecOperand), + Both(P::VecOperand, P::VecOperand), } pub struct Arg3 { @@ -424,8 +426,7 @@ pub struct LdData { pub qualifier: LdStQualifier, pub state_space: LdStateSpace, pub caching: LdCacheOperator, - pub vector: Option, - pub typ: ScalarType, + pub typ: Type, } #[derive(Copy, Clone, PartialEq, Eq)] @@ -710,8 +711,7 @@ pub struct StData { pub qualifier: LdStQualifier, pub state_space: StStateSpace, pub caching: StCacheOperator, - pub vector: Option, - pub typ: ScalarType, + pub typ: Type, } #[derive(PartialEq, Eq, Copy, Clone)] diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index fd419f5..6e5f5e3 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -269,10 +269,10 @@ ScalarType: ast::ScalarType = { ".f16" => ast::ScalarType::F16, ".f16x2" => ast::ScalarType::F16x2, ".pred" => ast::ScalarType::Pred, - MemoryType + LdStScalarType }; -MemoryType: ast::ScalarType = { +LdStScalarType: ast::ScalarType = { ".b8" => ast::ScalarType::B8, ".b16" => ast::ScalarType::B16, ".b32" => ast::ScalarType::B32, @@ -446,13 +446,12 @@ Instruction: ast::Instruction> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld InstLd: ast::Instruction> = { - "ld" "," "[" "]" => { + "ld" "," "[" "]" => { ast::Instruction::Ld( ast::LdData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), state_space: ss.unwrap_or(ast::LdStateSpace::Generic), caching: cop.unwrap_or(ast::LdCacheOperator::Cached), - vector: v, typ: t }, ast::Arg2 { dst:dst, src:src } @@ -460,6 +459,11 @@ InstLd: ast::Instruction> = { } }; +LdStType: ast::Type = { + => ast::Type::Vector(t, v), + => ast::Type::Scalar(t), +} + LdStQualifier: ast::LdStQualifier = { ".weak" => ast::LdStQualifier::Weak, ".volatile" => ast::LdStQualifier::Volatile, @@ -895,13 +899,12 @@ ShlType: ast::ShlType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // Warning: NVIDIA documentation is incorrect, you can specify scope only once InstSt: ast::Instruction> = { - "st" "[" "]" "," => { + "st" "[" "]" "," => { ast::Instruction::St( ast::StData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), state_space: ss.unwrap_or(ast::StStateSpace::Generic), caching: cop.unwrap_or(ast::StCacheOperator::Writeback), - vector: v, typ: t }, ast::Arg2St { src1:src1, src2:src2 } diff --git a/ptx/src/test/spirv_run/call.spvtxt b/ptx/src/test/spirv_run/call.spvtxt index 001cda3..ca4685a 100644 --- a/ptx/src/test/spirv_run/call.spvtxt +++ b/ptx/src/test/spirv_run/call.spvtxt @@ -4,20 +4,20 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 - %45 = OpExtInstImport "OpenCL.std" + %47 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %4 "call" %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %48 = OpTypeFunction %void %ulong %ulong + %50 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %51 = OpTypeFunction %ulong %ulong + %53 = OpTypeFunction %ulong %ulong %ulong_1 = OpConstant %ulong 1 - %4 = OpFunction %void None %48 + %4 = OpFunction %void None %50 %12 = OpFunctionParameter %ulong %13 = OpFunctionParameter %ulong - %30 = OpLabel + %32 = OpLabel %5 = OpVariable %_ptr_Function_ulong Function %6 = OpVariable %_ptr_Function_ulong Function %7 = OpVariable %_ptr_Function_ulong Function @@ -38,7 +38,9 @@ %18 = OpLoad %ulong %28 OpStore %9 %18 %21 = OpLoad %ulong %9 - %20 = OpCopyObject %ulong %21 + %29 = OpCopyObject %ulong %21 + %30 = OpCopyObject %ulong %29 + %20 = OpCopyObject %ulong %30 OpStore %10 %20 %23 = OpLoad %ulong %10 %22 = OpFunctionCall %ulong %1 %23 @@ -48,26 +50,26 @@ OpStore %9 %24 %26 = OpLoad %ulong %8 %27 = OpLoad %ulong %9 - %29 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %26 - OpStore %29 %27 + %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %26 + OpStore %31 %27 OpReturn OpFunctionEnd - %1 = OpFunction %ulong None %51 - %34 = OpFunctionParameter %ulong - %43 = OpLabel - %32 = OpVariable %_ptr_Function_ulong Function - %31 = OpVariable %_ptr_Function_ulong Function + %1 = OpFunction %ulong None %53 + %36 = OpFunctionParameter %ulong + %45 = OpLabel + %34 = OpVariable %_ptr_Function_ulong Function %33 = OpVariable %_ptr_Function_ulong Function - OpStore %32 %34 - %36 = OpLoad %ulong %32 - %35 = OpCopyObject %ulong %36 - OpStore %33 %35 - %38 = OpLoad %ulong %33 - %37 = OpIAdd %ulong %38 %ulong_1 - OpStore %33 %37 - %40 = OpLoad %ulong %33 - %39 = OpCopyObject %ulong %40 - OpStore %31 %39 - %41 = OpLoad %ulong %31 - OpReturnValue %41 + %35 = OpVariable %_ptr_Function_ulong Function + OpStore %34 %36 + %38 = OpLoad %ulong %34 + %37 = OpCopyObject %ulong %38 + OpStore %35 %37 + %40 = OpLoad %ulong %35 + %39 = OpIAdd %ulong %40 %ulong_1 + OpStore %35 %39 + %42 = OpLoad %ulong %35 + %41 = OpCopyObject %ulong %42 + OpStore %33 %41 + %43 = OpLoad %ulong %33 + OpReturnValue %43 OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvta.spvtxt b/ptx/src/test/spirv_run/cvta.spvtxt index e708613..84e7eac 100644 --- a/ptx/src/test/spirv_run/cvta.spvtxt +++ b/ptx/src/test/spirv_run/cvta.spvtxt @@ -4,20 +4,20 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 - %25 = OpExtInstImport "OpenCL.std" + %29 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "cvta" %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %28 = OpTypeFunction %void %ulong %ulong + %32 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %float = OpTypeFloat 32 %_ptr_Function_float = OpTypePointer Function %float %_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float - %1 = OpFunction %void None %28 + %1 = OpFunction %void None %32 %7 = OpFunctionParameter %ulong %8 = OpFunctionParameter %ulong - %23 = OpLabel + %27 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -32,18 +32,22 @@ %11 = OpCopyObject %ulong %12 OpStore %5 %11 %14 = OpLoad %ulong %4 - %13 = OpCopyObject %ulong %14 + %22 = OpCopyObject %ulong %14 + %21 = OpCopyObject %ulong %22 + %13 = OpCopyObject %ulong %21 OpStore %4 %13 %16 = OpLoad %ulong %5 - %15 = OpCopyObject %ulong %16 + %24 = OpCopyObject %ulong %16 + %23 = OpCopyObject %ulong %24 + %15 = OpCopyObject %ulong %23 OpStore %5 %15 %18 = OpLoad %ulong %4 - %21 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18 - %17 = OpLoad %float %21 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18 + %17 = OpLoad %float %25 OpStore %6 %17 %19 = OpLoad %ulong %5 %20 = OpLoad %float %6 - %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19 - OpStore %22 %20 + %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19 + OpStore %26 %20 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/ld_st_implicit.ptx b/ptx/src/test/spirv_run/ld_st_implicit.ptx new file mode 100644 index 0000000..8562286 --- /dev/null +++ b/ptx/src/test/spirv_run/ld_st_implicit.ptx @@ -0,0 +1,20 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry ld_st_implicit( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b64 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.global.f32 temp, [in_addr]; + st.global.f32 [out_addr], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt new file mode 100644 index 0000000..e7dba5a --- /dev/null +++ b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt @@ -0,0 +1,48 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %23 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "ld_st_implicit" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %26 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %float = OpTypeFloat 32 +%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float + %uint = OpTypeInt 32 0 + %1 = OpFunction %void None %26 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %21 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %7 + OpStore %3 %8 + %10 = OpLoad %ulong %2 + %9 = OpCopyObject %ulong %10 + OpStore %4 %9 + %12 = OpLoad %ulong %3 + %11 = OpCopyObject %ulong %12 + OpStore %5 %11 + %14 = OpLoad %ulong %4 + %17 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14 + %18 = OpLoad %float %17 + %30 = OpBitcast %ulong %18 + %32 = OpUConvert %uint %30 + %13 = OpBitcast %uint %32 + OpStore %6 %13 + %15 = OpLoad %ulong %5 + %16 = OpLoad %ulong %6 + %33 = OpBitcast %uint %16 + %19 = OpUConvert %ulong %33 + %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15 + OpStore %20 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index a04f0eb..fd50d3c 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -8,10 +8,12 @@ use spirv_headers::Word; use spirv_tools_sys::{ spv_binary, spv_endianness_t, spv_parsed_instruction_t, spv_result_t, spv_target_env, }; +use std::collections::hash_map::Entry; use std::error; use std::ffi::{c_void, CStr, CString}; use std::fmt; use std::fmt::{Debug, Display, Formatter}; +use std::hash::Hash; use std::mem; use std::slice; use std::{borrow::Cow, collections::HashMap, env, fs, path::PathBuf, ptr, str}; @@ -41,6 +43,7 @@ macro_rules! test_ptx { } test_ptx!(ld_st, [1u64], [1u64]); +test_ptx!(ld_st_implicit, [0.5f32], [0.5f32]); test_ptx!(mov, [1u64], [1u64]); test_ptx!(mul_lo, [1u64], [2u64]); test_ptx!(mul_hi, [u64::max_value()], [1u64]); @@ -214,14 +217,45 @@ fn test_spvtxt_assert<'a>( } } } - panic!(spirv_text); + panic!(spirv_text.to_string()); } unsafe { spirv_tools::spvContextDestroy(spv_context) }; Ok(()) } +struct EqMap +where + T: Eq + Copy + Hash, +{ + m1: HashMap, + m2: HashMap, +} + +impl EqMap { + fn new() -> Self { + EqMap { + m1: HashMap::new(), + m2: HashMap::new(), + } + } + + fn is_equal(&mut self, t1: T, t2: T) -> bool { + match (self.m1.entry(t1), self.m2.entry(t2)) { + (Entry::Occupied(entry1), Entry::Occupied(entry2)) => { + *entry1.get() == t2 && *entry2.get() == t1 + } + (Entry::Vacant(entry1), Entry::Vacant(entry2)) => { + entry1.insert(t2); + entry2.insert(t1); + true + } + _ => false, + } + } +} + fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool { - let mut map = HashMap::new(); + let mut map = EqMap::new(); if !is_option_equal(&fn1.def, &fn2.def, &mut map, is_instr_equal) { return false; } @@ -247,7 +281,7 @@ fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool { true } -fn is_block_equal(b1: &Block, b2: &Block, map: &mut HashMap) -> bool { +fn is_block_equal(b1: &Block, b2: &Block, map: &mut EqMap) -> bool { if !is_option_equal(&b1.label, &b2.label, map, is_instr_equal) { return false; } @@ -262,11 +296,7 @@ fn is_block_equal(b1: &Block, b2: &Block, map: &mut HashMap) -> bool true } -fn is_instr_equal( - instr1: &Instruction, - instr2: &Instruction, - map: &mut HashMap, -) -> bool { +fn is_instr_equal(instr1: &Instruction, instr2: &Instruction, map: &mut EqMap) -> bool { if instr1.class.opcode != instr2.class.opcode { return false; } @@ -306,24 +336,14 @@ fn is_instr_equal( true } -fn is_word_equal(w1: &Word, w2: &Word, map: &mut HashMap) -> bool { - match map.entry(*w1) { - std::collections::hash_map::Entry::Occupied(entry) => { - if entry.get() != w2 { - return false; - } - } - std::collections::hash_map::Entry::Vacant(entry) => { - entry.insert(*w2); - } - } - true +fn is_word_equal(t1: &Word, t2: &Word, map: &mut EqMap) -> bool { + map.is_equal(*t1, *t2) } -fn is_option_equal) -> bool>( +fn is_option_equal) -> bool>( o1: &Option, o2: &Option, - map: &mut HashMap, + map: &mut EqMap, f: F, ) -> bool { match (o1, o2) { diff --git a/ptx/src/test/spirv_run/not.spvtxt b/ptx/src/test/spirv_run/not.spvtxt index de340ed..b358858 100644 --- a/ptx/src/test/spirv_run/not.spvtxt +++ b/ptx/src/test/spirv_run/not.spvtxt @@ -4,18 +4,18 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 - %24 = OpExtInstImport "OpenCL.std" + %26 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "not" %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %27 = OpTypeFunction %void %ulong %ulong + %29 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Generic_ulong = OpTypePointer Generic %ulong - %1 = OpFunction %void None %27 + %1 = OpFunction %void None %29 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong - %22 = OpLabel + %24 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -35,11 +35,13 @@ %14 = OpLoad %ulong %20 OpStore %6 %14 %17 = OpLoad %ulong %6 - %16 = OpNot %ulong %17 + %22 = OpCopyObject %ulong %17 + %21 = OpNot %ulong %22 + %16 = OpCopyObject %ulong %21 OpStore %7 %16 %18 = OpLoad %ulong %5 %19 = OpLoad %ulong %7 - %21 = OpConvertUToPtr %_ptr_Generic_ulong %18 - OpStore %21 %19 + %23 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %23 %19 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shl.spvtxt b/ptx/src/test/spirv_run/shl.spvtxt index dbd2664..4843a65 100644 --- a/ptx/src/test/spirv_run/shl.spvtxt +++ b/ptx/src/test/spirv_run/shl.spvtxt @@ -4,20 +4,20 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 - %25 = OpExtInstImport "OpenCL.std" + %27 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "shl" %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %28 = OpTypeFunction %void %ulong %ulong + %30 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Generic_ulong = OpTypePointer Generic %ulong %uint = OpTypeInt 32 0 %uint_2 = OpConstant %uint 2 - %1 = OpFunction %void None %28 + %1 = OpFunction %void None %30 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong - %23 = OpLabel + %25 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -37,11 +37,13 @@ %14 = OpLoad %ulong %21 OpStore %6 %14 %17 = OpLoad %ulong %6 - %16 = OpShiftLeftLogical %ulong %17 %uint_2 + %23 = OpCopyObject %ulong %17 + %22 = OpShiftLeftLogical %ulong %23 %uint_2 + %16 = OpCopyObject %ulong %22 OpStore %7 %16 %18 = OpLoad %ulong %5 %19 = OpLoad %ulong %7 - %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 - OpStore %22 %19 + %24 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %24 %19 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt index 6810fec..25dd80e 100644 --- a/ptx/src/test/spirv_run/vector.spvtxt +++ b/ptx/src/test/spirv_run/vector.spvtxt @@ -4,43 +4,92 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 - %25 = OpExtInstImport "OpenCL.std" + %58 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "add" + OpEntryPoint Kernel %31 "vector" %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %62 = OpTypeFunction %v2uint %v2uint +%_ptr_Function_v2uint = OpTypePointer Function %v2uint +%_ptr_Function_uint = OpTypePointer Function %uint %ulong = OpTypeInt 64 0 - %28 = OpTypeFunction %void %ulong %ulong + %66 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_1 = OpConstant %ulong 1 - %1 = OpFunction %void None %28 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %23 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %11 = OpLoad %ulong %2 - %10 = OpCopyObject %ulong %11 - OpStore %4 %10 - %13 = OpLoad %ulong %3 - %12 = OpCopyObject %ulong %13 - OpStore %5 %12 - %15 = OpLoad %ulong %4 - %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 - %14 = OpLoad %ulong %21 - OpStore %6 %14 - %17 = OpLoad %ulong %6 - %16 = OpIAdd %ulong %17 %ulong_1 - OpStore %7 %16 - %18 = OpLoad %ulong %5 - %19 = OpLoad %ulong %7 - %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 - OpStore %22 %19 +%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint + %1 = OpFunction %v2uint None %62 + %7 = OpFunctionParameter %v2uint + %30 = OpLabel + %3 = OpVariable %_ptr_Function_v2uint Function + %2 = OpVariable %_ptr_Function_v2uint Function + %4 = OpVariable %_ptr_Function_v2uint Function + %5 = OpVariable %_ptr_Function_uint Function + %6 = OpVariable %_ptr_Function_uint Function + OpStore %3 %7 + %9 = OpLoad %v2uint %3 + %24 = OpCompositeExtract %uint %9 0 + %8 = OpCopyObject %uint %24 + OpStore %5 %8 + %11 = OpLoad %v2uint %3 + %25 = OpCompositeExtract %uint %11 1 + %10 = OpCopyObject %uint %25 + OpStore %6 %10 + %13 = OpLoad %uint %5 + %14 = OpLoad %uint %6 + %12 = OpIAdd %uint %13 %14 + OpStore %6 %12 + %16 = OpLoad %uint %6 + %26 = OpCopyObject %uint %16 + %15 = OpCompositeInsert %uint %26 %15 0 + OpStore %4 %15 + %18 = OpLoad %uint %6 + %27 = OpCopyObject %uint %18 + %17 = OpCompositeInsert %uint %27 %17 1 + OpStore %4 %17 + %20 = OpLoad %v2uint %4 + %29 = OpCompositeExtract %uint %20 1 + %28 = OpCopyObject %uint %29 + %19 = OpCompositeInsert %uint %28 %19 0 + OpStore %4 %19 + %22 = OpLoad %v2uint %4 + %21 = OpCopyObject %v2uint %22 + OpStore %2 %21 + %23 = OpLoad %v2uint %2 + OpReturnValue %23 + OpFunctionEnd + %31 = OpFunction %void None %66 + %40 = OpFunctionParameter %ulong + %41 = OpFunctionParameter %ulong + %56 = OpLabel + %32 = OpVariable %_ptr_Function_ulong Function + %33 = OpVariable %_ptr_Function_ulong Function + %34 = OpVariable %_ptr_Function_ulong Function + %35 = OpVariable %_ptr_Function_ulong Function + %36 = OpVariable %_ptr_Function_v2uint Function + %37 = OpVariable %_ptr_Function_uint Function + %38 = OpVariable %_ptr_Function_uint Function + %39 = OpVariable %_ptr_Function_ulong Function + OpStore %32 %40 + OpStore %33 %41 + %43 = OpLoad %ulong %32 + %42 = OpCopyObject %ulong %43 + OpStore %34 %42 + %45 = OpLoad %ulong %33 + %44 = OpCopyObject %ulong %45 + OpStore %35 %44 + %47 = OpLoad %ulong %34 + %54 = OpConvertUToPtr %_ptr_Generic_v2uint %47 + %46 = OpLoad %v2uint %54 + OpStore %36 %46 + %49 = OpLoad %v2uint %36 + %48 = OpFunctionCall %v2uint %1 %49 + OpStore %36 %48 + %51 = OpLoad %v2uint %36 + %50 = OpCopyObject %ulong %51 + OpStore %39 %50 + %52 = OpLoad %ulong %35 + %53 = OpLoad %v2uint %36 + %55 = OpConvertUToPtr %_ptr_Generic_v2uint %52 + OpStore %55 %53 OpReturn OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7591722..57d3485 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,5 +1,5 @@ use crate::ast; -use rspirv::dr; +use rspirv::{binary::Disassemble, dr}; use std::collections::{hash_map, HashMap, HashSet}; use std::{borrow::Cow, iter, mem}; @@ -398,7 +398,8 @@ fn normalize_labels( labels_in_use.insert(cond.if_true); labels_in_use.insert(cond.if_false); } - Statement::Call(_) + Statement::Composite(_) + | Statement::Call(_) | Statement::Variable(_) | Statement::LoadVar(_, _) | Statement::StoreVar(_, _) @@ -528,13 +529,13 @@ fn insert_mem_ssa_statements<'a, 'b>( ast::Instruction::Ret(d) => { if let Some(out_param) = out_param { let typ = id_def.get_type(out_param); - let new_id = id_def.new_id(Some(typ)); + let new_id = id_def.new_id(typ); result.push(Statement::LoadVar( ast::Arg2 { dst: new_id, src: out_param, }, - typ, + typ.unwrap(), )); result.push(Statement::RetValue(d, new_id)); } else { @@ -561,19 +562,25 @@ fn insert_mem_ssa_statements<'a, 'b>( | Statement::Conversion(_) | Statement::RetValue(_, _) | Statement::Constant(_) => unreachable!(), + Statement::Composite(_) => todo!(), } } (f_args, result) } trait VisitVariable: Sized { - fn visit_variable<'a, F: FnMut(ArgumentDescriptor) -> spirv::Word>( + fn visit_variable< + 'a, + F: FnMut(ArgumentDescriptor, Option) -> spirv::Word, + >( self, f: &mut F, ) -> UnadornedStatement; } trait VisitVariableExpanded { - fn visit_variable_extended) -> spirv::Word>( + fn visit_variable_extended< + F: FnMut(ArgumentDescriptor, Option) -> spirv::Word, + >( self, f: &mut F, ) -> ExpandedStatement; @@ -585,8 +592,8 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( stmt: F, ) { let mut post_statements = Vec::new(); - let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor| { - let id_type = match (desc.typ, desc.is_pointer) { + let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor, _| { + let id_type = match (id_def.get_type(desc.op), desc.is_pointer) { (Some(t), false) => t, (Some(_), true) => ast::Type::Scalar(ast::ScalarType::B64), (None, _) => return desc.op, @@ -624,13 +631,15 @@ fn expand_arguments<'a, 'b>( match s { Statement::Call(call) => { let mut visitor = FlattenArguments::new(&mut result, id_def); - let new_call = call.map(&mut visitor); + let (new_call, post_stmts) = (call.map(&mut visitor), visitor.post_stmts); result.push(Statement::Call(new_call)); + result.extend(post_stmts); } Statement::Instruction(inst) => { let mut visitor = FlattenArguments::new(&mut result, id_def); - let new_inst = inst.map(&mut visitor); + let (new_inst, post_stmts) = (inst.map(&mut visitor), visitor.post_stmts); result.push(Statement::Instruction(new_inst)); + result.extend(post_stmts); } Statement::Variable(ast::Variable { align, @@ -646,7 +655,9 @@ fn expand_arguments<'a, 'b>( Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)), Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), - Statement::Conversion(_) | Statement::Constant(_) => unreachable!(), + Statement::Composite(_) | Statement::Conversion(_) | Statement::Constant(_) => { + unreachable!() + } } } result @@ -655,74 +666,79 @@ fn expand_arguments<'a, 'b>( struct FlattenArguments<'a, 'b> { func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>, + post_stmts: Vec, } impl<'a, 'b> FlattenArguments<'a, 'b> { fn new(func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>) -> Self { - FlattenArguments { func, id_def } + FlattenArguments { + func, + id_def, + post_stmts: Vec::new(), + } } } impl<'a, 'b> ArgumentMapVisitor for FlattenArguments<'a, 'b> { - fn variable(&mut self, desc: ArgumentDescriptor) -> spirv::Word { + fn variable( + &mut self, + desc: ArgumentDescriptor, + typ: Option, + ) -> spirv::Word { desc.op } - fn operand(&mut self, desc: ArgumentDescriptor>) -> spirv::Word { + fn operand( + &mut self, + desc: ArgumentDescriptor>, + typ: ast::Type, + ) -> spirv::Word { match desc.op { - ast::Operand::Reg(r) => self.variable(desc.new_op(r)), + ast::Operand::Reg(r) => self.variable(desc.new_op(r), Some(typ)), ast::Operand::Imm(x) => { - if let Some(typ) = desc.typ { - let scalar_t = if let ast::Type::Scalar(scalar) = typ { - scalar - } else { - todo!() - }; - let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t))); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id, - typ: scalar_t, - value: x, - })); - id + let scalar_t = if let ast::Type::Scalar(scalar) = typ { + scalar } else { todo!() - } + }; + let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t))); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id, + typ: scalar_t, + value: x, + })); + id } ast::Operand::RegOffset(reg, offset) => { - if let Some(typ) = desc.typ { - let scalar_t = if let ast::Type::Scalar(scalar) = typ { - scalar - } else { - todo!() - }; - let id_constant_stmt = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t))); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: scalar_t, - value: offset as i128, - })); - let result_id = self.id_def.new_id(desc.typ); - let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!()); - self.func.push(Statement::Instruction( - ast::Instruction::::Add( - ast::AddDetails::Int(ast::AddIntDesc { - typ: int_type, - saturate: false, - }), - ast::Arg3 { - dst: result_id, - src1: reg, - src2: id_constant_stmt, - }, - ), - )); - result_id + let scalar_t = if let ast::Type::Scalar(scalar) = typ { + scalar } else { todo!() - } + }; + let id_constant_stmt = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t))); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: scalar_t, + value: offset as i128, + })); + let result_id = self.id_def.new_id(Some(typ)); + let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!()); + self.func.push(Statement::Instruction( + ast::Instruction::::Add( + ast::AddDetails::Int(ast::AddIntDesc { + typ: int_type, + saturate: false, + }), + ast::Arg3 { + dst: result_id, + src1: reg, + src2: id_constant_stmt, + }, + ), + )); + result_id } } } @@ -730,18 +746,45 @@ impl<'a, 'b> ArgumentMapVisitor fn src_call_operand( &mut self, desc: ArgumentDescriptor>, + typ: ast::Type, ) -> spirv::Word { match desc.op { - ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg)), - ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x))), + ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg), Some(typ)), + ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x)), typ), } } fn src_vec_operand( &mut self, desc: ArgumentDescriptor<(spirv::Word, u8)>, - ) -> (spirv::Word, u8) { - (self.variable(desc.new_op(desc.op.0)), desc.op.1) + typ: ast::MovVectorType, + ) -> spirv::Word { + let (vector_id, index) = desc.op; + let new_id = self.id_def.new_id(Some(ast::Type::Scalar(typ.into()))); + let composite = if desc.is_dst { + Statement::Composite(CompositeAccess { + typ: typ, + dst: new_id, + src: vector_id, + index: index as u32, + is_write: true + }) + } else { + Statement::Composite(CompositeAccess { + typ: typ, + dst: new_id, + src: vector_id, + index: index as u32, + is_write: false + }) + }; + if desc.is_dst { + self.post_stmts.push(composite); + new_id + } else { + self.func.push(composite); + new_id + } } } @@ -768,48 +811,63 @@ fn insert_implicit_conversions( match s { Statement::Call(call) => insert_implicit_bitcasts(&mut result, id_def, call), Statement::Instruction(inst) => match inst { - ast::Instruction::Ld(ld, mut arg) => { - arg.src = insert_implicit_conversions_ld_src( - &mut result, - ast::Type::Scalar(ld.typ), + ast::Instruction::Ld(ld, arg) => { + let pre_conv = + get_implicit_conversions_ld_src(id_def, ld.typ, ld.state_space, arg.src); + let post_conv = get_implicit_conversions_ld_dst( id_def, - ld.state_space, - arg.src, - ); - insert_with_implicit_conversion_dst( - &mut result, ld.typ, - id_def, + arg.dst, should_convert_relaxed_dst, + false, + ); + insert_with_conversions( + &mut result, + id_def, arg, + pre_conv.into_iter(), + iter::empty(), + post_conv.into_iter().collect(), + |arg| &mut arg.src, |arg| &mut arg.dst, |arg| ast::Instruction::Ld(ld, arg), - ); + ) } - ast::Instruction::St(st, mut arg) => { - let arg_src2_type = id_def.get_type(arg.src2); - if let Some(conv) = should_convert_relaxed_src(arg_src2_type, st.typ) { - arg.src2 = insert_conversion_src( - &mut result, - id_def, - arg.src2, - arg_src2_type, - ast::Type::Scalar(st.typ), - conv, - ); - } - arg.src1 = insert_implicit_conversions_ld_src( - &mut result, - ast::Type::Scalar(st.typ), + ast::Instruction::St(st, arg) => { + let pre_conv = get_implicit_conversions_ld_dst( id_def, + st.typ, + arg.src2, + should_convert_relaxed_src, + true, + ); + let post_conv = get_implicit_conversions_ld_src( + id_def, + st.typ, st.state_space.to_ld_ss(), arg.src1, ); - result.push(Statement::Instruction(ast::Instruction::St(st, arg))); + let (pre_conv_dest, post_conv) = if st.state_space == ast::StStateSpace::Param { + (Vec::new(), post_conv) + } else { + (post_conv, Vec::new()) + }; + insert_with_conversions( + &mut result, + id_def, + arg, + pre_conv.into_iter(), + pre_conv_dest.into_iter(), + post_conv, + |arg| &mut arg.src2, + |arg| &mut arg.src1, + |arg| ast::Instruction::St(st, arg), + ) } inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst), }, - s @ Statement::Conditional(_) + s @ Statement::Composite(_) + | s @ Statement::Conditional(_) | s @ Statement::Label(_) | s @ Statement::Constant(_) | s @ Statement::Variable(_) @@ -950,10 +1008,10 @@ fn emit_function_body_ops( builder.branch(arg.src)?; } ast::Instruction::Ld(data, arg) => { - if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() { + if data.qualifier != ast::LdStQualifier::Weak { todo!() } - let result_type = map.get_or_add_scalar(builder, data.typ); + let result_type = map.get_or_add(builder, SpirvType::from(data.typ)); match data.state_space { ast::LdStateSpace::Generic | ast::LdStateSpace::Global => { builder.load(result_type, Some(arg.dst), arg.src, None, [])?; @@ -967,7 +1025,6 @@ fn emit_function_body_ops( } ast::Instruction::St(data, arg) => { if data.qualifier != ast::LdStQualifier::Weak - || data.vector.is_some() || (data.state_space != ast::StStateSpace::Generic && data.state_space != ast::StStateSpace::Param && data.state_space != ast::StStateSpace::Global) @@ -1030,7 +1087,10 @@ fn emit_function_body_ops( builder.copy_object(result_type, Some(arg.dst), arg.src)?; } ast::Instruction::SetpBool(_, _) => todo!(), - ast::Instruction::MovVector(_, _) => todo!(), + ast::Instruction::MovVector(t, arg) => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); + builder.copy_object(result_type, Some(arg.dst()), arg.src())?; + } }, Statement::LoadVar(arg, typ) => { let type_id = map.get_or_add(builder, SpirvType::from(*typ)); @@ -1042,6 +1102,19 @@ fn emit_function_body_ops( Statement::RetValue(_, id) => { builder.ret_value(*id)?; } + Statement::Composite(c) => { + let result_type = map.get_or_add_scalar(builder, c.typ.into()); + let result_id = Some(c.dst); + let indexes = [c.index]; + if c.is_write { + let object = c.src; + let composite = c.dst; + builder.composite_insert(result_type, result_id, object, composite, indexes)?; + } else { + let composite = c.src; + builder.composite_extract(result_type, result_id, composite, indexes)?; + } + } } } Ok(()) @@ -1188,7 +1261,7 @@ fn emit_setp( match (setp.cmp_op, setp.typ.kind()) { (ast::SetpCompareOp::Eq, ScalarKind::Signed) | (ast::SetpCompareOp::Eq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::Eq, ScalarKind::Byte) => { + | (ast::SetpCompareOp::Eq, ScalarKind::Bit) => { builder.i_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::Eq, ScalarKind::Float) => { @@ -1196,14 +1269,14 @@ fn emit_setp( } (ast::SetpCompareOp::NotEq, ScalarKind::Signed) | (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::NotEq, ScalarKind::Byte) => { + | (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => { builder.i_not_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::NotEq, ScalarKind::Float) => { builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::Less, ScalarKind::Unsigned) - | (ast::SetpCompareOp::Less, ScalarKind::Byte) => { + | (ast::SetpCompareOp::Less, ScalarKind::Bit) => { builder.u_less_than(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::Less, ScalarKind::Signed) => { @@ -1213,7 +1286,7 @@ fn emit_setp( builder.f_ord_less_than(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::LessOrEq, ScalarKind::Byte) => { + | (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => { builder.u_less_than_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => { @@ -1223,7 +1296,7 @@ fn emit_setp( builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::Greater, ScalarKind::Unsigned) - | (ast::SetpCompareOp::Greater, ScalarKind::Byte) => { + | (ast::SetpCompareOp::Greater, ScalarKind::Bit) => { builder.u_greater_than(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::Greater, ScalarKind::Signed) => { @@ -1233,7 +1306,7 @@ fn emit_setp( builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned) - | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Byte) => { + | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => { builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => { @@ -1294,54 +1367,56 @@ fn emit_implicit_conversion( map: &mut TypeWordMap, cv: &ImplicitConversion, ) -> Result<(), dr::Error> { - let (from_type, to_type) = match (cv.from, cv.to) { - (ast::Type::Scalar(from), ast::Type::Scalar(to)) => (from, to), - _ => todo!(), - }; + let from_parts = cv.from.to_parts(); + let to_parts = cv.to.to_parts(); match cv.kind { ConversionKind::Ptr(space) => { let dst_type = map.get_or_add( builder, - SpirvType::Pointer( - Box::new(SpirvType::Base(SpirvScalarKey::from(to_type))), - space.to_spirv(), - ), + SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()), ); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } ConversionKind::Default => { - if from_type.width() == to_type.width() { - let dst_type = map.get_or_add_scalar(builder, to_type); - if from_type.kind() != ScalarKind::Float && to_type.kind() != ScalarKind::Float { + if from_parts.width == to_parts.width { + let dst_type = map.get_or_add(builder, SpirvType::from(cv.from)); + if from_parts.scalar_kind != ScalarKind::Float + && to_parts.scalar_kind != ScalarKind::Float + { // It is noop, but another instruction expects result of this conversion builder.copy_object(dst_type, Some(cv.dst), cv.src)?; } else { builder.bitcast(dst_type, Some(cv.dst), cv.src)?; } } else { - let as_unsigned_type = map.get_or_add_scalar( + // This block is safe because it's illegal to implictly convert between floating point instructions + let same_width_bit_type = map.get_or_add( builder, - ast::ScalarType::from_parts(from_type.width(), ScalarKind::Unsigned), + SpirvType::from(ast::Type::from_parts(TypeParts { + scalar_kind: ScalarKind::Bit, + ..from_parts + })), ); - let as_unsigned = builder.bitcast(as_unsigned_type, None, cv.src)?; - let as_unsigned_wide_type = - ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned); - let as_unsigned_wide_spirv = map.get_or_add_scalar( - builder, - ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned), - ); - if to_type.kind() == ScalarKind::Unsigned || to_type.kind() == ScalarKind::Byte { - builder.u_convert(as_unsigned_wide_spirv, Some(cv.dst), as_unsigned)?; + let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?; + let wide_bit_type = ast::Type::from_parts(TypeParts { + scalar_kind: ScalarKind::Bit, + ..to_parts + }); + let wide_bit_type_spirv = map.get_or_add(builder, SpirvType::from(wide_bit_type)); + if to_parts.scalar_kind == ScalarKind::Unsigned + || to_parts.scalar_kind == ScalarKind::Bit + { + builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?; } else { - let as_unsigned_wide = - builder.u_convert(as_unsigned_wide_spirv, None, as_unsigned)?; + let wide_bit_value = + builder.u_convert(wide_bit_type_spirv, None, same_width_bit_value)?; emit_implicit_conversion( builder, map, &ImplicitConversion { - src: as_unsigned_wide, + src: wide_bit_value, dst: cv.dst, - from: ast::Type::Scalar(as_unsigned_wide_type), + from: wide_bit_type, to: cv.to, kind: ConversionKind::Default, }, @@ -1627,8 +1702,8 @@ struct NumericIdResolver<'b> { } impl<'b> NumericIdResolver<'b> { - fn get_type(&self, id: spirv::Word) -> ast::Type { - self.type_check[&id] + fn get_type(&self, id: spirv::Word) -> Option { + self.type_check.get(&id).map(|x| *x) } fn new_id(&mut self, typ: Option) -> spirv::Word { @@ -1648,6 +1723,7 @@ enum Statement { LoadVar(ast::Arg2, ast::Type), StoreVar(ast::Arg2St, ast::Type), Call(ResolvedCall

), + Composite(CompositeAccess), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), Conversion(ImplicitConversion), @@ -1671,31 +1747,37 @@ impl> ResolvedCall { .ret_params .into_iter() .map(|(id, typ)| { - let new_id = visitor.variable(ArgumentDescriptor { - op: id, - typ: Some(typ.into()), - is_dst: true, - is_pointer: false, - }); + let new_id = visitor.variable( + ArgumentDescriptor { + op: id, + is_dst: true, + is_pointer: false, + }, + Some(typ.into()), + ); (new_id, typ) }) .collect(); - let func = visitor.variable(ArgumentDescriptor { - op: self.func, - typ: None, - is_dst: false, - is_pointer: false, - }); + let func = visitor.variable( + ArgumentDescriptor { + op: self.func, + is_dst: false, + is_pointer: false, + }, + None, + ); let param_list = self .param_list .into_iter() .map(|(id, typ)| { - let new_id = visitor.src_call_operand(ArgumentDescriptor { - op: id, - typ: Some(typ.into()), - is_dst: false, - is_pointer: false, - }); + let new_id = visitor.src_call_operand( + ArgumentDescriptor { + op: id, + is_dst: false, + is_pointer: false, + }, + typ.into(), + ); (new_id, typ) }) .collect(); @@ -1709,7 +1791,10 @@ impl> ResolvedCall { } impl VisitVariable for ResolvedCall { - fn visit_variable<'a, F: FnMut(ArgumentDescriptor) -> spirv::Word>( + fn visit_variable< + 'a, + F: FnMut(ArgumentDescriptor, Option) -> spirv::Word, + >( self, f: &mut F, ) -> UnadornedStatement { @@ -1718,7 +1803,9 @@ impl VisitVariable for ResolvedCall { } impl VisitVariableExpanded for ResolvedCall { - fn visit_variable_extended) -> spirv::Word>( + fn visit_variable_extended< + F: FnMut(ArgumentDescriptor, Option) -> spirv::Word, + >( self, f: &mut F, ) -> ExpandedStatement { @@ -1750,6 +1837,7 @@ impl ast::ArgParams for NormalizedArgParams { type ID = spirv::Word; type Operand = ast::Operand; type CallOperand = ast::CallOperand; + type VecOperand = (spirv::Word, u8); } impl ArgParamsEx for NormalizedArgParams { @@ -1766,6 +1854,7 @@ impl ast::ArgParams for ExpandedArgParams { type ID = spirv::Word; type Operand = spirv::Word; type CallOperand = spirv::Word; + type VecOperand = spirv::Word; } impl ArgParamsEx for ExpandedArgParams { @@ -1775,30 +1864,47 @@ impl ArgParamsEx for ExpandedArgParams { } trait ArgumentMapVisitor { - fn variable(&mut self, desc: ArgumentDescriptor) -> U::ID; - fn operand(&mut self, desc: ArgumentDescriptor) -> U::Operand; - fn src_call_operand(&mut self, desc: ArgumentDescriptor) -> U::CallOperand; - fn src_vec_operand(&mut self, desc: ArgumentDescriptor<(T::ID, u8)>) -> (U::ID, u8); + fn variable(&mut self, desc: ArgumentDescriptor, typ: Option) -> U::ID; + fn operand(&mut self, desc: ArgumentDescriptor, typ: ast::Type) -> U::Operand; + fn src_call_operand( + &mut self, + desc: ArgumentDescriptor, + typ: ast::Type, + ) -> U::CallOperand; + fn src_vec_operand( + &mut self, + desc: ArgumentDescriptor, + typ: ast::MovVectorType, + ) -> U::VecOperand; } impl ArgumentMapVisitor for T where - T: FnMut(ArgumentDescriptor) -> spirv::Word, + T: FnMut(ArgumentDescriptor, Option) -> spirv::Word, { - fn variable(&mut self, desc: ArgumentDescriptor) -> spirv::Word { - self(desc) + fn variable( + &mut self, + desc: ArgumentDescriptor, + t: Option, + ) -> spirv::Word { + self(desc, t) } - fn operand(&mut self, desc: ArgumentDescriptor) -> spirv::Word { - self(desc) + fn operand(&mut self, desc: ArgumentDescriptor, t: ast::Type) -> spirv::Word { + self(desc, Some(t)) } - fn src_call_operand(&mut self, desc: ArgumentDescriptor) -> spirv::Word { - self(desc.new_op(desc.op)) + fn src_call_operand( + &mut self, + desc: ArgumentDescriptor, + t: ast::Type, + ) -> spirv::Word { + self(desc, Some(t)) } fn src_vec_operand( &mut self, - desc: ArgumentDescriptor<(spirv::Word, u8)>, - ) -> (spirv::Word, u8) { - (self(desc.new_op(desc.op.0)), desc.op.1) + desc: ArgumentDescriptor, + t: ast::MovVectorType, + ) -> spirv::Word { + self(desc, Some(ast::Type::Scalar(t.into()))) } } @@ -1806,13 +1912,14 @@ impl<'a, T> ArgumentMapVisitor, NormalizedArgParams> fo where T: FnMut(&str) -> spirv::Word, { - fn variable(&mut self, desc: ArgumentDescriptor<&str>) -> spirv::Word { + fn variable(&mut self, desc: ArgumentDescriptor<&str>, _: Option) -> spirv::Word { self(desc.op) } fn operand( &mut self, desc: ArgumentDescriptor>, + _: ast::Type, ) -> ast::Operand { match desc.op { ast::Operand::Reg(id) => ast::Operand::Reg(self(id)), @@ -1824,6 +1931,7 @@ where fn src_call_operand( &mut self, desc: ArgumentDescriptor>, + _: ast::Type, ) -> ast::CallOperand { match desc.op { ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(id)), @@ -1831,15 +1939,18 @@ where } } - fn src_vec_operand(&mut self, desc: ArgumentDescriptor<(&str, u8)>) -> (spirv::Word, u8) { + fn src_vec_operand( + &mut self, + desc: ArgumentDescriptor<(&str, u8)>, + _: ast::MovVectorType, + ) -> (spirv::Word, u8) { (self(desc.op.0), desc.op.1) } } -struct ArgumentDescriptor { - op: T, +struct ArgumentDescriptor { + op: Op, is_dst: bool, - typ: Option, is_pointer: bool, } @@ -1848,7 +1959,6 @@ impl ArgumentDescriptor { ArgumentDescriptor { op: u, is_dst: self.is_dst, - typ: self.typ, is_pointer: self.is_pointer, } } @@ -1860,39 +1970,35 @@ impl ast::Instruction { visitor: &mut V, ) -> ast::Instruction { match self { - ast::Instruction::MovVector(_, _) => todo!(), + ast::Instruction::MovVector(t, a) => ast::Instruction::MovVector(t, a.map(visitor, t)), ast::Instruction::Abs(_, _) => todo!(), + // Call instruction is converted to a call statement early on ast::Instruction::Call(_) => unreachable!(), ast::Instruction::Ld(d, a) => { let inst_type = d.typ; let src_is_pointer = d.state_space != ast::LdStateSpace::Param; - ast::Instruction::Ld( - d, - a.map_ld(visitor, Some(ast::Type::Scalar(inst_type)), src_is_pointer), - ) + ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, src_is_pointer)) } ast::Instruction::Mov(mov_type, a) => { - ast::Instruction::Mov(mov_type, a.map(visitor, Some(mov_type.into()))) + ast::Instruction::Mov(mov_type, a.map(visitor, mov_type.into())) } ast::Instruction::Mul(d, a) => { let inst_type = d.get_type(); - ast::Instruction::Mul(d, a.map_non_shift(visitor, Some(inst_type))) + ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type)) } ast::Instruction::Add(d, a) => { let inst_type = d.get_type(); - ast::Instruction::Add(d, a.map_non_shift(visitor, Some(inst_type))) + ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type)) } ast::Instruction::Setp(d, a) => { let inst_type = d.typ; - ast::Instruction::Setp(d, a.map(visitor, Some(ast::Type::Scalar(inst_type)))) + ast::Instruction::Setp(d, a.map(visitor, ast::Type::Scalar(inst_type))) } ast::Instruction::SetpBool(d, a) => { let inst_type = d.typ; - ast::Instruction::SetpBool(d, a.map(visitor, Some(ast::Type::Scalar(inst_type)))) - } - ast::Instruction::Not(t, a) => { - ast::Instruction::Not(t, a.map(visitor, Some(t.to_type()))) + ast::Instruction::SetpBool(d, a.map(visitor, ast::Type::Scalar(inst_type))) } + ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, t.to_type())), ast::Instruction::Cvt(d, a) => { let (dst_t, src_t) = match &d { ast::CvtDetails::FloatFromFloat(desc) => ( @@ -1915,28 +2021,28 @@ impl ast::Instruction { ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t)) } ast::Instruction::Shl(t, a) => { - ast::Instruction::Shl(t, a.map_shift(visitor, Some(t.to_type()))) + ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type())) } ast::Instruction::St(d, a) => { let inst_type = d.typ; let param_space = d.state_space == ast::StStateSpace::Param; - ast::Instruction::St( - d, - a.map(visitor, Some(ast::Type::Scalar(inst_type)), param_space), - ) + ast::Instruction::St(d, a.map(visitor, inst_type, param_space)) } ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)), ast::Instruction::Ret(d) => ast::Instruction::Ret(d), ast::Instruction::Cvta(d, a) => { let inst_type = ast::Type::Scalar(ast::ScalarType::B64); - ast::Instruction::Cvta(d, a.map(visitor, Some(inst_type))) + ast::Instruction::Cvta(d, a.map(visitor, inst_type)) } } } } impl VisitVariable for ast::Instruction { - fn visit_variable<'a, F: FnMut(ArgumentDescriptor) -> spirv::Word>( + fn visit_variable< + 'a, + F: FnMut(ArgumentDescriptor, Option) -> spirv::Word, + >( self, f: &mut F, ) -> UnadornedStatement { @@ -1946,29 +2052,37 @@ impl VisitVariable for ast::Instruction { impl ArgumentMapVisitor for T where - T: FnMut(ArgumentDescriptor) -> spirv::Word, + T: FnMut(ArgumentDescriptor, Option) -> spirv::Word, { - fn variable(&mut self, desc: ArgumentDescriptor) -> spirv::Word { - self(desc) + fn variable( + &mut self, + desc: ArgumentDescriptor, + t: Option, + ) -> spirv::Word { + self(desc, t) } fn operand( &mut self, desc: ArgumentDescriptor>, + t: ast::Type, ) -> ast::Operand { match desc.op { - ast::Operand::Reg(id) => ast::Operand::Reg(self(desc.new_op(id))), + ast::Operand::Reg(id) => ast::Operand::Reg(self(desc.new_op(id), Some(t))), ast::Operand::Imm(imm) => ast::Operand::Imm(imm), - ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(desc.new_op(id)), imm), + ast::Operand::RegOffset(id, imm) => { + ast::Operand::RegOffset(self(desc.new_op(id), Some(t)), imm) + } } } fn src_call_operand( &mut self, desc: ArgumentDescriptor>, + t: ast::Type, ) -> ast::CallOperand { match desc.op { - ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(desc.new_op(id))), + ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(desc.new_op(id), Some(t))), ast::CallOperand::Imm(imm) => ast::CallOperand::Imm(imm), } } @@ -1976,11 +2090,74 @@ where fn src_vec_operand( &mut self, desc: ArgumentDescriptor<(spirv::Word, u8)>, + t: ast::MovVectorType, ) -> (spirv::Word, u8) { - (self(desc.new_op(desc.op.0)), desc.op.1) + ( + self( + desc.new_op(desc.op.0), + Some(ast::Type::Vector(t.into(), desc.op.1)), + ), + desc.op.1, + ) } } +impl ast::Type { + fn to_parts(self) -> TypeParts { + match self { + ast::Type::Scalar(scalar) => TypeParts { + kind: TypeKind::Scalar, + scalar_kind: scalar.kind(), + width: scalar.width(), + components: 0, + }, + ast::Type::Vector(scalar, components) => TypeParts { + kind: TypeKind::Vector, + scalar_kind: scalar.kind(), + width: scalar.width(), + components: components as u32, + }, + ast::Type::Array(scalar, components) => TypeParts { + kind: TypeKind::Array, + scalar_kind: scalar.kind(), + width: scalar.width(), + components: components, + }, + } + } + + fn from_parts(t: TypeParts) -> Self { + match t.kind { + TypeKind::Scalar => { + ast::Type::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)) + } + TypeKind::Vector => ast::Type::Vector( + ast::ScalarType::from_parts(t.width, t.scalar_kind), + t.components as u8, + ), + TypeKind::Array => ast::Type::Array( + ast::ScalarType::from_parts(t.width, t.scalar_kind), + t.components, + ), + } + } +} + +#[derive(Eq, PartialEq, Copy, Clone)] +struct TypeParts { + kind: TypeKind, + scalar_kind: ScalarKind, + width: u8, + components: u32, +} + +#[derive(Eq, PartialEq, Copy, Clone)] +enum TypeKind { + Scalar, + Vector, + Array, +} + impl ast::Instruction { fn jump_target(&self) -> Option { match self { @@ -2005,7 +2182,9 @@ impl ast::Instruction { } impl VisitVariableExpanded for ast::Instruction { - fn visit_variable_extended) -> spirv::Word>( + fn visit_variable_extended< + F: FnMut(ArgumentDescriptor, Option) -> spirv::Word, + >( self, f: &mut F, ) -> ExpandedStatement { @@ -2016,6 +2195,29 @@ impl VisitVariableExpanded for ast::Instruction { type Arg2 = ast::Arg2; type Arg2St = ast::Arg2St; +struct CompositeAccess { + pub typ: ast::MovVectorType, + pub dst: spirv::Word, + pub src: spirv::Word, + pub index: u32, + pub is_write: bool +} + +struct CompositeWrite { + pub typ: ast::MovVectorType, + pub dst: spirv::Word, + pub src_composite: spirv::Word, + pub src_scalar: spirv::Word, + pub index: u32, +} + +struct CompositeRead { + pub typ: ast::MovVectorType, + pub dst: spirv::Word, + pub src: spirv::Word, + pub index: u32, +} + struct ConstantDefinition { pub dst: spirv::Word, pub typ: ast::ScalarType, @@ -2028,6 +2230,7 @@ struct BrachCondition { if_false: spirv::Word, } +#[derive(Copy, Clone)] struct ImplicitConversion { src: spirv::Word, dst: spirv::Word, @@ -2036,7 +2239,7 @@ struct ImplicitConversion { kind: ConversionKind, } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Copy, Clone)] enum ConversionKind { Default, // zero-extend/chop/bitcast depending on types @@ -2084,12 +2287,14 @@ impl ast::Arg1 { t: Option, ) -> ast::Arg1 { ast::Arg1 { - src: visitor.variable(ArgumentDescriptor { - op: self.src, - typ: t, - is_dst: false, - is_pointer: false, - }), + src: visitor.variable( + ArgumentDescriptor { + op: self.src, + is_dst: false, + is_pointer: false, + }, + t, + ), } } } @@ -2098,43 +2303,51 @@ impl ast::Arg2 { fn map>( self, visitor: &mut V, - t: Option, + t: ast::Type, ) -> ast::Arg2 { ast::Arg2 { - dst: visitor.variable(ArgumentDescriptor { - op: self.dst, - typ: t, - is_dst: true, - is_pointer: false, - }), - src: visitor.operand(ArgumentDescriptor { - op: self.src, - typ: t, - is_dst: false, - is_pointer: false, - }), + dst: visitor.variable( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + is_pointer: false, + }, + Some(t), + ), + src: visitor.operand( + ArgumentDescriptor { + op: self.src, + is_dst: false, + is_pointer: false, + }, + t, + ), } } fn map_ld>( self, visitor: &mut V, - t: Option, + t: ast::Type, is_src_pointer: bool, ) -> ast::Arg2 { ast::Arg2 { - dst: visitor.variable(ArgumentDescriptor { - op: self.dst, - typ: t, - is_dst: true, - is_pointer: false, - }), - src: visitor.operand(ArgumentDescriptor { - op: self.src, - typ: t, - is_dst: false, - is_pointer: is_src_pointer, - }), + dst: visitor.variable( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + is_pointer: false, + }, + Some(t), + ), + src: visitor.operand( + ArgumentDescriptor { + op: self.src, + is_dst: false, + is_pointer: is_src_pointer, + }, + t, + ), } } @@ -2145,18 +2358,22 @@ impl ast::Arg2 { src_t: ast::Type, ) -> ast::Arg2 { ast::Arg2 { - dst: visitor.variable(ArgumentDescriptor { - op: self.dst, - typ: Some(dst_t), - is_dst: true, - is_pointer: false, - }), - src: visitor.operand(ArgumentDescriptor { - op: self.src, - typ: Some(src_t), - is_dst: false, - is_pointer: false, - }), + dst: visitor.variable( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + is_pointer: false, + }, + Some(dst_t), + ), + src: visitor.operand( + ArgumentDescriptor { + op: self.src, + is_dst: false, + is_pointer: false, + }, + src_t, + ), } } } @@ -2165,22 +2382,26 @@ impl ast::Arg2St { fn map>( self, visitor: &mut V, - t: Option, + t: ast::Type, param_space: bool, ) -> ast::Arg2St { ast::Arg2St { - src1: visitor.operand(ArgumentDescriptor { - op: self.src1, - typ: t, - is_dst: param_space, - is_pointer: !param_space, - }), - src2: visitor.operand(ArgumentDescriptor { - op: self.src2, - typ: t, - is_dst: false, - is_pointer: false, - }), + src1: visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: param_space, + is_pointer: !param_space, + }, + t, + ), + src2: visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + is_pointer: false, + }, + t, + ), } } } @@ -2189,107 +2410,149 @@ impl ast::Arg2Vec { fn map>( self, visitor: &mut V, - t: ast::Type, + t: ast::MovVectorType, ) -> ast::Arg2Vec { match self { ast::Arg2Vec::Dst(dst, src) => ast::Arg2Vec::Dst( - visitor.src_vec_operand(ArgumentDescriptor { - op: dst, - typ: Some(t), - is_dst: true, - is_pointer: false, - }), - visitor.variable(ArgumentDescriptor { - op: src, - typ: Some(t), - is_dst: false, - is_pointer: false, - }), + visitor.src_vec_operand( + ArgumentDescriptor { + op: dst, + is_dst: true, + is_pointer: false, + }, + t, + ), + visitor.variable( + ArgumentDescriptor { + op: src, + is_dst: false, + is_pointer: false, + }, + Some(ast::Type::Scalar(t.into())), + ), ), - ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src ( - visitor.variable(ArgumentDescriptor { - op: dst, - typ: Some(t), - is_dst: true, - is_pointer: false, - }), - visitor.src_vec_operand(ArgumentDescriptor { - op: src, - typ: Some(t), - is_dst: false, - is_pointer: false, - }), + ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src( + visitor.variable( + ArgumentDescriptor { + op: dst, + is_dst: true, + is_pointer: false, + }, + Some(ast::Type::Scalar(t.into())), + ), + visitor.src_vec_operand( + ArgumentDescriptor { + op: src, + is_dst: false, + is_pointer: false, + }, + t, + ), ), - ast::Arg2Vec::Both(dst, src) => ast::Arg2Vec::Both ( - visitor.src_vec_operand(ArgumentDescriptor { - op: dst, - typ: Some(t), - is_dst: true, - is_pointer: false, - }), - visitor.src_vec_operand(ArgumentDescriptor { - op: src, - typ: Some(t), - is_dst: false, - is_pointer: false, - }), + ast::Arg2Vec::Both(dst, src) => ast::Arg2Vec::Both( + visitor.src_vec_operand( + ArgumentDescriptor { + op: dst, + is_dst: true, + is_pointer: false, + }, + t, + ), + visitor.src_vec_operand( + ArgumentDescriptor { + op: src, + is_dst: false, + is_pointer: false, + }, + t, + ), ), } } } +impl ast::Arg2Vec { + fn dst(&self) -> spirv::Word { + match self { + ast::Arg2Vec::Dst(dst, _) | ast::Arg2Vec::Src(dst, _) | ast::Arg2Vec::Both(dst, _) => { + *dst + } + } + } + + fn src(&self) -> spirv::Word { + match self { + ast::Arg2Vec::Dst(_, src) | ast::Arg2Vec::Src(_, src) | ast::Arg2Vec::Both(_, src) => { + *src + } + } + } +} + impl ast::Arg3 { fn map_non_shift>( self, visitor: &mut V, - t: Option, + t: ast::Type, ) -> ast::Arg3 { ast::Arg3 { - dst: visitor.variable(ArgumentDescriptor { - op: self.dst, - typ: t, - is_dst: true, - is_pointer: false, - }), - src1: visitor.operand(ArgumentDescriptor { - op: self.src1, - typ: t, - is_dst: false, - is_pointer: false, - }), - src2: visitor.operand(ArgumentDescriptor { - op: self.src2, - typ: t, - is_dst: false, - is_pointer: false, - }), + dst: visitor.variable( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + is_pointer: false, + }, + Some(t), + ), + src1: visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: false, + is_pointer: false, + }, + t, + ), + src2: visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + is_pointer: false, + }, + t, + ), } } fn map_shift>( self, visitor: &mut V, - t: Option, + t: ast::Type, ) -> ast::Arg3 { ast::Arg3 { - dst: visitor.variable(ArgumentDescriptor { - op: self.dst, - typ: t, - is_dst: true, - is_pointer: false, - }), - src1: visitor.operand(ArgumentDescriptor { - op: self.src1, - typ: t, - is_dst: false, - is_pointer: false, - }), - src2: visitor.operand(ArgumentDescriptor { - op: self.src2, - typ: Some(ast::Type::Scalar(ast::ScalarType::U32)), - is_dst: false, - is_pointer: false, - }), + dst: visitor.variable( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + is_pointer: false, + }, + Some(t), + ), + src1: visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: false, + is_pointer: false, + }, + t, + ), + src2: visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + is_pointer: false, + }, + ast::Type::Scalar(ast::ScalarType::U32), + ), } } } @@ -2298,35 +2561,43 @@ impl ast::Arg4 { fn map>( self, visitor: &mut V, - t: Option, + t: ast::Type, ) -> ast::Arg4 { ast::Arg4 { - dst1: visitor.variable(ArgumentDescriptor { - op: self.dst1, - typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)), - is_dst: true, - is_pointer: false, - }), - dst2: self.dst2.map(|dst2| { - visitor.variable(ArgumentDescriptor { - op: dst2, - typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)), + dst1: visitor.variable( + ArgumentDescriptor { + op: self.dst1, is_dst: true, is_pointer: false, - }) - }), - src1: visitor.operand(ArgumentDescriptor { - op: self.src1, - typ: t, - is_dst: false, - is_pointer: false, - }), - src2: visitor.operand(ArgumentDescriptor { - op: self.src2, - typ: t, - is_dst: false, - is_pointer: false, + }, + Some(ast::Type::Scalar(ast::ScalarType::Pred)), + ), + dst2: self.dst2.map(|dst2| { + visitor.variable( + ArgumentDescriptor { + op: dst2, + is_dst: true, + is_pointer: false, + }, + Some(ast::Type::Scalar(ast::ScalarType::Pred)), + ) }), + src1: visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: false, + is_pointer: false, + }, + t, + ), + src2: visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + is_pointer: false, + }, + t, + ), } } } @@ -2335,41 +2606,51 @@ impl ast::Arg5 { fn map>( self, visitor: &mut V, - t: Option, + t: ast::Type, ) -> ast::Arg5 { ast::Arg5 { - dst1: visitor.variable(ArgumentDescriptor { - op: self.dst1, - typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)), - is_dst: true, - is_pointer: false, - }), - dst2: self.dst2.map(|dst2| { - visitor.variable(ArgumentDescriptor { - op: dst2, - typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)), + dst1: visitor.variable( + ArgumentDescriptor { + op: self.dst1, is_dst: true, is_pointer: false, - }) - }), - src1: visitor.operand(ArgumentDescriptor { - op: self.src1, - typ: t, - is_dst: false, - is_pointer: false, - }), - src2: visitor.operand(ArgumentDescriptor { - op: self.src2, - typ: t, - is_dst: false, - is_pointer: false, - }), - src3: visitor.operand(ArgumentDescriptor { - op: self.src3, - typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)), - is_dst: false, - is_pointer: false, + }, + Some(ast::Type::Scalar(ast::ScalarType::Pred)), + ), + dst2: self.dst2.map(|dst2| { + visitor.variable( + ArgumentDescriptor { + op: dst2, + is_dst: true, + is_pointer: false, + }, + Some(ast::Type::Scalar(ast::ScalarType::Pred)), + ) }), + src1: visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: false, + is_pointer: false, + }, + t, + ), + src2: visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + is_pointer: false, + }, + t, + ), + src3: visitor.operand( + ArgumentDescriptor { + op: self.src3, + is_dst: false, + is_pointer: false, + }, + ast::Type::Scalar(ast::ScalarType::Pred), + ), } } } @@ -2395,9 +2676,9 @@ impl ast::StStateSpace { } } -#[derive(Clone, Copy, PartialEq)] +#[derive(Clone, Copy, PartialEq, Eq)] enum ScalarKind { - Byte, + Bit, Unsigned, Signed, Float, @@ -2438,10 +2719,10 @@ impl ast::ScalarType { ast::ScalarType::S16 => ScalarKind::Signed, ast::ScalarType::S32 => ScalarKind::Signed, ast::ScalarType::S64 => ScalarKind::Signed, - ast::ScalarType::B8 => ScalarKind::Byte, - ast::ScalarType::B16 => ScalarKind::Byte, - ast::ScalarType::B32 => ScalarKind::Byte, - ast::ScalarType::B64 => ScalarKind::Byte, + ast::ScalarType::B8 => ScalarKind::Bit, + ast::ScalarType::B16 => ScalarKind::Bit, + ast::ScalarType::B32 => ScalarKind::Bit, + ast::ScalarType::B64 => ScalarKind::Bit, ast::ScalarType::F16 => ScalarKind::Float, ast::ScalarType::F32 => ScalarKind::Float, ast::ScalarType::F64 => ScalarKind::Float, @@ -2458,7 +2739,7 @@ impl ast::ScalarType { 8 => ast::ScalarType::F64, _ => unreachable!(), }, - ScalarKind::Byte => match width { + ScalarKind::Bit => match width { 1 => ast::ScalarType::B8, 2 => ast::ScalarType::B16, 4 => ast::ScalarType::B32, @@ -2574,22 +2855,159 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { return false; } match inst.kind() { - ScalarKind::Byte => operand.kind() != ScalarKind::Byte, - ScalarKind::Float => operand.kind() == ScalarKind::Byte, + ScalarKind::Bit => operand.kind() != ScalarKind::Bit, + ScalarKind::Float => operand.kind() == ScalarKind::Bit, ScalarKind::Signed => { - operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Unsigned + operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned } ScalarKind::Unsigned => { - operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed + operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed } - ScalarKind::Float2 => todo!(), + ScalarKind::Float2 => false, ScalarKind::Pred => false, } } + (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _)) + | (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => { + should_bitcast(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) + } _ => false, } } +fn insert_with_conversions ast::Instruction>( + func: &mut Vec, + id_def: &mut NumericIdResolver, + mut instr: T, + pre_conv_src: impl ExactSizeIterator, + pre_conv_dst: impl ExactSizeIterator, + mut post_conv: Vec, + mut src: impl FnMut(&mut T) -> &mut spirv::Word, + mut dst: impl FnMut(&mut T) -> &mut spirv::Word, + to_inst: ToInstruction, +) { + insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_src, &mut src); + insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_dst, &mut dst); + if post_conv.len() > 0 { + let new_id = id_def.new_id(Some(post_conv[0].from)); + post_conv[0].src = new_id; + post_conv.last_mut().unwrap().dst = *dst(&mut instr); + *dst(&mut instr) = new_id; + } + func.push(Statement::Instruction(to_inst(instr))); + for conv in post_conv { + func.push(Statement::Conversion(conv)); + } +} + +fn insert_with_conversions_pre_conv( + func: &mut Vec, + id_def: &mut NumericIdResolver, + mut instr: &mut T, + pre_conv: impl ExactSizeIterator, + src: &mut impl FnMut(&mut T) -> &mut spirv::Word, +) { + let pre_conv_len = pre_conv.len(); + for (i, mut conv) in pre_conv.enumerate() { + let original_src = src(&mut instr); + if i == 0 { + conv.src = *original_src; + } + if i == pre_conv_len - 1 { + let new_id = id_def.new_id(Some(conv.to)); + conv.dst = new_id; + *original_src = new_id; + } + func.push(Statement::Conversion(conv)); + } +} + +fn get_implicit_conversions_ld_dst< + ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option, +>( + id_def: &mut NumericIdResolver, + instr_type: ast::Type, + dst: spirv::Word, + should_convert: ShouldConvert, + in_reverse: bool, +) -> Option { + let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!()); + if let Some(conv) = should_convert(dst_type, instr_type) { + Some(ImplicitConversion { + src: u32::max_value(), + dst: u32::max_value(), + from: if !in_reverse { dst_type } else { instr_type }, + to: if !in_reverse { instr_type } else { dst_type }, + kind: conv, + }) + } else { + None + } +} + +fn get_implicit_conversions_ld_src( + id_def: &mut NumericIdResolver, + instr_type: ast::Type, + state_space: ast::LdStateSpace, + src: spirv::Word, +) -> Vec { + let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()); + match state_space { + ast::LdStateSpace::Param => { + if src_type != instr_type { + vec![ + ImplicitConversion { + src: u32::max_value(), + dst: u32::max_value(), + from: src_type, + to: instr_type, + kind: ConversionKind::Default, + }; + 1 + ] + } else { + Vec::new() + } + } + ast::LdStateSpace::Generic | ast::LdStateSpace::Global => { + let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts( + mem::size_of::() as u8, + ScalarKind::Bit, + )); + let mut result = Vec::new(); + // HACK ALERT + // IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an + // additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier + // TODO: error out if the src is not B64/U64/S64 + if let ast::Type::Scalar(scalar_src_type) = src_type { + if scalar_src_type.kind() == ScalarKind::Signed { + result.push(ImplicitConversion { + src: u32::max_value(), + dst: u32::max_value(), + from: src_type, + to: new_src_type, + kind: ConversionKind::Default, + }); + } + } + result.push(ImplicitConversion { + src: u32::max_value(), + dst: u32::max_value(), + from: src_type, + to: instr_type, + kind: ConversionKind::Ptr(state_space), + }); + if result.len() == 2 { + let new_id = id_def.new_id(Some(new_src_type)); + result[0].dst = new_id; + result[1].src = new_id; + result[1].from = new_src_type; + } + result + } + _ => todo!(), + } +} fn insert_implicit_conversions_ld_src( func: &mut Vec, instr_type: ast::Type, @@ -2608,7 +3026,7 @@ fn insert_implicit_conversions_ld_src( ast::LdStateSpace::Generic | ast::LdStateSpace::Global => { let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts( mem::size_of::() as u8, - ScalarKind::Byte, + ScalarKind::Bit, )); let new_src = insert_implicit_conversions_ld_src_impl( func, @@ -2640,8 +3058,8 @@ fn insert_implicit_conversions_ld_src_impl< should_convert: ShouldConvert, ) -> spirv::Word { let src_type = id_def.get_type(src); - if let Some(conv) = should_convert(src_type, instr_type) { - insert_conversion_src(func, id_def, src, src_type, instr_type, conv) + if let Some(conv) = should_convert(src_type.unwrap(), instr_type) { + insert_conversion_src(func, id_def, src, src_type.unwrap(), instr_type, conv) } else { src } @@ -2692,14 +3110,15 @@ fn insert_conversion_src( temp_src } +/* fn insert_with_implicit_conversion_dst< T, - ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option, + ShouldConvert: FnOnce(ast::StateSpace, ast::Type, ast::Type) -> Option, Setter: Fn(&mut T) -> &mut spirv::Word, ToInstruction: FnOnce(T) -> ast::Instruction, >( func: &mut Vec, - instr_type: ast::ScalarType, + instr_type: ast::Type, id_def: &mut NumericIdResolver, should_convert: ShouldConvert, mut t: T, @@ -2708,13 +3127,14 @@ fn insert_with_implicit_conversion_dst< ) { let dst = setter(&mut t); let dst_type = id_def.get_type(*dst); - let dst_coercion = should_convert(dst_type, instr_type) - .map(|conv| get_conversion_dst(id_def, dst, ast::Type::Scalar(instr_type), dst_type, conv)); + let dst_coercion = should_convert(dst_type.unwrap(), instr_type) + .map(|conv| get_conversion_dst(id_def, dst, instr_type, dst_type.unwrap(), conv)); func.push(Statement::Instruction(to_inst(t))); if let Some(conv) = dst_coercion { func.push(conv); } } +*/ #[must_use] fn get_conversion_dst( @@ -2739,14 +3159,14 @@ fn get_conversion_dst( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands fn should_convert_relaxed_src( src_type: ast::Type, - instr_type: ast::ScalarType, + instr_type: ast::Type, ) -> Option { - if src_type == ast::Type::Scalar(instr_type) { + if src_type == instr_type { return None; } - match src_type { - ast::Type::Scalar(src_type) => match instr_type.kind() { - ScalarKind::Byte => { + match (src_type, instr_type) { + (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { + ScalarKind::Bit => { if instr_type.width() <= src_type.width() { Some(ConversionKind::Default) } else { @@ -2761,7 +3181,7 @@ fn should_convert_relaxed_src( } } ScalarKind::Float => { - if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Byte { + if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Bit { Some(ConversionKind::Default) } else { None @@ -2770,6 +3190,10 @@ fn should_convert_relaxed_src( ScalarKind::Float2 => todo!(), ScalarKind::Pred => None, }, + (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) + | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { + should_convert_relaxed_src(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) + } _ => None, } } @@ -2777,14 +3201,14 @@ fn should_convert_relaxed_src( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands fn should_convert_relaxed_dst( dst_type: ast::Type, - instr_type: ast::ScalarType, + instr_type: ast::Type, ) -> Option { - if dst_type == ast::Type::Scalar(instr_type) { + if dst_type == instr_type { return None; } - match dst_type { - ast::Type::Scalar(dst_type) => match instr_type.kind() { - ScalarKind::Byte => { + match (dst_type, instr_type) { + (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { + ScalarKind::Bit => { if instr_type.width() <= dst_type.width() { Some(ConversionKind::Default) } else { @@ -2812,7 +3236,7 @@ fn should_convert_relaxed_dst( } } ScalarKind::Float => { - if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Byte { + if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Bit { Some(ConversionKind::Default) } else { None @@ -2821,6 +3245,10 @@ fn should_convert_relaxed_dst( ScalarKind::Float2 => todo!(), ScalarKind::Pred => None, }, + (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) + | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { + should_convert_relaxed_dst(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) + } _ => None, } } @@ -2831,13 +3259,13 @@ fn insert_implicit_bitcasts( stmt: impl VisitVariableExpanded, ) { let mut dst_coercion = None; - let instr = stmt.visit_variable_extended(&mut |mut desc| { - let id_type_from_instr = match desc.typ { + let instr = stmt.visit_variable_extended(&mut |mut desc, typ| { + let id_type_from_instr = match typ { Some(t) => t, None => return desc.op, }; - let id_actual_type = id_def.get_type(desc.op); - if should_bitcast(id_type_from_instr, id_def.get_type(desc.op)) { + let id_actual_type = id_def.get_type(desc.op).unwrap(); + if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap()) { if desc.is_dst { dst_coercion = Some(get_conversion_dst( id_def, @@ -2970,14 +3398,14 @@ mod tests { .collect::>() } - fn assert_conversion_table Option>( + fn assert_conversion_table Option>( table: &'static str, f: F, ) { let conv_table = parse_conversion_table(table); for (instr_idx, instr_type) in SCALAR_TYPES.iter().enumerate() { for (op_idx, op_type) in SCALAR_TYPES.iter().enumerate() { - let conversion = f(ast::Type::Scalar(*op_type), *instr_type); + let conversion = f(ast::Type::Scalar(*op_type), ast::Type::Scalar(*instr_type)); if instr_idx == op_idx { assert_eq!(conversion, None); } else {