diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 02ff958..651f996 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -151,7 +151,7 @@ fn ssa_legalize( doms: &Vec, dom_fronts: &Vec>, ) { - let phis = gather_phi_sets(&func, &bbs, dom_fronts); + let phis = gather_phi_sets(&func, max_id, &bbs, dom_fronts); apply_ssa_renaming(func, &bbs, doms, max_id, &phis); } @@ -161,7 +161,7 @@ fn apply_ssa_renaming( bbs: &[BasicBlock], doms: &[BBIndex], max_id: spirv::Word, - old_phi: &[Vec], + old_phi: &[HashSet], ) { let mut dom_tree = vec![Vec::new(); bbs.len()]; for (bb, idom) in doms.iter().enumerate() { @@ -251,7 +251,7 @@ fn rename_succesor_phi_src( } } -fn pop_stacks(ssa_state: &mut SSARewriteState, old_phi: &[spirv::Word], old_ids: &[spirv::Word]) { +fn pop_stacks(ssa_state: &mut SSARewriteState, old_phi: &HashSet, old_ids: &[spirv::Word]) { for id in old_phi.iter().chain(old_ids) { ssa_state.pop(*id); } @@ -313,12 +313,49 @@ impl SSARewriteState { } } +// "Modern Compiler Implementation in Java" - Algorithm 19.6 fn gather_phi_sets( func: &[Statement], + max_id: spirv::Word, bbs: &[BasicBlock], dom_fronts: &[HashSet], -) -> Vec> { - todo!() +) -> Vec> { + let mut result = vec![HashSet::new(); bbs.len()]; + let mut bb_dst_definitions = vec![HashSet::new(); bbs.len()]; + let mut def_sites = vec![(HashSet::new(), Vec::new()); (max_id as usize) + 1]; + for bb in 0..bbs.len() { + let bb = BBIndex(bb); + for s in get_bb_body(func, bbs, bb) { + s.for_dst_id(&mut |id| { + bb_dst_definitions[bb.0].insert(id); + let (ref mut set, ref mut stack) = def_sites[id as usize]; + if set.insert(bb) { + stack.push(bb); + } + }); + } + } + for (id, to_work ) in def_sites.iter_mut().enumerate() { + let id = id as spirv::Word; + let (ref mut set, ref mut stack) = to_work; + loop { + if let Some(bb) = stack.pop() { + set.remove(&bb); + for y_bb in &dom_fronts[bb.0] { + if result[y_bb.0].insert(id) { + if !bb_dst_definitions[y_bb.0].contains(&id) { + if set.insert(*y_bb) { + stack.push(*y_bb); + } + } + } + } + } else { + break; + } + } + } + result } fn get_basic_blocks(fun: &Vec) -> Vec {