From 7b2bc69330f2043791db01f96a4daf8198116503 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 22 Apr 2020 00:55:49 +0200 Subject: [PATCH] Start doing SSA conversion --- ptx/src/ast.rs | 3 - ptx/src/lib.rs | 1 + ptx/src/test/mod.rs | 3 +- ptx/src/translate.rs | 186 +++++++++++++++++++++++++++++++++++-------- 4 files changed, 155 insertions(+), 38 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index a7bbe1f..9089c01 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,7 +1,4 @@ use std::convert::From; -use std::convert::Into; -use std::error::Error; -use std::mem; use std::num::ParseIntError; quick_error! { diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 61c3444..f8bb7fd 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -9,6 +9,7 @@ extern crate spirv_headers as spirv; lalrpop_mod!(ptx); +#[cfg(test)] mod test; mod translate; pub mod ast; diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index e12097a..15876ad 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -2,7 +2,7 @@ use super::ptx; fn parse_and_assert(s: &str) { let mut errors = Vec::new(); - let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap(); + ptx::ModuleParser::new().parse(&mut errors, s).unwrap(); assert!(errors.len() == 0); } @@ -12,6 +12,7 @@ fn empty() { } #[test] +#[allow(non_snake_case)] fn vectorAdd_kernel64_ptx() { let vector_add = include_str!("vectorAdd_kernel64.ptx"); parse_and_assert(vector_add); diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 259bcd2..5584af5 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,8 +1,8 @@ use crate::ast; use bit_vec::BitVec; use rspirv::dr; +use std::cell::RefCell; use std::collections::{BTreeMap, HashMap, HashSet}; -use std::{cell::RefCell, ptr}; #[derive(PartialEq, Eq, Hash, Clone, Copy)] enum SpirvType { @@ -57,23 +57,8 @@ impl TypeWordMap { } } -struct IdWordMap<'a>(HashMap<&'a str, spirv::Word>); - -impl<'a> IdWordMap<'a> { - fn new() -> Self { - IdWordMap(HashMap::new()) - } -} - -impl<'a> IdWordMap<'a> { - fn get_or_add(&mut self, b: &mut dr::Builder, id: &'a str) -> spirv::Word { - *self.0.entry(id).or_insert_with(|| b.id()) - } -} - pub fn to_spirv(ast: ast::Module) -> Result, rspirv::dr::Error> { let mut builder = dr::Builder::new(); - let mut ids = IdWordMap::new(); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module builder.set_version(1, 0); emit_capabilities(&mut builder); @@ -82,7 +67,7 @@ pub fn to_spirv(ast: ast::Module) -> Result, rspirv::dr::Error> { emit_memory_model(&mut builder); let mut map = TypeWordMap::new(&mut builder); for f in ast.functions { - emit_function(&mut builder, &mut map, &mut ids, f)?; + emit_function(&mut builder, &mut map, f)?; } Ok(vec![]) } @@ -111,9 +96,8 @@ fn emit_memory_model(builder: &mut dr::Builder) { fn emit_function<'a>( builder: &mut dr::Builder, map: &mut TypeWordMap, - ids: &mut IdWordMap<'a>, f: ast::Function<'a>, -) -> Result<(), rspirv::dr::Error> { +) -> Result { let func_id = builder.begin_function( map.void(), None, @@ -128,15 +112,19 @@ fn emit_function<'a>( let bbs = get_basic_blocks(&normalized_ids); let rpostorder = to_reverse_postorder(&bbs); let dom_fronts = dominance_frontiers(&bbs, &rpostorder); - let ssa = ssa_legalize(normalized_ids, dom_fronts); - emit_function_body_ops(ssa, builder); + let (ops, phis) = ssa_legalize(normalized_ids, bbs, &dom_fronts); + emit_function_body_ops(builder, ops, phis); builder.ret()?; builder.end_function()?; - Ok(()) + Ok(func_id) } -fn emit_function_body_ops(ssa: Vec, builder: &mut dr::Builder) { - unimplemented!() +fn emit_function_body_ops( + builder: &mut dr::Builder, + ops: Vec, + phis: Vec>, +) { + todo!() } // TODO: support scopes @@ -158,8 +146,47 @@ fn normalize_identifiers<'a>(func: Vec>) -> Vec, dom_fronts: Vec>) -> Vec { - unimplemented!() +fn ssa_legalize( + func: Vec, + bbs: Vec, + dom_fronts: &Vec>, +) -> (Vec, Vec>) { + let mut phis = gather_phi_sets(&func, &bbs, dom_fronts); + trim_singleton_phi_sets(&mut phis); + todo!() +} + +fn gather_phi_sets( + func: &Vec, + bbs: &Vec, + dom_fronts: &Vec>, +) -> Vec>> { + let mut phis = vec![HashMap::new(); bbs.len()]; + for (bb_idx, bb) in bbs.iter().enumerate() { + let StmtIndex(start) = bb.start; + let end = if bb_idx == bbs.len() - 1 { + bbs.len() + } else { + bbs[bb_idx + 1].start.0 + }; + for s in func[start..end].iter() { + s.for_dst_id(&mut |id| { + for BBIndex(phi_target) in dom_fronts[bb_idx].iter() { + phis[*phi_target] + .entry(id) + .or_insert_with(|| HashSet::new()) + .insert(BBIndex(bb_idx)); + } + }); + } + } + phis +} + +fn trim_singleton_phi_sets(phis: &mut Vec>>) { + for phi_map in phis.iter_mut() { + phi_map.retain(|_, set| set.len() > 1); + } } fn get_basic_blocks(fun: &Vec) -> Vec { @@ -179,7 +206,6 @@ fn get_basic_blocks(fun: &Vec) -> Vec { Statement::Label(id) => { labels.insert(id, StmtIndex(idx)); } - Statement::Phi(_) => (), }; } let mut bbs_map = BTreeMap::new(); @@ -322,10 +348,10 @@ fn to_reverse_postorder(input: &Vec) -> Vec { result } -#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)] -struct StmtIndex(pub usize); -#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)] -struct BBIndex(pub usize); +struct PhiBasicBlock { + bb: BasicBlock, + phi: Vec<(spirv::Word, Vec<(spirv::Word, BBIndex)>)>, +} #[derive(Eq, PartialEq, Debug, Clone)] struct BasicBlock { @@ -334,10 +360,17 @@ struct BasicBlock { succ: Vec, } +#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)] +struct StmtIndex(pub usize); +#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)] +struct BBIndex(pub usize); + enum Statement { Label(u32), - Instruction(Option>, ast::Instruction), - Phi(Vec), + Instruction( + Option>, + ast::Instruction, + ), } impl Statement { @@ -353,6 +386,16 @@ impl Statement { ast::Statement::Variable(_) => None, } } + + fn for_dst_id(&self, f: &mut F) { + match self { + Statement::Label(id) => f(*id), + Statement::Instruction(pred, inst) => { + pred.as_ref().map(|p| p.for_dst_id(f)); + inst.for_dst_id(f); + } + } + } } impl ast::PredAt { @@ -364,6 +407,10 @@ impl ast::PredAt { } } +impl ast::PredAt { + fn for_dst_id(&self, f: &mut F) {} +} + impl ast::Instruction { fn map_id U>(self, f: &mut F) -> ast::Instruction { match self { @@ -387,7 +434,7 @@ impl ast::Instruction { impl ast::Instruction { fn jump_target(&self) -> Option { match self { - ast::Instruction::Bra(d, a) => Some(a.dst), + ast::Instruction::Bra(_, a) => Some(a.dst), ast::Instruction::Ld(_, _) | ast::Instruction::Mov(_, _) | ast::Instruction::Mul(_, _) @@ -402,6 +449,24 @@ impl ast::Instruction { | ast::Instruction::Ret(_) => None, } } + + fn for_dst_id(&self, f: &mut F) { + match self { + ast::Instruction::Bra(_, a) => a.for_dst_id(f), + ast::Instruction::Ld(_, a) => a.for_dst_id(f), + ast::Instruction::Mov(_, a) => a.for_dst_id(f), + ast::Instruction::Mul(_, a) => a.for_dst_id(f), + ast::Instruction::Add(_, a) => a.for_dst_id(f), + ast::Instruction::Setp(_, a) => a.for_dst_id(f), + ast::Instruction::SetpBool(_, a) => a.for_dst_id(f), + ast::Instruction::Not(_, a) => a.for_dst_id(f), + ast::Instruction::Cvt(_, a) => a.for_dst_id(f), + ast::Instruction::Shl(_, a) => a.for_dst_id(f), + ast::Instruction::St(_, a) => a.for_dst_id(f), + ast::Instruction::At(_, a) => a.for_dst_id(f), + ast::Instruction::Ret(_) => (), + } + } } impl ast::Arg1 { @@ -410,6 +475,12 @@ impl ast::Arg1 { } } +impl ast::Arg1 { + fn for_dst_id(&self, f: &mut F) { + f(self.dst) + } +} + impl ast::Arg2 { fn map_id U>(self, f: &mut F) -> ast::Arg2 { ast::Arg2 { @@ -419,6 +490,12 @@ impl ast::Arg2 { } } +impl ast::Arg2 { + fn for_dst_id(&self, f: &mut F) { + f(self.dst); + } +} + impl ast::Arg2Mov { fn map_id U>(self, f: &mut F) -> ast::Arg2Mov { ast::Arg2Mov { @@ -428,6 +505,12 @@ impl ast::Arg2Mov { } } +impl ast::Arg2Mov { + fn for_dst_id(&self, f: &mut F) { + f(self.dst); + } +} + impl ast::Arg3 { fn map_id U>(self, f: &mut F) -> ast::Arg3 { ast::Arg3 { @@ -438,6 +521,12 @@ impl ast::Arg3 { } } +impl ast::Arg3 { + fn for_dst_id(&self, f: &mut F) { + f(self.dst); + } +} + impl ast::Arg4 { fn map_id U>(self, f: &mut F) -> ast::Arg4 { ast::Arg4 { @@ -449,6 +538,13 @@ impl ast::Arg4 { } } +impl ast::Arg4 { + fn for_dst_id(&self, f: &mut F) { + f(self.dst1); + self.dst2.map(|t| f(t)); + } +} + impl ast::Arg5 { fn map_id U>(self, f: &mut F) -> ast::Arg5 { ast::Arg5 { @@ -461,6 +557,13 @@ impl ast::Arg5 { } } +impl ast::Arg5 { + fn for_dst_id(&self, f: &mut F) { + f(self.dst1); + self.dst2.map(|t| f(t)); + } +} + impl ast::Operand { fn map_id U>(self, f: &mut F) -> ast::Operand { match self { @@ -471,6 +574,12 @@ impl ast::Operand { } } +impl ast::Operand { + fn for_dst_id(&self, f: &mut F) { + unreachable!() + } +} + impl ast::MovOperand { fn map_id U>(self, f: &mut F) -> ast::MovOperand { match self { @@ -480,6 +589,15 @@ impl ast::MovOperand { } } +impl ast::MovOperand { + fn for_dst_id(&self, f: &mut F) { + match self { + ast::MovOperand::Op(o) => o.for_dst_id(f), + ast::MovOperand::Vec(_, _) => (), + } + } +} + // CFGs below taken from "Modern Compiler Implementation in Java" #[cfg(test)] mod tests {