diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 8dd612d..c37413a 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -164,7 +164,7 @@ fn emit_function<'a>( fn apply_id_offset(func_body: &mut Vec, id_offset: u32) { for s in func_body { - s.visit_id_mut(&mut |_, id| *id += id_offset); + s.visit_id(&mut |id| *id += id_offset); } } @@ -200,7 +200,7 @@ fn normalize_labels( Statement::Variable(_, _, _) | Statement::LoadVar(_, _) | Statement::StoreVar(_, _) - | Statement::Converison(_) + | Statement::Conversion(_) | Statement::Constant(_) | Statement::Label(_) => (), } @@ -275,18 +275,20 @@ fn insert_mem_ssa_statements( result.push(Statement::Instruction(Instruction::Ld(ld, arg))); } mut inst => { - let inst_type = inst.get_type(); let mut post_statements = Vec::new(); - inst.visit_id_mut(&mut |is_dst, id| { - let inst_type = inst_type.unwrap(); - let generated_id = id_def.new_id(Some(inst_type)); + inst.visit_id(&mut |is_dst, id, id_type| { + let id_type = match id_type { + Some(t) => t, + None => return, + }; + let generated_id = id_def.new_id(Some(id_type)); if !is_dst { result.push(Statement::LoadVar( Arg2 { dst: generated_id, src: *id, }, - inst_type, + id_type, )); } else { post_statements.push(Statement::StoreVar( @@ -294,7 +296,7 @@ fn insert_mem_ssa_statements( src1: *id, src2: generated_id, }, - inst_type, + id_type, )); } *id = generated_id; @@ -308,7 +310,7 @@ fn insert_mem_ssa_statements( | s @ Statement::Conditional(_) => result.push(s), Statement::LoadVar(_, _) | Statement::StoreVar(_, _) - | Statement::Converison(_) + | Statement::Conversion(_) | Statement::Constant(_) => unreachable!(), } } @@ -331,7 +333,7 @@ fn expand_arguments( Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)), - Statement::Converison(_) | Statement::Constant(_) => unreachable!(), + Statement::Conversion(_) | Statement::Constant(_) => unreachable!(), } } result @@ -572,7 +574,7 @@ fn insert_implicit_conversions( | s @ Statement::Variable(_, _, _) | s @ Statement::LoadVar(_, _) | s @ Statement::StoreVar(_, _) => result.push(s), - Statement::Converison(_) => unreachable!(), + Statement::Conversion(_) => unreachable!(), } } result @@ -660,7 +662,7 @@ fn emit_function_body_ops( _ => unreachable!(), } } - Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?, + Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?, Statement::Conditional(bra) => { builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?; } @@ -973,38 +975,33 @@ enum Statement { Instruction(Instruction), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), - Converison(ImplicitConversion), + Conversion(ImplicitConversion), Constant(ConstantDefinition), } -impl Statement { - fn visit_id_mut(&mut self, f: &mut F) { +impl Statement { + fn visit_id(&mut self, f: &mut F) { match self { - Statement::Variable(id, _, _) => f(true, id), - Statement::LoadVar(a, _) => a.visit_id_mut(f), - Statement::StoreVar(a, _) => a.visit_id_mut(f), - Statement::Label(id) => f(false, id), - Statement::Instruction(inst) => inst.visit_id_mut(f), - Statement::Conditional(bra) => bra.visit_id_mut(f), - Statement::Converison(conv) => conv.visit_id_mut(f), - Statement::Constant(cons) => cons.visit_id_mut(f), + Statement::Variable(id, _, _) => f(id), + Statement::LoadVar(a, _) => a.visit_id(&mut |_, id, _| f(id), None), + Statement::StoreVar(a, _) => a.visit_id(&mut |_, id, _| f(id), None), + Statement::Label(id) => f(id), + Statement::Instruction(inst) => inst.visit_id(f), + Statement::Conditional(bra) => bra.visit_id(&mut |_, id, _| f(id)), + Statement::Conversion(conv) => conv.visit_id(f), + Statement::Constant(cons) => cons.visit_id(f), } } } trait Args { - type Arg1: Arg; - type Arg2: Arg; - type Arg2St: Arg; - type Arg2Mov: Arg; - type Arg3: Arg; - type Arg4: Arg; - type Arg5: Arg; -} - -trait Arg { - fn visit_id(&self, f: &mut F); - fn visit_id_mut(&mut self, f: &mut F); + type Arg1; + type Arg2; + type Arg2St; + type Arg2Mov; + type Arg3; + type Arg4; + type Arg5; } enum NormalizedArgs {} @@ -1049,48 +1046,24 @@ enum Instruction { Ret(ast::RetData), } -impl Instruction { - fn visit_id_mut(&mut self, f: &mut F) { +impl Instruction { + fn visit_id)>(&mut self, f: &mut F) { match self { - Instruction::Ld(_, a) => a.visit_id_mut(f), - Instruction::Mov(_, a) => a.visit_id_mut(f), - Instruction::Mul(_, a) => a.visit_id_mut(f), - Instruction::Add(_, a) => a.visit_id_mut(f), - Instruction::Setp(_, a) => a.visit_id_mut(f), - Instruction::SetpBool(_, a) => a.visit_id_mut(f), - Instruction::Not(_, a) => a.visit_id_mut(f), - Instruction::Cvt(_, a) => a.visit_id_mut(f), - Instruction::Shl(_, a) => a.visit_id_mut(f), - Instruction::St(_, a) => a.visit_id_mut(f), - Instruction::Bra(_, a) => a.visit_id_mut(f), + Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)), + Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())), + Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())), + Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + Instruction::Not(_, _) => todo!(), + Instruction::Cvt(_, _) => todo!(), + Instruction::Shl(_, _) => todo!(), + Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + Instruction::Bra(_, a) => a.visit_id(f, None), Instruction::Ret(_) => (), } } - fn get_type(&self) -> Option { - match self { - Instruction::Add(add, _) => match add { - ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => { - Some(ast::Type::Scalar((*typ).into())) - } - ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => Some((*typ).into()), - }, - Instruction::Ret(_) => None, - Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)), - Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)), - Instruction::Mov(mov, _) => Some(mov.typ), - Instruction::Mul(mul, _) => match mul { - ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => { - Some(ast::Type::Scalar((*typ).into())) - } - ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => Some((*typ).into()), - }, - _ => todo!(), - } - } -} - -impl Instruction { fn from_ast(s: ast::Instruction) -> Self { match s { ast::Instruction::Ld(d, a) => Instruction::Ld(d, a), @@ -1110,6 +1083,50 @@ impl Instruction { } impl Instruction { + fn visit_id(&mut self, f: &mut F) { + let f_visitor = &mut Self::typed_visitor(f); + match self { + Instruction::Ld(_, a) => a.visit_id(f_visitor, None), + Instruction::Mov(_, a) => a.visit_id(f_visitor, None), + Instruction::Mul(_, a) => a.visit_id(f_visitor, None), + Instruction::Add(_, a) => a.visit_id(f_visitor, None), + Instruction::Setp(_, a) => a.visit_id(f_visitor, None), + Instruction::SetpBool(_, a) => a.visit_id(f_visitor, None), + Instruction::Not(_, a) => a.visit_id(f_visitor, None), + Instruction::Cvt(_, a) => a.visit_id(f_visitor, None), + Instruction::Shl(_, a) => a.visit_id(f_visitor, None), + Instruction::St(_, a) => a.visit_id(f_visitor, None), + Instruction::Bra(_, a) => a.visit_id(f_visitor, None), + Instruction::Ret(_) => (), + } + } + + fn typed_visitor<'a>( + f: &'a mut impl FnMut(&mut spirv::Word), + ) -> impl FnMut(bool, &mut spirv::Word, Option) + 'a { + move |_, id, _| f(id) + } + + fn visit_id_extended)>( + &mut self, + f: &mut F, + ) { + match self { + Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)), + Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())), + Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())), + Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + Instruction::Not(_, a) => todo!(), + Instruction::Cvt(_, a) => todo!(), + Instruction::Shl(_, a) => todo!(), + Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + Instruction::Bra(_, a) => a.visit_id(f, None), + Instruction::Ret(_) => (), + } + } + fn jump_target(&self) -> Option { match self { Instruction::Bra(_, a) => Some(a.src), @@ -1132,13 +1149,13 @@ struct Arg1 { pub src: spirv::Word, } -impl Arg for Arg1 { - fn visit_id(&self, f: &mut F) { - f(false, self.src); - } - - fn visit_id_mut(&mut self, f: &mut F) { - f(false, &mut self.src); +impl Arg1 { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { + f(false, &mut self.src, t); } } @@ -1147,15 +1164,14 @@ struct Arg2 { pub src: spirv::Word, } -impl Arg for Arg2 { - fn visit_id(&self, f: &mut F) { - f(true, self.dst); - f(false, self.src); - } - - fn visit_id_mut(&mut self, f: &mut F) { - f(false, &mut self.src); - f(true, &mut self.dst); +impl Arg2 { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { + f(true, &mut self.dst, t); + f(false, &mut self.src, t); } } @@ -1164,15 +1180,14 @@ pub struct Arg2St { pub src2: spirv::Word, } -impl Arg for Arg2St { - fn visit_id(&self, f: &mut F) { - f(false, self.src1); - f(false, self.src2); - } - - fn visit_id_mut(&mut self, f: &mut F) { - f(false, &mut self.src1); - f(false, &mut self.src2); +impl Arg2St { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { + f(false, &mut self.src1, t); + f(false, &mut self.src2, t); } } @@ -1182,17 +1197,15 @@ struct Arg3 { pub src2: spirv::Word, } -impl Arg for Arg3 { - fn visit_id(&self, f: &mut F) { - f(true, self.dst); - f(false, self.src1); - f(false, self.src2); - } - - fn visit_id_mut(&mut self, f: &mut F) { - f(false, &mut self.src1); - f(false, &mut self.src2); - f(true, &mut self.dst); +impl Arg3 { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { + f(true, &mut self.dst, t); + f(false, &mut self.src1, t); + f(false, &mut self.src2, t); } } @@ -1203,19 +1216,26 @@ struct Arg4 { pub src2: spirv::Word, } -impl Arg for Arg4 { - fn visit_id(&self, f: &mut F) { - f(true, self.dst1); - self.dst2.map(|dst2| f(true, dst2)); - f(false, self.src1); - f(false, self.src2); - } - - fn visit_id_mut(&mut self, f: &mut F) { - f(false, &mut self.src1); - f(false, &mut self.src2); - f(true, &mut self.dst1); - self.dst2.as_mut().map(|dst2| f(true, dst2)); +impl Arg4 { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { + f( + true, + &mut self.dst1, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ); + self.dst2.as_mut().map(|dst2| { + f( + true, + dst2, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ) + }); + f(false, &mut self.src1, t); + f(false, &mut self.src2, t); } } @@ -1227,21 +1247,31 @@ struct Arg5 { pub src3: spirv::Word, } -impl Arg for Arg5 { - fn visit_id(&self, f: &mut F) { - f(true, self.dst1); - self.dst2.map(|dst2| f(true, dst2)); - f(false, self.src1); - f(false, self.src2); - f(false, self.src3); - } - - fn visit_id_mut(&mut self, f: &mut F) { - f(false, &mut self.src1); - f(false, &mut self.src2); - f(false, &mut self.src3); - f(true, &mut self.dst1); - self.dst2.as_mut().map(|dst2| f(true, dst2)); +impl Arg5 { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { + f( + true, + &mut self.dst1, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ); + self.dst2.as_mut().map(|dst2| { + f( + true, + dst2, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ) + }); + f(false, &mut self.src1, t); + f(false, &mut self.src2, t); + f( + false, + &mut self.src3, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ); } } @@ -1252,12 +1282,8 @@ struct ConstantDefinition { } impl ConstantDefinition { - fn visit_id(&self, f: &mut F) { - f(true, self.dst); - } - - fn visit_id_mut(&mut self, f: &mut F) { - f(true, &mut self.dst); + fn visit_id(&mut self, f: &mut F) { + f(&mut self.dst); } } @@ -1268,16 +1294,14 @@ struct BrachCondition { } impl BrachCondition { - fn visit_id(&self, f: &mut F) { - f(false, self.predicate); - f(false, self.if_true); - f(false, self.if_false); - } - - fn visit_id_mut(&mut self, f: &mut F) { - f(false, &mut self.predicate); - f(false, &mut self.if_true); - f(false, &mut self.if_false); + fn visit_id)>(&mut self, f: &mut F) { + f( + false, + &mut self.predicate, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ); + f(false, &mut self.if_true, None); + f(false, &mut self.if_false, None); } } @@ -1298,14 +1322,9 @@ enum ConversionKind { } impl ImplicitConversion { - fn visit_id(&self, f: &mut F) { - f(false, self.src); - f(true, self.dst); - } - - fn visit_id_mut(&mut self, f: &mut F) { - f(false, &mut self.src); - f(true, &mut self.dst); + fn visit_id(&mut self, f: &mut F) { + f(&mut self.dst); + f(&mut self.src); } } @@ -1343,13 +1362,13 @@ impl ast::Arg1 { } } -impl Arg for ast::Arg1 { - fn visit_id(&self, f: &mut F) { - f(false, self.src); - } - - fn visit_id_mut(&mut self, f: &mut F) { - f(false, &mut self.src); +impl ast::Arg1 { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { + f(false, &mut self.src, t); } } @@ -1362,15 +1381,14 @@ impl ast::Arg2 { } } -impl Arg for ast::Arg2 { - fn visit_id(&self, f: &mut F) { - f(true, self.dst); - self.src.visit_id(f); - } - - fn visit_id_mut(&mut self, f: &mut F) { - self.src.visit_id_mut(f); - f(true, &mut self.dst); +impl ast::Arg2 { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { + f(true, &mut self.dst, t); + self.src.visit_id(f, t); } } @@ -1383,15 +1401,14 @@ impl ast::Arg2St { } } -impl Arg for ast::Arg2St { - fn visit_id(&self, f: &mut F) { - self.src1.visit_id(f); - self.src2.visit_id(f); - } - - fn visit_id_mut(&mut self, f: &mut F) { - self.src1.visit_id_mut(f); - self.src2.visit_id_mut(f); +impl ast::Arg2St { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { + self.src1.visit_id(f, t); + self.src2.visit_id(f, t); } } @@ -1404,15 +1421,14 @@ impl ast::Arg2Mov { } } -impl Arg for ast::Arg2Mov { - fn visit_id(&self, f: &mut F) { - f(true, self.dst); - self.src.visit_id(f); - } - - fn visit_id_mut(&mut self, f: &mut F) { - self.src.visit_id_mut(f); - f(true, &mut self.dst); +impl ast::Arg2Mov { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { + f(true, &mut self.dst, t); + self.src.visit_id(f, t); } } @@ -1426,17 +1442,15 @@ impl ast::Arg3 { } } -impl Arg for ast::Arg3 { - fn visit_id(&self, f: &mut F) { - f(true, self.dst); - self.src1.visit_id(f); - self.src2.visit_id(f); - } - - fn visit_id_mut(&mut self, f: &mut F) { - self.src1.visit_id_mut(f); - self.src2.visit_id_mut(f); - f(true, &mut self.dst); +impl ast::Arg3 { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { + f(true, &mut self.dst, t); + self.src1.visit_id(f, t); + self.src2.visit_id(f, t); } } @@ -1451,19 +1465,26 @@ impl ast::Arg4 { } } -impl Arg for ast::Arg4 { - fn visit_id(&self, f: &mut F) { - f(true, self.dst1); - self.dst2.map(|i| f(true, i)); - self.src1.visit_id(f); - self.src2.visit_id(f); - } - - fn visit_id_mut(&mut self, f: &mut F) { - self.src1.visit_id_mut(f); - self.src2.visit_id_mut(f); - f(true, &mut self.dst1); - self.dst2.as_mut().map(|i| f(true, i)); +impl ast::Arg4 { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { + f( + true, + &mut self.dst1, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ); + self.dst2.as_mut().map(|i| { + f( + true, + i, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ) + }); + self.src1.visit_id(f, t); + self.src2.visit_id(f, t); } } @@ -1479,21 +1500,30 @@ impl ast::Arg5 { } } -impl Arg for ast::Arg5 { - fn visit_id(&self, f: &mut F) { - f(true, self.dst1); - self.dst2.map(|i| f(true, i)); - self.src1.visit_id(f); - self.src2.visit_id(f); - self.src3.visit_id(f); - } - - fn visit_id_mut(&mut self, f: &mut F) { - self.src1.visit_id_mut(f); - self.src2.visit_id_mut(f); - self.src3.visit_id_mut(f); - f(true, &mut self.dst1); - self.dst2.as_mut().map(|i| f(true, i)); +impl ast::Arg5 { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { + f( + true, + &mut self.dst1, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ); + self.dst2.as_mut().map(|i| { + f( + true, + i, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ) + }); + self.src1.visit_id(f, t); + self.src2.visit_id(f, t); + self.src3.visit_id( + f, + Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + ); } } @@ -1508,18 +1538,14 @@ impl ast::Operand { } impl ast::Operand { - fn visit_id(&self, f: &mut F) { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { match self { - ast::Operand::Reg(i) => f(false, *i), - ast::Operand::RegOffset(i, _) => f(false, *i), - ast::Operand::Imm(_) => (), - } - } - - fn visit_id_mut(&mut self, f: &mut F) { - match self { - ast::Operand::Reg(i) => f(false, i), - ast::Operand::RegOffset(i, _) => f(false, i), + ast::Operand::Reg(i) => f(false, i, t), + ast::Operand::RegOffset(i, _) => f(false, i, t), ast::Operand::Imm(_) => (), } } @@ -1535,16 +1561,13 @@ impl ast::MovOperand { } impl ast::MovOperand { - fn visit_id(&self, f: &mut F) { + fn visit_id)>( + &mut self, + f: &mut F, + t: Option, + ) { match self { - ast::MovOperand::Op(o) => o.visit_id(f), - ast::MovOperand::Vec(_, _) => todo!(), - } - } - - fn visit_id_mut(&mut self, f: &mut F) { - match self { - ast::MovOperand::Op(o) => o.visit_id_mut(f), + ast::MovOperand::Op(o) => o.visit_id(f, t), ast::MovOperand::Vec(_, _) => todo!(), } } @@ -1793,7 +1816,7 @@ fn insert_conversion_src( conv: ConversionKind, ) -> spirv::Word { let temp_src = id_def.new_id(Some(instr_type)); - func.push(Statement::Converison(ImplicitConversion { + func.push(Statement::Conversion(ImplicitConversion { src: src, dst: temp_src, from: src_type, @@ -1838,7 +1861,7 @@ fn get_conversion_dst( let original_dst = *dst; let temp_dst = id_def.new_id(Some(instr_type)); *dst = temp_dst; - Statement::Converison(ImplicitConversion { + Statement::Conversion(ImplicitConversion { src: temp_dst, dst: original_dst, from: instr_type, @@ -1938,31 +1961,33 @@ fn insert_implicit_bitcasts( mut instr: Instruction, ) { let mut dst_coercion = None; - if let Some(instr_type) = instr.get_type() { - instr.visit_id_mut(&mut |is_dst, id| { - let id_type = id_def.get_type(*id); - if should_bitcast(instr_type, id_def.get_type(*id)) { - if is_dst { - dst_coercion = Some(get_conversion_dst( - id_def, - id, - instr_type, - id_type, - ConversionKind::Default, - )); - } else { - *id = insert_conversion_src( - func, - id_def, - *id, - id_type, - instr_type, - ConversionKind::Default, - ); - } + instr.visit_id_extended(&mut |is_dst, id, id_type| { + let id_type_from_instr = match id_type { + Some(t) => t, + None => return, + }; + let id_actual_type = id_def.get_type(*id); + if should_bitcast(id_type_from_instr, id_def.get_type(*id)) { + if is_dst { + dst_coercion = Some(get_conversion_dst( + id_def, + id, + id_type_from_instr, + id_actual_type, + ConversionKind::Default, + )); + } else { + *id = insert_conversion_src( + func, + id_def, + *id, + id_actual_type, + id_type_from_instr, + ConversionKind::Default, + ); } - }); - } + } + }); func.push(Statement::Instruction(instr)); if let Some(cond) = dst_coercion { func.push(cond);