Allow ptr offsets to non-scalar types

This commit is contained in:
Andrzej Janik 2021-06-25 22:29:25 +02:00
parent 8ef6c3d8b6
commit 23874efe68
4 changed files with 90 additions and 12 deletions

View file

@ -180,6 +180,7 @@ test_ptx!(
],
[0u32, 0u32, 0u32, 2u32]
);
test_ptx!(non_scalar_ptr_offset, [1u32, 2u32, 3u32, 4u32], [7u32]);
struct DisplayError<T: Debug> {
err: T,

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry non_scalar_ptr_offset(
.param .u64 input_p,
.param .u64 output_p
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u32 x;
.reg .u32 y;
ld.param.u64 in_addr, [input_p];
ld.param.u64 out_addr, [output_p];
ld.global.v2.u32 {x,y}, [in_addr+8];
add.u32 x, x, y;
st.global.u32 [out_addr], x;
ret;
}

View file

@ -0,0 +1,60 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%27 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "non_scalar_ptr_offset"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%30 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%ulong_8 = OpConstant %ulong 8
%v2uint = OpTypeVector %uint 2
%_ptr_CrossWorkgroup_v2uint = OpTypePointer CrossWorkgroup %v2uint
%uchar = OpTypeInt 8 0
%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
%1 = OpFunction %void None %30
%9 = OpFunctionParameter %ulong
%10 = OpFunctionParameter %ulong
%25 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function
%7 = OpVariable %_ptr_Function_uint Function
OpStore %2 %9
OpStore %3 %10
%11 = OpLoad %ulong %2 Aligned 8
OpStore %4 %11
%12 = OpLoad %ulong %3 Aligned 8
OpStore %5 %12
%13 = OpLoad %ulong %4
%23 = OpConvertUToPtr %_ptr_CrossWorkgroup_v2uint %13
%38 = OpBitcast %_ptr_CrossWorkgroup_uchar %23
%39 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %38 %ulong_8
%22 = OpBitcast %_ptr_CrossWorkgroup_v2uint %39
%8 = OpLoad %v2uint %22 Aligned 8
%14 = OpCompositeExtract %uint %8 0
%15 = OpCompositeExtract %uint %8 1
OpStore %6 %14
OpStore %7 %15
%17 = OpLoad %uint %6
%18 = OpLoad %uint %7
%16 = OpIAdd %uint %17 %18
OpStore %6 %16
%19 = OpLoad %ulong %5
%20 = OpLoad %uint %6
%24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %19
OpStore %24 %20 Aligned 4
OpReturn
OpFunctionEnd

View file

@ -2389,10 +2389,6 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
)));
Ok(id_add_result)
} else {
let scalar_type = match typ {
ast::Type::Scalar(underlying_type) => *underlying_type,
_ => return Err(error_unreachable()),
};
let id_constant_stmt = self.id_def.register_intermediate(
ast::Type::Scalar(ast::ScalarType::S64),
ast::StateSpace::Reg,
@ -2404,7 +2400,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
}));
let dst = self.id_def.register_intermediate(typ.clone(), state_space);
self.func.push(Statement::PtrAccess(PtrAccess {
underlying_type: scalar_type,
underlying_type: typ.clone(),
state_space: state_space,
dst,
ptr_src: reg,
@ -3118,7 +3114,7 @@ fn emit_function_body_ops(
);
let result_type = map.get_or_add(
builder,
SpirvType::new(ast::Type::Pointer(*underlying_type, *state_space)),
SpirvType::pointer_to(underlying_type.clone(), state_space.to_spirv()),
);
let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?;
let temp = builder.in_bounds_ptr_access_chain(
@ -4532,7 +4528,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
};
let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
underlying_type: ast::ScalarType::U8,
underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
@ -4575,7 +4571,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
)));
let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
underlying_type: ast::ScalarType::U8,
underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
@ -5497,7 +5493,6 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
self,
visitor: &mut V,
) -> Result<PtrAccess<To>, TranslateError> {
let ptr_type = ast::Type::Scalar(self.underlying_type.clone());
let new_dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
@ -5505,7 +5500,7 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
is_memory_access: false,
non_default_implicit_conversion: None,
},
Some((&ptr_type, self.state_space)),
Some((&self.underlying_type, self.state_space)),
)?;
let new_ptr_src = visitor.id(
ArgumentDescriptor {
@ -5514,7 +5509,7 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
is_memory_access: false,
non_default_implicit_conversion: None,
},
Some((&ptr_type, self.state_space)),
Some((&self.underlying_type, self.state_space)),
)?;
let new_constant_src = visitor.operand(
ArgumentDescriptor {
@ -5707,7 +5702,7 @@ pub struct ArgumentDescriptor<Op> {
}
pub struct PtrAccess<P: ast::ArgParams> {
underlying_type: ast::ScalarType,
underlying_type: ast::Type,
state_space: ast::StateSpace,
dst: spirv::Word,
ptr_src: spirv::Word,