Finish implementing implicit conversions

This commit is contained in:
Andrzej Janik 2020-06-17 02:53:46 +02:00
parent 4a0edf0e14
commit 279e6246ba
6 changed files with 418 additions and 138 deletions

View file

@ -1,5 +1,7 @@
fn main() {
println!("cargo:rustc-link-lib=dylib=ze_loader");
// TODO: make this windows-only
println!("cargo:rustc-link-search=native=C:\\Windows\\System32");
println!("cargo:rerun-if-changed=build.rs");
}

View file

@ -137,7 +137,7 @@ pub enum Instruction<ID> {
Bra(BraData, Arg1<ID>),
Cvt(CvtData, Arg2<ID>),
Shl(ShlData, Arg3<ID>),
St(StData, Arg2<ID>),
St(StData, Arg2St<ID>),
Ret(RetData),
}
@ -150,6 +150,11 @@ pub struct Arg2<ID> {
pub src: Operand<ID>,
}
pub struct Arg2St<ID> {
pub src1: Operand<ID>,
pub src2: Operand<ID>,
}
pub struct Arg2Mov<ID> {
pub dst: ID,
pub src: MovOperand<ID>,
@ -264,7 +269,7 @@ pub struct StData {
pub typ: ScalarType,
}
#[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum StStateSpace {
Generic,
Global,

View file

@ -12,8 +12,7 @@ extern crate rspirv;
extern crate spirv_headers as spirv;
lalrpop_mod!(
#[allow(dead_code)]
#[allow(unused_imports)]
#[allow(warnings)]
ptx
);

View file

@ -386,7 +386,7 @@ ShlType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
InstSt: ast::Instruction<&'input str> = {
"st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <v:VectorPrefix?> <t:MemoryType> "[" <dst:ID> "]" "," <src:Operand> => {
"st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <v:VectorPrefix?> <t:MemoryType> "[" <src1:Operand> "]" "," <src2:Operand> => {
ast::Instruction::St(
ast::StData {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
@ -395,7 +395,7 @@ InstSt: ast::Instruction<&'input str> = {
vector: v,
typ: t
},
ast::Arg2{dst:dst, src:src}
ast::Arg2St { src1:src1, src2:src2 }
)
}
};

View file

@ -57,7 +57,7 @@ fn test_ptx_assert<'a, T: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq>(
Ok(())
}
fn run_spirv<T: From<u8> + ze::SafeRepr + Copy>(
fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
name: &CStr,
spirv: &[u32],
input: &[T],
@ -84,15 +84,16 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy>(
let event_pool = ze::EventPool::new(&drv, 3, Some(&[&dev]))?;
let ev0 = ze::Event::new(&event_pool, 0)?;
let ev1 = ze::Event::new(&event_pool, 1)?;
let ev2 = ze::Event::new(&event_pool, 2)?;
let mut cmd_list = ze::CommandList::new(&dev)?;
let out_b_ptr: ze::BufferPtrMut<T> = (&mut out_b).into();
let out_b_ptr_mut: ze::BufferPtrMut<T> = (&mut out_b).into();
cmd_list.append_memory_copy(inp_b_ptr_mut, input, None, Some(&ev0))?;
cmd_list.append_memory_fill(out_b_ptr, 0u8.into(), Some(&ev1))?;
cmd_list.append_memory_fill(out_b_ptr_mut, 0u8.into(), Some(&ev1))?;
kernel.set_group_size(1, 1, 1)?;
kernel.set_arg_buffer(0, inp_b_ptr_mut)?;
kernel.set_arg_buffer(1, out_b_ptr)?;
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], None, &[&ev0, &ev1])?;
cmd_list.append_memory_copy(result.as_mut_slice(), inp_b_ptr_mut, None, Some(&ev0))?;
kernel.set_arg_buffer(1, out_b_ptr_mut)?;
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&ev2), &[&ev0, &ev1])?;
cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, Some(&ev2))?;
queue.execute(cmd_list)?;
Ok(result)
}

View file

