diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c37413a..3486edd 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -182,9 +182,9 @@ fn to_ssa<'a>( } fn normalize_labels( - func: Vec>, + func: Vec, id_def: &mut NumericIdResolver, -) -> Vec> { +) -> Vec { let mut labels_in_use = HashSet::new(); for s in func.iter() { match s { @@ -240,11 +240,11 @@ fn normalize_predicates( result.push(Statement::Conditional(branch)); if folded_bra.is_none() { result.push(Statement::Label(if_true)); - result.push(Statement::Instruction(Instruction::from_ast(inst))); + result.push(Statement::Instruction(inst)); } result.push(Statement::Label(if_false)); } else { - result.push(Statement::Instruction(Instruction::from_ast(inst))); + result.push(Statement::Instruction(inst)); } } ast::Statement::Variable(var) => { @@ -263,7 +263,7 @@ fn insert_mem_ssa_statements( for s in func { match s { Statement::Instruction(inst) => match inst { - Instruction::Ld( + ast::Instruction::Ld( ld @ ast::LdData { @@ -272,7 +272,7 @@ fn insert_mem_ssa_statements( }, arg, ) => { - result.push(Statement::Instruction(Instruction::Ld(ld, arg))); + result.push(Statement::Instruction(ast::Instruction::Ld(ld, arg))); } mut inst => { let mut post_statements = Vec::new(); @@ -343,51 +343,51 @@ fn expand_arguments( fn normalize_insert_instruction( func: &mut Vec, id_def: &mut NumericIdResolver, - instr: Instruction, -) -> Instruction { + instr: ast::Instruction, +) -> Instruction { match instr { - Instruction::Ld(d, a) => { + ast::Instruction::Ld(d, a) => { let arg = normalize_expand_arg2(func, id_def, &|| Some(d.typ), a); Instruction::Ld(d, arg) } - Instruction::Mov(d, a) => { + ast::Instruction::Mov(d, a) => { let arg = normalize_expand_arg2mov(func, id_def, &|| d.typ.try_as_scalar(), a); Instruction::Mov(d, arg) } - Instruction::Mul(d, a) => { + ast::Instruction::Mul(d, a) => { let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a); Instruction::Mul(d, arg) } - Instruction::Add(d, a) => { + ast::Instruction::Add(d, a) => { let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a); Instruction::Add(d, arg) } - Instruction::Setp(d, a) => { + ast::Instruction::Setp(d, a) => { let arg = normalize_expand_arg4(func, id_def, &|| Some(d.typ), a); Instruction::Setp(d, arg) } - Instruction::SetpBool(d, a) => { + ast::Instruction::SetpBool(d, a) => { let arg = normalize_expand_arg5(func, id_def, &|| Some(d.typ), a); Instruction::SetpBool(d, arg) } - Instruction::Not(d, a) => { + ast::Instruction::Not(d, a) => { let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a); Instruction::Not(d, arg) } - Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }), - Instruction::Cvt(d, a) => { + ast::Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }), + ast::Instruction::Cvt(d, a) => { let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a); Instruction::Cvt(d, arg) } - Instruction::Shl(d, a) => { + ast::Instruction::Shl(d, a) => { let arg = normalize_expand_arg3(func, id_def, &|| todo!(), a); Instruction::Shl(d, arg) } - Instruction::St(d, a) => { + ast::Instruction::St(d, a) => { let arg = normalize_expand_arg2st(func, id_def, &|| todo!(), a); Instruction::St(d, arg) } - Instruction::Ret(d) => Instruction::Ret(d), + ast::Instruction::Ret(d) => Instruction::Ret(d), } } @@ -967,19 +967,19 @@ impl NumericIdResolver { } } -enum Statement { +enum Statement { Variable(spirv::Word, ast::Type, ast::StateSpace), LoadVar(Arg2, ast::Type), StoreVar(Arg2St, ast::Type), Label(u32), - Instruction(Instruction), + Instruction(I), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), Conversion(ImplicitConversion), Constant(ConstantDefinition), } -impl Statement { +impl Statement { fn visit_id(&mut self, f: &mut F) { match self { Statement::Variable(id, _, _) => f(id), @@ -994,95 +994,44 @@ impl Statement { } } -trait Args { - type Arg1; - type Arg2; - type Arg2St; - type Arg2Mov; - type Arg3; - type Arg4; - type Arg5; -} +type NormalizedStatement = Statement>; +type ExpandedStatement = Statement; -enum NormalizedArgs {} - -impl Args for NormalizedArgs { - type Arg1 = ast::Arg1; - type Arg2 = ast::Arg2; - type Arg2St = ast::Arg2St; - type Arg2Mov = ast::Arg2Mov; - type Arg3 = ast::Arg3; - type Arg4 = ast::Arg4; - type Arg5 = ast::Arg5; -} - -enum ExpandedArgs {} - -impl Args for ExpandedArgs { - type Arg1 = Arg1; - type Arg2 = Arg2; - type Arg2St = Arg2St; - type Arg2Mov = Arg2; - type Arg3 = Arg3; - type Arg4 = Arg4; - type Arg5 = Arg5; -} - -type NormalizedStatement = Statement; -type ExpandedStatement = Statement; - -enum Instruction { - Ld(ast::LdData, A::Arg2), - Mov(ast::MovData, A::Arg2Mov), - Mul(ast::MulDetails, A::Arg3), - Add(ast::AddDetails, A::Arg3), - Setp(ast::SetpData, A::Arg4), - SetpBool(ast::SetpBoolData, A::Arg5), - Not(ast::NotData, A::Arg2), - Bra(ast::BraData, A::Arg1), - Cvt(ast::CvtData, A::Arg2), - Shl(ast::ShlData, A::Arg3), - St(ast::StData, A::Arg2St), +enum Instruction { + Ld(ast::LdData, Arg2), + Mov(ast::MovData, Arg2), + Mul(ast::MulDetails, Arg3), + Add(ast::AddDetails, Arg3), + Setp(ast::SetpData, Arg4), + SetpBool(ast::SetpBoolData, Arg5), + Not(ast::NotData, Arg2), + Bra(ast::BraData, Arg1), + Cvt(ast::CvtData, Arg2), + Shl(ast::ShlData, Arg3), + St(ast::StData, Arg2St), Ret(ast::RetData), } -impl Instruction { +impl ast::Instruction { fn visit_id)>(&mut self, f: &mut F) { match self { - Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)), - Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())), - Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())), - Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - Instruction::Not(_, _) => todo!(), - Instruction::Cvt(_, _) => todo!(), - Instruction::Shl(_, _) => todo!(), - Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - Instruction::Bra(_, a) => a.visit_id(f, None), - Instruction::Ret(_) => (), - } - } - - fn from_ast(s: ast::Instruction) -> Self { - match s { - ast::Instruction::Ld(d, a) => Instruction::Ld(d, a), - ast::Instruction::Mov(d, a) => Instruction::Mov(d, a), - ast::Instruction::Mul(d, a) => Instruction::Mul(d, a), - ast::Instruction::Add(d, a) => Instruction::Add(d, a), - ast::Instruction::Setp(d, a) => Instruction::Setp(d, a), - ast::Instruction::SetpBool(d, a) => Instruction::SetpBool(d, a), - ast::Instruction::Not(d, a) => Instruction::Not(d, a), - ast::Instruction::Cvt(d, a) => Instruction::Cvt(d, a), - ast::Instruction::Shl(d, a) => Instruction::Shl(d, a), - ast::Instruction::St(d, a) => Instruction::St(d, a), - ast::Instruction::Bra(d, a) => Instruction::Bra(d, a), - ast::Instruction::Ret(d) => Instruction::Ret(d), + ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + ast::Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)), + ast::Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())), + ast::Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())), + ast::Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + ast::Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + ast::Instruction::Not(_, _) => todo!(), + ast::Instruction::Cvt(_, _) => todo!(), + ast::Instruction::Shl(_, _) => todo!(), + ast::Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), + ast::Instruction::Bra(_, a) => a.visit_id(f, None), + ast::Instruction::Ret(_) => (), } } } -impl Instruction { +impl Instruction { fn visit_id(&mut self, f: &mut F) { let f_visitor = &mut Self::typed_visitor(f); match self { @@ -1118,9 +1067,9 @@ impl Instruction { Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())), Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), - Instruction::Not(_, a) => todo!(), - Instruction::Cvt(_, a) => todo!(), - Instruction::Shl(_, a) => todo!(), + Instruction::Not(_, _) => todo!(), + Instruction::Cvt(_, _) => todo!(), + Instruction::Shl(_, _) => todo!(), Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))), Instruction::Bra(_, a) => a.visit_id(f, None), Instruction::Ret(_) => (), @@ -1830,7 +1779,7 @@ fn insert_with_implicit_conversion_dst< T, ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option, Setter: Fn(&mut T) -> &mut spirv::Word, - ToInstruction: FnOnce(T) -> Instruction, + ToInstruction: FnOnce(T) -> Instruction, >( func: &mut Vec, instr_type: ast::ScalarType, @@ -1958,7 +1907,7 @@ fn should_convert_relaxed_dst( fn insert_implicit_bitcasts( func: &mut Vec, id_def: &mut NumericIdResolver, - mut instr: Instruction, + mut instr: Instruction, ) { let mut dst_coercion = None; instr.visit_id_extended(&mut |is_dst, id, id_type| {