mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Fix bugs in basic block resolution
This commit is contained in:
parent
bce5f27843
commit
92b5dbd6a8
2 changed files with 148 additions and 28 deletions
|
@ -86,7 +86,7 @@ FunctionInput: ast::Argument<'input> = {
|
|||
}
|
||||
};
|
||||
|
||||
FunctionBody: Vec<ast::Statement<&'input str>> = {
|
||||
pub(crate) FunctionBody: Vec<ast::Statement<&'input str>> = {
|
||||
"{" <s:Statement*> "}" => { without_none(s) }
|
||||
};
|
||||
|
||||
|
|
|
@ -251,7 +251,11 @@ fn rename_succesor_phi_src(
|
|||
}
|
||||
}
|
||||
|
||||
fn pop_stacks(ssa_state: &mut SSARewriteState, old_phi: &HashSet<spirv::Word>, old_ids: &[spirv::Word]) {
|
||||
fn pop_stacks(
|
||||
ssa_state: &mut SSARewriteState,
|
||||
old_phi: &HashSet<spirv::Word>,
|
||||
old_ids: &[spirv::Word],
|
||||
) {
|
||||
for id in old_phi.iter().chain(old_ids) {
|
||||
ssa_state.pop(*id);
|
||||
}
|
||||
|
@ -335,7 +339,7 @@ fn gather_phi_sets(
|
|||
});
|
||||
}
|
||||
}
|
||||
for (id, to_work ) in def_sites.iter_mut().enumerate() {
|
||||
for (id, to_work) in def_sites.iter_mut().enumerate() {
|
||||
let id = id as spirv::Word;
|
||||
let (ref mut set, ref mut stack) = to_work;
|
||||
loop {
|
||||
|
@ -358,18 +362,26 @@ fn gather_phi_sets(
|
|||
result
|
||||
}
|
||||
|
||||
fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
|
||||
let mut direct_bb_start = Vec::new();
|
||||
let mut indirect_bb_start = Vec::new();
|
||||
fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
|
||||
// edge signify pred/succ relationship between bbs
|
||||
let mut bb_edge = HashSet::new();
|
||||
let mut unresolved_bb_edge = Vec::new();
|
||||
// bb start means that a bb is starting at this statement, but there's no predecessor
|
||||
let mut bb_start = Vec::new();
|
||||
let mut labels = HashMap::new();
|
||||
for (idx, s) in fun.iter().enumerate() {
|
||||
match s {
|
||||
Statement::Instruction(_, i) => {
|
||||
Statement::Instruction(pred, i) => {
|
||||
if let Some(id) = i.jump_target() {
|
||||
indirect_bb_start.push((StmtIndex(idx), id));
|
||||
unresolved_bb_edge.push((StmtIndex(idx), id));
|
||||
if idx + 1 < fun.len() {
|
||||
direct_bb_start.push((StmtIndex(idx), StmtIndex(idx + 1)));
|
||||
if pred.is_some() {
|
||||
bb_edge.insert((StmtIndex(idx), StmtIndex(idx + 1)));
|
||||
}
|
||||
bb_start.push(StmtIndex(idx + 1));
|
||||
}
|
||||
} else if i.is_terminal() && idx + 1 < fun.len() {
|
||||
bb_start.push(StmtIndex(idx + 1));
|
||||
}
|
||||
}
|
||||
Statement::Label(id) => {
|
||||
|
@ -377,6 +389,25 @@ fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
|
|||
}
|
||||
};
|
||||
}
|
||||
// Resolve every <jump into label> into <jump into statement index>
|
||||
// TODO: handle jumps into nowhere
|
||||
for (idx, id) in unresolved_bb_edge {
|
||||
let target = labels[&id];
|
||||
bb_edge.insert((idx, target));
|
||||
bb_start.push(target);
|
||||
// now check if the preceding statement forms an edge
|
||||
if target != StmtIndex(0) {
|
||||
match &fun[target.0 - 1] {
|
||||
Statement::Instruction(pred, i) => {
|
||||
if !((pred.is_none() && i.jump_target().is_some()) || i.is_terminal()) {
|
||||
bb_edge.insert((StmtIndex(target.0 - 1), target));
|
||||
}
|
||||
}
|
||||
Statement::Label(_) => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
// Create list of bbs without succ/pred
|
||||
let mut bbs_map = BTreeMap::new();
|
||||
bbs_map.insert(
|
||||
StmtIndex(0),
|
||||
|
@ -386,32 +417,22 @@ fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
|
|||
succ: Vec::new(),
|
||||
},
|
||||
);
|
||||
// TODO: handle jumps into nowhere
|
||||
let resolved_indirect_bb_start = indirect_bb_start
|
||||
.into_iter()
|
||||
.map(|(idx, id)| (idx, labels[&id]))
|
||||
.collect::<Vec<_>>();
|
||||
for (_, to) in direct_bb_start
|
||||
.iter()
|
||||
.chain(resolved_indirect_bb_start.iter())
|
||||
{
|
||||
bbs_map.entry(*to).or_insert_with(|| BasicBlock {
|
||||
start: *to,
|
||||
for bb_first_stmt in bb_start {
|
||||
bbs_map.entry(bb_first_stmt).or_insert_with(|| BasicBlock {
|
||||
start: bb_first_stmt,
|
||||
pred: Vec::new(),
|
||||
succ: Vec::new(),
|
||||
});
|
||||
}
|
||||
// Populate succ/pred
|
||||
let indexed_bbs_map = bbs_map
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, (key, val))| (key, (BBIndex(idx), RefCell::new(val))))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
for (from, to) in direct_bb_start
|
||||
.iter()
|
||||
.chain(resolved_indirect_bb_start.iter())
|
||||
{
|
||||
let (_, (from_idx, from_ref)) = indexed_bbs_map.range(..=*from).next_back().unwrap();
|
||||
let (to_idx, to_ref) = indexed_bbs_map.get(to).unwrap();
|
||||
for (from, to) in bb_edge {
|
||||
let (_, (from_idx, from_ref)) = indexed_bbs_map.range(..=from).next_back().unwrap();
|
||||
let (to_idx, to_ref) = indexed_bbs_map.get(&to).unwrap();
|
||||
{
|
||||
from_ref.borrow_mut().succ.push(*to_idx);
|
||||
}
|
||||
|
@ -527,9 +548,9 @@ struct BasicBlock {
|
|||
succ: Vec<BBIndex>,
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)]
|
||||
#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)]
|
||||
struct StmtIndex(pub usize);
|
||||
#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)]
|
||||
#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)]
|
||||
struct BBIndex(pub usize);
|
||||
|
||||
enum Statement {
|
||||
|
@ -646,6 +667,23 @@ impl<T: Copy> ast::Instruction<T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn is_terminal(&self) -> bool {
|
||||
match self {
|
||||
ast::Instruction::Ret(_) => true,
|
||||
ast::Instruction::Ld(_, _)
|
||||
| ast::Instruction::Mov(_, _)
|
||||
| ast::Instruction::Mul(_, _)
|
||||
| ast::Instruction::Add(_, _)
|
||||
| ast::Instruction::Setp(_, _)
|
||||
| ast::Instruction::SetpBool(_, _)
|
||||
| ast::Instruction::Not(_, _)
|
||||
| ast::Instruction::Cvt(_, _)
|
||||
| ast::Instruction::Shl(_, _)
|
||||
| ast::Instruction::St(_, _)
|
||||
| ast::Instruction::Bra(_, _) => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
|
||||
match self {
|
||||
ast::Instruction::Ld(_, a) => a.for_dst_id(f),
|
||||
|
@ -826,6 +864,8 @@ impl<T> ast::MovOperand<T> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ast;
|
||||
use crate::ptx;
|
||||
|
||||
// page 411
|
||||
#[test]
|
||||
|
@ -1140,4 +1180,84 @@ mod tests {
|
|||
]
|
||||
);
|
||||
}
|
||||
|
||||
fn sort_pred_succ(bb: &mut BasicBlock) {
|
||||
bb.pred.sort();
|
||||
bb.succ.sort();
|
||||
}
|
||||
|
||||
// page 403
|
||||
#[test]
|
||||
fn gather_phi_sets_19_4() {
|
||||
let func = "{
|
||||
mov.u32 i, 1;
|
||||
mov.u32 j, 1;
|
||||
mov.u32 k, 0;
|
||||
block_2:
|
||||
setp.ge.u32 p, k, 100;
|
||||
@p bra block_4;
|
||||
block_3:
|
||||
setp.ge.u32 q, j, 20;
|
||||
@q bra block_6;
|
||||
block_5:
|
||||
mov.u32 j, i;
|
||||
add.u32 k, k, 1;
|
||||
bra block_7;
|
||||
block_6:
|
||||
mov.u32 j, k;
|
||||
add.u32 k, k, 2;
|
||||
block_7:
|
||||
bra block_2;
|
||||
block_4:
|
||||
ret;
|
||||
}";
|
||||
let mut errors = Vec::new();
|
||||
let ast = ptx::FunctionBodyParser::new()
|
||||
.parse(&mut errors, func)
|
||||
.unwrap();
|
||||
assert_eq!(errors.len(), 0);
|
||||
let (normalized_ids, _) = normalize_identifiers(ast);
|
||||
let mut bbs = get_basic_blocks(&normalized_ids);
|
||||
bbs.iter_mut().for_each(sort_pred_succ);
|
||||
assert_eq!(
|
||||
bbs,
|
||||
vec![
|
||||
BasicBlock {
|
||||
start: StmtIndex(0),
|
||||
pred: vec![],
|
||||
succ: vec![BBIndex(1)]
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(3),
|
||||
pred: vec![BBIndex(0), BBIndex(5)],
|
||||
succ: vec![BBIndex(2), BBIndex(6)]
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(6),
|
||||
pred: vec![BBIndex(1)],
|
||||
succ: vec![BBIndex(3), BBIndex(4)]
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(9),
|
||||
pred: vec![BBIndex(2)],
|
||||
succ: vec![BBIndex(5)]
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(13),
|
||||
pred: vec![BBIndex(2)],
|
||||
succ: vec![BBIndex(5)]
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(16),
|
||||
pred: vec![BBIndex(3), BBIndex(4)],
|
||||
succ: vec![BBIndex(1)]
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(18),
|
||||
pred: vec![BBIndex(1)],
|
||||
succ: vec![]
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue