From eee780e7b4f4b51e53536878f553d593bcdb0b0a Mon Sep 17 00:00:00 2001
From: Andrzej Janik <vosen@vosen.pl>
Date: Tue, 8 Dec 2020 22:49:33 +0100
Subject: [PATCH] Fix failures with access to builtins

---
 ptx/src/test/spirv_run/vector.spvtxt         | 124 ++++++-------
 ptx/src/test/spirv_run/vector_extract.spvtxt | 183 ++++++++++---------
 ptx/src/translate.rs                         | 121 +++++++-----
 3 files changed, 234 insertions(+), 194 deletions(-)

diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt
index 535e480..a77ab7d 100644
--- a/ptx/src/test/spirv_run/vector.spvtxt
+++ b/ptx/src/test/spirv_run/vector.spvtxt
@@ -7,91 +7,93 @@
                OpCapability Int64
                OpCapability Float16
                OpCapability Float64
-         %57 = OpExtInstImport "OpenCL.std"
+         %51 = OpExtInstImport "OpenCL.std"
                OpMemoryModel Physical64 OpenCL
-               OpEntryPoint Kernel %31 "vector"
+               OpEntryPoint Kernel %25 "vector"
        %void = OpTypeVoid
        %uint = OpTypeInt 32 0
      %v2uint = OpTypeVector %uint 2
-         %61 = OpTypeFunction %v2uint %v2uint
+         %55 = OpTypeFunction %v2uint %v2uint
 %_ptr_Function_v2uint = OpTypePointer Function %v2uint
 %_ptr_Function_uint = OpTypePointer Function %uint
+     %uint_0 = OpConstant %uint 0
+     %uint_1 = OpConstant %uint 1
       %ulong = OpTypeInt 64 0
-         %65 = OpTypeFunction %void %ulong %ulong
+         %67 = OpTypeFunction %void %ulong %ulong
 %_ptr_Function_ulong = OpTypePointer Function %ulong
 %_ptr_Generic_v2uint = OpTypePointer Generic %v2uint
-          %1 = OpFunction %v2uint None %61
+          %1 = OpFunction %v2uint None %55
           %7 = OpFunctionParameter %v2uint
-         %30 = OpLabel
+         %24 = OpLabel
           %2 = OpVariable %_ptr_Function_v2uint Function
           %3 = OpVariable %_ptr_Function_v2uint Function
           %4 = OpVariable %_ptr_Function_v2uint Function
           %5 = OpVariable %_ptr_Function_uint Function
           %6 = OpVariable %_ptr_Function_uint Function
                OpStore %3 %7
-          %9 = OpLoad %v2uint %3
-         %27 = OpCompositeExtract %uint %9 0
-          %8 = OpCopyObject %uint %27
+         %59 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_0
+          %9 = OpLoad %uint %59
+          %8 = OpCopyObject %uint %9
                OpStore %5 %8
-         %11 = OpLoad %v2uint %3
-         %28 = OpCompositeExtract %uint %11 1
-         %10 = OpCopyObject %uint %28
+         %61 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_1
+         %11 = OpLoad %uint %61
+         %10 = OpCopyObject %uint %11
                OpStore %6 %10
          %13 = OpLoad %uint %5
          %14 = OpLoad %uint %6
          %12 = OpIAdd %uint %13 %14
                OpStore %6 %12
-         %16 = OpLoad %v2uint %4
-         %17 = OpLoad %uint %6
-         %15 = OpCompositeInsert %v2uint %17 %16 0
-               OpStore %4 %15
-         %19 = OpLoad %v2uint %4
-         %20 = OpLoad %uint %6
-         %18 = OpCompositeInsert %v2uint %20 %19 1
-               OpStore %4 %18
+         %16 = OpLoad %uint %6
+         %15 = OpCopyObject %uint %16
+         %62 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0
+               OpStore %62 %15
+         %18 = OpLoad %uint %6
+         %17 = OpCopyObject %uint %18
+         %63 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1
+               OpStore %63 %17
+         %64 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1
+         %20 = OpLoad %uint %64
+         %19 = OpCopyObject %uint %20
+         %65 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0
+               OpStore %65 %19
          %22 = OpLoad %v2uint %4
