Regenerate SPIR-V for ptx_impl and fix weird handling of ptr-ptr add or sub

This commit is contained in:
Andrzej Janik 2021-07-03 02:13:38 +02:00
parent e328ecc550
commit ad2059872a
5 changed files with 164 additions and 22 deletions

Binary file not shown.

View file

@ -181,6 +181,7 @@ test_ptx!(
[0u32, 0u32, 0u32, 2u32]
);
test_ptx!(non_scalar_ptr_offset, [1u32, 2u32, 3u32, 4u32], [7u32]);
test_ptx!(stateful_neg_offset, [1237518u64], [1237518u64]);
struct DisplayError<T: Debug> {
err: T,

View file

@ -0,0 +1,29 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry stateful_neg_offset(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 in_addr2;
.reg .u64 out_addr2;
.reg .u64 offset;
.reg .u64 temp;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
cvta.to.global.u64 in_addr2, in_addr;
cvta.to.global.u64 out_addr2, out_addr;
add.u64 offset, in_addr2, out_addr2;
sub.u64 offset, in_addr2, out_addr2;
ld.global.u64 temp, [in_addr2];
st.global.u64 [out_addr2], temp;
ret;
}

View file

@ -0,0 +1,80 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%57 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "stateful_neg_offset"
%void = OpTypeVoid
%uchar = OpTypeInt 8 0
%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
%61 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar
%ulong = OpTypeInt 64 0
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
%1 = OpFunction %void None %61
%29 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
%30 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
%55 = OpLabel
%15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%12 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%13 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%9 = OpVariable %_ptr_Function_ulong Function
OpStore %15 %29
OpStore %16 %30
%47 = OpBitcast %_ptr_Function_ulong %15
%17 = OpLoad %ulong %47 Aligned 8
%31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %17
OpStore %10 %31
%48 = OpBitcast %_ptr_Function_ulong %16
%18 = OpLoad %ulong %48 Aligned 8
%32 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %18
OpStore %11 %32
%33 = OpLoad %_ptr_CrossWorkgroup_uchar %10
%20 = OpConvertPtrToU %ulong %33
%50 = OpCopyObject %ulong %20
%49 = OpCopyObject %ulong %50
%19 = OpCopyObject %ulong %49
%34 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %19
OpStore %12 %34
%35 = OpLoad %_ptr_CrossWorkgroup_uchar %11
%22 = OpConvertPtrToU %ulong %35
%52 = OpCopyObject %ulong %22
%51 = OpCopyObject %ulong %52
%21 = OpCopyObject %ulong %51
%36 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %21
OpStore %13 %36
%37 = OpLoad %_ptr_CrossWorkgroup_uchar %12
%24 = OpConvertPtrToU %ulong %37
%38 = OpLoad %_ptr_CrossWorkgroup_uchar %13
%25 = OpConvertPtrToU %ulong %38
%23 = OpIAdd %ulong %24 %25
%39 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %23
OpStore %14 %39
%40 = OpLoad %_ptr_CrossWorkgroup_uchar %12
%27 = OpConvertPtrToU %ulong %40
%41 = OpLoad %_ptr_CrossWorkgroup_uchar %13
%28 = OpConvertPtrToU %ulong %41
%26 = OpISub %ulong %27 %28
%42 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26
OpStore %14 %42
%44 = OpLoad %_ptr_CrossWorkgroup_uchar %12
%53 = OpBitcast %_ptr_CrossWorkgroup_ulong %44
%43 = OpLoad %ulong %53 Aligned 8
OpStore %9 %43
%45 = OpLoad %_ptr_CrossWorkgroup_uchar %13
%46 = OpLoad %ulong %9
%54 = OpBitcast %_ptr_CrossWorkgroup_ulong %45
OpStore %54 %46 Aligned 8
OpReturn
OpFunctionEnd

View file

@ -4234,7 +4234,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
arg,
)) => {
if let (TypedOperand::Reg(dst), Some(src)) =
(arg.dst, arg.src.upcast().underlying())
(arg.dst, arg.src.upcast().underlying_register())
{
if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, dst) {
stateful_markers.push((dst, *src));
@ -4266,7 +4266,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
arg,
)) => {
if let (TypedOperand::Reg(dst), Some(src)) =
(&arg.dst, arg.src.upcast().underlying())
(&arg.dst, arg.src.upcast().underlying_register())
{
if func_args_64bit.contains(src) {
multi_hash_map_append(&mut stateful_init_reg, *dst, *src);
@ -4320,14 +4320,16 @@ fn convert_to_stateful_memory_access<'a, 'input>(
}),
arg,
)) => {
// TODO: don't mark result of double pointer sub or double
// pointer add as ptr result
if let (TypedOperand::Reg(dst), Some(src1)) =
(arg.dst, arg.src1.upcast().underlying())
(arg.dst, arg.src1.upcast().underlying_register())
{
if regs_ptr_current.contains(src1) && !regs_ptr_seen.contains(src1) {
regs_ptr_new.insert(dst);
}
} else if let (TypedOperand::Reg(dst), Some(src2)) =
(arg.dst, arg.src2.upcast().underlying())
(arg.dst, arg.src2.upcast().underlying_register())
{
if regs_ptr_current.contains(src2) && !regs_ptr_seen.contains(src2) {
regs_ptr_new.insert(dst);
@ -4392,7 +4394,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
}),
arg,
)) if is_add_ptr_direct(&remapped_ids, &arg) => {
let (ptr, offset) = match arg.src1.upcast().underlying() {
let (ptr, offset) = match arg.src1.upcast().underlying_register() {
Some(src1) if remapped_ids.contains_key(src1) => {
(remapped_ids.get(src1).unwrap(), arg.src2)
}
@ -4420,14 +4422,9 @@ fn convert_to_stateful_memory_access<'a, 'input>(
saturate: false,
}),
arg,
)) if is_add_ptr_direct(&remapped_ids, &arg) => {
let (ptr, offset) = match arg.src1.upcast().underlying() {
Some(src1) if remapped_ids.contains_key(src1) => {
(remapped_ids.get(src1).unwrap(), arg.src2)
}
Some(src2) if remapped_ids.contains_key(src2) => {
(remapped_ids.get(src2).unwrap(), arg.src1)
}
)) if is_sub_ptr_direct(&remapped_ids, &arg) => {
let (ptr, offset) = match arg.src1.upcast().underlying_register() {
Some(src1) => (remapped_ids.get(src1).unwrap(), arg.src2),
_ => return Err(error_unreachable()),
};
let offset_neg = id_defs.register_intermediate(Some((
@ -4577,10 +4574,45 @@ fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgP
if !remapped_ids.contains_key(&dst) {
return false;
}
match arg.src1.upcast().underlying() {
Some(src1) if remapped_ids.contains_key(src1) => true,
Some(src2) if remapped_ids.contains_key(src2) => true,
_ => false,
if let Some(src1_reg) = arg.src1.upcast().underlying_register() {
if remapped_ids.contains_key(src1_reg) {
// don't trigger optimization when adding two pointers
if let Some(src2_reg) = arg.src2.upcast().underlying_register() {
return !remapped_ids.contains_key(src2_reg);
}
}
}
if let Some(src2_reg) = arg.src2.upcast().underlying_register() {
remapped_ids.contains_key(src2_reg)
} else {
false
}
}
}
}
fn is_sub_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {
match arg.dst {
TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
return false
}
TypedOperand::Reg(dst) => {
if !remapped_ids.contains_key(&dst) {
return false;
}
match arg.src1.upcast().underlying_register() {
Some(src1_reg) => {
if remapped_ids.contains_key(src1_reg) {
// don't trigger optimization when subtracting two pointers
arg.src2
.upcast()
.underlying_register()
.map_or(true, |src2_reg| !remapped_ids.contains_key(src2_reg))
} else {
false
}
}
None => false,
}
}
}
@ -7099,12 +7131,12 @@ impl ast::StateSpace {
}
impl<T> ast::Operand<T> {
fn underlying(&self) -> Option<&T> {
fn underlying_register(&self) -> Option<&T> {
match self {
ast::Operand::Reg(r) | ast::Operand::RegOffset(r, _) => Some(r),
ast::Operand::Imm(_) => None,
ast::Operand::VecMember(reg, _) => Some(reg),
ast::Operand::VecPack(..) => None,
ast::Operand::Reg(r)
| ast::Operand::RegOffset(r, _)
| ast::Operand::VecMember(r, _) => Some(r),
ast::Operand::Imm(_) | ast::Operand::VecPack(..) => None,
}
}
}