mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Fix remaining bugs in vector destructuring and in the process improve implicit conversions
This commit is contained in:
parent
1e0b35be4b
commit
3e92921275
8 changed files with 433 additions and 485 deletions
|
@ -11,5 +11,5 @@ members = [
|
|||
]
|
||||
|
||||
[patch.crates-io]
|
||||
rspirv = { git = 'https://github.com/vosen/rspirv', rev = '4523d54d785faff59c1e928dd1f210c531a70258' }
|
||||
spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '4523d54d785faff59c1e928dd1f210c531a70258' }
|
||||
rspirv = { git = 'https://github.com/vosen/rspirv', rev = '0f5761918624f4a95107c14abe64946c5c5f60ce' }
|
||||
spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '0f5761918624f4a95107c14abe64946c5c5f60ce' }
|
|
@ -10,7 +10,7 @@ edition = "2018"
|
|||
lalrpop-util = "0.19"
|
||||
regex = "1"
|
||||
rspirv = "0.6"
|
||||
spirv_headers = "1.4"
|
||||
spirv_headers = "~1.4.2"
|
||||
quick-error = "1.2"
|
||||
bit-vec = "0.6"
|
||||
half ="1.6"
|
||||
|
|
|
@ -463,14 +463,14 @@ pub enum CallOperand<ID> {
|
|||
|
||||
pub enum IdOrVector<ID> {
|
||||
Reg(ID),
|
||||
Vec(Vec<ID>)
|
||||
Vec(Vec<ID>),
|
||||
}
|
||||
|
||||
pub enum OperandOrVector<ID> {
|
||||
Reg(ID),
|
||||
RegOffset(ID, i32),
|
||||
Imm(u32),
|
||||
Vec(Vec<ID>)
|
||||
Vec(Vec<ID>),
|
||||
}
|
||||
|
||||
impl<T> From<Operand<T>> for OperandOrVector<T> {
|
||||
|
@ -536,6 +536,8 @@ pub struct MovDetails {
|
|||
// two fields below are in use by member moves
|
||||
pub dst_width: u8,
|
||||
pub src_width: u8,
|
||||
// This is in use by auto-generated movs
|
||||
pub relaxed_src2_conv: bool,
|
||||
}
|
||||
|
||||
impl MovDetails {
|
||||
|
@ -544,7 +546,8 @@ impl MovDetails {
|
|||
typ,
|
||||
src_is_address: false,
|
||||
dst_width: 0,
|
||||
src_width: 0
|
||||
src_width: 0,
|
||||
relaxed_src2_conv: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -560,7 +563,7 @@ pub struct MulIntDesc {
|
|||
pub control: MulIntControl,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum MulIntControl {
|
||||
Low,
|
||||
High,
|
||||
|
|
|
@ -2,8 +2,10 @@
|
|||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
%23 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
|
@ -33,17 +35,17 @@
|
|||
%11 = OpCopyObject %ulong %12
|
||||
OpStore %5 %11
|
||||
%14 = OpLoad %ulong %4
|
||||
%17 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14
|
||||
%18 = OpLoad %float %17
|
||||
%31 = OpBitcast %uint %18
|
||||
%18 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14
|
||||
%17 = OpLoad %float %18
|
||||
%31 = OpBitcast %uint %17
|
||||
%13 = OpUConvert %ulong %31
|
||||
OpStore %6 %13
|
||||
%15 = OpLoad %ulong %5
|
||||
%16 = OpLoad %ulong %6
|
||||
%19 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15
|
||||
%32 = OpBitcast %ulong %16
|
||||
%33 = OpUConvert %uint %32
|
||||
%19 = OpBitcast %float %33
|
||||
%20 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15
|
||||
OpStore %20 %19
|
||||
%20 = OpBitcast %float %33
|
||||
OpStore %19 %20
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
|
|
|
@ -2,8 +2,10 @@
|
|||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
%32 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
|
@ -57,8 +59,8 @@
|
|||
OpStore %8 %19
|
||||
%22 = OpLoad %ulong %5
|
||||
%23 = OpLoad %ulong %8
|
||||
%28 = OpCopyObject %ulong %23
|
||||
%29 = OpConvertUToPtr %_ptr_Generic_ulong %22
|
||||
OpStore %29 %28
|
||||
%28 = OpConvertUToPtr %_ptr_Generic_ulong %22
|
||||
%29 = OpCopyObject %ulong %23
|
||||
OpStore %28 %29
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
.reg .u16 temp4;
|
||||
.reg .v4.u16 foo;
|
||||
|
||||
ld.param.u64 in_addr, [input_p];
|
||||
ld.param.u64 out_addr, [output_p];
|
||||
|
||||
ld.global.v4.u8 {temp1, temp2, temp3, temp4}, [in_addr];
|
||||
mov.v4.u16 foo, {temp2, temp3, temp4, temp1};
|
||||
mov.v4.u16 {temp3, temp4, temp1, temp2}, foo;
|
||||
|
|
|
@ -2,96 +2,123 @@
|
|||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
%60 = OpExtInstImport "OpenCL.std"
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
%75 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %31 "vector"
|
||||
OpEntryPoint Kernel %1 "vector_extract"
|
||||
%void = OpTypeVoid
|
||||
%uint = OpTypeInt 32 0
|
||||
%v2uint = OpTypeVector %uint 2
|
||||
%64 = OpTypeFunction %v2uint %v2uint
|
||||
%_ptr_Function_v2uint = OpTypePointer Function %v2uint
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%ulong = OpTypeInt 64 0
|
||||
%68 = OpTypeFunction %void %ulong %ulong
|
||||
%78 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint
|
||||
%1 = OpFunction %v2uint None %64
|
||||
%7 = OpFunctionParameter %v2uint
|
||||
%30 = OpLabel
|
||||
%3 = OpVariable %_ptr_Function_v2uint Function
|
||||
%2 = OpVariable %_ptr_Function_v2uint Function
|
||||
%4 = OpVariable %_ptr_Function_v2uint Function
|
||||
%5 = OpVariable %_ptr_Function_uint Function
|
||||
%6 = OpVariable %_ptr_Function_uint Function
|
||||
OpStore %3 %7
|
||||
%9 = OpLoad %v2uint %3
|
||||
%27 = OpCompositeExtract %uint %9 0
|
||||
%8 = OpCopyObject %uint %27
|
||||
OpStore %5 %8
|
||||
%11 = OpLoad %v2uint %3
|
||||
%28 = OpCompositeExtract %uint %11 1
|
||||
%10 = OpCopyObject %uint %28
|
||||
OpStore %6 %10
|
||||
%13 = OpLoad %uint %5
|
||||
%14 = OpLoad %uint %6
|
||||
%12 = OpIAdd %uint %13 %14
|
||||
OpStore %6 %12
|
||||
%16 = OpLoad %v2uint %4
|
||||
%17 = OpLoad %uint %6
|
||||
%15 = OpCompositeInsert %v2uint %17 %16 0
|
||||
OpStore %4 %15
|
||||
%19 = OpLoad %v2uint %4
|
||||
%20 = OpLoad %uint %6
|
||||
%18 = OpCompositeInsert %v2uint %20 %19 1
|
||||
OpStore %4 %18
|
||||
%22 = OpLoad %v2uint %4
|
||||
%23 = OpLoad %v2uint %4
|
||||
%29 = OpCompositeExtract %uint %23 1
|
||||
%21 = OpCompositeInsert %v2uint %29 %22 0
|
||||
OpStore %4 %21
|
||||
%25 = OpLoad %v2uint %4
|
||||
%24 = OpCopyObject %v2uint %25
|
||||
OpStore %2 %24
|
||||
%26 = OpLoad %v2uint %2
|
||||
OpReturnValue %26
|
||||
OpFunctionEnd
|
||||
%31 = OpFunction %void None %68
|
||||
%40 = OpFunctionParameter %ulong
|
||||
%41 = OpFunctionParameter %ulong
|
||||
%58 = OpLabel
|
||||
%32 = OpVariable %_ptr_Function_ulong Function
|
||||
%33 = OpVariable %_ptr_Function_ulong Function
|
||||
%34 = OpVariable %_ptr_Function_ulong Function
|
||||
%35 = OpVariable %_ptr_Function_ulong Function
|
||||
%36 = OpVariable %_ptr_Function_v2uint Function
|
||||
%37 = OpVariable %_ptr_Function_uint Function
|
||||
%38 = OpVariable %_ptr_Function_uint Function
|
||||
%39 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %32 %40
|
||||
OpStore %33 %41
|
||||
%43 = OpLoad %ulong %32
|
||||
%42 = OpCopyObject %ulong %43
|
||||
OpStore %34 %42
|
||||
%45 = OpLoad %ulong %33
|
||||
%44 = OpCopyObject %ulong %45
|
||||
OpStore %35 %44
|
||||
%47 = OpLoad %ulong %34
|
||||
%54 = OpConvertUToPtr %_ptr_Generic_v2uint %47
|
||||
%46 = OpLoad %v2uint %54
|
||||
OpStore %36 %46
|
||||
%49 = OpLoad %v2uint %36
|
||||
%48 = OpFunctionCall %v2uint %1 %49
|
||||
OpStore %36 %48
|
||||
%51 = OpLoad %v2uint %36
|
||||
%55 = OpBitcast %ulong %51
|
||||
%56 = OpCopyObject %ulong %55
|
||||
%50 = OpCopyObject %ulong %56
|
||||
OpStore %39 %50
|
||||
%52 = OpLoad %ulong %35
|
||||
%53 = OpLoad %v2uint %36
|
||||
%57 = OpConvertUToPtr %_ptr_Generic_v2uint %52
|
||||
OpStore %57 %53
|
||||
%ushort = OpTypeInt 16 0
|
||||
%_ptr_Function_ushort = OpTypePointer Function %ushort
|
||||
%v4ushort = OpTypeVector %ushort 4
|
||||
%_ptr_Function_v4ushort = OpTypePointer Function %v4ushort
|
||||
%uchar = OpTypeInt 8 0
|
||||
%v4uchar = OpTypeVector %uchar 4
|
||||
%_ptr_CrossWorkgroup_v4uchar = OpTypePointer CrossWorkgroup %v4uchar
|
||||
%1 = OpFunction %void None %78
|
||||
%11 = OpFunctionParameter %ulong
|
||||
%12 = OpFunctionParameter %ulong
|
||||
%73 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_ushort Function
|
||||
%7 = OpVariable %_ptr_Function_ushort Function
|
||||
%8 = OpVariable %_ptr_Function_ushort Function
|
||||
%9 = OpVariable %_ptr_Function_ushort Function
|
||||
%10 = OpVariable %_ptr_Function_v4ushort Function
|
||||
OpStore %2 %11
|
||||
OpStore %3 %12
|
||||
%14 = OpLoad %ulong %2
|
||||
%13 = OpCopyObject %ulong %14
|
||||
OpStore %4 %13
|
||||
%16 = OpLoad %ulong %3
|
||||
%15 = OpCopyObject %ulong %16
|
||||
OpStore %5 %15
|
||||
%21 = OpLoad %ulong %4
|
||||
%63 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %21
|
||||
%45 = OpLoad %v4uchar %63
|
||||
%64 = OpCompositeExtract %uchar %45 0
|
||||
%87 = OpBitcast %uchar %64
|
||||
%17 = OpUConvert %ushort %87
|
||||
%65 = OpCompositeExtract %uchar %45 1
|
||||
%88 = OpBitcast %uchar %65
|
||||
%18 = OpUConvert %ushort %88
|
||||
%66 = OpCompositeExtract %uchar %45 2
|
||||
%89 = OpBitcast %uchar %66
|
||||
%19 = OpUConvert %ushort %89
|
||||
%67 = OpCompositeExtract %uchar %45 3
|
||||
%90 = OpBitcast %uchar %67
|
||||
%20 = OpUConvert %ushort %90
|
||||
OpStore %6 %17
|
||||
OpStore %7 %18
|
||||
OpStore %8 %19
|
||||
OpStore %9 %20
|
||||
%23 = OpLoad %ushort %7
|
||||
%24 = OpLoad %ushort %8
|
||||
%25 = OpLoad %ushort %9
|
||||
%26 = OpLoad %ushort %6
|
||||
%46 = OpUndef %v4ushort
|
||||
%47 = OpCompositeInsert %v4ushort %23 %46 0
|
||||
%48 = OpCompositeInsert %v4ushort %24 %47 1
|
||||
%49 = OpCompositeInsert %v4ushort %25 %48 2
|
||||
%50 = OpCompositeInsert %v4ushort %26 %49 3
|
||||
%22 = OpCopyObject %v4ushort %50
|
||||
OpStore %10 %22
|
||||
%31 = OpLoad %v4ushort %10
|
||||
%51 = OpCopyObject %v4ushort %31
|
||||
%27 = OpCompositeExtract %ushort %51 0
|
||||
%28 = OpCompositeExtract %ushort %51 1
|
||||
%29 = OpCompositeExtract %ushort %51 2
|
||||
%30 = OpCompositeExtract %ushort %51 3
|
||||
OpStore %8 %27
|
||||
OpStore %9 %28
|
||||
OpStore %6 %29
|
||||
OpStore %7 %30
|
||||
%36 = OpLoad %ushort %8
|
||||
%37 = OpLoad %ushort %9
|
||||
%38 = OpLoad %ushort %6
|
||||
%39 = OpLoad %ushort %7
|
||||
%53 = OpUndef %v4ushort
|
||||
%54 = OpCompositeInsert %v4ushort %36 %53 0
|
||||
%55 = OpCompositeInsert %v4ushort %37 %54 1
|
||||
%56 = OpCompositeInsert %v4ushort %38 %55 2
|
||||
%57 = OpCompositeInsert %v4ushort %39 %56 3
|
||||
%52 = OpCopyObject %v4ushort %57
|
||||
%32 = OpCompositeExtract %ushort %52 0
|
||||
%33 = OpCompositeExtract %ushort %52 1
|
||||
%34 = OpCompositeExtract %ushort %52 2
|
||||
%35 = OpCompositeExtract %ushort %52 3
|
||||
OpStore %9 %32
|
||||
OpStore %6 %33
|
||||
OpStore %7 %34
|
||||
OpStore %8 %35
|
||||
%40 = OpLoad %ulong %5
|
||||
%41 = OpLoad %ushort %6
|
||||
%42 = OpLoad %ushort %7
|
||||
%43 = OpLoad %ushort %8
|
||||
%44 = OpLoad %ushort %9
|
||||
%58 = OpUndef %v4uchar
|
||||
%91 = OpBitcast %ushort %41
|
||||
%68 = OpUConvert %uchar %91
|
||||
%59 = OpCompositeInsert %v4uchar %68 %58 0
|
||||
%92 = OpBitcast %ushort %42
|
||||
%69 = OpUConvert %uchar %92
|
||||
%60 = OpCompositeInsert %v4uchar %69 %59 1
|
||||
%93 = OpBitcast %ushort %43
|
||||
%70 = OpUConvert %uchar %93
|
||||
%61 = OpCompositeInsert %v4uchar %70 %60 2
|
||||
%94 = OpBitcast %ushort %44
|
||||
%71 = OpUConvert %uchar %94
|
||||
%62 = OpCompositeInsert %v4uchar %71 %61 3
|
||||
%72 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %40
|
||||
OpStore %72 %62
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
|
|
|
@ -843,6 +843,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
|
|||
(_, ArgumentSemantics::Address) => return Ok(desc.op),
|
||||
(t, ArgumentSemantics::RegisterPointer)
|
||||
| (t, ArgumentSemantics::Default)
|
||||
| (t, ArgumentSemantics::DefaultRelaxed)
|
||||
| (t, ArgumentSemantics::PhysicalPointer) => t,
|
||||
};
|
||||
let generated_id = id_def.new_id(id_type);
|
||||
|
@ -933,17 +934,19 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
|
|||
fn insert_composite_read(
|
||||
func: &mut Vec<ExpandedStatement>,
|
||||
id_def: &mut MutableNumericIdResolver<'a>,
|
||||
(scalar_type, vec_len): (ast::ScalarType, u8),
|
||||
typ: (ast::ScalarType, u8),
|
||||
scalar_dst: Option<spirv::Word>,
|
||||
scalar_sema_override: Option<ArgumentSemantics>,
|
||||
composite_src: (spirv::Word, u8),
|
||||
) -> spirv::Word {
|
||||
let new_id =
|
||||
scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Vector(scalar_type, vec_len)));
|
||||
let new_id = scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Scalar(typ.0)));
|
||||
func.push(Statement::Composite(CompositeRead {
|
||||
typ: scalar_type,
|
||||
typ: typ.0,
|
||||
dst: new_id,
|
||||
dst_semantics_override: scalar_sema_override,
|
||||
src_composite: composite_src.0,
|
||||
src_index: composite_src.1 as u32,
|
||||
src_len: typ.1 as u32,
|
||||
}));
|
||||
new_id
|
||||
}
|
||||
|
@ -963,7 +966,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
|
|||
) -> Result<spirv::Word, TranslateError> {
|
||||
let (reg, offset) = desc.op;
|
||||
match desc.sema {
|
||||
ArgumentSemantics::Default => {
|
||||
ArgumentSemantics::Default | ArgumentSemantics::DefaultRelaxed => {
|
||||
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
|
||||
scalar
|
||||
} else {
|
||||
|
@ -1049,18 +1052,19 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
|
|||
fn member_src(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<(spirv::Word, u8)>,
|
||||
(scalar_type, vec_len): (ast::ScalarType, u8),
|
||||
typ: (ast::ScalarType, u8),
|
||||
) -> Result<spirv::Word, TranslateError> {
|
||||
if desc.is_dst {
|
||||
return Err(TranslateError::Unreachable);
|
||||
}
|
||||
let new_id = self.id_def.new_id(ast::Type::Vector(scalar_type, vec_len));
|
||||
self.func.push(Statement::Composite(CompositeRead {
|
||||
typ: scalar_type,
|
||||
dst: new_id,
|
||||
src_composite: desc.op.0,
|
||||
src_index: desc.op.1 as u32,
|
||||
}));
|
||||
let new_id = Self::insert_composite_read(
|
||||
self.func,
|
||||
self.id_def,
|
||||
typ,
|
||||
None,
|
||||
Some(desc.sema),
|
||||
desc.op,
|
||||
);
|
||||
Ok(new_id)
|
||||
}
|
||||
|
||||
|
@ -1077,10 +1081,11 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
|
|||
let newer_id = self.id_def.new_id(typ);
|
||||
self.func.push(Statement::Instruction(ast::Instruction::Mov(
|
||||
ast::MovDetails {
|
||||
typ: typ,
|
||||
typ: ast::Type::Scalar(scalar_type),
|
||||
src_is_address: false,
|
||||
dst_width: 0,
|
||||
src_width: vec_len,
|
||||
dst_width: vec_len,
|
||||
src_width: 0,
|
||||
relaxed_src2_conv: desc.sema == ArgumentSemantics::DefaultRelaxed,
|
||||
},
|
||||
ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
|
||||
(newer_id, idx as u8),
|
||||
|
@ -1099,6 +1104,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
|
|||
self.id_def,
|
||||
(scalar_type, vec_len),
|
||||
Some(*id),
|
||||
Some(desc.sema),
|
||||
(new_id, idx as u8),
|
||||
);
|
||||
}
|
||||
|
@ -1144,9 +1150,9 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr
|
|||
fn src_member_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<(spirv::Word, u8)>,
|
||||
(scalar_type, vec_len): (ast::ScalarType, u8),
|
||||
typ: (ast::ScalarType, u8),
|
||||
) -> Result<spirv::Word, TranslateError> {
|
||||
self.member_src(desc, (scalar_type, vec_len))
|
||||
self.member_src(desc, typ)
|
||||
}
|
||||
|
||||
fn id_or_vector(
|
||||
|
@ -1195,123 +1201,41 @@ fn insert_implicit_conversions(
|
|||
let mut result = Vec::with_capacity(func.len());
|
||||
for s in func.into_iter() {
|
||||
match s {
|
||||
Statement::Call(call) => insert_implicit_bitcasts(&mut result, id_def, call)?,
|
||||
Statement::Instruction(inst) => match inst {
|
||||
ast::Instruction::Ld(ld, arg) => {
|
||||
let pre_conv = get_implicit_conversions_ld_src(
|
||||
id_def,
|
||||
ld.typ,
|
||||
ld.state_space,
|
||||
arg.src,
|
||||
false,
|
||||
)?;
|
||||
let post_conv = get_implicit_conversions_ld_dst(
|
||||
id_def,
|
||||
ld.typ,
|
||||
arg.dst,
|
||||
should_convert_relaxed_dst,
|
||||
false,
|
||||
)?;
|
||||
insert_with_conversions(
|
||||
&mut result,
|
||||
id_def,
|
||||
arg,
|
||||
pre_conv.into_iter(),
|
||||
iter::empty(),
|
||||
post_conv.into_iter().collect(),
|
||||
|arg| &mut arg.src,
|
||||
|arg| &mut arg.dst,
|
||||
|arg| ast::Instruction::Ld(ld, arg),
|
||||
)
|
||||
Statement::Call(call) => insert_implicit_conversions_impl(
|
||||
&mut result,
|
||||
id_def,
|
||||
call,
|
||||
should_bitcast_wrapper,
|
||||
None,
|
||||
)?,
|
||||
Statement::Instruction(inst) => {
|
||||
let mut default_conversion_fn = should_bitcast_wrapper
|
||||
as fn(_, _, _) -> Result<Option<ConversionKind>, TranslateError>;
|
||||
let mut state_space = None;
|
||||
if let ast::Instruction::Ld(d, _) = &inst {
|
||||
state_space = Some(d.state_space);
|
||||
}
|
||||
ast::Instruction::St(st, arg) => {
|
||||
let pre_conv = get_implicit_conversions_ld_dst(
|
||||
id_def,
|
||||
st.typ,
|
||||
arg.src2,
|
||||
should_convert_relaxed_src,
|
||||
true,
|
||||
)?;
|
||||
let post_conv = get_implicit_conversions_ld_src(
|
||||
id_def,
|
||||
st.typ,
|
||||
st.state_space.to_ld_ss(),
|
||||
arg.src1,
|
||||
true,
|
||||
)?;
|
||||
let (pre_conv_dest, post_conv) = if st.state_space == ast::StStateSpace::Param
|
||||
|| st.state_space == ast::StStateSpace::Local
|
||||
{
|
||||
(Vec::new(), post_conv)
|
||||
} else {
|
||||
(post_conv, Vec::new())
|
||||
};
|
||||
insert_with_conversions(
|
||||
&mut result,
|
||||
id_def,
|
||||
arg,
|
||||
pre_conv.into_iter(),
|
||||
pre_conv_dest.into_iter(),
|
||||
post_conv,
|
||||
|arg| &mut arg.src2,
|
||||
|arg| &mut arg.src1,
|
||||
|arg| ast::Instruction::St(st, arg),
|
||||
)
|
||||
if let ast::Instruction::St(d, _) = &inst {
|
||||
state_space = Some(d.state_space.to_ld_ss());
|
||||
}
|
||||
ast::Instruction::Mov(d, ast::Arg2Mov::Normal(mut arg)) => {
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov-2
|
||||
// TODO: handle the case of mixed vector/scalar implicit conversions
|
||||
let inst_typ_is_bit = match d.typ {
|
||||
ast::Type::Scalar(t) => ast::ScalarType::from(t).kind() == ScalarKind::Bit,
|
||||
ast::Type::Vector(_, _) => false,
|
||||
ast::Type::Array(_, _) => false,
|
||||
};
|
||||
let mut did_vector_implicit = false;
|
||||
let mut post_conv = None;
|
||||
if inst_typ_is_bit {
|
||||
let src_type = id_def.get_typed(arg.src)?;
|
||||
if let ast::Type::Vector(_, _) = src_type {
|
||||
arg.src = insert_conversion_src(
|
||||
&mut result,
|
||||
id_def,
|
||||
arg.src,
|
||||
src_type,
|
||||
d.typ.into(),
|
||||
ConversionKind::Default,
|
||||
);
|
||||
did_vector_implicit = true;
|
||||
}
|
||||
let dst_type = id_def.get_typed(arg.dst)?;
|
||||
if let ast::Type::Vector(_, _) = dst_type {
|
||||
post_conv = Some(get_conversion_dst(
|
||||
id_def,
|
||||
&mut arg.dst,
|
||||
d.typ.into(),
|
||||
dst_type,
|
||||
ConversionKind::Default,
|
||||
));
|
||||
did_vector_implicit = true;
|
||||
}
|
||||
}
|
||||
if did_vector_implicit {
|
||||
result.push(Statement::Instruction(ast::Instruction::Mov(
|
||||
d,
|
||||
ast::Arg2Mov::Normal(arg),
|
||||
)));
|
||||
} else {
|
||||
insert_implicit_bitcasts(
|
||||
&mut result,
|
||||
id_def,
|
||||
ast::Instruction::Mov(d, ast::Arg2Mov::Normal(arg)),
|
||||
)?;
|
||||
}
|
||||
if let Some(post_conv) = post_conv {
|
||||
result.push(post_conv);
|
||||
}
|
||||
if let ast::Instruction::Mov(_, ast::Arg2Mov::Normal(_)) = &inst {
|
||||
default_conversion_fn = should_bitcast_packed;
|
||||
}
|
||||
inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst)?,
|
||||
},
|
||||
Statement::Composite(c) => insert_implicit_bitcasts(&mut result, id_def, c)?,
|
||||
insert_implicit_conversions_impl(
|
||||
&mut result,
|
||||
id_def,
|
||||
inst,
|
||||
default_conversion_fn,
|
||||
state_space,
|
||||
)?;
|
||||
}
|
||||
Statement::Composite(composite) => insert_implicit_conversions_impl(
|
||||
&mut result,
|
||||
id_def,
|
||||
composite,
|
||||
should_bitcast_wrapper,
|
||||
None,
|
||||
)?,
|
||||
s @ Statement::Conditional(_)
|
||||
| s @ Statement::Label(_)
|
||||
| s @ Statement::Constant(_)
|
||||
|
@ -1326,6 +1250,77 @@ fn insert_implicit_conversions(
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
fn insert_implicit_conversions_impl(
|
||||
func: &mut Vec<ExpandedStatement>,
|
||||
id_def: &mut MutableNumericIdResolver,
|
||||
stmt: impl VisitVariableExpanded,
|
||||
default_conversion_fn: fn(
|
||||
ast::Type,
|
||||
ast::Type,
|
||||
Option<ast::LdStateSpace>,
|
||||
) -> Result<Option<ConversionKind>, TranslateError>,
|
||||
state_space: Option<ast::LdStateSpace>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut post_conv = Vec::new();
|
||||
let statement = stmt.visit_variable_extended(&mut |desc, typ| {
|
||||
let instr_type = match typ {
|
||||
None => return Ok(desc.op),
|
||||
Some(t) => t,
|
||||
};
|
||||
let operand_type = id_def.get_typed(desc.op)?;
|
||||
let mut conversion_fn = default_conversion_fn;
|
||||
match desc.sema {
|
||||
ArgumentSemantics::Default => {}
|
||||
ArgumentSemantics::DefaultRelaxed => {
|
||||
if desc.is_dst {
|
||||
conversion_fn = should_convert_relaxed_dst_wrapper;
|
||||
} else {
|
||||
conversion_fn = should_convert_relaxed_src_wrapper;
|
||||
}
|
||||
}
|
||||
ArgumentSemantics::PhysicalPointer => {
|
||||
conversion_fn = bitcast_physical_pointer;
|
||||
}
|
||||
ArgumentSemantics::RegisterPointer => {
|
||||
conversion_fn = force_bitcast;
|
||||
}
|
||||
ArgumentSemantics::Address => {
|
||||
conversion_fn = force_bitcast_ptr_to_bit;
|
||||
}
|
||||
};
|
||||
match conversion_fn(operand_type, instr_type, state_space)? {
|
||||
Some(conv_kind) => {
|
||||
let conv_output = if desc.is_dst {
|
||||
&mut post_conv
|
||||
} else {
|
||||
&mut *func
|
||||
};
|
||||
let mut from = instr_type;
|
||||
let mut to = operand_type;
|
||||
let mut src = id_def.new_id(instr_type);
|
||||
let mut dst = desc.op;
|
||||
let result = Ok(src);
|
||||
if !desc.is_dst {
|
||||
mem::swap(&mut src, &mut dst);
|
||||
mem::swap(&mut from, &mut to);
|
||||
}
|
||||
conv_output.push(Statement::Conversion(ImplicitConversion {
|
||||
src,
|
||||
dst,
|
||||
from,
|
||||
to,
|
||||
kind: conv_kind,
|
||||
}));
|
||||
result
|
||||
}
|
||||
None => Ok(desc.op),
|
||||
}
|
||||
})?;
|
||||
func.push(statement);
|
||||
func.append(&mut post_conv);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_function_type(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
|
@ -1505,7 +1500,11 @@ fn emit_function_body_ops(
|
|||
composite_src,
|
||||
scalar_src,
|
||||
)) => {
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(d.typ));
|
||||
let scalar_type = d.typ.get_scalar()?;
|
||||
let result_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::from(ast::Type::Vector(scalar_type, d.dst_width)),
|
||||
);
|
||||
let result_id = Some(dst.0);
|
||||
builder.composite_insert(
|
||||
result_type,
|
||||
|
@ -1545,8 +1544,8 @@ fn emit_function_body_ops(
|
|||
// Obviously, old and buggy one is used for compiling L0 SPIRV
|
||||
// https://github.com/intel/intel-graphics-compiler/issues/148
|
||||
let type_pred = map.get_or_add_scalar(builder, ast::ScalarType::Pred);
|
||||
let const_true = builder.constant_true(type_pred);
|
||||
let const_false = builder.constant_false(type_pred);
|
||||
let const_true = builder.constant_true(type_pred, None);
|
||||
let const_false = builder.constant_false(type_pred, None);
|
||||
builder.select(result_type, result_id, operand, const_false, const_true)
|
||||
}
|
||||
_ => builder.not(result_type, result_id, operand),
|
||||
|
@ -2700,12 +2699,9 @@ where
|
|||
fn src_member_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<spirv::Word>,
|
||||
(scalar_type, vec_len): (ast::ScalarType, u8),
|
||||
(scalar_type, _): (ast::ScalarType, u8),
|
||||
) -> Result<spirv::Word, TranslateError> {
|
||||
self(
|
||||
desc.new_op(desc.op),
|
||||
Some(ast::Type::Vector(scalar_type.into(), vec_len)),
|
||||
)
|
||||
self(desc.new_op(desc.op), Some(ast::Type::Scalar(scalar_type)))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2793,6 +2789,8 @@ pub struct ArgumentDescriptor<Op> {
|
|||
pub enum ArgumentSemantics {
|
||||
// normal register access
|
||||
Default,
|
||||
// normal register access with relaxed conversion rules (ld/st)
|
||||
DefaultRelaxed,
|
||||
// st/ld global
|
||||
PhysicalPointer,
|
||||
// st/ld .param, .local
|
||||
|
@ -2834,11 +2832,12 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||
}
|
||||
ast::Instruction::Mul(d, a) => {
|
||||
let inst_type = d.get_type();
|
||||
ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type)?)
|
||||
let is_wide = d.is_wide();
|
||||
ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type, is_wide)?)
|
||||
}
|
||||
ast::Instruction::Add(d, a) => {
|
||||
let inst_type = d.get_type();
|
||||
ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type)?)
|
||||
ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type, false)?)
|
||||
}
|
||||
ast::Instruction::Setp(d, a) => {
|
||||
let inst_type = d.typ;
|
||||
|
@ -2889,7 +2888,8 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||
}
|
||||
ast::Instruction::Mad(d, a) => {
|
||||
let inst_type = d.get_type();
|
||||
ast::Instruction::Mad(d, a.map(visitor, inst_type)?)
|
||||
let is_wide = d.is_wide();
|
||||
ast::Instruction::Mad(d, a.map(visitor, inst_type, is_wide)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -3004,6 +3004,27 @@ where
|
|||
}
|
||||
|
||||
impl ast::Type {
|
||||
fn widen(self) -> Result<Self, TranslateError> {
|
||||
match self {
|
||||
ast::Type::Scalar(scalar) => {
|
||||
let kind = scalar.kind();
|
||||
let width = scalar.width();
|
||||
if (kind != ScalarKind::Signed
|
||||
&& kind != ScalarKind::Unsigned
|
||||
&& kind != ScalarKind::Bit)
|
||||
|| (width == 8)
|
||||
{
|
||||
return Err(TranslateError::MismatchedType);
|
||||
}
|
||||
Ok(ast::Type::Scalar(ast::ScalarType::from_parts(
|
||||
width * 2,
|
||||
kind,
|
||||
)))
|
||||
}
|
||||
_ => Err(TranslateError::Unreachable),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_parts(self) -> TypeParts {
|
||||
match self {
|
||||
ast::Type::Scalar(scalar) => TypeParts {
|
||||
|
@ -3102,8 +3123,10 @@ type Arg2St = ast::Arg2St<ExpandedArgParams>;
|
|||
struct CompositeRead {
|
||||
pub typ: ast::ScalarType,
|
||||
pub dst: spirv::Word,
|
||||
pub dst_semantics_override: Option<ArgumentSemantics>,
|
||||
pub src_composite: spirv::Word,
|
||||
pub src_index: u32,
|
||||
pub src_len: u32,
|
||||
}
|
||||
|
||||
impl VisitVariableExpanded for CompositeRead {
|
||||
|
@ -3116,12 +3139,15 @@ impl VisitVariableExpanded for CompositeRead {
|
|||
self,
|
||||
f: &mut F,
|
||||
) -> Result<ExpandedStatement, TranslateError> {
|
||||
let dst_sema = self
|
||||
.dst_semantics_override
|
||||
.unwrap_or(ArgumentSemantics::Default);
|
||||
Ok(Statement::Composite(CompositeRead {
|
||||
dst: f(
|
||||
ArgumentDescriptor {
|
||||
op: self.dst,
|
||||
is_dst: true,
|
||||
sema: ArgumentSemantics::Default,
|
||||
sema: dst_sema,
|
||||
},
|
||||
Some(ast::Type::Scalar(self.typ)),
|
||||
)?,
|
||||
|
@ -3131,7 +3157,7 @@ impl VisitVariableExpanded for CompositeRead {
|
|||
is_dst: false,
|
||||
sema: ArgumentSemantics::Default,
|
||||
},
|
||||
Some(ast::Type::Scalar(self.typ)),
|
||||
Some(ast::Type::Vector(self.typ, self.src_len as u8)),
|
||||
)?,
|
||||
..self
|
||||
}))
|
||||
|
@ -3328,7 +3354,7 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
|
|||
ArgumentDescriptor {
|
||||
op: self.dst,
|
||||
is_dst: true,
|
||||
sema: ArgumentSemantics::Default,
|
||||
sema: ArgumentSemantics::DefaultRelaxed,
|
||||
},
|
||||
t.into(),
|
||||
)?;
|
||||
|
@ -3380,7 +3406,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
|
|||
ArgumentDescriptor {
|
||||
op: self.src2,
|
||||
is_dst: false,
|
||||
sema: ArgumentSemantics::Default,
|
||||
sema: ArgumentSemantics::DefaultRelaxed,
|
||||
},
|
||||
t,
|
||||
)?;
|
||||
|
@ -3429,9 +3455,9 @@ impl<P: ArgParamsEx> ast::Arg2MovNormal<P> {
|
|||
op: self.src,
|
||||
is_dst: false,
|
||||
sema: if details.src_is_address {
|
||||
ArgumentSemantics::RegisterPointer
|
||||
ArgumentSemantics::Address
|
||||
} else {
|
||||
ArgumentSemantics::PhysicalPointer
|
||||
ArgumentSemantics::Default
|
||||
},
|
||||
},
|
||||
details.typ.into(),
|
||||
|
@ -3476,13 +3502,14 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
|
|||
) -> Result<ast::Arg2MovMember<U>, TranslateError> {
|
||||
match self {
|
||||
ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => {
|
||||
let scalar_type = details.typ.get_scalar()?;
|
||||
let dst = visitor.id(
|
||||
ArgumentDescriptor {
|
||||
op: dst,
|
||||
is_dst: true,
|
||||
sema: ArgumentSemantics::Default,
|
||||
},
|
||||
Some(details.typ.into()),
|
||||
Some(ast::Type::Vector(scalar_type, details.dst_width)),
|
||||
)?;
|
||||
let src1 = visitor.id(
|
||||
ArgumentDescriptor {
|
||||
|
@ -3490,7 +3517,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
|
|||
is_dst: false,
|
||||
sema: ArgumentSemantics::Default,
|
||||
},
|
||||
Some(details.typ.into()),
|
||||
Some(ast::Type::Vector(scalar_type, details.dst_width)),
|
||||
)?;
|
||||
let src2 = visitor.id(
|
||||
ArgumentDescriptor {
|
||||
|
@ -3498,6 +3525,8 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
|
|||
is_dst: false,
|
||||
sema: if details.src_is_address {
|
||||
ArgumentSemantics::Address
|
||||
} else if details.relaxed_src2_conv {
|
||||
ArgumentSemantics::DefaultRelaxed
|
||||
} else {
|
||||
ArgumentSemantics::Default
|
||||
},
|
||||
|
@ -3527,13 +3556,14 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
|
|||
Ok(ast::Arg2MovMember::Src(dst, src))
|
||||
}
|
||||
ast::Arg2MovMember::Both((dst, len), composite_src, src) => {
|
||||
let scalar_type = details.typ.get_scalar()?;
|
||||
let dst = visitor.id(
|
||||
ArgumentDescriptor {
|
||||
op: dst,
|
||||
is_dst: true,
|
||||
sema: ArgumentSemantics::Default,
|
||||
},
|
||||
Some(details.typ.into()),
|
||||
Some(ast::Type::Vector(scalar_type, details.dst_width)),
|
||||
)?;
|
||||
let composite_src = visitor.id(
|
||||
ArgumentDescriptor {
|
||||
|
@ -3541,16 +3571,19 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
|
|||
is_dst: false,
|
||||
sema: ArgumentSemantics::Default,
|
||||
},
|
||||
Some(details.typ.into()),
|
||||
Some(ast::Type::Vector(scalar_type, details.dst_width)),
|
||||
)?;
|
||||
let scalar_typ = details.typ.get_scalar()?;
|
||||
let src = visitor.src_member_operand(
|
||||
ArgumentDescriptor {
|
||||
op: src,
|
||||
is_dst: false,
|
||||
sema: ArgumentSemantics::Default,
|
||||
sema: if details.relaxed_src2_conv {
|
||||
ArgumentSemantics::DefaultRelaxed
|
||||
} else {
|
||||
ArgumentSemantics::Default
|
||||
},
|
||||
},
|
||||
(scalar_typ.into(), details.src_width),
|
||||
(scalar_type.into(), details.src_width),
|
||||
)?;
|
||||
Ok(ast::Arg2MovMember::Both((dst, len), composite_src, src))
|
||||
}
|
||||
|
@ -3570,7 +3603,8 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
|
|||
fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
|
||||
self,
|
||||
visitor: &mut V,
|
||||
t: ast::Type,
|
||||
typ: ast::Type,
|
||||
is_wide: bool,
|
||||
) -> Result<ast::Arg3<U>, TranslateError> {
|
||||
let dst = visitor.id(
|
||||
ArgumentDescriptor {
|
||||
|
@ -3578,7 +3612,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
|
|||
is_dst: true,
|
||||
sema: ArgumentSemantics::Default,
|
||||
},
|
||||
Some(t),
|
||||
Some(if is_wide { typ.widen()? } else { typ }),
|
||||
)?;
|
||||
let src1 = visitor.operand(
|
||||
ArgumentDescriptor {
|
||||
|
@ -3586,7 +3620,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
|
|||
is_dst: false,
|
||||
sema: ArgumentSemantics::Default,
|
||||
},
|
||||
t,
|
||||
typ,
|
||||
)?;
|
||||
let src2 = visitor.operand(
|
||||
ArgumentDescriptor {
|
||||
|
@ -3594,7 +3628,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
|
|||
is_dst: false,
|
||||
sema: ArgumentSemantics::Default,
|
||||
},
|
||||
t,
|
||||
typ,
|
||||
)?;
|
||||
Ok(ast::Arg3 { dst, src1, src2 })
|
||||
}
|
||||
|
@ -3646,6 +3680,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
|
|||
self,
|
||||
visitor: &mut V,
|
||||
t: ast::Type,
|
||||
is_wide: bool,
|
||||
) -> Result<ast::Arg4<U>, TranslateError> {
|
||||
let dst = visitor.id(
|
||||
ArgumentDescriptor {
|
||||
|
@ -3653,7 +3688,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
|
|||
is_dst: true,
|
||||
sema: ArgumentSemantics::Default,
|
||||
},
|
||||
Some(t),
|
||||
Some(if is_wide { t.widen()? } else { t }),
|
||||
)?;
|
||||
let src1 = visitor.operand(
|
||||
ArgumentDescriptor {
|
||||
|
@ -4050,6 +4085,54 @@ impl<T> ast::OperandOrVector<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl ast::MulDetails {
|
||||
fn is_wide(&self) -> bool {
|
||||
match self {
|
||||
ast::MulDetails::Int(desc) => desc.control == ast::MulIntControl::Wide,
|
||||
ast::MulDetails::Float(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn force_bitcast(
|
||||
operand: ast::Type,
|
||||
instr: ast::Type,
|
||||
_: Option<ast::LdStateSpace>,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if instr != operand {
|
||||
Ok(Some(ConversionKind::Default))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn bitcast_physical_pointer(
|
||||
operand_type: ast::Type,
|
||||
_: ast::Type,
|
||||
ss: Option<ast::LdStateSpace>,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
match operand_type {
|
||||
ast::Type::Scalar(ast::ScalarType::B64)
|
||||
| ast::Type::Scalar(ast::ScalarType::U64)
|
||||
| ast::Type::Scalar(ast::ScalarType::S64) => {
|
||||
if let Some(space) = ss {
|
||||
Ok(Some(ConversionKind::BitToPtr(space)))
|
||||
} else {
|
||||
Err(TranslateError::Unreachable)
|
||||
}
|
||||
}
|
||||
_ => Err(TranslateError::MismatchedType),
|
||||
}
|
||||
}
|
||||
|
||||
fn force_bitcast_ptr_to_bit(
|
||||
_: ast::Type,
|
||||
_: ast::Type,
|
||||
_: Option<ast::LdStateSpace>,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
Ok(Some(ConversionKind::PtrToBit))
|
||||
}
|
||||
|
||||
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
||||
match (instr, operand) {
|
||||
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
||||
|
@ -4077,187 +4160,50 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
|||
}
|
||||
}
|
||||
|
||||
fn insert_with_conversions<T, ToInstruction: FnOnce(T) -> ast::Instruction<ExpandedArgParams>>(
|
||||
func: &mut Vec<ExpandedStatement>,
|
||||
id_def: &mut MutableNumericIdResolver,
|
||||
mut instr: T,
|
||||
pre_conv_src: impl ExactSizeIterator<Item = ImplicitConversion>,
|
||||
pre_conv_dst: impl ExactSizeIterator<Item = ImplicitConversion>,
|
||||
mut post_conv: Vec<ImplicitConversion>,
|
||||
mut src: impl FnMut(&mut T) -> &mut spirv::Word,
|
||||
mut dst: impl FnMut(&mut T) -> &mut spirv::Word,
|
||||
to_inst: ToInstruction,
|
||||
) {
|
||||
insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_src, &mut src);
|
||||
insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_dst, &mut dst);
|
||||
if post_conv.len() > 0 {
|
||||
let new_id = id_def.new_id(post_conv[0].from);
|
||||
post_conv[0].src = new_id;
|
||||
post_conv.last_mut().unwrap().dst = *dst(&mut instr);
|
||||
*dst(&mut instr) = new_id;
|
||||
}
|
||||
func.push(Statement::Instruction(to_inst(instr)));
|
||||
for conv in post_conv {
|
||||
func.push(Statement::Conversion(conv));
|
||||
fn should_bitcast_packed(
|
||||
operand: ast::Type,
|
||||
instr: ast::Type,
|
||||
ss: Option<ast::LdStateSpace>,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
|
||||
(operand, instr)
|
||||
{
|
||||
if scalar.kind() == ScalarKind::Bit
|
||||
&& scalar.width() == (vec_underlying_type.width() * vec_len)
|
||||
{
|
||||
return Ok(Some(ConversionKind::Default));
|
||||
}
|
||||
}
|
||||
should_bitcast_wrapper(operand, instr, ss)
|
||||
}
|
||||
|
||||
fn insert_with_conversions_pre_conv<T>(
|
||||
func: &mut Vec<ExpandedStatement>,
|
||||
id_def: &mut MutableNumericIdResolver,
|
||||
mut instr: &mut T,
|
||||
pre_conv: impl ExactSizeIterator<Item = ImplicitConversion>,
|
||||
src: &mut impl FnMut(&mut T) -> &mut spirv::Word,
|
||||
) {
|
||||
let pre_conv_len = pre_conv.len();
|
||||
for (i, mut conv) in pre_conv.enumerate() {
|
||||
let original_src = src(&mut instr);
|
||||
if i == 0 {
|
||||
conv.src = *original_src;
|
||||
}
|
||||
if i == pre_conv_len - 1 {
|
||||
let new_id = id_def.new_id(conv.to);
|
||||
conv.dst = new_id;
|
||||
*original_src = new_id;
|
||||
}
|
||||
func.push(Statement::Conversion(conv));
|
||||
fn should_bitcast_wrapper(
|
||||
operand: ast::Type,
|
||||
instr: ast::Type,
|
||||
_: Option<ast::LdStateSpace>,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if instr == operand {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
fn get_implicit_conversions_ld_dst<
|
||||
ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
|
||||
>(
|
||||
id_def: &mut MutableNumericIdResolver,
|
||||
instr_type: ast::Type,
|
||||
dst: spirv::Word,
|
||||
should_convert: ShouldConvert,
|
||||
in_reverse: bool,
|
||||
) -> Result<Option<ImplicitConversion>, TranslateError> {
|
||||
let dst_type = id_def.get_typed(dst)?;
|
||||
if let Some(conv) = should_convert(dst_type, instr_type) {
|
||||
Ok(Some(ImplicitConversion {
|
||||
src: u32::max_value(),
|
||||
dst: u32::max_value(),
|
||||
from: if !in_reverse { instr_type } else { dst_type },
|
||||
to: if !in_reverse { dst_type } else { instr_type },
|
||||
kind: conv,
|
||||
}))
|
||||
if should_bitcast(instr, operand) {
|
||||
Ok(Some(ConversionKind::Default))
|
||||
} else {
|
||||
Ok(None)
|
||||
Err(TranslateError::MismatchedType)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_implicit_conversions_ld_src(
|
||||
id_def: &mut MutableNumericIdResolver,
|
||||
instr_type: ast::Type,
|
||||
state_space: ast::LdStateSpace,
|
||||
src: spirv::Word,
|
||||
in_reverse_param_local: bool,
|
||||
) -> Result<Vec<ImplicitConversion>, TranslateError> {
|
||||
let src_type = id_def.get_typed(src)?;
|
||||
match state_space {
|
||||
ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
|
||||
if src_type != instr_type {
|
||||
Ok(vec![
|
||||
ImplicitConversion {
|
||||
src: u32::max_value(),
|
||||
dst: u32::max_value(),
|
||||
from: if !in_reverse_param_local {
|
||||
src_type
|
||||
} else {
|
||||
instr_type
|
||||
},
|
||||
to: if !in_reverse_param_local {
|
||||
instr_type
|
||||
} else {
|
||||
src_type
|
||||
},
|
||||
kind: ConversionKind::Default,
|
||||
};
|
||||
1
|
||||
])
|
||||
} else {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
|
||||
let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
|
||||
mem::size_of::<usize>() as u8,
|
||||
ScalarKind::Bit,
|
||||
));
|
||||
let mut result = Vec::new();
|
||||
// HACK ALERT
|
||||
// IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an
|
||||
// additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier
|
||||
// TODO: error out if the src is not B64/U64/S64
|
||||
if let ast::Type::Scalar(scalar_src_type) = src_type {
|
||||
if scalar_src_type.kind() == ScalarKind::Signed {
|
||||
result.push(ImplicitConversion {
|
||||
src: u32::max_value(),
|
||||
dst: u32::max_value(),
|
||||
from: src_type,
|
||||
to: new_src_type,
|
||||
kind: ConversionKind::Default,
|
||||
});
|
||||
}
|
||||
}
|
||||
result.push(ImplicitConversion {
|
||||
src: u32::max_value(),
|
||||
dst: u32::max_value(),
|
||||
from: src_type,
|
||||
to: instr_type,
|
||||
kind: ConversionKind::BitToPtr(state_space),
|
||||
});
|
||||
if result.len() == 2 {
|
||||
let new_id = id_def.new_id(new_src_type);
|
||||
result[0].dst = new_id;
|
||||
result[1].src = new_id;
|
||||
result[1].from = new_src_type;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
_ => Err(TranslateError::Todo),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn insert_conversion_src(
|
||||
func: &mut Vec<ExpandedStatement>,
|
||||
id_def: &mut MutableNumericIdResolver,
|
||||
src: spirv::Word,
|
||||
fn should_convert_relaxed_src_wrapper(
|
||||
src_type: ast::Type,
|
||||
instr_type: ast::Type,
|
||||
conv: ConversionKind,
|
||||
) -> spirv::Word {
|
||||
let temp_src = id_def.new_id(instr_type);
|
||||
func.push(Statement::Conversion(ImplicitConversion {
|
||||
src: src,
|
||||
dst: temp_src,
|
||||
from: src_type,
|
||||
to: instr_type,
|
||||
kind: conv,
|
||||
}));
|
||||
temp_src
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn get_conversion_dst(
|
||||
id_def: &mut MutableNumericIdResolver,
|
||||
dst: &mut spirv::Word,
|
||||
instr_type: ast::Type,
|
||||
dst_type: ast::Type,
|
||||
kind: ConversionKind,
|
||||
) -> ExpandedStatement {
|
||||
let original_dst = *dst;
|
||||
let temp_dst = id_def.new_id(instr_type);
|
||||
*dst = temp_dst;
|
||||
Statement::Conversion(ImplicitConversion {
|
||||
src: temp_dst,
|
||||
dst: original_dst,
|
||||
from: instr_type,
|
||||
to: dst_type,
|
||||
kind: kind,
|
||||
})
|
||||
_: Option<ast::LdStateSpace>,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if src_type == instr_type {
|
||||
return Ok(None);
|
||||
}
|
||||
match should_convert_relaxed_src(src_type, instr_type) {
|
||||
conv @ Some(_) => Ok(conv),
|
||||
None => Err(TranslateError::MismatchedType),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
|
||||
|
@ -4302,6 +4248,20 @@ fn should_convert_relaxed_src(
|
|||
}
|
||||
}
|
||||
|
||||
fn should_convert_relaxed_dst_wrapper(
|
||||
dst_type: ast::Type,
|
||||
instr_type: ast::Type,
|
||||
_: Option<ast::LdStateSpace>,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if dst_type == instr_type {
|
||||
return Ok(None);
|
||||
}
|
||||
match should_convert_relaxed_dst(dst_type, instr_type) {
|
||||
conv @ Some(_) => Ok(conv),
|
||||
None => Err(TranslateError::MismatchedType),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
|
||||
fn should_convert_relaxed_dst(
|
||||
dst_type: ast::Type,
|
||||
|
@ -4357,55 +4317,6 @@ fn should_convert_relaxed_dst(
|
|||
}
|
||||
}
|
||||
|
||||
fn insert_implicit_bitcasts(
|
||||
func: &mut Vec<ExpandedStatement>,
|
||||
id_def: &mut MutableNumericIdResolver,
|
||||
stmt: impl VisitVariableExpanded,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut dst_coercion = None;
|
||||
let instr = stmt.visit_variable_extended(&mut |mut desc, typ| {
|
||||
let id_type_from_instr = match typ {
|
||||
Some(t) => t,
|
||||
None => return Ok(desc.op),
|
||||
};
|
||||
let id_actual_type = id_def.get_typed(desc.op)?;
|
||||
let conv_kind = if desc.sema == ArgumentSemantics::Address {
|
||||
Some(ConversionKind::PtrToBit)
|
||||
} else if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if let Some(conv_kind) = conv_kind {
|
||||
if desc.is_dst {
|
||||
dst_coercion = Some(get_conversion_dst(
|
||||
id_def,
|
||||
&mut desc.op,
|
||||
id_type_from_instr,
|
||||
id_actual_type,
|
||||
conv_kind,
|
||||
));
|
||||
Ok(desc.op)
|
||||
} else {
|
||||
Ok(insert_conversion_src(
|
||||
func,
|
||||
id_def,
|
||||
desc.op,
|
||||
id_actual_type,
|
||||
id_type_from_instr,
|
||||
conv_kind,
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Ok(desc.op)
|
||||
}
|
||||
})?;
|
||||
func.push(instr);
|
||||
if let Some(cond) = dst_coercion {
|
||||
func.push(cond);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
impl<'a> ast::MethodDecl<'a, &'a str> {
|
||||
fn name(&self) -> &'a str {
|
||||
match self {
|
||||
|
|
Loading…
Add table
Reference in a new issue