-         %23 = OpLoad %v2uint %4
-         %29 = OpCompositeExtract %uint %23 1
-         %21 = OpCompositeInsert %v2uint %29 %22 0
-               OpStore %4 %21
-         %25 = OpLoad %v2uint %4
-         %24 = OpCopyObject %v2uint %25
-               OpStore %2 %24
-         %26 = OpLoad %v2uint %2
-               OpReturnValue %26
+         %21 = OpCopyObject %v2uint %22
+               OpStore %2 %21
+         %23 = OpLoad %v2uint %2
+               OpReturnValue %23
                OpFunctionEnd
-         %31 = OpFunction %void None %65
-         %40 = OpFunctionParameter %ulong
-         %41 = OpFunctionParameter %ulong
-         %55 = OpLabel
-         %32 = OpVariable %_ptr_Function_ulong Function
+         %25 = OpFunction %void None %67
+         %34 = OpFunctionParameter %ulong
+         %35 = OpFunctionParameter %ulong
+         %49 = OpLabel
+         %26 = OpVariable %_ptr_Function_ulong Function
+         %27 = OpVariable %_ptr_Function_ulong Function
+         %28 = OpVariable %_ptr_Function_ulong Function
+         %29 = OpVariable %_ptr_Function_ulong Function
+         %30 = OpVariable %_ptr_Function_v2uint Function
+         %31 = OpVariable %_ptr_Function_uint Function
+         %32 = OpVariable %_ptr_Function_uint Function
          %33 = OpVariable %_ptr_Function_ulong Function
-         %34 = OpVariable %_ptr_Function_ulong Function
-         %35 = OpVariable %_ptr_Function_ulong Function
-         %36 = OpVariable %_ptr_Function_v2uint Function
-         %37 = OpVariable %_ptr_Function_uint Function
-         %38 = OpVariable %_ptr_Function_uint Function
-         %39 = OpVariable %_ptr_Function_ulong Function
-               OpStore %32 %40
-               OpStore %33 %41
-         %42 = OpLoad %ulong %32
-               OpStore %34 %42
-         %43 = OpLoad %ulong %33
-               OpStore %35 %43
-         %45 = OpLoad %ulong %34
-         %52 = OpConvertUToPtr %_ptr_Generic_v2uint %45
-         %44 = OpLoad %v2uint %52
-               OpStore %36 %44
-         %47 = OpLoad %v2uint %36
-         %46 = OpFunctionCall %v2uint %1 %47
-               OpStore %36 %46
-         %49 = OpLoad %v2uint %36
-         %53 = OpBitcast %ulong %49
-         %48 = OpCopyObject %ulong %53
-               OpStore %39 %48
-         %50 = OpLoad %ulong %35
-         %51 = OpLoad %v2uint %36
-         %54 = OpConvertUToPtr %_ptr_Generic_v2uint %50
-               OpStore %54 %51
+               OpStore %26 %34
+               OpStore %27 %35
+         %36 = OpLoad %ulong %26
+               OpStore %28 %36
+         %37 = OpLoad %ulong %27
+               OpStore %29 %37
+         %39 = OpLoad %ulong %28
+         %46 = OpConvertUToPtr %_ptr_Generic_v2uint %39
+         %38 = OpLoad %v2uint %46
+               OpStore %30 %38
+         %41 = OpLoad %v2uint %30
+         %40 = OpFunctionCall %v2uint %1 %41
+               OpStore %30 %40
+         %43 = OpLoad %v2uint %30
+         %47 = OpBitcast %ulong %43
+         %42 = OpCopyObject %ulong %47
+               OpStore %33 %42
+         %44 = OpLoad %ulong %29
+         %45 = OpLoad %v2uint %30
+         %48 = OpConvertUToPtr %_ptr_Generic_v2uint %44
+               OpStore %48 %45
                OpReturn
                OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/vector_extract.spvtxt b/ptx/src/test/spirv_run/vector_extract.spvtxt
