Start doing SSA conversion

This commit is contained in:
Andrzej Janik 2020-04-22 00:55:49 +02:00
parent 0c71826bc7
commit 7b2bc69330
4 changed files with 155 additions and 38 deletions

View file

@ -1,7 +1,4 @@
use std::convert::From;
use std::convert::Into;
use std::error::Error;
use std::mem;
use std::num::ParseIntError;
quick_error! {

View file

@ -9,6 +9,7 @@ extern crate spirv_headers as spirv;
lalrpop_mod!(ptx);
#[cfg(test)]
mod test;
mod translate;
pub mod ast;

View file

@ -2,7 +2,7 @@ use super::ptx;
fn parse_and_assert(s: &str) {
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
assert!(errors.len() == 0);
}
@ -12,6 +12,7 @@ fn empty() {
}
#[test]
#[allow(non_snake_case)]
fn vectorAdd_kernel64_ptx() {
let vector_add = include_str!("vectorAdd_kernel64.ptx");
parse_and_assert(vector_add);

View file

@ -1,8 +1,8 @@
use crate::ast;
use bit_vec::BitVec;
use rspirv::dr;
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::{cell::RefCell, ptr};
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvType {
@ -57,23 +57,8 @@ impl TypeWordMap {
}
}
struct IdWordMap<'a>(HashMap<&'a str, spirv::Word>);
impl<'a> IdWordMap<'a> {
fn new() -> Self {
IdWordMap(HashMap::new())
}
}
impl<'a> IdWordMap<'a> {
fn get_or_add(&mut self, b: &mut dr::Builder, id: &'a str) -> spirv::Word {
*self.0.entry(id).or_insert_with(|| b.id())
}
}
pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
let mut builder = dr::Builder::new();
let mut ids = IdWordMap::new();
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
builder.set_version(1, 0);
emit_capabilities(&mut builder);
@ -82,7 +67,7 @@ pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
for f in ast.functions {
emit_function(&mut builder, &mut map, &mut ids, f)?;
emit_function(&mut builder, &mut map, f)?;
}
Ok(vec![])
}
@ -111,9 +96,8 @@ fn emit_memory_model(builder: &mut dr::Builder) {
fn emit_function<'a>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
ids: &mut IdWordMap<'a>,
f: ast::Function<'a>,
) -> Result<(), rspirv::dr::Error> {
) -> Result<spirv::Word, rspirv::dr::Error> {
let func_id = builder.begin_function(
map.void(),
None,
@ -128,15 +112,19 @@ fn emit_function<'a>(
let bbs = get_basic_blocks(&normalized_ids);
let rpostorder = to_reverse_postorder(&bbs);
let dom_fronts = dominance_frontiers(&bbs, &rpostorder);
let ssa = ssa_legalize(normalized_ids, dom_fronts);
emit_function_body_ops(ssa, builder);
let (ops, phis) = ssa_legalize(normalized_ids, bbs, &dom_fronts);
emit_function_body_ops(builder, ops, phis);
builder.ret()?;
builder.end_function()?;
Ok(())
Ok(func_id)
}
fn emit_function_body_ops(ssa: Vec<Statement>, builder: &mut dr::Builder) {
unimplemented!()
fn emit_function_body_ops(
builder: &mut dr::Builder,
ops: Vec<Statement>,
phis: Vec<RefCell<PhiBasicBlock>>,
) {
todo!()
}
// TODO: support scopes
@ -158,8 +146,47 @@ fn normalize_identifiers<'a>(func: Vec<ast::Statement<&'a str>>) -> Vec<Statemen
result
}
fn ssa_legalize(func: Vec<Statement>, dom_fronts: Vec<HashSet<BBIndex>>) -> Vec<Statement> {
unimplemented!()
fn ssa_legalize(
func: Vec<Statement>,
bbs: Vec<BasicBlock>,
dom_fronts: &Vec<HashSet<BBIndex>>,
) -> (Vec<Statement>, Vec<RefCell<PhiBasicBlock>>) {
let mut phis = gather_phi_sets(&func, &bbs, dom_fronts);
trim_singleton_phi_sets(&mut phis);
todo!()
}
fn gather_phi_sets(
func: &Vec<Statement>,
bbs: &Vec<BasicBlock>,
dom_fronts: &Vec<HashSet<BBIndex>>,
) -> Vec<HashMap<spirv::Word, HashSet<BBIndex>>> {
let mut phis = vec![HashMap::new(); bbs.len()];
for (bb_idx, bb) in bbs.iter().enumerate() {
let StmtIndex(start) = bb.start;
let end = if bb_idx == bbs.len() - 1 {
bbs.len()
} else {
bbs[bb_idx + 1].start.0
};
for s in func[start..end].iter() {
s.for_dst_id(&mut |id| {
for BBIndex(phi_target) in dom_fronts[bb_idx].iter() {
phis[*phi_target]
.entry(id)
.or_insert_with(|| HashSet::new())
.insert(BBIndex(bb_idx));
}
});
}
}
phis
}
fn trim_singleton_phi_sets(phis: &mut Vec<HashMap<spirv::Word, HashSet<BBIndex>>>) {
for phi_map in phis.iter_mut() {
phi_map.retain(|_, set| set.len() > 1);
}
}
fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
@ -179,7 +206,6 @@ fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
Statement::Label(id) => {
labels.insert(id, StmtIndex(idx));
}
Statement::Phi(_) => (),
};
}
let mut bbs_map = BTreeMap::new();
@ -322,10 +348,10 @@ fn to_reverse_postorder(input: &Vec<BasicBlock>) -> Vec<BBIndex> {
result
}
#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)]
struct StmtIndex(pub usize);
#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)]
struct BBIndex(pub usize);
struct PhiBasicBlock {
bb: BasicBlock,
phi: Vec<(spirv::Word, Vec<(spirv::Word, BBIndex)>)>,
}
#[derive(Eq, PartialEq, Debug, Clone)]
struct BasicBlock {
@ -334,10 +360,17 @@ struct BasicBlock {
succ: Vec<BBIndex>,
}
#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)]
struct StmtIndex(pub usize);
#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)]
struct BBIndex(pub usize);
enum Statement {
Label(u32),
Instruction(Option<ast::PredAt<u32>>, ast::Instruction<u32>),
Phi(Vec<spirv::Word>),
Instruction(
Option<ast::PredAt<spirv::Word>>,
ast::Instruction<spirv::Word>,
),
}
impl Statement {
@ -353,6 +386,16 @@ impl Statement {
ast::Statement::Variable(_) => None,
}
}
fn for_dst_id<F: FnMut(spirv::Word)>(&self, f: &mut F) {
match self {
Statement::Label(id) => f(*id),
Statement::Instruction(pred, inst) => {
pred.as_ref().map(|p| p.for_dst_id(f));
inst.for_dst_id(f);
}
}
}
}
impl<T> ast::PredAt<T> {
@ -364,6 +407,10 @@ impl<T> ast::PredAt<T> {
}
}
impl<T: Copy> ast::PredAt<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {}
}
impl<T> ast::Instruction<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Instruction<U> {
match self {
@ -387,7 +434,7 @@ impl<T> ast::Instruction<T> {
impl<T: Copy> ast::Instruction<T> {
fn jump_target(&self) -> Option<T> {
match self {
ast::Instruction::Bra(d, a) => Some(a.dst),
ast::Instruction::Bra(_, a) => Some(a.dst),
ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _)
| ast::Instruction::Mul(_, _)
@ -402,6 +449,24 @@ impl<T: Copy> ast::Instruction<T> {
| ast::Instruction::Ret(_) => None,
}
}
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
match self {
ast::Instruction::Bra(_, a) => a.for_dst_id(f),
ast::Instruction::Ld(_, a) => a.for_dst_id(f),
ast::Instruction::Mov(_, a) => a.for_dst_id(f),
ast::Instruction::Mul(_, a) => a.for_dst_id(f),
ast::Instruction::Add(_, a) => a.for_dst_id(f),
ast::Instruction::Setp(_, a) => a.for_dst_id(f),
ast::Instruction::SetpBool(_, a) => a.for_dst_id(f),
ast::Instruction::Not(_, a) => a.for_dst_id(f),
ast::Instruction::Cvt(_, a) => a.for_dst_id(f),
ast::Instruction::Shl(_, a) => a.for_dst_id(f),
ast::Instruction::St(_, a) => a.for_dst_id(f),
ast::Instruction::At(_, a) => a.for_dst_id(f),
ast::Instruction::Ret(_) => (),
}
}
}
impl<T> ast::Arg1<T> {
@ -410,6 +475,12 @@ impl<T> ast::Arg1<T> {
}
}
impl<T: Copy> ast::Arg1<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
f(self.dst)
}
}
impl<T> ast::Arg2<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2<U> {
ast::Arg2 {
@ -419,6 +490,12 @@ impl<T> ast::Arg2<T> {
}
}
impl<T: Copy> ast::Arg2<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
f(self.dst);
}
}
impl<T> ast::Arg2Mov<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2Mov<U> {
ast::Arg2Mov {
@ -428,6 +505,12 @@ impl<T> ast::Arg2Mov<T> {
}
}
impl<T: Copy> ast::Arg2Mov<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
f(self.dst);
}
}
impl<T> ast::Arg3<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg3<U> {
ast::Arg3 {
@ -438,6 +521,12 @@ impl<T> ast::Arg3<T> {
}
}
impl<T: Copy> ast::Arg3<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
f(self.dst);
}
}
impl<T> ast::Arg4<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg4<U> {
ast::Arg4 {
@ -449,6 +538,13 @@ impl<T> ast::Arg4<T> {
}
}
impl<T: Copy> ast::Arg4<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
f(self.dst1);
self.dst2.map(|t| f(t));
}
}
impl<T> ast::Arg5<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg5<U> {
ast::Arg5 {
@ -461,6 +557,13 @@ impl<T> ast::Arg5<T> {
}
}
impl<T: Copy> ast::Arg5<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
f(self.dst1);
self.dst2.map(|t| f(t));
}
}
impl<T> ast::Operand<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Operand<U> {
match self {
@ -471,6 +574,12 @@ impl<T> ast::Operand<T> {
}
}
impl<T: Copy> ast::Operand<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
unreachable!()
}
}
impl<T> ast::MovOperand<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::MovOperand<U> {
match self {
@ -480,6 +589,15 @@ impl<T> ast::MovOperand<T> {
}
}
impl<T: Copy> ast::MovOperand<T> {
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
match self {
ast::MovOperand::Op(o) => o.for_dst_id(f),
ast::MovOperand::Vec(_, _) => (),
}
}
}
// CFGs below taken from "Modern Compiler Implementation in Java"
#[cfg(test)]
mod tests {