From a69c12a3872c5f71000ae4a8eaa868ecbffedc7d Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 2 May 2020 01:08:44 +0200 Subject: [PATCH] Fix remaining bugs in SSA renaming --- ptx/src/translate.rs | 354 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 295 insertions(+), 59 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index e0f69a7..a0e8405 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -105,63 +105,122 @@ fn emit_function<'a>( spirv::FunctionControl::NONE, map.fn_void(), )?; - for arg in f.args.iter() { - let arg_type = map.get_or_add(builder, SpirvType::Base(arg.a_type)); - builder.function_parameter(arg_type)?; - } - let (mut normalized_ids, max_id) = normalize_identifiers(f.body); + let mut contant_ids = HashMap::new(); + collect_arg_ids(&mut contant_ids, &f.args); + collect_label_ids(&mut contant_ids, &f.body); + let (mut normalized_ids, unique_ids) = normalize_identifiers(f.body, &contant_ids); let bbs = get_basic_blocks(&normalized_ids); let rpostorder = to_reverse_postorder(&bbs); let doms = immediate_dominators(&bbs, &rpostorder); let dom_fronts = dominance_frontiers(&bbs, &doms); - ssa_legalize(&mut normalized_ids, max_id, bbs, &doms, &dom_fronts); + ssa_legalize( + &mut normalized_ids, + contant_ids.len() as u32, + unique_ids, + &bbs, + &doms, + &dom_fronts, + ); emit_function_body_ops(builder); builder.ret()?; builder.end_function()?; Ok(func_id) } +fn collect_arg_ids<'a>(result: &mut HashMap<&'a str, spirv::Word>, args: &'a [ast::Argument<'a>]) { + let mut id = result.len() as u32; + for arg in args { + result.insert(arg.name, id); + id += 1; + } +} + +fn collect_label_ids<'a>( + result: &mut HashMap<&'a str, spirv::Word>, + fn_body: &[ast::Statement<&'a str>], +) { + let mut id = result.len() as u32; + for s in fn_body { + match s { + ast::Statement::Label(name) => { + result.insert(name, id); + id += 1; + } + ast::Statement::Instruction(_, _) => (), + ast::Statement::Variable(_) => (), + } + } +} + fn emit_function_body_ops(builder: &mut dr::Builder) { todo!() } +// This functions converts string identifiers to numeric identifiers in a normalized form, where +// - identifiers in the range [0..constant_identifiers.len()) are arguments and labels +// - identifiers in the range [constant_identifiers.len()..result.1) are variables // TODO: support scopes -fn normalize_identifiers<'a>(func: Vec>) -> (Vec, spirv::Word) { +fn normalize_identifiers<'a>( + func: Vec>, + constant_identifiers: &HashMap<&'a str, spirv::Word>, // arguments and labels can't be redefined +) -> (Vec, spirv::Word) { let mut result = Vec::with_capacity(func.len()); - let mut id: u32 = 0; - let mut known_ids = HashMap::new(); + let mut id: u32 = constant_identifiers.len() as u32; + let mut remapped_ids = HashMap::new(); let mut get_or_add = |key| { - *known_ids.entry(key).or_insert_with(|| { - let to_insert = id; - id += 1; - to_insert - }) + constant_identifiers.get(key).map_or_else( + || { + *remapped_ids.entry(key).or_insert_with(|| { + let to_insert = id; + id += 1; + to_insert + }) + }, + |id| *id, + ) }; for s in func { if let Some(s) = Statement::from_ast(s, &mut get_or_add) { result.push(s); } } - (result, id - 1) + (result, id) } fn ssa_legalize( func: &mut [Statement], - max_id: spirv::Word, - bbs: Vec, - doms: &Vec, - dom_fronts: &Vec>, + constant_ids: spirv::Word, + unique_ids: spirv::Word, + bbs: &[BasicBlock], + doms: &[BBIndex], + dom_fronts: &[HashSet], ) -> Vec> { - let phis = gather_phi_sets(&func, max_id, &bbs, dom_fronts); - apply_ssa_renaming(func, &bbs, doms, max_id, &phis) + let phis = gather_phi_sets( + &func, + constant_ids, + unique_ids, + &bbs, + dom_fronts, + ); + apply_ssa_renaming(func, &bbs, doms, constant_ids, unique_ids, &phis) } -// "Modern Compiler Implementation in Java" - Algorithm 19.7 +/* "Modern Compiler Implementation in Java" - Algorithm 19.7 + * This algorithm modifies passed function body in-place by renumbering ids, + * result ids can be divided into following categories + * - if id < constant_ids + * it's a non-redefinable id + * - if id >= constant_ids && id < all_ids + * then it's an undefined id (a0, b0, c0) + * - if id >= all_ids + * then it's a normally redefined id + */ fn apply_ssa_renaming( func: &mut [Statement], bbs: &[BasicBlock], doms: &[BBIndex], - max_id: spirv::Word, + constant_ids: spirv::Word, + all_ids: spirv::Word, old_phi: &[HashSet], ) -> Vec> { let mut dom_tree = vec![Vec::new(); bbs.len()]; @@ -182,7 +241,7 @@ fn apply_ssa_renaming( .collect::>() }) .collect::>(); - let mut ssa_state = SSARewriteState::new(max_id); + let mut ssa_state = SSARewriteState::new(constant_ids, all_ids); // once again, we do explicit stack let mut state = Vec::new(); state.push((BBIndex(0), 0)); @@ -300,32 +359,46 @@ fn get_bb_body_idx(func: &[Statement], all_bb: &[BasicBlock], bb: BBIndex) -> (u // We assume here that the variables are defined in the dense sequence 0..max struct SSARewriteState { next: spirv::Word, + constant_ids: spirv::Word, stack: Vec>, } -impl SSARewriteState { - fn new(max: spirv::Word) -> Self { - let len = max + 1; - let stack = (0..len).map(|x| vec![x + len]).collect::>(); +impl<'a> SSARewriteState { + fn new(constant_ids: spirv::Word, all_ids: spirv::Word) -> Self { + let to_redefine = all_ids - constant_ids; + let stack = (0..to_redefine) + .map(|x| vec![x + constant_ids]) + .collect::>(); SSARewriteState { - next: 2 * len, + next: all_ids, + constant_ids: constant_ids, stack, } } fn get(&self, x: spirv::Word) -> spirv::Word { - *self.stack[x as usize].last().unwrap() + if x < self.constant_ids { + x + } else { + *self.stack[(x - self.constant_ids) as usize].last().unwrap() + } } fn redefine(&mut self, x: spirv::Word) -> spirv::Word { - let result = self.next; - self.next += 1; - self.stack[x as usize].push(result); - return result; + if x < self.constant_ids { + x + } else { + let result = self.next; + self.next += 1; + self.stack[(x - self.constant_ids) as usize].push(result); + result + } } fn pop(&mut self, x: spirv::Word) { - self.stack[x as usize].pop(); + if x >= self.constant_ids { + self.stack[(x - self.constant_ids) as usize].pop(); + } } } @@ -333,24 +406,28 @@ impl SSARewriteState { // Calculates semi-pruned phis fn gather_phi_sets( func: &[Statement], - max_id: spirv::Word, + constant_ids: spirv::Word, + all_ids: spirv::Word, cfg: &[BasicBlock], dom_fronts: &[HashSet], ) -> Vec> { let mut result = vec![HashSet::new(); cfg.len()]; let mut globals = HashSet::new(); - let mut blocks = vec![(Vec::new(), HashSet::new()); (max_id as usize) + 1]; + let mut blocks = vec![(Vec::new(), HashSet::new()); (all_ids - constant_ids) as usize]; for bb in 0..cfg.len() { let mut var_kill = HashSet::new(); let mut visitor = |is_dst, id: &u32| { - if is_dst { - var_kill.insert(*id); - let (ref mut stack, ref mut set) = blocks[*id as usize]; - stack.push(BBIndex(bb)); - set.insert(BBIndex(bb)); - } else { - if !var_kill.contains(id) { - globals.insert(*id); + if *id >= constant_ids { + let id = id - constant_ids; + if is_dst { + var_kill.insert(id); + let (ref mut stack, ref mut set) = blocks[id as usize]; + stack.push(BBIndex(bb)); + set.insert(BBIndex(bb)); + } else { + if !var_kill.contains(&id) { + globals.insert(id); + } } } }; @@ -360,6 +437,7 @@ fn gather_phi_sets( pred.as_ref().map(|p| p.visit_id(&mut visitor)); inst.visit_id(&mut visitor); } + // label redefinition is a compile-time error Statement::Label(_) => (), } } @@ -370,7 +448,7 @@ fn gather_phi_sets( if let Some(bb) = work_stack.pop() { work_set.remove(&bb); for d_bb in &dom_fronts[bb.0] { - if result[d_bb.0].insert(id) { + if result[d_bb.0].insert(id + constant_ids) { if work_set.insert(*d_bb) { work_stack.push(*d_bb); } @@ -627,6 +705,8 @@ impl Statement { } } + // WARNING: It is very important to first visit src operands and then dst operands, + // otherwise SSA renaming will yield weird results fn visit_id_mut(&mut self, f: &mut F) { match self { Statement::Label(id) => f(true, id), @@ -793,8 +873,8 @@ impl ast::Arg2 { } fn visit_id_mut(&mut self, f: &mut F) { - f(true, &mut self.dst); self.src.visit_id_mut(f); + f(true, &mut self.dst); } } @@ -818,8 +898,8 @@ impl ast::Arg2Mov { } fn visit_id_mut(&mut self, f: &mut F) { - f(true, &mut self.dst); self.src.visit_id_mut(f); + f(true, &mut self.dst); } } @@ -845,9 +925,9 @@ impl ast::Arg3 { } fn visit_id_mut(&mut self, f: &mut F) { - f(true, &mut self.dst); self.src1.visit_id_mut(f); self.src2.visit_id_mut(f); + f(true, &mut self.dst); } } @@ -875,10 +955,10 @@ impl ast::Arg4 { } fn visit_id_mut(&mut self, f: &mut F) { - f(true, &mut self.dst1); - self.dst2.as_mut().map(|i| f(true, i)); self.src1.visit_id_mut(f); self.src2.visit_id_mut(f); + f(true, &mut self.dst1); + self.dst2.as_mut().map(|i| f(true, i)); } } @@ -909,11 +989,11 @@ impl ast::Arg5 { } fn visit_id_mut(&mut self, f: &mut F) { - f(true, &mut self.dst1); - self.dst2.as_mut().map(|i| f(true, i)); self.src1.visit_id_mut(f); self.src2.visit_id_mut(f); self.src3.visit_id_mut(f); + f(true, &mut self.dst1); + self.dst2.as_mut().map(|i| f(true, i)); } } @@ -1335,7 +1415,9 @@ mod tests { .parse(&mut errors, func) .unwrap(); assert_eq!(errors.len(), 0); - let (normalized_ids, _) = normalize_identifiers(ast); + let mut constant_ids = HashMap::new(); + collect_label_ids(&mut constant_ids, &ast); + let (normalized_ids, _) = normalize_identifiers(ast, &constant_ids); let mut bbs = get_basic_blocks(&normalized_ids); bbs.iter_mut().for_each(sort_pred_succ); assert_eq!( @@ -1479,21 +1561,30 @@ mod tests { .parse(&mut errors, func) .unwrap(); assert_eq!(errors.len(), 0); - let (normalized_ids, max_id) = normalize_identifiers(fn_ast); + let mut constant_ids = HashMap::new(); + collect_label_ids(&mut constant_ids, &fn_ast); + assert_eq!(constant_ids.len(), 6); + let (normalized_ids, max_id) = normalize_identifiers(fn_ast, &constant_ids); let bbs = get_basic_blocks(&normalized_ids); let rpostorder = to_reverse_postorder(&bbs); let doms = immediate_dominators(&bbs, &rpostorder); let dom_fronts = dominance_frontiers(&bbs, &doms); - let phi = gather_phi_sets(&normalized_ids, max_id, &bbs, &dom_fronts); + let phi = gather_phi_sets( + &normalized_ids, + constant_ids.len() as u32, + max_id, + &bbs, + &dom_fronts, + ); assert_eq!( phi, vec![ HashSet::new(), - to_hashset(vec![1, 2]), + to_hashset(vec![7, 8]), HashSet::new(), HashSet::new(), HashSet::new(), - to_hashset(vec![1, 2]), + to_hashset(vec![7, 8]), HashSet::new() ] ); @@ -1503,12 +1594,157 @@ mod tests { v.into_iter().collect::>() } - fn assert_dst_unique(func: &[Statement]) { + #[test] + fn ssa_rename_19_4() { + let func = FIG_19_4; + let mut errors = Vec::new(); + let fn_ast = ptx::FunctionBodyParser::new() + .parse(&mut errors, func) + .unwrap(); + assert_eq!(errors.len(), 0); + let mut constant_ids = HashMap::new(); + collect_label_ids(&mut constant_ids, &fn_ast); + let (mut func, unique_ids) = normalize_identifiers(fn_ast, &constant_ids); + let bbs = get_basic_blocks(&func); + let rpostorder = to_reverse_postorder(&bbs); + let doms = immediate_dominators(&bbs, &rpostorder); + let dom_fronts = dominance_frontiers(&bbs, &doms); + let mut ssa_phis = ssa_legalize( + &mut func, + constant_ids.len() as u32, + unique_ids, + &bbs, + &doms, + &dom_fronts, + ); + assert_phi_dst_id(unique_ids, &ssa_phis); + assert_dst_unique(&func, &ssa_phis); + sort_phi(&mut ssa_phis); + + let i1 = unique_ids; + let j1 = unique_ids + 1; + let j2 = get_dst_from_src(&ssa_phis[1], j1); + let j3 = get_dst(&func[10]); + let j4 = get_dst_from_src(&ssa_phis[5], j3); + let j5 = get_dst(&func[14]); + let k1 = unique_ids + 2; + let k2 = get_dst_from_src(&ssa_phis[1], k1); + let k3 = get_dst(&func[11]); + let k4 = get_dst_from_src(&ssa_phis[5], k3); + let k5 = get_dst(&func[15]); + let p1 = get_dst(&func[4]); + let q1 = get_dst(&func[7]); + let block_2 = get_dst(&func[3]); + let block_3 = get_dst(&func[6]); + let block_5 = get_dst(&func[9]); + let block_6 = get_dst(&func[13]); + let block_7 = get_dst(&func[16]); + let block_4 = get_dst(&func[18]); + + { + assert_eq!(get_ids(&func[0]), vec![i1]); + assert_eq!(get_ids(&func[1]), vec![j1]); + assert_eq!(get_ids(&func[2]), vec![k1]); + + assert_eq!( + ssa_phis[1], + to_phi(vec![(j2, vec![j4, j1]), (k2, vec![k4, k1])]) + ); + assert_eq!(get_ids(&func[3]), vec![block_2]); + assert_eq!(get_ids(&func[4]), vec![p1, k2]); + assert_eq!(get_ids(&func[5]), vec![p1, block_4]); + + assert_eq!(get_ids(&func[6]), vec![block_3]); + assert_eq!(get_ids(&func[7]), vec![q1, j2]); + assert_eq!(get_ids(&func[8]), vec![q1, block_6]); + + assert_eq!(get_ids(&func[9]), vec![block_5]); + assert_eq!(get_ids(&func[10]), vec![j3, i1]); + assert_eq!(get_ids(&func[11]), vec![k3, k2]); + assert_eq!(get_ids(&func[12]), vec![block_7]); + + assert_eq!(get_ids(&func[13]), vec![block_6]); + assert_eq!(get_ids(&func[14]), vec![j5, k2]); + assert_eq!(get_ids(&func[15]), vec![k5, k2]); + + assert_eq!( + ssa_phis[5], + to_phi(vec![(j4, vec![j3, j5]), (k4, vec![k3, k5])]) + ); + assert_eq!(get_ids(&func[16]), vec![block_7]); + assert_eq!(get_ids(&func[17]), vec![block_2]); + + assert_eq!(get_ids(&func[18]), vec![block_4]); + assert_eq!(get_ids(&func[19]), vec![]); + } + } + + fn assert_phi_dst_id(max_id: spirv::Word, phis: &[Vec]) { + for phi_set in phis { + for phi in phi_set { + assert!(phi.dst > max_id); + } + } + } + + fn assert_dst_unique(func: &[Statement], phis: &[Vec]) { let mut seen = HashSet::new(); for s in func { s.for_dst_id(&mut |id| { assert!(seen.insert(id)); }); } + for phi_set in phis { + for phi in phi_set { + assert!(seen.insert(phi.dst)); + } + } + } + + fn get_ids(s: &Statement) -> Vec { + let mut result = Vec::new(); + s.visit_id(&mut |_, id| { + result.push(*id); + }); + result + } + + fn sort_phi(phis: &mut [Vec]) { + for phi_set in phis { + phi_set.sort_by_key(|phi| phi.dst); + } + } + + fn to_phi(raw: Vec<(spirv::Word, Vec)>) -> Vec { + let result = raw + .into_iter() + .map(|(dst, src)| PhiDef { + dst: dst, + src: src.into_iter().collect::>(), + }) + .collect::>(); + let mut result = [result]; + sort_phi(&mut result); + let [result] = result; + result + } + + fn get_dst(s: &Statement) -> spirv::Word { + let mut result = None; + s.visit_id(&mut |is_dst, id| { + if is_dst { + assert_eq!(result.replace(*id), None); + } + }); + result.unwrap() + } + + fn get_dst_from_src(phi: &[PhiDef], src: spirv::Word) -> spirv::Word { + for phi_set in phi { + if phi_set.src.contains(&src) { + return phi_set.dst; + } + } + panic!() } }