index 4943189..2037dec 100644
--- a/ptx/src/test/spirv_run/vector_extract.spvtxt
+++ b/ptx/src/test/spirv_run/vector_extract.spvtxt
@@ -7,12 +7,12 @@
                OpCapability Int64
                OpCapability Float16
                OpCapability Float64
-         %73 = OpExtInstImport "OpenCL.std"
+         %61 = OpExtInstImport "OpenCL.std"
                OpMemoryModel Physical64 OpenCL
                OpEntryPoint Kernel %1 "vector_extract"
        %void = OpTypeVoid
       %ulong = OpTypeInt 64 0
-         %76 = OpTypeFunction %void %ulong %ulong
+         %64 = OpTypeFunction %void %ulong %ulong
 %_ptr_Function_ulong = OpTypePointer Function %ulong
      %ushort = OpTypeInt 16 0
 %_ptr_Function_ushort = OpTypePointer Function %ushort
@@ -21,10 +21,10 @@
       %uchar = OpTypeInt 8 0
     %v4uchar = OpTypeVector %uchar 4
 %_ptr_CrossWorkgroup_v4uchar = OpTypePointer CrossWorkgroup %v4uchar
-          %1 = OpFunction %void None %76
-         %11 = OpFunctionParameter %ulong
-         %12 = OpFunctionParameter %ulong
-         %71 = OpLabel
+          %1 = OpFunction %void None %64
+         %17 = OpFunctionParameter %ulong
+         %18 = OpFunctionParameter %ulong
+         %59 = OpLabel
           %2 = OpVariable %_ptr_Function_ulong Function
           %3 = OpVariable %_ptr_Function_ulong Function
           %4 = OpVariable %_ptr_Function_ulong Function
@@ -34,89 +34,92 @@
           %8 = OpVariable %_ptr_Function_ushort Function
           %9 = OpVariable %_ptr_Function_ushort Function
          %10 = OpVariable %_ptr_Function_v4ushort Function
