Start implementing implicit conversions

This commit is contained in:
Andrzej Janik 2020-05-26 00:33:32 +02:00
parent 9f60990765
commit 4a0edf0e14
2 changed files with 455 additions and 154 deletions

View file

@ -200,7 +200,7 @@ pub struct LdData {
pub typ: ScalarType,
}
#[derive(PartialEq, Eq)]
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum LdStQualifier {
Weak,
Volatile,
@ -208,14 +208,14 @@ pub enum LdStQualifier {
Acquire(LdScope),
}
#[derive(PartialEq, Eq)]
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum LdScope {
Cta,
Gpu,
Sys,
}
#[derive(PartialEq, Eq)]
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum LdStateSpace {
Generic,
Const,
@ -225,7 +225,7 @@ pub enum LdStateSpace {
Shared,
}
#[derive(PartialEq, Eq)]
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum LdCacheOperator {
Cached,
L2Only,

View file

@ -5,7 +5,7 @@ use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::{borrow::Cow, fmt};
use rspirv::binary::Assemble;
use rspirv::binary::{Assemble, Disassemble};
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvType {
@ -86,6 +86,7 @@ pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, dr::Error> {
emit_function(&mut builder, &mut map, f)?;
}
let module = builder.module();
dbg!(print!("{}", module.disassemble()));
Ok(module.assemble())
}
@ -122,19 +123,44 @@ fn emit_function<'a>(
if f.kernel {
builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[]);
}
let (mut func_body, bbs, _, unique_ids) = to_ssa(&f.args, f.body);
let id_offset = builder.reserve_ids(unique_ids);
emit_function_args(builder, id_offset, map, &f.args);
apply_id_offset(&mut func_body, id_offset);
emit_function_body_ops(builder, map, &func_body, &bbs)?;
builder.end_function()?;
Ok(func_id)
}
fn apply_id_offset(func_body: &mut Vec<Statement>, id_offset: u32) {
for s in func_body {
s.visit_id_mut(&mut |_, id| *id += id_offset);
}
}
fn to_ssa<'a>(
f_args: &[ast::Argument],
f_body: Vec<ast::Statement<&'a str>>,
) -> (
Vec<Statement>,
Vec<BasicBlock>,
Vec<Vec<PhiDef>>,
spirv::Word,
) {
let mut contant_ids = HashMap::new();
collect_arg_ids(&mut contant_ids, &f.args);
collect_label_ids(&mut contant_ids, &f.body);
let registers = collect_registers(&f.body);
let (normalized_ids, unique_ids, type_check) =
normalize_identifiers(f.body, &contant_ids, registers);
let mut type_check = HashMap::new();
collect_arg_ids(&mut contant_ids, &mut type_check, &f_args);
collect_label_ids(&mut contant_ids, &f_body);
let registers = collect_var_definitions(&f_args, &f_body);
let (normalized_ids, unique_ids) =
normalize_identifiers(f_body, &contant_ids, &mut type_check, registers);
let (mut func_body, unique_ids) =
insert_implicit_conversion(normalized_ids, unique_ids, &|x| type_check[&x]);
insert_implicit_conversions(normalized_ids, unique_ids, &|x| type_check[&x]);
let bbs = get_basic_blocks(&func_body);
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
let dom_fronts = dominance_frontiers(&bbs, &doms);
let (_, unique_ids) = ssa_legalize(
let (phis, unique_ids) = ssa_legalize(
&mut func_body,
contant_ids.len() as u32,
unique_ids,
@ -142,15 +168,17 @@ fn emit_function<'a>(
&doms,
&dom_fronts,
);
let id_offset = builder.reserve_ids(unique_ids);
emit_function_args(builder, id_offset, map, &f.args);
emit_function_body_ops(builder, id_offset, map, &func_body, &bbs)?;
builder.end_function()?;
Ok(func_id)
(func_body, bbs, phis, unique_ids)
}
fn collect_registers<'a>(body: &[ast::Statement<&'a str>]) -> HashMap<Cow<'a, str>, ast::Type> {
fn collect_var_definitions<'a>(
args: &[ast::Argument<'a>],
body: &[ast::Statement<&'a str>],
) -> HashMap<Cow<'a, str>, ast::Type> {
let mut result = HashMap::new();
for param in args {
result.insert(Cow::Borrowed(param.name), ast::Type::Scalar(param.a_type));
}
for s in body {
match s {
ast::Statement::Variable(var) => match var.count {
@ -170,12 +198,19 @@ fn collect_registers<'a>(body: &[ast::Statement<&'a str>]) -> HashMap<Cow<'a, st
}
/*
There are three kinds of implicit conversions in PTX:
There are several kinds of implicit conversions in PTX:
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
* pointer dereference in st/ld: not documented, but for instruction `ld.<space>.<type> x, [y]` semantics are x = *(<type>*)y
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
semantics are to first zext/chop/bitcast `y` as needed and then do
documented special ld/st/cvt conversion rules for destination operands
- generic ld: for instruction `ld x, [y]`, y must be of type b64/u64/s64,
which is bitcast to a pointer, dereferenced and then documented special
ld/st/cvt conversion rules are applied
- generic ld: for instruction `ld [x], y`, x must be of type b64/u64/s64,
which is bitcast to a pointer
*/
fn insert_implicit_conversion<TypeCheck: Fn(spirv::Word) -> ast::Type>(
fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
normalized_ids: Vec<Statement>,
unique_ids: spirv::Word,
type_check: &TypeCheck,
@ -190,16 +225,42 @@ fn insert_implicit_conversion<TypeCheck: Fn(spirv::Word) -> ast::Type>(
for s in normalized_ids.into_iter() {
match s {
Statement::Instruction(inst) => match inst {
ast::Instruction::Add(add, arg) => {
arg.insert_implicit_conversions(
ast::Instruction::Ld(ld, mut arg) => {
let new_arg_src = arg.src.map_id(&mut |arg_src| {
insert_implicit_conversions_ld_src(
&mut result,
ast::Type::Scalar(ld.typ),
type_check,
new_id,
|instr, op| ld.state_space.should_convert(instr, op),
arg_src,
)
});
arg.src = new_arg_src;
insert_implicit_bitcasts(
false,
true,
&mut result,
ast::Type::Scalar(add.typ),
type_check,
new_id,
|arg| Statement::Instruction(ast::Instruction::Add(add, arg)),
ast::Instruction::Ld(ld, arg),
);
}
_ => todo!(),
ast::Instruction::St(st, mut arg) => {
let arg_dst_type = type_check(arg.dst);
let new_dst = new_id();
result.push(Statement::Converison(ImplicitConversion{
src: arg.dst,
dst: new_dst,
from: arg_dst_type,
to: ast::Type::Scalar(st.typ),
kind: ConversionKind::Ptr
}));
arg.dst = new_dst;
}
inst @ _ => {
insert_implicit_bitcasts(true, true, &mut result, type_check, new_id, inst)
}
},
s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s),
Statement::Converison(_) => unreachable!(),
@ -236,10 +297,15 @@ fn emit_function_args(
}
}
fn collect_arg_ids<'a>(result: &mut HashMap<&'a str, spirv::Word>, args: &'a [ast::Argument<'a>]) {
fn collect_arg_ids<'a>(
result: &mut HashMap<&'a str, spirv::Word>,
type_check: &mut HashMap<spirv::Word, ast::Type>,
args: &'a [ast::Argument<'a>],
) {
let mut id = result.len() as u32;
for arg in args {
result.insert(arg.name, id);
type_check.insert(id, ast::Type::Scalar(arg.a_type));
id += 1;
}
}
@ -263,7 +329,6 @@ fn collect_label_ids<'a>(
fn emit_function_body_ops(
builder: &mut dr::Builder,
id_offset: spirv::Word,
map: &mut TypeWordMap,
func: &[Statement],
cfg: &[BasicBlock],
@ -276,56 +341,40 @@ fn emit_function_body_ops(
continue;
}
let header_id = if let Statement::Label(id) = body[0] {
Some(id_offset + id)
Some(id)
} else {
None
};
builder.begin_block(header_id)?;
for s in body {
match s {
// If block startd with a label it has already been emitted,
// If block starts with a label it has already been emitted,
// all other labels in the block are unused
Statement::Label(_) => (),
Statement::Converison(_) => todo!(),
Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?,
Statement::Conditional(bra) => {
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
}
Statement::Instruction(inst) => match inst {
// SPIR-V does not support marking jumps as guaranteed-converged
ast::Instruction::Bra(_, arg) => {
builder.branch(arg.src + id_offset)?;
builder.branch(arg.src)?;
}
ast::Instruction::Ld(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
todo!()
}
let src = match arg.src {
ast::Operand::Reg(id) => id + id_offset,
ast::Operand::Reg(id) => id,
_ => todo!(),
};
let result_type = map.get_or_add_scalar(builder, data.typ);
match data.state_space {
ast::LdStateSpace::Generic => {
// TODO: make the cast optional
let ptr_result_type = map.get_or_add(
builder,
SpirvType::Pointer(
data.typ,
spirv::StorageClass::CrossWorkgroup,
),
);
let bitcast =
builder.convert_u_to_ptr(ptr_result_type, None, src)?;
builder.load(
result_type,
Some(arg.dst + id_offset),
bitcast,
None,
[],
)?;
builder.load(result_type, Some(arg.dst), src, None, [])?;
}
ast::LdStateSpace::Param => {
builder.copy_object(result_type, Some(arg.dst + id_offset), src)?;
builder.copy_object(result_type, Some(arg.dst), src)?;
}
_ => todo!(),
}
@ -338,17 +387,10 @@ fn emit_function_body_ops(
todo!()
}
let src = match arg.src {
ast::Operand::Reg(id) => id + id_offset,
ast::Operand::Reg(id) => id,
_ => todo!(),
};
// TODO make cast optional
let ptr_result_type = map.get_or_add(
builder,
SpirvType::Pointer(data.typ, spirv::StorageClass::CrossWorkgroup),
);
let bitcast =
builder.convert_u_to_ptr(ptr_result_type, None, arg.dst + id_offset)?;
builder.store(bitcast, src, None, &[])?;
builder.store(arg.dst, src, None, &[])?;
}
// SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?,
@ -360,12 +402,76 @@ fn emit_function_body_ops(
Ok(())
}
fn emit_implicit_conversion(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
cv: &ImplicitConversion,
) -> Result<(), dr::Error> {
let (from_type, to_type) = match (cv.from, cv.to) {
(ast::Type::Scalar(from), ast::Type::Scalar(to)) => (from, to),
_ => todo!(),
};
match cv.kind {
ConversionKind::Ptr => {
let dst_type = map.get_or_add(
builder,
SpirvType::Pointer(to_type, spirv_headers::StorageClass::Generic),
);
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
}
ConversionKind::Default => {
if from_type.width() == to_type.width() {
if from_type.kind() == ScalarKind::Unsigned && to_type.kind() == ScalarKind::Byte
|| from_type.kind() == ScalarKind::Byte
&& to_type.kind() == ScalarKind::Unsigned
{
return Ok(());
}
let dst_type = map.get_or_add_scalar(builder, to_type);
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
} else {
let as_unsigned_type = map.get_or_add_scalar(
builder,
ast::ScalarType::from_parts(from_type.width(), ScalarKind::Unsigned),
);
let as_unsigned = builder.bitcast(as_unsigned_type, None, cv.src)?;
let as_unsigned_wide_type =
ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned);
let as_unsigned_wide_spirv = map.get_or_add_scalar(
builder,
ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned),
);
if to_type.kind() == ScalarKind::Unsigned || to_type.kind() == ScalarKind::Byte {
builder.u_convert(as_unsigned_wide_spirv, Some(cv.dst), as_unsigned)?;
} else {
let as_unsigned_wide =
builder.u_convert(as_unsigned_wide_spirv, None, as_unsigned)?;
emit_implicit_conversion(
builder,
map,
&ImplicitConversion {
src: as_unsigned_wide,
dst: cv.dst,
from: ast::Type::Scalar(as_unsigned_wide_type),
to: cv.to,
kind: ConversionKind::Default,
},
)?;
}
}
}
ConversionKind::SignExtend => todo!(),
}
Ok(())
}
// TODO: support scopes
fn normalize_identifiers<'a>(
func: Vec<ast::Statement<&'a str>>,
constant_identifiers: &HashMap<&'a str, spirv::Word>, // arguments and labels can't be redefined
type_map: &mut HashMap<spirv::Word, ast::Type>,
types: HashMap<Cow<'a, str>, ast::Type>,
) -> (Vec<Statement>, spirv::Word, HashMap<spirv::Word, ast::Type>) {
) -> (Vec<Statement>, spirv::Word) {
let mut result = Vec::with_capacity(func.len());
let mut id: u32 = constant_identifiers.len() as u32;
let mut remapped_ids = HashMap::new();
@ -389,11 +495,12 @@ fn normalize_identifiers<'a>(
for s in func {
Statement::from_ast(s, &mut result, &mut get_or_add);
}
let mut type_map = HashMap::with_capacity(types.len());
for (old_id, new_id) in remapped_ids {
type_map.insert(new_id, types[old_id]);
}
(result, id, type_map)
type_map.extend(
remapped_ids
.into_iter()
.map(|(old_id, new_id)| (new_id, types[old_id])),
);
(result, id)
}
fn ssa_legalize(
@ -911,10 +1018,17 @@ impl BrachCondition {
}
struct ImplicitConversion {
dst: spirv::Word,
src: spirv::Word,
dst: spirv::Word,
from: ast::Type,
to: ast::Type,
kind: ConversionKind,
}
enum ConversionKind {
Default, // zero-extend/chop/bitcast depending on types
SignExtend,
Ptr,
}
impl ImplicitConversion {
@ -1050,6 +1164,16 @@ impl<T> ast::Instruction<T> {
ast::Instruction::Ret(_) => (),
}
}
fn get_type(&self) -> Option<ast::Type> {
match self {
ast::Instruction::Add(add, _) => Some(ast::Type::Scalar(add.typ)),
ast::Instruction::Ret(_) => None,
ast::Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)),
ast::Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)),
_ => todo!(),
}
}
}
impl<T: Copy> ast::Instruction<T> {
@ -1162,31 +1286,6 @@ impl<T> ast::Arg3<T> {
}
}
impl ast::Arg3<spirv::Word> {
fn insert_implicit_conversions<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
NewStatement: FnOnce(Self) -> Statement,
>(
self,
func: &mut Vec<Statement>,
op_type: ast::Type,
type_check: &TypeCheck,
new_id: &mut NewId,
new_statement: NewStatement,
) {
let src1 = self
.src1
.insert_implicit_conversion(func, op_type, type_check, new_id);
let src2 = self
.src2
.insert_implicit_conversion(func, op_type, type_check, new_id);
insert_implicit_conversion_dst(func, op_type, type_check, new_id, self.dst, |dst| {
new_statement(Self { dst, src1, src2 })
});
}
}
impl<T> ast::Arg4<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg4<U> {
ast::Arg4 {
@ -1266,37 +1365,6 @@ impl<T> ast::Operand<T> {
}
}
impl ast::Operand<spirv::Word> {
fn insert_implicit_conversion<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
>(
self,
func: &mut Vec<Statement>,
op_type: ast::Type,
type_check: &TypeCheck,
new_id: &mut NewId,
) -> Self {
match self {
ast::Operand::Reg(src) => {
if type_check(src) == op_type {
return self;
}
let new_src = new_id();
func.push(Statement::Converison(ImplicitConversion {
src: src,
dst: new_src,
from: type_check(src),
to: op_type,
}));
ast::Operand::Reg(new_src)
}
o @ ast::Operand::Imm(_) => o,
ast::Operand::RegOffset(_, _) => todo!(),
}
}
}
impl<T> ast::MovOperand<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::MovOperand<U> {
match self {
@ -1320,29 +1388,266 @@ impl<T> ast::MovOperand<T> {
}
}
fn insert_implicit_conversion_dst<
#[derive(Clone, Copy, PartialEq)]
enum ScalarKind {
Byte,
Unsigned,
Signed,
Float,
}
impl ast::ScalarType {
fn width(self) -> u8 {
match self {
ast::ScalarType::U8 => 1,
ast::ScalarType::S8 => 1,
ast::ScalarType::B8 => 1,
ast::ScalarType::U16 => 2,
ast::ScalarType::S16 => 2,
ast::ScalarType::B16 => 2,
ast::ScalarType::F16 => 2,
ast::ScalarType::U32 => 4,
ast::ScalarType::S32 => 4,
ast::ScalarType::B32 => 4,
ast::ScalarType::F32 => 4,
ast::ScalarType::U64 => 8,
ast::ScalarType::S64 => 8,
ast::ScalarType::B64 => 8,
ast::ScalarType::F64 => 8,
}
}
fn kind(self) -> ScalarKind {
match self {
ast::ScalarType::U8 => ScalarKind::Unsigned,
ast::ScalarType::U16 => ScalarKind::Unsigned,
ast::ScalarType::U32 => ScalarKind::Unsigned,
ast::ScalarType::U64 => ScalarKind::Unsigned,
ast::ScalarType::S8 => ScalarKind::Signed,
ast::ScalarType::S16 => ScalarKind::Signed,
ast::ScalarType::S32 => ScalarKind::Signed,
ast::ScalarType::S64 => ScalarKind::Signed,
ast::ScalarType::B8 => ScalarKind::Byte,
ast::ScalarType::B16 => ScalarKind::Byte,
ast::ScalarType::B32 => ScalarKind::Byte,
ast::ScalarType::B64 => ScalarKind::Byte,
ast::ScalarType::F16 => ScalarKind::Float,
ast::ScalarType::F32 => ScalarKind::Float,
ast::ScalarType::F64 => ScalarKind::Float,
}
}
fn from_parts(width: u8, kind: ScalarKind) -> Self {
match kind {
ScalarKind::Float => match width {
2 => ast::ScalarType::F16,
4 => ast::ScalarType::F32,
8 => ast::ScalarType::F64,
_ => unreachable!(),
},
ScalarKind::Byte => match width {
1 => ast::ScalarType::B8,
2 => ast::ScalarType::B16,
4 => ast::ScalarType::B32,
8 => ast::ScalarType::B64,
_ => unreachable!(),
},
ScalarKind::Signed => match width {
1 => ast::ScalarType::S8,
2 => ast::ScalarType::S16,
4 => ast::ScalarType::S32,
8 => ast::ScalarType::S64,
_ => unreachable!(),
},
ScalarKind::Unsigned => match width {
1 => ast::ScalarType::U8,
2 => ast::ScalarType::U16,
4 => ast::ScalarType::U32,
8 => ast::ScalarType::U64,
_ => unreachable!(),
},
}
}
}
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
if inst.width() != operand.width() {
return false;
}
match inst.kind() {
ScalarKind::Byte => operand.kind() != ScalarKind::Byte,
ScalarKind::Float => operand.kind() == ScalarKind::Byte,
ScalarKind::Signed => {
operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Unsigned
}
ScalarKind::Unsigned => {
operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed
}
}
}
_ => false,
}
}
impl ast::LdStateSpace {
fn should_convert(self, instr_type: ast::Type, op_type: ast::Type) -> Option<ConversionKind> {
match self {
ast::LdStateSpace::Param => {
if instr_type != op_type {
Some(ConversionKind::Default)
} else {
None
}
}
ast::LdStateSpace::Generic => Some(ConversionKind::Ptr),
_ => todo!(),
}
}
}
fn insert_forced_bitcast_src<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
NewStatement: FnOnce(spirv::Word) -> Statement,
>(
func: &mut Vec<Statement>,
op_type: ast::Type,
type_check: &TypeCheck,
new_id: &mut NewId,
dst: spirv::Word,
new_statement: NewStatement,
) {
if type_check(dst) == op_type {
func.push(new_statement(dst));
} else {
let new_dst = new_id();
func.push(new_statement(new_dst));
src: spirv::Word,
) -> spirv::Word {
let src_type = type_check(src);
if src_type == op_type {
return src;
}
let new_src = new_id();
func.push(Statement::Converison(ImplicitConversion {
src: src,
dst: new_src,
from: src_type,
to: op_type,
kind: ConversionKind::Default,
}));
new_src
}
fn insert_implicit_conversions_ld_src<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
ShouldConvert: Fn(ast::Type, ast::Type) -> Option<ConversionKind>,
>(
func: &mut Vec<Statement>,
instr_type: ast::Type,
type_check: &TypeCheck,
new_id: &mut NewId,
should_convert: ShouldConvert,
src: spirv::Word,
) -> spirv::Word {
let src_type = type_check(src);
if let Some(conv_kind) = should_convert(src_type, instr_type) {
let new_src = new_id();
func.push(Statement::Converison(ImplicitConversion {
src: new_dst,
dst: dst,
from: type_check(new_dst),
to: op_type,
src: src,
dst: new_src,
from: src_type,
to: instr_type,
kind: conv_kind,
}));
new_src
} else {
src
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: ast::Type,
instr_type: ast::ScalarType,
) -> Option<ConversionKind> {
if dst_type == ast::Type::Scalar(instr_type) {
return None;
}
match dst_type {
ast::Type::Scalar(dst_type) => match instr_type.kind() {
ScalarKind::Byte => {
if instr_type.width() <= dst_type.width() {
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Signed => {
if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float {
Some(ConversionKind::SignExtend)
} else {
None
}
}
ScalarKind::Unsigned => {
if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float {
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Float => {
if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Float {
Some(ConversionKind::Default)
} else {
None
}
}
},
_ => None,
}
}
fn insert_implicit_bitcasts<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
>(
do_src_bitcast: bool,
do_dst_bitcast: bool,
func: &mut Vec<Statement>,
type_check: &TypeCheck,
new_id: &mut NewId,
mut instr: ast::Instruction<spirv::Word>,
) {
let mut dst_coercion = None;
if let Some(instr_type) = instr.get_type() {
instr.visit_id_mut(&mut |is_dst, id| {
if (is_dst && !do_dst_bitcast) || (!is_dst && !do_src_bitcast) {
return;
}
let id_type = type_check(*id);
if should_bitcast(instr_type, type_check(*id)) {
let replacement_id = new_id();
if is_dst {
dst_coercion = Some(ImplicitConversion {
src: replacement_id,
dst: *id,
from: instr_type,
to: id_type,
kind: ConversionKind::Default,
});
*id = replacement_id;
} else {
func.push(Statement::Converison(ImplicitConversion {
src: *id,
dst: replacement_id,
from: id_type,
to: instr_type,
kind: ConversionKind::Default,
}));
*id = replacement_id;
}
}
});
}
func.push(Statement::Instruction(instr));
if let Some(cond) = dst_coercion {
func.push(Statement::Converison(cond));
}
}
@ -1678,6 +1983,12 @@ mod tests {
// page 403
const FIG_19_4: &'static str = "{
.reg.u32 i;
.reg.u32 j;
.reg.u32 k;
.reg.pred p;
.reg.pred q;
mov.u32 i, 1;
mov.u32 j, 1;
mov.u32 k, 0;
@ -1710,7 +2021,9 @@ mod tests {
assert_eq!(errors.len(), 0);
let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &ast);
let (normalized_ids, _) = normalize_identifiers(ast, &constant_ids);
let registers = collect_var_definitions(&[], &ast);
let (normalized_ids, _) =
normalize_identifiers(ast, &constant_ids, &mut HashMap::new(), registers);
let mut bbs = get_basic_blocks(&normalized_ids);
bbs.iter_mut().for_each(sort_pred_succ);
assert_eq!(
@ -1857,7 +2170,9 @@ mod tests {
let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &fn_ast);
assert_eq!(constant_ids.len(), 4);
let (normalized_ids, max_id) = normalize_identifiers(fn_ast, &constant_ids);
let registers = collect_var_definitions(&[], &fn_ast);
let (normalized_ids, max_id) =
normalize_identifiers(fn_ast, &constant_ids, &mut HashMap::new(), registers);
let bbs = get_basic_blocks(&normalized_ids);
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
@ -1895,21 +2210,7 @@ mod tests {
.parse(&mut errors, func)
.unwrap();
assert_eq!(errors.len(), 0);
let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &fn_ast);
let (mut func, unique_ids) = normalize_identifiers(fn_ast, &constant_ids);
let bbs = get_basic_blocks(&func);
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
let dom_fronts = dominance_frontiers(&bbs, &doms);
let (mut ssa_phis, _) = ssa_legalize(
&mut func,
constant_ids.len() as u32,
unique_ids,
&bbs,
&doms,
&dom_fronts,
);
let (func, _, mut ssa_phis, unique_ids) = to_ssa(&[], fn_ast);
assert_phi_dst_id(unique_ids, &ssa_phis);
assert_dst_unique(&func, &ssa_phis);
sort_phi(&mut ssa_phis);