Fix remaining bugs in vector destructuring and in the process improve implicit conversions

This commit is contained in:
Andrzej Janik 2020-10-01 00:44:58 +02:00
parent 1e0b35be4b
commit 3e92921275
8 changed files with 433 additions and 485 deletions

View file

@ -11,5 +11,5 @@ members = [
]
[patch.crates-io]
rspirv = { git = 'https://github.com/vosen/rspirv', rev = '4523d54d785faff59c1e928dd1f210c531a70258' }
spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '4523d54d785faff59c1e928dd1f210c531a70258' }
rspirv = { git = 'https://github.com/vosen/rspirv', rev = '0f5761918624f4a95107c14abe64946c5c5f60ce' }
spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '0f5761918624f4a95107c14abe64946c5c5f60ce' }

View file

@ -10,7 +10,7 @@ edition = "2018"
lalrpop-util = "0.19"
regex = "1"
rspirv = "0.6"
spirv_headers = "1.4"
spirv_headers = "~1.4.2"
quick-error = "1.2"
bit-vec = "0.6"
half ="1.6"

View file

@ -463,14 +463,14 @@ pub enum CallOperand<ID> {
pub enum IdOrVector<ID> {
Reg(ID),
Vec(Vec<ID>)
Vec(Vec<ID>),
}
pub enum OperandOrVector<ID> {
Reg(ID),
RegOffset(ID, i32),
Imm(u32),
Vec(Vec<ID>)
Vec(Vec<ID>),
}
impl<T> From<Operand<T>> for OperandOrVector<T> {
@ -536,6 +536,8 @@ pub struct MovDetails {
// two fields below are in use by member moves
pub dst_width: u8,
pub src_width: u8,
// This is in use by auto-generated movs
pub relaxed_src2_conv: bool,
}
impl MovDetails {
@ -544,7 +546,8 @@ impl MovDetails {
typ,
src_is_address: false,
dst_width: 0,
src_width: 0
src_width: 0,
relaxed_src2_conv: false,
}
}
}
@ -560,7 +563,7 @@ pub struct MulIntDesc {
pub control: MulIntControl,
}
#[derive(Copy, Clone)]
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum MulIntControl {
Low,
High,

View file

@ -2,8 +2,10 @@
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%23 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
@ -33,17 +35,17 @@
%11 = OpCopyObject %ulong %12
OpStore %5 %11
%14 = OpLoad %ulong %4
%17 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14
%18 = OpLoad %float %17
%31 = OpBitcast %uint %18
%18 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14
%17 = OpLoad %float %18
%31 = OpBitcast %uint %17
%13 = OpUConvert %ulong %31
OpStore %6 %13
%15 = OpLoad %ulong %5
%16 = OpLoad %ulong %6
%19 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15
%32 = OpBitcast %ulong %16
%33 = OpUConvert %uint %32
%19 = OpBitcast %float %33
%20 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15
OpStore %20 %19
%20 = OpBitcast %float %33
OpStore %19 %20
OpReturn
OpFunctionEnd

View file

@ -2,8 +2,10 @@
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%32 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
@ -57,8 +59,8 @@
OpStore %8 %19
%22 = OpLoad %ulong %5
%23 = OpLoad %ulong %8
%28 = OpCopyObject %ulong %23
%29 = OpConvertUToPtr %_ptr_Generic_ulong %22
OpStore %29 %28
%28 = OpConvertUToPtr %_ptr_Generic_ulong %22
%29 = OpCopyObject %ulong %23
OpStore %28 %29
OpReturn
OpFunctionEnd

View file

@ -15,6 +15,9 @@
.reg .u16 temp4;
.reg .v4.u16 foo;
ld.param.u64 in_addr, [input_p];
ld.param.u64 out_addr, [output_p];
ld.global.v4.u8 {temp1, temp2, temp3, temp4}, [in_addr];
mov.v4.u16 foo, {temp2, temp3, temp4, temp1};
mov.v4.u16 {temp3, temp4, temp1, temp2}, foo;

View file

@ -2,96 +2,123 @@
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%60 = OpExtInstImport "OpenCL.std"
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%75 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %31 "vector"
OpEntryPoint Kernel %1 "vector_extract"
%void = OpTypeVoid
%uint = OpTypeInt 32 0
%v2uint = OpTypeVector %uint 2
%64 = OpTypeFunction %v2uint %v2uint
%_ptr_Function_v2uint = OpTypePointer Function %v2uint
%_ptr_Function_uint = OpTypePointer Function %uint
%ulong = OpTypeInt 64 0
%68 = OpTypeFunction %void %ulong %ulong
%78 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint
%1 = OpFunction %v2uint None %64
%7 = OpFunctionParameter %v2uint
%30 = OpLabel
%3 = OpVariable %_ptr_Function_v2uint Function
%2 = OpVariable %_ptr_Function_v2uint Function
%4 = OpVariable %_ptr_Function_v2uint Function
%5 = OpVariable %_ptr_Function_uint Function
%6 = OpVariable %_ptr_Function_uint Function
OpStore %3 %7
%9 = OpLoad %v2uint %3
%27 = OpCompositeExtract %uint %9 0
%8 = OpCopyObject %uint %27
OpStore %5 %8
%11 = OpLoad %v2uint %3
%28 = OpCompositeExtract %uint %11 1
%10 = OpCopyObject %uint %28
OpStore %6 %10
%13 = OpLoad %uint %5
%14 = OpLoad %uint %6
%12 = OpIAdd %uint %13 %14
OpStore %6 %12
%16 = OpLoad %v2uint %4
%17 = OpLoad %uint %6
%15 = OpCompositeInsert %v2uint %17 %16 0
OpStore %4 %15
%19 = OpLoad %v2uint %4
%20 = OpLoad %uint %6
%18 = OpCompositeInsert %v2uint %20 %19 1
OpStore %4 %18
%22 = OpLoad %v2uint %4
%23 = OpLoad %v2uint %4
%29 = OpCompositeExtract %uint %23 1
%21 = OpCompositeInsert %v2uint %29 %22 0
OpStore %4 %21
%25 = OpLoad %v2uint %4
%24 = OpCopyObject %v2uint %25
OpStore %2 %24
%26 = OpLoad %v2uint %2
OpReturnValue %26
OpFunctionEnd
%31 = OpFunction %void None %68
%40 = OpFunctionParameter %ulong
%41 = OpFunctionParameter %ulong
%58 = OpLabel
%32 = OpVariable %_ptr_Function_ulong Function
%33 = OpVariable %_ptr_Function_ulong Function
%34 = OpVariable %_ptr_Function_ulong Function
%35 = OpVariable %_ptr_Function_ulong Function
%36 = OpVariable %_ptr_Function_v2uint Function
%37 = OpVariable %_ptr_Function_uint Function
%38 = OpVariable %_ptr_Function_uint Function
%39 = OpVariable %_ptr_Function_ulong Function
OpStore %32 %40
OpStore %33 %41
%43 = OpLoad %ulong %32
%42 = OpCopyObject %ulong %43
OpStore %34 %42
%45 = OpLoad %ulong %33
%44 = OpCopyObject %ulong %45
OpStore %35 %44
%47 = OpLoad %ulong %34
%54 = OpConvertUToPtr %_ptr_Generic_v2uint %47
%46 = OpLoad %v2uint %54
OpStore %36 %46
%49 = OpLoad %v2uint %36
%48 = OpFunctionCall %v2uint %1 %49
OpStore %36 %48
%51 = OpLoad %v2uint %36
%55 = OpBitcast %ulong %51
%56 = OpCopyObject %ulong %55
%50 = OpCopyObject %ulong %56
OpStore %39 %50
%52 = OpLoad %ulong %35
%53 = OpLoad %v2uint %36
%57 = OpConvertUToPtr %_ptr_Generic_v2uint %52
OpStore %57 %53
%ushort = OpTypeInt 16 0
%_ptr_Function_ushort = OpTypePointer Function %ushort
%v4ushort = OpTypeVector %ushort 4
%_ptr_Function_v4ushort = OpTypePointer Function %v4ushort
%uchar = OpTypeInt 8 0
%v4uchar = OpTypeVector %uchar 4
%_ptr_CrossWorkgroup_v4uchar = OpTypePointer CrossWorkgroup %v4uchar
%1 = OpFunction %void None %78
%11 = OpFunctionParameter %ulong
%12 = OpFunctionParameter %ulong
%73 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ushort Function
%7 = OpVariable %_ptr_Function_ushort Function
%8 = OpVariable %_ptr_Function_ushort Function
%9 = OpVariable %_ptr_Function_ushort Function
%10 = OpVariable %_ptr_Function_v4ushort Function
OpStore %2 %11
OpStore %3 %12
%14 = OpLoad %ulong %2
%13 = OpCopyObject %ulong %14
OpStore %4 %13
%16 = OpLoad %ulong %3
%15 = OpCopyObject %ulong %16
OpStore %5 %15
%21 = OpLoad %ulong %4
%63 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %21
%45 = OpLoad %v4uchar %63
%64 = OpCompositeExtract %uchar %45 0
%87 = OpBitcast %uchar %64
%17 = OpUConvert %ushort %87
%65 = OpCompositeExtract %uchar %45 1
%88 = OpBitcast %uchar %65
%18 = OpUConvert %ushort %88
%66 = OpCompositeExtract %uchar %45 2
%89 = OpBitcast %uchar %66
%19 = OpUConvert %ushort %89
%67 = OpCompositeExtract %uchar %45 3
%90 = OpBitcast %uchar %67
%20 = OpUConvert %ushort %90
OpStore %6 %17
OpStore %7 %18
OpStore %8 %19
OpStore %9 %20
%23 = OpLoad %ushort %7
%24 = OpLoad %ushort %8
%25 = OpLoad %ushort %9
%26 = OpLoad %ushort %6
%46 = OpUndef %v4ushort
%47 = OpCompositeInsert %v4ushort %23 %46 0
%48 = OpCompositeInsert %v4ushort %24 %47 1
%49 = OpCompositeInsert %v4ushort %25 %48 2
%50 = OpCompositeInsert %v4ushort %26 %49 3
%22 = OpCopyObject %v4ushort %50
OpStore %10 %22
%31 = OpLoad %v4ushort %10
%51 = OpCopyObject %v4ushort %31
%27 = OpCompositeExtract %ushort %51 0
%28 = OpCompositeExtract %ushort %51 1
%29 = OpCompositeExtract %ushort %51 2
%30 = OpCompositeExtract %ushort %51 3
OpStore %8 %27
OpStore %9 %28
OpStore %6 %29
OpStore %7 %30
%36 = OpLoad %ushort %8
%37 = OpLoad %ushort %9
%38 = OpLoad %ushort %6
%39 = OpLoad %ushort %7
%53 = OpUndef %v4ushort
%54 = OpCompositeInsert %v4ushort %36 %53 0
%55 = OpCompositeInsert %v4ushort %37 %54 1
%56 = OpCompositeInsert %v4ushort %38 %55 2
%57 = OpCompositeInsert %v4ushort %39 %56 3
%52 = OpCopyObject %v4ushort %57
%32 = OpCompositeExtract %ushort %52 0
%33 = OpCompositeExtract %ushort %52 1
%34 = OpCompositeExtract %ushort %52 2
%35 = OpCompositeExtract %ushort %52 3
OpStore %9 %32
OpStore %6 %33
OpStore %7 %34
OpStore %8 %35
%40 = OpLoad %ulong %5
%41 = OpLoad %ushort %6
%42 = OpLoad %ushort %7
%43 = OpLoad %ushort %8
%44 = OpLoad %ushort %9
%58 = OpUndef %v4uchar
%91 = OpBitcast %ushort %41
%68 = OpUConvert %uchar %91
%59 = OpCompositeInsert %v4uchar %68 %58 0
%92 = OpBitcast %ushort %42
%69 = OpUConvert %uchar %92
%60 = OpCompositeInsert %v4uchar %69 %59 1
%93 = OpBitcast %ushort %43
%70 = OpUConvert %uchar %93
%61 = OpCompositeInsert %v4uchar %70 %60 2
%94 = OpBitcast %ushort %44
%71 = OpUConvert %uchar %94
%62 = OpCompositeInsert %v4uchar %71 %61 3
%72 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %40
OpStore %72 %62
OpReturn
OpFunctionEnd

View file

@ -843,6 +843,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
(_, ArgumentSemantics::Address) => return Ok(desc.op),
(t, ArgumentSemantics::RegisterPointer)
| (t, ArgumentSemantics::Default)
| (t, ArgumentSemantics::DefaultRelaxed)
| (t, ArgumentSemantics::PhysicalPointer) => t,
};
let generated_id = id_def.new_id(id_type);
@ -933,17 +934,19 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn insert_composite_read(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver<'a>,
(scalar_type, vec_len): (ast::ScalarType, u8),
typ: (ast::ScalarType, u8),
scalar_dst: Option<spirv::Word>,
scalar_sema_override: Option<ArgumentSemantics>,
composite_src: (spirv::Word, u8),
) -> spirv::Word {
let new_id =
scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Vector(scalar_type, vec_len)));
let new_id = scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Scalar(typ.0)));
func.push(Statement::Composite(CompositeRead {
typ: scalar_type,
typ: typ.0,
dst: new_id,
dst_semantics_override: scalar_sema_override,
src_composite: composite_src.0,
src_index: composite_src.1 as u32,
src_len: typ.1 as u32,
}));
new_id
}
@ -963,7 +966,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
) -> Result<spirv::Word, TranslateError> {
let (reg, offset) = desc.op;
match desc.sema {
ArgumentSemantics::Default => {
ArgumentSemantics::Default | ArgumentSemantics::DefaultRelaxed => {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
scalar
} else {
@ -1049,18 +1052,19 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn member_src(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
(scalar_type, vec_len): (ast::ScalarType, u8),
typ: (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
if desc.is_dst {
return Err(TranslateError::Unreachable);
}
let new_id = self.id_def.new_id(ast::Type::Vector(scalar_type, vec_len));
self.func.push(Statement::Composite(CompositeRead {
typ: scalar_type,
dst: new_id,
src_composite: desc.op.0,
src_index: desc.op.1 as u32,
}));
let new_id = Self::insert_composite_read(
self.func,
self.id_def,
typ,
None,
Some(desc.sema),
desc.op,
);
Ok(new_id)
}
@ -1077,10 +1081,11 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
let newer_id = self.id_def.new_id(typ);
self.func.push(Statement::Instruction(ast::Instruction::Mov(
ast::MovDetails {
typ: typ,
typ: ast::Type::Scalar(scalar_type),
src_is_address: false,
dst_width: 0,
src_width: vec_len,
dst_width: vec_len,
src_width: 0,
relaxed_src2_conv: desc.sema == ArgumentSemantics::DefaultRelaxed,
},
ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
(newer_id, idx as u8),
@ -1099,6 +1104,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
self.id_def,
(scalar_type, vec_len),
Some(*id),
Some(desc.sema),
(new_id, idx as u8),
);
}
@ -1144,9 +1150,9 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr
fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
(scalar_type, vec_len): (ast::ScalarType, u8),
typ: (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
self.member_src(desc, (scalar_type, vec_len))
self.member_src(desc, typ)
}
fn id_or_vector(
@ -1195,123 +1201,41 @@ fn insert_implicit_conversions(
let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
match s {
Statement::Call(call) => insert_implicit_bitcasts(&mut result, id_def, call)?,
Statement::Instruction(inst) => match inst {
ast::Instruction::Ld(ld, arg) => {
let pre_conv = get_implicit_conversions_ld_src(
id_def,
ld.typ,
ld.state_space,
arg.src,
false,
)?;
let post_conv = get_implicit_conversions_ld_dst(
id_def,
ld.typ,
arg.dst,
should_convert_relaxed_dst,
false,
)?;
insert_with_conversions(
&mut result,
id_def,
arg,
pre_conv.into_iter(),
iter::empty(),
post_conv.into_iter().collect(),
|arg| &mut arg.src,
|arg| &mut arg.dst,
|arg| ast::Instruction::Ld(ld, arg),
)
Statement::Call(call) => insert_implicit_conversions_impl(
&mut result,
id_def,
call,
should_bitcast_wrapper,
None,
)?,
Statement::Instruction(inst) => {
let mut default_conversion_fn = should_bitcast_wrapper
as fn(_, _, _) -> Result<Option<ConversionKind>, TranslateError>;
let mut state_space = None;
if let ast::Instruction::Ld(d, _) = &inst {
state_space = Some(d.state_space);
}
ast::Instruction::St(st, arg) => {
let pre_conv = get_implicit_conversions_ld_dst(
id_def,
st.typ,
arg.src2,
should_convert_relaxed_src,
true,
)?;
let post_conv = get_implicit_conversions_ld_src(
id_def,
st.typ,
st.state_space.to_ld_ss(),
arg.src1,
true,
)?;
let (pre_conv_dest, post_conv) = if st.state_space == ast::StStateSpace::Param
|| st.state_space == ast::StStateSpace::Local
{
(Vec::new(), post_conv)
} else {
(post_conv, Vec::new())
};
insert_with_conversions(
&mut result,
id_def,
arg,
pre_conv.into_iter(),
pre_conv_dest.into_iter(),
post_conv,
|arg| &mut arg.src2,
|arg| &mut arg.src1,
|arg| ast::Instruction::St(st, arg),
)
if let ast::Instruction::St(d, _) = &inst {
state_space = Some(d.state_space.to_ld_ss());
}
ast::Instruction::Mov(d, ast::Arg2Mov::Normal(mut arg)) => {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov-2
// TODO: handle the case of mixed vector/scalar implicit conversions
let inst_typ_is_bit = match d.typ {
ast::Type::Scalar(t) => ast::ScalarType::from(t).kind() == ScalarKind::Bit,
ast::Type::Vector(_, _) => false,
ast::Type::Array(_, _) => false,
};
let mut did_vector_implicit = false;
let mut post_conv = None;
if inst_typ_is_bit {
let src_type = id_def.get_typed(arg.src)?;
if let ast::Type::Vector(_, _) = src_type {
arg.src = insert_conversion_src(
&mut result,
id_def,
arg.src,
src_type,
d.typ.into(),
ConversionKind::Default,
);
did_vector_implicit = true;
}
let dst_type = id_def.get_typed(arg.dst)?;
if let ast::Type::Vector(_, _) = dst_type {
post_conv = Some(get_conversion_dst(
id_def,
&mut arg.dst,
d.typ.into(),
dst_type,
ConversionKind::Default,
));
did_vector_implicit = true;
}
}
if did_vector_implicit {
result.push(Statement::Instruction(ast::Instruction::Mov(
d,
ast::Arg2Mov::Normal(arg),
)));
} else {
insert_implicit_bitcasts(
&mut result,
id_def,
ast::Instruction::Mov(d, ast::Arg2Mov::Normal(arg)),
)?;
}
if let Some(post_conv) = post_conv {
result.push(post_conv);
}
if let ast::Instruction::Mov(_, ast::Arg2Mov::Normal(_)) = &inst {
default_conversion_fn = should_bitcast_packed;
}
inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst)?,
},
Statement::Composite(c) => insert_implicit_bitcasts(&mut result, id_def, c)?,
insert_implicit_conversions_impl(
&mut result,
id_def,
inst,
default_conversion_fn,
state_space,
)?;
}
Statement::Composite(composite) => insert_implicit_conversions_impl(
&mut result,
id_def,
composite,
should_bitcast_wrapper,
None,
)?,
s @ Statement::Conditional(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
@ -1326,6 +1250,77 @@ fn insert_implicit_conversions(
Ok(result)
}
fn insert_implicit_conversions_impl(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
stmt: impl VisitVariableExpanded,
default_conversion_fn: fn(
ast::Type,
ast::Type,
Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError>,
state_space: Option<ast::LdStateSpace>,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
let statement = stmt.visit_variable_extended(&mut |desc, typ| {
let instr_type = match typ {
None => return Ok(desc.op),
Some(t) => t,
};
let operand_type = id_def.get_typed(desc.op)?;
let mut conversion_fn = default_conversion_fn;
match desc.sema {
ArgumentSemantics::Default => {}
ArgumentSemantics::DefaultRelaxed => {
if desc.is_dst {
conversion_fn = should_convert_relaxed_dst_wrapper;
} else {
conversion_fn = should_convert_relaxed_src_wrapper;
}
}
ArgumentSemantics::PhysicalPointer => {
conversion_fn = bitcast_physical_pointer;
}
ArgumentSemantics::RegisterPointer => {
conversion_fn = force_bitcast;
}
ArgumentSemantics::Address => {
conversion_fn = force_bitcast_ptr_to_bit;
}
};
match conversion_fn(operand_type, instr_type, state_space)? {
Some(conv_kind) => {
let conv_output = if desc.is_dst {
&mut post_conv
} else {
&mut *func
};
let mut from = instr_type;
let mut to = operand_type;
let mut src = id_def.new_id(instr_type);
let mut dst = desc.op;
let result = Ok(src);
if !desc.is_dst {
mem::swap(&mut src, &mut dst);
mem::swap(&mut from, &mut to);
}
conv_output.push(Statement::Conversion(ImplicitConversion {
src,
dst,
from,
to,
kind: conv_kind,
}));
result
}
None => Ok(desc.op),
}
})?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
}
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@ -1505,7 +1500,11 @@ fn emit_function_body_ops(
composite_src,
scalar_src,
)) => {
let result_type = map.get_or_add(builder, SpirvType::from(d.typ));
let scalar_type = d.typ.get_scalar()?;
let result_type = map.get_or_add(
builder,
SpirvType::from(ast::Type::Vector(scalar_type, d.dst_width)),
);
let result_id = Some(dst.0);
builder.composite_insert(
result_type,
@ -1545,8 +1544,8 @@ fn emit_function_body_ops(
// Obviously, old and buggy one is used for compiling L0 SPIRV
// https://github.com/intel/intel-graphics-compiler/issues/148
let type_pred = map.get_or_add_scalar(builder, ast::ScalarType::Pred);
let const_true = builder.constant_true(type_pred);
let const_false = builder.constant_false(type_pred);
let const_true = builder.constant_true(type_pred, None);
let const_false = builder.constant_false(type_pred, None);
builder.select(result_type, result_id, operand, const_false, const_true)
}
_ => builder.not(result_type, result_id, operand),
@ -2700,12 +2699,9 @@ where
fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
(scalar_type, vec_len): (ast::ScalarType, u8),
(scalar_type, _): (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
self(
desc.new_op(desc.op),
Some(ast::Type::Vector(scalar_type.into(), vec_len)),
)
self(desc.new_op(desc.op), Some(ast::Type::Scalar(scalar_type)))
}
}
@ -2793,6 +2789,8 @@ pub struct ArgumentDescriptor<Op> {
pub enum ArgumentSemantics {
// normal register access
Default,
// normal register access with relaxed conversion rules (ld/st)
DefaultRelaxed,
// st/ld global
PhysicalPointer,
// st/ld .param, .local
@ -2834,11 +2832,12 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
}
ast::Instruction::Mul(d, a) => {
let inst_type = d.get_type();
ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type)?)
let is_wide = d.is_wide();
ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type, is_wide)?)
}
ast::Instruction::Add(d, a) => {
let inst_type = d.get_type();
ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type)?)
ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type, false)?)
}
ast::Instruction::Setp(d, a) => {
let inst_type = d.typ;
@ -2889,7 +2888,8 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
}
ast::Instruction::Mad(d, a) => {
let inst_type = d.get_type();
ast::Instruction::Mad(d, a.map(visitor, inst_type)?)
let is_wide = d.is_wide();
ast::Instruction::Mad(d, a.map(visitor, inst_type, is_wide)?)
}
})
}
@ -3004,6 +3004,27 @@ where
}
impl ast::Type {
fn widen(self) -> Result<Self, TranslateError> {
match self {
ast::Type::Scalar(scalar) => {
let kind = scalar.kind();
let width = scalar.width();
if (kind != ScalarKind::Signed
&& kind != ScalarKind::Unsigned
&& kind != ScalarKind::Bit)
|| (width == 8)
{
return Err(TranslateError::MismatchedType);
}
Ok(ast::Type::Scalar(ast::ScalarType::from_parts(
width * 2,
kind,
)))
}
_ => Err(TranslateError::Unreachable),
}
}
fn to_parts(self) -> TypeParts {
match self {
ast::Type::Scalar(scalar) => TypeParts {
@ -3102,8 +3123,10 @@ type Arg2St = ast::Arg2St<ExpandedArgParams>;
struct CompositeRead {
pub typ: ast::ScalarType,
pub dst: spirv::Word,
pub dst_semantics_override: Option<ArgumentSemantics>,
pub src_composite: spirv::Word,
pub src_index: u32,
pub src_len: u32,
}
impl VisitVariableExpanded for CompositeRead {
@ -3116,12 +3139,15 @@ impl VisitVariableExpanded for CompositeRead {
self,
f: &mut F,
) -> Result<ExpandedStatement, TranslateError> {
let dst_sema = self
.dst_semantics_override
.unwrap_or(ArgumentSemantics::Default);
Ok(Statement::Composite(CompositeRead {
dst: f(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
sema: dst_sema,
},
Some(ast::Type::Scalar(self.typ)),
)?,
@ -3131,7 +3157,7 @@ impl VisitVariableExpanded for CompositeRead {
is_dst: false,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(self.typ)),
Some(ast::Type::Vector(self.typ, self.src_len as u8)),
)?,
..self
}))
@ -3328,7 +3354,7 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
sema: ArgumentSemantics::DefaultRelaxed,
},
t.into(),
)?;
@ -3380,7 +3406,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
ArgumentDescriptor {
op: self.src2,
is_dst: false,
sema: ArgumentSemantics::Default,
sema: ArgumentSemantics::DefaultRelaxed,
},
t,
)?;
@ -3429,9 +3455,9 @@ impl<P: ArgParamsEx> ast::Arg2MovNormal<P> {
op: self.src,
is_dst: false,
sema: if details.src_is_address {
ArgumentSemantics::RegisterPointer
ArgumentSemantics::Address
} else {
ArgumentSemantics::PhysicalPointer
ArgumentSemantics::Default
},
},
details.typ.into(),
@ -3476,13 +3502,14 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
) -> Result<ast::Arg2MovMember<U>, TranslateError> {
match self {
ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => {
let scalar_type = details.typ.get_scalar()?;
let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(details.typ.into()),
Some(ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src1 = visitor.id(
ArgumentDescriptor {
@ -3490,7 +3517,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
is_dst: false,
sema: ArgumentSemantics::Default,
},
Some(details.typ.into()),
Some(ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src2 = visitor.id(
ArgumentDescriptor {
@ -3498,6 +3525,8 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
is_dst: false,
sema: if details.src_is_address {
ArgumentSemantics::Address
} else if details.relaxed_src2_conv {
ArgumentSemantics::DefaultRelaxed
} else {
ArgumentSemantics::Default
},
@ -3527,13 +3556,14 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
Ok(ast::Arg2MovMember::Src(dst, src))
}
ast::Arg2MovMember::Both((dst, len), composite_src, src) => {
let scalar_type = details.typ.get_scalar()?;
let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(details.typ.into()),
Some(ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let composite_src = visitor.id(
ArgumentDescriptor {
@ -3541,16 +3571,19 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
is_dst: false,
sema: ArgumentSemantics::Default,
},
Some(details.typ.into()),
Some(ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let scalar_typ = details.typ.get_scalar()?;
let src = visitor.src_member_operand(
ArgumentDescriptor {
op: src,
is_dst: false,
sema: ArgumentSemantics::Default,
sema: if details.relaxed_src2_conv {
ArgumentSemantics::DefaultRelaxed
} else {
ArgumentSemantics::Default
},
},
(scalar_typ.into(), details.src_width),
(scalar_type.into(), details.src_width),
)?;
Ok(ast::Arg2MovMember::Both((dst, len), composite_src, src))
}
@ -3570,7 +3603,8 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::Type,
typ: ast::Type,
is_wide: bool,
) -> Result<ast::Arg3<U>, TranslateError> {
let dst = visitor.id(
ArgumentDescriptor {
@ -3578,7 +3612,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(t),
Some(if is_wide { typ.widen()? } else { typ }),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@ -3586,7 +3620,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
is_dst: false,
sema: ArgumentSemantics::Default,
},
t,
typ,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
@ -3594,7 +3628,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
is_dst: false,
sema: ArgumentSemantics::Default,
},
t,
typ,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@ -3646,6 +3680,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
self,
visitor: &mut V,
t: ast::Type,
is_wide: bool,
) -> Result<ast::Arg4<U>, TranslateError> {
let dst = visitor.id(
ArgumentDescriptor {
@ -3653,7 +3688,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(t),
Some(if is_wide { t.widen()? } else { t }),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@ -4050,6 +4085,54 @@ impl<T> ast::OperandOrVector<T> {
}
}
impl ast::MulDetails {
fn is_wide(&self) -> bool {
match self {
ast::MulDetails::Int(desc) => desc.control == ast::MulIntControl::Wide,
ast::MulDetails::Float(_) => false,
}
}
}
fn force_bitcast(
operand: ast::Type,
instr: ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if instr != operand {
Ok(Some(ConversionKind::Default))
} else {
Ok(None)
}
}
fn bitcast_physical_pointer(
operand_type: ast::Type,
_: ast::Type,
ss: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
match operand_type {
ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::S64) => {
if let Some(space) = ss {
Ok(Some(ConversionKind::BitToPtr(space)))
} else {
Err(TranslateError::Unreachable)
}
}
_ => Err(TranslateError::MismatchedType),
}
}
fn force_bitcast_ptr_to_bit(
_: ast::Type,
_: ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
Ok(Some(ConversionKind::PtrToBit))
}
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
@ -4077,187 +4160,50 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
}
}
fn insert_with_conversions<T, ToInstruction: FnOnce(T) -> ast::Instruction<ExpandedArgParams>>(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
mut instr: T,
pre_conv_src: impl ExactSizeIterator<Item = ImplicitConversion>,
pre_conv_dst: impl ExactSizeIterator<Item = ImplicitConversion>,
mut post_conv: Vec<ImplicitConversion>,
mut src: impl FnMut(&mut T) -> &mut spirv::Word,
mut dst: impl FnMut(&mut T) -> &mut spirv::Word,
to_inst: ToInstruction,
) {
insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_src, &mut src);
insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_dst, &mut dst);
if post_conv.len() > 0 {
let new_id = id_def.new_id(post_conv[0].from);
post_conv[0].src = new_id;
post_conv.last_mut().unwrap().dst = *dst(&mut instr);
*dst(&mut instr) = new_id;
}
func.push(Statement::Instruction(to_inst(instr)));
for conv in post_conv {
func.push(Statement::Conversion(conv));
fn should_bitcast_packed(
operand: ast::Type,
instr: ast::Type,
ss: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
(operand, instr)
{
if scalar.kind() == ScalarKind::Bit
&& scalar.width() == (vec_underlying_type.width() * vec_len)
{
return Ok(Some(ConversionKind::Default));
}
}
should_bitcast_wrapper(operand, instr, ss)
}
fn insert_with_conversions_pre_conv<T>(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
mut instr: &mut T,
pre_conv: impl ExactSizeIterator<Item = ImplicitConversion>,
src: &mut impl FnMut(&mut T) -> &mut spirv::Word,
) {
let pre_conv_len = pre_conv.len();
for (i, mut conv) in pre_conv.enumerate() {
let original_src = src(&mut instr);
if i == 0 {
conv.src = *original_src;
}
if i == pre_conv_len - 1 {
let new_id = id_def.new_id(conv.to);
conv.dst = new_id;
*original_src = new_id;
}
func.push(Statement::Conversion(conv));
fn should_bitcast_wrapper(
operand: ast::Type,
instr: ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if instr == operand {
return Ok(None);
}
}
fn get_implicit_conversions_ld_dst<
ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
>(
id_def: &mut MutableNumericIdResolver,
instr_type: ast::Type,
dst: spirv::Word,
should_convert: ShouldConvert,
in_reverse: bool,
) -> Result<Option<ImplicitConversion>, TranslateError> {
let dst_type = id_def.get_typed(dst)?;
if let Some(conv) = should_convert(dst_type, instr_type) {
Ok(Some(ImplicitConversion {
src: u32::max_value(),
dst: u32::max_value(),
from: if !in_reverse { instr_type } else { dst_type },
to: if !in_reverse { dst_type } else { instr_type },
kind: conv,
}))
if should_bitcast(instr, operand) {
Ok(Some(ConversionKind::Default))
} else {
Ok(None)
Err(TranslateError::MismatchedType)
}
}
fn get_implicit_conversions_ld_src(
id_def: &mut MutableNumericIdResolver,
instr_type: ast::Type,
state_space: ast::LdStateSpace,
src: spirv::Word,
in_reverse_param_local: bool,
) -> Result<Vec<ImplicitConversion>, TranslateError> {
let src_type = id_def.get_typed(src)?;
match state_space {
ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
if src_type != instr_type {
Ok(vec![
ImplicitConversion {
src: u32::max_value(),
dst: u32::max_value(),
from: if !in_reverse_param_local {
src_type
} else {
instr_type
},
to: if !in_reverse_param_local {
instr_type
} else {
src_type
},
kind: ConversionKind::Default,
};
1
])
} else {
Ok(Vec::new())
}
}
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
mem::size_of::<usize>() as u8,
ScalarKind::Bit,
));
let mut result = Vec::new();
// HACK ALERT
// IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an
// additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier
// TODO: error out if the src is not B64/U64/S64
if let ast::Type::Scalar(scalar_src_type) = src_type {
if scalar_src_type.kind() == ScalarKind::Signed {
result.push(ImplicitConversion {
src: u32::max_value(),
dst: u32::max_value(),
from: src_type,
to: new_src_type,
kind: ConversionKind::Default,
});
}
}
result.push(ImplicitConversion {
src: u32::max_value(),
dst: u32::max_value(),
from: src_type,
to: instr_type,
kind: ConversionKind::BitToPtr(state_space),
});
if result.len() == 2 {
let new_id = id_def.new_id(new_src_type);
result[0].dst = new_id;
result[1].src = new_id;
result[1].from = new_src_type;
}
Ok(result)
}
_ => Err(TranslateError::Todo),
}
}
#[must_use]
fn insert_conversion_src(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
src: spirv::Word,
fn should_convert_relaxed_src_wrapper(
src_type: ast::Type,
instr_type: ast::Type,
conv: ConversionKind,
) -> spirv::Word {
let temp_src = id_def.new_id(instr_type);
func.push(Statement::Conversion(ImplicitConversion {
src: src,
dst: temp_src,
from: src_type,
to: instr_type,
kind: conv,
}));
temp_src
}
#[must_use]
fn get_conversion_dst(
id_def: &mut MutableNumericIdResolver,
dst: &mut spirv::Word,
instr_type: ast::Type,
dst_type: ast::Type,
kind: ConversionKind,
) -> ExpandedStatement {
let original_dst = *dst;
let temp_dst = id_def.new_id(instr_type);
*dst = temp_dst;
Statement::Conversion(ImplicitConversion {
src: temp_dst,
dst: original_dst,
from: instr_type,
to: dst_type,
kind: kind,
})
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if src_type == instr_type {
return Ok(None);
}
match should_convert_relaxed_src(src_type, instr_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
@ -4302,6 +4248,20 @@ fn should_convert_relaxed_src(
}
}
fn should_convert_relaxed_dst_wrapper(
dst_type: ast::Type,
instr_type: ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if dst_type == instr_type {
return Ok(None);
}
match should_convert_relaxed_dst(dst_type, instr_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: ast::Type,
@ -4357,55 +4317,6 @@ fn should_convert_relaxed_dst(
}
}
fn insert_implicit_bitcasts(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
stmt: impl VisitVariableExpanded,
) -> Result<(), TranslateError> {
let mut dst_coercion = None;
let instr = stmt.visit_variable_extended(&mut |mut desc, typ| {
let id_type_from_instr = match typ {
Some(t) => t,
None => return Ok(desc.op),
};
let id_actual_type = id_def.get_typed(desc.op)?;
let conv_kind = if desc.sema == ArgumentSemantics::Address {
Some(ConversionKind::PtrToBit)
} else if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) {
Some(ConversionKind::Default)
} else {
None
};
if let Some(conv_kind) = conv_kind {
if desc.is_dst {
dst_coercion = Some(get_conversion_dst(
id_def,
&mut desc.op,
id_type_from_instr,
id_actual_type,
conv_kind,
));
Ok(desc.op)
} else {
Ok(insert_conversion_src(
func,
id_def,
desc.op,
id_actual_type,
id_type_from_instr,
conv_kind,
))
}
} else {
Ok(desc.op)
}
})?;
func.push(instr);
if let Some(cond) = dst_coercion {
func.push(cond);
}
Ok(())
}
impl<'a> ast::MethodDecl<'a, &'a str> {
fn name(&self) -> &'a str {
match self {