-               OpStore %2 %11
-               OpStore %3 %12
-         %13 = OpLoad %ulong %2
-               OpStore %4 %13
-         %14 = OpLoad %ulong %3
-               OpStore %5 %14
-         %19 = OpLoad %ulong %4
-         %61 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %19
-         %43 = OpLoad %v4uchar %61
-         %62 = OpCompositeExtract %uchar %43 0
-         %85 = OpBitcast %uchar %62
-         %15 = OpUConvert %ushort %85
-         %63 = OpCompositeExtract %uchar %43 1
-         %86 = OpBitcast %uchar %63
-         %16 = OpUConvert %ushort %86
-         %64 = OpCompositeExtract %uchar %43 2
-         %87 = OpBitcast %uchar %64
-         %17 = OpUConvert %ushort %87
-         %65 = OpCompositeExtract %uchar %43 3
-         %88 = OpBitcast %uchar %65
-         %18 = OpUConvert %ushort %88
-               OpStore %6 %15
-               OpStore %7 %16
-               OpStore %8 %17
-               OpStore %9 %18
-         %21 = OpLoad %ushort %7
-         %22 = OpLoad %ushort %8
-         %23 = OpLoad %ushort %9
-         %24 = OpLoad %ushort %6
-         %44 = OpUndef %v4ushort
-         %45 = OpCompositeInsert %v4ushort %21 %44 0
-         %46 = OpCompositeInsert %v4ushort %22 %45 1
-         %47 = OpCompositeInsert %v4ushort %23 %46 2
-         %48 = OpCompositeInsert %v4ushort %24 %47 3
-         %20 = OpCopyObject %v4ushort %48
-               OpStore %10 %20
-         %29 = OpLoad %v4ushort %10
-         %49 = OpCopyObject %v4ushort %29
-         %25 = OpCompositeExtract %ushort %49 0
-         %26 = OpCompositeExtract %ushort %49 1
-         %27 = OpCompositeExtract %ushort %49 2
-         %28 = OpCompositeExtract %ushort %49 3
-               OpStore %8 %25
-               OpStore %9 %26
-               OpStore %6 %27
-               OpStore %7 %28
-         %34 = OpLoad %ushort %8
-         %35 = OpLoad %ushort %9
-         %36 = OpLoad %ushort %6
-         %37 = OpLoad %ushort %7
-         %51 = OpUndef %v4ushort
-         %52 = OpCompositeInsert %v4ushort %34 %51 0
-         %53 = OpCompositeInsert %v4ushort %35 %52 1
-         %54 = OpCompositeInsert %v4ushort %36 %53 2
-         %55 = OpCompositeInsert %v4ushort %37 %54 3
-         %50 = OpCopyObject %v4ushort %55
-         %30 = OpCompositeExtract %ushort %50 0
-         %31 = OpCompositeExtract %ushort %50 1
-         %32 = OpCompositeExtract %ushort %50 2
-         %33 = OpCompositeExtract %ushort %50 3
-               OpStore %9 %30
-               OpStore %6 %31
-               OpStore %7 %32
-               OpStore %8 %33
-         %38 = OpLoad %ulong %5
-         %39 = OpLoad %ushort %6
-         %40 = OpLoad %ushort %7
-         %41 = OpLoad %ushort %8
-         %42 = OpLoad %ushort %9
-         %56 = OpUndef %v4uchar
-         %89 = OpBitcast %ushort %39
-         %66 = OpUConvert %uchar %89
-         %57 = OpCompositeInsert %v4uchar %66 %56 0
-         %90 = OpBitcast %ushort %40
-         %67 = OpUConvert %uchar %90
-         %58 = OpCompositeInsert %v4uchar %67 %57 1
-         %91 = OpBitcast %ushort %41
-         %68 = OpUConvert %uchar %91
-         %59 = OpCompositeInsert %v4uchar %68 %58 2
-         %92 = OpBitcast %ushort %42
-         %69 = OpUConvert %uchar %92
-         %60 = OpCompositeInsert %v4uchar %69 %59 3
-         %70 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %38
-               OpStore %70 %60
+               OpStore %2 %17
+               OpStore %3 %18
+         %19 = OpLoad %ulong %2
+               OpStore %4 %19
+         %20 = OpLoad %ulong %3
+               OpStore %5 %20
+         %21 = OpLoad %ulong %4
+         %49 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %21
+         %11 = OpLoad %v4uchar %49
+         %50 = OpCompositeExtract %uchar %11 0
+         %51 = OpCompositeExtract %uchar %11 1
+         %52 = OpCompositeExtract %uchar %11 2
+         %53 = OpCompositeExtract %uchar %11 3
+         %73 = OpBitcast %uchar %50
+         %22 = OpUConvert %ushort %73
+         %74 = OpBitcast %uchar %51
+         %23 = OpUConvert %ushort %74
+         %75 = OpBitcast %uchar %52
+         %24 = OpUConvert %ushort %75
+         %76 = OpBitcast %uchar %53
+         %25 = OpUConvert %ushort %76
+               OpStore %6 %22
+               OpStore %7 %23
+               OpStore %8 %24
+               OpStore %9 %25
+         %26 = OpLoad %ushort %7
+         %27 = OpLoad %ushort %8
+         %28 = OpLoad %ushort %9
+         %29 = OpLoad %ushort %6
+         %77 = OpUndef %v4ushort
+         %78 = OpCompositeInsert %v4ushort %26 %77 0
+         %79 = OpCompositeInsert %v4ushort %27 %78 1
+         %80 = OpCompositeInsert %v4ushort %28 %79 2
+         %81 = OpCompositeInsert %v4ushort %29 %80 3
+         %12 = OpCopyObject %v4ushort %81
+         %30 = OpCopyObject %v4ushort %12
+               OpStore %10 %30
+         %31 = OpLoad %v4ushort %10
+         %13 = OpCopyObject %v4ushort %31
+         %32 = OpCompositeExtract %ushort %13 0
+         %33 = OpCompositeExtract %ushort %13 1
+         %34 = OpCompositeExtract %ushort %13 2
+         %35 = OpCompositeExtract %ushort %13 3
+               OpStore %8 %32
+               OpStore %9 %33
+               OpStore %6 %34
+               OpStore %7 %35
+         %36 = OpLoad %ushort %8
+         %37 = OpLoad %ushort %9
+         %38 = OpLoad %ushort %6
+         %39 = OpLoad %ushort %7
+         %82 = OpUndef %v4ushort
+         %83 = OpCompositeInsert %v4ushort %36 %82 0
+         %84 = OpCompositeInsert %v4ushort %37 %83 1
+         %85 = OpCompositeInsert %v4ushort %38 %84 2
+         %86 = OpCompositeInsert %v4ushort %39 %85 3
+         %15 = OpCopyObject %v4ushort %86
+         %14 = OpCopyObject %v4ushort %15
+         %40 = OpCompositeExtract %ushort %14 0
+         %41 = OpCompositeExtract %ushort %14 1
+         %42 = OpCompositeExtract %ushort %14 2
+         %43 = OpCompositeExtract %ushort %14 3
+               OpStore %9 %40
+               OpStore %6 %41
+               OpStore %7 %42
+               OpStore %8 %43
+         %44 = OpLoad %ushort %6
+         %45 = OpLoad %ushort %7
+         %46 = OpLoad %ushort %8
+         %47 = OpLoad %ushort %9
+         %87 = OpBitcast %ushort %44
+         %54 = OpUConvert %uchar %87
+         %88 = OpBitcast %ushort %45
+         %55 = OpUConvert %uchar %88
+         %89 = OpBitcast %ushort %46
+         %56 = OpUConvert %uchar %89
+         %90 = OpBitcast %ushort %47
+         %57 = OpUConvert %uchar %90
+         %91 = OpUndef %v4uchar
+         %92 = OpCompositeInsert %v4uchar %54 %91 0
+         %93 = OpCompositeInsert %v4uchar %55 %92 1
+         %94 = OpCompositeInsert %v4uchar %56 %93 2
+         %95 = OpCompositeInsert %v4uchar %57 %94 3
+         %16 = OpCopyObject %v4uchar %95
+         %48 = OpLoad %ulong %5
+         %58 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %48
+               OpStore %58 %16
                OpReturn
                OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index db062db..ca64e60 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -2091,13 +2091,21 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
         };
         let member_index = match desc.op.1 {
             Some(idx) => {
-                match var_type {
-                    ast::Type::Vector(scalar_t, _) => {
+                let vector_width = match var_type {
+                    ast::Type::Vector(scalar_t, width) => {
                         var_type = ast::Type::Scalar(scalar_t);
+                        width
                     }
                     _ => return Err(TranslateError::MismatchedType),
-                }
-                Some((idx, self.id_def.special_registers.contains_key(&symbol)))
+                };
+                Some((
+                    idx,
+                    if self.id_def.special_registers.contains_key(&symbol) {
+                        Some(vector_width)
+                    } else {
+                        None
+                    },
+                ))
             }
             None => None,
         };
@@ -2119,7 +2127,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
                         src2: generated_id,
                     },
                     typ: var_type,
