mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 08:24:44 +00:00
Start implementing implicit conversions
This commit is contained in:
parent
9f60990765
commit
4a0edf0e14
2 changed files with 455 additions and 154 deletions
|
@ -200,7 +200,7 @@ pub struct LdData {
|
|||
pub typ: ScalarType,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum LdStQualifier {
|
||||
Weak,
|
||||
Volatile,
|
||||
|
@ -208,14 +208,14 @@ pub enum LdStQualifier {
|
|||
Acquire(LdScope),
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum LdScope {
|
||||
Cta,
|
||||
Gpu,
|
||||
Sys,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum LdStateSpace {
|
||||
Generic,
|
||||
Const,
|
||||
|
@ -225,7 +225,7 @@ pub enum LdStateSpace {
|
|||
Shared,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum LdCacheOperator {
|
||||
Cached,
|
||||
L2Only,
|
||||
|
|
|
@ -5,7 +5,7 @@ use std::cell::RefCell;
|
|||
use std::collections::{BTreeMap, HashMap, HashSet};
|
||||
use std::{borrow::Cow, fmt};
|
||||
|
||||
use rspirv::binary::Assemble;
|
||||
use rspirv::binary::{Assemble, Disassemble};
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
enum SpirvType {
|
||||
|
@ -86,6 +86,7 @@ pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, dr::Error> {
|
|||
emit_function(&mut builder, &mut map, f)?;
|
||||
}
|
||||
let module = builder.module();
|
||||
dbg!(print!("{}", module.disassemble()));
|
||||
Ok(module.assemble())
|
||||
}
|
||||
|
||||
|
@ -122,19 +123,44 @@ fn emit_function<'a>(
|
|||
if f.kernel {
|
||||
builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[]);
|
||||
}
|
||||
let (mut func_body, bbs, _, unique_ids) = to_ssa(&f.args, f.body);
|
||||
let id_offset = builder.reserve_ids(unique_ids);
|
||||
emit_function_args(builder, id_offset, map, &f.args);
|
||||
apply_id_offset(&mut func_body, id_offset);
|
||||
emit_function_body_ops(builder, map, &func_body, &bbs)?;
|
||||
builder.end_function()?;
|
||||
Ok(func_id)
|
||||
}
|
||||
|
||||
fn apply_id_offset(func_body: &mut Vec<Statement>, id_offset: u32) {
|
||||
for s in func_body {
|
||||
s.visit_id_mut(&mut |_, id| *id += id_offset);
|
||||
}
|
||||
}
|
||||
|
||||
fn to_ssa<'a>(
|
||||
f_args: &[ast::Argument],
|
||||
f_body: Vec<ast::Statement<&'a str>>,
|
||||
) -> (
|
||||
Vec<Statement>,
|
||||
Vec<BasicBlock>,
|
||||
Vec<Vec<PhiDef>>,
|
||||
spirv::Word,
|
||||
) {
|
||||
let mut contant_ids = HashMap::new();
|
||||
collect_arg_ids(&mut contant_ids, &f.args);
|
||||
collect_label_ids(&mut contant_ids, &f.body);
|
||||
let registers = collect_registers(&f.body);
|
||||
let (normalized_ids, unique_ids, type_check) =
|
||||
normalize_identifiers(f.body, &contant_ids, registers);
|
||||
let mut type_check = HashMap::new();
|
||||
collect_arg_ids(&mut contant_ids, &mut type_check, &f_args);
|
||||
collect_label_ids(&mut contant_ids, &f_body);
|
||||
let registers = collect_var_definitions(&f_args, &f_body);
|
||||
let (normalized_ids, unique_ids) =
|
||||
normalize_identifiers(f_body, &contant_ids, &mut type_check, registers);
|
||||
let (mut func_body, unique_ids) =
|
||||
insert_implicit_conversion(normalized_ids, unique_ids, &|x| type_check[&x]);
|
||||
insert_implicit_conversions(normalized_ids, unique_ids, &|x| type_check[&x]);
|
||||
let bbs = get_basic_blocks(&func_body);
|
||||
let rpostorder = to_reverse_postorder(&bbs);
|
||||
let doms = immediate_dominators(&bbs, &rpostorder);
|
||||
let dom_fronts = dominance_frontiers(&bbs, &doms);
|
||||
let (_, unique_ids) = ssa_legalize(
|
||||
let (phis, unique_ids) = ssa_legalize(
|
||||
&mut func_body,
|
||||
contant_ids.len() as u32,
|
||||
unique_ids,
|
||||
|
@ -142,15 +168,17 @@ fn emit_function<'a>(
|
|||
&doms,
|
||||
&dom_fronts,
|
||||
);
|
||||
let id_offset = builder.reserve_ids(unique_ids);
|
||||
emit_function_args(builder, id_offset, map, &f.args);
|
||||
emit_function_body_ops(builder, id_offset, map, &func_body, &bbs)?;
|
||||
builder.end_function()?;
|
||||
Ok(func_id)
|
||||
(func_body, bbs, phis, unique_ids)
|
||||
}
|
||||
|
||||
fn collect_registers<'a>(body: &[ast::Statement<&'a str>]) -> HashMap<Cow<'a, str>, ast::Type> {
|
||||
fn collect_var_definitions<'a>(
|
||||
args: &[ast::Argument<'a>],
|
||||
body: &[ast::Statement<&'a str>],
|
||||
) -> HashMap<Cow<'a, str>, ast::Type> {
|
||||
let mut result = HashMap::new();
|
||||
for param in args {
|
||||
result.insert(Cow::Borrowed(param.name), ast::Type::Scalar(param.a_type));
|
||||
}
|
||||
for s in body {
|
||||
match s {
|
||||
ast::Statement::Variable(var) => match var.count {
|
||||
|
@ -170,12 +198,19 @@ fn collect_registers<'a>(body: &[ast::Statement<&'a str>]) -> HashMap<Cow<'a, st
|
|||
}
|
||||
|
||||
/*
|
||||
There are three kinds of implicit conversions in PTX:
|
||||
There are several kinds of implicit conversions in PTX:
|
||||
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
|
||||
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
|
||||
* pointer dereference in st/ld: not documented, but for instruction `ld.<space>.<type> x, [y]` semantics are x = *(<type>*)y
|
||||
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
|
||||
semantics are to first zext/chop/bitcast `y` as needed and then do
|
||||
documented special ld/st/cvt conversion rules for destination operands
|
||||
- generic ld: for instruction `ld x, [y]`, y must be of type b64/u64/s64,
|
||||
which is bitcast to a pointer, dereferenced and then documented special
|
||||
ld/st/cvt conversion rules are applied
|
||||
- generic ld: for instruction `ld [x], y`, x must be of type b64/u64/s64,
|
||||
which is bitcast to a pointer
|
||||
*/
|
||||
fn insert_implicit_conversion<TypeCheck: Fn(spirv::Word) -> ast::Type>(
|
||||
fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
|
||||
normalized_ids: Vec<Statement>,
|
||||
unique_ids: spirv::Word,
|
||||
type_check: &TypeCheck,
|
||||
|
@ -190,16 +225,42 @@ fn insert_implicit_conversion<TypeCheck: Fn(spirv::Word) -> ast::Type>(
|
|||
for s in normalized_ids.into_iter() {
|
||||
match s {
|
||||
Statement::Instruction(inst) => match inst {
|
||||
ast::Instruction::Add(add, arg) => {
|
||||
arg.insert_implicit_conversions(
|
||||
ast::Instruction::Ld(ld, mut arg) => {
|
||||
let new_arg_src = arg.src.map_id(&mut |arg_src| {
|
||||
insert_implicit_conversions_ld_src(
|
||||
&mut result,
|
||||
ast::Type::Scalar(ld.typ),
|
||||
type_check,
|
||||
new_id,
|
||||
|instr, op| ld.state_space.should_convert(instr, op),
|
||||
arg_src,
|
||||
)
|
||||
});
|
||||
arg.src = new_arg_src;
|
||||
insert_implicit_bitcasts(
|
||||
false,
|
||||
true,
|
||||
&mut result,
|
||||
ast::Type::Scalar(add.typ),
|
||||
type_check,
|
||||
new_id,
|
||||
|arg| Statement::Instruction(ast::Instruction::Add(add, arg)),
|
||||
ast::Instruction::Ld(ld, arg),
|
||||
);
|
||||
}
|
||||
_ => todo!(),
|
||||
ast::Instruction::St(st, mut arg) => {
|
||||
let arg_dst_type = type_check(arg.dst);
|
||||
let new_dst = new_id();
|
||||
result.push(Statement::Converison(ImplicitConversion{
|
||||
src: arg.dst,
|
||||
dst: new_dst,
|
||||
from: arg_dst_type,
|
||||
to: ast::Type::Scalar(st.typ),
|
||||
kind: ConversionKind::Ptr
|
||||
}));
|
||||
arg.dst = new_dst;
|
||||
}
|
||||
inst @ _ => {
|
||||
insert_implicit_bitcasts(true, true, &mut result, type_check, new_id, inst)
|
||||
}
|
||||
},
|
||||
s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s),
|
||||
Statement::Converison(_) => unreachable!(),
|
||||
|
@ -236,10 +297,15 @@ fn emit_function_args(
|
|||
}
|
||||
}
|
||||
|
||||
fn collect_arg_ids<'a>(result: &mut HashMap<&'a str, spirv::Word>, args: &'a [ast::Argument<'a>]) {
|
||||
fn collect_arg_ids<'a>(
|
||||
result: &mut HashMap<&'a str, spirv::Word>,
|
||||
type_check: &mut HashMap<spirv::Word, ast::Type>,
|
||||
args: &'a [ast::Argument<'a>],
|
||||
) {
|
||||
let mut id = result.len() as u32;
|
||||
for arg in args {
|
||||
result.insert(arg.name, id);
|
||||
type_check.insert(id, ast::Type::Scalar(arg.a_type));
|
||||
id += 1;
|
||||
}
|
||||
}
|
||||
|
@ -263,7 +329,6 @@ fn collect_label_ids<'a>(
|
|||
|
||||
fn emit_function_body_ops(
|
||||
builder: &mut dr::Builder,
|
||||
id_offset: spirv::Word,
|
||||
map: &mut TypeWordMap,
|
||||
func: &[Statement],
|
||||
cfg: &[BasicBlock],
|
||||
|
@ -276,56 +341,40 @@ fn emit_function_body_ops(
|
|||
continue;
|
||||
}
|
||||
let header_id = if let Statement::Label(id) = body[0] {
|
||||
Some(id_offset + id)
|
||||
Some(id)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
builder.begin_block(header_id)?;
|
||||
for s in body {
|
||||
match s {
|
||||
// If block startd with a label it has already been emitted,
|
||||
// If block starts with a label it has already been emitted,
|
||||
// all other labels in the block are unused
|
||||
Statement::Label(_) => (),
|
||||
Statement::Converison(_) => todo!(),
|
||||
Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?,
|
||||
Statement::Conditional(bra) => {
|
||||
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
|
||||
}
|
||||
Statement::Instruction(inst) => match inst {
|
||||
// SPIR-V does not support marking jumps as guaranteed-converged
|
||||
ast::Instruction::Bra(_, arg) => {
|
||||
builder.branch(arg.src + id_offset)?;
|
||||
builder.branch(arg.src)?;
|
||||
}
|
||||
ast::Instruction::Ld(data, arg) => {
|
||||
if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
|
||||
todo!()
|
||||
}
|
||||
let src = match arg.src {
|
||||
ast::Operand::Reg(id) => id + id_offset,
|
||||
ast::Operand::Reg(id) => id,
|
||||
_ => todo!(),
|
||||
};
|
||||
let result_type = map.get_or_add_scalar(builder, data.typ);
|
||||
match data.state_space {
|
||||
ast::LdStateSpace::Generic => {
|
||||
// TODO: make the cast optional
|
||||
let ptr_result_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::Pointer(
|
||||
data.typ,
|
||||
spirv::StorageClass::CrossWorkgroup,
|
||||
),
|
||||
);
|
||||
let bitcast =
|
||||
builder.convert_u_to_ptr(ptr_result_type, None, src)?;
|
||||
builder.load(
|
||||
result_type,
|
||||
Some(arg.dst + id_offset),
|
||||
bitcast,
|
||||
None,
|
||||
[],
|
||||
)?;
|
||||
builder.load(result_type, Some(arg.dst), src, None, [])?;
|
||||
}
|
||||
ast::LdStateSpace::Param => {
|
||||
builder.copy_object(result_type, Some(arg.dst + id_offset), src)?;
|
||||
builder.copy_object(result_type, Some(arg.dst), src)?;
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
|
@ -338,17 +387,10 @@ fn emit_function_body_ops(
|
|||
todo!()
|
||||
}
|
||||
let src = match arg.src {
|
||||
ast::Operand::Reg(id) => id + id_offset,
|
||||
ast::Operand::Reg(id) => id,
|
||||
_ => todo!(),
|
||||
};
|
||||
// TODO make cast optional
|
||||
let ptr_result_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::Pointer(data.typ, spirv::StorageClass::CrossWorkgroup),
|
||||
);
|
||||
let bitcast =
|
||||
builder.convert_u_to_ptr(ptr_result_type, None, arg.dst + id_offset)?;
|
||||
builder.store(bitcast, src, None, &[])?;
|
||||
builder.store(arg.dst, src, None, &[])?;
|
||||
}
|
||||
// SPIR-V does not support ret as guaranteed-converged
|
||||
ast::Instruction::Ret(_) => builder.ret()?,
|
||||
|
@ -360,12 +402,76 @@ fn emit_function_body_ops(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_implicit_conversion(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
cv: &ImplicitConversion,
|
||||
) -> Result<(), dr::Error> {
|
||||
let (from_type, to_type) = match (cv.from, cv.to) {
|
||||
(ast::Type::Scalar(from), ast::Type::Scalar(to)) => (from, to),
|
||||
_ => todo!(),
|
||||
};
|
||||
match cv.kind {
|
||||
ConversionKind::Ptr => {
|
||||
let dst_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::Pointer(to_type, spirv_headers::StorageClass::Generic),
|
||||
);
|
||||
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
|
||||
}
|
||||
ConversionKind::Default => {
|
||||
if from_type.width() == to_type.width() {
|
||||
if from_type.kind() == ScalarKind::Unsigned && to_type.kind() == ScalarKind::Byte
|
||||
|| from_type.kind() == ScalarKind::Byte
|
||||
&& to_type.kind() == ScalarKind::Unsigned
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
let dst_type = map.get_or_add_scalar(builder, to_type);
|
||||
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
|
||||
} else {
|
||||
let as_unsigned_type = map.get_or_add_scalar(
|
||||
builder,
|
||||
ast::ScalarType::from_parts(from_type.width(), ScalarKind::Unsigned),
|
||||
);
|
||||
let as_unsigned = builder.bitcast(as_unsigned_type, None, cv.src)?;
|
||||
let as_unsigned_wide_type =
|
||||
ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned);
|
||||
let as_unsigned_wide_spirv = map.get_or_add_scalar(
|
||||
builder,
|
||||
ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned),
|
||||
);
|
||||
if to_type.kind() == ScalarKind::Unsigned || to_type.kind() == ScalarKind::Byte {
|
||||
builder.u_convert(as_unsigned_wide_spirv, Some(cv.dst), as_unsigned)?;
|
||||
} else {
|
||||
let as_unsigned_wide =
|
||||
builder.u_convert(as_unsigned_wide_spirv, None, as_unsigned)?;
|
||||
emit_implicit_conversion(
|
||||
builder,
|
||||
map,
|
||||
&ImplicitConversion {
|
||||
src: as_unsigned_wide,
|
||||
dst: cv.dst,
|
||||
from: ast::Type::Scalar(as_unsigned_wide_type),
|
||||
to: cv.to,
|
||||
kind: ConversionKind::Default,
|
||||
},
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
ConversionKind::SignExtend => todo!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TODO: support scopes
|
||||
fn normalize_identifiers<'a>(
|
||||
func: Vec<ast::Statement<&'a str>>,
|
||||
constant_identifiers: &HashMap<&'a str, spirv::Word>, // arguments and labels can't be redefined
|
||||
type_map: &mut HashMap<spirv::Word, ast::Type>,
|
||||
types: HashMap<Cow<'a, str>, ast::Type>,
|
||||
) -> (Vec<Statement>, spirv::Word, HashMap<spirv::Word, ast::Type>) {
|
||||
) -> (Vec<Statement>, spirv::Word) {
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
let mut id: u32 = constant_identifiers.len() as u32;
|
||||
let mut remapped_ids = HashMap::new();
|
||||
|
@ -389,11 +495,12 @@ fn normalize_identifiers<'a>(
|
|||
for s in func {
|
||||
Statement::from_ast(s, &mut result, &mut get_or_add);
|
||||
}
|
||||
let mut type_map = HashMap::with_capacity(types.len());
|
||||
for (old_id, new_id) in remapped_ids {
|
||||
type_map.insert(new_id, types[old_id]);
|
||||
}
|
||||
(result, id, type_map)
|
||||
type_map.extend(
|
||||
remapped_ids
|
||||
.into_iter()
|
||||
.map(|(old_id, new_id)| (new_id, types[old_id])),
|
||||
);
|
||||
(result, id)
|
||||
}
|
||||
|
||||
fn ssa_legalize(
|
||||
|
@ -911,10 +1018,17 @@ impl BrachCondition {
|
|||
}
|
||||
|
||||
struct ImplicitConversion {
|
||||
dst: spirv::Word,
|
||||
src: spirv::Word,
|
||||
dst: spirv::Word,
|
||||
from: ast::Type,
|
||||
to: ast::Type,
|
||||
kind: ConversionKind,
|
||||
}
|
||||
|
||||
enum ConversionKind {
|
||||
Default, // zero-extend/chop/bitcast depending on types
|
||||
SignExtend,
|
||||
Ptr,
|
||||
}
|
||||
|
||||
impl ImplicitConversion {
|
||||
|
@ -1050,6 +1164,16 @@ impl<T> ast::Instruction<T> {
|
|||
ast::Instruction::Ret(_) => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Option<ast::Type> {
|
||||
match self {
|
||||
ast::Instruction::Add(add, _) => Some(ast::Type::Scalar(add.typ)),
|
||||
ast::Instruction::Ret(_) => None,
|
||||
ast::Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)),
|
||||
ast::Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)),
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Copy> ast::Instruction<T> {
|
||||
|
@ -1162,31 +1286,6 @@ impl<T> ast::Arg3<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl ast::Arg3<spirv::Word> {
|
||||
fn insert_implicit_conversions<
|
||||
TypeCheck: Fn(spirv::Word) -> ast::Type,
|
||||
NewId: FnMut() -> spirv::Word,
|
||||
NewStatement: FnOnce(Self) -> Statement,
|
||||
>(
|
||||
self,
|
||||
func: &mut Vec<Statement>,
|
||||
op_type: ast::Type,
|
||||
type_check: &TypeCheck,
|
||||
new_id: &mut NewId,
|
||||
new_statement: NewStatement,
|
||||
) {
|
||||
let src1 = self
|
||||
.src1
|
||||
.insert_implicit_conversion(func, op_type, type_check, new_id);
|
||||
let src2 = self
|
||||
.src2
|
||||
.insert_implicit_conversion(func, op_type, type_check, new_id);
|
||||
insert_implicit_conversion_dst(func, op_type, type_check, new_id, self.dst, |dst| {
|
||||
new_statement(Self { dst, src1, src2 })
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ast::Arg4<T> {
|
||||
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg4<U> {
|
||||
ast::Arg4 {
|
||||
|
@ -1266,37 +1365,6 @@ impl<T> ast::Operand<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl ast::Operand<spirv::Word> {
|
||||
fn insert_implicit_conversion<
|
||||
TypeCheck: Fn(spirv::Word) -> ast::Type,
|
||||
NewId: FnMut() -> spirv::Word,
|
||||
>(
|
||||
self,
|
||||
func: &mut Vec<Statement>,
|
||||
op_type: ast::Type,
|
||||
type_check: &TypeCheck,
|
||||
new_id: &mut NewId,
|
||||
) -> Self {
|
||||
match self {
|
||||
ast::Operand::Reg(src) => {
|
||||
if type_check(src) == op_type {
|
||||
return self;
|
||||
}
|
||||
let new_src = new_id();
|
||||
func.push(Statement::Converison(ImplicitConversion {
|
||||
src: src,
|
||||
dst: new_src,
|
||||
from: type_check(src),
|
||||
to: op_type,
|
||||
}));
|
||||
ast::Operand::Reg(new_src)
|
||||
}
|
||||
o @ ast::Operand::Imm(_) => o,
|
||||
ast::Operand::RegOffset(_, _) => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ast::MovOperand<T> {
|
||||
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::MovOperand<U> {
|
||||
match self {
|
||||
|
@ -1320,29 +1388,266 @@ impl<T> ast::MovOperand<T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn insert_implicit_conversion_dst<
|
||||
#[derive(Clone, Copy, PartialEq)]
|
||||
enum ScalarKind {
|
||||
Byte,
|
||||
Unsigned,
|
||||
Signed,
|
||||
Float,
|
||||
}
|
||||
|
||||
impl ast::ScalarType {
|
||||
fn width(self) -> u8 {
|
||||
match self {
|
||||
ast::ScalarType::U8 => 1,
|
||||
ast::ScalarType::S8 => 1,
|
||||
ast::ScalarType::B8 => 1,
|
||||
ast::ScalarType::U16 => 2,
|
||||
ast::ScalarType::S16 => 2,
|
||||
ast::ScalarType::B16 => 2,
|
||||
ast::ScalarType::F16 => 2,
|
||||
ast::ScalarType::U32 => 4,
|
||||
ast::ScalarType::S32 => 4,
|
||||
ast::ScalarType::B32 => 4,
|
||||
ast::ScalarType::F32 => 4,
|
||||
ast::ScalarType::U64 => 8,
|
||||
ast::ScalarType::S64 => 8,
|
||||
ast::ScalarType::B64 => 8,
|
||||
ast::ScalarType::F64 => 8,
|
||||
}
|
||||
}
|
||||
|
||||
fn kind(self) -> ScalarKind {
|
||||
match self {
|
||||
ast::ScalarType::U8 => ScalarKind::Unsigned,
|
||||
ast::ScalarType::U16 => ScalarKind::Unsigned,
|
||||
ast::ScalarType::U32 => ScalarKind::Unsigned,
|
||||
ast::ScalarType::U64 => ScalarKind::Unsigned,
|
||||
ast::ScalarType::S8 => ScalarKind::Signed,
|
||||
ast::ScalarType::S16 => ScalarKind::Signed,
|
||||
ast::ScalarType::S32 => ScalarKind::Signed,
|
||||
ast::ScalarType::S64 => ScalarKind::Signed,
|
||||
ast::ScalarType::B8 => ScalarKind::Byte,
|
||||
ast::ScalarType::B16 => ScalarKind::Byte,
|
||||
ast::ScalarType::B32 => ScalarKind::Byte,
|
||||
ast::ScalarType::B64 => ScalarKind::Byte,
|
||||
ast::ScalarType::F16 => ScalarKind::Float,
|
||||
ast::ScalarType::F32 => ScalarKind::Float,
|
||||
ast::ScalarType::F64 => ScalarKind::Float,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_parts(width: u8, kind: ScalarKind) -> Self {
|
||||
match kind {
|
||||
ScalarKind::Float => match width {
|
||||
2 => ast::ScalarType::F16,
|
||||
4 => ast::ScalarType::F32,
|
||||
8 => ast::ScalarType::F64,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
ScalarKind::Byte => match width {
|
||||
1 => ast::ScalarType::B8,
|
||||
2 => ast::ScalarType::B16,
|
||||
4 => ast::ScalarType::B32,
|
||||
8 => ast::ScalarType::B64,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
ScalarKind::Signed => match width {
|
||||
1 => ast::ScalarType::S8,
|
||||
2 => ast::ScalarType::S16,
|
||||
4 => ast::ScalarType::S32,
|
||||
8 => ast::ScalarType::S64,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
ScalarKind::Unsigned => match width {
|
||||
1 => ast::ScalarType::U8,
|
||||
2 => ast::ScalarType::U16,
|
||||
4 => ast::ScalarType::U32,
|
||||
8 => ast::ScalarType::U64,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
||||
match (instr, operand) {
|
||||
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
||||
if inst.width() != operand.width() {
|
||||
return false;
|
||||
}
|
||||
match inst.kind() {
|
||||
ScalarKind::Byte => operand.kind() != ScalarKind::Byte,
|
||||
ScalarKind::Float => operand.kind() == ScalarKind::Byte,
|
||||
ScalarKind::Signed => {
|
||||
operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Unsigned
|
||||
}
|
||||
ScalarKind::Unsigned => {
|
||||
operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::LdStateSpace {
|
||||
fn should_convert(self, instr_type: ast::Type, op_type: ast::Type) -> Option<ConversionKind> {
|
||||
match self {
|
||||
ast::LdStateSpace::Param => {
|
||||
if instr_type != op_type {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::LdStateSpace::Generic => Some(ConversionKind::Ptr),
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_forced_bitcast_src<
|
||||
TypeCheck: Fn(spirv::Word) -> ast::Type,
|
||||
NewId: FnMut() -> spirv::Word,
|
||||
NewStatement: FnOnce(spirv::Word) -> Statement,
|
||||
>(
|
||||
func: &mut Vec<Statement>,
|
||||
op_type: ast::Type,
|
||||
type_check: &TypeCheck,
|
||||
new_id: &mut NewId,
|
||||
dst: spirv::Word,
|
||||
new_statement: NewStatement,
|
||||
) {
|
||||
if type_check(dst) == op_type {
|
||||
func.push(new_statement(dst));
|
||||
} else {
|
||||
let new_dst = new_id();
|
||||
func.push(new_statement(new_dst));
|
||||
src: spirv::Word,
|
||||
) -> spirv::Word {
|
||||
let src_type = type_check(src);
|
||||
if src_type == op_type {
|
||||
return src;
|
||||
}
|
||||
let new_src = new_id();
|
||||
func.push(Statement::Converison(ImplicitConversion {
|
||||
src: src,
|
||||
dst: new_src,
|
||||
from: src_type,
|
||||
to: op_type,
|
||||
kind: ConversionKind::Default,
|
||||
}));
|
||||
new_src
|
||||
}
|
||||
|
||||
fn insert_implicit_conversions_ld_src<
|
||||
TypeCheck: Fn(spirv::Word) -> ast::Type,
|
||||
NewId: FnMut() -> spirv::Word,
|
||||
ShouldConvert: Fn(ast::Type, ast::Type) -> Option<ConversionKind>,
|
||||
>(
|
||||
func: &mut Vec<Statement>,
|
||||
instr_type: ast::Type,
|
||||
type_check: &TypeCheck,
|
||||
new_id: &mut NewId,
|
||||
should_convert: ShouldConvert,
|
||||
src: spirv::Word,
|
||||
) -> spirv::Word {
|
||||
let src_type = type_check(src);
|
||||
if let Some(conv_kind) = should_convert(src_type, instr_type) {
|
||||
let new_src = new_id();
|
||||
func.push(Statement::Converison(ImplicitConversion {
|
||||
src: new_dst,
|
||||
dst: dst,
|
||||
from: type_check(new_dst),
|
||||
to: op_type,
|
||||
src: src,
|
||||
dst: new_src,
|
||||
from: src_type,
|
||||
to: instr_type,
|
||||
kind: conv_kind,
|
||||
}));
|
||||
new_src
|
||||
} else {
|
||||
src
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
|
||||
fn should_convert_relaxed_dst(
|
||||
dst_type: ast::Type,
|
||||
instr_type: ast::ScalarType,
|
||||
) -> Option<ConversionKind> {
|
||||
if dst_type == ast::Type::Scalar(instr_type) {
|
||||
return None;
|
||||
}
|
||||
match dst_type {
|
||||
ast::Type::Scalar(dst_type) => match instr_type.kind() {
|
||||
ScalarKind::Byte => {
|
||||
if instr_type.width() <= dst_type.width() {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ScalarKind::Signed => {
|
||||
if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float {
|
||||
Some(ConversionKind::SignExtend)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ScalarKind::Unsigned => {
|
||||
if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ScalarKind::Float => {
|
||||
if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Float {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_implicit_bitcasts<
|
||||
TypeCheck: Fn(spirv::Word) -> ast::Type,
|
||||
NewId: FnMut() -> spirv::Word,
|
||||
>(
|
||||
do_src_bitcast: bool,
|
||||
do_dst_bitcast: bool,
|
||||
func: &mut Vec<Statement>,
|
||||
type_check: &TypeCheck,
|
||||
new_id: &mut NewId,
|
||||
mut instr: ast::Instruction<spirv::Word>,
|
||||
) {
|
||||
let mut dst_coercion = None;
|
||||
if let Some(instr_type) = instr.get_type() {
|
||||
instr.visit_id_mut(&mut |is_dst, id| {
|
||||
if (is_dst && !do_dst_bitcast) || (!is_dst && !do_src_bitcast) {
|
||||
return;
|
||||
}
|
||||
let id_type = type_check(*id);
|
||||
if should_bitcast(instr_type, type_check(*id)) {
|
||||
let replacement_id = new_id();
|
||||
if is_dst {
|
||||
dst_coercion = Some(ImplicitConversion {
|
||||
src: replacement_id,
|
||||
dst: *id,
|
||||
from: instr_type,
|
||||
to: id_type,
|
||||
kind: ConversionKind::Default,
|
||||
});
|
||||
*id = replacement_id;
|
||||
} else {
|
||||
func.push(Statement::Converison(ImplicitConversion {
|
||||
src: *id,
|
||||
dst: replacement_id,
|
||||
from: id_type,
|
||||
to: instr_type,
|
||||
kind: ConversionKind::Default,
|
||||
}));
|
||||
*id = replacement_id;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
func.push(Statement::Instruction(instr));
|
||||
if let Some(cond) = dst_coercion {
|
||||
func.push(Statement::Converison(cond));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1678,6 +1983,12 @@ mod tests {
|
|||
|
||||
// page 403
|
||||
const FIG_19_4: &'static str = "{
|
||||
.reg.u32 i;
|
||||
.reg.u32 j;
|
||||
.reg.u32 k;
|
||||
.reg.pred p;
|
||||
.reg.pred q;
|
||||
|
||||
mov.u32 i, 1;
|
||||
mov.u32 j, 1;
|
||||
mov.u32 k, 0;
|
||||
|
@ -1710,7 +2021,9 @@ mod tests {
|
|||
assert_eq!(errors.len(), 0);
|
||||
let mut constant_ids = HashMap::new();
|
||||
collect_label_ids(&mut constant_ids, &ast);
|
||||
let (normalized_ids, _) = normalize_identifiers(ast, &constant_ids);
|
||||
let registers = collect_var_definitions(&[], &ast);
|
||||
let (normalized_ids, _) =
|
||||
normalize_identifiers(ast, &constant_ids, &mut HashMap::new(), registers);
|
||||
let mut bbs = get_basic_blocks(&normalized_ids);
|
||||
bbs.iter_mut().for_each(sort_pred_succ);
|
||||
assert_eq!(
|
||||
|
@ -1857,7 +2170,9 @@ mod tests {
|
|||
let mut constant_ids = HashMap::new();
|
||||
collect_label_ids(&mut constant_ids, &fn_ast);
|
||||
assert_eq!(constant_ids.len(), 4);
|
||||
let (normalized_ids, max_id) = normalize_identifiers(fn_ast, &constant_ids);
|
||||
let registers = collect_var_definitions(&[], &fn_ast);
|
||||
let (normalized_ids, max_id) =
|
||||
normalize_identifiers(fn_ast, &constant_ids, &mut HashMap::new(), registers);
|
||||
let bbs = get_basic_blocks(&normalized_ids);
|
||||
let rpostorder = to_reverse_postorder(&bbs);
|
||||
let doms = immediate_dominators(&bbs, &rpostorder);
|
||||
|
@ -1895,21 +2210,7 @@ mod tests {
|
|||
.parse(&mut errors, func)
|
||||
.unwrap();
|
||||
assert_eq!(errors.len(), 0);
|
||||
let mut constant_ids = HashMap::new();
|
||||
collect_label_ids(&mut constant_ids, &fn_ast);
|
||||
let (mut func, unique_ids) = normalize_identifiers(fn_ast, &constant_ids);
|
||||
let bbs = get_basic_blocks(&func);
|
||||
let rpostorder = to_reverse_postorder(&bbs);
|
||||
let doms = immediate_dominators(&bbs, &rpostorder);
|
||||
let dom_fronts = dominance_frontiers(&bbs, &doms);
|
||||
let (mut ssa_phis, _) = ssa_legalize(
|
||||
&mut func,
|
||||
constant_ids.len() as u32,
|
||||
unique_ids,
|
||||
&bbs,
|
||||
&doms,
|
||||
&dom_fronts,
|
||||
);
|
||||
let (func, _, mut ssa_phis, unique_ids) = to_ssa(&[], fn_ast);
|
||||
assert_phi_dst_id(unique_ids, &ssa_phis);
|
||||
assert_dst_unique(&func, &ssa_phis);
|
||||
sort_phi(&mut ssa_phis);
|
||||
|
|
Loading…
Add table
Reference in a new issue