diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index ac78cbe..e0f69a7 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -151,9 +151,9 @@ fn ssa_legalize( bbs: Vec, doms: &Vec, dom_fronts: &Vec>, -) { +) -> Vec> { let phis = gather_phi_sets(&func, max_id, &bbs, dom_fronts); - apply_ssa_renaming(func, &bbs, doms, max_id, &phis); + apply_ssa_renaming(func, &bbs, doms, max_id, &phis) } // "Modern Compiler Implementation in Java" - Algorithm 19.7 @@ -163,9 +163,9 @@ fn apply_ssa_renaming( doms: &[BBIndex], max_id: spirv::Word, old_phi: &[HashSet], -) { +) -> Vec> { let mut dom_tree = vec![Vec::new(); bbs.len()]; - for (bb, idom) in doms.iter().enumerate() { + for (bb, idom) in doms.iter().enumerate().skip(1) { dom_tree[idom.0].push(BBIndex(bb)); } let mut old_dst_id = vec![Vec::new(); bbs.len()]; @@ -178,7 +178,7 @@ fn apply_ssa_renaming( .iter() .map(|ids| { ids.iter() - .map(|id| (*id, Vec::new())) + .map(|id| (*id, (u32::max_value(), HashSet::new()))) .collect::>() }) .collect::>(); @@ -190,7 +190,7 @@ fn apply_ssa_renaming( if let Some((BBIndex(bb), dom_succ_idx)) = state.last_mut() { let bb = *bb; if *dom_succ_idx == 0 { - rename_phi_dst(max_id, &mut ssa_state, &mut new_phi[bb]); + rename_phi_dst(&mut ssa_state, &mut new_phi[bb]); rename_bb_body(&mut ssa_state, func, bbs, BBIndex(bb)); for BBIndex(succ_idx) in bbs[bb].succ.iter() { rename_succesor_phi_src(&ssa_state, &mut new_phi[*succ_idx]); @@ -207,22 +207,32 @@ fn apply_ssa_renaming( break; } } + new_phi + .into_iter() + .map(|map| { + map.into_iter() + .map(|(_, (new_id, defs))| PhiDef { + dst: new_id, + src: defs, + }) + .collect::>() + }) + .collect::>() +} + +// before ssa-renaming every phi is x <- phi(x,x,x,x) +#[derive(Debug, PartialEq)] +struct PhiDef { + dst: spirv::Word, + src: HashSet, } fn rename_phi_dst( - max_old_id: spirv::Word, rewriter: &mut SSARewriteState, - phi: &mut HashMap>, + phi: &mut HashMap)>, ) { - let old_keys = phi - .keys() - .copied() - .filter(|id| *id <= max_old_id) - .collect::>(); - for k in old_keys.into_iter() { - let remapped_id = rewriter.redefine(k); - let values = phi.remove(&k).unwrap(); - phi.insert(remapped_id, values); + for (old_k, (new_k, _)) in phi.iter_mut() { + *new_k = rewriter.redefine(*old_k); } } @@ -245,10 +255,10 @@ fn rename_bb_body( fn rename_succesor_phi_src( ssa_state: &SSARewriteState, - phi: &mut HashMap>, + phi: &mut HashMap)>, ) { - for (id, v) in phi.iter_mut() { - v.push(ssa_state.get(*id)); + for (id, (_, v)) in phi.iter_mut() { + v.insert(ssa_state.get(*id)); } } @@ -295,9 +305,10 @@ struct SSARewriteState { impl SSARewriteState { fn new(max: spirv::Word) -> Self { - let stack = vec![Vec::new(); max as usize]; + let len = max + 1; + let stack = (0..len).map(|x| vec![x + len]).collect::>(); SSARewriteState { - next: max + 1, + next: 2 * len, stack, } } @@ -318,40 +329,50 @@ impl SSARewriteState { } } -// "Modern Compiler Implementation in Java" - Algorithm 19.6 +// "Engineering a Compiler" - Figure 9.9 +// Calculates semi-pruned phis fn gather_phi_sets( func: &[Statement], max_id: spirv::Word, - bbs: &[BasicBlock], + cfg: &[BasicBlock], dom_fronts: &[HashSet], ) -> Vec> { - let mut result = vec![HashSet::new(); bbs.len()]; - let mut bb_dst_definitions = vec![HashSet::new(); bbs.len()]; - let mut def_sites = vec![(HashSet::new(), Vec::new()); (max_id as usize) + 1]; - for bb in 0..bbs.len() { - let bb = BBIndex(bb); - for s in get_bb_body(func, bbs, bb) { - s.for_dst_id(&mut |id| { - bb_dst_definitions[bb.0].insert(id); - let (ref mut set, ref mut stack) = def_sites[id as usize]; - if set.insert(bb) { - stack.push(bb); + let mut result = vec![HashSet::new(); cfg.len()]; + let mut globals = HashSet::new(); + let mut blocks = vec![(Vec::new(), HashSet::new()); (max_id as usize) + 1]; + for bb in 0..cfg.len() { + let mut var_kill = HashSet::new(); + let mut visitor = |is_dst, id: &u32| { + if is_dst { + var_kill.insert(*id); + let (ref mut stack, ref mut set) = blocks[*id as usize]; + stack.push(BBIndex(bb)); + set.insert(BBIndex(bb)); + } else { + if !var_kill.contains(id) { + globals.insert(*id); } - }); + } + }; + for s in get_bb_body(func, cfg, BBIndex(bb)) { + match s { + Statement::Instruction(pred, inst) => { + pred.as_ref().map(|p| p.visit_id(&mut visitor)); + inst.visit_id(&mut visitor); + } + Statement::Label(_) => (), + } } } - for (id, to_work) in def_sites.iter_mut().enumerate() { - let id = id as spirv::Word; - let (ref mut set, ref mut stack) = to_work; + for id in globals { + let (ref mut work_stack, ref mut work_set) = &mut blocks[id as usize]; loop { - if let Some(bb) = stack.pop() { - set.remove(&bb); - for y_bb in &dom_fronts[bb.0] { - if result[y_bb.0].insert(id) { - if !bb_dst_definitions[y_bb.0].contains(&id) { - if set.insert(*y_bb) { - stack.push(*y_bb); - } + if let Some(bb) = work_stack.pop() { + work_set.remove(&bb); + for d_bb in &dom_fronts[bb.0] { + if result[d_bb.0].insert(id) { + if work_set.insert(*d_bb) { + work_stack.push(*d_bb); } } } @@ -596,6 +617,16 @@ impl Statement { } } + fn visit_id(&self, f: &mut F) { + match self { + Statement::Label(id) => f(true, id), + Statement::Instruction(pred, inst) => { + pred.as_ref().map(|p| p.visit_id(f)); + inst.visit_id(f); + } + } + } + fn visit_id_mut(&mut self, f: &mut F) { match self { Statement::Label(id) => f(true, id), @@ -615,6 +646,10 @@ impl ast::PredAt { } } + fn visit_id(&self, f: &mut F) { + f(false, &self.label) + } + fn visit_id_mut(&mut self, f: &mut F) { f(false, &mut self.label) } @@ -642,6 +677,23 @@ impl ast::Instruction { } } + fn visit_id(&self, f: &mut F) { + match self { + ast::Instruction::Ld(_, a) => a.visit_id(f), + ast::Instruction::Mov(_, a) => a.visit_id(f), + ast::Instruction::Mul(_, a) => a.visit_id(f), + ast::Instruction::Add(_, a) => a.visit_id(f), + ast::Instruction::Setp(_, a) => a.visit_id(f), + ast::Instruction::SetpBool(_, a) => a.visit_id(f), + ast::Instruction::Not(_, a) => a.visit_id(f), + ast::Instruction::Cvt(_, a) => a.visit_id(f), + ast::Instruction::Shl(_, a) => a.visit_id(f), + ast::Instruction::St(_, a) => a.visit_id(f), + ast::Instruction::Bra(_, a) => a.visit_id(f), + ast::Instruction::Ret(_) => (), + } + } + fn visit_id_mut(&mut self, f: &mut F) { match self { ast::Instruction::Ld(_, a) => a.visit_id_mut(f), @@ -718,6 +770,10 @@ impl ast::Arg1 { ast::Arg1 { src: f(self.src) } } + fn visit_id(&self, f: &mut F) { + f(false, &self.src); + } + fn visit_id_mut(&mut self, f: &mut F) { f(false, &mut self.src); } @@ -731,6 +787,11 @@ impl ast::Arg2 { } } + fn visit_id(&self, f: &mut F) { + f(true, &self.dst); + self.src.visit_id(f); + } + fn visit_id_mut(&mut self, f: &mut F) { f(true, &mut self.dst); self.src.visit_id_mut(f); @@ -751,6 +812,11 @@ impl ast::Arg2Mov { } } + fn visit_id(&self, f: &mut F) { + f(true, &self.dst); + self.src.visit_id(f); + } + fn visit_id_mut(&mut self, f: &mut F) { f(true, &mut self.dst); self.src.visit_id_mut(f); @@ -772,6 +838,12 @@ impl ast::Arg3 { } } + fn visit_id(&self, f: &mut F) { + f(true, &self.dst); + self.src1.visit_id(f); + self.src2.visit_id(f); + } + fn visit_id_mut(&mut self, f: &mut F) { f(true, &mut self.dst); self.src1.visit_id_mut(f); @@ -795,6 +867,13 @@ impl ast::Arg4 { } } + fn visit_id(&self, f: &mut F) { + f(true, &self.dst1); + self.dst2.as_ref().map(|i| f(true, i)); + self.src1.visit_id(f); + self.src2.visit_id(f); + } + fn visit_id_mut(&mut self, f: &mut F) { f(true, &mut self.dst1); self.dst2.as_mut().map(|i| f(true, i)); @@ -821,6 +900,14 @@ impl ast::Arg5 { } } + fn visit_id(&self, f: &mut F) { + f(true, &self.dst1); + self.dst2.as_ref().map(|i| f(true, i)); + self.src1.visit_id(f); + self.src2.visit_id(f); + self.src3.visit_id(f); + } + fn visit_id_mut(&mut self, f: &mut F) { f(true, &mut self.dst1); self.dst2.as_mut().map(|i| f(true, i)); @@ -846,6 +933,14 @@ impl ast::Operand { } } + fn visit_id(&self, f: &mut F) { + match self { + ast::Operand::Reg(i) => f(false, i), + ast::Operand::RegOffset(i, _) => f(false, i), + ast::Operand::Imm(_) => (), + } + } + fn visit_id_mut(&mut self, f: &mut F) { match self { ast::Operand::Reg(i) => f(false, i), @@ -863,6 +958,13 @@ impl ast::MovOperand { } } + fn visit_id(&self, f: &mut F) { + match self { + ast::MovOperand::Op(o) => o.visit_id(f), + ast::MovOperand::Vec(_, _) => (), + } + } + fn visit_id_mut(&mut self, f: &mut F) { match self { ast::MovOperand::Op(o) => o.visit_id_mut(f), @@ -1202,7 +1304,7 @@ mod tests { } // page 403 - const fig_19_4: &'static str = "{ + const FIG_19_4: &'static str = "{ mov.u32 i, 1; mov.u32 j, 1; mov.u32 k, 0; @@ -1226,8 +1328,8 @@ mod tests { }"; #[test] - fn gather_phi_sets_fig_19_4() { - let func = fig_19_4; + fn get_basic_blocks_fig_19_4() { + let func = FIG_19_4; let mut errors = Vec::new(); let ast = ptx::FunctionBodyParser::new() .parse(&mut errors, func) @@ -1348,7 +1450,6 @@ mod tests { ); } - // page 403 #[test] fn dominance_frontiers_fig_19_4() { let cfg = cfg_fig_19_4(); @@ -1369,4 +1470,45 @@ mod tests { ]; assert_eq!(dom_fronts, should); } + + #[test] + fn gather_phi_sets_fig_19_4() { + let func = FIG_19_4; + let mut errors = Vec::new(); + let fn_ast = ptx::FunctionBodyParser::new() + .parse(&mut errors, func) + .unwrap(); + assert_eq!(errors.len(), 0); + let (normalized_ids, max_id) = normalize_identifiers(fn_ast); + let bbs = get_basic_blocks(&normalized_ids); + let rpostorder = to_reverse_postorder(&bbs); + let doms = immediate_dominators(&bbs, &rpostorder); + let dom_fronts = dominance_frontiers(&bbs, &doms); + let phi = gather_phi_sets(&normalized_ids, max_id, &bbs, &dom_fronts); + assert_eq!( + phi, + vec![ + HashSet::new(), + to_hashset(vec![1, 2]), + HashSet::new(), + HashSet::new(), + HashSet::new(), + to_hashset(vec![1, 2]), + HashSet::new() + ] + ); + } + + fn to_hashset(v: Vec) -> HashSet { + v.into_iter().collect::>() + } + + fn assert_dst_unique(func: &[Statement]) { + let mut seen = HashSet::new(); + for s in func { + s.for_dst_id(&mut |id| { + assert!(seen.insert(id)); + }); + } + } }