mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Finish implementing implicit conversions
This commit is contained in:
parent
4a0edf0e14
commit
279e6246ba
6 changed files with 418 additions and 138 deletions
|
@ -1,5 +1,7 @@
|
|||
|
||||
fn main() {
|
||||
println!("cargo:rustc-link-lib=dylib=ze_loader");
|
||||
// TODO: make this windows-only
|
||||
println!("cargo:rustc-link-search=native=C:\\Windows\\System32");
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
}
|
|
@ -137,7 +137,7 @@ pub enum Instruction<ID> {
|
|||
Bra(BraData, Arg1<ID>),
|
||||
Cvt(CvtData, Arg2<ID>),
|
||||
Shl(ShlData, Arg3<ID>),
|
||||
St(StData, Arg2<ID>),
|
||||
St(StData, Arg2St<ID>),
|
||||
Ret(RetData),
|
||||
}
|
||||
|
||||
|
@ -150,6 +150,11 @@ pub struct Arg2<ID> {
|
|||
pub src: Operand<ID>,
|
||||
}
|
||||
|
||||
pub struct Arg2St<ID> {
|
||||
pub src1: Operand<ID>,
|
||||
pub src2: Operand<ID>,
|
||||
}
|
||||
|
||||
pub struct Arg2Mov<ID> {
|
||||
pub dst: ID,
|
||||
pub src: MovOperand<ID>,
|
||||
|
@ -264,7 +269,7 @@ pub struct StData {
|
|||
pub typ: ScalarType,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||
pub enum StStateSpace {
|
||||
Generic,
|
||||
Global,
|
||||
|
|
|
@ -12,8 +12,7 @@ extern crate rspirv;
|
|||
extern crate spirv_headers as spirv;
|
||||
|
||||
lalrpop_mod!(
|
||||
#[allow(dead_code)]
|
||||
#[allow(unused_imports)]
|
||||
#[allow(warnings)]
|
||||
ptx
|
||||
);
|
||||
|
||||
|
|
|
@ -386,7 +386,7 @@ ShlType = {
|
|||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
|
||||
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
|
||||
InstSt: ast::Instruction<&'input str> = {
|
||||
"st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <v:VectorPrefix?> <t:MemoryType> "[" <dst:ID> "]" "," <src:Operand> => {
|
||||
"st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <v:VectorPrefix?> <t:MemoryType> "[" <src1:Operand> "]" "," <src2:Operand> => {
|
||||
ast::Instruction::St(
|
||||
ast::StData {
|
||||
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
|
||||
|
@ -395,7 +395,7 @@ InstSt: ast::Instruction<&'input str> = {
|
|||
vector: v,
|
||||
typ: t
|
||||
},
|
||||
ast::Arg2{dst:dst, src:src}
|
||||
ast::Arg2St { src1:src1, src2:src2 }
|
||||
)
|
||||
}
|
||||
};
|
||||
|
|
|
@ -57,7 +57,7 @@ fn test_ptx_assert<'a, T: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq>(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn run_spirv<T: From<u8> + ze::SafeRepr + Copy>(
|
||||
fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
|
||||
name: &CStr,
|
||||
spirv: &[u32],
|
||||
input: &[T],
|
||||
|
@ -84,15 +84,16 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy>(
|
|||
let event_pool = ze::EventPool::new(&drv, 3, Some(&[&dev]))?;
|
||||
let ev0 = ze::Event::new(&event_pool, 0)?;
|
||||
let ev1 = ze::Event::new(&event_pool, 1)?;
|
||||
let ev2 = ze::Event::new(&event_pool, 2)?;
|
||||
let mut cmd_list = ze::CommandList::new(&dev)?;
|
||||
let out_b_ptr: ze::BufferPtrMut<T> = (&mut out_b).into();
|
||||
let out_b_ptr_mut: ze::BufferPtrMut<T> = (&mut out_b).into();
|
||||
cmd_list.append_memory_copy(inp_b_ptr_mut, input, None, Some(&ev0))?;
|
||||
cmd_list.append_memory_fill(out_b_ptr, 0u8.into(), Some(&ev1))?;
|
||||
cmd_list.append_memory_fill(out_b_ptr_mut, 0u8.into(), Some(&ev1))?;
|
||||
kernel.set_group_size(1, 1, 1)?;
|
||||
kernel.set_arg_buffer(0, inp_b_ptr_mut)?;
|
||||
kernel.set_arg_buffer(1, out_b_ptr)?;
|
||||
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], None, &[&ev0, &ev1])?;
|
||||
cmd_list.append_memory_copy(result.as_mut_slice(), inp_b_ptr_mut, None, Some(&ev0))?;
|
||||
kernel.set_arg_buffer(1, out_b_ptr_mut)?;
|
||||
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&ev2), &[&ev0, &ev1])?;
|
||||
cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, Some(&ev2))?;
|
||||
queue.execute(cmd_list)?;
|
||||
Ok(result)
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ use bit_vec::BitVec;
|
|||
use rspirv::dr;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::{BTreeMap, HashMap, HashSet};
|
||||
use std::{borrow::Cow, fmt};
|
||||
use std::{borrow::Cow, fmt, mem};
|
||||
|
||||
use rspirv::binary::{Assemble, Disassemble};
|
||||
|
||||
|
@ -86,7 +86,7 @@ pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, dr::Error> {
|
|||
emit_function(&mut builder, &mut map, f)?;
|
||||
}
|
||||
let module = builder.module();
|
||||
dbg!(print!("{}", module.disassemble()));
|
||||
println!("{}", module.disassemble());
|
||||
Ok(module.assemble())
|
||||
}
|
||||
|
||||
|
@ -206,8 +206,8 @@ fn collect_var_definitions<'a>(
|
|||
documented special ld/st/cvt conversion rules for destination operands
|
||||
- generic ld: for instruction `ld x, [y]`, y must be of type b64/u64/s64,
|
||||
which is bitcast to a pointer, dereferenced and then documented special
|
||||
ld/st/cvt conversion rules are applied
|
||||
- generic ld: for instruction `ld [x], y`, x must be of type b64/u64/s64,
|
||||
ld/st/cvt conversion rules are applied to dst
|
||||
- generic st: for instruction `st [x], y`, x must be of type b64/u64/s64,
|
||||
which is bitcast to a pointer
|
||||
*/
|
||||
fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
|
||||
|
@ -226,41 +226,56 @@ fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
|
|||
match s {
|
||||
Statement::Instruction(inst) => match inst {
|
||||
ast::Instruction::Ld(ld, mut arg) => {
|
||||
let new_arg_src = arg.src.map_id(&mut |arg_src| {
|
||||
arg.src = arg.src.map_id(&mut |arg_src| {
|
||||
insert_implicit_conversions_ld_src(
|
||||
&mut result,
|
||||
ast::Type::Scalar(ld.typ),
|
||||
type_check,
|
||||
new_id,
|
||||
|instr, op| ld.state_space.should_convert(instr, op),
|
||||
ld.state_space,
|
||||
arg_src,
|
||||
)
|
||||
});
|
||||
arg.src = new_arg_src;
|
||||
insert_implicit_bitcasts(
|
||||
false,
|
||||
true,
|
||||
insert_with_implicit_conversion_dst(
|
||||
&mut result,
|
||||
ld.typ,
|
||||
type_check,
|
||||
new_id,
|
||||
ast::Instruction::Ld(ld, arg),
|
||||
should_convert_relaxed_dst,
|
||||
arg,
|
||||
|arg| &mut arg.dst,
|
||||
|arg| ast::Instruction::Ld(ld, arg),
|
||||
);
|
||||
}
|
||||
ast::Instruction::St(st, mut arg) => {
|
||||
let arg_dst_type = type_check(arg.dst);
|
||||
let new_dst = new_id();
|
||||
result.push(Statement::Converison(ImplicitConversion{
|
||||
src: arg.dst,
|
||||
dst: new_dst,
|
||||
from: arg_dst_type,
|
||||
to: ast::Type::Scalar(st.typ),
|
||||
kind: ConversionKind::Ptr
|
||||
}));
|
||||
arg.dst = new_dst;
|
||||
}
|
||||
inst @ _ => {
|
||||
insert_implicit_bitcasts(true, true, &mut result, type_check, new_id, inst)
|
||||
arg.src2 = arg.src2.map_id(&mut |arg_src| {
|
||||
let arg_src_type = type_check(arg_src);
|
||||
if let Some(conv) = should_convert_relaxed_src(arg_src_type, st.typ) {
|
||||
insert_conversion_src(
|
||||
&mut result,
|
||||
new_id,
|
||||
arg_src,
|
||||
arg_src_type,
|
||||
ast::Type::Scalar(st.typ),
|
||||
conv,
|
||||
)
|
||||
} else {
|
||||
arg_src
|
||||
}
|
||||
});
|
||||
arg.src1 = arg.src1.map_id(&mut |arg_src| {
|
||||
insert_implicit_conversions_ld_src(
|
||||
&mut result,
|
||||
ast::Type::Scalar(st.typ),
|
||||
type_check,
|
||||
new_id,
|
||||
st.state_space.to_ld_ss(),
|
||||
arg_src,
|
||||
)
|
||||
});
|
||||
result.push(Statement::Instruction(ast::Instruction::St(st, arg)));
|
||||
}
|
||||
inst @ _ => insert_implicit_bitcasts(&mut result, type_check, new_id, inst),
|
||||
},
|
||||
s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s),
|
||||
Statement::Converison(_) => unreachable!(),
|
||||
|
@ -386,11 +401,15 @@ fn emit_function_body_ops(
|
|||
{
|
||||
todo!()
|
||||
}
|
||||
let src = match arg.src {
|
||||
let dst = match arg.src1 {
|
||||
ast::Operand::Reg(id) => id,
|
||||
_ => todo!(),
|
||||
};
|
||||
builder.store(arg.dst, src, None, &[])?;
|
||||
let src = match arg.src2 {
|
||||
ast::Operand::Reg(id) => id,
|
||||
_ => todo!(),
|
||||
};
|
||||
builder.store(dst, src, None, &[])?;
|
||||
}
|
||||
// SPIR-V does not support ret as guaranteed-converged
|
||||
ast::Instruction::Ret(_) => builder.ret()?,
|
||||
|
@ -417,17 +436,18 @@ fn emit_implicit_conversion(
|
|||
builder,
|
||||
SpirvType::Pointer(to_type, spirv_headers::StorageClass::Generic),
|
||||
);
|
||||
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
|
||||
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
|
||||
}
|
||||
ConversionKind::Default => {
|
||||
if from_type.width() == to_type.width() {
|
||||
let dst_type = map.get_or_add_scalar(builder, to_type);
|
||||
if from_type.kind() == ScalarKind::Unsigned && to_type.kind() == ScalarKind::Byte
|
||||
|| from_type.kind() == ScalarKind::Byte
|
||||
&& to_type.kind() == ScalarKind::Unsigned
|
||||
{
|
||||
return Ok(());
|
||||
// It is noop, but another instruction expects result of this conversion
|
||||
builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
|
||||
}
|
||||
let dst_type = map.get_or_add_scalar(builder, to_type);
|
||||
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
|
||||
} else {
|
||||
let as_unsigned_type = map.get_or_add_scalar(
|
||||
|
@ -1025,8 +1045,10 @@ struct ImplicitConversion {
|
|||
kind: ConversionKind,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
enum ConversionKind {
|
||||
Default, // zero-extend/chop/bitcast depending on types
|
||||
Default,
|
||||
// zero-extend/chop/bitcast depending on types
|
||||
SignExtend,
|
||||
Ptr,
|
||||
}
|
||||
|
@ -1136,10 +1158,7 @@ impl<T> ast::Instruction<T> {
|
|||
ast::Instruction::Not(_, a) => a.visit_id(f),
|
||||
ast::Instruction::Cvt(_, a) => a.visit_id(f),
|
||||
ast::Instruction::Shl(_, a) => a.visit_id(f),
|
||||
ast::Instruction::St(_, a) => {
|
||||
f(false, &a.dst);
|
||||
a.src.visit_id(f);
|
||||
}
|
||||
ast::Instruction::St(_, a) => a.visit_id(f),
|
||||
ast::Instruction::Bra(_, a) => a.visit_id(f),
|
||||
ast::Instruction::Ret(_) => (),
|
||||
}
|
||||
|
@ -1156,10 +1175,7 @@ impl<T> ast::Instruction<T> {
|
|||
ast::Instruction::Not(_, a) => a.visit_id_mut(f),
|
||||
ast::Instruction::Cvt(_, a) => a.visit_id_mut(f),
|
||||
ast::Instruction::Shl(_, a) => a.visit_id_mut(f),
|
||||
ast::Instruction::St(_, a) => {
|
||||
f(false, &mut a.dst);
|
||||
a.src.visit_id_mut(f);
|
||||
}
|
||||
ast::Instruction::St(_, a) => a.visit_id_mut(f),
|
||||
ast::Instruction::Bra(_, a) => a.visit_id_mut(f),
|
||||
ast::Instruction::Ret(_) => (),
|
||||
}
|
||||
|
@ -1245,6 +1261,25 @@ impl<T> ast::Arg2<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T> ast::Arg2St<T> {
|
||||
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2St<U> {
|
||||
ast::Arg2St {
|
||||
src1: self.src1.map_id(f),
|
||||
src2: self.src2.map_id(f),
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_id<F: FnMut(bool, &T)>(&self, f: &mut F) {
|
||||
self.src1.visit_id(f);
|
||||
self.src2.visit_id(f);
|
||||
}
|
||||
|
||||
fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
|
||||
self.src1.visit_id_mut(f);
|
||||
self.src2.visit_id_mut(f);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ast::Arg2Mov<T> {
|
||||
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2Mov<U> {
|
||||
ast::Arg2Mov {
|
||||
|
@ -1388,6 +1423,18 @@ impl<T> ast::MovOperand<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl ast::StStateSpace {
|
||||
fn to_ld_ss(self) -> ast::LdStateSpace {
|
||||
match self {
|
||||
ast::StStateSpace::Generic => ast::LdStateSpace::Generic,
|
||||
ast::StStateSpace::Global => ast::LdStateSpace::Global,
|
||||
ast::StStateSpace::Local => ast::LdStateSpace::Local,
|
||||
ast::StStateSpace::Param => ast::LdStateSpace::Param,
|
||||
ast::StStateSpace::Shared => ast::LdStateSpace::Shared,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq)]
|
||||
enum ScalarKind {
|
||||
Byte,
|
||||
|
@ -1491,75 +1538,200 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
|||
}
|
||||
}
|
||||
|
||||
impl ast::LdStateSpace {
|
||||
fn should_convert(self, instr_type: ast::Type, op_type: ast::Type) -> Option<ConversionKind> {
|
||||
match self {
|
||||
ast::LdStateSpace::Param => {
|
||||
if instr_type != op_type {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::LdStateSpace::Generic => Some(ConversionKind::Ptr),
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_forced_bitcast_src<
|
||||
TypeCheck: Fn(spirv::Word) -> ast::Type,
|
||||
NewId: FnMut() -> spirv::Word,
|
||||
>(
|
||||
func: &mut Vec<Statement>,
|
||||
op_type: ast::Type,
|
||||
type_check: &TypeCheck,
|
||||
new_id: &mut NewId,
|
||||
src: spirv::Word,
|
||||
) -> spirv::Word {
|
||||
let src_type = type_check(src);
|
||||
if src_type == op_type {
|
||||
return src;
|
||||
}
|
||||
let new_src = new_id();
|
||||
func.push(Statement::Converison(ImplicitConversion {
|
||||
src: src,
|
||||
dst: new_src,
|
||||
from: src_type,
|
||||
to: op_type,
|
||||
kind: ConversionKind::Default,
|
||||
}));
|
||||
new_src
|
||||
}
|
||||
|
||||
fn insert_implicit_conversions_ld_src<
|
||||
TypeCheck: Fn(spirv::Word) -> ast::Type,
|
||||
NewId: FnMut() -> spirv::Word,
|
||||
ShouldConvert: Fn(ast::Type, ast::Type) -> Option<ConversionKind>,
|
||||
>(
|
||||
func: &mut Vec<Statement>,
|
||||
instr_type: ast::Type,
|
||||
type_check: &TypeCheck,
|
||||
new_id: &mut NewId,
|
||||
should_convert: ShouldConvert,
|
||||
state_space: ast::LdStateSpace,
|
||||
src: spirv::Word,
|
||||
) -> spirv::Word {
|
||||
match state_space {
|
||||
ast::LdStateSpace::Param => insert_implicit_conversions_ld_src_impl(
|
||||
func,
|
||||
type_check,
|
||||
new_id,
|
||||
instr_type,
|
||||
src,
|
||||
should_convert_ld_param_src,
|
||||
),
|
||||
ast::LdStateSpace::Generic => {
|
||||
let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
|
||||
mem::size_of::<usize>() as u8,
|
||||
ScalarKind::Byte,
|
||||
));
|
||||
let new_src = insert_implicit_conversions_ld_src_impl(
|
||||
func,
|
||||
type_check,
|
||||
new_id,
|
||||
new_src_type,
|
||||
src,
|
||||
should_convert_ld_generic_src_to_bitcast,
|
||||
);
|
||||
insert_conversion_src(
|
||||
func,
|
||||
new_id,
|
||||
new_src,
|
||||
new_src_type,
|
||||
instr_type,
|
||||
ConversionKind::Ptr,
|
||||
)
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_implicit_conversions_ld_src_impl<
|
||||
TypeCheck: Fn(spirv::Word) -> ast::Type,
|
||||
NewId: FnMut() -> spirv::Word,
|
||||
ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
|
||||
>(
|
||||
func: &mut Vec<Statement>,
|
||||
type_check: &TypeCheck,
|
||||
new_id: &mut NewId,
|
||||
instr_type: ast::Type,
|
||||
src: spirv::Word,
|
||||
should_convert: ShouldConvert,
|
||||
) -> spirv::Word {
|
||||
let src_type = type_check(src);
|
||||
if let Some(conv_kind) = should_convert(src_type, instr_type) {
|
||||
let new_src = new_id();
|
||||
func.push(Statement::Converison(ImplicitConversion {
|
||||
src: src,
|
||||
dst: new_src,
|
||||
from: src_type,
|
||||
to: instr_type,
|
||||
kind: conv_kind,
|
||||
}));
|
||||
new_src
|
||||
if let Some(conv) = should_convert(src_type, instr_type) {
|
||||
insert_conversion_src(func, new_id, src, src_type, instr_type, conv)
|
||||
} else {
|
||||
src
|
||||
}
|
||||
}
|
||||
|
||||
fn should_convert_ld_param_src(
|
||||
src_type: ast::Type,
|
||||
instr_type: ast::Type,
|
||||
) -> Option<ConversionKind> {
|
||||
if src_type != instr_type {
|
||||
return Some(ConversionKind::Default);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// HACK ALERT
|
||||
// IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an
|
||||
// additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier
|
||||
fn should_convert_ld_generic_src_to_bitcast(
|
||||
src_type: ast::Type,
|
||||
_instr_type: ast::Type,
|
||||
) -> Option<ConversionKind> {
|
||||
if let ast::Type::Scalar(src_type) = src_type {
|
||||
if src_type.kind() == ScalarKind::Signed {
|
||||
return Some(ConversionKind::Default);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn insert_conversion_src<NewId: FnMut() -> spirv::Word>(
|
||||
func: &mut Vec<Statement>,
|
||||
new_id: &mut NewId,
|
||||
src: spirv::Word,
|
||||
src_type: ast::Type,
|
||||
instr_type: ast::Type,
|
||||
conv: ConversionKind,
|
||||
) -> spirv::Word {
|
||||
let temp_src = new_id();
|
||||
func.push(Statement::Converison(ImplicitConversion {
|
||||
src: src,
|
||||
dst: temp_src,
|
||||
from: src_type,
|
||||
to: instr_type,
|
||||
kind: conv,
|
||||
}));
|
||||
temp_src
|
||||
}
|
||||
|
||||
fn insert_with_implicit_conversion_dst<
|
||||
T,
|
||||
TypeCheck: Fn(spirv::Word) -> ast::Type,
|
||||
NewId: FnMut() -> spirv::Word,
|
||||
ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
|
||||
Setter: Fn(&mut T) -> &mut spirv::Word,
|
||||
ToInstruction: FnOnce(T) -> ast::Instruction<spirv::Word>,
|
||||
>(
|
||||
func: &mut Vec<Statement>,
|
||||
instr_type: ast::ScalarType,
|
||||
type_check: &TypeCheck,
|
||||
new_id: &mut NewId,
|
||||
should_convert: ShouldConvert,
|
||||
mut t: T,
|
||||
setter: Setter,
|
||||
to_inst: ToInstruction,
|
||||
) {
|
||||
let dst = setter(&mut t);
|
||||
let dst_type = type_check(*dst);
|
||||
let dst_coercion = should_convert(dst_type, instr_type)
|
||||
.map(|conv| get_conversion_dst(new_id, dst, ast::Type::Scalar(instr_type), dst_type, conv));
|
||||
func.push(Statement::Instruction(to_inst(t)));
|
||||
if let Some(conv) = dst_coercion {
|
||||
func.push(conv);
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn get_conversion_dst<NewId: FnMut() -> spirv::Word>(
|
||||
new_id: &mut NewId,
|
||||
dst: &mut spirv::Word,
|
||||
instr_type: ast::Type,
|
||||
dst_type: ast::Type,
|
||||
kind: ConversionKind,
|
||||
) -> Statement {
|
||||
let original_dst = *dst;
|
||||
let temp_dst = new_id();
|
||||
*dst = temp_dst;
|
||||
Statement::Converison(ImplicitConversion {
|
||||
src: temp_dst,
|
||||
dst: original_dst,
|
||||
from: instr_type,
|
||||
to: dst_type,
|
||||
kind: kind,
|
||||
})
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
|
||||
fn should_convert_relaxed_src(
|
||||
src_type: ast::Type,
|
||||
instr_type: ast::ScalarType,
|
||||
) -> Option<ConversionKind> {
|
||||
if src_type == ast::Type::Scalar(instr_type) {
|
||||
return None;
|
||||
}
|
||||
match src_type {
|
||||
ast::Type::Scalar(src_type) => match instr_type.kind() {
|
||||
ScalarKind::Byte => {
|
||||
if instr_type.width() <= src_type.width() {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ScalarKind::Signed | ScalarKind::Unsigned => {
|
||||
if instr_type.width() <= src_type.width() && src_type.kind() != ScalarKind::Float {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ScalarKind::Float => {
|
||||
if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Byte {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
|
||||
fn should_convert_relaxed_dst(
|
||||
dst_type: ast::Type,
|
||||
|
@ -1578,8 +1750,14 @@ fn should_convert_relaxed_dst(
|
|||
}
|
||||
}
|
||||
ScalarKind::Signed => {
|
||||
if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float {
|
||||
Some(ConversionKind::SignExtend)
|
||||
if dst_type.kind() != ScalarKind::Float {
|
||||
if instr_type.width() == dst_type.width() {
|
||||
Some(ConversionKind::Default)
|
||||
} else if instr_type.width() < dst_type.width() {
|
||||
Some(ConversionKind::SignExtend)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -1592,7 +1770,7 @@ fn should_convert_relaxed_dst(
|
|||
}
|
||||
}
|
||||
ScalarKind::Float => {
|
||||
if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Float {
|
||||
if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Byte {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
|
@ -1607,8 +1785,6 @@ fn insert_implicit_bitcasts<
|
|||
TypeCheck: Fn(spirv::Word) -> ast::Type,
|
||||
NewId: FnMut() -> spirv::Word,
|
||||
>(
|
||||
do_src_bitcast: bool,
|
||||
do_dst_bitcast: bool,
|
||||
func: &mut Vec<Statement>,
|
||||
type_check: &TypeCheck,
|
||||
new_id: &mut NewId,
|
||||
|
@ -1617,37 +1793,32 @@ fn insert_implicit_bitcasts<
|
|||
let mut dst_coercion = None;
|
||||
if let Some(instr_type) = instr.get_type() {
|
||||
instr.visit_id_mut(&mut |is_dst, id| {
|
||||
if (is_dst && !do_dst_bitcast) || (!is_dst && !do_src_bitcast) {
|
||||
return;
|
||||
}
|
||||
let id_type = type_check(*id);
|
||||
if should_bitcast(instr_type, type_check(*id)) {
|
||||
let replacement_id = new_id();
|
||||
if is_dst {
|
||||
dst_coercion = Some(ImplicitConversion {
|
||||
src: replacement_id,
|
||||
dst: *id,
|
||||
from: instr_type,
|
||||
to: id_type,
|
||||
kind: ConversionKind::Default,
|
||||
});
|
||||
*id = replacement_id;
|
||||
dst_coercion = Some(get_conversion_dst(
|
||||
new_id,
|
||||
id,
|
||||
instr_type,
|
||||
id_type,
|
||||
ConversionKind::Default,
|
||||
));
|
||||
} else {
|
||||
func.push(Statement::Converison(ImplicitConversion {
|
||||
src: *id,
|
||||
dst: replacement_id,
|
||||
from: id_type,
|
||||
to: instr_type,
|
||||
kind: ConversionKind::Default,
|
||||
}));
|
||||
*id = replacement_id;
|
||||
*id = insert_conversion_src(
|
||||
func,
|
||||
new_id,
|
||||
*id,
|
||||
id_type,
|
||||
instr_type,
|
||||
ConversionKind::Default,
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
func.push(Statement::Instruction(instr));
|
||||
if let Some(cond) = dst_coercion {
|
||||
func.push(Statement::Converison(cond));
|
||||
func.push(cond);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1771,7 +1942,7 @@ mod tests {
|
|||
vec![BasicBlock {
|
||||
start: StmtIndex(0),
|
||||
pred: vec![],
|
||||
succ: vec![]
|
||||
succ: vec![],
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
@ -1791,7 +1962,7 @@ mod tests {
|
|||
vec![BasicBlock {
|
||||
start: StmtIndex(0),
|
||||
pred: vec![BBIndex(0)],
|
||||
succ: vec![BBIndex(0)]
|
||||
succ: vec![BBIndex(0)],
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
@ -2032,37 +2203,37 @@ mod tests {
|
|||
BasicBlock {
|
||||
start: StmtIndex(0),
|
||||
pred: vec![],
|
||||
succ: vec![BBIndex(1)]
|
||||
succ: vec![BBIndex(1)],
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(3),
|
||||
pred: vec![BBIndex(0), BBIndex(5)],
|
||||
succ: vec![BBIndex(2), BBIndex(6)]
|
||||
succ: vec![BBIndex(2), BBIndex(6)],
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(6),
|
||||
pred: vec![BBIndex(1)],
|
||||
succ: vec![BBIndex(3), BBIndex(4)]
|
||||
succ: vec![BBIndex(3), BBIndex(4)],
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(9),
|
||||
pred: vec![BBIndex(2)],
|
||||
succ: vec![BBIndex(5)]
|
||||
succ: vec![BBIndex(5)],
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(13),
|
||||
pred: vec![BBIndex(2)],
|
||||
succ: vec![BBIndex(5)]
|
||||
succ: vec![BBIndex(5)],
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(16),
|
||||
pred: vec![BBIndex(3), BBIndex(4)],
|
||||
succ: vec![BBIndex(1)]
|
||||
succ: vec![BBIndex(1)],
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(18),
|
||||
pred: vec![BBIndex(1)],
|
||||
succ: vec![]
|
||||
succ: vec![],
|
||||
},
|
||||
]
|
||||
);
|
||||
|
@ -2350,4 +2521,106 @@ mod tests {
|
|||
}
|
||||
panic!()
|
||||
}
|
||||
|
||||
static SCALAR_TYPES: [ast::ScalarType; 15] = [
|
||||
ast::ScalarType::B8,
|
||||
ast::ScalarType::B16,
|
||||
ast::ScalarType::B32,
|
||||
ast::ScalarType::B64,
|
||||
ast::ScalarType::S8,
|
||||
ast::ScalarType::S16,
|
||||
ast::ScalarType::S32,
|
||||
ast::ScalarType::S64,
|
||||
ast::ScalarType::U8,
|
||||
ast::ScalarType::U16,
|
||||
ast::ScalarType::U32,
|
||||
ast::ScalarType::U64,
|
||||
ast::ScalarType::F16,
|
||||
ast::ScalarType::F32,
|
||||
ast::ScalarType::F64,
|
||||
];
|
||||
|
||||
static RELAXED_SRC_CONVERSION_TABLE: &'static str =
|
||||
"b8 - chop chop chop - chop chop chop - chop chop chop chop chop chop
|
||||
b16 inv - chop chop inv - chop chop inv - chop chop - chop chop
|
||||
b32 inv inv - chop inv inv - chop inv inv - chop inv - chop
|
||||
b64 inv inv inv - inv inv inv - inv inv inv - inv inv -
|
||||
s8 - chop chop chop - chop chop chop - chop chop chop inv inv inv
|
||||
s16 inv - chop chop inv - chop chop inv - chop chop inv inv inv
|
||||
s32 inv inv - chop inv inv - chop inv inv - chop inv inv inv
|
||||
s64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
|
||||
u8 - chop chop chop - chop chop chop - chop chop chop inv inv inv
|
||||
u16 inv - chop chop inv - chop chop inv - chop chop inv inv inv
|
||||
u32 inv inv - chop inv inv - chop inv inv - chop inv inv inv
|
||||
u64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
|
||||
f16 inv - chop chop inv inv inv inv inv inv inv inv - inv inv
|
||||
f32 inv inv - chop inv inv inv inv inv inv inv inv inv - inv
|
||||
f64 inv inv inv - inv inv inv inv inv inv inv inv inv inv -";
|
||||
|
||||
static RELAXED_DST_CONVERSION_TABLE: &'static str =
|
||||
"b8 - zext zext zext - zext zext zext - zext zext zext zext zext zext
|
||||
b16 inv - zext zext inv - zext zext inv - zext zext - zext zext
|
||||
b32 inv inv - zext inv inv - zext inv inv - zext inv - zext
|
||||
b64 inv inv inv - inv inv inv - inv inv inv - inv inv -
|
||||
s8 - sext sext sext - sext sext sext - sext sext sext inv inv inv
|
||||
s16 inv - sext sext inv - sext sext inv - sext sext inv inv inv
|
||||
s32 inv inv - sext inv inv - sext inv inv - sext inv inv inv
|
||||
s64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
|
||||
u8 - zext zext zext - zext zext zext - zext zext zext inv inv inv
|
||||
u16 inv - zext zext inv - zext zext inv - zext zext inv inv inv
|
||||
u32 inv inv - zext inv inv - zext inv inv - zext inv inv inv
|
||||
u64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
|
||||
f16 inv - zext zext inv inv inv inv inv inv inv inv - inv inv
|
||||
f32 inv inv - zext inv inv inv inv inv inv inv inv inv - inv
|
||||
f64 inv inv inv - inv inv inv inv inv inv inv inv inv inv -";
|
||||
|
||||
fn table_entry_to_conversion(entry: &'static str) -> Option<ConversionKind> {
|
||||
match entry {
|
||||
"-" => Some(ConversionKind::Default),
|
||||
"inv" => None,
|
||||
"zext" => Some(ConversionKind::Default),
|
||||
"chop" => Some(ConversionKind::Default),
|
||||
"sext" => Some(ConversionKind::SignExtend),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_conversion_table(table: &'static str) -> Vec<Vec<Option<ConversionKind>>> {
|
||||
table
|
||||
.lines()
|
||||
.map(|line| {
|
||||
line.split_ascii_whitespace()
|
||||
.skip(1)
|
||||
.map(table_entry_to_conversion)
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn assert_conversion_table<F: Fn(ast::Type, ast::ScalarType) -> Option<ConversionKind>>(
|
||||
table: &'static str,
|
||||
f: F,
|
||||
) {
|
||||
let conv_table = parse_conversion_table(table);
|
||||
for (instr_idx, instr_type) in SCALAR_TYPES.iter().enumerate() {
|
||||
for (op_idx, op_type) in SCALAR_TYPES.iter().enumerate() {
|
||||
let conversion = f(ast::Type::Scalar(*op_type), *instr_type);
|
||||
if instr_idx == op_idx {
|
||||
assert_eq!(conversion, None);
|
||||
} else {
|
||||
assert_eq!(conversion, conv_table[instr_idx][op_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_convert_relaxed_src_all_combinations() {
|
||||
assert_conversion_table(RELAXED_SRC_CONVERSION_TABLE, should_convert_relaxed_src);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_convert_relaxed_dst_all_combinations() {
|
||||
assert_conversion_table(RELAXED_DST_CONVERSION_TABLE, should_convert_relaxed_dst);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue