Implement initial support for cvt instruction (only integer-to-integer)

This commit is contained in:
Andrzej Janik 2020-08-05 01:58:01 +02:00
parent a10ee48e91
commit 7b407d1c44
9 changed files with 583 additions and 166 deletions

View file

@ -743,7 +743,7 @@ impl<'a> Kernel<'a> {
check!(sys::zeKernelSetArgumentValue(
self.0,
index,
mem::size_of::<T>(),
mem::size_of::<*const ()>(),
&ptr as *const _ as *const _,
));
Ok(())

View file

@ -268,6 +268,7 @@ pub struct Arg5<P: ArgParams> {
pub src3: P::Operand,
}
#[derive(Copy, Clone)]
pub enum Operand<ID> {
Reg(ID),
RegOffset(ID, i32),
@ -353,6 +354,7 @@ pub struct MulFloatDesc {
pub saturate: bool,
}
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum RoundingMode {
NearestEven,
Zero,

View file

@ -0,0 +1,24 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry cvt_sat_s_u(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .s32 temp;
.reg .u32 temp2;
.reg .s32 temp3;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.s32 temp, [in_addr];
cvt.sat.u32.s32 temp2, temp;
cvt.s32.u32 temp3, temp2;
st.s32 [out_addr], temp3;
ret;
}

View file

@ -0,0 +1,43 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%1 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %5 "cvt_sat_s_u"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%4 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%5 = OpFunction %void None %4
%6 = OpFunctionParameter %ulong
%7 = OpFunctionParameter %ulong
%23 = OpLabel
%8 = OpVariable %_ptr_Function_ulong Function
%9 = OpVariable %_ptr_Function_ulong Function
%10 = OpVariable %_ptr_Function_uint Function
%11 = OpVariable %_ptr_Function_uint Function
%12 = OpVariable %_ptr_Function_uint Function
OpStore %8 %6
OpStore %9 %7
%14 = OpLoad %ulong %8
%21 = OpConvertUToPtr %_ptr_Generic_uint %14
%13 = OpLoad %uint %21
OpStore %10 %13
%16 = OpLoad %uint %10
%15 = OpSatConvertSToU %uint %16
OpStore %11 %15
%18 = OpLoad %uint %11
%17 = OpBitcast %uint %18
OpStore %12 %17
%19 = OpLoad %ulong %9
%20 = OpLoad %uint %12
%22 = OpConvertUToPtr %_ptr_Generic_uint %19
OpStore %22 %20
OpReturn
OpFunctionEnd

View file

@ -1,7 +1,7 @@
use crate::ptx;
use crate::translate;
use rspirv::{
binary::Assemble,
binary::{Assemble, Disassemble},
dr::{Block, Function, Instruction, Loader, Operand},
};
use spirv_headers::Word;
@ -48,6 +48,7 @@ test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]);
test_ptx!(bra, [10u64], [11u64]);
test_ptx!(not, [0u64], [u64::max_value()]);
test_ptx!(shl, [11u64], [44u64]);
test_ptx!(cvt_sat_s_u, [0i32], [0i32]);
struct DisplayError<T: Display + Debug> {
err: T,

View file

@ -12,7 +12,6 @@
%4 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_0 = OpTypeInt 64 0
%5 = OpFunction %void None %4
%6 = OpFunctionParameter %ulong
%7 = OpFunctionParameter %ulong
@ -27,8 +26,8 @@
%18 = OpConvertUToPtr %_ptr_Generic_ulong %13
%12 = OpLoad %ulong %18
OpStore %10 %12
%15 = OpLoad %ulong_0 %10
%14 = OpNot %ulong_0 %15
%15 = OpLoad %ulong %10
%14 = OpNot %ulong %15
OpStore %11 %14
%16 = OpLoad %ulong %9
%17 = OpLoad %ulong %11
@ -36,4 +35,3 @@
OpStore %19 %17
OpReturn
OpFunctionEnd

View file

@ -20,7 +20,7 @@
%5 = OpFunction %void None %4
%6 = OpFunctionParameter %ulong
%7 = OpFunctionParameter %ulong
%38 = OpLabel
%39 = OpLabel
%8 = OpVariable %_ptr_Function_ulong Function
%9 = OpVariable %_ptr_Function_ulong Function
%10 = OpVariable %_ptr_Function_ulong Function
@ -34,9 +34,10 @@
%18 = OpLoad %ulong %35
OpStore %10 %18
%21 = OpLoad %ulong %8
%32 = OpIAdd %ulong %21 %ulong_8
%36 = OpConvertUToPtr %_ptr_Generic_ulong %32
%20 = OpLoad %ulong %36
%36 = OpCopyObject %ulong %21
%32 = OpIAdd %ulong %36 %ulong_8
%37 = OpConvertUToPtr %_ptr_Generic_ulong %32
%20 = OpLoad %ulong %37
OpStore %11 %20
%23 = OpLoad %ulong %10
%24 = OpLoad %ulong %11
@ -58,8 +59,7 @@
%17 = OpLabel
%29 = OpLoad %ulong %9
%30 = OpLoad %ulong %12
%37 = OpConvertUToPtr %_ptr_Generic_ulong %29
OpStore %37 %30
%38 = OpConvertUToPtr %_ptr_Generic_ulong %29
OpStore %38 %30
OpReturn
OpFunctionEnd

View file

@ -12,7 +12,6 @@
%4 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_0 = OpTypeInt 64 0
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
%5 = OpFunction %void None %4
@ -29,8 +28,8 @@
%19 = OpConvertUToPtr %_ptr_Generic_ulong %13
%12 = OpLoad %ulong %19
OpStore %10 %12
%15 = OpLoad %ulong_0 %10
%14 = OpShiftLeftLogical %ulong_0 %15 %uint_2
%15 = OpLoad %ulong %10
%14 = OpShiftLeftLogical %ulong %15 %uint_2
OpStore %11 %14
%16 = OpLoad %ulong %9
%17 = OpLoad %ulong %11
@ -38,4 +37,3 @@
OpStore %20 %17
OpReturn
OpFunctionEnd

View file

@ -7,25 +7,85 @@ use rspirv::binary::Assemble;
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvType {
Base(ast::ScalarType),
Extended(ast::ExtendedScalarType),
Pointer(ast::Type, spirv::StorageClass),
Base(SpirvScalarKey),
Pointer(SpirvScalarKey, spirv::StorageClass),
}
impl SpirvType {
fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
let key = match t {
ast::Type::Scalar(typ) => SpirvScalarKey::from(typ),
ast::Type::ExtendedScalar(typ) => SpirvScalarKey::from(typ),
};
SpirvType::Pointer(key, sc)
}
}
impl From<ast::Type> for SpirvType {
fn from(t: ast::Type) -> Self {
match t {
ast::Type::Scalar(t) => SpirvType::Base(t),
ast::Type::ExtendedScalar(t) => SpirvType::Extended(t),
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::ExtendedScalar(t) => SpirvType::Base(t.into()),
}
}
}
impl From<ast::ScalarType> for SpirvType {
fn from(t: ast::ScalarType) -> Self {
SpirvType::Base(t.into())
}
}
struct TypeWordMap {
void: spirv::Word,
complex: HashMap<SpirvType, spirv::Word>,
}
// SPIR-V integer type definitions are signless, more below:
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvScalarKey {
B8,
B16,
B32,
B64,
F16,
F32,
F64,
Pred,
F16x2,
}
impl From<ast::ScalarType> for SpirvScalarKey {
fn from(t: ast::ScalarType) -> Self {
match t {
ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8,
ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
SpirvScalarKey::B16
}
ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => {
SpirvScalarKey::B32
}
ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => {
SpirvScalarKey::B64
}
ast::ScalarType::F16 => SpirvScalarKey::F16,
ast::ScalarType::F32 => SpirvScalarKey::F32,
ast::ScalarType::F64 => SpirvScalarKey::F64,
}
}
}
impl From<ast::ExtendedScalarType> for SpirvScalarKey {
fn from(t: ast::ExtendedScalarType) -> Self {
match t {
ast::ExtendedScalarType::Pred => SpirvScalarKey::Pred,
ast::ExtendedScalarType::F16x2 => SpirvScalarKey::F16x2,
}
}
}
impl TypeWordMap {
fn new(b: &mut dr::Builder) -> TypeWordMap {
let void = b.type_void();
@ -40,21 +100,24 @@ impl TypeWordMap {
}
fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word {
let key: SpirvScalarKey = t.into();
self.get_or_add_spirv_scalar(b, key)
}
fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> spirv::Word {
*self
.complex
.entry(SpirvType::Base(t))
.or_insert_with(|| match t {
ast::ScalarType::B8 | ast::ScalarType::U8 => b.type_int(8, 0),
ast::ScalarType::B16 | ast::ScalarType::U16 => b.type_int(16, 0),
ast::ScalarType::B32 | ast::ScalarType::U32 => b.type_int(32, 0),
ast::ScalarType::B64 | ast::ScalarType::U64 => b.type_int(64, 0),
ast::ScalarType::S8 => b.type_int(8, 1),
ast::ScalarType::S16 => b.type_int(16, 1),
ast::ScalarType::S32 => b.type_int(32, 1),
ast::ScalarType::S64 => b.type_int(64, 1),
ast::ScalarType::F16 => b.type_float(16),
ast::ScalarType::F32 => b.type_float(32),
ast::ScalarType::F64 => b.type_float(64),
.entry(SpirvType::Base(key))
.or_insert_with(|| match key {
SpirvScalarKey::B8 => b.type_int(8, 0),
SpirvScalarKey::B16 => b.type_int(16, 0),
SpirvScalarKey::B32 => b.type_int(32, 0),
SpirvScalarKey::B64 => b.type_int(64, 0),
SpirvScalarKey::F16 => b.type_float(16),
SpirvScalarKey::F32 => b.type_float(32),
SpirvScalarKey::F64 => b.type_float(64),
SpirvScalarKey::Pred => b.type_bool(),
SpirvScalarKey::F16x2 => todo!(),
})
}
@ -63,24 +126,15 @@ impl TypeWordMap {
b: &mut dr::Builder,
t: ast::ExtendedScalarType,
) -> spirv::Word {
*self
.complex
.entry(SpirvType::Extended(t))
.or_insert_with(|| match t {
ast::ExtendedScalarType::Pred => b.type_bool(),
ast::ExtendedScalarType::F16x2 => todo!(),
})
let key: SpirvScalarKey = t.into();
self.get_or_add_spirv_scalar(b, key)
}
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
match t {
SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar),
SpirvType::Extended(ext) => self.get_or_add_extended(b, ext),
SpirvType::Pointer(typ, storage) => {
let base = match typ {
ast::Type::Scalar(scalar) => self.get_or_add_scalar(b, scalar),
ast::Type::ExtendedScalar(ext) => self.get_or_add_extended(b, ext),
};
SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key),
SpirvType::Pointer(typ, mut storage) => {
let base = self.get_or_add_spirv_scalar(b, typ);
*self
.complex
.entry(t)
@ -102,7 +156,7 @@ impl TypeWordMap {
pub fn to_spirv_module(ast: ast::Module) -> Result<dr::Module, dr::Error> {
let mut builder = dr::Builder::new();
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
builder.set_version(1, 0);
builder.set_version(1, 3);
emit_capabilities(&mut builder);
emit_extensions(&mut builder);
let opencl_id = emit_opencl_import(&mut builder);
@ -277,24 +331,25 @@ fn insert_mem_ssa_statements(
}
inst => {
let mut post_statements = Vec::new();
let inst = inst.visit_variable(&mut |id, is_dst, id_type| {
let id_type = match id_type {
Some(t) => t,
None => return id,
let inst = inst.visit_variable(&mut |desc| {
let id_type = match (desc.typ, desc.is_pointer) {
(Some(t), false) => t,
(Some(_), true) => ast::Type::Scalar(ast::ScalarType::B64),
(None, _) => return desc.op,
};
let generated_id = id_def.new_id(Some(id_type));
if !is_dst {
if !desc.is_dst {
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
src: id,
src: desc.op,
},
id_type,
));
} else {
post_statements.push(Statement::StoreVar(
Arg2St {
src1: id,
src1: desc.op,
src2: generated_id,
},
id_type,
@ -365,15 +420,15 @@ impl<'a> FlattenArguments<'a> {
}
impl<'a> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams> for FlattenArguments<'a> {
fn dst_variable(&mut self, x: spirv::Word, _: Option<ast::Type>) -> spirv::Word {
x
fn dst_variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
desc.op
}
fn src_operand(&mut self, op: ast::Operand<spirv::Word>, t: Option<ast::Type>) -> spirv::Word {
match op {
fn src_operand(&mut self, desc: ArgumentDescriptor<ast::Operand<spirv::Word>>) -> spirv::Word {
match desc.op {
ast::Operand::Reg(r) => r,
ast::Operand::Imm(x) => {
if let Some(typ) = t {
if let Some(typ) = desc.typ {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
scalar
} else {
@ -391,7 +446,7 @@ impl<'a> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams> for FlattenA
}
}
ast::Operand::RegOffset(reg, offset) => {
if let Some(typ) = t {
if let Some(typ) = desc.typ {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
scalar
} else {
@ -403,7 +458,7 @@ impl<'a> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams> for FlattenA
typ: scalar_t,
value: offset as i128,
}));
let result_id = self.id_def.new_id(t);
let result_id = self.id_def.new_id(desc.typ);
let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Add(
@ -428,11 +483,10 @@ impl<'a> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams> for FlattenA
fn src_mov_operand(
&mut self,
op: ast::MovOperand<spirv::Word>,
t: Option<ast::Type>,
desc: ArgumentDescriptor<ast::MovOperand<spirv::Word>>,
) -> spirv::Word {
match op {
ast::MovOperand::Op(opr) => self.src_operand(opr, t),
match &desc.op {
ast::MovOperand::Op(opr) => self.src_operand(desc.new_op(*opr)),
ast::MovOperand::Vec(_, _) => todo!(),
}
}
@ -517,7 +571,7 @@ fn get_function_type(
map: &mut TypeWordMap,
args: &[ast::Argument],
) -> spirv::Word {
map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::Base(arg.a_type)))
map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::from(arg.a_type)))
}
fn emit_function_args(
@ -565,7 +619,7 @@ fn emit_function_body_ops(
Statement::Variable(id, typ, ss) => {
let type_id = map.get_or_add(
builder,
SpirvType::Pointer(*typ, spirv::StorageClass::Function),
SpirvType::new_pointer(*typ, spirv::StorageClass::Function),
);
if *ss != ast::StateSpace::Reg {
todo!()
@ -672,7 +726,10 @@ fn emit_function_body_ops(
let result_type = map.get_or_add(builder, SpirvType::from(t.to_type()));
builder.shift_left_logical(result_type, Some(a.dst), a.src1, a.src2)?;
}
_ => todo!(),
ast::Instruction::Cvt(dets, arg) => {
emit_cvt(builder, map, opencl, dets, arg)?;
}
ast::Instruction::SetpBool(_, _) => todo!(),
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(*typ));
@ -686,6 +743,133 @@ fn emit_function_body_ops(
Ok(())
}
fn emit_cvt(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
dets: &ast::CvtDetails,
arg: &ast::Arg2<ExpandedArgParams>,
) -> Result<(), dr::Error> {
match dets {
ast::CvtDetails::FloatFromFloat(desc) => {
if desc.dst == desc.src {
return Ok(());
}
if desc.saturate || desc.flush_to_zero {
todo!()
}
let dest_t: ast::Type = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
builder.f_convert(result_type, Some(arg.dst), arg.src)?;
emit_rounding_decoration(builder, arg.dst, desc.rounding);
}
ast::CvtDetails::FloatFromInt(desc) => {
if desc.saturate || desc.flush_to_zero {
todo!()
}
let dest_t: ast::Type = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.src.is_signed() {
builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?;
} else {
builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?;
}
emit_rounding_decoration(builder, arg.dst, desc.rounding);
}
ast::CvtDetails::IntFromFloat(desc) => {
if desc.flush_to_zero {
todo!()
}
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.dst.is_signed() {
builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?;
} else {
builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?;
}
emit_rounding_decoration(builder, arg.dst, desc.rounding);
emit_saturating_decoration(builder, arg.dst, desc.saturate);
}
ast::CvtDetails::IntFromInt(desc) => {
if desc.dst == desc.src {
return Ok(());
}
let dest_t: ast::ScalarType = desc.dst.into();
let src_t: ast::ScalarType = desc.src.into();
// first do shortening/widening
let src = if desc.dst.width() != desc.src.width() {
let new_dst = if dest_t.kind() == src_t.kind() {
arg.dst
} else {
builder.id()
};
let cv = ImplicitConversion {
src: arg.src,
dst: new_dst,
from: ast::Type::Scalar(src_t),
to: ast::Type::Scalar(ast::ScalarType::from_parts(
dest_t.width(),
src_t.kind(),
)),
kind: ConversionKind::Default,
};
emit_implicit_conversion(builder, map, &cv)?;
new_dst
} else {
arg.src
};
if dest_t.kind() == src_t.kind() {
return Ok(());
}
// now do actual conversion
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.saturate {
if desc.dst.is_signed() {
builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?;
} else {
builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?;
}
} else {
builder.bitcast(result_type, Some(arg.dst), src)?;
}
}
_ => todo!(),
}
Ok(())
}
fn emit_saturating_decoration(builder: &mut dr::Builder, dst: u32, saturate: bool) {
if saturate {
builder.decorate(dst, spirv::Decoration::SaturatedConversion, []);
}
}
fn emit_rounding_decoration(
builder: &mut dr::Builder,
dst: spirv::Word,
rounding: Option<ast::RoundingMode>,
) {
if let Some(rounding) = rounding {
builder.decorate(
dst,
spirv::Decoration::FPRoundingMode,
[rounding.to_spirv()],
);
}
}
impl ast::RoundingMode {
fn to_spirv(self) -> rspirv::dr::Operand {
let mode = match self {
ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE,
ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ,
ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP,
ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN,
};
rspirv::dr::Operand::FPRoundingMode(mode)
}
}
fn emit_setp(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@ -695,7 +879,7 @@ fn emit_setp(
if setp.flush_to_zero {
todo!()
}
let result_type = map.get_or_add(builder, SpirvType::Extended(ast::ExtendedScalarType::Pred));
let result_type = map.get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred));
let result_id = Some(arg.dst1);
let operand_1 = arg.src1;
let operand_2 = arg.src2;
@ -768,7 +952,7 @@ fn emit_mul_int(
desc: &ast::MulIntDesc,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::Base(desc.typ.into()));
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
match desc.control {
ast::MulIntControl::Low => {
builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
@ -798,7 +982,7 @@ fn emit_add_int(
ctr: &ast::AddIntDesc,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::Base(ctr.typ.into()));
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(ctr.typ)));
builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
Ok(())
}
@ -817,7 +1001,7 @@ fn emit_implicit_conversion(
let dst_type = map.get_or_add(
builder,
SpirvType::Pointer(
ast::Type::Scalar(to_type),
SpirvScalarKey::from(to_type),
spirv_headers::StorageClass::Generic,
),
);
@ -826,14 +1010,12 @@ fn emit_implicit_conversion(
ConversionKind::Default => {
if from_type.width() == to_type.width() {
let dst_type = map.get_or_add_scalar(builder, to_type);
if from_type.kind() == ScalarKind::Unsigned && to_type.kind() == ScalarKind::Byte
|| from_type.kind() == ScalarKind::Byte
&& to_type.kind() == ScalarKind::Unsigned
{
if from_type.kind() != ScalarKind::Float && to_type.kind() != ScalarKind::Float {
// It is noop, but another instruction expects result of this conversion
builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
} else {
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
}
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
} else {
let as_unsigned_type = map.get_or_add_scalar(
builder,
@ -1057,23 +1239,23 @@ impl ast::ArgParams for ExpandedArgParams {
}
trait ArgumentMapVisitor<T: ast::ArgParams, U: ast::ArgParams> {
fn dst_variable(&mut self, v: T::ID, typ: Option<ast::Type>) -> U::ID;
fn src_operand(&mut self, o: T::Operand, typ: Option<ast::Type>) -> U::Operand;
fn src_mov_operand(&mut self, o: T::MovOperand, typ: Option<ast::Type>) -> U::MovOperand;
fn dst_variable(&mut self, desc: ArgumentDescriptor<T::ID>) -> U::ID;
fn src_operand(&mut self, desc: ArgumentDescriptor<T::Operand>) -> U::Operand;
fn src_mov_operand(&mut self, desc: ArgumentDescriptor<T::MovOperand>) -> U::MovOperand;
}
impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
where
T: FnMut(spirv::Word, bool, Option<ast::Type>) -> spirv::Word,
T: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word,
{
fn dst_variable(&mut self, x: spirv::Word, t: Option<ast::Type>) -> spirv::Word {
self(x, t.is_some(), t)
fn dst_variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
self(desc)
}
fn src_operand(&mut self, x: spirv::Word, t: Option<ast::Type>) -> spirv::Word {
self(x, false, t)
fn src_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
self(desc)
}
fn src_mov_operand(&mut self, x: spirv::Word, t: Option<ast::Type>) -> spirv::Word {
self(x, false, t)
fn src_mov_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
self(desc)
}
}
@ -1081,16 +1263,15 @@ impl<'a, T> ArgumentMapVisitor<ast::ParsedArgParams<'a>, NormalizedArgParams> fo
where
T: FnMut(&str) -> spirv::Word,
{
fn dst_variable(&mut self, x: &str, _: Option<ast::Type>) -> spirv::Word {
self(x)
fn dst_variable(&mut self, desc: ArgumentDescriptor<&str>) -> spirv::Word {
self(desc.op)
}
fn src_operand(
&mut self,
x: ast::Operand<&str>,
_: Option<ast::Type>,
desc: ArgumentDescriptor<ast::Operand<&str>>,
) -> ast::Operand<spirv::Word> {
match x {
match desc.op {
ast::Operand::Reg(id) => ast::Operand::Reg(self(id)),
ast::Operand::Imm(imm) => ast::Operand::Imm(imm),
ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id), imm),
@ -1099,16 +1280,33 @@ where
fn src_mov_operand(
&mut self,
x: ast::MovOperand<&str>,
t: Option<ast::Type>,
desc: ArgumentDescriptor<ast::MovOperand<&str>>,
) -> ast::MovOperand<spirv::Word> {
match x {
ast::MovOperand::Op(op) => ast::MovOperand::Op(self.src_operand(op, t)),
match desc.op {
ast::MovOperand::Op(op) => ast::MovOperand::Op(self.src_operand(desc.new_op(op))),
ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2),
}
}
}
struct ArgumentDescriptor<T> {
op: T,
is_dst: bool,
typ: Option<ast::Type>,
is_pointer: bool,
}
impl<T> ArgumentDescriptor<T> {
fn new_op<U>(&self, u: U) -> ArgumentDescriptor<U> {
ArgumentDescriptor {
op: u,
is_dst: self.is_dst,
typ: self.typ,
is_pointer: self.is_pointer,
}
}
}
impl<T: ast::ArgParams> ast::Instruction<T> {
fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
self,
@ -1117,7 +1315,7 @@ impl<T: ast::ArgParams> ast::Instruction<T> {
match self {
ast::Instruction::Ld(d, a) => {
let inst_type = d.typ;
ast::Instruction::Ld(d, a.map(visitor, Some(ast::Type::Scalar(inst_type))))
ast::Instruction::Ld(d, a.map_ld(visitor, Some(ast::Type::Scalar(inst_type))))
}
ast::Instruction::Mov(d, a) => {
let inst_type = d.typ;
@ -1142,7 +1340,22 @@ impl<T: ast::ArgParams> ast::Instruction<T> {
ast::Instruction::Not(t, a) => {
ast::Instruction::Not(t, a.map(visitor, Some(t.to_type())))
}
ast::Instruction::Cvt(_, _) => todo!(),
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (desc.dst.into(), desc.src.into()),
ast::CvtDetails::FloatFromInt(desc) => {
(desc.dst.into(), ast::Type::Scalar(desc.src.into()))
}
ast::CvtDetails::IntFromFloat(desc) => {
(ast::Type::Scalar(desc.dst.into()), desc.src.into())
}
ast::CvtDetails::IntFromInt(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
};
ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t))
}
ast::Instruction::Shl(t, a) => {
ast::Instruction::Shl(t, a.map_shift(visitor, Some(t.to_type())))
}
@ -1157,7 +1370,7 @@ impl<T: ast::ArgParams> ast::Instruction<T> {
}
impl ast::Instruction<NormalizedArgParams> {
fn visit_variable<F: FnMut(spirv::Word, bool, Option<ast::Type>) -> spirv::Word>(
fn visit_variable<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
self,
f: &mut F,
) -> ast::Instruction<NormalizedArgParams> {
@ -1167,34 +1380,34 @@ impl ast::Instruction<NormalizedArgParams> {
impl<T> ArgumentMapVisitor<NormalizedArgParams, NormalizedArgParams> for T
where
T: FnMut(spirv::Word, bool, Option<ast::Type>) -> spirv::Word,
T: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word,
{
fn dst_variable(&mut self, x: spirv::Word, t: Option<ast::Type>) -> spirv::Word {
self(x, t.is_some(), t)
fn dst_variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
self(desc)
}
fn src_operand(
&mut self,
x: ast::Operand<spirv::Word>,
t: Option<ast::Type>,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
) -> ast::Operand<spirv::Word> {
match x {
ast::Operand::Reg(id) => ast::Operand::Reg(self(id, false, t)),
match desc.op {
ast::Operand::Reg(id) => ast::Operand::Reg(self(desc.new_op(id))),
ast::Operand::Imm(imm) => ast::Operand::Imm(imm),
ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id, false, t), imm),
ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(desc.new_op(id)), imm),
}
}
fn src_mov_operand(
&mut self,
x: ast::MovOperand<spirv::Word>,
t: Option<ast::Type>,
desc: ArgumentDescriptor<ast::MovOperand<spirv::Word>>,
) -> ast::MovOperand<spirv::Word> {
match x {
match desc.op {
ast::MovOperand::Op(op) => ast::MovOperand::Op(ArgumentMapVisitor::<
NormalizedArgParams,
NormalizedArgParams,
>::src_operand(self, op, t)),
>::src_operand(
self, desc.new_op(op)
)),
ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2),
}
}
@ -1202,8 +1415,8 @@ where
fn reduced_visitor<'a>(
f: &'a mut impl FnMut(spirv::Word) -> spirv::Word,
) -> impl FnMut(spirv::Word, bool, Option<ast::Type>) -> spirv::Word + 'a {
move |id, _, _| f(id)
) -> impl FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word + 'a {
move |desc| f(desc.op)
}
impl ast::Instruction<ExpandedArgParams> {
@ -1212,7 +1425,7 @@ impl ast::Instruction<ExpandedArgParams> {
self.map(&mut visitor)
}
fn visit_variable_extended<F: FnMut(spirv::Word, bool, Option<ast::Type>) -> spirv::Word>(
fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
self,
f: &mut F,
) -> Self {
@ -1326,7 +1539,12 @@ impl<T: ast::ArgParams> ast::Arg1<T> {
t: Option<ast::Type>,
) -> ast::Arg1<U> {
ast::Arg1 {
src: visitor.dst_variable(self.src, t),
src: visitor.dst_variable(ArgumentDescriptor {
op: self.src,
typ: t,
is_dst: false,
is_pointer: false,
}),
}
}
}
@ -1338,8 +1556,61 @@ impl<T: ast::ArgParams> ast::Arg2<T> {
t: Option<ast::Type>,
) -> ast::Arg2<U> {
ast::Arg2 {
dst: visitor.dst_variable(self.dst, t),
src: visitor.src_operand(self.src, t),
dst: visitor.dst_variable(ArgumentDescriptor {
op: self.dst,
typ: t,
is_dst: true,
is_pointer: false,
}),
src: visitor.src_operand(ArgumentDescriptor {
op: self.src,
typ: t,
is_dst: false,
is_pointer: false,
}),
}
}
fn map_ld<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: Option<ast::Type>,
) -> ast::Arg2<U> {
ast::Arg2 {
dst: visitor.dst_variable(ArgumentDescriptor {
op: self.dst,
typ: t,
is_dst: true,
is_pointer: false,
}),
src: visitor.src_operand(ArgumentDescriptor {
op: self.src,
typ: t,
is_dst: false,
is_pointer: true,
}),
}
}
fn map_cvt<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
dst_t: ast::Type,
src_t: ast::Type,
) -> ast::Arg2<U> {
ast::Arg2 {
dst: visitor.dst_variable(ArgumentDescriptor {
op: self.dst,
typ: Some(dst_t),
is_dst: true,
is_pointer: false,
}),
src: visitor.src_operand(ArgumentDescriptor {
op: self.src,
typ: Some(src_t),
is_dst: false,
is_pointer: false,
}),
}
}
}
@ -1351,8 +1622,18 @@ impl<T: ast::ArgParams> ast::Arg2St<T> {
t: Option<ast::Type>,
) -> ast::Arg2St<U> {
ast::Arg2St {
src1: visitor.src_operand(self.src1, t),
src2: visitor.src_operand(self.src2, t),
src1: visitor.src_operand(ArgumentDescriptor {
op: self.src1,
typ: t,
is_dst: false,
is_pointer: true,
}),
src2: visitor.src_operand(ArgumentDescriptor {
op: self.src2,
typ: t,
is_dst: false,
is_pointer: false,
}),
}
}
}
@ -1364,8 +1645,18 @@ impl<T: ast::ArgParams> ast::Arg2Mov<T> {
t: Option<ast::Type>,
) -> ast::Arg2Mov<U> {
ast::Arg2Mov {
dst: visitor.dst_variable(self.dst, t),
src: visitor.src_mov_operand(self.src, t),
dst: visitor.dst_variable(ArgumentDescriptor {
op: self.dst,
typ: t,
is_dst: true,
is_pointer: false,
}),
src: visitor.src_mov_operand(ArgumentDescriptor {
op: self.src,
typ: t,
is_dst: false,
is_pointer: false,
}),
}
}
}
@ -1377,9 +1668,24 @@ impl<T: ast::ArgParams> ast::Arg3<T> {
t: Option<ast::Type>,
) -> ast::Arg3<U> {
ast::Arg3 {
dst: visitor.dst_variable(self.dst, t),
src1: visitor.src_operand(self.src1, t),
src2: visitor.src_operand(self.src2, t),
dst: visitor.dst_variable(ArgumentDescriptor {
op: self.dst,
typ: t,
is_dst: true,
is_pointer: false,
}),
src1: visitor.src_operand(ArgumentDescriptor {
op: self.src1,
typ: t,
is_dst: false,
is_pointer: false,
}),
src2: visitor.src_operand(ArgumentDescriptor {
op: self.src2,
typ: t,
is_dst: false,
is_pointer: false,
}),
}
}
@ -1389,9 +1695,24 @@ impl<T: ast::ArgParams> ast::Arg3<T> {
t: Option<ast::Type>,
) -> ast::Arg3<U> {
ast::Arg3 {
dst: visitor.dst_variable(self.dst, t),
src1: visitor.src_operand(self.src1, t),
src2: visitor.src_operand(self.src2, Some(ast::Type::Scalar(ast::ScalarType::U32))),
dst: visitor.dst_variable(ArgumentDescriptor {
op: self.dst,
typ: t,
is_dst: true,
is_pointer: false,
}),
src1: visitor.src_operand(ArgumentDescriptor {
op: self.src1,
typ: t,
is_dst: false,
is_pointer: false,
}),
src2: visitor.src_operand(ArgumentDescriptor {
op: self.src2,
typ: Some(ast::Type::Scalar(ast::ScalarType::U32)),
is_dst: false,
is_pointer: false,
}),
}
}
}
@ -1403,18 +1724,32 @@ impl<T: ast::ArgParams> ast::Arg4<T> {
t: Option<ast::Type>,
) -> ast::Arg4<U> {
ast::Arg4 {
dst1: visitor.dst_variable(
self.dst1,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
),
dst2: self.dst2.map(|dst2| {
visitor.dst_variable(
dst2,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
)
dst1: visitor.dst_variable(ArgumentDescriptor {
op: self.dst1,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
is_dst: true,
is_pointer: false,
}),
dst2: self.dst2.map(|dst2| {
visitor.dst_variable(ArgumentDescriptor {
op: dst2,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
is_dst: true,
is_pointer: false,
})
}),
src1: visitor.src_operand(ArgumentDescriptor {
op: self.src1,
typ: t,
is_dst: false,
is_pointer: false,
}),
src2: visitor.src_operand(ArgumentDescriptor {
op: self.src2,
typ: t,
is_dst: false,
is_pointer: false,
}),
src1: visitor.src_operand(self.src1, t),
src2: visitor.src_operand(self.src2, t),
}
}
}
@ -1426,22 +1761,38 @@ impl<T: ast::ArgParams> ast::Arg5<T> {
t: Option<ast::Type>,
) -> ast::Arg5<U> {
ast::Arg5 {
dst1: visitor.dst_variable(
self.dst1,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
),
dst2: self.dst2.map(|dst2| {
visitor.dst_variable(
dst2,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
)
dst1: visitor.dst_variable(ArgumentDescriptor {
op: self.dst1,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
is_dst: true,
is_pointer: false,
}),
dst2: self.dst2.map(|dst2| {
visitor.dst_variable(ArgumentDescriptor {
op: dst2,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
is_dst: true,
is_pointer: false,
})
}),
src1: visitor.src_operand(ArgumentDescriptor {
op: self.src1,
typ: t,
is_dst: false,
is_pointer: false,
}),
src2: visitor.src_operand(ArgumentDescriptor {
op: self.src2,
typ: t,
is_dst: false,
is_pointer: false,
}),
src3: visitor.src_operand(ArgumentDescriptor {
op: self.src3,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
is_dst: false,
is_pointer: false,
}),
src1: visitor.src_operand(self.src1, t),
src2: visitor.src_operand(self.src2, t),
src3: visitor.src_operand(
self.src3,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
),
}
}
}
@ -1851,34 +2202,34 @@ fn insert_implicit_bitcasts(
instr: ast::Instruction<ExpandedArgParams>,
) {
let mut dst_coercion = None;
let instr = instr.visit_variable_extended(&mut |mut id, is_dst, id_type| {
let id_type_from_instr = match id_type {
let instr = instr.visit_variable_extended(&mut |mut desc| {
let id_type_from_instr = match desc.typ {
Some(t) => t,
None => return id,
None => return desc.op,
};
let id_actual_type = id_def.get_type(id);
if should_bitcast(id_type_from_instr, id_def.get_type(id)) {
if is_dst {
let id_actual_type = id_def.get_type(desc.op);
if should_bitcast(id_type_from_instr, id_def.get_type(desc.op)) {
if desc.is_dst {
dst_coercion = Some(get_conversion_dst(
id_def,
&mut id,
&mut desc.op,
id_type_from_instr,
id_actual_type,
ConversionKind::Default,
));
id
desc.op
} else {
insert_conversion_src(
func,
id_def,
id,
desc.op,
id_actual_type,
id_type_from_instr,
ConversionKind::Default,
)
}
} else {
id
desc.op
}
});
func.push(Statement::Instruction(instr));