From 3b433456a1428a423f7f5ec8aaa3e926eb9eea99 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 3 May 2020 23:53:56 +0200 Subject: [PATCH] Convert PTX predicates to a distinct conditional jump statement --- Cargo.toml | 6 +- ptx/src/translate.rs | 249 ++++++++++++++++++++++++++++++------------- 2 files changed, 178 insertions(+), 77 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e311ea7..5106d33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,4 +6,8 @@ members = [ "notcuda_inject", "notcuda_redirect", "ptx", -] \ No newline at end of file +] + +[patch.crates-io] +rspirv = { git = 'https://github.com/vosen/rspirv', branch = 'notcuda' } +spirv_headers = { git = 'https://github.com/vosen/rspirv', branch = 'notcuda' } \ No newline at end of file diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a0e8405..52de35d 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -113,7 +113,7 @@ fn emit_function<'a>( let rpostorder = to_reverse_postorder(&bbs); let doms = immediate_dominators(&bbs, &rpostorder); let dom_fronts = dominance_frontiers(&bbs, &doms); - ssa_legalize( + let phis = ssa_legalize( &mut normalized_ids, contant_ids.len() as u32, unique_ids, @@ -121,12 +121,35 @@ fn emit_function<'a>( &doms, &dom_fronts, ); - emit_function_body_ops(builder); + let id_offset = builder.reserve_ids(unique_ids); + emit_function_args(builder, id_offset, map, &f.args); + emit_function_body_ops(builder, id_offset, &normalized_ids, &bbs)?; + builder.end_function()?; builder.ret()?; builder.end_function()?; Ok(func_id) } +fn emit_function_args( + builder: &mut dr::Builder, + id_offset: spirv::Word, + map: &mut TypeWordMap, + args: &[ast::Argument], +) { + let mut id = id_offset; + for arg in args { + let result_type = map.get_or_add(builder, SpirvType::Base(arg.a_type)); + let inst = dr::Instruction::new( + spirv::Op::FunctionParameter, + Some(result_type), + Some(id), + Vec::new(), + ); + builder.function.as_mut().unwrap().parameters.push(inst); + id += 1; + } +} + fn collect_arg_ids<'a>(result: &mut HashMap<&'a str, spirv::Word>, args: &'a [ast::Argument<'a>]) { let mut id = result.len() as u32; for arg in args { @@ -152,13 +175,35 @@ fn collect_label_ids<'a>( } } -fn emit_function_body_ops(builder: &mut dr::Builder) { - todo!() +fn emit_function_body_ops( + builder: &mut dr::Builder, + id_offset: spirv::Word, + func: &[Statement], + cfg: &[BasicBlock], +) -> Result<(), dr::Error> { + for bb_idx in 0..cfg.len() { + let body = get_bb_body(func, cfg, BBIndex(bb_idx)); + if body.len() == 0 { + continue; + } + let header_id = if let Statement::Label(id) = body[0] { + Some(id_offset + id) + } else { + None + }; + builder.begin_block(header_id)?; + for s in body { + /* + match s { + Statement::Instruction(pred, inst) => (), + Statement::Label(_) => (), + } + */ + } + } + Ok(()) } -// This functions converts string identifiers to numeric identifiers in a normalized form, where -// - identifiers in the range [0..constant_identifiers.len()) are arguments and labels -// - identifiers in the range [constant_identifiers.len()..result.1) are variables // TODO: support scopes fn normalize_identifiers<'a>( func: Vec>, @@ -167,8 +212,8 @@ fn normalize_identifiers<'a>( let mut result = Vec::with_capacity(func.len()); let mut id: u32 = constant_identifiers.len() as u32; let mut remapped_ids = HashMap::new(); - let mut get_or_add = |key| { - constant_identifiers.get(key).map_or_else( + let mut get_or_add = |key| match key { + Some(key) => constant_identifiers.get(key).map_or_else( || { *remapped_ids.entry(key).or_insert_with(|| { let to_insert = id; @@ -177,12 +222,15 @@ fn normalize_identifiers<'a>( }) }, |id| *id, - ) + ), + None => { + let to_insert = id; + id += 1; + to_insert + } }; for s in func { - if let Some(s) = Statement::from_ast(s, &mut get_or_add) { - result.push(s); - } + Statement::from_ast(s, &mut result, &mut get_or_add); } (result, id) } @@ -195,13 +243,7 @@ fn ssa_legalize( doms: &[BBIndex], dom_fronts: &[HashSet], ) -> Vec> { - let phis = gather_phi_sets( - &func, - constant_ids, - unique_ids, - &bbs, - dom_fronts, - ); + let phis = gather_phi_sets(&func, constant_ids, unique_ids, &bbs, dom_fronts); apply_ssa_renaming(func, &bbs, doms, constant_ids, unique_ids, &phis) } @@ -431,12 +473,12 @@ fn gather_phi_sets( } } }; + // We try to avoid adding labels to the global-visbility set. + // We are not 100% precise (we add jump targets in bra), but it shouldn't be a problem for s in get_bb_body(func, cfg, BBIndex(bb)) { match s { - Statement::Instruction(pred, inst) => { - pred.as_ref().map(|p| p.visit_id(&mut visitor)); - inst.visit_id(&mut visitor); - } + Statement::Instruction(inst) => inst.visit_id(&mut visitor), + Statement::Conditional(brc) => visitor(false, &brc.predicate), // label redefinition is a compile-time error Statement::Label(_) => (), } @@ -464,20 +506,16 @@ fn gather_phi_sets( fn get_basic_blocks(fun: &[Statement]) -> Vec { // edge signify pred/succ relationship between bbs - let mut bb_edge = HashSet::new(); let mut unresolved_bb_edge = Vec::new(); // bb start means that a bb is starting at this statement, but there's no predecessor let mut bb_start = Vec::new(); let mut labels = HashMap::new(); for (idx, s) in fun.iter().enumerate() { match s { - Statement::Instruction(pred, i) => { + Statement::Instruction(i) => { if let Some(id) = i.jump_target() { unresolved_bb_edge.push((StmtIndex(idx), id)); if idx + 1 < fun.len() { - if pred.is_some() { - bb_edge.insert((StmtIndex(idx), StmtIndex(idx + 1))); - } bb_start.push(StmtIndex(idx + 1)); } } else if i.is_terminal() && idx + 1 < fun.len() { @@ -487,23 +525,32 @@ fn get_basic_blocks(fun: &[Statement]) -> Vec { Statement::Label(id) => { labels.insert(id, StmtIndex(idx)); } + Statement::Conditional(bra) => { + unresolved_bb_edge.push((StmtIndex(idx), bra.if_false)); + unresolved_bb_edge.push((StmtIndex(idx), bra.if_true)); + } }; } + let mut bb_edge = HashSet::new(); // Resolve every into // TODO: handle jumps into nowhere for (idx, id) in unresolved_bb_edge { let target = labels[&id]; bb_edge.insert((idx, target)); bb_start.push(target); - // now check if the preceding statement forms an edge + // now check if there is an edge target-1 -> target if target != StmtIndex(0) { match &fun[target.0 - 1] { - Statement::Instruction(pred, i) => { - if !((pred.is_none() && i.jump_target().is_some()) || i.is_terminal()) { + Statement::Instruction(i) => { + if !(i.jump_target().is_some() || i.is_terminal()) { bb_edge.insert((StmtIndex(target.0 - 1), target)); } } - Statement::Label(_) => (), + Statement::Label(_) => { + bb_edge.insert((StmtIndex(target.0 - 1), target)); + } + // This is already in `unresolved_bb_edge` + Statement::Conditional(_) => (), } } } @@ -665,43 +712,88 @@ impl fmt::Display for BBIndex { enum Statement { Label(u32), - Instruction( - Option>, - ast::Instruction, - ), + Instruction(ast::Instruction), + // SPIR-V compatible replacement for PTX predicates + Conditional(BrachCondition), +} + +struct BrachCondition { + predicate: spirv::Word, + if_true: spirv::Word, + if_false: spirv::Word, +} +impl BrachCondition { + fn visit_id(&self, f: &mut F) { + f(false, &self.predicate); + f(false, &self.if_true); + f(false, &self.if_false); + } + + fn visit_id_mut(&mut self, f: &mut F) { + f(false, &mut self.predicate); + f(false, &mut self.if_true); + f(false, &mut self.if_false); + } } impl Statement { - fn from_ast<'a, F: FnMut(&'a str) -> u32>( + fn from_ast<'a, F: FnMut(Option<&'a str>) -> u32>( s: ast::Statement<&'a str>, - f: &mut F, - ) -> Option { + out: &mut Vec, + get_id: &mut F, + ) { match s { - ast::Statement::Label(name) => Some(Statement::Label(f(name))), + ast::Statement::Label(name) => out.push(Statement::Label(get_id(Some(name)))), ast::Statement::Instruction(p, i) => { - Some(Statement::Instruction(p.map(|p| p.map_id(f)), i.map_id(f))) + if let Some(pred) = p { + let predicate = get_id(Some(pred.label)); + let mut if_true = get_id(None); + let mut if_false = get_id(None); + if pred.not { + std::mem::swap(&mut if_true, &mut if_false); + } + let folded_bra = match &i { + ast::Instruction::Bra(_, arg) => Some(get_id(Some(arg.src))), + _ => None, + }; + let branch = BrachCondition { + predicate, + if_true: folded_bra.unwrap_or(if_true), + if_false, + }; + out.push(Statement::Conditional(branch)); + if folded_bra.is_none() { + out.push(Statement::Label(if_true)); + out.push(Statement::Instruction( + i.map_id(&mut |name| get_id(Some(name))), + )); + } + out.push(Statement::Label(if_false)); + } else { + out.push(Statement::Instruction( + i.map_id(&mut |name| get_id(Some(name))), + )); + } } - ast::Statement::Variable(_) => None, + ast::Statement::Variable(_) => (), } } fn for_dst_id(&self, f: &mut F) { match self { Statement::Label(id) => f(*id), - Statement::Instruction(pred, inst) => { - pred.as_ref().map(|p| p.for_dst_id(f)); + Statement::Instruction(inst) => { inst.for_dst_id(f); } + Statement::Conditional(_) => (), } } fn visit_id(&self, f: &mut F) { match self { - Statement::Label(id) => f(true, id), - Statement::Instruction(pred, inst) => { - pred.as_ref().map(|p| p.visit_id(f)); - inst.visit_id(f); - } + Statement::Label(id) => f(false, id), + Statement::Instruction(inst) => inst.visit_id(f), + Statement::Conditional(bra) => bra.visit_id(f), } } @@ -709,11 +801,9 @@ impl Statement { // otherwise SSA renaming will yield weird results fn visit_id_mut(&mut self, f: &mut F) { match self { - Statement::Label(id) => f(true, id), - Statement::Instruction(pred, inst) => { - pred.as_mut().map(|p| p.visit_id_mut(f)); - inst.visit_id_mut(f); - } + Statement::Label(id) => f(false, id), + Statement::Instruction(inst) => inst.visit_id_mut(f), + Statement::Conditional(bra) => bra.visit_id_mut(f), } } } @@ -1182,10 +1272,10 @@ mod tests { fn get_basic_blocks_miniloop() { let func = vec![ Statement::Label(12), - Statement::Instruction( - None, - ast::Instruction::Bra(ast::BraData {}, ast::Arg1 { src: 12 }), - ), + Statement::Instruction(ast::Instruction::Bra( + ast::BraData {}, + ast::Arg1 { src: 12 }, + )), ]; let bbs = get_basic_blocks(&func); assert_eq!( @@ -1390,11 +1480,11 @@ mod tests { mov.u32 k, 0; block_2: setp.ge.u32 p, k, 100; - @p bra block_4; - block_3: + @p bra block_4; // conditional p block_4 if_false1 + // if_false1: setp.ge.u32 q, j, 20; - @q bra block_6; - block_5: + @q bra block_6; // conditional q block_6 if_false2 + // if_false2: mov.u32 j, i; add.u32 k, k, 1; bra block_7; @@ -1563,7 +1653,7 @@ mod tests { assert_eq!(errors.len(), 0); let mut constant_ids = HashMap::new(); collect_label_ids(&mut constant_ids, &fn_ast); - assert_eq!(constant_ids.len(), 6); + assert_eq!(constant_ids.len(), 4); let (normalized_ids, max_id) = normalize_identifiers(fn_ast, &constant_ids); let bbs = get_basic_blocks(&normalized_ids); let rpostorder = to_reverse_postorder(&bbs); @@ -1580,11 +1670,11 @@ mod tests { phi, vec![ HashSet::new(), - to_hashset(vec![7, 8]), + to_hashset(vec![5, 6]), HashSet::new(), HashSet::new(), HashSet::new(), - to_hashset(vec![7, 8]), + to_hashset(vec![5, 6]), HashSet::new() ] ); @@ -1634,12 +1724,12 @@ mod tests { let k5 = get_dst(&func[15]); let p1 = get_dst(&func[4]); let q1 = get_dst(&func[7]); - let block_2 = get_dst(&func[3]); - let block_3 = get_dst(&func[6]); - let block_5 = get_dst(&func[9]); - let block_6 = get_dst(&func[13]); - let block_7 = get_dst(&func[16]); - let block_4 = get_dst(&func[18]); + let block_2 = get_label(&func[3]); + let if_false1 = get_label(&func[6]); + let if_false2 = get_label(&func[9]); + let block_6 = get_label(&func[13]); + let block_7 = get_label(&func[16]); + let block_4 = get_label(&func[18]); { assert_eq!(get_ids(&func[0]), vec![i1]); @@ -1652,13 +1742,13 @@ mod tests { ); assert_eq!(get_ids(&func[3]), vec![block_2]); assert_eq!(get_ids(&func[4]), vec![p1, k2]); - assert_eq!(get_ids(&func[5]), vec![p1, block_4]); + assert_eq!(get_ids(&func[5]), vec![p1, block_4, if_false1]); - assert_eq!(get_ids(&func[6]), vec![block_3]); + assert_eq!(get_ids(&func[6]), vec![if_false1]); assert_eq!(get_ids(&func[7]), vec![q1, j2]); - assert_eq!(get_ids(&func[8]), vec![q1, block_6]); + assert_eq!(get_ids(&func[8]), vec![q1, block_6, if_false2]); - assert_eq!(get_ids(&func[9]), vec![block_5]); + assert_eq!(get_ids(&func[9]), vec![if_false2]); assert_eq!(get_ids(&func[10]), vec![j3, i1]); assert_eq!(get_ids(&func[11]), vec![k3, k2]); assert_eq!(get_ids(&func[12]), vec![block_7]); @@ -1739,6 +1829,13 @@ mod tests { result.unwrap() } + fn get_label(s: &Statement) -> spirv::Word { + match s { + Statement::Label(id) => *id, + _ => panic!(), + } + } + fn get_dst_from_src(phi: &[PhiDef], src: spirv::Word) -> spirv::Word { for phi_set in phi { if phi_set.src.contains(&src) {