diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt index 535e480..a77ab7d 100644 --- a/ptx/src/test/spirv_run/vector.spvtxt +++ b/ptx/src/test/spirv_run/vector.spvtxt @@ -7,91 +7,93 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %57 = OpExtInstImport "OpenCL.std" + %51 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %31 "vector" + OpEntryPoint Kernel %25 "vector" %void = OpTypeVoid %uint = OpTypeInt 32 0 %v2uint = OpTypeVector %uint 2 - %61 = OpTypeFunction %v2uint %v2uint + %55 = OpTypeFunction %v2uint %v2uint %_ptr_Function_v2uint = OpTypePointer Function %v2uint %_ptr_Function_uint = OpTypePointer Function %uint + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 %ulong = OpTypeInt 64 0 - %65 = OpTypeFunction %void %ulong %ulong + %67 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Generic_v2uint = OpTypePointer Generic %v2uint - %1 = OpFunction %v2uint None %61 + %1 = OpFunction %v2uint None %55 %7 = OpFunctionParameter %v2uint - %30 = OpLabel + %24 = OpLabel %2 = OpVariable %_ptr_Function_v2uint Function %3 = OpVariable %_ptr_Function_v2uint Function %4 = OpVariable %_ptr_Function_v2uint Function %5 = OpVariable %_ptr_Function_uint Function %6 = OpVariable %_ptr_Function_uint Function OpStore %3 %7 - %9 = OpLoad %v2uint %3 - %27 = OpCompositeExtract %uint %9 0 - %8 = OpCopyObject %uint %27 + %59 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_0 + %9 = OpLoad %uint %59 + %8 = OpCopyObject %uint %9 OpStore %5 %8 - %11 = OpLoad %v2uint %3 - %28 = OpCompositeExtract %uint %11 1 - %10 = OpCopyObject %uint %28 + %61 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_1 + %11 = OpLoad %uint %61 + %10 = OpCopyObject %uint %11 OpStore %6 %10 %13 = OpLoad %uint %5 %14 = OpLoad %uint %6 %12 = OpIAdd %uint %13 %14 OpStore %6 %12 - %16 = OpLoad %v2uint %4 - %17 = OpLoad %uint %6 - %15 = OpCompositeInsert %v2uint %17 %16 0 - OpStore %4 %15 - %19 = OpLoad %v2uint %4 - %20 = OpLoad %uint %6 - %18 = OpCompositeInsert %v2uint %20 %19 1 - OpStore %4 %18 + %16 = OpLoad %uint %6 + %15 = OpCopyObject %uint %16 + %62 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0 + OpStore %62 %15 + %18 = OpLoad %uint %6 + %17 = OpCopyObject %uint %18 + %63 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1 + OpStore %63 %17 + %64 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1 + %20 = OpLoad %uint %64 + %19 = OpCopyObject %uint %20 + %65 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0 + OpStore %65 %19 %22 = OpLoad %v2uint %4 - %23 = OpLoad %v2uint %4 - %29 = OpCompositeExtract %uint %23 1 - %21 = OpCompositeInsert %v2uint %29 %22 0 - OpStore %4 %21 - %25 = OpLoad %v2uint %4 - %24 = OpCopyObject %v2uint %25 - OpStore %2 %24 - %26 = OpLoad %v2uint %2 - OpReturnValue %26 + %21 = OpCopyObject %v2uint %22 + OpStore %2 %21 + %23 = OpLoad %v2uint %2 + OpReturnValue %23 OpFunctionEnd - %31 = OpFunction %void None %65 - %40 = OpFunctionParameter %ulong - %41 = OpFunctionParameter %ulong - %55 = OpLabel - %32 = OpVariable %_ptr_Function_ulong Function + %25 = OpFunction %void None %67 + %34 = OpFunctionParameter %ulong + %35 = OpFunctionParameter %ulong + %49 = OpLabel + %26 = OpVariable %_ptr_Function_ulong Function + %27 = OpVariable %_ptr_Function_ulong Function + %28 = OpVariable %_ptr_Function_ulong Function + %29 = OpVariable %_ptr_Function_ulong Function + %30 = OpVariable %_ptr_Function_v2uint Function + %31 = OpVariable %_ptr_Function_uint Function + %32 = OpVariable %_ptr_Function_uint Function %33 = OpVariable %_ptr_Function_ulong Function - %34 = OpVariable %_ptr_Function_ulong Function - %35 = OpVariable %_ptr_Function_ulong Function - %36 = OpVariable %_ptr_Function_v2uint Function - %37 = OpVariable %_ptr_Function_uint Function - %38 = OpVariable %_ptr_Function_uint Function - %39 = OpVariable %_ptr_Function_ulong Function - OpStore %32 %40 - OpStore %33 %41 - %42 = OpLoad %ulong %32 - OpStore %34 %42 - %43 = OpLoad %ulong %33 - OpStore %35 %43 - %45 = OpLoad %ulong %34 - %52 = OpConvertUToPtr %_ptr_Generic_v2uint %45 - %44 = OpLoad %v2uint %52 - OpStore %36 %44 - %47 = OpLoad %v2uint %36 - %46 = OpFunctionCall %v2uint %1 %47 - OpStore %36 %46 - %49 = OpLoad %v2uint %36 - %53 = OpBitcast %ulong %49 - %48 = OpCopyObject %ulong %53 - OpStore %39 %48 - %50 = OpLoad %ulong %35 - %51 = OpLoad %v2uint %36 - %54 = OpConvertUToPtr %_ptr_Generic_v2uint %50 - OpStore %54 %51 + OpStore %26 %34 + OpStore %27 %35 + %36 = OpLoad %ulong %26 + OpStore %28 %36 + %37 = OpLoad %ulong %27 + OpStore %29 %37 + %39 = OpLoad %ulong %28 + %46 = OpConvertUToPtr %_ptr_Generic_v2uint %39 + %38 = OpLoad %v2uint %46 + OpStore %30 %38 + %41 = OpLoad %v2uint %30 + %40 = OpFunctionCall %v2uint %1 %41 + OpStore %30 %40 + %43 = OpLoad %v2uint %30 + %47 = OpBitcast %ulong %43 + %42 = OpCopyObject %ulong %47 + OpStore %33 %42 + %44 = OpLoad %ulong %29 + %45 = OpLoad %v2uint %30 + %48 = OpConvertUToPtr %_ptr_Generic_v2uint %44 + OpStore %48 %45 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector_extract.spvtxt b/ptx/src/test/spirv_run/vector_extract.spvtxt index 4943189..2037dec 100644 --- a/ptx/src/test/spirv_run/vector_extract.spvtxt +++ b/ptx/src/test/spirv_run/vector_extract.spvtxt @@ -7,12 +7,12 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %73 = OpExtInstImport "OpenCL.std" + %61 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "vector_extract" %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %76 = OpTypeFunction %void %ulong %ulong + %64 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %ushort = OpTypeInt 16 0 %_ptr_Function_ushort = OpTypePointer Function %ushort @@ -21,10 +21,10 @@ %uchar = OpTypeInt 8 0 %v4uchar = OpTypeVector %uchar 4 %_ptr_CrossWorkgroup_v4uchar = OpTypePointer CrossWorkgroup %v4uchar - %1 = OpFunction %void None %76 - %11 = OpFunctionParameter %ulong - %12 = OpFunctionParameter %ulong - %71 = OpLabel + %1 = OpFunction %void None %64 + %17 = OpFunctionParameter %ulong + %18 = OpFunctionParameter %ulong + %59 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -34,89 +34,92 @@ %8 = OpVariable %_ptr_Function_ushort Function %9 = OpVariable %_ptr_Function_ushort Function %10 = OpVariable %_ptr_Function_v4ushort Function - OpStore %2 %11 - OpStore %3 %12 - %13 = OpLoad %ulong %2 - OpStore %4 %13 - %14 = OpLoad %ulong %3 - OpStore %5 %14 - %19 = OpLoad %ulong %4 - %61 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %19 - %43 = OpLoad %v4uchar %61 - %62 = OpCompositeExtract %uchar %43 0 - %85 = OpBitcast %uchar %62 - %15 = OpUConvert %ushort %85 - %63 = OpCompositeExtract %uchar %43 1 - %86 = OpBitcast %uchar %63 - %16 = OpUConvert %ushort %86 - %64 = OpCompositeExtract %uchar %43 2 - %87 = OpBitcast %uchar %64 - %17 = OpUConvert %ushort %87 - %65 = OpCompositeExtract %uchar %43 3 - %88 = OpBitcast %uchar %65 - %18 = OpUConvert %ushort %88 - OpStore %6 %15 - OpStore %7 %16 - OpStore %8 %17 - OpStore %9 %18 - %21 = OpLoad %ushort %7 - %22 = OpLoad %ushort %8 - %23 = OpLoad %ushort %9 - %24 = OpLoad %ushort %6 - %44 = OpUndef %v4ushort - %45 = OpCompositeInsert %v4ushort %21 %44 0 - %46 = OpCompositeInsert %v4ushort %22 %45 1 - %47 = OpCompositeInsert %v4ushort %23 %46 2 - %48 = OpCompositeInsert %v4ushort %24 %47 3 - %20 = OpCopyObject %v4ushort %48 - OpStore %10 %20 - %29 = OpLoad %v4ushort %10 - %49 = OpCopyObject %v4ushort %29 - %25 = OpCompositeExtract %ushort %49 0 - %26 = OpCompositeExtract %ushort %49 1 - %27 = OpCompositeExtract %ushort %49 2 - %28 = OpCompositeExtract %ushort %49 3 - OpStore %8 %25 - OpStore %9 %26 - OpStore %6 %27 - OpStore %7 %28 - %34 = OpLoad %ushort %8 - %35 = OpLoad %ushort %9 - %36 = OpLoad %ushort %6 - %37 = OpLoad %ushort %7 - %51 = OpUndef %v4ushort - %52 = OpCompositeInsert %v4ushort %34 %51 0 - %53 = OpCompositeInsert %v4ushort %35 %52 1 - %54 = OpCompositeInsert %v4ushort %36 %53 2 - %55 = OpCompositeInsert %v4ushort %37 %54 3 - %50 = OpCopyObject %v4ushort %55 - %30 = OpCompositeExtract %ushort %50 0 - %31 = OpCompositeExtract %ushort %50 1 - %32 = OpCompositeExtract %ushort %50 2 - %33 = OpCompositeExtract %ushort %50 3 - OpStore %9 %30 - OpStore %6 %31 - OpStore %7 %32 - OpStore %8 %33 - %38 = OpLoad %ulong %5 - %39 = OpLoad %ushort %6 - %40 = OpLoad %ushort %7 - %41 = OpLoad %ushort %8 - %42 = OpLoad %ushort %9 - %56 = OpUndef %v4uchar - %89 = OpBitcast %ushort %39 - %66 = OpUConvert %uchar %89 - %57 = OpCompositeInsert %v4uchar %66 %56 0 - %90 = OpBitcast %ushort %40 - %67 = OpUConvert %uchar %90 - %58 = OpCompositeInsert %v4uchar %67 %57 1 - %91 = OpBitcast %ushort %41 - %68 = OpUConvert %uchar %91 - %59 = OpCompositeInsert %v4uchar %68 %58 2 - %92 = OpBitcast %ushort %42 - %69 = OpUConvert %uchar %92 - %60 = OpCompositeInsert %v4uchar %69 %59 3 - %70 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %38 - OpStore %70 %60 + OpStore %2 %17 + OpStore %3 %18 + %19 = OpLoad %ulong %2 + OpStore %4 %19 + %20 = OpLoad %ulong %3 + OpStore %5 %20 + %21 = OpLoad %ulong %4 + %49 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %21 + %11 = OpLoad %v4uchar %49 + %50 = OpCompositeExtract %uchar %11 0 + %51 = OpCompositeExtract %uchar %11 1 + %52 = OpCompositeExtract %uchar %11 2 + %53 = OpCompositeExtract %uchar %11 3 + %73 = OpBitcast %uchar %50 + %22 = OpUConvert %ushort %73 + %74 = OpBitcast %uchar %51 + %23 = OpUConvert %ushort %74 + %75 = OpBitcast %uchar %52 + %24 = OpUConvert %ushort %75 + %76 = OpBitcast %uchar %53 + %25 = OpUConvert %ushort %76 + OpStore %6 %22 + OpStore %7 %23 + OpStore %8 %24 + OpStore %9 %25 + %26 = OpLoad %ushort %7 + %27 = OpLoad %ushort %8 + %28 = OpLoad %ushort %9 + %29 = OpLoad %ushort %6 + %77 = OpUndef %v4ushort + %78 = OpCompositeInsert %v4ushort %26 %77 0 + %79 = OpCompositeInsert %v4ushort %27 %78 1 + %80 = OpCompositeInsert %v4ushort %28 %79 2 + %81 = OpCompositeInsert %v4ushort %29 %80 3 + %12 = OpCopyObject %v4ushort %81 + %30 = OpCopyObject %v4ushort %12 + OpStore %10 %30 + %31 = OpLoad %v4ushort %10 + %13 = OpCopyObject %v4ushort %31 + %32 = OpCompositeExtract %ushort %13 0 + %33 = OpCompositeExtract %ushort %13 1 + %34 = OpCompositeExtract %ushort %13 2 + %35 = OpCompositeExtract %ushort %13 3 + OpStore %8 %32 + OpStore %9 %33 + OpStore %6 %34 + OpStore %7 %35 + %36 = OpLoad %ushort %8 + %37 = OpLoad %ushort %9 + %38 = OpLoad %ushort %6 + %39 = OpLoad %ushort %7 + %82 = OpUndef %v4ushort + %83 = OpCompositeInsert %v4ushort %36 %82 0 + %84 = OpCompositeInsert %v4ushort %37 %83 1 + %85 = OpCompositeInsert %v4ushort %38 %84 2 + %86 = OpCompositeInsert %v4ushort %39 %85 3 + %15 = OpCopyObject %v4ushort %86 + %14 = OpCopyObject %v4ushort %15 + %40 = OpCompositeExtract %ushort %14 0 + %41 = OpCompositeExtract %ushort %14 1 + %42 = OpCompositeExtract %ushort %14 2 + %43 = OpCompositeExtract %ushort %14 3 + OpStore %9 %40 + OpStore %6 %41 + OpStore %7 %42 + OpStore %8 %43 + %44 = OpLoad %ushort %6 + %45 = OpLoad %ushort %7 + %46 = OpLoad %ushort %8 + %47 = OpLoad %ushort %9 + %87 = OpBitcast %ushort %44 + %54 = OpUConvert %uchar %87 + %88 = OpBitcast %ushort %45 + %55 = OpUConvert %uchar %88 + %89 = OpBitcast %ushort %46 + %56 = OpUConvert %uchar %89 + %90 = OpBitcast %ushort %47 + %57 = OpUConvert %uchar %90 + %91 = OpUndef %v4uchar + %92 = OpCompositeInsert %v4uchar %54 %91 0 + %93 = OpCompositeInsert %v4uchar %55 %92 1 + %94 = OpCompositeInsert %v4uchar %56 %93 2 + %95 = OpCompositeInsert %v4uchar %57 %94 3 + %16 = OpCopyObject %v4uchar %95 + %48 = OpLoad %ulong %5 + %58 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %48 + OpStore %58 %16 OpReturn OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index db062db..ca64e60 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -2091,13 +2091,21 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { }; let member_index = match desc.op.1 { Some(idx) => { - match var_type { - ast::Type::Vector(scalar_t, _) => { + let vector_width = match var_type { + ast::Type::Vector(scalar_t, width) => { var_type = ast::Type::Scalar(scalar_t); + width } _ => return Err(TranslateError::MismatchedType), - } - Some((idx, self.id_def.special_registers.contains_key(&symbol))) + }; + Some(( + idx, + if self.id_def.special_registers.contains_key(&symbol) { + Some(vector_width) + } else { + None + }, + )) } None => None, }; @@ -2119,7 +2127,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { src2: generated_id, }, typ: var_type, - member_index, + member_index: member_index.map(|(idx, _)| idx), })); } Ok(generated_id) @@ -3159,45 +3167,17 @@ fn emit_function_body_ops( } }, Statement::LoadVar(details) => { - let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone())); - let src = match details.member_index { - Some((index, is_sreg)) => { - let storage_class = if is_sreg { - spirv::StorageClass::Input - } else { - spirv::StorageClass::Function - }; - let result_ptr_type = map.get_or_add( - builder, - SpirvType::new_pointer(details.typ.clone(), storage_class), - ); - let index_spirv = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(index as u32), - )?; - builder.in_bounds_access_chain( - result_ptr_type, - None, - details.arg.src, - &[index_spirv], - )? - } - None => details.arg.src, - }; - builder.load(result_type, Some(details.arg.dst), src, None, [])?; + emit_load_var(builder, map, details)?; } Statement::StoreVar(details) => { let dst_ptr = match details.member_index { - Some((index, is_sreg)) => { - let storage_class = if is_sreg { - spirv::StorageClass::Input - } else { - spirv::StorageClass::Function - }; + Some(index) => { let result_ptr_type = map.get_or_add( builder, - SpirvType::new_pointer(details.typ.clone(), storage_class), + SpirvType::new_pointer( + details.typ.clone(), + spirv::StorageClass::Function, + ), ); let index_spirv = map.get_or_add_constant( builder, @@ -4189,6 +4169,58 @@ fn emit_implicit_conversion( Ok(()) } +fn emit_load_var( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + details: &LoadVarDetails, +) -> Result<(), TranslateError> { + let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone())); + match details.member_index { + Some((index, Some(width))) => { + let vector_type = match details.typ { + ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), + _ => return Err(TranslateError::MismatchedType), + }; + let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type)); + let vector_temp = builder.load(vector_type_spirv, None, details.arg.src, None, [])?; + builder.composite_extract( + result_type, + Some(details.arg.dst), + vector_temp, + &[index as u32], + )?; + } + Some((index, None)) => { + let result_ptr_type = map.get_or_add( + builder, + SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function), + ); + let index_spirv = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(index as u32), + )?; + let src = builder.in_bounds_access_chain( + result_ptr_type, + None, + details.arg.src, + &[index_spirv], + )?; + builder.load(result_type, Some(details.arg.dst), src, None, [])?; + } + None => { + builder.load( + result_type, + Some(details.arg.dst), + details.arg.src, + None, + [], + )?; + } + }; + Ok(()) +} + fn normalize_identifiers<'a, 'b>( id_defs: &mut FnStringIdResolver<'a, 'b>, fn_defs: &GlobalFnDeclResolver<'a, 'b>, @@ -5106,15 +5138,18 @@ impl ExpandedStatement { struct LoadVarDetails { arg: ast::Arg2, typ: ast::Type, - // (index, is_sreg) - member_index: Option<(u8, bool)>, + // (index, vector_width) + // HACK ALERT + // For some reason IGC explodes when you try to load from builtin vectors + // using OpInBoundsAccessChain, the one true way to do it is to + // OpLoad+OpCompositeExtract + member_index: Option<(u8, Option)>, } struct StoreVarDetails { arg: ast::Arg2St, typ: ast::Type, - // (index, is_sreg) - member_index: Option<(u8, bool)>, + member_index: Option, } struct RepackVectorDetails {