diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 716a25c..61c3444 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -1,7 +1,9 @@ #[macro_use] -extern crate quick_error; -#[macro_use] extern crate lalrpop_util; +#[macro_use] +extern crate quick_error; + +extern crate bit_vec; extern crate rspirv; extern crate spirv_headers as spirv; diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 2382180..8f38516 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,6 +1,8 @@ use crate::ast; +use bit_vec::BitVec; use rspirv::dr; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::{cell::RefCell, ptr}; #[derive(PartialEq, Eq, Hash, Clone, Copy)] enum SpirvType { @@ -80,7 +82,7 @@ pub fn to_spirv(ast: ast::Module) -> Result, rspirv::dr::Error> { emit_memory_model(&mut builder); let mut map = TypeWordMap::new(&mut builder); for f in ast.functions { - emit_function(&mut builder, &mut map, &mut ids, &f)?; + emit_function(&mut builder, &mut map, &mut ids, f)?; } Ok(vec![]) } @@ -110,7 +112,7 @@ fn emit_function<'a>( builder: &mut dr::Builder, map: &mut TypeWordMap, ids: &mut IdWordMap<'a>, - f: &ast::Function<'a>, + f: ast::Function<'a>, ) -> Result<(), rspirv::dr::Error> { let func_id = builder.begin_function( map.void(), @@ -122,21 +124,21 @@ fn emit_function<'a>( let arg_type = map.get_or_add(builder, SpirvType::Base(arg.a_type)); builder.function_parameter(arg_type)?; } - for s in f.body.iter() { - match s { - ast::Statement::Label(name) => { - let id = ids.get_or_add(builder, name); - builder.begin_block(Some(id))?; - } - ast::Statement::Variable(var) => panic!(), - ast::Statement::Instruction(_, _) => panic!(), - } - } + let normalized_ids = normalize_identifiers(f.body); + let bbs = get_basic_blocks(&normalized_ids); + let rpostorder = to_reverse_postorder(&bbs); + let dom_fronts = dominance_frontiers(&bbs, &rpostorder); + let ssa = ssa_legalize(normalized_ids, dom_fronts); + emit_function_body_ops(ssa, builder); builder.ret()?; builder.end_function()?; Ok(()) } +fn emit_function_body_ops(ssa: Vec, builder: &mut dr::Builder) { + unimplemented!() +} + // TODO: support scopes fn normalize_identifiers<'a>(func: Vec>) -> Vec { let mut result = Vec::with_capacity(func.len()); @@ -156,8 +158,175 @@ fn normalize_identifiers<'a>(func: Vec>) -> Vec) -> Vec { - vec![] +fn ssa_legalize(func: Vec, dom_fronts: Vec>) -> Vec { + unimplemented!() +} + +fn get_basic_blocks(fun: &Vec) -> Vec { + let mut direct_bb_start = Vec::new(); + let mut indirect_bb_start = Vec::new(); + let mut labels = HashMap::new(); + for (idx, s) in fun.iter().enumerate() { + match s { + Statement::Instruction(_, i) => { + if let Some(id) = i.jump_target() { + indirect_bb_start.push((StmtIndex(idx), id)); + if idx + 1 < fun.len() { + direct_bb_start.push((StmtIndex(idx), StmtIndex(idx + 1))); + } + } + } + Statement::Label(id) => { + labels.insert(id, StmtIndex(idx)); + } + Statement::Phi(_) => (), + }; + } + let mut bbs_map = BTreeMap::new(); + bbs_map.insert( + StmtIndex(0), + BasicBlock { + start: StmtIndex(0), + pred: Vec::new(), + succ: Vec::new(), + }, + ); + // TODO: handle jumps into nowhere + let resolved_indirect_bb_start = indirect_bb_start + .into_iter() + .map(|(idx, id)| (idx, labels[&id])) + .collect::>(); + for (_, to) in direct_bb_start + .iter() + .chain(resolved_indirect_bb_start.iter()) + { + bbs_map.entry(*to).or_insert_with(|| BasicBlock { + start: *to, + pred: Vec::new(), + succ: Vec::new(), + }); + } + let indexed_bbs_map = bbs_map + .into_iter() + .enumerate() + .map(|(idx, (key, val))| (key, (BBIndex(idx), RefCell::new(val)))) + .collect::>(); + for (from, to) in direct_bb_start + .iter() + .chain(resolved_indirect_bb_start.iter()) + { + let (_, (from_idx, from_ref)) = indexed_bbs_map.range(..=*from).next_back().unwrap(); + let (to_idx, to_ref) = indexed_bbs_map.get(to).unwrap(); + { + from_ref.borrow_mut().succ.push(*to_idx); + } + { + to_ref.borrow_mut().pred.push(*from_idx); + } + } + indexed_bbs_map + .into_iter() + .map(|(_, (_, bb))| bb.into_inner()) + .collect::>() +} + +// "A Simple, Fast Dominance Algorithm" - Keith D. Cooper, Timothy J. Harvey, and Ken Kennedy +// https://www.cs.rice.edu/~keith/EMBED/dom.pdf +fn dominance_frontiers(bbs: &Vec, order: &Vec) -> Vec> { + let doms = immediate_dominators(bbs, order); + let mut result = vec![HashSet::new(); bbs.len()]; + for (bb_idx, b) in bbs.iter().enumerate() { + if b.pred.len() < 2 { continue; } + for p in b.pred.iter() { + let mut runner = *p; + while runner != doms[bb_idx] { + result[runner.0].insert(BBIndex(bb_idx)); + runner = doms[runner.0]; + } + } + } + result +} + +fn immediate_dominators(bbs: &Vec, order: &Vec) -> Vec { + let mut doms = vec![BBIndex(usize::max_value()); bbs.len() - 1]; + let mut changed = true; + while changed { + changed = false; + for BBIndex(bb_idx) in order.iter().skip(1) { + let bb = &bbs[*bb_idx]; + if let Some(first_pred) = bb.pred.get(0) { + let mut new_idom = *first_pred; + for BBIndex(p_idx) in bb.pred.iter().copied().skip(1) { + if doms[p_idx] != BBIndex(usize::max_value()) { + new_idom = intersect(&mut doms, BBIndex(p_idx), new_idom); + } + } + if doms[*bb_idx] != new_idom { + doms[*bb_idx] = new_idom; + changed = true; + } + } + } + } + return doms; +} + +fn intersect(doms: &mut Vec, b1: BBIndex, b2: BBIndex) -> BBIndex { + let mut finger1 = b1; + let mut finger2 = b2; + while finger1 != finger2 { + while finger1 < finger2 { + finger1 = doms[finger1.0]; + } + while finger2 < finger1 { + finger2 = doms[finger2.0]; + } + } + finger1 +} + +// "A Simple Algorithm for Global Data Flow Analysis Problems" - Hecht, M. S., & Ullman, J. D. (1975) +fn to_reverse_postorder(input: &Vec) -> Vec { + let mut i = input.len(); + let mut old = BitVec::from_elem(input.len(), false); + // I would do just vec![BasicBlock::empty(), input.len()], but Vec is not Copy + let mut result = Vec::with_capacity(input.len()); + unsafe { result.set_len(input.len()) }; + // original uses recursion and implicit stack, we do it explictly + let mut state = Vec::new(); + state.push((BBIndex(0), 0usize)); + loop { + if let Some((BBIndex(bb), succ_iter_idx)) = state.last_mut() { + let bb = *bb; + old.set(bb, true); + if let Some(BBIndex(succ)) = &input[bb].succ.get(*succ_iter_idx) { + *succ_iter_idx += 1; + if !old.get(*succ).unwrap() { + state.push((BBIndex(*succ), 0)); + } + } else { + state.pop(); + i = i - 1; + result[i] = BBIndex(bb); + } + } else { + break; + } + } + result +} + +#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)] +struct StmtIndex(pub usize); +#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)] +struct BBIndex(pub usize); + +#[derive(Eq, PartialEq, Debug, Clone)] +struct BasicBlock { + start: StmtIndex, + pred: Vec, + succ: Vec, } enum Statement { @@ -167,13 +336,15 @@ enum Statement { } impl Statement { - fn from_ast<'a, F: FnMut(&'a str) -> u32>(s: ast::Statement<&'a str>, f: &mut F) -> Option { + fn from_ast<'a, F: FnMut(&'a str) -> u32>( + s: ast::Statement<&'a str>, + f: &mut F, + ) -> Option { match s { ast::Statement::Label(name) => Some(Statement::Label(f(name))), - ast::Statement::Instruction(p, i) => Some(Statement::Instruction( - p.map(|p| p.map_id(f)), - i.map_id(f), - )), + ast::Statement::Instruction(p, i) => { + Some(Statement::Instruction(p.map(|p| p.map_id(f)), i.map_id(f))) + } ast::Statement::Variable(_) => None, } } @@ -188,8 +359,8 @@ impl ast::PredAt { } } -impl ast::Instruction { - fn map_id U>(self, f: &mut F) -> ast::Instruction { +impl ast::Instruction { + fn map_id U>(self, f: &mut F) -> ast::Instruction { match self { ast::Instruction::Ld(d, a) => ast::Instruction::Ld(d, a.map_id(f)), ast::Instruction::Mov(d, a) => ast::Instruction::Mov(d, a.map_id(f)), @@ -208,6 +379,26 @@ impl ast::Instruction { } } +impl ast::Instruction { + fn jump_target(&self) -> Option { + match self { + ast::Instruction::Bra(d, a) => Some(a.dst), + ast::Instruction::Ld(_, _) + | ast::Instruction::Mov(_, _) + | ast::Instruction::Mul(_, _) + | ast::Instruction::Add(_, _) + | ast::Instruction::Setp(_, _) + | ast::Instruction::SetpBool(_, _) + | ast::Instruction::Not(_, _) + | ast::Instruction::Cvt(_, _) + | ast::Instruction::Shl(_, _) + | ast::Instruction::St(_, _) + | ast::Instruction::At(_, _) + | ast::Instruction::Ret(_) => None, + } + } +} + impl ast::Arg1 { fn map_id U>(self, f: &mut F) -> ast::Arg1 { ast::Arg1 { dst: f(self.dst) } @@ -279,7 +470,151 @@ impl ast::MovOperand { fn map_id U>(self, f: &mut F) -> ast::MovOperand { match self { ast::MovOperand::Op(o) => ast::MovOperand::Op(o.map_id(f)), - ast::MovOperand::Vec(s1, s2) => ast::MovOperand::Vec(s1, s2) + ast::MovOperand::Vec(s1, s2) => ast::MovOperand::Vec(s1, s2), } } -} \ No newline at end of file +} + +// CFGs below taken from "Modern Compiler Implementation in Java" +#[cfg(test)] +mod tests { + use super::*; + + #[test] + // page 411 + fn to_reverse_postorder1() { + let input = vec![ + BasicBlock { + // A + start: StmtIndex(0), + pred: vec![], + succ: vec![BBIndex(1), BBIndex(2)], + }, + BasicBlock { + // B + start: StmtIndex(1), + pred: vec![BBIndex(0), BBIndex(11)], + succ: vec![BBIndex(3), BBIndex(6)], + }, + BasicBlock { + // C + start: StmtIndex(2), + pred: vec![BBIndex(0), BBIndex(4)], + succ: vec![BBIndex(4), BBIndex(7)], + }, + BasicBlock { + // D + start: StmtIndex(3), + pred: vec![BBIndex(1)], + succ: vec![BBIndex(5), BBIndex(6)], + }, + BasicBlock { + // E + start: StmtIndex(4), + pred: vec![BBIndex(2)], + succ: vec![BBIndex(2), BBIndex(7)], + }, + BasicBlock { + // F + start: StmtIndex(5), + pred: vec![BBIndex(3)], + succ: vec![BBIndex(8), BBIndex(10)], + }, + BasicBlock { + // G + start: StmtIndex(6), + pred: vec![BBIndex(1), BBIndex(3)], + succ: vec![BBIndex(9)], + }, + BasicBlock { + // H + start: StmtIndex(7), + pred: vec![BBIndex(2), BBIndex(4)], + succ: vec![BBIndex(12)], + }, + BasicBlock { + // I + start: StmtIndex(8), + pred: vec![BBIndex(5), BBIndex(9)], + succ: vec![BBIndex(11)], + }, + BasicBlock { + // J + start: StmtIndex(9), + pred: vec![BBIndex(6)], + succ: vec![BBIndex(8)], + }, + BasicBlock { + // K + start: StmtIndex(10), + pred: vec![BBIndex(5)], + succ: vec![BBIndex(11)], + }, + BasicBlock { + // L + start: StmtIndex(11), + pred: vec![BBIndex(8), BBIndex(10)], + succ: vec![BBIndex(1), BBIndex(12)], + }, + BasicBlock { + // M + start: StmtIndex(12), + pred: vec![BBIndex(7), BBIndex(11)], + succ: vec![], + }, + ]; + let rpostord = to_reverse_postorder(&input); + assert_eq!( + rpostord, + vec![ + BBIndex(0), // A + BBIndex(2), // C + BBIndex(4), // E + BBIndex(7), // H + BBIndex(1), // B + BBIndex(3), // D + BBIndex(6), // G + BBIndex(9), // J + BBIndex(5), // F + BBIndex(10), // K + BBIndex(8), // I + BBIndex(11), // L + BBIndex(12), // M + ] + ); + } + + #[test] + fn get_basic_blocks_empty() { + let func = Vec::new(); + let bbs = get_basic_blocks(&func); + assert_eq!( + bbs, + vec![BasicBlock { + start: StmtIndex(0), + pred: vec![], + succ: vec![] + }] + ); + } + + #[test] + fn get_basic_blocks_miniloop() { + let func = vec![ + Statement::Label(12), + Statement::Instruction( + None, + ast::Instruction::Bra(ast::BraData {}, ast::Arg1 { dst: 12 }), + ), + ]; + let bbs = get_basic_blocks(&func); + assert_eq!( + bbs, + vec![BasicBlock { + start: StmtIndex(0), + pred: vec![BBIndex(0)], + succ: vec![BBIndex(0)] + }] + ); + } +}