diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index f40846d..83a0fe2 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -86,7 +86,7 @@ FunctionInput: ast::Argument<'input> = { } }; -FunctionBody: Vec> = { +pub(crate) FunctionBody: Vec> = { "{" "}" => { without_none(s) } }; diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 651f996..1206e22 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -251,7 +251,11 @@ fn rename_succesor_phi_src( } } -fn pop_stacks(ssa_state: &mut SSARewriteState, old_phi: &HashSet, old_ids: &[spirv::Word]) { +fn pop_stacks( + ssa_state: &mut SSARewriteState, + old_phi: &HashSet, + old_ids: &[spirv::Word], +) { for id in old_phi.iter().chain(old_ids) { ssa_state.pop(*id); } @@ -335,7 +339,7 @@ fn gather_phi_sets( }); } } - for (id, to_work ) in def_sites.iter_mut().enumerate() { + for (id, to_work) in def_sites.iter_mut().enumerate() { let id = id as spirv::Word; let (ref mut set, ref mut stack) = to_work; loop { @@ -358,18 +362,26 @@ fn gather_phi_sets( result } -fn get_basic_blocks(fun: &Vec) -> Vec { - let mut direct_bb_start = Vec::new(); - let mut indirect_bb_start = Vec::new(); +fn get_basic_blocks(fun: &[Statement]) -> Vec { + // edge signify pred/succ relationship between bbs + let mut bb_edge = HashSet::new(); + let mut unresolved_bb_edge = Vec::new(); + // bb start means that a bb is starting at this statement, but there's no predecessor + let mut bb_start = Vec::new(); let mut labels = HashMap::new(); for (idx, s) in fun.iter().enumerate() { match s { - Statement::Instruction(_, i) => { + Statement::Instruction(pred, i) => { if let Some(id) = i.jump_target() { - indirect_bb_start.push((StmtIndex(idx), id)); + unresolved_bb_edge.push((StmtIndex(idx), id)); if idx + 1 < fun.len() { - direct_bb_start.push((StmtIndex(idx), StmtIndex(idx + 1))); + if pred.is_some() { + bb_edge.insert((StmtIndex(idx), StmtIndex(idx + 1))); + } + bb_start.push(StmtIndex(idx + 1)); } + } else if i.is_terminal() && idx + 1 < fun.len() { + bb_start.push(StmtIndex(idx + 1)); } } Statement::Label(id) => { @@ -377,6 +389,25 @@ fn get_basic_blocks(fun: &Vec) -> Vec { } }; } + // Resolve every into + // TODO: handle jumps into nowhere + for (idx, id) in unresolved_bb_edge { + let target = labels[&id]; + bb_edge.insert((idx, target)); + bb_start.push(target); + // now check if the preceding statement forms an edge + if target != StmtIndex(0) { + match &fun[target.0 - 1] { + Statement::Instruction(pred, i) => { + if !((pred.is_none() && i.jump_target().is_some()) || i.is_terminal()) { + bb_edge.insert((StmtIndex(target.0 - 1), target)); + } + } + Statement::Label(_) => (), + } + } + } + // Create list of bbs without succ/pred let mut bbs_map = BTreeMap::new(); bbs_map.insert( StmtIndex(0), @@ -386,32 +417,22 @@ fn get_basic_blocks(fun: &Vec) -> Vec { succ: Vec::new(), }, ); - // TODO: handle jumps into nowhere - let resolved_indirect_bb_start = indirect_bb_start - .into_iter() - .map(|(idx, id)| (idx, labels[&id])) - .collect::>(); - for (_, to) in direct_bb_start - .iter() - .chain(resolved_indirect_bb_start.iter()) - { - bbs_map.entry(*to).or_insert_with(|| BasicBlock { - start: *to, + for bb_first_stmt in bb_start { + bbs_map.entry(bb_first_stmt).or_insert_with(|| BasicBlock { + start: bb_first_stmt, pred: Vec::new(), succ: Vec::new(), }); } + // Populate succ/pred let indexed_bbs_map = bbs_map .into_iter() .enumerate() .map(|(idx, (key, val))| (key, (BBIndex(idx), RefCell::new(val)))) .collect::>(); - for (from, to) in direct_bb_start - .iter() - .chain(resolved_indirect_bb_start.iter()) - { - let (_, (from_idx, from_ref)) = indexed_bbs_map.range(..=*from).next_back().unwrap(); - let (to_idx, to_ref) = indexed_bbs_map.get(to).unwrap(); + for (from, to) in bb_edge { + let (_, (from_idx, from_ref)) = indexed_bbs_map.range(..=from).next_back().unwrap(); + let (to_idx, to_ref) = indexed_bbs_map.get(&to).unwrap(); { from_ref.borrow_mut().succ.push(*to_idx); } @@ -527,9 +548,9 @@ struct BasicBlock { succ: Vec, } -#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)] +#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)] struct StmtIndex(pub usize); -#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)] +#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)] struct BBIndex(pub usize); enum Statement { @@ -646,6 +667,23 @@ impl ast::Instruction { } } + fn is_terminal(&self) -> bool { + match self { + ast::Instruction::Ret(_) => true, + ast::Instruction::Ld(_, _) + | ast::Instruction::Mov(_, _) + | ast::Instruction::Mul(_, _) + | ast::Instruction::Add(_, _) + | ast::Instruction::Setp(_, _) + | ast::Instruction::SetpBool(_, _) + | ast::Instruction::Not(_, _) + | ast::Instruction::Cvt(_, _) + | ast::Instruction::Shl(_, _) + | ast::Instruction::St(_, _) + | ast::Instruction::Bra(_, _) => false, + } + } + fn for_dst_id(&self, f: &mut F) { match self { ast::Instruction::Ld(_, a) => a.for_dst_id(f), @@ -826,6 +864,8 @@ impl ast::MovOperand { #[cfg(test)] mod tests { use super::*; + use crate::ast; + use crate::ptx; // page 411 #[test] @@ -1140,4 +1180,84 @@ mod tests { ] ); } + + fn sort_pred_succ(bb: &mut BasicBlock) { + bb.pred.sort(); + bb.succ.sort(); + } + + // page 403 + #[test] + fn gather_phi_sets_19_4() { + let func = "{ + mov.u32 i, 1; + mov.u32 j, 1; + mov.u32 k, 0; + block_2: + setp.ge.u32 p, k, 100; + @p bra block_4; + block_3: + setp.ge.u32 q, j, 20; + @q bra block_6; + block_5: + mov.u32 j, i; + add.u32 k, k, 1; + bra block_7; + block_6: + mov.u32 j, k; + add.u32 k, k, 2; + block_7: + bra block_2; + block_4: + ret; + }"; + let mut errors = Vec::new(); + let ast = ptx::FunctionBodyParser::new() + .parse(&mut errors, func) + .unwrap(); + assert_eq!(errors.len(), 0); + let (normalized_ids, _) = normalize_identifiers(ast); + let mut bbs = get_basic_blocks(&normalized_ids); + bbs.iter_mut().for_each(sort_pred_succ); + assert_eq!( + bbs, + vec![ + BasicBlock { + start: StmtIndex(0), + pred: vec![], + succ: vec![BBIndex(1)] + }, + BasicBlock { + start: StmtIndex(3), + pred: vec![BBIndex(0), BBIndex(5)], + succ: vec![BBIndex(2), BBIndex(6)] + }, + BasicBlock { + start: StmtIndex(6), + pred: vec![BBIndex(1)], + succ: vec![BBIndex(3), BBIndex(4)] + }, + BasicBlock { + start: StmtIndex(9), + pred: vec![BBIndex(2)], + succ: vec![BBIndex(5)] + }, + BasicBlock { + start: StmtIndex(13), + pred: vec![BBIndex(2)], + succ: vec![BBIndex(5)] + }, + BasicBlock { + start: StmtIndex(16), + pred: vec![BBIndex(3), BBIndex(4)], + succ: vec![BBIndex(1)] + }, + BasicBlock { + start: StmtIndex(18), + pred: vec![BBIndex(1)], + succ: vec![] + }, + ] + ); + } }