From 66fa0706a473a4263334a3440402967b9178b177 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 30 Jul 2020 03:01:37 +0200 Subject: [PATCH] Refactor various functions for visiting/mapping statements and instructions into one --- ptx/src/translate.rs | 750 +++++++++++++++++++------------------------ 1 file changed, 335 insertions(+), 415 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index ebcb090..233c67f 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -156,16 +156,17 @@ fn emit_function<'a>( let (mut func_body, unique_ids) = to_ssa(&f.args, f.body); let id_offset = builder.reserve_ids(unique_ids); emit_function_args(builder, id_offset, map, &f.args); - apply_id_offset(&mut func_body, id_offset); + func_body = apply_id_offset(func_body, id_offset); emit_function_body_ops(builder, map, opencl_id, &func_body)?; builder.end_function()?; Ok(func_id) } -fn apply_id_offset(func_body: &mut Vec, id_offset: u32) { - for s in func_body { - s.visit_id(&mut |id| *id += id_offset); - } +fn apply_id_offset(func_body: Vec, id_offset: u32) -> Vec { + func_body + .into_iter() + .map(|s| s.visit_variable(&mut |id| id + id_offset)) + .collect() } fn to_ssa<'a, 'b>( @@ -274,32 +275,32 @@ fn insert_mem_ssa_statements( ) => { result.push(Statement::Instruction(ast::Instruction::Ld(ld, arg))); } - mut inst => { + inst => { let mut post_statements = Vec::new(); - inst.visit_id(&mut |is_dst, id, id_type| { + let inst = inst.visit_variable(&mut |id, is_dst, id_type| { let id_type = match id_type { Some(t) => t, - None => return, + None => return id, }; let generated_id = id_def.new_id(Some(id_type)); if !is_dst { result.push(Statement::LoadVar( Arg2 { dst: generated_id, - src: *id, + src: id, }, id_type, )); } else { post_statements.push(Statement::StoreVar( Arg2St { - src1: *id, + src1: id, src2: generated_id, }, id_type, )); } - *id = generated_id; + generated_id }); result.push(Statement::Instruction(inst)); result.append(&mut post_statements); @@ -847,12 +848,12 @@ fn normalize_identifiers<'a, 'b>( } let mut result = Vec::new(); for s in func { - expand_map_ids(&mut id_defs, &mut result, s); + expand_map_variables(&mut id_defs, &mut result, s); } (result, id_defs.finish()) } -fn expand_map_ids<'a>( +fn expand_map_variables<'a>( id_defs: &mut StringIdResolver<'a>, result: &mut Vec>, s: ast::Statement>, @@ -862,8 +863,8 @@ fn expand_map_ids<'a>( result.push(ast::Statement::Label(id_defs.add_def(name, None))) } ast::Statement::Instruction(p, i) => result.push(ast::Statement::Instruction( - p.map(|p| p.map_id(&mut |id| id_defs.get_id(id))), - i.map_id(&mut |id| id_defs.get_id(id)), + p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))), + i.map_variable(&mut |id| id_defs.get_id(id)), )), ast::Statement::Variable(var) => match var.count { Some(count) => { @@ -969,8 +970,8 @@ impl NumericIdResolver { enum Statement { Variable(spirv::Word, ast::Type, ast::StateSpace), - LoadVar(Arg2, ast::Type), - StoreVar(Arg2St, ast::Type), + LoadVar(ast::Arg2, ast::Type), + StoreVar(ast::Arg2St, ast::Type), Label(u32), Instruction(I), // SPIR-V compatible replacement for PTX predicates @@ -980,16 +981,20 @@ enum Statement { } impl Statement> { - fn visit_id(&mut self, f: &mut F) { + fn visit_variable spirv::Word>(self, f: &mut F) -> Self { match self { - Statement::Variable(id, _, _) => f(id), - Statement::LoadVar(a, _) => a.visit_id(&mut |_, id, _| f(id), None), - Statement::StoreVar(a, _) => a.visit_id(&mut |_, id, _| f(id), None), - Statement::Label(id) => f(id), - Statement::Instruction(inst) => inst.visit_id(f), - Statement::Conditional(bra) => bra.visit_id(&mut |_, id, _| f(id)), - Statement::Conversion(conv) => conv.visit_id(f), - Statement::Constant(cons) => cons.visit_id(f), + Statement::Variable(id, t, ss) => Statement::Variable(f(id), t, ss), + Statement::LoadVar(a, t) => { + Statement::LoadVar(a.map(&mut reduced_visitor(f), Some(t)), t) + } + Statement::StoreVar(a, t) => { + Statement::StoreVar(a.map(&mut reduced_visitor(f), Some(t)), t) + } + Statement::Label(id) => Statement::Label(f(id)), + Statement::Instruction(inst) => Statement::Instruction(inst.visit_variable(f)), + Statement::Conditional(bra) => Statement::Conditional(bra.map(f)), + Statement::Conversion(conv) => Statement::Conversion(conv.map(f)), + Statement::Constant(cons) => Statement::Constant(cons.map(f)), } } } @@ -1012,69 +1017,211 @@ impl ast::ArgParams for ExpandedArgParams { type MovOperand = spirv::Word; } -impl ast::Instruction { - fn visit_id)>(&mut self, f: &mut F) { - match self { - ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - ast::Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)), - ast::Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())), - ast::Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())), - ast::Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - ast::Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - ast::Instruction::Not(_, _) => todo!(), - ast::Instruction::Cvt(_, _) => todo!(), - ast::Instruction::Shl(_, _) => todo!(), - ast::Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - ast::Instruction::Bra(_, a) => a.visit_id(f, None), - ast::Instruction::Ret(_) => (), +trait ArgumentMapVisitor { + fn dst_variable(&mut self, v: T::ID, typ: Option) -> U::ID; + fn src_operand(&mut self, o: T::Operand, typ: Option) -> U::Operand; + fn src_mov_operand(&mut self, o: T::MovOperand, typ: Option) -> U::MovOperand; +} + +struct FlattenArguments<'a> { + func: &'a mut Vec, + id_def: &'a mut NumericIdResolver, +} + +impl<'a> ArgumentMapVisitor for FlattenArguments<'a> { + fn dst_variable(&mut self, x: spirv::Word, _: Option) -> spirv::Word { + x + } + + fn src_operand(&mut self, op: ast::Operand, t: Option) -> spirv::Word { + match op { + ast::Operand::Reg(r) => r, + ast::Operand::Imm(x) => { + if let Some(typ) = t { + let scalar_t = if let ast::Type::Scalar(scalar) = typ { + scalar + } else { + todo!() + }; + let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t))); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id, + typ: scalar_t, + value: x, + })); + id + } else { + todo!() + } + } + _ => todo!(), + } + } + + fn src_mov_operand( + &mut self, + op: ast::MovOperand, + t: Option, + ) -> spirv::Word { + match op { + ast::MovOperand::Op(opr) => self.src_operand(opr, t), + ast::MovOperand::Vec(_, _) => todo!(), } } } -impl ast::Instruction { - fn visit_id(&mut self, f: &mut F) { - let f_visitor = &mut Self::typed_visitor(f); - match self { - ast::Instruction::Ld(_, a) => a.visit_id(f_visitor, None), - ast::Instruction::Mov(_, a) => a.visit_id(f_visitor, None), - ast::Instruction::Mul(_, a) => a.visit_id(f_visitor, None), - ast::Instruction::Add(_, a) => a.visit_id(f_visitor, None), - ast::Instruction::Setp(_, a) => a.visit_id(f_visitor, None), - ast::Instruction::SetpBool(_, a) => a.visit_id(f_visitor, None), - ast::Instruction::Not(_, a) => a.visit_id(f_visitor, None), - ast::Instruction::Cvt(_, a) => a.visit_id(f_visitor, None), - ast::Instruction::Shl(_, a) => a.visit_id(f_visitor, None), - ast::Instruction::St(_, a) => a.visit_id(f_visitor, None), - ast::Instruction::Bra(_, a) => a.visit_id(f_visitor, None), - ast::Instruction::Ret(_) => (), +impl ArgumentMapVisitor for T +where + T: FnMut(spirv::Word, bool, Option) -> spirv::Word, +{ + fn dst_variable(&mut self, x: spirv::Word, t: Option) -> spirv::Word { + self(x, t.is_some(), t) + } + fn src_operand(&mut self, x: spirv::Word, t: Option) -> spirv::Word { + self(x, false, t) + } + fn src_mov_operand(&mut self, x: spirv::Word, t: Option) -> spirv::Word { + self(x, false, t) + } +} + +impl<'a, T> ArgumentMapVisitor, NormalizedArgParams> for T +where + T: FnMut(&str) -> spirv::Word, +{ + fn dst_variable(&mut self, x: &str, _: Option) -> spirv::Word { + self(x) + } + + fn src_operand( + &mut self, + x: ast::Operand<&str>, + _: Option, + ) -> ast::Operand { + match x { + ast::Operand::Reg(id) => ast::Operand::Reg(self(id)), + ast::Operand::Imm(imm) => ast::Operand::Imm(imm), + ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id), imm), } } - fn typed_visitor<'a>( - f: &'a mut impl FnMut(&mut spirv::Word), - ) -> impl FnMut(bool, &mut spirv::Word, Option) + 'a { - move |_, id, _| f(id) - } - - fn visit_id_extended)>( + fn src_mov_operand( &mut self, - f: &mut F, - ) { + x: ast::MovOperand<&str>, + t: Option, + ) -> ast::MovOperand { + match x { + ast::MovOperand::Op(op) => ast::MovOperand::Op(self.src_operand(op, t)), + ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2), + } + } +} + +impl ast::Instruction { + fn map_variable_new>( + self, + visitor: &mut V, + ) -> ast::Instruction { match self { - ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - ast::Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)), - ast::Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())), - ast::Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())), - ast::Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - ast::Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + ast::Instruction::Ld(d, a) => { + let inst_type = d.typ; + ast::Instruction::Ld(d, a.map(visitor, Some(ast::Type::Scalar(inst_type)))) + } + ast::Instruction::Mov(d, a) => { + let inst_type = d.typ; + ast::Instruction::Mov(d, a.map(visitor, Some(inst_type))) + } + ast::Instruction::Mul(d, a) => { + let inst_type = d.get_type(); + ast::Instruction::Mul(d, a.map(visitor, Some(inst_type))) + } + ast::Instruction::Add(d, a) => { + let inst_type = d.get_type(); + ast::Instruction::Add(d, a.map(visitor, Some(inst_type))) + } + ast::Instruction::Setp(d, a) => { + let inst_type = d.typ; + ast::Instruction::Setp(d, a.map(visitor, Some(ast::Type::Scalar(inst_type)))) + } + ast::Instruction::SetpBool(d, a) => { + let inst_type = d.typ; + ast::Instruction::SetpBool(d, a.map(visitor, Some(ast::Type::Scalar(inst_type)))) + } ast::Instruction::Not(_, _) => todo!(), ast::Instruction::Cvt(_, _) => todo!(), ast::Instruction::Shl(_, _) => todo!(), - ast::Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - ast::Instruction::Bra(_, a) => a.visit_id(f, None), - ast::Instruction::Ret(_) => (), + ast::Instruction::St(d, a) => { + let inst_type = d.typ; + ast::Instruction::St(d, a.map(visitor, Some(ast::Type::Scalar(inst_type)))) + } + ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)), + ast::Instruction::Ret(d) => ast::Instruction::Ret(d), } } +} + +impl ast::Instruction { + fn visit_variable) -> spirv::Word>( + self, + f: &mut F, + ) -> ast::Instruction { + self.map_variable_new(f) + } +} + +impl ArgumentMapVisitor for T +where + T: FnMut(spirv::Word, bool, Option) -> spirv::Word, +{ + fn dst_variable(&mut self, x: spirv::Word, t: Option) -> spirv::Word { + self(x, t.is_some(), t) + } + + fn src_operand( + &mut self, + x: ast::Operand, + t: Option, + ) -> ast::Operand { + match x { + ast::Operand::Reg(id) => ast::Operand::Reg(self(id, false, t)), + ast::Operand::Imm(imm) => ast::Operand::Imm(imm), + ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id, false, t), imm), + } + } + + fn src_mov_operand( + &mut self, + x: ast::MovOperand, + t: Option, + ) -> ast::MovOperand { + match x { + ast::MovOperand::Op(op) => ast::MovOperand::Op(ArgumentMapVisitor::< + NormalizedArgParams, + NormalizedArgParams, + >::src_operand(self, op, t)), + ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2), + } + } +} + +fn reduced_visitor<'a>( + f: &'a mut impl FnMut(spirv::Word) -> spirv::Word, +) -> impl FnMut(spirv::Word, bool, Option) -> spirv::Word + 'a { + move |id, _, _| f(id) +} + +impl ast::Instruction { + fn visit_variable spirv::Word>(self, f: &mut F) -> Self { + let mut visitor = reduced_visitor(f); + self.map_variable_new(&mut visitor) + } + + fn visit_variable_extended) -> spirv::Word>( + self, + f: &mut F, + ) -> Self { + self.map_variable_new(f) + } fn jump_target(&self) -> Option { match self { @@ -1094,126 +1241,9 @@ impl ast::Instruction { } } -type Arg1 = ast::Arg1; - -impl Arg1 { - fn visit_id)>( - &mut self, - f: &mut F, - t: Option, - ) { - f(false, &mut self.src, t); - } -} - type Arg2 = ast::Arg2; - -impl Arg2 { - fn visit_id)>( - &mut self, - f: &mut F, - t: Option, - ) { - f(true, &mut self.dst, t); - f(false, &mut self.src, t); - } -} - -type Arg2Mov = ast::Arg2Mov; - -impl Arg2Mov { - fn visit_id)>( - &mut self, - f: &mut F, - t: Option, - ) { - f(true, &mut self.dst, t); - f(false, &mut self.src, t); - } -} - type Arg2St = ast::Arg2St; -impl Arg2St { - fn visit_id)>( - &mut self, - f: &mut F, - t: Option, - ) { - f(false, &mut self.src1, t); - f(false, &mut self.src2, t); - } -} - -type Arg3 = ast::Arg3; - -impl Arg3 { - fn visit_id)>( - &mut self, - f: &mut F, - t: Option, - ) { - f(true, &mut self.dst, t); - f(false, &mut self.src1, t); - f(false, &mut self.src2, t); - } -} - -type Arg4 = ast::Arg4; - -impl Arg4 { - fn visit_id)>( - &mut self, - f: &mut F, - t: Option, - ) { - f( - true, - &mut self.dst1, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ); - self.dst2.as_mut().map(|dst2| { - f( - true, - dst2, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ) - }); - f(false, &mut self.src1, t); - f(false, &mut self.src2, t); - } -} - -type Arg5 = ast::Arg5; - -impl Arg5 { - fn visit_id)>( - &mut self, - f: &mut F, - t: Option, - ) { - f( - true, - &mut self.dst1, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ); - self.dst2.as_mut().map(|dst2| { - f( - true, - dst2, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ) - }); - f(false, &mut self.src1, t); - f(false, &mut self.src2, t); - f( - false, - &mut self.src3, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ); - } -} - struct ConstantDefinition { pub dst: spirv::Word, pub typ: ast::ScalarType, @@ -1221,8 +1251,12 @@ struct ConstantDefinition { } impl ConstantDefinition { - fn visit_id(&mut self, f: &mut F) { - f(&mut self.dst); + fn map spirv::Word>(self, f: &mut F) -> Self { + Self { + dst: f(self.dst), + typ: self.typ, + value: self.value, + } } } @@ -1233,14 +1267,12 @@ struct BrachCondition { } impl BrachCondition { - fn visit_id)>(&mut self, f: &mut F) { - f( - false, - &mut self.predicate, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ); - f(false, &mut self.if_true, None); - f(false, &mut self.if_false, None); + fn map spirv::Word>(self, f: &mut F) -> Self { + Self { + predicate: f(self.predicate), + if_true: f(self.if_true), + if_false: f(self.if_false), + } } } @@ -1261,14 +1293,19 @@ enum ConversionKind { } impl ImplicitConversion { - fn visit_id(&mut self, f: &mut F) { - f(&mut self.dst); - f(&mut self.src); + fn map spirv::Word>(self, f: &mut F) -> Self { + Self { + src: f(self.src), + dst: f(self.dst), + from: self.from, + to: self.to, + kind: self.kind, + } } } impl ast::PredAt { - fn map_id U>(self, f: &mut F) -> ast::PredAt { + fn map_variable U>(self, f: &mut F) -> ast::PredAt { ast::PredAt { not: self.not, label: f(self.label), @@ -1276,247 +1313,127 @@ impl ast::PredAt { } } +// REMOVE impl<'a> ast::Instruction> { - fn map_id spirv::Word>( + fn map_variable spirv::Word>( self, f: &mut F, ) -> ast::Instruction { - match self { - ast::Instruction::Ld(d, a) => ast::Instruction::Ld(d, a.map_id(f)), - ast::Instruction::Mov(d, a) => ast::Instruction::Mov(d, a.map_id(f)), - ast::Instruction::Mul(d, a) => ast::Instruction::Mul(d, a.map_id(f)), - ast::Instruction::Add(d, a) => ast::Instruction::Add(d, a.map_id(f)), - ast::Instruction::Setp(d, a) => ast::Instruction::Setp(d, a.map_id(f)), - ast::Instruction::SetpBool(d, a) => ast::Instruction::SetpBool(d, a.map_id(f)), - ast::Instruction::Not(d, a) => ast::Instruction::Not(d, a.map_id(f)), - ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map_id(f)), - ast::Instruction::Cvt(d, a) => ast::Instruction::Cvt(d, a.map_id(f)), - ast::Instruction::Shl(d, a) => ast::Instruction::Shl(d, a.map_id(f)), - ast::Instruction::St(d, a) => ast::Instruction::St(d, a.map_id(f)), - ast::Instruction::Ret(d) => ast::Instruction::Ret(d), + self.map_variable_new(f) + } +} + +impl ast::Arg1 { + fn map>( + self, + visitor: &mut V, + t: Option, + ) -> ast::Arg1 { + ast::Arg1 { + src: visitor.dst_variable(self.src, t), } } } -impl<'a> ast::Arg1> { - fn map_id spirv::Word>(self, f: &mut F) -> ast::Arg1 { - ast::Arg1 { src: f(self.src) } - } -} - -impl ast::Arg1 { - fn visit_id)>( - &mut self, - f: &mut F, +impl ast::Arg2 { + fn map>( + self, + visitor: &mut V, t: Option, - ) { - f(false, &mut self.src, t); - } -} - -impl<'a> ast::Arg2> { - fn map_id spirv::Word>(self, f: &mut F) -> ast::Arg2 { + ) -> ast::Arg2 { ast::Arg2 { - dst: f(self.dst), - src: self.src.map_id(f), + dst: visitor.dst_variable(self.dst, t), + src: visitor.src_operand(self.src, t), } } } -impl ast::Arg2 { - fn visit_id)>( - &mut self, - f: &mut F, - t: Option, - ) { - f(true, &mut self.dst, t); - self.src.visit_id(f, t); - } -} - -impl<'a> ast::Arg2St> { - fn map_id spirv::Word>( +impl ast::Arg2St { + fn map>( self, - f: &mut F, - ) -> ast::Arg2St { + visitor: &mut V, + t: Option, + ) -> ast::Arg2St { ast::Arg2St { - src1: self.src1.map_id(f), - src2: self.src2.map_id(f), + src1: visitor.src_operand(self.src1, t), + src2: visitor.src_operand(self.src2, t), } } } -impl ast::Arg2St { - fn visit_id)>( - &mut self, - f: &mut F, - t: Option, - ) { - self.src1.visit_id(f, t); - self.src2.visit_id(f, t); - } -} - -impl<'a> ast::Arg2Mov> { - fn map_id spirv::Word>( +impl ast::Arg2Mov { + fn map>( self, - f: &mut F, - ) -> ast::Arg2Mov { + visitor: &mut V, + t: Option, + ) -> ast::Arg2Mov { ast::Arg2Mov { - dst: f(self.dst), - src: self.src.map_id(f), + dst: visitor.dst_variable(self.dst, t), + src: visitor.src_mov_operand(self.src, t), } } } -impl ast::Arg2Mov { - fn visit_id)>( - &mut self, - f: &mut F, +impl ast::Arg3 { + fn map>( + self, + visitor: &mut V, t: Option, - ) { - f(true, &mut self.dst, t); - self.src.visit_id(f, t); - } -} - -impl<'a> ast::Arg3> { - fn map_id spirv::Word>(self, f: &mut F) -> ast::Arg3 { + ) -> ast::Arg3 { ast::Arg3 { - dst: f(self.dst), - src1: self.src1.map_id(f), - src2: self.src2.map_id(f), + dst: visitor.dst_variable(self.dst, t), + src1: visitor.src_operand(self.src1, t), + src2: visitor.src_operand(self.src2, t), } } } -impl ast::Arg3 { - fn visit_id)>( - &mut self, - f: &mut F, +impl ast::Arg4 { + fn map>( + self, + visitor: &mut V, t: Option, - ) { - f(true, &mut self.dst, t); - self.src1.visit_id(f, t); - self.src2.visit_id(f, t); - } -} - -impl<'a> ast::Arg4> { - fn map_id spirv::Word>(self, f: &mut F) -> ast::Arg4 { + ) -> ast::Arg4 { ast::Arg4 { - dst1: f(self.dst1), - dst2: self.dst2.map(|i| f(i)), - src1: self.src1.map_id(f), - src2: self.src2.map_id(f), + dst1: visitor.dst_variable( + self.dst1, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ), + dst2: self.dst2.map(|dst2| { + visitor.dst_variable( + dst2, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ) + }), + src1: visitor.src_operand(self.src1, t), + src2: visitor.src_operand(self.src2, t), } } } -impl ast::Arg4 { - fn visit_id)>( - &mut self, - f: &mut F, +impl ast::Arg5 { + fn map>( + self, + visitor: &mut V, t: Option, - ) { - f( - true, - &mut self.dst1, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ); - self.dst2.as_mut().map(|i| { - f( - true, - i, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ) - }); - self.src1.visit_id(f, t); - self.src2.visit_id(f, t); - } -} - -impl<'a> ast::Arg5> { - fn map_id spirv::Word>(self, f: &mut F) -> ast::Arg5 { + ) -> ast::Arg5 { ast::Arg5 { - dst1: f(self.dst1), - dst2: self.dst2.map(|i| f(i)), - src1: self.src1.map_id(f), - src2: self.src2.map_id(f), - src3: self.src3.map_id(f), - } - } -} - -impl ast::Arg5 { - fn visit_id)>( - &mut self, - f: &mut F, - t: Option, - ) { - f( - true, - &mut self.dst1, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ); - self.dst2.as_mut().map(|i| { - f( - true, - i, + dst1: visitor.dst_variable( + self.dst1, Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ) - }); - self.src1.visit_id(f, t); - self.src2.visit_id(f, t); - self.src3.visit_id( - f, - Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), - ); - } -} - -impl ast::Operand { - fn map_id U>(self, f: &mut F) -> ast::Operand { - match self { - ast::Operand::Reg(i) => ast::Operand::Reg(f(i)), - ast::Operand::RegOffset(i, o) => ast::Operand::RegOffset(f(i), o), - ast::Operand::Imm(v) => ast::Operand::Imm(v), - } - } -} - -impl ast::Operand { - fn visit_id)>( - &mut self, - f: &mut F, - t: Option, - ) { - match self { - ast::Operand::Reg(i) => f(false, i, t), - ast::Operand::RegOffset(i, _) => f(false, i, t), - ast::Operand::Imm(_) => (), - } - } -} - -impl ast::MovOperand { - fn map_id U>(self, f: &mut F) -> ast::MovOperand { - match self { - ast::MovOperand::Op(o) => ast::MovOperand::Op(o.map_id(f)), - ast::MovOperand::Vec(s1, s2) => ast::MovOperand::Vec(s1, s2), - } - } -} - -impl ast::MovOperand { - fn visit_id)>( - &mut self, - f: &mut F, - t: Option, - ) { - match self { - ast::MovOperand::Op(o) => o.visit_id(f, t), - ast::MovOperand::Vec(_, _) => todo!(), + ), + dst2: self.dst2.map(|dst2| { + visitor.dst_variable( + dst2, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ) + }), + src1: visitor.src_operand(self.src1, t), + src2: visitor.src_operand(self.src2, t), + src3: visitor.src_operand( + self.src3, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ), } } } @@ -1906,34 +1823,37 @@ fn should_convert_relaxed_dst( fn insert_implicit_bitcasts( func: &mut Vec, id_def: &mut NumericIdResolver, - mut instr: ast::Instruction, + instr: ast::Instruction, ) { let mut dst_coercion = None; - instr.visit_id_extended(&mut |is_dst, id, id_type| { + let instr = instr.visit_variable_extended(&mut |mut id, is_dst, id_type| { let id_type_from_instr = match id_type { Some(t) => t, - None => return, + None => return id, }; - let id_actual_type = id_def.get_type(*id); - if should_bitcast(id_type_from_instr, id_def.get_type(*id)) { + let id_actual_type = id_def.get_type(id); + if should_bitcast(id_type_from_instr, id_def.get_type(id)) { if is_dst { dst_coercion = Some(get_conversion_dst( id_def, - id, + &mut id, id_type_from_instr, id_actual_type, ConversionKind::Default, )); + id } else { - *id = insert_conversion_src( + insert_conversion_src( func, id_def, - *id, + id, id_actual_type, id_type_from_instr, ConversionKind::Default, - ); + ) } + } else { + id } }); func.push(Statement::Instruction(instr));