@ -3,7 +3,7 @@ use bit_vec::BitVec;
use rspirv::dr;
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::{borrow::Cow, fmt};
use std::{borrow::Cow, fmt, mem};
use rspirv::binary::{Assemble, Disassemble};
@ -86,7 +86,7 @@ pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, dr::Error> {
emit_function(&mut builder, &mut map, f)?;
}
let module = builder.module();
dbg!(print!("{}", module.disassemble()));
println!("{}", module.disassemble());
Ok(module.assemble())
}
@ -206,8 +206,8 @@ fn collect_var_definitions<'a>(
documented special ld/st/cvt conversion rules for destination operands
- generic ld: for instruction `ld x, [y]`, y must be of type b64/u64/s64,
which is bitcast to a pointer, dereferenced and then documented special
ld/st/cvt conversion rules are applied
- generic ld: for instruction `ld [x], y`, x must be of type b64/u64/s64,
ld/st/cvt conversion rules are applied to dst
- generic st: for instruction `st [x], y`, x must be of type b64/u64/s64,
which is bitcast to a pointer
*/
fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
@ -226,41 +226,56 @@ fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
match s {
Statement::Instruction(inst) => match inst {
ast::Instruction::Ld(ld, mut arg) => {
let new_arg_src = arg.src.map_id(&mut |arg_src| {
arg.src = arg.src.map_id(&mut |arg_src| {
insert_implicit_conversions_ld_src(
&mut result,
ast::Type::Scalar(ld.typ),
type_check,
new_id,
|instr, op| ld.state_space.should_convert(instr, op),
ld.state_space,
arg_src,
)
});
arg.src = new_arg_src;
insert_implicit_bitcasts(
false,
true,
insert_with_implicit_conversion_dst(
&mut result,
ld.typ,
type_check,
new_id,
ast::Instruction::Ld(ld, arg),
should_convert_relaxed_dst,
arg,
|arg| &mut arg.dst,
|arg| ast::Instruction::Ld(ld, arg),
);
}
ast::Instruction::St(st, mut arg) => {
let arg_dst_type = type_check(arg.dst);
let new_dst = new_id();
result.push(Statement::Converison(ImplicitConversion{
src: arg.dst,
dst: new_dst,
from: arg_dst_type,
to: ast::Type::Scalar(st.typ),
kind: ConversionKind::Ptr
}));
arg.dst = new_dst;
}
inst @ _ => {
insert_implicit_bitcasts(true, true, &mut result, type_check, new_id, inst)
arg.src2 = arg.src2.map_id(&mut |arg_src| {
let arg_src_type = type_check(arg_src);
if let Some(conv) = should_convert_relaxed_src(arg_src_type, st.typ) {
insert_conversion_src(
&mut result,
new_id,
arg_src,
arg_src_type,
ast::Type::Scalar(st.typ),
conv,
)
} else {
arg_src
}
});
arg.src1 = arg.src1.map_id(&mut |arg_src| {
insert_implicit_conversions_ld_src(
&mut result,
ast::Type::Scalar(st.typ),
type_check,
new_id,
st.state_space.to_ld_ss(),
arg_src,
)
});
result.push(Statement::Instruction(ast::Instruction::St(st, arg)));
}
inst @ _ => insert_implicit_bitcasts(&mut result, type_check, new_id, inst),
},
s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s),
Statement::Converison(_) => unreachable!(),
@ -386,11 +401,15 @@ fn emit_function_body_ops(
{
todo!()
}
let src = match arg.src {
let dst = match arg.src1 {
ast::Operand::Reg(id) => id,
_ => todo!(),
};
builder.store(arg.dst, src, None, &[])?;
let src = match arg.src2 {
ast::Operand::Reg(id) => id,
_ => todo!(),
};
builder.store(dst, src, None, &[])?;
}
// SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?,
@ -417,17 +436,18 @@ fn emit_implicit_conversion(
builder,
SpirvType::Pointer(to_type, spirv_headers::StorageClass::Generic),
);
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
ConversionKind::Default => {
if from_type.width() == to_type.width() {
let dst_type = map.get_or_add_scalar(builder, to_type);
if from_type.kind() == ScalarKind::Unsigned && to_type.kind() == ScalarKind::Byte
|| from_type.kind() == ScalarKind::Byte
&& to_type.kind() == ScalarKind::Unsigned
{
return Ok(());
// It is noop, but another instruction expects result of this conversion
builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
}
let dst_type = map.get_or_add_scalar(builder, to_type);
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
} else {
let as_unsigned_type = map.get_or_add_scalar(
@ -1025,8 +1045,10 @@ struct ImplicitConversion {
kind: ConversionKind,
}
#[derive(Debug, PartialEq)]
enum ConversionKind {
Default, // zero-extend/chop/bitcast depending on types
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
Ptr,
}
@ -1136,10 +1158,7 @@ impl<T> ast::Instruction<T> {
ast::Instruction::Not(_, a) => a.visit_id(f),
ast::Instruction::Cvt(_, a) => a.visit_id(f),
ast::Instruction::Shl(_, a) => a.visit_id(f),
ast::Instruction::St(_, a) => {
f(false, &a.dst);
a.src.visit_id(f);
}
ast::Instruction::St(_, a) => a.visit_id(f),
ast::Instruction::Bra(_, a) => a.visit_id(f),
ast::Instruction::Ret(_) => (),
}
@ -1156,10 +1175,7 @@ impl<T> ast::Instruction<T> {
ast::Instruction::Not(_, a) => a.visit_id_mut(f),
ast::Instruction::Cvt(_, a) => a.visit_id_mut(f),
ast::Instruction::Shl(_, a) => a.visit_id_mut(f),
ast::Instruction::St(_, a) => {
f(false, &mut a.dst);
a.src.visit_id_mut(f);
}
ast::Instruction::St(_, a) => a.visit_id_mut(f),
ast::Instruction::Bra(_, a) => a.visit_id_mut(f),
ast::Instruction::Ret(_) => (),
}
@ -1245,6 +1261,25 @@ impl<T> ast::Arg2<T> {
}
}
impl<T> ast::Arg2St<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2St<U> {
ast::Arg2St {
src1: self.src1.map_id(f),
src2: self.src2.map_id(f),
}
}
fn visit_id<F: FnMut(bool, &T)>(&self, f: &mut F) {
self.src1.visit_id(f);
self.src2.visit_id(f);
}
fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
self.src1.visit_id_mut(f);
self.src2.visit_id_mut(f);
}
}
impl<T> ast::Arg2Mov<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2Mov<U> {
ast::Arg2Mov {
@ -1388,6 +1423,18 @@ impl<T> ast::MovOperand<T> {
}
}
impl ast::StStateSpace {
fn to_ld_ss(self) -> ast::LdStateSpace {
match self {
ast::StStateSpace::Generic => ast::LdStateSpace::Generic,
ast::StStateSpace::Global => ast::LdStateSpace::Global,
ast::StStateSpace::Local => ast::LdStateSpace::Local,
ast::StStateSpace::Param => ast::LdStateSpace::Param,
ast::StStateSpace::Shared => ast::LdStateSpace::Shared,
}
}
}
#[derive(Clone, Copy, PartialEq)]
enum ScalarKind {
Byte,
@ -1491,75 +1538,200 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
}
}
impl ast::LdStateSpace {
fn should_convert(self, instr_type: ast::Type, op_type: ast::Type) -> Option<ConversionKind> {
match self {
ast::LdStateSpace::Param => {
if instr_type != op_type {
Some(ConversionKind::Default)
} else {
None
}
}
ast::LdStateSpace::Generic => Some(ConversionKind::Ptr),
_ => todo!(),
}
}
}
fn insert_forced_bitcast_src<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
>(
func: &mut Vec<Statement>,
op_type: ast::Type,
type_check: &TypeCheck,
new_id: &mut NewId,
src: spirv::Word,
) -> spirv::Word {
let src_type = type_check(src);
if src_type == op_type {
return src;
}
let new_src = new_id();
func.push(Statement::Converison(ImplicitConversion {
src: src,
dst: new_src,
from: src_type,
to: op_type,
kind: ConversionKind::Default,
}));
new_src
}
fn insert_implicit_conversions_ld_src<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
ShouldConvert: Fn(ast::Type, ast::Type) -> Option<ConversionKind>,
>(
func: &mut Vec<Statement>,
instr_type: ast::Type,
type_check: &TypeCheck,
new_id: &mut NewId,
should_convert: ShouldConvert,
state_space: ast::LdStateSpace,
src: spirv::Word,
) -> spirv::Word {
match state_space {
ast::LdStateSpace::Param => insert_implicit_conversions_ld_src_impl(
func,
type_check,
new_id,
instr_type,
src,
should_convert_ld_param_src,
),
ast::LdStateSpace::Generic => {
let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
mem::size_of::<usize>() as u8,
ScalarKind::Byte,
));
let new_src = insert_implicit_conversions_ld_src_impl(
func,
type_check,
new_id,
new_src_type,
src,
should_convert_ld_generic_src_to_bitcast,
);
insert_conversion_src(
func,
new_id,
new_src,
new_src_type,
instr_type,
ConversionKind::Ptr,
)
}
_ => todo!(),
}
}
fn insert_implicit_conversions_ld_src_impl<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
>(
func: &mut Vec<Statement>,
type_check: &TypeCheck,
new_id: &mut NewId,
instr_type: ast::Type,
src: spirv::Word,
should_convert: ShouldConvert,
) -> spirv::Word {
let src_type = type_check(src);
if let Some(conv_kind) = should_convert(src_type, instr_type) {
let new_src = new_id();
func.push(Statement::Converison(ImplicitConversion {
src: src,
dst: new_src,
from: src_type,
to: instr_type,
kind: conv_kind,
}));
new_src
if let Some(conv) = should_convert(src_type, instr_type) {
insert_conversion_src(func, new_id, src, src_type, instr_type, conv)
} else {
src
}
}
fn should_convert_ld_param_src(
src_type: ast::Type,
instr_type: ast::Type,
) -> Option<ConversionKind> {
if src_type != instr_type {
return Some(ConversionKind::Default);
}
None
}
// HACK ALERT
// IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an
// additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier
fn should_convert_ld_generic_src_to_bitcast(
src_type: ast::Type,
_instr_type: ast::Type,
) -> Option<ConversionKind> {
if let ast::Type::Scalar(src_type) = src_type {
if src_type.kind() == ScalarKind::Signed {
return Some(ConversionKind::Default);
}
}
None
}
#[must_use]
fn insert_conversion_src<NewId: FnMut() -> spirv::Word>(
func: &mut Vec<Statement>,
new_id: &mut NewId,
src: spirv::Word,
src_type: ast::Type,
instr_type: ast::Type,
conv: ConversionKind,
) -> spirv::Word {
let temp_src = new_id();
func.push(Statement::Converison(ImplicitConversion {
src: src,
dst: temp_src,
from: src_type,
to: instr_type,
kind: conv,
}));
temp_src
}
fn insert_with_implicit_conversion_dst<
T,
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
Setter: Fn(&mut T) -> &mut spirv::Word,
ToInstruction: FnOnce(T) -> ast::Instruction<spirv::Word>,
>(
func: &mut Vec<Statement>,
instr_type: ast::ScalarType,
type_check: &TypeCheck,
new_id: &mut NewId,
should_convert: ShouldConvert,
mut t: T,
setter: Setter,
to_inst: ToInstruction,
) {
let dst = setter(&mut t);
let dst_type = type_check(*dst);
let dst_coercion = should_convert(dst_type, instr_type)
.map(|conv| get_conversion_dst(new_id, dst, ast::Type::Scalar(instr_type), dst_type, conv));
func.push(Statement::Instruction(to_inst(t)));
if let Some(conv) = dst_coercion {
func.push(conv);
}
}
#[must_use]
fn get_conversion_dst<NewId: FnMut() -> spirv::Word>(
new_id: &mut NewId,
dst: &mut spirv::Word,
instr_type: ast::Type,
dst_type: ast::Type,
kind: ConversionKind,
) -> Statement {
let original_dst = *dst;
let temp_dst = new_id();
*dst = temp_dst;
Statement::Converison(ImplicitConversion {
src: temp_dst,
dst: original_dst,
from: instr_type,
to: dst_type,
kind: kind,
})
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
fn should_convert_relaxed_src(
src_type: ast::Type,
instr_type: ast::ScalarType,
) -> Option<ConversionKind> {
if src_type == ast::Type::Scalar(instr_type) {
return None;
}
match src_type {
ast::Type::Scalar(src_type) => match instr_type.kind() {
ScalarKind::Byte => {
if instr_type.width() <= src_type.width() {
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Signed | ScalarKind::Unsigned => {
if instr_type.width() <= src_type.width() && src_type.kind() != ScalarKind::Float {
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Float => {
if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Byte {
Some(ConversionKind::Default)
} else {
None
}
}
},
_ => None,
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: ast::Type,
@ -1578,8 +1750,14 @@ fn should_convert_relaxed_dst(
}
}
ScalarKind::Signed => {
if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float {
Some(ConversionKind::SignExtend)
if dst_type.kind() != ScalarKind::Float {
if instr_type.width() == dst_type.width() {
Some(ConversionKind::Default)
} else if instr_type.width() < dst_type.width() {
Some(ConversionKind::SignExtend)
} else {
None
}
} else {
None
}
@ -1592,7 +1770,7 @@ fn should_convert_relaxed_dst(
}
}
ScalarKind::Float => {
if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Float {
if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Byte {
Some(ConversionKind::Default)
} else {
None
@ -1607,8 +1785,6 @@ fn insert_implicit_bitcasts<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
>(
do_src_bitcast: bool,
do_dst_bitcast: bool,
func: &mut Vec<Statement>,
type_check: &TypeCheck,
new_id: &mut NewId,
@ -1617,37 +1793,32 @@ fn insert_implicit_bitcasts<
let mut dst_coercion = None;
if let Some(instr_type) = instr.get_type() {
instr.visit_id_mut(&mut |is_dst, id| {
if (is_dst && !do_dst_bitcast) || (!is_dst && !do_src_bitcast) {
return;
}
let id_type = type_check(*id);
if should_bitcast(instr_type, type_check(*id)) {
let replacement_id = new_id();
if is_dst {
dst_coercion = Some(ImplicitConversion {
src: replacement_id,
dst: *id,
from: instr_type,
to: id_type,
kind: ConversionKind::Default,
});
*id = replacement_id;
dst_coercion = Some(get_conversion_dst(
new_id,
id,
instr_type,
id_type,
ConversionKind::Default,
));
} else {
func.push(Statement::Converison(ImplicitConversion {
src: *id,
dst: replacement_id,
from: id_type,
to: instr_type,
kind: ConversionKind::Default,
}));
*id = replacement_id;
*id = insert_conversion_src(
func,
new_id,
*id,
id_type,
instr_type,
ConversionKind::Default,
);
}
}
});
}
func.push(Statement::Instruction(instr));
if let Some(cond) = dst_coercion {
func.push(Statement::Converison(cond));
func.push(cond);
}
}
@ -1771,7 +1942,7 @@ mod tests {
vec![BasicBlock {
start: StmtIndex(0),
pred: vec![],
succ: vec![]
succ: vec![],
}]
);
}
@ -1791,7 +1962,7 @@ mod tests {
vec![BasicBlock {
start: StmtIndex(0),
pred: vec![BBIndex(0)],
succ: vec![BBIndex(0)]
succ: vec![BBIndex(0)],
}]
);
}
@ -2032,37 +2203,37 @@ mod tests {
BasicBlock {
start: StmtIndex(0),
pred: vec![],
succ: vec![BBIndex(1)]
succ: vec![BBIndex(1)],
},
BasicBlock {
start: StmtIndex(3),
pred: vec![BBIndex(0), BBIndex(5)],
succ: vec![BBIndex(2), BBIndex(6)]
succ: vec![BBIndex(2), BBIndex(6)],
},
BasicBlock {
start: StmtIndex(6),
pred: vec![BBIndex(1)],
succ: vec![BBIndex(3), BBIndex(4)]
succ: vec![BBIndex(3), BBIndex(4)],
},
BasicBlock {
start: StmtIndex(9),
pred: vec![BBIndex(2)],
succ: vec![BBIndex(5)]
succ: vec![BBIndex(5)],
},
BasicBlock {
start: StmtIndex(13),
pred: vec![BBIndex(2)],
succ: vec![BBIndex(5)]
succ: vec![BBIndex(5)],
},
BasicBlock {
start: StmtIndex(16),
pred: vec![BBIndex(3), BBIndex(4)],
succ: vec![BBIndex(1)]
succ: vec![BBIndex(1)],
},
BasicBlock {
start: StmtIndex(18),
pred: vec![BBIndex(1)],
succ: vec![]
succ: vec![],
},
]
);
@ -2350,4 +2521,106 @@ mod tests {
}
panic!()
}
static SCALAR_TYPES: [ast::ScalarType; 15] = [
ast::ScalarType::B8,
ast::ScalarType::B16,
ast::ScalarType::B32,
ast::ScalarType::B64,
ast::ScalarType::S8,
ast::ScalarType::S16,
ast::ScalarType::S32,
ast::ScalarType::S64,
ast::ScalarType::U8,
ast::ScalarType::U16,
ast::ScalarType::U32,
ast::ScalarType::U64,
ast::ScalarType::F16,
ast::ScalarType::F32,
ast::ScalarType::F64,
];
static RELAXED_SRC_CONVERSION_TABLE: &'static str =
"b8 - chop chop chop - chop chop chop - chop chop chop chop chop chop
b16 inv - chop chop inv - chop chop inv - chop chop - chop chop
b32 inv inv - chop inv inv - chop inv inv - chop inv - chop
b64 inv inv inv - inv inv inv - inv inv inv - inv inv -
s8 - chop chop chop - chop chop chop - chop chop chop inv inv inv
s16 inv - chop chop inv - chop chop inv - chop chop inv inv inv
s32 inv inv - chop inv inv - chop inv inv - chop inv inv inv
s64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
u8 - chop chop chop - chop chop chop - chop chop chop inv inv inv
u16 inv - chop chop inv - chop chop inv - chop chop inv inv inv
u32 inv inv - chop inv inv - chop inv inv - chop inv inv inv
u64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
f16 inv - chop chop inv inv inv inv inv inv inv inv - inv inv
f32 inv inv - chop inv inv inv inv inv inv inv inv inv - inv
f64 inv inv inv - inv inv inv inv inv inv inv inv inv inv -";
static RELAXED_DST_CONVERSION_TABLE: &'static str =
"b8 - zext zext zext - zext zext zext - zext zext zext zext zext zext
b16 inv - zext zext inv - zext zext inv - zext zext - zext zext
b32 inv inv - zext inv inv - zext inv inv - zext inv - zext
b64 inv inv inv - inv inv inv - inv inv inv - inv inv -
s8 - sext sext sext - sext sext sext - sext sext sext inv inv inv
s16 inv - sext sext inv - sext sext inv - sext sext inv inv inv
s32 inv inv - sext inv inv - sext inv inv - sext inv inv inv
s64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
u8 - zext zext zext - zext zext zext - zext zext zext inv inv inv
u16 inv - zext zext inv - zext zext inv - zext zext inv inv inv
u32 inv inv - zext inv inv - zext inv inv - zext inv inv inv
u64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
f16 inv - zext zext inv inv inv inv inv inv inv inv - inv inv
f32 inv inv - zext inv inv inv inv inv inv inv inv inv - inv
f64 inv inv inv - inv inv inv inv inv inv inv inv inv inv -";
fn table_entry_to_conversion(entry: &'static str) -> Option<ConversionKind> {
match entry {
"-" => Some(ConversionKind::Default),
"inv" => None,
"zext" => Some(ConversionKind::Default),
"chop" => Some(ConversionKind::Default),
"sext" => Some(ConversionKind::SignExtend),
_ => unreachable!(),
}
}
fn parse_conversion_table(table: &'static str) -> Vec<Vec<Option<ConversionKind>>> {
table
.lines()
.map(|line| {
line.split_ascii_whitespace()
.skip(1)
.map(table_entry_to_conversion)
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
fn assert_conversion_table<F: Fn(ast::Type, ast::ScalarType) -> Option<ConversionKind>>(
table: &'static str,
f: F,
) {
let conv_table = parse_conversion_table(table);
for (instr_idx, instr_type) in SCALAR_TYPES.iter().enumerate() {
for (op_idx, op_type) in SCALAR_TYPES.iter().enumerate() {
let conversion = f(ast::Type::Scalar(*op_type), *instr_type);
if instr_idx == op_idx {
assert_eq!(conversion, None);
} else {
assert_eq!(conversion, conv_table[instr_idx][op_idx]);
}
}
}
}
#[test]
fn should_convert_relaxed_src_all_combinations() {
assert_conversion_table(RELAXED_SRC_CONVERSION_TABLE, should_convert_relaxed_src);
}
#[test]
fn should_convert_relaxed_dst_all_combinations() {
assert_conversion_table(RELAXED_DST_CONVERSION_TABLE, should_convert_relaxed_dst);
}
}