Calculate domination frontiers

This commit is contained in:
Andrzej Janik 2020-04-19 18:09:44 +02:00
parent 4a0e91949c
commit 426b9c5cb8
2 changed files with 363 additions and 26 deletions

View file

@ -1,7 +1,9 @@
#[macro_use]
extern crate quick_error;
#[macro_use]
extern crate lalrpop_util;
#[macro_use]
extern crate quick_error;
extern crate bit_vec;
extern crate rspirv;
extern crate spirv_headers as spirv;

View file

@ -1,6 +1,8 @@
use crate::ast;
use bit_vec::BitVec;
use rspirv::dr;
use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::{cell::RefCell, ptr};
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvType {
@ -80,7 +82,7 @@ pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
for f in ast.functions {
emit_function(&mut builder, &mut map, &mut ids, &f)?;
emit_function(&mut builder, &mut map, &mut ids, f)?;
}
Ok(vec![])
}
@ -110,7 +112,7 @@ fn emit_function<'a>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
ids: &mut IdWordMap<'a>,
f: &ast::Function<'a>,
f: ast::Function<'a>,
) -> Result<(), rspirv::dr::Error> {
let func_id = builder.begin_function(
map.void(),
@ -122,21 +124,21 @@ fn emit_function<'a>(
let arg_type = map.get_or_add(builder, SpirvType::Base(arg.a_type));
builder.function_parameter(arg_type)?;
}
for s in f.body.iter() {
match s {
ast::Statement::Label(name) => {
let id = ids.get_or_add(builder, name);
builder.begin_block(Some(id))?;
}
ast::Statement::Variable(var) => panic!(),
ast::Statement::Instruction(_, _) => panic!(),
}
}
let normalized_ids = normalize_identifiers(f.body);
let bbs = get_basic_blocks(&normalized_ids);
let rpostorder = to_reverse_postorder(&bbs);
let dom_fronts = dominance_frontiers(&bbs, &rpostorder);
let ssa = ssa_legalize(normalized_ids, dom_fronts);
emit_function_body_ops(ssa, builder);
builder.ret()?;
builder.end_function()?;
Ok(())
}
fn emit_function_body_ops(ssa: Vec<Statement>, builder: &mut dr::Builder) {
unimplemented!()
}
// TODO: support scopes
fn normalize_identifiers<'a>(func: Vec<ast::Statement<&'a str>>) -> Vec<Statement> {
let mut result = Vec::with_capacity(func.len());
@ -156,8 +158,175 @@ fn normalize_identifiers<'a>(func: Vec<ast::Statement<&'a str>>) -> Vec<Statemen
result
}
fn ssa_legalize(func: Vec<Statement>) -> Vec<Statement> {
vec![]
fn ssa_legalize(func: Vec<Statement>, dom_fronts: Vec<HashSet<BBIndex>>) -> Vec<Statement> {
unimplemented!()
}
fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
let mut direct_bb_start = Vec::new();
let mut indirect_bb_start = Vec::new();
let mut labels = HashMap::new();
for (idx, s) in fun.iter().enumerate() {
match s {
Statement::Instruction(_, i) => {
if let Some(id) = i.jump_target() {
indirect_bb_start.push((StmtIndex(idx), id));
if idx + 1 < fun.len() {
direct_bb_start.push((StmtIndex(idx), StmtIndex(idx + 1)));
}
}
}
Statement::Label(id) => {
labels.insert(id, StmtIndex(idx));
}
Statement::Phi(_) => (),
};
}
let mut bbs_map = BTreeMap::new();
bbs_map.insert(
StmtIndex(0),
BasicBlock {
start: StmtIndex(0),
pred: Vec::new(),
succ: Vec::new(),
},
);
// TODO: handle jumps into nowhere
let resolved_indirect_bb_start = indirect_bb_start
.into_iter()
.map(|(idx, id)| (idx, labels[&id]))
.collect::<Vec<_>>();
for (_, to) in direct_bb_start
.iter()
.chain(resolved_indirect_bb_start.iter())
{
bbs_map.entry(*to).or_insert_with(|| BasicBlock {
start: *to,
pred: Vec::new(),
succ: Vec::new(),
});
}
let indexed_bbs_map = bbs_map
.into_iter()
.enumerate()
.map(|(idx, (key, val))| (key, (BBIndex(idx), RefCell::new(val))))
.collect::<BTreeMap<_, _>>();
for (from, to) in direct_bb_start
.iter()
.chain(resolved_indirect_bb_start.iter())
{
let (_, (from_idx, from_ref)) = indexed_bbs_map.range(..=*from).next_back().unwrap();
let (to_idx, to_ref) = indexed_bbs_map.get(to).unwrap();
{
from_ref.borrow_mut().succ.push(*to_idx);
}
{
to_ref.borrow_mut().pred.push(*from_idx);
}
}
indexed_bbs_map
.into_iter()
.map(|(_, (_, bb))| bb.into_inner())
.collect::<Vec<_>>()
}
// "A Simple, Fast Dominance Algorithm" - Keith D. Cooper, Timothy J. Harvey, and Ken Kennedy
// https://www.cs.rice.edu/~keith/EMBED/dom.pdf
fn dominance_frontiers(bbs: &Vec<BasicBlock>, order: &Vec<BBIndex>) -> Vec<HashSet<BBIndex>> {
let doms = immediate_dominators(bbs, order);
let mut result = vec![HashSet::new(); bbs.len()];
for (bb_idx, b) in bbs.iter().enumerate() {
if b.pred.len() < 2 { continue; }
for p in b.pred.iter() {
let mut runner = *p;
while runner != doms[bb_idx] {
result[runner.0].insert(BBIndex(bb_idx));
runner = doms[runner.0];
}
}
}
result
}
fn immediate_dominators(bbs: &Vec<BasicBlock>, order: &Vec<BBIndex>) -> Vec<BBIndex> {
let mut doms = vec![BBIndex(usize::max_value()); bbs.len() - 1];
let mut changed = true;
while changed {
changed = false;
for BBIndex(bb_idx) in order.iter().skip(1) {
let bb = &bbs[*bb_idx];
if let Some(first_pred) = bb.pred.get(0) {
let mut new_idom = *first_pred;
for BBIndex(p_idx) in bb.pred.iter().copied().skip(1) {
if doms[p_idx] != BBIndex(usize::max_value()) {
new_idom = intersect(&mut doms, BBIndex(p_idx), new_idom);
}
}
if doms[*bb_idx] != new_idom {
doms[*bb_idx] = new_idom;
changed = true;
}
}
}
}
return doms;
}
fn intersect(doms: &mut Vec<BBIndex>, b1: BBIndex, b2: BBIndex) -> BBIndex {
let mut finger1 = b1;
let mut finger2 = b2;
while finger1 != finger2 {
while finger1 < finger2 {
finger1 = doms[finger1.0];
}
while finger2 < finger1 {
finger2 = doms[finger2.0];
}
}
finger1
}
// "A Simple Algorithm for Global Data Flow Analysis Problems" - Hecht, M. S., & Ullman, J. D. (1975)
fn to_reverse_postorder(input: &Vec<BasicBlock>) -> Vec<BBIndex> {
let mut i = input.len();
let mut old = BitVec::from_elem(input.len(), false);
// I would do just vec![BasicBlock::empty(), input.len()], but Vec<T> is not Copy
let mut result = Vec::with_capacity(input.len());
unsafe { result.set_len(input.len()) };
// original uses recursion and implicit stack, we do it explictly
let mut state = Vec::new();
state.push((BBIndex(0), 0usize));
loop {
if let Some((BBIndex(bb), succ_iter_idx)) = state.last_mut() {
let bb = *bb;
old.set(bb, true);
if let Some(BBIndex(succ)) = &input[bb].succ.get(*succ_iter_idx) {
*succ_iter_idx += 1;
if !old.get(*succ).unwrap() {
state.push((BBIndex(*succ), 0));
}
} else {
state.pop();
i = i - 1;
result[i] = BBIndex(bb);
}
} else {
break;
}
}
result
}
#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)]
struct StmtIndex(pub usize);
#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)]
struct BBIndex(pub usize);
#[derive(Eq, PartialEq, Debug, Clone)]
struct BasicBlock {
start: StmtIndex,
pred: Vec<BBIndex>,
succ: Vec<BBIndex>,
}
enum Statement {
@ -167,13 +336,15 @@ enum Statement {
}
impl Statement {
fn from_ast<'a, F: FnMut(&'a str) -> u32>(s: ast::Statement<&'a str>, f: &mut F) -> Option<Self> {
fn from_ast<'a, F: FnMut(&'a str) -> u32>(
s: ast::Statement<&'a str>,
f: &mut F,
) -> Option<Self> {
match s {
ast::Statement::Label(name) => Some(Statement::Label(f(name))),
ast::Statement::Instruction(p, i) => Some(Statement::Instruction(
p.map(|p| p.map_id(f)),
i.map_id(f),
)),
ast::Statement::Instruction(p, i) => {
Some(Statement::Instruction(p.map(|p| p.map_id(f)), i.map_id(f)))
}
ast::Statement::Variable(_) => None,
}
}
@ -188,8 +359,8 @@ impl<T> ast::PredAt<T> {
}
}
impl<T> ast::Instruction<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Instruction<U> {
impl<T> ast::Instruction<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Instruction<U> {
match self {
ast::Instruction::Ld(d, a) => ast::Instruction::Ld(d, a.map_id(f)),
ast::Instruction::Mov(d, a) => ast::Instruction::Mov(d, a.map_id(f)),
@ -208,6 +379,26 @@ impl<T> ast::Instruction<T> {
}
}
impl<T: Copy> ast::Instruction<T> {
fn jump_target(&self) -> Option<T> {
match self {
ast::Instruction::Bra(d, a) => Some(a.dst),
ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _)
| ast::Instruction::Mul(_, _)
| ast::Instruction::Add(_, _)
| ast::Instruction::Setp(_, _)
| ast::Instruction::SetpBool(_, _)
| ast::Instruction::Not(_, _)
| ast::Instruction::Cvt(_, _)
| ast::Instruction::Shl(_, _)
| ast::Instruction::St(_, _)
| ast::Instruction::At(_, _)
| ast::Instruction::Ret(_) => None,
}
}
}
impl<T> ast::Arg1<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg1<U> {
ast::Arg1 { dst: f(self.dst) }
@ -279,7 +470,151 @@ impl<T> ast::MovOperand<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::MovOperand<U> {
match self {
ast::MovOperand::Op(o) => ast::MovOperand::Op(o.map_id(f)),
ast::MovOperand::Vec(s1, s2) => ast::MovOperand::Vec(s1, s2)
ast::MovOperand::Vec(s1, s2) => ast::MovOperand::Vec(s1, s2),
}
}
}
}
// CFGs below taken from "Modern Compiler Implementation in Java"
#[cfg(test)]
mod tests {
use super::*;
#[test]
// page 411
fn to_reverse_postorder1() {
let input = vec![
BasicBlock {
// A
start: StmtIndex(0),
pred: vec![],
succ: vec![BBIndex(1), BBIndex(2)],
},
BasicBlock {
// B
start: StmtIndex(1),
pred: vec![BBIndex(0), BBIndex(11)],
succ: vec![BBIndex(3), BBIndex(6)],
},
BasicBlock {
// C
start: StmtIndex(2),
pred: vec![BBIndex(0), BBIndex(4)],
succ: vec![BBIndex(4), BBIndex(7)],
},
BasicBlock {
// D
start: StmtIndex(3),
pred: vec![BBIndex(1)],
succ: vec![BBIndex(5), BBIndex(6)],
},
BasicBlock {
// E
start: StmtIndex(4),
pred: vec![BBIndex(2)],
succ: vec![BBIndex(2), BBIndex(7)],
},
BasicBlock {
// F
start: StmtIndex(5),
pred: vec![BBIndex(3)],
succ: vec![BBIndex(8), BBIndex(10)],
},
BasicBlock {
// G
start: StmtIndex(6),
pred: vec![BBIndex(1), BBIndex(3)],
succ: vec![BBIndex(9)],
},
BasicBlock {
// H
start: StmtIndex(7),
pred: vec![BBIndex(2), BBIndex(4)],
succ: vec![BBIndex(12)],
},
BasicBlock {
// I
start: StmtIndex(8),
pred: vec![BBIndex(5), BBIndex(9)],
succ: vec![BBIndex(11)],
},
BasicBlock {
// J
start: StmtIndex(9),
pred: vec![BBIndex(6)],
succ: vec![BBIndex(8)],
},
BasicBlock {
// K
start: StmtIndex(10),
pred: vec![BBIndex(5)],
succ: vec![BBIndex(11)],
},
BasicBlock {
// L
start: StmtIndex(11),
pred: vec![BBIndex(8), BBIndex(10)],
succ: vec![BBIndex(1), BBIndex(12)],
},
BasicBlock {
// M
start: StmtIndex(12),
pred: vec![BBIndex(7), BBIndex(11)],
succ: vec![],
},
];
let rpostord = to_reverse_postorder(&input);
assert_eq!(
rpostord,
vec![
BBIndex(0), // A
BBIndex(2), // C
BBIndex(4), // E
BBIndex(7), // H
BBIndex(1), // B
BBIndex(3), // D
BBIndex(6), // G
BBIndex(9), // J
BBIndex(5), // F
BBIndex(10), // K
BBIndex(8), // I
BBIndex(11), // L
BBIndex(12), // M
]
);
}
#[test]
fn get_basic_blocks_empty() {
let func = Vec::new();
let bbs = get_basic_blocks(&func);
assert_eq!(
bbs,
vec![BasicBlock {
start: StmtIndex(0),
pred: vec![],
succ: vec![]
}]
);
}
#[test]
fn get_basic_blocks_miniloop() {
let func = vec![
Statement::Label(12),
Statement::Instruction(
None,
ast::Instruction::Bra(ast::BraData {}, ast::Arg1 { dst: 12 }),
),
];
let bbs = get_basic_blocks(&func);
assert_eq!(
bbs,
vec![BasicBlock {
start: StmtIndex(0),
pred: vec![BBIndex(0)],
succ: vec![BBIndex(0)]
}]
);
}
}