diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 5a5f6be..367f060 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -766,6 +766,8 @@ sub_type! { LdStType { Scalar(LdStScalarType), Vector(LdStScalarType, u8), + // Used in generated code + Pointer(PointerType, LdStateSpace), } } @@ -774,6 +776,10 @@ impl From for PointerType { match t { LdStType::Scalar(t) => PointerType::Scalar(t.into()), LdStType::Vector(t, len) => PointerType::Vector(t.into(), len), + LdStType::Pointer(PointerType::Scalar(scalar_type), space) => { + PointerType::Pointer(scalar_type, space) + } + LdStType::Pointer(..) => unreachable!(), } } } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 6c231b2..d2c235a 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -1237,18 +1237,18 @@ InstRet: ast::Instruction> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvta InstCvta: ast::Instruction> = { - "cvta" => { + "cvta" => { ast::Instruction::Cvta(ast::CvtaDetails { - to: to, - from: ast::CvtaStateSpace::Generic, + to: ast::CvtaStateSpace::Generic, + from, size: s }, a) }, - "cvta" ".to" => { + "cvta" ".to" => { ast::Instruction::Cvta(ast::CvtaDetails { - to: ast::CvtaStateSpace::Generic, - from: from, + to, + from: ast::CvtaStateSpace::Generic, size: s }, a) diff --git a/ptx/src/test/spirv_run/atom_inc.spvtxt b/ptx/src/test/spirv_run/atom_inc.spvtxt index 6948cd9..fda26c5 100644 --- a/ptx/src/test/spirv_run/atom_inc.spvtxt +++ b/ptx/src/test/spirv_run/atom_inc.spvtxt @@ -1,89 +1,81 @@ -; SPIR-V -; Version: 1.3 -; Generator: rspirv -; Bound: 60 -OpCapability GenericPointer -OpCapability Linkage -OpCapability Addresses -OpCapability Kernel -OpCapability Int8 -OpCapability Int16 -OpCapability Int64 -OpCapability Float16 -OpCapability Float64 -; OpCapability FunctionFloatControlINTEL -; OpExtension "SPV_INTEL_float_controls2" -%49 = OpExtInstImport "OpenCL.std" -OpMemoryModel Physical64 OpenCL -OpEntryPoint Kernel %1 "atom_inc" -OpDecorate %40 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_generic_inc" Import -OpDecorate %44 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_global_inc" Import -%50 = OpTypeVoid -%51 = OpTypeInt 32 0 -%52 = OpTypePointer Generic %51 -%53 = OpTypeFunction %51 %52 %51 -%54 = OpTypePointer CrossWorkgroup %51 -%55 = OpTypeFunction %51 %54 %51 -%56 = OpTypeInt 64 0 -%57 = OpTypeFunction %50 %56 %56 -%58 = OpTypePointer Function %56 -%59 = OpTypePointer Function %51 -%27 = OpConstant %51 101 -%28 = OpConstant %51 101 -%29 = OpConstant %56 4 -%31 = OpConstant %56 8 -%40 = OpFunction %51 None %53 -%42 = OpFunctionParameter %52 -%43 = OpFunctionParameter %51 -OpFunctionEnd -%44 = OpFunction %51 None %55 -%46 = OpFunctionParameter %54 -%47 = OpFunctionParameter %51 -OpFunctionEnd -%1 = OpFunction %50 None %57 -%9 = OpFunctionParameter %56 -%10 = OpFunctionParameter %56 -%39 = OpLabel -%2 = OpVariable %58 Function -%3 = OpVariable %58 Function -%4 = OpVariable %58 Function -%5 = OpVariable %58 Function -%6 = OpVariable %59 Function -%7 = OpVariable %59 Function -%8 = OpVariable %59 Function -OpStore %2 %9 -OpStore %3 %10 -%12 = OpLoad %56 %2 -%11 = OpCopyObject %56 %12 -OpStore %4 %11 -%14 = OpLoad %56 %3 -%13 = OpCopyObject %56 %14 -OpStore %5 %13 -%16 = OpLoad %56 %4 -%33 = OpConvertUToPtr %52 %16 -%15 = OpFunctionCall %51 %40 %33 %27 -OpStore %6 %15 -%18 = OpLoad %56 %4 -%34 = OpConvertUToPtr %54 %18 -%17 = OpFunctionCall %51 %44 %34 %28 -OpStore %7 %17 -%20 = OpLoad %56 %4 -%35 = OpConvertUToPtr %52 %20 -%19 = OpLoad %51 %35 -OpStore %8 %19 -%21 = OpLoad %56 %5 -%22 = OpLoad %51 %6 -%36 = OpConvertUToPtr %52 %21 -OpStore %36 %22 -%23 = OpLoad %56 %5 -%24 = OpLoad %51 %7 -%30 = OpIAdd %56 %23 %29 -%37 = OpConvertUToPtr %52 %30 -OpStore %37 %24 -%25 = OpLoad %56 %5 -%26 = OpLoad %51 %8 -%32 = OpIAdd %56 %25 %31 -%38 = OpConvertUToPtr %52 %32 -OpStore %38 %26 -OpReturn -OpFunctionEnd \ No newline at end of file + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %47 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "atom_inc" + OpDecorate %38 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_generic_inc" Import + OpDecorate %42 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_global_inc" Import + %void = OpTypeVoid + %uint = OpTypeInt 32 0 +%_ptr_Generic_uint = OpTypePointer Generic %uint + %51 = OpTypeFunction %uint %_ptr_Generic_uint %uint +%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint + %53 = OpTypeFunction %uint %_ptr_CrossWorkgroup_uint %uint + %ulong = OpTypeInt 64 0 + %55 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Function_uint = OpTypePointer Function %uint + %uint_101 = OpConstant %uint 101 + %uint_101_0 = OpConstant %uint 101 + %ulong_4 = OpConstant %ulong 4 + %ulong_8 = OpConstant %ulong 8 + %38 = OpFunction %uint None %51 + %40 = OpFunctionParameter %_ptr_Generic_uint + %41 = OpFunctionParameter %uint + OpFunctionEnd + %42 = OpFunction %uint None %53 + %44 = OpFunctionParameter %_ptr_CrossWorkgroup_uint + %45 = OpFunctionParameter %uint + OpFunctionEnd + %1 = OpFunction %void None %55 + %9 = OpFunctionParameter %ulong + %10 = OpFunctionParameter %ulong + %37 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + %8 = OpVariable %_ptr_Function_uint Function + OpStore %2 %9 + OpStore %3 %10 + %11 = OpLoad %ulong %2 + OpStore %4 %11 + %12 = OpLoad %ulong %3 + OpStore %5 %12 + %14 = OpLoad %ulong %4 + %31 = OpConvertUToPtr %_ptr_Generic_uint %14 + %13 = OpFunctionCall %uint %38 %31 %uint_101 + OpStore %6 %13 + %16 = OpLoad %ulong %4 + %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %16 + %15 = OpFunctionCall %uint %42 %32 %uint_101_0 + OpStore %7 %15 + %18 = OpLoad %ulong %4 + %33 = OpConvertUToPtr %_ptr_Generic_uint %18 + %17 = OpLoad %uint %33 + OpStore %8 %17 + %19 = OpLoad %ulong %5 + %20 = OpLoad %uint %6 + %34 = OpConvertUToPtr %_ptr_Generic_uint %19 + OpStore %34 %20 + %21 = OpLoad %ulong %5 + %22 = OpLoad %uint %7 + %28 = OpIAdd %ulong %21 %ulong_4 + %35 = OpConvertUToPtr %_ptr_Generic_uint %28 + OpStore %35 %22 + %23 = OpLoad %ulong %5 + %24 = OpLoad %uint %8 + %30 = OpIAdd %ulong %23 %ulong_8 + %36 = OpConvertUToPtr %_ptr_Generic_uint %30 + OpStore %36 %24 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvta.spvtxt b/ptx/src/test/spirv_run/cvta.spvtxt index cf6ff8b..143d0a5 100644 --- a/ptx/src/test/spirv_run/cvta.spvtxt +++ b/ptx/src/test/spirv_run/cvta.spvtxt @@ -7,48 +7,59 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %27 = OpExtInstImport "OpenCL.std" + %37 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "cvta" %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %30 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %41 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar +%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar %float = OpTypeFloat 32 %_ptr_Function_float = OpTypePointer Function %float + %ulong = OpTypeInt 64 0 +%_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float - %1 = OpFunction %void None %30 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %25 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function + %1 = OpFunction %void None %41 + %17 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %18 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %35 = OpLabel + %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %7 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %8 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %6 = OpVariable %_ptr_Function_float Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 - OpStore %4 %9 - %10 = OpLoad %ulong %3 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %20 = OpCopyObject %ulong %12 - %19 = OpCopyObject %ulong %20 - %11 = OpCopyObject %ulong %19 - OpStore %4 %11 - %14 = OpLoad %ulong %5 - %22 = OpCopyObject %ulong %14 - %21 = OpCopyObject %ulong %22 - %13 = OpCopyObject %ulong %21 - OpStore %5 %13 - %16 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %16 - %15 = OpLoad %float %23 - OpStore %6 %15 - %17 = OpLoad %ulong %5 - %18 = OpLoad %float %6 - %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %17 - OpStore %24 %18 + OpStore %2 %17 + OpStore %3 %18 + %10 = OpBitcast %_ptr_Function_ulong %2 + %9 = OpLoad %ulong %10 + %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %9 + OpStore %7 %19 + %12 = OpBitcast %_ptr_Function_ulong %3 + %11 = OpLoad %ulong %12 + %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %11 + OpStore %8 %20 + %21 = OpLoad %_ptr_CrossWorkgroup_uchar %7 + %14 = OpConvertPtrToU %ulong %21 + %30 = OpCopyObject %ulong %14 + %29 = OpCopyObject %ulong %30 + %13 = OpCopyObject %ulong %29 + %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %13 + OpStore %7 %22 + %23 = OpLoad %_ptr_CrossWorkgroup_uchar %8 + %16 = OpConvertPtrToU %ulong %23 + %32 = OpCopyObject %ulong %16 + %31 = OpCopyObject %ulong %32 + %15 = OpCopyObject %ulong %31 + %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15 + OpStore %8 %24 + %26 = OpLoad %_ptr_CrossWorkgroup_uchar %7 + %33 = OpBitcast %_ptr_CrossWorkgroup_float %26 + %25 = OpLoad %float %33 + OpStore %6 %25 + %27 = OpLoad %_ptr_CrossWorkgroup_uchar %8 + %28 = OpLoad %float %6 + %34 = OpBitcast %_ptr_CrossWorkgroup_float %27 + OpStore %34 %28 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/extern_shared_call.spvtxt b/ptx/src/test/spirv_run/extern_shared_call.spvtxt index d979193..39f8683 100644 --- a/ptx/src/test/spirv_run/extern_shared_call.spvtxt +++ b/ptx/src/test/spirv_run/extern_shared_call.spvtxt @@ -7,7 +7,7 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %48 = OpExtInstImport "OpenCL.std" + %46 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %14 "extern_shared_call" %1 OpDecorate %1 Alignment 4 @@ -18,78 +18,76 @@ %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uint Workgroup %uchar = OpTypeInt 8 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %55 = OpTypeFunction %void %_ptr_Workgroup_uchar + %53 = OpTypeFunction %void %_ptr_Workgroup_uchar %_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar %ulong = OpTypeInt 64 0 %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function__ptr_Workgroup_uint = OpTypePointer Function %_ptr_Workgroup_uint %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong %ulong_2 = OpConstant %ulong 2 - %62 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar + %60 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %2 = OpFunction %void None %55 - %40 = OpFunctionParameter %_ptr_Workgroup_uchar - %56 = OpLabel - %41 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %2 = OpFunction %void None %53 + %38 = OpFunctionParameter %_ptr_Workgroup_uchar + %54 = OpLabel + %39 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function %3 = OpVariable %_ptr_Function_ulong Function - OpStore %41 %40 + OpStore %39 %38 OpBranch %13 %13 = OpLabel - %42 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %41 - %5 = OpLoad %_ptr_Workgroup_uint %42 + %40 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %39 + %5 = OpLoad %_ptr_Workgroup_uint %40 %11 = OpBitcast %_ptr_Workgroup_ulong %5 %4 = OpLoad %ulong %11 OpStore %3 %4 %7 = OpLoad %ulong %3 %6 = OpIAdd %ulong %7 %ulong_2 OpStore %3 %6 - %43 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %41 - %8 = OpLoad %_ptr_Workgroup_uint %43 + %41 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %39 + %8 = OpLoad %_ptr_Workgroup_uint %41 %9 = OpLoad %ulong %3 %12 = OpBitcast %_ptr_Workgroup_ulong %8 OpStore %12 %9 OpReturn OpFunctionEnd - %14 = OpFunction %void None %62 + %14 = OpFunction %void None %60 %20 = OpFunctionParameter %ulong %21 = OpFunctionParameter %ulong - %44 = OpFunctionParameter %_ptr_Workgroup_uchar - %63 = OpLabel - %45 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %42 = OpFunctionParameter %_ptr_Workgroup_uchar + %61 = OpLabel + %43 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function %15 = OpVariable %_ptr_Function_ulong Function %16 = OpVariable %_ptr_Function_ulong Function %17 = OpVariable %_ptr_Function_ulong Function %18 = OpVariable %_ptr_Function_ulong Function %19 = OpVariable %_ptr_Function_ulong Function - OpStore %45 %44 - OpBranch %38 - %38 = OpLabel + OpStore %43 %42 + OpBranch %36 + %36 = OpLabel OpStore %15 %20 OpStore %16 %21 - %23 = OpLoad %ulong %15 - %22 = OpCopyObject %ulong %23 + %22 = OpLoad %ulong %15 OpStore %17 %22 - %25 = OpLoad %ulong %16 - %24 = OpCopyObject %ulong %25 - OpStore %18 %24 - %27 = OpLoad %ulong %17 - %34 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %27 - %26 = OpLoad %ulong %34 - OpStore %19 %26 - %46 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %45 - %28 = OpLoad %_ptr_Workgroup_uint %46 - %29 = OpLoad %ulong %19 - %35 = OpBitcast %_ptr_Workgroup_ulong %28 - OpStore %35 %29 - %65 = OpFunctionCall %void %2 %44 - %47 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %45 - %31 = OpLoad %_ptr_Workgroup_uint %47 - %36 = OpBitcast %_ptr_Workgroup_ulong %31 - %30 = OpLoad %ulong %36 - OpStore %19 %30 - %32 = OpLoad %ulong %18 - %33 = OpLoad %ulong %19 - %37 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %32 - OpStore %37 %33 + %23 = OpLoad %ulong %16 + OpStore %18 %23 + %25 = OpLoad %ulong %17 + %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %25 + %24 = OpLoad %ulong %32 + OpStore %19 %24 + %44 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %43 + %26 = OpLoad %_ptr_Workgroup_uint %44 + %27 = OpLoad %ulong %19 + %33 = OpBitcast %_ptr_Workgroup_ulong %26 + OpStore %33 %27 + %63 = OpFunctionCall %void %2 %42 + %45 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %43 + %29 = OpLoad %_ptr_Workgroup_uint %45 + %34 = OpBitcast %_ptr_Workgroup_ulong %29 + %28 = OpLoad %ulong %34 + OpStore %19 %28 + %30 = OpLoad %ulong %18 + %31 = OpLoad %ulong %19 + %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %30 + OpStore %35 %31 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index bd74508..f18b15c 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -133,6 +133,10 @@ test_ptx!( [0b11111000_11000001_00100010_10100000u32, 16u32, 8u32], [0b11000001u32] ); +test_ptx!(stateful_ld_st_simple, [121u64], [121u64]); +test_ptx!(stateful_ld_st_ntid, [123u64], [123u64]); +test_ptx!(stateful_ld_st_ntid_chain, [12651u64], [12651u64]); +test_ptx!(stateful_ld_st_ntid_sub, [96311u64], [96311u64]); struct DisplayError { err: T, @@ -292,7 +296,7 @@ fn test_spvtxt_assert<'a>( rspirv::binary::parse_words(&parsed_spirv, &mut loader)?; let spvtxt_mod = loader.module(); unsafe { spirv_tools::spvBinaryDestroy(spv_binary) }; - if !is_spirv_fn_equal(&spirv_module.spirv.functions[0], &spvtxt_mod.functions[0]) { + if !is_spirv_fns_equal(&spirv_module.spirv.functions, &spvtxt_mod.functions) { // We could simply use ptx_mod.disassemble, but SPIRV-Tools text formattinmg is so much nicer let spv_from_ptx_binary = spirv_module.spirv.assemble(); let mut spv_text: spirv_tools::spv_text = ptr::null_mut(); @@ -364,6 +368,18 @@ impl EqMap { } } +fn is_spirv_fns_equal(fns1: &[Function], fns2: &[Function]) -> bool { + if fns1.len() != fns2.len() { + return false; + } + for (fn1, fn2) in fns1.iter().zip(fns2.iter()) { + if !is_spirv_fn_equal(fn1, fn2) { + return false; + } + } + true +} + fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool { let mut map = EqMap::new(); if !is_option_equal(&fn1.def, &fn2.def, &mut map, is_instr_equal) { diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt index 596cedc..5ce3689 100644 --- a/ptx/src/test/spirv_run/reg_local.spvtxt +++ b/ptx/src/test/spirv_run/reg_local.spvtxt @@ -22,7 +22,9 @@ %_ptr_Function__arr_uchar_uint_8 = OpTypePointer Function %_arr_uchar_uint_8 %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %ulong_1 = OpConstant %ulong 1 +%_ptr_Generic_ulong = OpTypePointer Generic %ulong %ulong_0 = OpConstant %ulong 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_0_0 = OpConstant %ulong 0 %1 = OpFunction %void None %37 %8 = OpFunctionParameter %ulong @@ -48,12 +50,12 @@ %14 = OpLoad %ulong %7 %26 = OpCopyObject %ulong %14 %19 = OpIAdd %ulong %26 %ulong_1 - %27 = OpBitcast %_ptr_Function_ulong %4 + %27 = OpBitcast %_ptr_Generic_ulong %4 OpStore %27 %19 - %28 = OpBitcast %_ptr_Function_ulong %4 - %45 = OpBitcast %ulong %28 - %46 = OpIAdd %ulong %45 %ulong_0 - %21 = OpBitcast %_ptr_Function_ulong %46 + %28 = OpBitcast %_ptr_Generic_ulong %4 + %47 = OpBitcast %_ptr_Generic_uchar %28 + %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_0 + %21 = OpBitcast %_ptr_Generic_ulong %48 %29 = OpLoad %ulong %21 %15 = OpCopyObject %ulong %29 OpStore %7 %15 diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.ptx b/ptx/src/test/spirv_run/stateful_ld_st_ntid.ptx new file mode 100644 index 0000000..1fc37d1 --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.ptx @@ -0,0 +1,31 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry stateful_ld_st_ntid( + .param .u64 input, + .param .u64 output +) +{ + .reg .b64 in_addr; + .reg .b64 out_addr; + .reg .u32 tid_32; + .reg .u64 tid_64; + .reg .u64 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + cvta.to.global.u64 in_addr, in_addr; + cvta.to.global.u64 out_addr, out_addr; + + mov.u32 tid_32, %tid.x; + cvt.u64.u32 tid_64, tid_32; + + add.u64 in_addr, in_addr, tid_64; + add.u64 out_addr, out_addr, tid_64; + + ld.global.u64 temp, [in_addr]; + st.global.u64 [out_addr], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt new file mode 100644 index 0000000..c53ad51 --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt @@ -0,0 +1,89 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %49 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "stateful_ld_st_ntid" %gl_LocalInvocationID + OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %v4uint = OpTypeVector %uint 4 +%_ptr_Input_v4uint = OpTypePointer Input %v4uint +%gl_LocalInvocationID = OpVariable %_ptr_Input_v4uint Input + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %56 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar +%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar +%_ptr_Function_uint = OpTypePointer Function %uint + %ulong = OpTypeInt 64 0 +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %1 = OpFunction %void None %56 + %20 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %21 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %47 = OpLabel + %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_ulong Function + %8 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %20 + OpStore %3 %21 + %13 = OpBitcast %_ptr_Function_ulong %2 + %43 = OpLoad %ulong %13 + %12 = OpCopyObject %ulong %43 + %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %12 + OpStore %10 %22 + %15 = OpBitcast %_ptr_Function_ulong %3 + %44 = OpLoad %ulong %15 + %14 = OpCopyObject %ulong %44 + %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14 + OpStore %11 %23 + %24 = OpLoad %_ptr_CrossWorkgroup_uchar %10 + %17 = OpConvertPtrToU %ulong %24 + %16 = OpCopyObject %ulong %17 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %16 + OpStore %10 %25 + %26 = OpLoad %_ptr_CrossWorkgroup_uchar %11 + %19 = OpConvertPtrToU %ulong %26 + %18 = OpCopyObject %ulong %19 + %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %18 + OpStore %11 %27 + %29 = OpLoad %v4uint %gl_LocalInvocationID + %42 = OpCompositeExtract %uint %29 0 + %28 = OpCopyObject %uint %42 + OpStore %6 %28 + %31 = OpLoad %uint %6 + %61 = OpBitcast %uint %31 + %30 = OpUConvert %ulong %61 + OpStore %7 %30 + %33 = OpLoad %_ptr_CrossWorkgroup_uchar %10 + %34 = OpLoad %ulong %7 + %62 = OpBitcast %_ptr_CrossWorkgroup_uchar %33 + %63 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %62 %34 + %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %63 + OpStore %10 %32 + %36 = OpLoad %_ptr_CrossWorkgroup_uchar %11 + %37 = OpLoad %ulong %7 + %64 = OpBitcast %_ptr_CrossWorkgroup_uchar %36 + %65 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %64 %37 + %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %65 + OpStore %11 %35 + %39 = OpLoad %_ptr_CrossWorkgroup_uchar %10 + %45 = OpBitcast %_ptr_CrossWorkgroup_ulong %39 + %38 = OpLoad %ulong %45 + OpStore %8 %38 + %40 = OpLoad %_ptr_CrossWorkgroup_uchar %11 + %41 = OpLoad %ulong %8 + %46 = OpBitcast %_ptr_CrossWorkgroup_ulong %40 + OpStore %46 %41 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.ptx b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.ptx new file mode 100644 index 0000000..ef7645d --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.ptx @@ -0,0 +1,35 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry stateful_ld_st_ntid_chain( + .param .u64 input, + .param .u64 output +) +{ + .reg .b64 in_addr1; + .reg .b64 in_addr2; + .reg .b64 in_addr3; + .reg .b64 out_addr1; + .reg .b64 out_addr2; + .reg .b64 out_addr3; + .reg .u32 tid_32; + .reg .u64 tid_64; + .reg .u64 temp; + + ld.param.u64 in_addr1, [input]; + ld.param.u64 out_addr1, [output]; + + cvta.to.global.u64 in_addr2, in_addr1; + cvta.to.global.u64 out_addr2, out_addr1; + + mov.u32 tid_32, %tid.x; + cvt.u64.u32 tid_64, tid_32; + + add.u64 in_addr3, in_addr2, tid_64; + add.u64 out_addr3, out_addr2, tid_64; + + ld.global.u64 temp, [in_addr3]; + st.global.u64 [out_addr3], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt new file mode 100644 index 0000000..5ba889c --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt @@ -0,0 +1,93 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %57 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "stateful_ld_st_ntid_chain" %gl_LocalInvocationID + OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %v4uint = OpTypeVector %uint 4 +%_ptr_Input_v4uint = OpTypePointer Input %v4uint +%gl_LocalInvocationID = OpVariable %_ptr_Input_v4uint Input + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %64 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar +%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar +%_ptr_Function_uint = OpTypePointer Function %uint + %ulong = OpTypeInt 64 0 +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %1 = OpFunction %void None %64 + %28 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %29 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %55 = OpLabel + %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %17 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %18 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %10 = OpVariable %_ptr_Function_uint Function + %11 = OpVariable %_ptr_Function_ulong Function + %12 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %28 + OpStore %3 %29 + %21 = OpBitcast %_ptr_Function_ulong %2 + %51 = OpLoad %ulong %21 + %20 = OpCopyObject %ulong %51 + %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %20 + OpStore %14 %30 + %23 = OpBitcast %_ptr_Function_ulong %3 + %52 = OpLoad %ulong %23 + %22 = OpCopyObject %ulong %52 + %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22 + OpStore %17 %31 + %32 = OpLoad %_ptr_CrossWorkgroup_uchar %14 + %25 = OpConvertPtrToU %ulong %32 + %24 = OpCopyObject %ulong %25 + %33 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %24 + OpStore %15 %33 + %34 = OpLoad %_ptr_CrossWorkgroup_uchar %17 + %27 = OpConvertPtrToU %ulong %34 + %26 = OpCopyObject %ulong %27 + %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26 + OpStore %18 %35 + %37 = OpLoad %v4uint %gl_LocalInvocationID + %50 = OpCompositeExtract %uint %37 0 + %36 = OpCopyObject %uint %50 + OpStore %10 %36 + %39 = OpLoad %uint %10 + %69 = OpBitcast %uint %39 + %38 = OpUConvert %ulong %69 + OpStore %11 %38 + %41 = OpLoad %_ptr_CrossWorkgroup_uchar %15 + %42 = OpLoad %ulong %11 + %70 = OpBitcast %_ptr_CrossWorkgroup_uchar %41 + %71 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %70 %42 + %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %71 + OpStore %16 %40 + %44 = OpLoad %_ptr_CrossWorkgroup_uchar %18 + %45 = OpLoad %ulong %11 + %72 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 + %73 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %72 %45 + %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %73 + OpStore %19 %43 + %47 = OpLoad %_ptr_CrossWorkgroup_uchar %16 + %53 = OpBitcast %_ptr_CrossWorkgroup_ulong %47 + %46 = OpLoad %ulong %53 + OpStore %12 %46 + %48 = OpLoad %_ptr_CrossWorkgroup_uchar %19 + %49 = OpLoad %ulong %12 + %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %48 + OpStore %54 %49 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.ptx b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.ptx new file mode 100644 index 0000000..018918c --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.ptx @@ -0,0 +1,35 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry stateful_ld_st_ntid_sub( + .param .u64 input, + .param .u64 output +) +{ + .reg .b64 in_addr1; + .reg .b64 in_addr2; + .reg .b64 in_addr3; + .reg .b64 out_addr1; + .reg .b64 out_addr2; + .reg .b64 out_addr3; + .reg .u32 tid_32; + .reg .u64 tid_64; + .reg .u64 temp; + + ld.param.u64 in_addr1, [input]; + ld.param.u64 out_addr1, [output]; + + cvta.to.global.u64 in_addr2, in_addr1; + cvta.to.global.u64 out_addr2, out_addr1; + + mov.u32 tid_32, %tid.x; + cvt.u64.u32 tid_64, tid_32; + + sub.s64 in_addr3, in_addr2, tid_64; + sub.s64 out_addr3, out_addr2, tid_64; + + ld.global.u64 temp, [in_addr3+-0]; + st.global.u64 [out_addr3+-0], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt new file mode 100644 index 0000000..3c215d4 --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt @@ -0,0 +1,105 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %65 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "stateful_ld_st_ntid_sub" %gl_LocalInvocationID + OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %v4uint = OpTypeVector %uint 4 +%_ptr_Input_v4uint = OpTypePointer Input %v4uint +%gl_LocalInvocationID = OpVariable %_ptr_Input_v4uint Input + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %72 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar +%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar +%_ptr_Function_uint = OpTypePointer Function %uint + %ulong = OpTypeInt 64 0 +%_ptr_Function_ulong = OpTypePointer Function %ulong + %ulong_0 = OpConstant %ulong 0 +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %ulong_0_0 = OpConstant %ulong 0 + %1 = OpFunction %void None %72 + %30 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %31 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %63 = OpLabel + %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %17 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %18 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %10 = OpVariable %_ptr_Function_uint Function + %11 = OpVariable %_ptr_Function_ulong Function + %12 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %30 + OpStore %3 %31 + %21 = OpBitcast %_ptr_Function_ulong %2 + %57 = OpLoad %ulong %21 + %20 = OpCopyObject %ulong %57 + %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %20 + OpStore %14 %32 + %23 = OpBitcast %_ptr_Function_ulong %3 + %58 = OpLoad %ulong %23 + %22 = OpCopyObject %ulong %58 + %33 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22 + OpStore %17 %33 + %34 = OpLoad %_ptr_CrossWorkgroup_uchar %14 + %25 = OpConvertPtrToU %ulong %34 + %24 = OpCopyObject %ulong %25 + %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %24 + OpStore %15 %35 + %36 = OpLoad %_ptr_CrossWorkgroup_uchar %17 + %27 = OpConvertPtrToU %ulong %36 + %26 = OpCopyObject %ulong %27 + %37 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26 + OpStore %18 %37 + %39 = OpLoad %v4uint %gl_LocalInvocationID + %52 = OpCompositeExtract %uint %39 0 + %38 = OpCopyObject %uint %52 + OpStore %10 %38 + %41 = OpLoad %uint %10 + %77 = OpBitcast %uint %41 + %40 = OpUConvert %ulong %77 + OpStore %11 %40 + %42 = OpLoad %ulong %11 + %59 = OpCopyObject %ulong %42 + %28 = OpSNegate %ulong %59 + %44 = OpLoad %_ptr_CrossWorkgroup_uchar %15 + %78 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 + %79 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %78 %28 + %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %79 + OpStore %16 %43 + %45 = OpLoad %ulong %11 + %60 = OpCopyObject %ulong %45 + %29 = OpSNegate %ulong %60 + %47 = OpLoad %_ptr_CrossWorkgroup_uchar %18 + %80 = OpBitcast %_ptr_CrossWorkgroup_uchar %47 + %81 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %80 %29 + %46 = OpBitcast %_ptr_CrossWorkgroup_uchar %81 + OpStore %19 %46 + %49 = OpLoad %_ptr_CrossWorkgroup_uchar %16 + %61 = OpBitcast %_ptr_CrossWorkgroup_ulong %49 + %83 = OpBitcast %_ptr_CrossWorkgroup_uchar %61 + %84 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %83 %ulong_0 + %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %84 + %48 = OpLoad %ulong %54 + OpStore %12 %48 + %50 = OpLoad %_ptr_CrossWorkgroup_uchar %19 + %51 = OpLoad %ulong %12 + %62 = OpBitcast %_ptr_CrossWorkgroup_ulong %50 + %85 = OpBitcast %_ptr_CrossWorkgroup_uchar %62 + %86 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %85 %ulong_0_0 + %56 = OpBitcast %_ptr_CrossWorkgroup_ulong %86 + OpStore %56 %51 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_simple.ptx b/ptx/src/test/spirv_run/stateful_ld_st_simple.ptx new file mode 100644 index 0000000..5650ada --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_simple.ptx @@ -0,0 +1,25 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry stateful_ld_st_simple( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 in_addr2; + .reg .u64 out_addr2; + .reg .u64 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + cvta.to.global.u64 in_addr2, in_addr; + cvta.to.global.u64 out_addr2, out_addr; + + ld.global.u64 temp, [in_addr2]; + st.global.u64 [out_addr2], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt new file mode 100644 index 0000000..cfd87eb --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt @@ -0,0 +1,65 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %41 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "stateful_ld_st_simple" + %void = OpTypeVoid + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %45 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar +%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar + %ulong = OpTypeInt 64 0 +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %1 = OpFunction %void None %45 + %21 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %22 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %39 = OpLabel + %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %9 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %12 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %8 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %21 + OpStore %3 %22 + %14 = OpBitcast %_ptr_Function_ulong %2 + %13 = OpLoad %ulong %14 + %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %13 + OpStore %9 %23 + %16 = OpBitcast %_ptr_Function_ulong %3 + %15 = OpLoad %ulong %16 + %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15 + OpStore %10 %24 + %25 = OpLoad %_ptr_CrossWorkgroup_uchar %9 + %18 = OpConvertPtrToU %ulong %25 + %34 = OpCopyObject %ulong %18 + %33 = OpCopyObject %ulong %34 + %17 = OpCopyObject %ulong %33 + %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %17 + OpStore %11 %26 + %27 = OpLoad %_ptr_CrossWorkgroup_uchar %10 + %20 = OpConvertPtrToU %ulong %27 + %36 = OpCopyObject %ulong %20 + %35 = OpCopyObject %ulong %36 + %19 = OpCopyObject %ulong %35 + %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %19 + OpStore %12 %28 + %30 = OpLoad %_ptr_CrossWorkgroup_uchar %11 + %37 = OpBitcast %_ptr_CrossWorkgroup_ulong %30 + %29 = OpLoad %ulong %37 + OpStore %8 %29 + %31 = OpLoad %_ptr_CrossWorkgroup_uchar %12 + %32 = OpLoad %ulong %8 + %38 = OpBitcast %_ptr_CrossWorkgroup_ulong %31 + OpStore %38 %32 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index f0a3187..328bf30 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,7 +1,7 @@ use crate::ast; use half::f16; use rspirv::dr; -use std::{borrow::Cow, ffi::CString, hash::Hash, iter, mem}; +use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem}; use std::{ collections::{hash_map, HashMap, HashSet}, convert::TryInto, @@ -72,15 +72,13 @@ impl From for ast::Type { } impl ast::Type { - fn pointer_to(self, space: ast::LdStateSpace) -> Result { + fn param_pointer_to(self, space: ast::LdStateSpace) -> Result { Ok(match self { ast::Type::Scalar(t) => ast::Type::Pointer(ast::PointerType::Scalar(t), space), ast::Type::Vector(t, len) => { ast::Type::Pointer(ast::PointerType::Vector(t, len), space) } - ast::Type::Array(t, dims) => { - ast::Type::Pointer(ast::PointerType::Array(t, dims), space) - } + ast::Type::Array(t, _) => ast::Type::Pointer(ast::PointerType::Scalar(t), space), ast::Type::Pointer(ast::PointerType::Scalar(t), space) => { ast::Type::Pointer(ast::PointerType::Pointer(t, space), space) } @@ -726,7 +724,7 @@ fn convert_dynamic_shared_memory_usage<'input>( multi_hash_map_append(&mut directly_called_by, call.func, call_key); Statement::Call(call) } - statement => statement.map_id(&mut |id| { + statement => statement.map_id(&mut |id, _| { if extern_shared_decls.contains_key(&id) { methods_using_extern_shared.insert(call_key); } @@ -843,7 +841,7 @@ fn replace_uses_of_shared_memory<'a>( result.push(Statement::Call(call)) } statement => { - let new_statement = statement.map_id(&mut |id| { + let new_statement = statement.map_id(&mut |id, _| { if let Some(typ) = extern_shared_decls.get(&id) { let replacement_id = new_id(); if *typ != ast::SizedScalarType::B8 { @@ -859,6 +857,8 @@ fn replace_uses_of_shared_memory<'a>( ast::LdStateSpace::Shared, ), kind: ConversionKind::PtrToPtr { spirv_ptr: true }, + src_sema: ArgumentSemantics::Default, + dst_sema: ArgumentSemantics::Default, })); } replacement_id @@ -971,7 +971,7 @@ fn compute_denorm_information<'input>( Statement::Undef(_, _) => {} Statement::Label(_) => {} Statement::Variable(_) => {} - Statement::PtrAdd { .. } => {} + Statement::PtrAccess { .. } => {} } } denorm_methods.insert(method_key, flush_counter); @@ -1022,15 +1022,10 @@ fn emit_builtins( builder, SpirvType::Pointer( Box::new(SpirvType::from(reg.get_type())), - spirv::StorageClass::UniformConstant, + spirv::StorageClass::Input, ), ); - builder.variable( - result_type, - Some(*id), - spirv::StorageClass::UniformConstant, - None, - ); + builder.variable(result_type, Some(*id), spirv::StorageClass::Input, None); builder.decorate( *id, spirv::Decoration::BuiltIn, @@ -1192,11 +1187,31 @@ fn translate_variable<'a>( id_defs: &mut GlobalStringIdResolver<'a>, var: ast::Variable, ) -> Result, TranslateError> { - let (state_space, typ) = var.v_type.to_type(); + let (space, var_type) = var.v_type.to_type(); + let mut is_variable = false; + let var_type = match space { + ast::StateSpace::Reg => { + is_variable = true; + var_type + } + ast::StateSpace::Const => var_type.param_pointer_to(ast::LdStateSpace::Const)?, + ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?, + ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?, + ast::StateSpace::Shared => { + // If it's a pointer it will be translated to a method parameter later + if let ast::Type::Pointer(..) = var_type { + is_variable = true; + var_type + } else { + var_type.param_pointer_to(ast::LdStateSpace::Shared)? + } + } + ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?, + }; Ok(ast::Variable { align: var.align, v_type: var.v_type, - name: id_defs.get_or_add_def_typed(var.name, (state_space.into(), typ)), + name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable), array_init: var.array_init, }) } @@ -1218,10 +1233,8 @@ fn expand_kernel_params<'a, 'b>( Ok(ast::KernelArgument { name: fn_resolver.add_def( a.name, - Some(( - StateSpace::Param, - ast::Type::from(a.v_type.clone()).pointer_to(ast::LdStateSpace::Param)?, - )), + Some(ast::Type::from(a.v_type.clone()).param_pointer_to(ast::LdStateSpace::Param)?), + false, ), v_type: a.v_type.clone(), align: a.align, @@ -1236,14 +1249,13 @@ fn expand_fn_params<'a, 'b>( args: impl Iterator>, ) -> Result>, TranslateError> { args.map(|a| { - let var_type = a.v_type.to_func_type(); - let ss = match a.v_type { - ast::FnArgumentType::Reg(_) => StateSpace::Reg, - ast::FnArgumentType::Param(_) => StateSpace::Param, - ast::FnArgumentType::Shared => StateSpace::Shared, + let is_variable = match a.v_type { + ast::FnArgumentType::Reg(_) => true, + _ => false, }; + let var_type = a.v_type.to_func_type(); Ok(ast::FnArgument { - name: fn_resolver.add_def(a.name, Some((ss, var_type))), + name: fn_resolver.add_def(a.name, Some(var_type), is_variable), v_type: a.v_type.clone(), align: a.align, array_init: Vec::new(), @@ -1274,12 +1286,18 @@ fn to_ssa<'input, 'b>( }; let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?; let mut numeric_id_defs = id_defs.finish(); - let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs); + let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?; + let typed_statements = + convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; + let ssa_statements = insert_mem_ssa_statements( + typed_statements, + &mut numeric_id_defs, + &f_args, + &mut spirv_decl, + )?; let mut numeric_id_defs = numeric_id_defs.finish(); - let ssa_statements = - insert_mem_ssa_statements(typed_statements, &mut numeric_id_defs, &mut spirv_decl)?; let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; @@ -1421,62 +1439,23 @@ fn convert_to_typed_statements( }; result.push(Statement::Call(resolved_call)); } - // Supported ld/st: - // global: only compatible with reg b64/u64/s64 source/dest - // generic: compatible with global/local sources - // param: compiled as mov - // local compiled as mov - // We would like to convert ld/st local/param to movs here, - // but they have different semantics for implicit conversions - // For now, we convert generic ld from local params to ld.local. - // This way, we can rely on further stages of the compilation on - // ld.generic & ld.global having bytes address source - // One complication: immediate address is only allowed in local, - // It is not supported in generic ld - // ld.local foo, [1]; - ast::Instruction::Ld(mut d, arg) => { - match arg.src.underlying() { - None => {} - Some(u) => { - let (ss, _) = id_defs.get_typed(*u)?; - match (d.state_space, ss) { - (ast::LdStateSpace::Generic, StateSpace::Local) => { - d.state_space = ast::LdStateSpace::Local; - } - _ => {} - }; - } - }; + ast::Instruction::Ld(d, arg) => { result.push(Statement::Instruction(ast::Instruction::Ld(d, arg.cast()))); } - ast::Instruction::St(mut d, arg) => { - match arg.src1.underlying() { - None => {} - Some(u) => { - let (ss, _) = id_defs.get_typed(*u)?; - match (d.state_space, ss) { - (ast::StStateSpace::Generic, StateSpace::Local) => { - d.state_space = ast::StStateSpace::Local; - } - _ => (), - }; - } - }; + ast::Instruction::St(d, arg) => { result.push(Statement::Instruction(ast::Instruction::St(d, arg.cast()))); } ast::Instruction::Mov(mut d, args) => match args { ast::Arg2Mov::Normal(arg) => { if let Some(src_id) = arg.src.single_underlying() { - let (scope, _) = id_defs.get_typed(*src_id)?; - d.src_is_address = match scope { - StateSpace::Reg => false, - StateSpace::Const - | StateSpace::Global - | StateSpace::Local - | StateSpace::Shared - | StateSpace::Param - | StateSpace::ParamReg => true, + let (typ, _) = id_defs.get_typed(*src_id)?; + let take_address = match typ { + ast::Type::Scalar(_) => false, + ast::Type::Vector(_, _) => false, + ast::Type::Array(_, _) => true, + ast::Type::Pointer(_, _) => true, }; + d.src_is_address = take_address; } result.push(Statement::Instruction(ast::Instruction::Mov( d, @@ -1486,7 +1465,7 @@ fn convert_to_typed_statements( ast::Arg2Mov::Member(args) => { if let Some(dst_typ) = args.vector_dst() { match id_defs.get_typed(*dst_typ)? { - (_, ast::Type::Vector(_, len)) => { + (ast::Type::Vector(_, len), _) => { d.dst_width = len; } _ => return Err(TranslateError::MismatchedType), @@ -1494,7 +1473,7 @@ fn convert_to_typed_statements( }; if let Some((src_typ, _)) = args.vector_src() { match id_defs.get_typed(*src_typ)? { - (_, ast::Type::Vector(_, len)) => { + (ast::Type::Vector(_, len), _) => { d.src_width = len; } _ => return Err(TranslateError::MismatchedType), @@ -1650,17 +1629,8 @@ fn convert_to_typed_statements( }, Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), - Statement::LoadVar(a, t) => result.push(Statement::LoadVar(a, t)), - Statement::StoreVar(a, t) => result.push(Statement::StoreVar(a, t)), - Statement::Call(c) => result.push(Statement::Call(c.cast())), - Statement::Composite(c) => result.push(Statement::Composite(c)), Statement::Conditional(c) => result.push(Statement::Conditional(c)), - Statement::Conversion(c) => result.push(Statement::Conversion(c)), - Statement::Constant(c) => result.push(Statement::Constant(c)), - Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), - Statement::Undef(_, _) | Statement::PtrAdd { .. } => { - return Err(TranslateError::Unreachable) - } + _ => return Err(TranslateError::Unreachable), } } Ok(result) @@ -1689,14 +1659,14 @@ fn to_ptx_impl_atomic_call( }; let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.new_id(None); + let fn_id = id_defs.new_non_variable(None); let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( ast::ScalarType::U32, )), - name: id_defs.new_id(None), + name: id_defs.new_non_variable(None), array_init: Vec::new(), }], fn_id, @@ -1707,7 +1677,7 @@ fn to_ptx_impl_atomic_call( ast::SizedScalarType::U32, ptr_space, )), - name: id_defs.new_id(None), + name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { @@ -1715,7 +1685,7 @@ fn to_ptx_impl_atomic_call( v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( ast::ScalarType::U32, )), - name: id_defs.new_id(None), + name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ], @@ -1779,12 +1749,12 @@ fn to_ptx_impl_bfe_call( let fn_name = format!("{}{}", prefix, suffix); let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.new_id(None); + let fn_id = id_defs.new_non_variable(None); let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), - name: id_defs.new_id(None), + name: id_defs.new_non_variable(None), array_init: Vec::new(), }], fn_id, @@ -1792,7 +1762,7 @@ fn to_ptx_impl_bfe_call( ast::FnArgument { align: None, v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), - name: id_defs.new_id(None), + name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { @@ -1800,7 +1770,7 @@ fn to_ptx_impl_bfe_call( v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( ast::ScalarType::U32, )), - name: id_defs.new_id(None), + name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { @@ -1808,7 +1778,7 @@ fn to_ptx_impl_bfe_call( v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( ast::ScalarType::U32, )), - name: id_defs.new_id(None), + name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ], @@ -1893,10 +1863,10 @@ fn normalize_labels( | Statement::Constant(_) | Statement::Label(_) | Statement::Undef(_, _) - | Statement::PtrAdd { .. } => {} + | Statement::PtrAccess { .. } => {} } } - iter::once(Statement::Label(id_def.new_id(None))) + iter::once(Statement::Label(id_def.new_non_variable(None))) .chain(func.into_iter().filter(|s| match s { Statement::Label(i) => labels_in_use.contains(i), _ => true, @@ -1907,15 +1877,15 @@ fn normalize_labels( fn normalize_predicates( func: Vec, id_def: &mut NumericIdResolver, -) -> Vec { +) -> Result, TranslateError> { let mut result = Vec::with_capacity(func.len()); for s in func { match s { Statement::Label(id) => result.push(Statement::Label(id)), Statement::Instruction((pred, inst)) => { if let Some(pred) = pred { - let if_true = id_def.new_id(None); - let if_false = id_def.new_id(None); + let if_true = id_def.new_non_variable(None); + let if_false = id_def.new_non_variable(None); let folded_bra = match &inst { ast::Instruction::Bra(_, arg) => Some(arg.src), _ => None, @@ -1940,20 +1910,25 @@ fn normalize_predicates( } Statement::Variable(var) => result.push(Statement::Variable(var)), // Blocks are flattened when resolving ids - _ => unreachable!(), + _ => return Err(TranslateError::Unreachable), } } - result + Ok(result) } fn insert_mem_ssa_statements<'a, 'b>( func: Vec, - id_def: &mut MutableNumericIdResolver, + id_def: &mut NumericIdResolver, + ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>, fn_decl: &mut SpirvMethodDecl, ) -> Result, TranslateError> { + let is_func = match ast_fn_decl { + ast::MethodDecl::Func(..) => true, + ast::MethodDecl::Kernel { .. } => false, + }; let mut result = Vec::with_capacity(func.len()); for arg in fn_decl.output.iter() { - match type_to_variable_type(&arg.v_type)? { + match type_to_variable_type(&arg.v_type, is_func)? { Some(var_type) => { result.push(Statement::Variable(ast::Variable { align: arg.align, @@ -1965,25 +1940,25 @@ fn insert_mem_ssa_statements<'a, 'b>( None => return Err(TranslateError::Unreachable), } } - for arg in fn_decl.input.iter_mut() { - match type_to_variable_type(&arg.v_type)? { + for spirv_arg in fn_decl.input.iter_mut() { + match type_to_variable_type(&spirv_arg.v_type, is_func)? { Some(var_type) => { - let typ = arg.v_type.clone(); - let new_id = id_def.new_id(typ.clone()); + let typ = spirv_arg.v_type.clone(); + let new_id = id_def.new_non_variable(Some(typ.clone())); result.push(Statement::Variable(ast::Variable { - align: arg.align, + align: spirv_arg.align, v_type: var_type, - name: arg.name, - array_init: arg.array_init.clone(), + name: spirv_arg.name, + array_init: spirv_arg.array_init.clone(), })); result.push(Statement::StoreVar( ast::Arg2St { - src1: arg.name, + src1: spirv_arg.name, src2: new_id, }, typ, )); - arg.name = new_id; + spirv_arg.name = new_id; } None => {} } @@ -1997,8 +1972,8 @@ fn insert_mem_ssa_statements<'a, 'b>( ast::Instruction::Ret(d) => { // TODO: handle multiple output args if let &[out_param] = &fn_decl.output.as_slice() { - let typ = id_def.get_typed(out_param.name)?; - let new_id = id_def.new_id(typ.clone()); + let (typ, _) = id_def.get_typed(out_param.name)?; + let new_id = id_def.new_non_variable(Some(typ.clone())); result.push(Statement::LoadVar( ast::Arg2 { dst: new_id, @@ -2014,7 +1989,8 @@ fn insert_mem_ssa_statements<'a, 'b>( inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?, }, Statement::Conditional(mut bra) => { - let generated_id = id_def.new_id(ast::Type::Scalar(ast::ScalarType::Pred)); + let generated_id = + id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred))); result.push(Statement::LoadVar( Arg2 { dst: generated_id, @@ -2025,21 +2001,23 @@ fn insert_mem_ssa_statements<'a, 'b>( bra.predicate = generated_id; result.push(Statement::Conditional(bra)); } + Statement::Conversion(conv) => { + insert_mem_ssa_statement_default(id_def, &mut result, conv)? + } + Statement::PtrAccess(ptr_access) => { + insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)? + } s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s), - Statement::LoadVar(_, _) - | Statement::StoreVar(_, _) - | Statement::Conversion(_) - | Statement::RetValue(_, _) - | Statement::Constant(_) - | Statement::Undef(_, _) - | Statement::PtrAdd { .. } => {} - Statement::Composite(_) => todo!(), + _ => return Err(TranslateError::Unreachable), } } Ok(result) } -fn type_to_variable_type(t: &ast::Type) -> Result, TranslateError> { +fn type_to_variable_type( + t: &ast::Type, + is_func: bool, +) -> Result, TranslateError> { Ok(match t { ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))), ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector( @@ -2054,7 +2032,22 @@ fn type_to_variable_type(t: &ast::Type) -> Result, Tra .map_err(|_| TranslateError::MismatchedType)?, len.clone(), ))), - ast::Type::Pointer(_, _) => None, + ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => { + if is_func { + return Ok(None); + } + Some(ast::VariableType::Reg(ast::VariableRegType::Pointer( + scalar_type + .clone() + .try_into() + .map_err(|_| TranslateError::Unreachable)?, + (*space) + .try_into() + .map_err(|_| TranslateError::Unreachable)?, + ))) + } + ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None, + _ => return Err(TranslateError::Unreachable), }) } @@ -2105,34 +2098,28 @@ impl<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> VisitVariableExpanded } fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( - id_def: &mut MutableNumericIdResolver, + id_def: &mut NumericIdResolver, result: &mut Vec, stmt: F, ) -> Result<(), TranslateError> { let mut post_statements = Vec::new(); - let new_statement = - stmt.visit_variable(&mut |desc: ArgumentDescriptor, instr_type| { - if instr_type.is_none() || desc.sema == ArgumentSemantics::RegisterPointer { + let new_statement = stmt.visit_variable( + &mut |desc: ArgumentDescriptor, expected_type| { + if expected_type.is_none() { return Ok(desc.op); - } - let id_type = match (id_def.get_typed(desc.op)?, desc.sema) { - (_, ArgumentSemantics::Address) => return Ok(desc.op), - (t, ArgumentSemantics::RegisterPointer) - | (t, ArgumentSemantics::Default) - | (t, ArgumentSemantics::DefaultRelaxed) - | (t, ArgumentSemantics::PhysicalPointer) => t, }; - if let ast::Type::Array(_, _) = id_type { + let (var_type, is_variable) = id_def.get_typed(desc.op)?; + if !is_variable { return Ok(desc.op); } - let generated_id = id_def.new_id(id_type.clone()); + let generated_id = id_def.new_non_variable(Some(var_type.clone())); if !desc.is_dst { result.push(Statement::LoadVar( Arg2 { dst: generated_id, src: desc.op, }, - id_type, + var_type, )); } else { post_statements.push(Statement::StoreVar( @@ -2140,11 +2127,12 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( src1: desc.op, src2: generated_id, }, - id_type, + var_type, )); } Ok(generated_id) - })?; + }, + )?; result.push(new_statement); result.append(&mut post_statements); Ok(()) @@ -2180,65 +2168,21 @@ fn expand_arguments<'a, 'b>( name, array_init, })), - Statement::PtrAdd { - underlying_type, - state_space, - dst, - ptr_src, - constant_src, - } => { + Statement::PtrAccess(ptr_access) => { let mut visitor = FlattenArguments::new(&mut result, id_def); - let sema = match state_space { - ast::LdStateSpace::Const - | ast::LdStateSpace::Global - | ast::LdStateSpace::Shared - | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer, - ast::LdStateSpace::Local | ast::LdStateSpace::Param => { - ArgumentSemantics::RegisterPointer - } - }; - let ptr_type = ast::Type::Pointer(underlying_type.clone(), state_space); - let new_dst = visitor.id( - ArgumentDescriptor { - op: dst, - is_dst: true, - sema, - }, - Some(&ptr_type), - )?; - let new_ptr_src = visitor.id( - ArgumentDescriptor { - op: ptr_src, - is_dst: false, - sema, - }, - Some(&ptr_type), - )?; - let new_constant_src = visitor.id( - ArgumentDescriptor { - op: constant_src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Scalar(ast::ScalarType::S64)), - )?; - result.push(Statement::PtrAdd { - underlying_type, - state_space, - dst: new_dst, - ptr_src: new_ptr_src, - constant_src: new_constant_src, - }) + let (new_inst, post_stmts) = (ptr_access.map(&mut visitor)?, visitor.post_stmts); + result.push(Statement::PtrAccess(new_inst)); + result.extend(post_stmts); } Statement::Label(id) => result.push(Statement::Label(id)), Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)), Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), - Statement::Composite(_) - | Statement::Conversion(_) - | Statement::Constant(_) - | Statement::Undef(_, _) => unreachable!(), + Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), + Statement::Composite(_) | Statement::Constant(_) | Statement::Undef(_, _) => { + return Err(TranslateError::Unreachable) + } } } Ok(result) @@ -2270,7 +2214,8 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { scalar_sema_override: Option, composite_src: (spirv::Word, u8), ) -> spirv::Word { - let new_id = scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Scalar(typ.0))); + let new_id = + scalar_dst.unwrap_or_else(|| id_def.new_non_variable(ast::Type::Scalar(typ.0))); func.push(Statement::Composite(CompositeRead { typ: typ.0, dst: new_id, @@ -2301,20 +2246,20 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { ast::Type::Pointer(underlying_type, state_space) => { let reg_typ = self.id_def.get_typed(reg)?; if let ast::Type::Pointer(_, _) = reg_typ { - let id_constant_stmt = self.id_def.new_id(typ.clone()); + let id_constant_stmt = self.id_def.new_non_variable(typ.clone()); self.func.push(Statement::Constant(ConstantDefinition { dst: id_constant_stmt, typ: ast::ScalarType::S64, value: ast::ImmediateValue::S64(offset as i64), })); - let dst = self.id_def.new_id(typ.clone()); - self.func.push(Statement::PtrAdd { + let dst = self.id_def.new_non_variable(typ.clone()); + self.func.push(Statement::PtrAccess(PtrAccess { underlying_type: underlying_type.clone(), state_space: *state_space, dst, ptr_src: reg, - constant_src: id_constant_stmt, - }); + offset_src: id_constant_stmt, + })); return Ok(dst); } else { add_type = self.id_def.get_typed(reg)?; @@ -2346,8 +2291,8 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { } else { ast::ArithDetails::Unsigned(ast::UIntType::from_size(width)) }; - let id_constant_stmt = self.id_def.new_id(add_type.clone()); - let result_id = self.id_def.new_id(add_type); + let id_constant_stmt = self.id_def.new_non_variable(add_type.clone()); + let result_id = self.id_def.new_non_variable(add_type); // TODO: check for edge cases around min value/max value/wrapping if offset < 0 && kind != ScalarKind::Signed { self.func.push(Statement::Constant(ConstantDefinition { @@ -2395,7 +2340,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { } else { todo!() }; - let id = self.id_def.new_id(ast::Type::Scalar(scalar_t)); + let id = self.id_def.new_non_variable(ast::Type::Scalar(scalar_t)); self.func.push(Statement::Constant(ConstantDefinition { dst: id, typ: scalar_t, @@ -2430,10 +2375,10 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { ) -> Result { let (scalar_type, vec_len) = typ.get_vector()?; if !desc.is_dst { - let mut new_id = self.id_def.new_id(typ.clone()); + let mut new_id = self.id_def.new_non_variable(typ.clone()); self.func.push(Statement::Undef(typ.clone(), new_id)); for (idx, id) in desc.op.iter().enumerate() { - let newer_id = self.id_def.new_id(typ.clone()); + let newer_id = self.id_def.new_non_variable(typ.clone()); self.func.push(Statement::Instruction(ast::Instruction::Mov( ast::MovDetails { typ: ast::Type::Scalar(scalar_type), @@ -2452,7 +2397,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { } Ok(new_id) } else { - let new_id = self.id_def.new_id(typ.clone()); + let new_id = self.id_def.new_non_variable(typ.clone()); for (idx, id) in desc.op.iter().enumerate() { Self::insert_composite_read( &mut self.post_stmts, @@ -2597,13 +2542,13 @@ fn insert_implicit_conversions( should_bitcast_wrapper, None, )?, - Statement::PtrAdd { + Statement::PtrAccess(PtrAccess { underlying_type, state_space, dst, ptr_src, - constant_src, - } => { + offset_src: constant_src, + }) => { let visit_desc = VisitArgumentDescriptor { desc: ArgumentDescriptor { op: ptr_src, @@ -2611,12 +2556,14 @@ fn insert_implicit_conversions( sema: ArgumentSemantics::PhysicalPointer, }, typ: &ast::Type::Pointer(underlying_type.clone(), state_space), - stmt_ctor: |new_ptr_src| Statement::PtrAdd { - underlying_type, - state_space, - dst, - ptr_src: new_ptr_src, - constant_src, + stmt_ctor: |new_ptr_src| { + Statement::PtrAccess(PtrAccess { + underlying_type, + state_space, + dst, + ptr_src: new_ptr_src, + offset_src: constant_src, + }) }, }; insert_implicit_conversions_impl( @@ -2628,6 +2575,7 @@ fn insert_implicit_conversions( )?; } s @ Statement::Conditional(_) + | s @ Statement::Conversion(_) | s @ Statement::Label(_) | s @ Statement::Constant(_) | s @ Statement::Variable(_) @@ -2635,7 +2583,6 @@ fn insert_implicit_conversions( | s @ Statement::StoreVar(_, _) | s @ Statement::Undef(_, _) | s @ Statement::RetValue(_, _) => result.push(s), - Statement::Conversion(_) => unreachable!(), } } Ok(result) @@ -2688,7 +2635,7 @@ fn insert_implicit_conversions_impl( }; let mut from = instr_type.clone(); let mut to = operand_type; - let mut src = id_def.new_id(instr_type.clone()); + let mut src = id_def.new_non_variable(instr_type.clone()); let mut dst = desc.op; let result = Ok(src); if !desc.is_dst { @@ -2701,6 +2648,8 @@ fn insert_implicit_conversions_impl( from, to, kind: conv_kind, + src_sema: ArgumentSemantics::Default, + dst_sema: ArgumentSemantics::Default, })); result } @@ -3242,21 +3191,33 @@ fn emit_function_body_ops( let result_type = map.get_or_add(builder, SpirvType::from(t.clone())); builder.undef(result_type, Some(*id)); } - Statement::PtrAdd { + Statement::PtrAccess(PtrAccess { underlying_type, state_space, dst, ptr_src, - constant_src, - } => { - let s64_type = map.get_or_add_scalar(builder, ast::ScalarType::S64); - let ptr_as_s64 = builder.bitcast(s64_type, None, *ptr_src)?; - let added_ptr = builder.i_add(s64_type, None, ptr_as_s64, *constant_src)?; + offset_src, + }) => { + let u8_pointer = map.get_or_add( + builder, + SpirvType::from(ast::Type::Pointer( + ast::PointerType::Scalar(ast::ScalarType::U8), + *state_space, + )), + ); let result_type = map.get_or_add( builder, SpirvType::from(ast::Type::Pointer(underlying_type.clone(), *state_space)), ); - builder.bitcast(result_type, Some(*dst), added_ptr)?; + let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?; + let temp = builder.in_bounds_ptr_access_chain( + u8_pointer, + None, + ptr_src_u8, + *offset_src, + &[], + )?; + builder.bitcast(result_type, Some(*dst), temp)?; } } } @@ -3745,6 +3706,8 @@ fn emit_cvt( src_t.kind(), )), kind: ConversionKind::Default, + src_sema: ArgumentSemantics::Default, + dst_sema: ArgumentSemantics::Default, }; emit_implicit_conversion(builder, map, &cv)?; new_dst @@ -4117,6 +4080,8 @@ fn emit_implicit_conversion( from: wide_bit_type, to: cv.to.clone(), kind: ConversionKind::Default, + src_sema: cv.src_sema, + dst_sema: cv.dst_sema, }, )?; } @@ -4156,7 +4121,7 @@ fn normalize_identifiers<'a, 'b>( for s in func.iter() { match s { ast::Statement::Label(id) => { - id_defs.add_def(*id, None); + id_defs.add_def(*id, None, false); } _ => (), } @@ -4189,23 +4154,35 @@ fn expand_map_variables<'a, 'b>( i.map_variable(&mut |id| id_defs.get_id(id))?, ))), ast::Statement::Variable(var) => { - let ss = match var.var.v_type { - ast::VariableType::Reg(_) => StateSpace::Reg, - ast::VariableType::Global(_) => StateSpace::Global, - ast::VariableType::Shared(_) => StateSpace::Shared, - ast::VariableType::Param(_) => StateSpace::ParamReg, - ast::VariableType::Local(_) => StateSpace::Local, - }; let mut var_type = ast::Type::from(var.var.v_type.clone()); + let mut is_variable = false; var_type = match var.var.v_type { - ast::VariableType::Reg(_) | ast::VariableType::Shared(_) => var_type, - ast::VariableType::Global(_) => var_type.pointer_to(ast::LdStateSpace::Global)?, - ast::VariableType::Param(_) => var_type.pointer_to(ast::LdStateSpace::Param)?, - ast::VariableType::Local(_) => var_type.pointer_to(ast::LdStateSpace::Local)?, + ast::VariableType::Reg(_) => { + is_variable = true; + var_type + } + ast::VariableType::Shared(_) => { + // If it's a pointer it will be translated to a method parameter later + if let ast::Type::Pointer(..) = var_type { + is_variable = true; + var_type + } else { + var_type.param_pointer_to(ast::LdStateSpace::Shared)? + } + } + ast::VariableType::Global(_) => { + var_type.param_pointer_to(ast::LdStateSpace::Global)? + } + ast::VariableType::Param(_) => { + var_type.param_pointer_to(ast::LdStateSpace::Param)? + } + ast::VariableType::Local(_) => { + var_type.param_pointer_to(ast::LdStateSpace::Local)? + } }; match var.count { Some(count) => { - for new_id in id_defs.add_defs(var.var.name, count, ss, var_type) { + for new_id in id_defs.add_defs(var.var.name, count, var_type, is_variable) { result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), @@ -4215,7 +4192,7 @@ fn expand_map_variables<'a, 'b>( } } None => { - let new_id = id_defs.add_def(var.var.name, Some((ss, var_type))); + let new_id = id_defs.add_def(var.var.name, Some(var_type), is_variable); result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), @@ -4229,6 +4206,384 @@ fn expand_map_variables<'a, 'b>( Ok(()) } +// TODO: detect more patterns (mov, call via reg, call via param) +// TODO: don't convert to ptr if the register is not ultimately used for ld/st +// TODO: once insert_mem_ssa_statements is moved to later, move this pass after +// argument expansion +// TODO: propagate through calls? +fn convert_to_stateful_memory_access<'a>( + func_args: &mut SpirvMethodDecl, + func_body: Vec, + id_defs: &mut NumericIdResolver<'a>, +) -> Result, TranslateError> { + let func_args_64bit = func_args + .input + .iter() + .filter_map(|arg| match arg.v_type { + ast::Type::Scalar(ast::ScalarType::U64) + | ast::Type::Scalar(ast::ScalarType::B64) + | ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name), + _ => None, + }) + .collect::>(); + let mut stateful_markers = Vec::new(); + let mut stateful_init_reg = MultiHashMap::new(); + for statement in func_body.iter() { + match statement { + Statement::Instruction(ast::Instruction::Cvta( + ast::CvtaDetails { + to: ast::CvtaStateSpace::Global, + size: ast::CvtaSize::U64, + from: ast::CvtaStateSpace::Generic, + }, + arg, + )) => { + if let Some(src) = arg.src.underlying() { + if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, arg.dst) { + stateful_markers.push((arg.dst, *src)); + } + } + } + Statement::Instruction(ast::Instruction::Ld( + ast::LdDetails { + state_space: ast::LdStateSpace::Param, + typ: ast::LdStType::Scalar(ast::LdStScalarType::U64), + .. + }, + arg, + )) + | Statement::Instruction(ast::Instruction::Ld( + ast::LdDetails { + state_space: ast::LdStateSpace::Param, + typ: ast::LdStType::Scalar(ast::LdStScalarType::S64), + .. + }, + arg, + )) + | Statement::Instruction(ast::Instruction::Ld( + ast::LdDetails { + state_space: ast::LdStateSpace::Param, + typ: ast::LdStType::Scalar(ast::LdStScalarType::B64), + .. + }, + arg, + )) => { + if let (ast::IdOrVector::Reg(dst), Some(src)) = (&arg.dst, arg.src.underlying()) { + if func_args_64bit.contains(src) { + multi_hash_map_append(&mut stateful_init_reg, *dst, *src); + } + } + } + _ => {} + } + } + let mut func_args_ptr = HashSet::new(); + let mut regs_ptr_current = HashSet::new(); + for (dst, src) in stateful_markers { + if let Some(func_args) = stateful_init_reg.get(&src) { + for a in func_args { + func_args_ptr.insert(*a); + regs_ptr_current.insert(src); + regs_ptr_current.insert(dst); + } + } + } + // BTreeSet here to have a stable order of iteration, + // unfortunately our tests rely on it + let mut regs_ptr_seen = BTreeSet::new(); + while regs_ptr_current.len() > 0 { + let mut regs_ptr_new = HashSet::new(); + for statement in func_body.iter() { + match statement { + Statement::Instruction(ast::Instruction::Add( + ast::ArithDetails::Unsigned(ast::UIntType::U64), + arg, + )) + | Statement::Instruction(ast::Instruction::Add( + ast::ArithDetails::Signed(ast::ArithSInt { + typ: ast::SIntType::S64, + saturate: false, + }), + arg, + )) + | Statement::Instruction(ast::Instruction::Sub( + ast::ArithDetails::Unsigned(ast::UIntType::U64), + arg, + )) + | Statement::Instruction(ast::Instruction::Sub( + ast::ArithDetails::Signed(ast::ArithSInt { + typ: ast::SIntType::S64, + saturate: false, + }), + arg, + )) => { + if let Some(src1) = arg.src1.underlying() { + if regs_ptr_current.contains(src1) && !regs_ptr_seen.contains(src1) { + regs_ptr_new.insert(arg.dst); + } + } else if let Some(src2) = arg.src2.underlying() { + if regs_ptr_current.contains(src2) && !regs_ptr_seen.contains(src2) { + regs_ptr_new.insert(arg.dst); + } + } + } + _ => {} + } + } + for id in regs_ptr_current { + regs_ptr_seen.insert(id); + } + regs_ptr_current = regs_ptr_new; + } + drop(regs_ptr_current); + let mut remapped_ids = HashMap::new(); + let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len()); + for reg in regs_ptr_seen { + let new_id = id_defs.new_variable(ast::Type::Pointer( + ast::PointerType::Scalar(ast::ScalarType::U8), + ast::LdStateSpace::Global, + )); + result.push(Statement::Variable(ast::Variable { + align: None, + name: new_id, + array_init: Vec::new(), + v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( + ast::SizedScalarType::U8, + ast::PointerStateSpace::Global, + )), + })); + remapped_ids.insert(reg, new_id); + } + for statement in func_body { + match statement { + l @ Statement::Label(_) => result.push(l), + c @ Statement::Conditional(_) => result.push(c), + Statement::Variable(var) => { + if !remapped_ids.contains_key(&var.name) { + result.push(Statement::Variable(var)); + } + } + Statement::Instruction(ast::Instruction::Add( + ast::ArithDetails::Unsigned(ast::UIntType::U64), + arg, + )) + | Statement::Instruction(ast::Instruction::Add( + ast::ArithDetails::Signed(ast::ArithSInt { + typ: ast::SIntType::S64, + saturate: false, + }), + arg, + )) if is_add_ptr_direct(&remapped_ids, &arg) => { + let (ptr, offset) = match arg.src1.underlying() { + Some(src1) if remapped_ids.contains_key(src1) => { + (remapped_ids.get(src1).unwrap(), arg.src2) + } + Some(src2) if remapped_ids.contains_key(src2) => { + (remapped_ids.get(src2).unwrap(), arg.src1) + } + _ => return Err(TranslateError::Unreachable), + }; + result.push(Statement::PtrAccess(PtrAccess { + underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), + state_space: ast::LdStateSpace::Global, + dst: *remapped_ids.get(&arg.dst).unwrap(), + ptr_src: *ptr, + offset_src: offset, + })) + } + Statement::Instruction(ast::Instruction::Sub( + ast::ArithDetails::Unsigned(ast::UIntType::U64), + arg, + )) + | Statement::Instruction(ast::Instruction::Sub( + ast::ArithDetails::Signed(ast::ArithSInt { + typ: ast::SIntType::S64, + saturate: false, + }), + arg, + )) if is_add_ptr_direct(&remapped_ids, &arg) => { + let (ptr, offset) = match arg.src1.underlying() { + Some(src1) if remapped_ids.contains_key(src1) => { + (remapped_ids.get(src1).unwrap(), arg.src2) + } + Some(src2) if remapped_ids.contains_key(src2) => { + (remapped_ids.get(src2).unwrap(), arg.src1) + } + _ => return Err(TranslateError::Unreachable), + }; + let offset_neg = + id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64))); + result.push(Statement::Instruction(ast::Instruction::Neg( + ast::NegDetails { + typ: ast::ScalarType::S64, + flush_to_zero: None, + }, + ast::Arg2 { + src: offset, + dst: offset_neg, + }, + ))); + result.push(Statement::PtrAccess(PtrAccess { + underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), + state_space: ast::LdStateSpace::Global, + dst: *remapped_ids.get(&arg.dst).unwrap(), + ptr_src: *ptr, + offset_src: ast::Operand::Reg(offset_neg), + })) + } + Statement::Instruction(inst) => { + let mut post_statements = Vec::new(); + let new_statement = inst.visit_variable( + &mut |arg_desc: ArgumentDescriptor, expected_type| { + convert_to_stateful_memory_access_postprocess( + id_defs, + &remapped_ids, + &func_args_ptr, + &mut result, + &mut post_statements, + arg_desc, + expected_type, + ) + }, + )?; + result.push(new_statement); + for s in post_statements { + result.push(s); + } + } + Statement::Call(call) => { + let mut post_statements = Vec::new(); + let new_statement = call.visit_variable( + &mut |arg_desc: ArgumentDescriptor, expected_type| { + convert_to_stateful_memory_access_postprocess( + id_defs, + &remapped_ids, + &func_args_ptr, + &mut result, + &mut post_statements, + arg_desc, + expected_type, + ) + }, + )?; + result.push(new_statement); + for s in post_statements { + result.push(s); + } + } + _ => return Err(TranslateError::Unreachable), + } + } + for arg in func_args.input.iter_mut() { + if func_args_ptr.contains(&arg.name) { + arg.v_type = ast::Type::Pointer( + ast::PointerType::Scalar(ast::ScalarType::U8), + ast::LdStateSpace::Global, + ); + } + } + Ok(result) +} + +fn convert_to_stateful_memory_access_postprocess( + id_defs: &mut NumericIdResolver, + remapped_ids: &HashMap, + func_args_ptr: &HashSet, + result: &mut Vec, + post_statements: &mut Vec, + arg_desc: ArgumentDescriptor, + expected_type: Option<&ast::Type>, +) -> Result { + Ok(match remapped_ids.get(&arg_desc.op) { + Some(new_id) => { + // We skip conversion here to trigger PtrAcces in a later pass + let old_type = match expected_type { + Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id), + _ => id_defs.get_typed(arg_desc.op)?.0, + }; + let old_type_clone = old_type.clone(); + let converting_id = id_defs.new_non_variable(Some(old_type_clone)); + if arg_desc.is_dst { + post_statements.push(Statement::Conversion(ImplicitConversion { + src: converting_id, + dst: *new_id, + from: old_type, + to: ast::Type::Pointer( + ast::PointerType::Scalar(ast::ScalarType::U8), + ast::LdStateSpace::Global, + ), + kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global), + src_sema: ArgumentSemantics::Default, + dst_sema: arg_desc.sema, + })); + converting_id + } else { + result.push(Statement::Conversion(ImplicitConversion { + src: *new_id, + dst: converting_id, + from: ast::Type::Pointer( + ast::PointerType::Scalar(ast::ScalarType::U8), + ast::LdStateSpace::Global, + ), + to: old_type, + kind: ConversionKind::PtrToBit(ast::UIntType::U64), + src_sema: arg_desc.sema, + dst_sema: ArgumentSemantics::Default, + })); + converting_id + } + } + None => match func_args_ptr.get(&arg_desc.op) { + Some(new_id) => { + if arg_desc.is_dst { + return Err(TranslateError::Unreachable); + } + // We skip conversion here to trigger PtrAcces in a later pass + let old_type = match expected_type { + Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id), + _ => id_defs.get_typed(arg_desc.op)?.0, + }; + let old_type_clone = old_type.clone(); + let converting_id = id_defs.new_non_variable(Some(old_type)); + result.push(Statement::Conversion(ImplicitConversion { + src: *new_id, + dst: converting_id, + from: ast::Type::Pointer( + ast::PointerType::Pointer(ast::ScalarType::U8, ast::LdStateSpace::Global), + ast::LdStateSpace::Param, + ), + to: old_type_clone, + kind: ConversionKind::PtrToPtr { spirv_ptr: false }, + src_sema: arg_desc.sema, + dst_sema: ArgumentSemantics::Default, + })); + converting_id + } + None => arg_desc.op, + }, + }) +} + +fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { + if !remapped_ids.contains_key(&arg.dst) { + return false; + } + match arg.src1.underlying() { + Some(src1) if remapped_ids.contains_key(src1) => true, + Some(src2) if remapped_ids.contains_key(src2) => true, + _ => false, + } +} + +fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool { + match id_defs.get_typed(id) { + Ok((ast::Type::Scalar(ast::ScalarType::U64), _)) + | Ok((ast::Type::Scalar(ast::ScalarType::S64), _)) + | Ok((ast::Type::Scalar(ast::ScalarType::B64), _)) => true, + _ => false, + } +} + #[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] enum PtxSpecialRegister { Tid, @@ -4270,7 +4625,7 @@ impl PtxSpecialRegister { struct GlobalStringIdResolver<'input> { current_id: spirv::Word, variables: HashMap, spirv::Word>, - variables_type_check: HashMap>, + variables_type_check: HashMap>, special_registers: HashMap, fns: HashMap, } @@ -4295,15 +4650,16 @@ impl<'a> GlobalStringIdResolver<'a> { self.get_or_add_impl(id, None) } - fn get_or_add_def_typed(&mut self, id: &'a str, typ: (StateSpace, ast::Type)) -> spirv::Word { - self.get_or_add_impl(id, Some(typ)) - } - - fn get_or_add_impl( + fn get_or_add_def_typed( &mut self, id: &'a str, - typ: Option<(StateSpace, ast::Type)>, + typ: ast::Type, + is_variable: bool, ) -> spirv::Word { + self.get_or_add_impl(id, Some((typ, is_variable))) + } + + fn get_or_add_impl(&mut self, id: &'a str, typ: Option<(ast::Type, bool)>) -> spirv::Word { let id = match self.variables.entry(Cow::Borrowed(id)) { hash_map::Entry::Occupied(e) => *(e.get()), hash_map::Entry::Vacant(e) => { @@ -4399,10 +4755,10 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { struct FnStringIdResolver<'input, 'b> { current_id: &'b mut spirv::Word, global_variables: &'b HashMap, spirv::Word>, - global_type_check: &'b HashMap>, + global_type_check: &'b HashMap>, special_registers: &'b mut HashMap, variables: Vec, spirv::Word>>, - type_check: HashMap>, + type_check: HashMap>, } impl<'a, 'b> FnStringIdResolver<'a, 'b> { @@ -4452,13 +4808,14 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { } } - fn add_def(&mut self, id: &'a str, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word { + fn add_def(&mut self, id: &'a str, typ: Option, is_variable: bool) -> spirv::Word { let numeric_id = *self.current_id; self.variables .last_mut() .unwrap() .insert(Cow::Borrowed(id), numeric_id); - self.type_check.insert(numeric_id, typ); + self.type_check + .insert(numeric_id, typ.map(|t| (t, is_variable))); *self.current_id += 1; numeric_id } @@ -4468,8 +4825,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { &mut self, base_id: &'a str, count: u32, - ss: StateSpace, typ: ast::Type, + is_variable: bool, ) -> impl Iterator { let numeric_id = *self.current_id; for i in 0..count { @@ -4478,7 +4835,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .unwrap() .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i); self.type_check - .insert(numeric_id + i, Some((ss, typ.clone()))); + .insert(numeric_id + i, Some((typ.clone(), is_variable))); } *self.current_id += count; (0..count).into_iter().map(move |i| i + numeric_id) @@ -4487,8 +4844,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { struct NumericIdResolver<'b> { current_id: &'b mut spirv::Word, - global_type_check: &'b HashMap>, - type_check: HashMap>, + global_type_check: &'b HashMap>, + type_check: HashMap>, special_registers: HashMap, } @@ -4497,23 +4854,32 @@ impl<'b> NumericIdResolver<'b> { MutableNumericIdResolver { base: self } } - fn get_typed(&self, id: spirv::Word) -> Result<(StateSpace, ast::Type), TranslateError> { + fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, bool), TranslateError> { match self.type_check.get(&id) { Some(Some(x)) => Ok(x.clone()), Some(None) => Err(TranslateError::UntypedSymbol), None => match self.special_registers.get(&id) { - Some(x) => Ok((StateSpace::Reg, x.get_type())), + Some(x) => Ok((x.get_type(), true)), None => match self.global_type_check.get(&id) { - Some(Some(x)) => Ok(x.clone()), + Some(Some(result)) => Ok(result.clone()), Some(None) | None => Err(TranslateError::UntypedSymbol), }, }, } } - fn new_id(&mut self, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word { + // This is for identifiers which will be emitted later as OpVariable + // They are candidates for insertion of LoadVar/StoreVar + fn new_variable(&mut self, typ: ast::Type) -> spirv::Word { let new_id = *self.current_id; - self.type_check.insert(new_id, typ); + self.type_check.insert(new_id, Some((typ, true))); + *self.current_id += 1; + new_id + } + + fn new_non_variable(&mut self, typ: Option) -> spirv::Word { + let new_id = *self.current_id; + self.type_check.insert(new_id, typ.map(|t| (t, false))); *self.current_id += 1; new_id } @@ -4529,11 +4895,11 @@ impl<'b> MutableNumericIdResolver<'b> { } fn get_typed(&self, id: spirv::Word) -> Result { - self.base.get_typed(id).map(|(_, t)| t) + self.base.get_typed(id).map(|(t, _)| t) } - fn new_id(&mut self, typ: ast::Type) -> spirv::Word { - self.base.new_id(Some((StateSpace::Reg, typ))) + fn new_non_variable(&mut self, typ: ast::Type) -> spirv::Word { + self.base.new_non_variable(Some(typ)) } } @@ -4541,101 +4907,102 @@ enum Statement { Label(u32), Variable(ast::Variable), Instruction(I), - LoadVar(ast::Arg2, ast::Type), - StoreVar(ast::Arg2St, ast::Type), - Call(ResolvedCall

), - Composite(CompositeRead), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), + Call(ResolvedCall

), + LoadVar(ast::Arg2, ast::Type), + StoreVar(ast::Arg2St, ast::Type), + Composite(CompositeRead), Conversion(ImplicitConversion), Constant(ConstantDefinition), RetValue(ast::RetData, spirv::Word), Undef(ast::Type, spirv::Word), - PtrAdd { - underlying_type: ast::PointerType, - state_space: ast::LdStateSpace, - dst: spirv::Word, - ptr_src: spirv::Word, - constant_src: spirv::Word, - }, + PtrAccess(PtrAccess

), } impl ExpandedStatement { - fn map_id(self, f: &mut impl FnMut(spirv::Word) -> spirv::Word) -> ExpandedStatement { + fn map_id(self, f: &mut impl FnMut(spirv::Word, bool) -> spirv::Word) -> ExpandedStatement { match self { - Statement::Label(id) => Statement::Label(f(id)), + Statement::Label(id) => Statement::Label(f(id, false)), Statement::Variable(mut var) => { - var.name = f(var.name); + var.name = f(var.name, true); Statement::Variable(var) } Statement::Instruction(inst) => inst - .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| Ok(f(arg.op))) + .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| { + Ok(f(arg.op, arg.is_dst)) + }) .unwrap(), Statement::LoadVar(mut arg, typ) => { - arg.dst = f(arg.dst); - arg.src = f(arg.src); + arg.dst = f(arg.dst, true); + arg.src = f(arg.src, false); Statement::LoadVar(arg, typ) } Statement::StoreVar(mut arg, typ) => { - arg.src1 = f(arg.src1); - arg.src2 = f(arg.src2); + arg.src1 = f(arg.src1, false); + arg.src2 = f(arg.src2, false); Statement::StoreVar(arg, typ) } Statement::Call(mut call) => { - for (id, _) in call.ret_params.iter_mut() { - *id = f(*id); + for (id, typ) in call.ret_params.iter_mut() { + let is_dst = match typ { + ast::FnArgumentType::Reg(_) => true, + ast::FnArgumentType::Param(_) => false, + ast::FnArgumentType::Shared => false, + }; + *id = f(*id, is_dst); } - call.func = f(call.func); + call.func = f(call.func, false); for (id, _) in call.param_list.iter_mut() { - *id = f(*id); + *id = f(*id, false); } Statement::Call(call) } Statement::Composite(mut composite) => { - composite.dst = f(composite.dst); - composite.src_composite = f(composite.src_composite); + composite.dst = f(composite.dst, true); + composite.src_composite = f(composite.src_composite, false); Statement::Composite(composite) } Statement::Conditional(mut conditional) => { - conditional.predicate = f(conditional.predicate); - conditional.if_true = f(conditional.if_true); - conditional.if_false = f(conditional.if_false); + conditional.predicate = f(conditional.predicate, false); + conditional.if_true = f(conditional.if_true, false); + conditional.if_false = f(conditional.if_false, false); Statement::Conditional(conditional) } Statement::Conversion(mut conv) => { - conv.dst = f(conv.dst); - conv.src = f(conv.src); + conv.dst = f(conv.dst, true); + conv.src = f(conv.src, false); Statement::Conversion(conv) } Statement::Constant(mut constant) => { - constant.dst = f(constant.dst); + constant.dst = f(constant.dst, true); Statement::Constant(constant) } Statement::RetValue(data, id) => { - let id = f(id); + let id = f(id, false); Statement::RetValue(data, id) } Statement::Undef(typ, id) => { - let id = f(id); + let id = f(id, true); Statement::Undef(typ, id) } - Statement::PtrAdd { + Statement::PtrAccess(PtrAccess { underlying_type, state_space, dst, ptr_src, - constant_src, - } => { - let dst = f(dst); - let ptr_src = f(ptr_src); - let constant_src = f(constant_src); - Statement::PtrAdd { + offset_src: constant_src, + }) => { + let dst = f(dst, true); + let ptr_src = f(ptr_src, false); + let constant_src = f(constant_src, false); + Statement::PtrAccess(PtrAccess { underlying_type, state_space, dst, ptr_src, - constant_src, - } + offset_src: constant_src, + }) } } } @@ -4740,6 +5107,70 @@ impl VisitVariableExpanded for ResolvedCall { } } +impl> PtrAccess

{ + fn map, V: ArgumentMapVisitor>( + self, + visitor: &mut V, + ) -> Result, TranslateError> { + let sema = match self.state_space { + ast::LdStateSpace::Const + | ast::LdStateSpace::Global + | ast::LdStateSpace::Shared + | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer, + ast::LdStateSpace::Local | ast::LdStateSpace::Param => { + ArgumentSemantics::RegisterPointer + } + }; + let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), self.state_space); + let new_dst = visitor.id( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema, + }, + Some(&ptr_type), + )?; + let new_ptr_src = visitor.id( + ArgumentDescriptor { + op: self.ptr_src, + is_dst: false, + sema, + }, + Some(&ptr_type), + )?; + let new_constant_src = visitor.operand( + ArgumentDescriptor { + op: self.offset_src, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + &ast::Type::Scalar(ast::ScalarType::S64), + )?; + Ok(PtrAccess { + underlying_type: self.underlying_type, + state_space: self.state_space, + dst: new_dst, + ptr_src: new_ptr_src, + offset_src: new_constant_src, + }) + } +} + +impl VisitVariable for PtrAccess { + fn visit_variable< + 'a, + F: FnMut( + ArgumentDescriptor, + Option<&ast::Type>, + ) -> Result, + >( + self, + f: &mut F, + ) -> Result { + Ok(Statement::PtrAccess(self.map(f)?)) + } +} + pub trait ArgParamsEx: ast::ArgParams + Sized { fn get_fn_decl<'x, 'b>( id: &Self::Id, @@ -5035,6 +5466,14 @@ pub struct ArgumentDescriptor { sema: ArgumentSemantics, } +pub struct PtrAccess { + underlying_type: ast::PointerType, + state_space: ast::LdStateSpace, + dst: spirv::Word, + ptr_src: spirv::Word, + offset_src: P::Operand, +} + #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum ArgumentSemantics { // normal register access @@ -5263,6 +5702,56 @@ impl ast::Instruction { } impl VisitVariable for ast::Instruction { + fn visit_variable< + 'a, + F: FnMut( + ArgumentDescriptor, + Option<&ast::Type>, + ) -> Result, + >( + self, + f: &mut F, + ) -> Result { + Ok(Statement::Instruction(self.map(f)?)) + } +} + +impl ImplicitConversion { + fn map< + T: ArgParamsEx, + U: ArgParamsEx, + V: ArgumentMapVisitor, + >( + self, + visitor: &mut V, + ) -> Result, U>, TranslateError> { + let new_dst = visitor.id( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema: self.dst_sema, + }, + Some(&self.to), + )?; + let new_src = visitor.id( + ArgumentDescriptor { + op: self.src, + is_dst: false, + sema: self.src_sema, + }, + Some(&self.from), + )?; + Ok(Statement::Conversion({ + ImplicitConversion { + src: new_src, + dst: new_dst, + ..self + } + })) + } +} + +impl VisitVariable for ImplicitConversion { fn visit_variable< 'a, F: FnMut( @@ -5273,7 +5762,21 @@ impl VisitVariable for ast::Instruction { self, f: &mut F, ) -> Result { - Ok(Statement::Instruction(self.map(f)?)) + self.map(f) + } +} + +impl VisitVariableExpanded for ImplicitConversion { + fn visit_variable_extended< + F: FnMut( + ArgumentDescriptor, + Option<&ast::Type>, + ) -> Result, + >( + self, + f: &mut F, + ) -> Result { + self.map(f) } } @@ -5708,6 +6211,8 @@ struct ImplicitConversion { from: ast::Type, to: ast::Type, kind: ConversionKind, + src_sema: ArgumentSemantics, + dst_sema: ArgumentSemantics, } #[derive(PartialEq, Copy, Clone)]