diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index c9ed9b1..ff48ae9 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -180,6 +180,7 @@ test_ptx!( ], [0u32, 0u32, 0u32, 2u32] ); +test_ptx!(non_scalar_ptr_offset, [1u32, 2u32, 3u32, 4u32], [7u32]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/non_scalar_ptr_offset.ptx b/ptx/src/test/spirv_run/non_scalar_ptr_offset.ptx new file mode 100644 index 0000000..14d3d2c --- /dev/null +++ b/ptx/src/test/spirv_run/non_scalar_ptr_offset.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry non_scalar_ptr_offset( + .param .u64 input_p, + .param .u64 output_p +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 x; + .reg .u32 y; + + ld.param.u64 in_addr, [input_p]; + ld.param.u64 out_addr, [output_p]; + + ld.global.v2.u32 {x,y}, [in_addr+8]; + add.u32 x, x, y; + st.global.u32 [out_addr], x; + ret; +} diff --git a/ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt b/ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt new file mode 100644 index 0000000..92dc7cc --- /dev/null +++ b/ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt @@ -0,0 +1,60 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %27 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "non_scalar_ptr_offset" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %30 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint + %ulong_8 = OpConstant %ulong 8 + %v2uint = OpTypeVector %uint 2 +%_ptr_CrossWorkgroup_v2uint = OpTypePointer CrossWorkgroup %v2uint + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar +%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint + %1 = OpFunction %void None %30 + %9 = OpFunctionParameter %ulong + %10 = OpFunctionParameter %ulong + %25 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %2 %9 + OpStore %3 %10 + %11 = OpLoad %ulong %2 Aligned 8 + OpStore %4 %11 + %12 = OpLoad %ulong %3 Aligned 8 + OpStore %5 %12 + %13 = OpLoad %ulong %4 + %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_v2uint %13 + %38 = OpBitcast %_ptr_CrossWorkgroup_uchar %23 + %39 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %38 %ulong_8 + %22 = OpBitcast %_ptr_CrossWorkgroup_v2uint %39 + %8 = OpLoad %v2uint %22 Aligned 8 + %14 = OpCompositeExtract %uint %8 0 + %15 = OpCompositeExtract %uint %8 1 + OpStore %6 %14 + OpStore %7 %15 + %17 = OpLoad %uint %6 + %18 = OpLoad %uint %7 + %16 = OpIAdd %uint %17 %18 + OpStore %6 %16 + %19 = OpLoad %ulong %5 + %20 = OpLoad %uint %6 + %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %19 + OpStore %24 %20 Aligned 4 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c2562c3..e0b82e8 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -2389,10 +2389,6 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { ))); Ok(id_add_result) } else { - let scalar_type = match typ { - ast::Type::Scalar(underlying_type) => *underlying_type, - _ => return Err(error_unreachable()), - }; let id_constant_stmt = self.id_def.register_intermediate( ast::Type::Scalar(ast::ScalarType::S64), ast::StateSpace::Reg, @@ -2404,7 +2400,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { })); let dst = self.id_def.register_intermediate(typ.clone(), state_space); self.func.push(Statement::PtrAccess(PtrAccess { - underlying_type: scalar_type, + underlying_type: typ.clone(), state_space: state_space, dst, ptr_src: reg, @@ -3118,7 +3114,7 @@ fn emit_function_body_ops( ); let result_type = map.get_or_add( builder, - SpirvType::new(ast::Type::Pointer(*underlying_type, *state_space)), + SpirvType::pointer_to(underlying_type.clone(), state_space.to_spirv()), ); let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?; let temp = builder.in_bounds_ptr_access_chain( @@ -4532,7 +4528,7 @@ fn convert_to_stateful_memory_access<'a, 'input>( }; let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { - underlying_type: ast::ScalarType::U8, + underlying_type: ast::Type::Scalar(ast::ScalarType::U8), state_space: ast::StateSpace::Global, dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, @@ -4575,7 +4571,7 @@ fn convert_to_stateful_memory_access<'a, 'input>( ))); let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { - underlying_type: ast::ScalarType::U8, + underlying_type: ast::Type::Scalar(ast::ScalarType::U8), state_space: ast::StateSpace::Global, dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, @@ -5497,7 +5493,6 @@ impl> PtrAccess

{ self, visitor: &mut V, ) -> Result, TranslateError> { - let ptr_type = ast::Type::Scalar(self.underlying_type.clone()); let new_dst = visitor.id( ArgumentDescriptor { op: self.dst, @@ -5505,7 +5500,7 @@ impl> PtrAccess

{ is_memory_access: false, non_default_implicit_conversion: None, }, - Some((&ptr_type, self.state_space)), + Some((&self.underlying_type, self.state_space)), )?; let new_ptr_src = visitor.id( ArgumentDescriptor { @@ -5514,7 +5509,7 @@ impl> PtrAccess

{ is_memory_access: false, non_default_implicit_conversion: None, }, - Some((&ptr_type, self.state_space)), + Some((&self.underlying_type, self.state_space)), )?; let new_constant_src = visitor.operand( ArgumentDescriptor { @@ -5707,7 +5702,7 @@ pub struct ArgumentDescriptor { } pub struct PtrAccess { - underlying_type: ast::ScalarType, + underlying_type: ast::Type, state_space: ast::StateSpace, dst: spirv::Word, ptr_src: spirv::Word,