mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 08:24:44 +00:00
Allow ptr offsets to non-scalar types
This commit is contained in:
parent
8ef6c3d8b6
commit
23874efe68
4 changed files with 90 additions and 12 deletions
|
@ -180,6 +180,7 @@ test_ptx!(
|
|||
],
|
||||
[0u32, 0u32, 0u32, 2u32]
|
||||
);
|
||||
test_ptx!(non_scalar_ptr_offset, [1u32, 2u32, 3u32, 4u32], [7u32]);
|
||||
|
||||
struct DisplayError<T: Debug> {
|
||||
err: T,
|
||||
|
|
22
ptx/src/test/spirv_run/non_scalar_ptr_offset.ptx
Normal file
22
ptx/src/test/spirv_run/non_scalar_ptr_offset.ptx
Normal file
|
@ -0,0 +1,22 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry non_scalar_ptr_offset(
|
||||
.param .u64 input_p,
|
||||
.param .u64 output_p
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u32 x;
|
||||
.reg .u32 y;
|
||||
|
||||
ld.param.u64 in_addr, [input_p];
|
||||
ld.param.u64 out_addr, [output_p];
|
||||
|
||||
ld.global.v2.u32 {x,y}, [in_addr+8];
|
||||
add.u32 x, x, y;
|
||||
st.global.u32 [out_addr], x;
|
||||
ret;
|
||||
}
|
60
ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt
Normal file
60
ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt
Normal file
|
@ -0,0 +1,60 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int8
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
%27 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "non_scalar_ptr_offset"
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%30 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%uint = OpTypeInt 32 0
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%ulong_8 = OpConstant %ulong 8
|
||||
%v2uint = OpTypeVector %uint 2
|
||||
%_ptr_CrossWorkgroup_v2uint = OpTypePointer CrossWorkgroup %v2uint
|
||||
%uchar = OpTypeInt 8 0
|
||||
%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
|
||||
%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
|
||||
%1 = OpFunction %void None %30
|
||||
%9 = OpFunctionParameter %ulong
|
||||
%10 = OpFunctionParameter %ulong
|
||||
%25 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_uint Function
|
||||
%7 = OpVariable %_ptr_Function_uint Function
|
||||
OpStore %2 %9
|
||||
OpStore %3 %10
|
||||
%11 = OpLoad %ulong %2 Aligned 8
|
||||
OpStore %4 %11
|
||||
%12 = OpLoad %ulong %3 Aligned 8
|
||||
OpStore %5 %12
|
||||
%13 = OpLoad %ulong %4
|
||||
%23 = OpConvertUToPtr %_ptr_CrossWorkgroup_v2uint %13
|
||||
%38 = OpBitcast %_ptr_CrossWorkgroup_uchar %23
|
||||
%39 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %38 %ulong_8
|
||||
%22 = OpBitcast %_ptr_CrossWorkgroup_v2uint %39
|
||||
%8 = OpLoad %v2uint %22 Aligned 8
|
||||
%14 = OpCompositeExtract %uint %8 0
|
||||
%15 = OpCompositeExtract %uint %8 1
|
||||
OpStore %6 %14
|
||||
OpStore %7 %15
|
||||
%17 = OpLoad %uint %6
|
||||
%18 = OpLoad %uint %7
|
||||
%16 = OpIAdd %uint %17 %18
|
||||
OpStore %6 %16
|
||||
%19 = OpLoad %ulong %5
|
||||
%20 = OpLoad %uint %6
|
||||
%24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %19
|
||||
OpStore %24 %20 Aligned 4
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -2389,10 +2389,6 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
|
|||
)));
|
||||
Ok(id_add_result)
|
||||
} else {
|
||||
let scalar_type = match typ {
|
||||
ast::Type::Scalar(underlying_type) => *underlying_type,
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let id_constant_stmt = self.id_def.register_intermediate(
|
||||
ast::Type::Scalar(ast::ScalarType::S64),
|
||||
ast::StateSpace::Reg,
|
||||
|
@ -2404,7 +2400,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
|
|||
}));
|
||||
let dst = self.id_def.register_intermediate(typ.clone(), state_space);
|
||||
self.func.push(Statement::PtrAccess(PtrAccess {
|
||||
underlying_type: scalar_type,
|
||||
underlying_type: typ.clone(),
|
||||
state_space: state_space,
|
||||
dst,
|
||||
ptr_src: reg,
|
||||
|
@ -3118,7 +3114,7 @@ fn emit_function_body_ops(
|
|||
);
|
||||
let result_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::new(ast::Type::Pointer(*underlying_type, *state_space)),
|
||||
SpirvType::pointer_to(underlying_type.clone(), state_space.to_spirv()),
|
||||
);
|
||||
let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?;
|
||||
let temp = builder.in_bounds_ptr_access_chain(
|
||||
|
@ -4532,7 +4528,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
|
|||
};
|
||||
let dst = arg.dst.upcast().unwrap_reg()?;
|
||||
result.push(Statement::PtrAccess(PtrAccess {
|
||||
underlying_type: ast::ScalarType::U8,
|
||||
underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
|
||||
state_space: ast::StateSpace::Global,
|
||||
dst: *remapped_ids.get(&dst).unwrap(),
|
||||
ptr_src: *ptr,
|
||||
|
@ -4575,7 +4571,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
|
|||
)));
|
||||
let dst = arg.dst.upcast().unwrap_reg()?;
|
||||
result.push(Statement::PtrAccess(PtrAccess {
|
||||
underlying_type: ast::ScalarType::U8,
|
||||
underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
|
||||
state_space: ast::StateSpace::Global,
|
||||
dst: *remapped_ids.get(&dst).unwrap(),
|
||||
ptr_src: *ptr,
|
||||
|
@ -5497,7 +5493,6 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
|
|||
self,
|
||||
visitor: &mut V,
|
||||
) -> Result<PtrAccess<To>, TranslateError> {
|
||||
let ptr_type = ast::Type::Scalar(self.underlying_type.clone());
|
||||
let new_dst = visitor.id(
|
||||
ArgumentDescriptor {
|
||||
op: self.dst,
|
||||
|
@ -5505,7 +5500,7 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
|
|||
is_memory_access: false,
|
||||
non_default_implicit_conversion: None,
|
||||
},
|
||||
Some((&ptr_type, self.state_space)),
|
||||
Some((&self.underlying_type, self.state_space)),
|
||||
)?;
|
||||
let new_ptr_src = visitor.id(
|
||||
ArgumentDescriptor {
|
||||
|
@ -5514,7 +5509,7 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
|
|||
is_memory_access: false,
|
||||
non_default_implicit_conversion: None,
|
||||
},
|
||||
Some((&ptr_type, self.state_space)),
|
||||
Some((&self.underlying_type, self.state_space)),
|
||||
)?;
|
||||
let new_constant_src = visitor.operand(
|
||||
ArgumentDescriptor {
|
||||
|
@ -5707,7 +5702,7 @@ pub struct ArgumentDescriptor<Op> {
|
|||
}
|
||||
|
||||
pub struct PtrAccess<P: ast::ArgParams> {
|
||||
underlying_type: ast::ScalarType,
|
||||
underlying_type: ast::Type,
|
||||
state_space: ast::StateSpace,
|
||||
dst: spirv::Word,
|
||||
ptr_src: spirv::Word,
|
||||
|
|
Loading…
Add table
Reference in a new issue