Refactor code to support per-variable type definition in callbacks

This commit is contained in:
Andrzej Janik 2020-07-28 00:37:16 +02:00
parent 04820fba2f
commit 72f5ffe2f9

View file

@ -164,7 +164,7 @@ fn emit_function<'a>(
fn apply_id_offset(func_body: &mut Vec<ExpandedStatement>, id_offset: u32) {
for s in func_body {
s.visit_id_mut(&mut |_, id| *id += id_offset);
s.visit_id(&mut |id| *id += id_offset);
}
}
@ -200,7 +200,7 @@ fn normalize_labels(
Statement::Variable(_, _, _)
| Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
| Statement::Converison(_)
| Statement::Conversion(_)
| Statement::Constant(_)
| Statement::Label(_) => (),
}
@ -275,18 +275,20 @@ fn insert_mem_ssa_statements(
result.push(Statement::Instruction(Instruction::Ld(ld, arg)));
}
mut inst => {
let inst_type = inst.get_type();
let mut post_statements = Vec::new();
inst.visit_id_mut(&mut |is_dst, id| {
let inst_type = inst_type.unwrap();
let generated_id = id_def.new_id(Some(inst_type));
inst.visit_id(&mut |is_dst, id, id_type| {
let id_type = match id_type {
Some(t) => t,
None => return,
};
let generated_id = id_def.new_id(Some(id_type));
if !is_dst {
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
src: *id,
},
inst_type,
id_type,
));
} else {
post_statements.push(Statement::StoreVar(
@ -294,7 +296,7 @@ fn insert_mem_ssa_statements(
src1: *id,
src2: generated_id,
},
inst_type,
id_type,
));
}
*id = generated_id;
@ -308,7 +310,7 @@ fn insert_mem_ssa_statements(
| s @ Statement::Conditional(_) => result.push(s),
Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
| Statement::Converison(_)
| Statement::Conversion(_)
| Statement::Constant(_) => unreachable!(),
}
}
@ -331,7 +333,7 @@ fn expand_arguments(
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
Statement::Converison(_) | Statement::Constant(_) => unreachable!(),
Statement::Conversion(_) | Statement::Constant(_) => unreachable!(),
}
}
result
@ -572,7 +574,7 @@ fn insert_implicit_conversions(
| s @ Statement::Variable(_, _, _)
| s @ Statement::LoadVar(_, _)
| s @ Statement::StoreVar(_, _) => result.push(s),
Statement::Converison(_) => unreachable!(),
Statement::Conversion(_) => unreachable!(),
}
}
result
@ -660,7 +662,7 @@ fn emit_function_body_ops(
_ => unreachable!(),
}
}
Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?,
Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?,
Statement::Conditional(bra) => {
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
}
@ -973,38 +975,33 @@ enum Statement<A: Args> {
Instruction(Instruction<A>),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
Converison(ImplicitConversion),
Conversion(ImplicitConversion),
Constant(ConstantDefinition),
}
impl<A: Args> Statement<A> {
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
impl Statement<ExpandedArgs> {
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
match self {
Statement::Variable(id, _, _) => f(true, id),
Statement::LoadVar(a, _) => a.visit_id_mut(f),
Statement::StoreVar(a, _) => a.visit_id_mut(f),
Statement::Label(id) => f(false, id),
Statement::Instruction(inst) => inst.visit_id_mut(f),
Statement::Conditional(bra) => bra.visit_id_mut(f),
Statement::Converison(conv) => conv.visit_id_mut(f),
Statement::Constant(cons) => cons.visit_id_mut(f),
Statement::Variable(id, _, _) => f(id),
Statement::LoadVar(a, _) => a.visit_id(&mut |_, id, _| f(id), None),
Statement::StoreVar(a, _) => a.visit_id(&mut |_, id, _| f(id), None),
Statement::Label(id) => f(id),
Statement::Instruction(inst) => inst.visit_id(f),
Statement::Conditional(bra) => bra.visit_id(&mut |_, id, _| f(id)),
Statement::Conversion(conv) => conv.visit_id(f),
Statement::Constant(cons) => cons.visit_id(f),
}
}
}
trait Args {
type Arg1: Arg;
type Arg2: Arg;
type Arg2St: Arg;
type Arg2Mov: Arg;
type Arg3: Arg;
type Arg4: Arg;
type Arg5: Arg;
}
trait Arg {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F);
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F);
type Arg1;
type Arg2;
type Arg2St;
type Arg2Mov;
type Arg3;
type Arg4;
type Arg5;
}
enum NormalizedArgs {}
@ -1049,48 +1046,24 @@ enum Instruction<A: Args> {
Ret(ast::RetData),
}
impl<A: Args> Instruction<A> {
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
impl Instruction<NormalizedArgs> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(&mut self, f: &mut F) {
match self {
Instruction::Ld(_, a) => a.visit_id_mut(f),
Instruction::Mov(_, a) => a.visit_id_mut(f),
Instruction::Mul(_, a) => a.visit_id_mut(f),
Instruction::Add(_, a) => a.visit_id_mut(f),
Instruction::Setp(_, a) => a.visit_id_mut(f),
Instruction::SetpBool(_, a) => a.visit_id_mut(f),
Instruction::Not(_, a) => a.visit_id_mut(f),
Instruction::Cvt(_, a) => a.visit_id_mut(f),
Instruction::Shl(_, a) => a.visit_id_mut(f),
Instruction::St(_, a) => a.visit_id_mut(f),
Instruction::Bra(_, a) => a.visit_id_mut(f),
Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)),
Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Not(_, _) => todo!(),
Instruction::Cvt(_, _) => todo!(),
Instruction::Shl(_, _) => todo!(),
Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Bra(_, a) => a.visit_id(f, None),
Instruction::Ret(_) => (),
}
}
fn get_type(&self) -> Option<ast::Type> {
match self {
Instruction::Add(add, _) => match add {
ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => {
Some(ast::Type::Scalar((*typ).into()))
}
ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => Some((*typ).into()),
},
Instruction::Ret(_) => None,
Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)),
Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)),
Instruction::Mov(mov, _) => Some(mov.typ),
Instruction::Mul(mul, _) => match mul {
ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => {
Some(ast::Type::Scalar((*typ).into()))
}
ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => Some((*typ).into()),
},
_ => todo!(),
}
}
}
impl Instruction<NormalizedArgs> {
fn from_ast(s: ast::Instruction<spirv::Word>) -> Self {
match s {
ast::Instruction::Ld(d, a) => Instruction::Ld(d, a),
@ -1110,6 +1083,50 @@ impl Instruction<NormalizedArgs> {
}
impl Instruction<ExpandedArgs> {
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
let f_visitor = &mut Self::typed_visitor(f);
match self {
Instruction::Ld(_, a) => a.visit_id(f_visitor, None),
Instruction::Mov(_, a) => a.visit_id(f_visitor, None),
Instruction::Mul(_, a) => a.visit_id(f_visitor, None),
Instruction::Add(_, a) => a.visit_id(f_visitor, None),
Instruction::Setp(_, a) => a.visit_id(f_visitor, None),
Instruction::SetpBool(_, a) => a.visit_id(f_visitor, None),
Instruction::Not(_, a) => a.visit_id(f_visitor, None),
Instruction::Cvt(_, a) => a.visit_id(f_visitor, None),
Instruction::Shl(_, a) => a.visit_id(f_visitor, None),
Instruction::St(_, a) => a.visit_id(f_visitor, None),
Instruction::Bra(_, a) => a.visit_id(f_visitor, None),
Instruction::Ret(_) => (),
}
}
fn typed_visitor<'a>(
f: &'a mut impl FnMut(&mut spirv::Word),
) -> impl FnMut(bool, &mut spirv::Word, Option<ast::Type>) + 'a {
move |_, id, _| f(id)
}
fn visit_id_extended<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
) {
match self {
Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)),
Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())),
Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Not(_, a) => todo!(),
Instruction::Cvt(_, a) => todo!(),
Instruction::Shl(_, a) => todo!(),
Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
Instruction::Bra(_, a) => a.visit_id(f, None),
Instruction::Ret(_) => (),
}
}
fn jump_target(&self) -> Option<spirv::Word> {
match self {
Instruction::Bra(_, a) => Some(a.src),
@ -1132,13 +1149,13 @@ struct Arg1 {
pub src: spirv::Word,
}
impl Arg for Arg1 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(false, self.src);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
f(false, &mut self.src);
impl Arg1 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(false, &mut self.src, t);
}
}
@ -1147,15 +1164,14 @@ struct Arg2 {
pub src: spirv::Word,
}
impl Arg for Arg2 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(true, self.dst);
f(false, self.src);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
f(false, &mut self.src);
f(true, &mut self.dst);
impl Arg2 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(true, &mut self.dst, t);
f(false, &mut self.src, t);
}
}
@ -1164,15 +1180,14 @@ pub struct Arg2St {
pub src2: spirv::Word,
}
impl Arg for Arg2St {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(false, self.src1);
f(false, self.src2);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
f(false, &mut self.src1);
f(false, &mut self.src2);
impl Arg2St {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(false, &mut self.src1, t);
f(false, &mut self.src2, t);
}
}
@ -1182,17 +1197,15 @@ struct Arg3 {
pub src2: spirv::Word,
}
impl Arg for Arg3 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(true, self.dst);
f(false, self.src1);
f(false, self.src2);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
f(false, &mut self.src1);
f(false, &mut self.src2);
f(true, &mut self.dst);
impl Arg3 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(true, &mut self.dst, t);
f(false, &mut self.src1, t);
f(false, &mut self.src2, t);
}
}
@ -1203,19 +1216,26 @@ struct Arg4 {
pub src2: spirv::Word,
}
impl Arg for Arg4 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(true, self.dst1);
self.dst2.map(|dst2| f(true, dst2));
f(false, self.src1);
f(false, self.src2);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
f(false, &mut self.src1);
f(false, &mut self.src2);
f(true, &mut self.dst1);
self.dst2.as_mut().map(|dst2| f(true, dst2));
impl Arg4 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(
true,
&mut self.dst1,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
self.dst2.as_mut().map(|dst2| {
f(
true,
dst2,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
)
});
f(false, &mut self.src1, t);
f(false, &mut self.src2, t);
}
}
@ -1227,21 +1247,31 @@ struct Arg5 {
pub src3: spirv::Word,
}
impl Arg for Arg5 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(true, self.dst1);
self.dst2.map(|dst2| f(true, dst2));
f(false, self.src1);
f(false, self.src2);
f(false, self.src3);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
f(false, &mut self.src1);
f(false, &mut self.src2);
f(false, &mut self.src3);
f(true, &mut self.dst1);
self.dst2.as_mut().map(|dst2| f(true, dst2));
impl Arg5 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(
true,
&mut self.dst1,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
self.dst2.as_mut().map(|dst2| {
f(
true,
dst2,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
)
});
f(false, &mut self.src1, t);
f(false, &mut self.src2, t);
f(
false,
&mut self.src3,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
}
}
@ -1252,12 +1282,8 @@ struct ConstantDefinition {
}
impl ConstantDefinition {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(true, self.dst);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
f(true, &mut self.dst);
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
f(&mut self.dst);
}
}
@ -1268,16 +1294,14 @@ struct BrachCondition {
}
impl BrachCondition {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(false, self.predicate);
f(false, self.if_true);
f(false, self.if_false);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
f(false, &mut self.predicate);
f(false, &mut self.if_true);
f(false, &mut self.if_false);
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(&mut self, f: &mut F) {
f(
false,
&mut self.predicate,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
f(false, &mut self.if_true, None);
f(false, &mut self.if_false, None);
}
}
@ -1298,14 +1322,9 @@ enum ConversionKind {
}
impl ImplicitConversion {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(false, self.src);
f(true, self.dst);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
f(false, &mut self.src);
f(true, &mut self.dst);
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
f(&mut self.dst);
f(&mut self.src);
}
}
@ -1343,13 +1362,13 @@ impl<T> ast::Arg1<T> {
}
}
impl Arg for ast::Arg1<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(false, self.src);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
f(false, &mut self.src);
impl ast::Arg1<spirv::Word> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(false, &mut self.src, t);
}
}
@ -1362,15 +1381,14 @@ impl<T> ast::Arg2<T> {
}
}
impl Arg for ast::Arg2<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(true, self.dst);
self.src.visit_id(f);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
self.src.visit_id_mut(f);
f(true, &mut self.dst);
impl ast::Arg2<spirv::Word> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(true, &mut self.dst, t);
self.src.visit_id(f, t);
}
}
@ -1383,15 +1401,14 @@ impl<T> ast::Arg2St<T> {
}
}
impl Arg for ast::Arg2St<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
self.src1.visit_id(f);
self.src2.visit_id(f);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
self.src1.visit_id_mut(f);
self.src2.visit_id_mut(f);
impl ast::Arg2St<spirv::Word> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
self.src1.visit_id(f, t);
self.src2.visit_id(f, t);
}
}
@ -1404,15 +1421,14 @@ impl<T> ast::Arg2Mov<T> {
}
}
impl Arg for ast::Arg2Mov<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(true, self.dst);
self.src.visit_id(f);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
self.src.visit_id_mut(f);
f(true, &mut self.dst);
impl ast::Arg2Mov<spirv::Word> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(true, &mut self.dst, t);
self.src.visit_id(f, t);
}
}
@ -1426,17 +1442,15 @@ impl<T> ast::Arg3<T> {
}
}
impl Arg for ast::Arg3<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(true, self.dst);
self.src1.visit_id(f);
self.src2.visit_id(f);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
self.src1.visit_id_mut(f);
self.src2.visit_id_mut(f);
f(true, &mut self.dst);
impl ast::Arg3<spirv::Word> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(true, &mut self.dst, t);
self.src1.visit_id(f, t);
self.src2.visit_id(f, t);
}
}
@ -1451,19 +1465,26 @@ impl<T> ast::Arg4<T> {
}
}
impl Arg for ast::Arg4<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(true, self.dst1);
self.dst2.map(|i| f(true, i));
self.src1.visit_id(f);
self.src2.visit_id(f);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
self.src1.visit_id_mut(f);
self.src2.visit_id_mut(f);
f(true, &mut self.dst1);
self.dst2.as_mut().map(|i| f(true, i));
impl ast::Arg4<spirv::Word> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(
true,
&mut self.dst1,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
self.dst2.as_mut().map(|i| {
f(
true,
i,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
)
});
self.src1.visit_id(f, t);
self.src2.visit_id(f, t);
}
}
@ -1479,21 +1500,30 @@ impl<T> ast::Arg5<T> {
}
}
impl Arg for ast::Arg5<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(true, self.dst1);
self.dst2.map(|i| f(true, i));
self.src1.visit_id(f);
self.src2.visit_id(f);
self.src3.visit_id(f);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
self.src1.visit_id_mut(f);
self.src2.visit_id_mut(f);
self.src3.visit_id_mut(f);
f(true, &mut self.dst1);
self.dst2.as_mut().map(|i| f(true, i));
impl ast::Arg5<spirv::Word> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
f(
true,
&mut self.dst1,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
self.dst2.as_mut().map(|i| {
f(
true,
i,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
)
});
self.src1.visit_id(f, t);
self.src2.visit_id(f, t);
self.src3.visit_id(
f,
Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
);
}
}
@ -1508,18 +1538,14 @@ impl<T> ast::Operand<T> {
}
impl<T: Copy> ast::Operand<T> {
fn visit_id<F: FnMut(bool, T)>(&self, f: &mut F) {
fn visit_id<F: FnMut(bool, &mut T, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
match self {
ast::Operand::Reg(i) => f(false, *i),
ast::Operand::RegOffset(i, _) => f(false, *i),
ast::Operand::Imm(_) => (),
}
}
fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
match self {
ast::Operand::Reg(i) => f(false, i),
ast::Operand::RegOffset(i, _) => f(false, i),
ast::Operand::Reg(i) => f(false, i, t),
ast::Operand::RegOffset(i, _) => f(false, i, t),
ast::Operand::Imm(_) => (),
}
}
@ -1535,16 +1561,13 @@ impl<T> ast::MovOperand<T> {
}
impl<T: Copy> ast::MovOperand<T> {
fn visit_id<F: FnMut(bool, T)>(&self, f: &mut F) {
fn visit_id<F: FnMut(bool, &mut T, Option<ast::Type>)>(
&mut self,
f: &mut F,
t: Option<ast::Type>,
) {
match self {
ast::MovOperand::Op(o) => o.visit_id(f),
ast::MovOperand::Vec(_, _) => todo!(),
}
}
fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
match self {
ast::MovOperand::Op(o) => o.visit_id_mut(f),
ast::MovOperand::Op(o) => o.visit_id(f, t),
ast::MovOperand::Vec(_, _) => todo!(),
}
}
@ -1793,7 +1816,7 @@ fn insert_conversion_src(
conv: ConversionKind,
) -> spirv::Word {
let temp_src = id_def.new_id(Some(instr_type));
func.push(Statement::Converison(ImplicitConversion {
func.push(Statement::Conversion(ImplicitConversion {
src: src,
dst: temp_src,
from: src_type,
@ -1838,7 +1861,7 @@ fn get_conversion_dst(
let original_dst = *dst;
let temp_dst = id_def.new_id(Some(instr_type));
*dst = temp_dst;
Statement::Converison(ImplicitConversion {
Statement::Conversion(ImplicitConversion {
src: temp_dst,
dst: original_dst,
from: instr_type,
@ -1938,31 +1961,33 @@ fn insert_implicit_bitcasts(
mut instr: Instruction<ExpandedArgs>,
) {
let mut dst_coercion = None;
if let Some(instr_type) = instr.get_type() {
instr.visit_id_mut(&mut |is_dst, id| {
let id_type = id_def.get_type(*id);
if should_bitcast(instr_type, id_def.get_type(*id)) {
if is_dst {
dst_coercion = Some(get_conversion_dst(
id_def,
id,
instr_type,
id_type,
ConversionKind::Default,
));
} else {
*id = insert_conversion_src(
func,
id_def,
*id,
id_type,
instr_type,
ConversionKind::Default,
);
}
instr.visit_id_extended(&mut |is_dst, id, id_type| {
let id_type_from_instr = match id_type {
Some(t) => t,
None => return,
};
let id_actual_type = id_def.get_type(*id);
if should_bitcast(id_type_from_instr, id_def.get_type(*id)) {
if is_dst {
dst_coercion = Some(get_conversion_dst(
id_def,
id,
id_type_from_instr,
id_actual_type,
ConversionKind::Default,
));
} else {
*id = insert_conversion_src(
func,
id_def,
*id,
id_actual_type,
id_type_from_instr,
ConversionKind::Default,
);
}
});
}
}
});
func.push(Statement::Instruction(instr));
if let Some(cond) = dst_coercion {
func.push(cond);