Fix bugs in basic block resolution

This commit is contained in:
Andrzej Janik 2020-04-28 00:02:34 +02:00
parent bce5f27843
commit 92b5dbd6a8
2 changed files with 148 additions and 28 deletions

View file

@ -86,7 +86,7 @@ FunctionInput: ast::Argument<'input> = {
}
};
FunctionBody: Vec<ast::Statement<&'input str>> = {
pub(crate) FunctionBody: Vec<ast::Statement<&'input str>> = {
"{" <s:Statement*> "}" => { without_none(s) }
};

View file

@ -251,7 +251,11 @@ fn rename_succesor_phi_src(
}
}
fn pop_stacks(ssa_state: &mut SSARewriteState, old_phi: &HashSet<spirv::Word>, old_ids: &[spirv::Word]) {
fn pop_stacks(
ssa_state: &mut SSARewriteState,
old_phi: &HashSet<spirv::Word>,
old_ids: &[spirv::Word],
) {
for id in old_phi.iter().chain(old_ids) {
ssa_state.pop(*id);
}
@ -335,7 +339,7 @@ fn gather_phi_sets(
});
}
}
for (id, to_work ) in def_sites.iter_mut().enumerate() {
for (id, to_work) in def_sites.iter_mut().enumerate() {
let id = id as spirv::Word;
let (ref mut set, ref mut stack) = to_work;
loop {
@ -358,18 +362,26 @@ fn gather_phi_sets(
result
}
fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
let mut direct_bb_start = Vec::new();
let mut indirect_bb_start = Vec::new();
fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
// edge signify pred/succ relationship between bbs
let mut bb_edge = HashSet::new();
let mut unresolved_bb_edge = Vec::new();
// bb start means that a bb is starting at this statement, but there's no predecessor
let mut bb_start = Vec::new();
let mut labels = HashMap::new();
for (idx, s) in fun.iter().enumerate() {
match s {
Statement::Instruction(_, i) => {
Statement::Instruction(pred, i) => {
if let Some(id) = i.jump_target() {
indirect_bb_start.push((StmtIndex(idx), id));
unresolved_bb_edge.push((StmtIndex(idx), id));
if idx + 1 < fun.len() {
direct_bb_start.push((StmtIndex(idx), StmtIndex(idx + 1)));
if pred.is_some() {
bb_edge.insert((StmtIndex(idx), StmtIndex(idx + 1)));
}
bb_start.push(StmtIndex(idx + 1));
}
} else if i.is_terminal() && idx + 1 < fun.len() {
bb_start.push(StmtIndex(idx + 1));
}
}
Statement::Label(id) => {
@ -377,6 +389,25 @@ fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
}
};
}
// Resolve every <jump into label> into <jump into statement index>
// TODO: handle jumps into nowhere
for (idx, id) in unresolved_bb_edge {
let target = labels[&id];
bb_edge.insert((idx, target));
bb_start.push(target);
// now check if the preceding statement forms an edge
if target != StmtIndex(0) {
match &fun[target.0 - 1] {
Statement::Instruction(pred, i) => {
if !((pred.is_none() && i.jump_target().is_some()) || i.is_terminal()) {
bb_edge.insert((StmtIndex(target.0 - 1), target));
}
}
Statement::Label(_) => (),
}
}
}
// Create list of bbs without succ/pred
let mut bbs_map = BTreeMap::new();
bbs_map.insert(
StmtIndex(0),
@ -386,32 +417,22 @@ fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
succ: Vec::new(),
},
);
// TODO: handle jumps into nowhere
let resolved_indirect_bb_start = indirect_bb_start
.into_iter()
.map(|(idx, id)| (idx, labels[&id]))
.collect::<Vec<_>>();
for (_, to) in direct_bb_start
.iter()
.chain(resolved_indirect_bb_start.iter())
{
bbs_map.entry(*to).or_insert_with(|| BasicBlock {
start: *to,
for bb_first_stmt in bb_start {
bbs_map.entry(bb_first_stmt).or_insert_with(|| BasicBlock {
start: bb_first_stmt,
pred: Vec::new(),
succ: Vec::new(),
});
}
// Populate succ/pred
let indexed_bbs_map = bbs_map
.into_iter()
.enumerate()
.map(|(idx, (key, val))| (key, (BBIndex(idx), RefCell::new(val))))
.collect::<BTreeMap<_, _>>();
for (from, to) in direct_bb_start
.iter()
.chain(resolved_indirect_bb_start.iter())
{
let (_, (from_idx, from_ref)) = indexed_bbs_map.range(..=*from).next_back().unwrap();
let (to_idx, to_ref) = indexed_bbs_map.get(to).unwrap();
for (from, to) in bb_edge {
let (_, (from_idx, from_ref)) = indexed_bbs_map.range(..=from).next_back().unwrap();
let (to_idx, to_ref) = indexed_bbs_map.get(&to).unwrap();
{
from_ref.borrow_mut().succ.push(*to_idx);
}
@ -527,9 +548,9 @@ struct BasicBlock {
succ: Vec<BBIndex>,
}
#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)]
#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)]
struct StmtIndex(pub usize);
#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)]
#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)]
struct BBIndex(pub usize);
enum Statement {
@ -646,6 +667,23 @@ impl<T: Copy> ast::Instruction<T> {
}
}
fn is_terminal(&self) -> bool {
match self {
ast::Instruction::Ret(_) => true,
ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _)
| ast::Instruction::Mul(_, _)
| ast::Instruction::Add(_, _)
| ast::Instruction::Setp(_, _)
| ast::Instruction::SetpBool(_, _)
| ast::Instruction::Not(_, _)
| ast::Instruction::Cvt(_, _)
| ast::Instruction::Shl(_, _)
| ast::Instruction::St(_, _)
| ast::Instruction::Bra(_, _) => false,
}
}
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
match self {
ast::Instruction::Ld(_, a) => a.for_dst_id(f),
@ -826,6 +864,8 @@ impl<T> ast::MovOperand<T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::ast;
use crate::ptx;
// page 411
#[test]
@ -1140,4 +1180,84 @@ mod tests {
]
);
}
fn sort_pred_succ(bb: &mut BasicBlock) {
bb.pred.sort();
bb.succ.sort();
}
// page 403
#[test]
fn gather_phi_sets_19_4() {
let func = "{
mov.u32 i, 1;
mov.u32 j, 1;
mov.u32 k, 0;
block_2:
setp.ge.u32 p, k, 100;
@p bra block_4;
block_3:
setp.ge.u32 q, j, 20;
@q bra block_6;
block_5:
mov.u32 j, i;
add.u32 k, k, 1;
bra block_7;
block_6:
mov.u32 j, k;
add.u32 k, k, 2;
block_7:
bra block_2;
block_4:
ret;
}";
let mut errors = Vec::new();
let ast = ptx::FunctionBodyParser::new()
.parse(&mut errors, func)
.unwrap();
assert_eq!(errors.len(), 0);
let (normalized_ids, _) = normalize_identifiers(ast);
let mut bbs = get_basic_blocks(&normalized_ids);
bbs.iter_mut().for_each(sort_pred_succ);
assert_eq!(
bbs,
vec![
BasicBlock {
start: StmtIndex(0),
pred: vec![],
succ: vec![BBIndex(1)]
},
BasicBlock {
start: StmtIndex(3),
pred: vec![BBIndex(0), BBIndex(5)],
succ: vec![BBIndex(2), BBIndex(6)]
},
BasicBlock {
start: StmtIndex(6),
pred: vec![BBIndex(1)],
succ: vec![BBIndex(3), BBIndex(4)]
},
BasicBlock {
start: StmtIndex(9),
pred: vec![BBIndex(2)],
succ: vec![BBIndex(5)]
},
BasicBlock {
start: StmtIndex(13),
pred: vec![BBIndex(2)],
succ: vec![BBIndex(5)]
},
BasicBlock {
start: StmtIndex(16),
pred: vec![BBIndex(3), BBIndex(4)],
succ: vec![BBIndex(1)]
},
BasicBlock {
start: StmtIndex(18),
pred: vec![BBIndex(1)],
succ: vec![]
},
]
);
}
}