-                    member_index,
+                    member_index: member_index.map(|(idx, _)| idx),
                 }));
         }
         Ok(generated_id)
@@ -3159,45 +3167,17 @@ fn emit_function_body_ops(
                 }
             },
             Statement::LoadVar(details) => {
-                let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone()));
-                let src = match details.member_index {
-                    Some((index, is_sreg)) => {
-                        let storage_class = if is_sreg {
-                            spirv::StorageClass::Input
-                        } else {
-                            spirv::StorageClass::Function
-                        };
-                        let result_ptr_type = map.get_or_add(
-                            builder,
-                            SpirvType::new_pointer(details.typ.clone(), storage_class),
-                        );
-                        let index_spirv = map.get_or_add_constant(
-                            builder,
-                            &ast::Type::Scalar(ast::ScalarType::U32),
-                            &vec_repr(index as u32),
-                        )?;
-                        builder.in_bounds_access_chain(
-                            result_ptr_type,
-                            None,
-                            details.arg.src,
-                            &[index_spirv],
-                        )?
-                    }
-                    None => details.arg.src,
-                };
-                builder.load(result_type, Some(details.arg.dst), src, None, [])?;
+                emit_load_var(builder, map, details)?;
             }
             Statement::StoreVar(details) => {
                 let dst_ptr = match details.member_index {
-                    Some((index, is_sreg)) => {
-                        let storage_class = if is_sreg {
-                            spirv::StorageClass::Input
-                        } else {
-                            spirv::StorageClass::Function
-                        };
+                    Some(index) => {
                         let result_ptr_type = map.get_or_add(
                             builder,
-                            SpirvType::new_pointer(details.typ.clone(), storage_class),
+                            SpirvType::new_pointer(
+                                details.typ.clone(),
+                                spirv::StorageClass::Function,
+                            ),
                         );
                         let index_spirv = map.get_or_add_constant(
                             builder,
@@ -4189,6 +4169,58 @@ fn emit_implicit_conversion(
     Ok(())
 }
 
+fn emit_load_var(
+    builder: &mut dr::Builder,
+    map: &mut TypeWordMap,
+    details: &LoadVarDetails,
+) -> Result<(), TranslateError> {
+    let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone()));
+    match details.member_index {
+        Some((index, Some(width))) => {
+            let vector_type = match details.typ {
+                ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
+                _ => return Err(TranslateError::MismatchedType),
+            };
+            let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type));
+            let vector_temp = builder.load(vector_type_spirv, None, details.arg.src, None, [])?;
+            builder.composite_extract(
+                result_type,
+                Some(details.arg.dst),
+                vector_temp,
+                &[index as u32],
+            )?;
+        }
+        Some((index, None)) => {
+            let result_ptr_type = map.get_or_add(
+                builder,
+                SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function),
+            );
+            let index_spirv = map.get_or_add_constant(
+                builder,
+                &ast::Type::Scalar(ast::ScalarType::U32),
+                &vec_repr(index as u32),
+            )?;
+            let src = builder.in_bounds_access_chain(
+                result_ptr_type,
+                None,
+                details.arg.src,
+                &[index_spirv],
+            )?;
+            builder.load(result_type, Some(details.arg.dst), src, None, [])?;
+        }
+        None => {
+            builder.load(
+                result_type,
+                Some(details.arg.dst),
+                details.arg.src,
+                None,
+                [],
+            )?;
+        }
+    };
+    Ok(())
+}
+
 fn normalize_identifiers<'a, 'b>(
     id_defs: &mut FnStringIdResolver<'a, 'b>,
     fn_defs: &GlobalFnDeclResolver<'a, 'b>,
@@ -5106,15 +5138,18 @@ impl ExpandedStatement {
 struct LoadVarDetails {
     arg: ast::Arg2<ExpandedArgParams>,
     typ: ast::Type,
-    // (index, is_sreg)
-    member_index: Option<(u8, bool)>,
+    // (index, vector_width)
+    // HACK ALERT
+    // For some reason IGC explodes when you try to load from builtin vectors
+    // using OpInBoundsAccessChain, the one true way to do it is to
+    // OpLoad+OpCompositeExtract
+    member_index: Option<(u8, Option<u8>)>,
 }
 
 struct StoreVarDetails {
     arg: ast::Arg2St<ExpandedArgParams>,
     typ: ast::Type,
-    // (index, is_sreg)
-    member_index: Option<(u8, bool)>,
+    member_index: Option<u8>,
 }
 
 struct RepackVectorDetails {