From 09be47a9193118f03f08f2c88f9294b679e888e1 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 23 Jul 2020 01:26:40 +0200 Subject: [PATCH] Start refactoring code to not use homemade strict-SSA translator --- ptx/src/ast.rs | 1 + ptx/src/translate.rs | 2156 +++++++++++------------------------------- 2 files changed, 559 insertions(+), 1598 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 0efc37c..979bedf 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -112,6 +112,7 @@ pub struct Variable { pub count: Option, } +#[derive(Copy, Clone, PartialEq, Eq)] pub enum StateSpace { Reg, Sreg, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 0d86066..7cce63c 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -151,16 +151,16 @@ fn emit_function<'a>( if f.kernel { builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[]); } - let (mut func_body, bbs, _, unique_ids) = to_ssa(&f.args, f.body); + let (mut func_body, unique_ids) = to_ssa(&f.args, f.body); let id_offset = builder.reserve_ids(unique_ids); emit_function_args(builder, id_offset, map, &f.args); apply_id_offset(&mut func_body, id_offset); - emit_function_body_ops(builder, map, &func_body, &bbs)?; + emit_function_body_ops(builder, map, &func_body)?; builder.end_function()?; Ok(func_id) } -fn apply_id_offset(func_body: &mut Vec, id_offset: u32) { +fn apply_id_offset(func_body: &mut Vec, id_offset: u32) { for s in func_body { s.visit_id_mut(&mut |_, id| *id += id_offset); } @@ -169,61 +169,27 @@ fn apply_id_offset(func_body: &mut Vec, id_offset: u32) { fn to_ssa<'a>( f_args: &[ast::Argument], f_body: Vec>, -) -> ( - Vec, - Vec, - Vec>, - spirv::Word, -) { - let mut contant_ids = HashMap::new(); - let mut type_check = HashMap::new(); - collect_arg_ids(&mut contant_ids, &mut type_check, &f_args); - collect_label_ids(&mut contant_ids, &f_body); - let registers = collect_var_definitions(&f_args, &f_body); - let (normalized_ids, mut unique_ids) = - normalize_identifiers(f_body, &contant_ids, &mut type_check, registers); - let type_check = RefCell::new(type_check); - let new_id = &mut |typ: Option| { - let to_insert = unique_ids; - { - let mut type_check = type_check.borrow_mut(); - typ.map(|t| (*type_check).insert(to_insert, t)); - } - unique_ids += 1; - to_insert - }; - let normalized_stmts = normalize_statements(normalized_ids, new_id); - let mut func_body = insert_implicit_conversions(normalized_stmts, new_id, &|x| { - let type_check = type_check.borrow(); - type_check[&x] - }); - let bbs = get_basic_blocks(&func_body); - let rpostorder = to_reverse_postorder(&bbs); - let doms = immediate_dominators(&bbs, &rpostorder); - let dom_fronts = dominance_frontiers(&bbs, &doms); - let (phis, unique_ids) = ssa_legalize( - &mut func_body, - contant_ids.len() as u32, - unique_ids, - &bbs, - &doms, - &dom_fronts, - ); - (func_body, bbs, phis, unique_ids) +) -> (Vec, spirv::Word) { + let (normalized_ids, mut id_def) = normalize_identifiers(&f_args, f_body); + let normalized_statements = normalize_predicates(normalized_ids, &mut id_def); + let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut id_def); + let expanded_statements = expand_arguments(ssa_statements, &mut id_def); + let expanded_statements = insert_implicit_conversions(expanded_statements, &mut id_def); + (expanded_statements, id_def.ids_count()) } -fn normalize_statements( +fn normalize_predicates( func: Vec>, - new_id: &mut impl FnMut(Option) -> spirv::Word, -) -> Vec { + id_def: &mut NumericIdResolver, +) -> Vec> { let mut result = Vec::with_capacity(func.len()); for s in func { match s { ast::Statement::Label(id) => result.push(Statement::Label(id)), ast::Statement::Instruction(pred, inst) => { if let Some(pred) = pred { - let mut if_true = new_id(None); - let mut if_false = new_id(None); + let mut if_true = id_def.new_id(None); + let mut if_false = id_def.new_id(None); if pred.not { std::mem::swap(&mut if_true, &mut if_false); } @@ -239,16 +205,82 @@ fn normalize_statements( result.push(Statement::Conditional(branch)); if folded_bra.is_none() { result.push(Statement::Label(if_true)); - let instr = normalize_insert_instruction(&mut result, new_id, inst); - result.push(Statement::Instruction(instr)); + result.push(Statement::Instruction(Instruction::from_ast(inst))); } result.push(Statement::Label(if_false)); } else { - let instr = normalize_insert_instruction(&mut result, new_id, inst); - result.push(Statement::Instruction(instr)); + result.push(Statement::Instruction(Instruction::from_ast(inst))); } } - ast::Statement::Variable(_) => unreachable!(), + ast::Statement::Variable(var) => result.push(Statement::Variable(var.name, var.v_type)), + } + } + result +} + +fn insert_mem_ssa_statements( + func: Vec>, + id_def: &mut NumericIdResolver, +) -> Vec> { + let mut result = Vec::with_capacity(func.len()); + for s in func { + match s { + Statement::Instruction(mut inst) => { + let inst_type = inst.get_type(); + let mut post_statements = Vec::new(); + inst.visit_id_mut(&mut |is_dst, id| { + let inst_type = inst_type.unwrap(); + let generated_id = id_def.new_id(Some(inst_type)); + if !is_dst { + result.push(Statement::LoadVar( + Arg2 { + dst: generated_id, + src: *id, + }, + inst_type, + )); + } else { + post_statements.push(Statement::StoreVar( + Arg2St { + src1: *id, + src2: generated_id, + }, + inst_type, + )); + } + *id = generated_id; + }); + result.push(Statement::Instruction(inst)); + result.append(&mut post_statements); + } + s @ Statement::Variable(_, _) + | s @ Statement::Label(_) + | s @ Statement::Conditional(_) => result.push(s), + Statement::LoadVar(_, _) + | Statement::StoreVar(_, _) + | Statement::Converison(_) + | Statement::Constant(_) => unreachable!(), + } + } + result +} + +fn expand_arguments( + func: Vec>, + id_def: &mut NumericIdResolver, +) -> Vec { + let mut result = Vec::with_capacity(func.len()); + for s in func { + match s { + Statement::Instruction(inst) => { + normalize_insert_instruction(&mut result, id_def, inst); + } + Statement::Variable(id, typ) => result.push(Statement::Variable(id, typ)), + Statement::Label(id) => result.push(Statement::Label(id)), + Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), + Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), + Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)), + Statement::Converison(_) | Statement::Constant(_) => unreachable!(), } } result @@ -256,137 +288,137 @@ fn normalize_statements( #[must_use] fn normalize_insert_instruction( - func: &mut Vec, - new_id: &mut impl FnMut(Option) -> spirv::Word, - instr: ast::Instruction, -) -> Instruction { + func: &mut Vec, + id_def: &mut NumericIdResolver, + instr: Instruction, +) -> Instruction { match instr { - ast::Instruction::Ld(d, a) => { - let arg = normalize_expand_arg2(func, new_id, &|| Some(d.typ), a); + Instruction::Ld(d, a) => { + let arg = normalize_expand_arg2(func, id_def, &|| Some(d.typ), a); Instruction::Ld(d, arg) } - ast::Instruction::Mov(d, a) => { - let arg = normalize_expand_arg2mov(func, new_id, &|| d.typ.try_as_scalar(), a); + Instruction::Mov(d, a) => { + let arg = normalize_expand_arg2mov(func, id_def, &|| d.typ.try_as_scalar(), a); Instruction::Mov(d, arg) } - ast::Instruction::Mul(d, a) => { - let arg = normalize_expand_arg3(func, new_id, &|| d.typ.try_as_scalar(), a); + Instruction::Mul(d, a) => { + let arg = normalize_expand_arg3(func, id_def, &|| d.typ.try_as_scalar(), a); Instruction::Mul(d, arg) } - ast::Instruction::Add(d, a) => { - let arg = normalize_expand_arg3(func, new_id, &|| Some(d.typ), a); + Instruction::Add(d, a) => { + let arg = normalize_expand_arg3(func, id_def, &|| Some(d.typ), a); Instruction::Add(d, arg) } - ast::Instruction::Setp(d, a) => { - let arg = normalize_expand_arg4(func, new_id, &|| Some(d.typ), a); + Instruction::Setp(d, a) => { + let arg = normalize_expand_arg4(func, id_def, &|| Some(d.typ), a); Instruction::Setp(d, arg) } - ast::Instruction::SetpBool(d, a) => { - let arg = normalize_expand_arg5(func, new_id, &|| Some(d.typ), a); + Instruction::SetpBool(d, a) => { + let arg = normalize_expand_arg5(func, id_def, &|| Some(d.typ), a); Instruction::SetpBool(d, arg) } - ast::Instruction::Not(d, a) => { - let arg = normalize_expand_arg2(func, new_id, &|| todo!(), a); + Instruction::Not(d, a) => { + let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a); Instruction::Not(d, arg) } - ast::Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }), - ast::Instruction::Cvt(d, a) => { - let arg = normalize_expand_arg2(func, new_id, &|| todo!(), a); + Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }), + Instruction::Cvt(d, a) => { + let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a); Instruction::Cvt(d, arg) } - ast::Instruction::Shl(d, a) => { - let arg = normalize_expand_arg3(func, new_id, &|| todo!(), a); + Instruction::Shl(d, a) => { + let arg = normalize_expand_arg3(func, id_def, &|| todo!(), a); Instruction::Shl(d, arg) } - ast::Instruction::St(d, a) => { - let arg = normalize_expand_arg2st(func, new_id, &|| todo!(), a); + Instruction::St(d, a) => { + let arg = normalize_expand_arg2st(func, id_def, &|| todo!(), a); Instruction::St(d, arg) } - ast::Instruction::Ret(d) => Instruction::Ret(d), + Instruction::Ret(d) => Instruction::Ret(d), } } fn normalize_expand_arg2( - func: &mut Vec, - new_id: &mut impl FnMut(Option) -> spirv::Word, + func: &mut Vec, + id_def: &mut NumericIdResolver, inst_type: &impl Fn() -> Option, a: ast::Arg2, ) -> Arg2 { Arg2 { dst: a.dst, - src: normalize_expand_operand(func, new_id, inst_type, a.src), + src: normalize_expand_operand(func, id_def, inst_type, a.src), } } fn normalize_expand_arg2mov( - func: &mut Vec, - new_id: &mut impl FnMut(Option) -> spirv::Word, + func: &mut Vec, + id_def: &mut NumericIdResolver, inst_type: &impl Fn() -> Option, a: ast::Arg2Mov, ) -> Arg2 { Arg2 { dst: a.dst, - src: normalize_expand_mov_operand(func, new_id, inst_type, a.src), + src: normalize_expand_mov_operand(func, id_def, inst_type, a.src), } } fn normalize_expand_arg2st( - func: &mut Vec, - new_id: &mut impl FnMut(Option) -> spirv::Word, + func: &mut Vec, + id_def: &mut NumericIdResolver, inst_type: &impl Fn() -> Option, a: ast::Arg2St, ) -> Arg2St { Arg2St { - src1: normalize_expand_operand(func, new_id, inst_type, a.src1), - src2: normalize_expand_operand(func, new_id, inst_type, a.src2), + src1: normalize_expand_operand(func, id_def, inst_type, a.src1), + src2: normalize_expand_operand(func, id_def, inst_type, a.src2), } } fn normalize_expand_arg3( - func: &mut Vec, - new_id: &mut impl FnMut(Option) -> spirv::Word, + func: &mut Vec, + id_def: &mut NumericIdResolver, inst_type: &impl Fn() -> Option, a: ast::Arg3, ) -> Arg3 { Arg3 { dst: a.dst, - src1: normalize_expand_operand(func, new_id, inst_type, a.src1), - src2: normalize_expand_operand(func, new_id, inst_type, a.src2), + src1: normalize_expand_operand(func, id_def, inst_type, a.src1), + src2: normalize_expand_operand(func, id_def, inst_type, a.src2), } } fn normalize_expand_arg4( - func: &mut Vec, - new_id: &mut impl FnMut(Option) -> spirv::Word, + func: &mut Vec, + id_def: &mut NumericIdResolver, inst_type: &impl Fn() -> Option, a: ast::Arg4, ) -> Arg4 { Arg4 { dst1: a.dst1, dst2: a.dst2, - src1: normalize_expand_operand(func, new_id, inst_type, a.src1), - src2: normalize_expand_operand(func, new_id, inst_type, a.src2), + src1: normalize_expand_operand(func, id_def, inst_type, a.src1), + src2: normalize_expand_operand(func, id_def, inst_type, a.src2), } } fn normalize_expand_arg5( - func: &mut Vec, - new_id: &mut impl FnMut(Option) -> spirv::Word, + func: &mut Vec, + id_def: &mut NumericIdResolver, inst_type: &impl Fn() -> Option, a: ast::Arg5, ) -> Arg5 { Arg5 { dst1: a.dst1, dst2: a.dst2, - src1: normalize_expand_operand(func, new_id, inst_type, a.src1), - src2: normalize_expand_operand(func, new_id, inst_type, a.src2), - src3: normalize_expand_operand(func, new_id, inst_type, a.src3), + src1: normalize_expand_operand(func, id_def, inst_type, a.src1), + src2: normalize_expand_operand(func, id_def, inst_type, a.src2), + src3: normalize_expand_operand(func, id_def, inst_type, a.src3), } } fn normalize_expand_operand( - func: &mut Vec, - new_id: &mut impl FnMut(Option) -> spirv::Word, + func: &mut Vec, + id_def: &mut NumericIdResolver, inst_type: &impl Fn() -> Option, opr: ast::Operand, ) -> spirv::Word { @@ -394,7 +426,7 @@ fn normalize_expand_operand( ast::Operand::Reg(r) => r, ast::Operand::Imm(x) => { if let Some(typ) = inst_type() { - let id = new_id(Some(ast::Type::Scalar(typ))); + let id = id_def.new_id(Some(ast::Type::Scalar(typ))); func.push(Statement::Constant(ConstantDefinition { dst: id, typ: typ, @@ -410,43 +442,17 @@ fn normalize_expand_operand( } fn normalize_expand_mov_operand( - func: &mut Vec, - new_id: &mut impl FnMut(Option) -> spirv::Word, + func: &mut Vec, + id_def: &mut NumericIdResolver, inst_type: &impl Fn() -> Option, opr: ast::MovOperand, ) -> spirv::Word { match opr { - ast::MovOperand::Op(opr) => normalize_expand_operand(func, new_id, inst_type, opr), + ast::MovOperand::Op(opr) => normalize_expand_operand(func, id_def, inst_type, opr), _ => todo!(), } } -fn collect_var_definitions<'a>( - args: &[ast::Argument<'a>], - body: &[ast::Statement<&'a str>], -) -> HashMap, ast::Type> { - let mut result = HashMap::new(); - for param in args { - result.insert(Cow::Borrowed(param.name), ast::Type::Scalar(param.a_type)); - } - for s in body { - match s { - ast::Statement::Variable(var) => match var.count { - Some(count) => { - for i in 0..count { - result.insert(Cow::Owned(format!("{}{}", var.name, i)), var.v_type); - } - } - None => { - result.insert(Cow::Borrowed(var.name), var.v_type); - } - }, - ast::Statement::Label(_) | ast::Statement::Instruction(_, _) => (), - } - } - result -} - /* There are several kinds of implicit conversions in PTX: * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands @@ -460,11 +466,10 @@ fn collect_var_definitions<'a>( - generic st: for instruction `st [x], y`, x must be of type b64/u64/s64, which is bitcast to a pointer */ -fn insert_implicit_conversions ast::Type>( - normalized_ids: Vec, - new_id: &mut impl FnMut(Option) -> spirv::Word, - type_check: &TypeCheck, -) -> Vec { +fn insert_implicit_conversions( + normalized_ids: Vec, + id_def: &mut NumericIdResolver, +) -> Vec { let mut result = Vec::with_capacity(normalized_ids.len()); for s in normalized_ids.into_iter() { match s { @@ -473,16 +478,14 @@ fn insert_implicit_conversions ast::Type>( arg.src = insert_implicit_conversions_ld_src( &mut result, ast::Type::Scalar(ld.typ), - type_check, - new_id, + id_def, ld.state_space, arg.src, ); insert_with_implicit_conversion_dst( &mut result, ld.typ, - type_check, - new_id, + id_def, should_convert_relaxed_dst, arg, |arg| &mut arg.dst, @@ -490,11 +493,11 @@ fn insert_implicit_conversions ast::Type>( ); } Instruction::St(st, mut arg) => { - let arg_src2_type = type_check(arg.src2); + let arg_src2_type = id_def.get_type(arg.src2); if let Some(conv) = should_convert_relaxed_src(arg_src2_type, st.typ) { arg.src2 = insert_conversion_src( &mut result, - new_id, + id_def, arg.src2, arg_src2_type, ast::Type::Scalar(st.typ), @@ -504,17 +507,19 @@ fn insert_implicit_conversions ast::Type>( arg.src1 = insert_implicit_conversions_ld_src( &mut result, ast::Type::Scalar(st.typ), - type_check, - new_id, + id_def, st.state_space.to_ld_ss(), arg.src1, ); result.push(Statement::Instruction(Instruction::St(st, arg))); } - inst @ _ => insert_implicit_bitcasts(&mut result, type_check, new_id, inst), + inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst), }, s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s), - Statement::Constant(_) => (), + Statement::Constant(_) + | Statement::Variable(_, _) + | Statement::LoadVar(_, _) + | Statement::StoreVar(_, _) => (), Statement::Converison(_) => unreachable!(), } } @@ -582,76 +587,62 @@ fn collect_label_ids<'a>( fn emit_function_body_ops( builder: &mut dr::Builder, map: &mut TypeWordMap, - func: &[Statement], - cfg: &[BasicBlock], + func: &[ExpandedStatement], ) -> Result<(), dr::Error> { - // TODO: entry basic block can't be target of jumps, - // we need to emit additional BB for this purpose - for bb_idx in 0..cfg.len() { - let body = get_bb_body(func, cfg, BBIndex(bb_idx)); - if body.len() == 0 { - continue; - } - let header_id = if let Statement::Label(id) = body[0] { - Some(id) - } else { - None - }; - builder.begin_block(header_id)?; - for s in body { - match s { - // If block starts with a label it has already been emitted, - // all other labels in the block are unused - Statement::Label(_) => (), - Statement::Constant(_) => todo!(), - Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?, - Statement::Conditional(bra) => { - builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?; - } - Statement::Instruction(inst) => match inst { - // SPIR-V does not support marking jumps as guaranteed-converged - Instruction::Bra(_, arg) => { - builder.branch(arg.src)?; - } - Instruction::Ld(data, arg) => { - if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() { - todo!() - } - let result_type = map.get_or_add_scalar(builder, data.typ); - match data.state_space { - ast::LdStateSpace::Generic => { - builder.load(result_type, Some(arg.dst), arg.src, None, [])?; - } - ast::LdStateSpace::Param => { - builder.copy_object(result_type, Some(arg.dst), arg.src)?; - } - _ => todo!(), - } - } - Instruction::St(data, arg) => { - if data.qualifier != ast::LdStQualifier::Weak - || data.vector.is_some() - || data.state_space != ast::StStateSpace::Generic - { - todo!() - } - builder.store(arg.src1, arg.src2, None, &[])?; - } - // SPIR-V does not support ret as guaranteed-converged - Instruction::Ret(_) => builder.ret()?, - Instruction::Mov(mov, arg) => { - let result_type = map.get_or_add(builder, SpirvType::from(mov.typ)); - builder.copy_object(result_type, Some(arg.dst), arg.src)?; - } - Instruction::Mul(mul, arg) => match mul.desc { - ast::MulDescriptor::Int(ref ctr) => { - emit_mul_int(builder, map, mul.typ, ctr, arg) - } - ast::MulDescriptor::Float(_) => todo!(), - }, - _ => todo!(), - }, + for s in func { + match s { + // If block starts with a label it has already been emitted, + // all other labels in the block are unused + Statement::Label(_) => (), + Statement::Constant(_) => todo!(), + Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?, + Statement::Conditional(bra) => { + builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?; } + Statement::Instruction(inst) => match inst { + // SPIR-V does not support marking jumps as guaranteed-converged + Instruction::Bra(_, arg) => { + builder.branch(arg.src)?; + } + Instruction::Ld(data, arg) => { + if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() { + todo!() + } + let result_type = map.get_or_add_scalar(builder, data.typ); + match data.state_space { + ast::LdStateSpace::Generic => { + builder.load(result_type, Some(arg.dst), arg.src, None, [])?; + } + ast::LdStateSpace::Param => { + builder.copy_object(result_type, Some(arg.dst), arg.src)?; + } + _ => todo!(), + } + } + Instruction::St(data, arg) => { + if data.qualifier != ast::LdStQualifier::Weak + || data.vector.is_some() + || data.state_space != ast::StStateSpace::Generic + { + todo!() + } + builder.store(arg.src1, arg.src2, None, &[])?; + } + // SPIR-V does not support ret as guaranteed-converged + Instruction::Ret(_) => builder.ret()?, + Instruction::Mov(mov, arg) => { + let result_type = map.get_or_add(builder, SpirvType::from(mov.typ)); + builder.copy_object(result_type, Some(arg.dst), arg.src)?; + } + Instruction::Mul(mul, arg) => match mul.desc { + ast::MulDescriptor::Int(ref ctr) => { + emit_mul_int(builder, map, mul.typ, ctr, arg) + } + ast::MulDescriptor::Float(_) => todo!(), + }, + _ => todo!(), + }, + _ => todo!(), } } Ok(()) @@ -734,567 +725,225 @@ fn emit_implicit_conversion( // TODO: support scopes fn normalize_identifiers<'a>( + args: &'a [ast::Argument<'a>], func: Vec>, - constant_identifiers: &HashMap<&'a str, spirv::Word>, // arguments and labels can't be redefined - type_map: &mut HashMap, - types: HashMap, ast::Type>, -) -> (Vec>, spirv::Word) { - let mut id: u32 = constant_identifiers.len() as u32; - let mut remapped_ids = HashMap::new(); - let mut get_or_add = |key| { - constant_identifiers.get(key).map_or_else( - || { - *remapped_ids.entry(key).or_insert_with(|| { - let to_insert = id; - id += 1; - to_insert - }) - }, - |id| *id, - ) - }; - let result = func - .into_iter() - .filter_map(|s| Statement::from_ast(s, &mut get_or_add)) - .collect::>(); - type_map.extend( - remapped_ids - .into_iter() - .map(|(old_id, new_id)| (new_id, types[old_id])), - ); - (result, id) -} - -fn ssa_legalize( - func: &mut [Statement], - constant_ids: spirv::Word, - unique_ids: spirv::Word, - bbs: &[BasicBlock], - doms: &[BBIndex], - dom_fronts: &[HashSet], -) -> (Vec>, spirv::Word) { - let phis = gather_phi_sets(&func, constant_ids, unique_ids, &bbs, dom_fronts); - apply_ssa_renaming(func, &bbs, doms, constant_ids, unique_ids, &phis) -} - -/* "Modern Compiler Implementation in Java" - Algorithm 19.7 - * This algorithm modifies passed function body in-place by renumbering ids, - * result ids can be divided into following categories - * - if id < constant_ids - * it's a non-redefinable id - * - if id >= constant_ids && id < all_ids - * then it's an undefined id (a0, b0, c0) - * - if id >= all_ids - * then it's a normally redefined id - */ -fn apply_ssa_renaming( - func: &mut [Statement], - bbs: &[BasicBlock], - doms: &[BBIndex], - constant_ids: spirv::Word, - all_ids: spirv::Word, - old_phi: &[HashSet], -) -> (Vec>, spirv::Word) { - let mut dom_tree = vec![Vec::new(); bbs.len()]; - for (bb, idom) in doms.iter().enumerate().skip(1) { - dom_tree[idom.0].push(BBIndex(bb)); +) -> (Vec>, NumericIdResolver) { + let mut id_defs = StringIdResolver::new(); + for arg in args { + id_defs.add_def(arg.name, Some(ast::Type::Scalar(arg.a_type))); } - let mut old_dst_id = vec![Vec::new(); bbs.len()]; - for bb in 0..bbs.len() { - for s in get_bb_body(func, bbs, BBIndex(bb)) { - s.visit_id(&mut |is_dst, id| { - if is_dst { - old_dst_id[bb].push(id) - } - }); - } + let mut result = Vec::new(); + for s in func { + expand_map_ids(&mut id_defs, &mut result, s); } - let mut new_phi = old_phi - .iter() - .map(|ids| { - ids.iter() - .map(|id| (*id, (u32::max_value(), HashSet::new()))) - .collect::>() - }) - .collect::>(); - let mut ssa_state = SSARewriteState::new(constant_ids, all_ids); - // once again, we do explicit stack - let mut state = Vec::new(); - state.push((BBIndex(0), 0)); - loop { - if let Some((BBIndex(bb), dom_succ_idx)) = state.last_mut() { - let bb = *bb; - if *dom_succ_idx == 0 { - rename_phi_dst(&mut ssa_state, &mut new_phi[bb]); - rename_bb_body(&mut ssa_state, func, bbs, BBIndex(bb)); - for BBIndex(succ_idx) in bbs[bb].succ.iter() { - rename_succesor_phi_src(&ssa_state, &mut new_phi[*succ_idx]); - } - } - if let Some(s) = dom_tree[bb].get(*dom_succ_idx) { - *dom_succ_idx += 1; - state.push((*s, 0)); - } else { - state.pop(); - pop_stacks(&mut ssa_state, &old_phi[bb], &old_dst_id[bb]); - } - } else { - break; - } - } - let phi = new_phi - .into_iter() - .map(|map| { - map.into_iter() - .map(|(_, (new_id, defs))| PhiDef { - dst: new_id, - src: defs, - }) - .collect::>() - }) - .collect::>(); - (phi, ssa_state.next_id()) + (result, id_defs.finish()) } -// before ssa-renaming every phi is x <- phi(x,x,x,x) -#[derive(Debug, PartialEq)] -struct PhiDef { - dst: spirv::Word, - src: HashSet, -} - -fn rename_phi_dst( - rewriter: &mut SSARewriteState, - phi: &mut HashMap)>, +fn expand_map_ids<'a>( + id_defs: &mut StringIdResolver<'a>, + result: &mut Vec>, + s: ast::Statement<&'a str>, ) { - for (old_k, (new_k, _)) in phi.iter_mut() { - *new_k = rewriter.redefine(*old_k); - } -} - -fn rename_bb_body( - ssa_state: &mut SSARewriteState, - func: &mut [Statement], - all_bb: &[BasicBlock], - bb: BBIndex, -) { - for s in get_bb_body_mut(func, all_bb, bb) { - s.visit_id_mut(&mut |is_dst, id| { - if is_dst { - *id = ssa_state.redefine(*id); - } else { - *id = ssa_state.get(*id); - } - }); - } -} - -fn rename_succesor_phi_src( - ssa_state: &SSARewriteState, - phi: &mut HashMap)>, -) { - for (id, (_, v)) in phi.iter_mut() { - v.insert(ssa_state.get(*id)); - } -} - -fn pop_stacks( - ssa_state: &mut SSARewriteState, - old_phi: &HashSet, - old_ids: &[spirv::Word], -) { - for id in old_phi.iter().chain(old_ids) { - ssa_state.pop(*id); - } -} - -fn get_bb_body_mut<'a>( - func: &'a mut [Statement], - all_bb: &[BasicBlock], - bb: BBIndex, -) -> &'a mut [Statement] { - let (start, end) = get_bb_body_idx(func, all_bb, bb); - &mut func[start..end] -} - -fn get_bb_body<'a>(func: &'a [Statement], all_bb: &[BasicBlock], bb: BBIndex) -> &'a [Statement] { - let (start, end) = get_bb_body_idx(func, all_bb, bb); - &func[start..end] -} - -fn get_bb_body_idx(func: &[Statement], all_bb: &[BasicBlock], bb: BBIndex) -> (usize, usize) { - let BBIndex(bb_idx) = bb; - let start = all_bb[bb_idx].start.0; - let end = if bb_idx == all_bb.len() - 1 { - func.len() - } else { - all_bb[bb_idx + 1].start.0 - }; - (start, end) -} - -// We assume here that the variables are defined in the dense sequence 0..max -struct SSARewriteState { - next: spirv::Word, - constant_ids: spirv::Word, - stack: Vec>, -} - -impl<'a> SSARewriteState { - fn new(constant_ids: spirv::Word, all_ids: spirv::Word) -> Self { - let to_redefine = all_ids - constant_ids; - let stack = (0..to_redefine) - .map(|x| vec![x + constant_ids]) - .collect::>(); - SSARewriteState { - next: all_ids, - constant_ids: constant_ids, - stack, + match s { + ast::Statement::Label(name) => { + result.push(ast::Statement::Label(id_defs.add_def(name, None))) } - } - - fn get(&self, x: spirv::Word) -> spirv::Word { - if x < self.constant_ids { - x - } else { - *self.stack[(x - self.constant_ids) as usize].last().unwrap() - } - } - - fn redefine(&mut self, x: spirv::Word) -> spirv::Word { - if x < self.constant_ids { - x - } else { - let result = self.next; - self.next += 1; - self.stack[(x - self.constant_ids) as usize].push(result); - result - } - } - - fn pop(&mut self, x: spirv::Word) { - if x >= self.constant_ids { - self.stack[(x - self.constant_ids) as usize].pop(); - } - } - - fn next_id(&self) -> spirv::Word { - self.next - } -} - -// "Engineering a Compiler" - Figure 9.9 -// Calculates semi-pruned phis -fn gather_phi_sets( - func: &[Statement], - constant_ids: spirv::Word, - all_ids: spirv::Word, - cfg: &[BasicBlock], - dom_fronts: &[HashSet], -) -> Vec> { - let mut result = vec![HashSet::new(); cfg.len()]; - let mut globals = HashSet::new(); - let mut blocks = vec![(Vec::new(), HashSet::new()); (all_ids - constant_ids) as usize]; - for bb in 0..cfg.len() { - let mut var_kill = HashSet::new(); - let mut visitor = |is_dst, id: spirv::Word| { - if id >= constant_ids { - let id = id - constant_ids; - if is_dst { - var_kill.insert(id); - let (ref mut stack, ref mut set) = blocks[id as usize]; - stack.push(BBIndex(bb)); - set.insert(BBIndex(bb)); - } else { - if !var_kill.contains(&id) { - globals.insert(id); - } + ast::Statement::Instruction(p, i) => result.push(ast::Statement::Instruction( + p.map(|p| p.map_id(&mut |id| id_defs.get_id(id))), + i.map_id1(&mut |id| id_defs.get_id(id)), + )), + ast::Statement::Variable(var) => match var.count { + Some(count) => { + for new_id in id_defs.add_defs(var.name, count, var.v_type) { + result.push(ast::Statement::Variable(ast::Variable { + space: var.space, + v_type: var.v_type, + name: new_id, + count: None, + })) } } - }; - // We try to avoid adding labels to the global-visbility set. - // We are not 100% precise (we add jump targets in bra), but it shouldn't be a problem - for s in get_bb_body(func, cfg, BBIndex(bb)) { - match s { - Statement::Instruction(inst) => inst.visit_id(&mut visitor), - Statement::Conditional(brc) => visitor(false, brc.predicate), - Statement::Converison(conv) => conv.visit_id(&mut visitor), - Statement::Constant(cons) => cons.visit_id(&mut visitor), - // label redefinition is a compile-time error - Statement::Label(_) => (), + None => { + let new_id = id_defs.add_def(var.name, Some(var.v_type)); + result.push(ast::Statement::Variable(ast::Variable { + space: var.space, + v_type: var.v_type, + name: new_id, + count: None, + })); } - } - } - for id in globals { - let (ref mut work_stack, ref mut work_set) = &mut blocks[id as usize]; - loop { - if let Some(bb) = work_stack.pop() { - work_set.remove(&bb); - for d_bb in &dom_fronts[bb.0] { - if result[d_bb.0].insert(id + constant_ids) { - if work_set.insert(*d_bb) { - work_stack.push(*d_bb); - } - } - } - } else { - break; - } - } - } - result -} - -fn get_basic_blocks(fun: &[Statement]) -> Vec { - // edge signify pred/succ relationship between bbs - let mut unresolved_bb_edge = Vec::new(); - // bb start means that a bb is starting at this statement, but there's no predecessor - let mut bb_start = Vec::new(); - let mut labels = HashMap::new(); - for (idx, s) in fun.iter().enumerate() { - match s { - Statement::Instruction(i) => { - if let Some(id) = i.jump_target() { - unresolved_bb_edge.push((StmtIndex(idx), id)); - if idx + 1 < fun.len() { - bb_start.push(StmtIndex(idx + 1)); - } - } else if i.is_terminal() && idx + 1 < fun.len() { - bb_start.push(StmtIndex(idx + 1)); - } - } - Statement::Label(id) => { - labels.insert(id, StmtIndex(idx)); - } - Statement::Conditional(bra) => { - unresolved_bb_edge.push((StmtIndex(idx), bra.if_false)); - unresolved_bb_edge.push((StmtIndex(idx), bra.if_true)); - } - Statement::Constant(_) => (), - Statement::Converison(_) => (), - }; - } - let mut bb_edge = HashSet::new(); - // Resolve every into - // TODO: handle jumps into nowhere - for (idx, id) in unresolved_bb_edge { - let target = labels[&id]; - bb_edge.insert((idx, target)); - bb_start.push(target); - // now check if there is an edge target-1 -> target - if target != StmtIndex(0) { - match &fun[target.0 - 1] { - Statement::Instruction(i) => { - if !(i.jump_target().is_some() || i.is_terminal()) { - bb_edge.insert((StmtIndex(target.0 - 1), target)); - } - } - Statement::Converison(_) | Statement::Constant(_) | Statement::Label(_) => { - bb_edge.insert((StmtIndex(target.0 - 1), target)); - } - // This is already in `unresolved_bb_edge` - Statement::Conditional(_) => (), - } - } - } - // Create list of bbs without succ/pred - let mut bbs_map = BTreeMap::new(); - bbs_map.insert( - StmtIndex(0), - BasicBlock { - start: StmtIndex(0), - pred: Vec::new(), - succ: Vec::new(), }, - ); - for bb_first_stmt in bb_start { - bbs_map.entry(bb_first_stmt).or_insert_with(|| BasicBlock { - start: bb_first_stmt, - pred: Vec::new(), - succ: Vec::new(), - }); - } - // Populate succ/pred - let indexed_bbs_map = bbs_map - .into_iter() - .enumerate() - .map(|(idx, (key, val))| (key, (BBIndex(idx), RefCell::new(val)))) - .collect::>(); - for (from, to) in bb_edge { - let (_, (from_idx, from_ref)) = indexed_bbs_map.range(..=from).next_back().unwrap(); - let (to_idx, to_ref) = indexed_bbs_map.get(&to).unwrap(); - { - from_ref.borrow_mut().succ.push(*to_idx); - } - { - to_ref.borrow_mut().pred.push(*from_idx); - } - } - indexed_bbs_map - .into_iter() - .map(|(_, (_, bb))| bb.into_inner()) - .collect::>() -} - -// "A Simple, Fast Dominance Algorithm" - Keith D. Cooper, Timothy J. Harvey, and Ken Kennedy -// https://www.cs.rice.edu/~keith/EMBED/dom.pdf -fn dominance_frontiers(bbs: &[BasicBlock], doms: &[BBIndex]) -> Vec> { - let mut result = vec![HashSet::new(); bbs.len()]; - for (bb_idx, b) in bbs.iter().enumerate() { - if b.pred.len() < 2 { - continue; - } - for p in b.pred.iter() { - let mut runner = *p; - while runner != doms[bb_idx] { - result[runner.0].insert(BBIndex(bb_idx)); - runner = doms[runner.0]; - } - } - } - result -} - -fn immediate_dominators(bbs: &Vec, order: &Vec) -> Vec { - let undefined = BBIndex(usize::max_value()); - let mut doms = vec![undefined; bbs.len()]; - doms[0] = BBIndex(0); - let mut changed = true; - while changed { - changed = false; - for BBIndex(bb_idx) in order.iter().skip(1) { - let bb = &bbs[*bb_idx]; - if let Some(first_pred) = bb.pred.iter().find(|bb| doms[bb.0] != undefined) { - let mut new_idom = *first_pred; - for BBIndex(p_idx) in bb.pred.iter().copied().filter(|bb| bb != first_pred) { - if doms[p_idx] != BBIndex(usize::max_value()) { - new_idom = intersect(&mut doms, BBIndex(p_idx), new_idom); - } - } - if doms[*bb_idx] != new_idom { - doms[*bb_idx] = new_idom; - changed = true; - } - } - } - } - return doms; -} - -// Original paper uses reverse indexing: their entry node has index n, -// that's why the compares are reversed -fn intersect(doms: &mut Vec, b1: BBIndex, b2: BBIndex) -> BBIndex { - let mut finger1 = b1; - let mut finger2 = b2; - while finger1 != finger2 { - while finger1 > finger2 { - finger1 = doms[finger1.0]; - } - while finger2 > finger1 { - finger2 = doms[finger2.0]; - } - } - finger1 -} - -// "A Simple Algorithm for Global Data Flow Analysis Problems" - Hecht, M. S., & Ullman, J. D. (1975) -fn to_reverse_postorder(input: &Vec) -> Vec { - let mut i = input.len(); - let mut old = BitVec::from_elem(input.len(), false); - let mut result = vec![BBIndex(usize::max_value()); input.len()]; - // original uses recursion and implicit stack, we do it explictly - let mut state = Vec::new(); - state.push((BBIndex(0), 0usize)); - loop { - if let Some((BBIndex(bb), succ_iter_idx)) = state.last_mut() { - let bb = *bb; - if *succ_iter_idx == 0 { - old.set(bb, true); - } - if let Some(BBIndex(succ)) = &input[bb].succ.get(*succ_iter_idx) { - *succ_iter_idx += 1; - if !old.get(*succ).unwrap() { - state.push((BBIndex(*succ), 0)); - } - } else { - state.pop(); - i = i - 1; - result[i] = BBIndex(bb); - } - } else { - break; - } - } - result -} - -#[derive(Eq, PartialEq, Debug, Clone)] -struct BasicBlock { - start: StmtIndex, - pred: Vec, - succ: Vec, -} - -#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)] -struct StmtIndex(pub usize); - -impl fmt::Display for StmtIndex { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.0.fmt(f) } } -#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)] -struct BBIndex(pub usize); +struct StringIdResolver<'a> { + current_id: spirv::Word, + variables: HashMap, spirv::Word>, + type_check: HashMap, +} -impl fmt::Display for BBIndex { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.0.fmt(f) +impl<'a> StringIdResolver<'a> { + fn new() -> Self { + StringIdResolver { + current_id: 0u32, + variables: HashMap::new(), + type_check: HashMap::new(), + } + } + + fn finish(self) -> NumericIdResolver { + NumericIdResolver { + current_id: self.current_id, + type_check: self.type_check, + } + } + + fn get_id(&self, id: &'a str) -> spirv::Word { + self.variables[id] + } + + #[must_use] + fn add_def(&mut self, id: &'a str, typ: Option) -> spirv::Word { + let numeric_id = self.current_id; + self.variables.insert(Cow::Borrowed(id), numeric_id); + if let Some(typ) = typ { + self.type_check.insert(numeric_id, typ); + } + self.current_id += 1; + numeric_id + } + + #[must_use] + fn add_defs( + &mut self, + base_id: &'a str, + count: u32, + typ: ast::Type, + ) -> impl Iterator { + let numeric_id = self.current_id; + for i in 0..count { + self.variables + .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i); + self.type_check.insert(numeric_id + i, typ); + } + self.current_id += count; + (0..count).into_iter().map(move |i| i + numeric_id) } } -enum Statement { +struct NumericIdResolver { + current_id: spirv::Word, + type_check: HashMap, +} + +impl NumericIdResolver { + fn get_type(&self, id: spirv::Word) -> ast::Type { + self.type_check[&id] + } + + fn new_id(&mut self, typ: Option) -> spirv::Word { + let new_id = self.current_id; + if let Some(typ) = typ { + self.type_check.insert(new_id, typ); + } + self.current_id += 1; + new_id + } + + fn ids_count(&self) -> spirv::Word { + self.current_id + } +} + +enum Statement { + Variable(spirv::Word, ast::Type), + LoadVar(Arg2, ast::Type), + StoreVar(Arg2St, ast::Type), Label(u32), - Instruction(Instruction), + Instruction(Instruction), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), Converison(ImplicitConversion), Constant(ConstantDefinition), } -enum Instruction { - Ld(ast::LdData, Arg2), - Mov(ast::MovData, Arg2), - Mul(ast::MulData, Arg3), - Add(ast::AddData, Arg3), - Setp(ast::SetpData, Arg4), - SetpBool(ast::SetpBoolData, Arg5), - Not(ast::NotData, Arg2), - Bra(ast::BraData, Arg1), - Cvt(ast::CvtData, Arg2), - Shl(ast::ShlData, Arg3), - St(ast::StData, Arg2St), - Ret(ast::RetData), -} - -impl Instruction { - fn visit_id(&self, f: &mut F) { +impl Statement { + fn visit_id_mut(&mut self, f: &mut F) { match self { - Instruction::Ld(_, a) => a.visit_id(f), - Instruction::Mov(_, a) => a.visit_id(f), - Instruction::Mul(_, a) => a.visit_id(f), - Instruction::Add(_, a) => a.visit_id(f), - Instruction::Setp(_, a) => a.visit_id(f), - Instruction::SetpBool(_, a) => a.visit_id(f), - Instruction::Not(_, a) => a.visit_id(f), - Instruction::Cvt(_, a) => a.visit_id(f), - Instruction::Shl(_, a) => a.visit_id(f), - Instruction::St(_, a) => a.visit_id(f), - Instruction::Bra(_, a) => a.visit_id(f), - Instruction::Ret(_) => (), + Statement::Variable(id, _) => f(true, id), + Statement::LoadVar(a, _) => a.visit_id_mut(f), + Statement::StoreVar(a, _) => a.visit_id_mut(f), + Statement::Label(id) => f(false, id), + Statement::Instruction(inst) => inst.visit_id_mut(f), + Statement::Conditional(bra) => bra.visit_id_mut(f), + Statement::Converison(conv) => conv.visit_id_mut(f), + Statement::Constant(cons) => cons.visit_id_mut(f), } } + fn get_type(&self) -> Option { + todo!() + } +} + +trait Args { + type Arg1: Arg; + type Arg2: Arg; + type Arg2St: Arg; + type Arg2Mov: Arg; + type Arg3: Arg; + type Arg4: Arg; + type Arg5: Arg; +} + +trait Arg { + fn visit_id(&self, f: &mut F); + fn visit_id_mut(&mut self, f: &mut F); +} + +enum NormalizedArgs {} + +impl Args for NormalizedArgs { + type Arg1 = ast::Arg1; + type Arg2 = ast::Arg2; + type Arg2St = ast::Arg2St; + type Arg2Mov = ast::Arg2Mov; + type Arg3 = ast::Arg3; + type Arg4 = ast::Arg4; + type Arg5 = ast::Arg5; +} + +enum ExpandedArgs {} + +impl Args for ExpandedArgs { + type Arg1 = Arg1; + type Arg2 = Arg2; + type Arg2St = Arg2St; + type Arg2Mov = Arg2; + type Arg3 = Arg3; + type Arg4 = Arg4; + type Arg5 = Arg5; +} + +type NormalizedStatement = Statement; +type ExpandedStatement = Statement; + +enum Instruction { + Ld(ast::LdData, A::Arg2), + Mov(ast::MovData, A::Arg2Mov), + Mul(ast::MulData, A::Arg3), + Add(ast::AddData, A::Arg3), + Setp(ast::SetpData, A::Arg4), + SetpBool(ast::SetpBoolData, A::Arg5), + Not(ast::NotData, A::Arg2), + Bra(ast::BraData, A::Arg1), + Cvt(ast::CvtData, A::Arg2), + Shl(ast::ShlData, A::Arg3), + St(ast::StData, A::Arg2St), + Ret(ast::RetData), +} + +impl Instruction { fn visit_id_mut(&mut self, f: &mut F) { match self { Instruction::Ld(_, a) => a.visit_id_mut(f), @@ -1324,23 +973,6 @@ impl Instruction { } } - fn jump_target(&self) -> Option { - match self { - Instruction::Bra(_, a) => Some(a.src), - Instruction::Ld(_, _) - | Instruction::Mov(_, _) - | Instruction::Mul(_, _) - | Instruction::Add(_, _) - | Instruction::Setp(_, _) - | Instruction::SetpBool(_, _) - | Instruction::Not(_, _) - | Instruction::Cvt(_, _) - | Instruction::Shl(_, _) - | Instruction::St(_, _) - | Instruction::Ret(_) => None, - } - } - fn is_terminal(&self) -> bool { match self { Instruction::Ret(_) => true, @@ -1359,11 +991,66 @@ impl Instruction { } } +impl Instruction { + fn from_ast(s: ast::Instruction) -> Self { + match s { + ast::Instruction::Ld(d, a) => Instruction::Ld(d, a), + ast::Instruction::Mov(d, a) => Instruction::Mov(d, a), + ast::Instruction::Mul(d, a) => Instruction::Mul(d, a), + ast::Instruction::Add(d, a) => Instruction::Add(d, a), + ast::Instruction::Setp(d, a) => Instruction::Setp(d, a), + ast::Instruction::SetpBool(d, a) => Instruction::SetpBool(d, a), + ast::Instruction::Not(d, a) => Instruction::Not(d, a), + ast::Instruction::Cvt(d, a) => Instruction::Cvt(d, a), + ast::Instruction::Shl(d, a) => Instruction::Shl(d, a), + ast::Instruction::St(d, a) => Instruction::St(d, a), + ast::Instruction::Bra(d, a) => Instruction::Bra(d, a), + ast::Instruction::Ret(d) => Instruction::Ret(d), + } + } +} + +impl Instruction { + fn visit_id(&self, f: &mut F) { + match self { + Instruction::Ld(_, a) => a.visit_id(f), + Instruction::Mov(_, a) => a.visit_id(f), + Instruction::Mul(_, a) => a.visit_id(f), + Instruction::Add(_, a) => a.visit_id(f), + Instruction::Setp(_, a) => a.visit_id(f), + Instruction::SetpBool(_, a) => a.visit_id(f), + Instruction::Not(_, a) => a.visit_id(f), + Instruction::Cvt(_, a) => a.visit_id(f), + Instruction::Shl(_, a) => a.visit_id(f), + Instruction::St(_, a) => a.visit_id(f), + Instruction::Bra(_, a) => a.visit_id(f), + Instruction::Ret(_) => (), + } + } + + fn jump_target(&self) -> Option { + match self { + Instruction::Bra(_, a) => Some(a.src), + Instruction::Ld(_, _) + | Instruction::Mov(_, _) + | Instruction::Mul(_, _) + | Instruction::Add(_, _) + | Instruction::Setp(_, _) + | Instruction::SetpBool(_, _) + | Instruction::Not(_, _) + | Instruction::Cvt(_, _) + | Instruction::Shl(_, _) + | Instruction::St(_, _) + | Instruction::Ret(_) => None, + } + } +} + struct Arg1 { pub src: spirv::Word, } -impl Arg1 { +impl Arg for Arg1 { fn visit_id(&self, f: &mut F) { f(false, self.src); } @@ -1378,7 +1065,7 @@ struct Arg2 { pub src: spirv::Word, } -impl Arg2 { +impl Arg for Arg2 { fn visit_id(&self, f: &mut F) { f(true, self.dst); f(false, self.src); @@ -1395,7 +1082,7 @@ pub struct Arg2St { pub src2: spirv::Word, } -impl Arg2St { +impl Arg for Arg2St { fn visit_id(&self, f: &mut F) { f(false, self.src1); f(false, self.src2); @@ -1413,7 +1100,7 @@ struct Arg3 { pub src2: spirv::Word, } -impl Arg3 { +impl Arg for Arg3 { fn visit_id(&self, f: &mut F) { f(true, self.dst); f(false, self.src1); @@ -1434,7 +1121,7 @@ struct Arg4 { pub src2: spirv::Word, } -impl Arg4 { +impl Arg for Arg4 { fn visit_id(&self, f: &mut F) { f(true, self.dst1); self.dst2.map(|dst2| f(true, dst2)); @@ -1458,7 +1145,7 @@ struct Arg5 { pub src3: spirv::Word, } -impl Arg5 { +impl Arg for Arg5 { fn visit_id(&self, f: &mut F) { f(true, self.dst1); self.dst2.map(|dst2| f(true, dst2)); @@ -1540,44 +1227,6 @@ impl ImplicitConversion { } } -impl Statement { - fn from_ast<'a, F: FnMut(&'a str) -> u32>( - s: ast::Statement<&'a str>, - get_id: &mut F, - ) -> Option> { - match s { - ast::Statement::Label(name) => Some(ast::Statement::Label(get_id(name))), - ast::Statement::Instruction(p, i) => Some(ast::Statement::Instruction( - p.map(|p| p.map_id(get_id)), - i.map_id(get_id), - )), - ast::Statement::Variable(_) => None, - } - } - - fn visit_id(&self, f: &mut F) { - match self { - Statement::Label(id) => f(false, *id), - Statement::Instruction(inst) => inst.visit_id(f), - Statement::Conditional(bra) => bra.visit_id(f), - Statement::Converison(conv) => conv.visit_id(f), - Statement::Constant(cons) => cons.visit_id(f), - } - } - - // WARNING: It is very important to first visit src operands and then dst operands, - // otherwise SSA renaming will yield weird results - fn visit_id_mut(&mut self, f: &mut F) { - match self { - Statement::Label(id) => f(false, id), - Statement::Instruction(inst) => inst.visit_id_mut(f), - Statement::Conditional(bra) => bra.visit_id_mut(f), - Statement::Converison(conv) => conv.visit_id_mut(f), - Statement::Constant(cons) => cons.visit_id_mut(f), - } - } -} - impl ast::PredAt { fn map_id U>(self, f: &mut F) -> ast::PredAt { ast::PredAt { @@ -1588,7 +1237,7 @@ impl ast::PredAt { } impl ast::Instruction { - fn map_id U>(self, f: &mut F) -> ast::Instruction { + fn map_id1 U>(self, f: &mut F) -> ast::Instruction { match self { ast::Instruction::Ld(d, a) => ast::Instruction::Ld(d, a.map_id(f)), ast::Instruction::Mov(d, a) => ast::Instruction::Mov(d, a.map_id(f)), @@ -1605,9 +1254,28 @@ impl ast::Instruction { } } - fn visit_id(&self, f: &mut F) { + fn map_id spirv::Word>(self, f: &mut F) -> Instruction { match self { - ast::Instruction::Ld(_, a) => a.visit_id(f), + ast::Instruction::Ld(d, a) => Instruction::Ld(d, a.map_id(f)), + ast::Instruction::Mov(d, a) => Instruction::Mov(d, a.map_id(f)), + ast::Instruction::Mul(d, a) => Instruction::Mul(d, a.map_id(f)), + ast::Instruction::Add(d, a) => Instruction::Add(d, a.map_id(f)), + ast::Instruction::Setp(d, a) => Instruction::Setp(d, a.map_id(f)), + ast::Instruction::SetpBool(d, a) => Instruction::SetpBool(d, a.map_id(f)), + ast::Instruction::Not(d, a) => Instruction::Not(d, a.map_id(f)), + ast::Instruction::Bra(d, a) => Instruction::Bra(d, a.map_id(f)), + ast::Instruction::Cvt(d, a) => Instruction::Cvt(d, a.map_id(f)), + ast::Instruction::Shl(d, a) => Instruction::Shl(d, a.map_id(f)), + ast::Instruction::St(d, a) => Instruction::St(d, a.map_id(f)), + ast::Instruction::Ret(d) => Instruction::Ret(d), + } + } +} + +impl ast::Instruction { + fn visit_id(&self, f: &mut F) { + match self { + ast::Instruction::Ld(_, a) => Arg::visit_id(a, f), ast::Instruction::Mov(_, a) => a.visit_id(f), ast::Instruction::Mul(_, a) => a.visit_id(f), ast::Instruction::Add(_, a) => a.visit_id(f), @@ -1622,7 +1290,7 @@ impl ast::Instruction { } } - fn visit_id_mut(&mut self, f: &mut F) { + fn visit_id_mut(&mut self, f: &mut F) { match self { ast::Instruction::Ld(_, a) => a.visit_id_mut(f), ast::Instruction::Mov(_, a) => a.visit_id_mut(f), @@ -1692,12 +1360,14 @@ impl ast::Arg1 { fn map_id U>(self, f: &mut F) -> ast::Arg1 { ast::Arg1 { src: f(self.src) } } +} - fn visit_id(&self, f: &mut F) { - f(false, &self.src); +impl Arg for ast::Arg1 { + fn visit_id(&self, f: &mut F) { + f(false, self.src); } - fn visit_id_mut(&mut self, f: &mut F) { + fn visit_id_mut(&mut self, f: &mut F) { f(false, &mut self.src); } } @@ -1709,13 +1379,15 @@ impl ast::Arg2 { src: self.src.map_id(f), } } +} - fn visit_id(&self, f: &mut F) { - f(true, &self.dst); +impl Arg for ast::Arg2 { + fn visit_id(&self, f: &mut F) { + f(true, self.dst); self.src.visit_id(f); } - fn visit_id_mut(&mut self, f: &mut F) { + fn visit_id_mut(&mut self, f: &mut F) { self.src.visit_id_mut(f); f(true, &mut self.dst); } @@ -1728,13 +1400,15 @@ impl ast::Arg2St { src2: self.src2.map_id(f), } } +} - fn visit_id(&self, f: &mut F) { +impl Arg for ast::Arg2St { + fn visit_id(&self, f: &mut F) { self.src1.visit_id(f); self.src2.visit_id(f); } - fn visit_id_mut(&mut self, f: &mut F) { + fn visit_id_mut(&mut self, f: &mut F) { self.src1.visit_id_mut(f); self.src2.visit_id_mut(f); } @@ -1747,13 +1421,15 @@ impl ast::Arg2Mov { src: self.src.map_id(f), } } +} - fn visit_id(&self, f: &mut F) { - f(true, &self.dst); +impl Arg for ast::Arg2Mov { + fn visit_id(&self, f: &mut F) { + f(true, self.dst); self.src.visit_id(f); } - fn visit_id_mut(&mut self, f: &mut F) { + fn visit_id_mut(&mut self, f: &mut F) { self.src.visit_id_mut(f); f(true, &mut self.dst); } @@ -1767,14 +1443,16 @@ impl ast::Arg3 { src2: self.src2.map_id(f), } } +} - fn visit_id(&self, f: &mut F) { - f(true, &self.dst); +impl Arg for ast::Arg3 { + fn visit_id(&self, f: &mut F) { + f(true, self.dst); self.src1.visit_id(f); self.src2.visit_id(f); } - fn visit_id_mut(&mut self, f: &mut F) { + fn visit_id_mut(&mut self, f: &mut F) { self.src1.visit_id_mut(f); self.src2.visit_id_mut(f); f(true, &mut self.dst); @@ -1790,15 +1468,17 @@ impl ast::Arg4 { src2: self.src2.map_id(f), } } +} - fn visit_id(&self, f: &mut F) { - f(true, &self.dst1); - self.dst2.as_ref().map(|i| f(true, i)); +impl Arg for ast::Arg4 { + fn visit_id(&self, f: &mut F) { + f(true, self.dst1); + self.dst2.map(|i| f(true, i)); self.src1.visit_id(f); self.src2.visit_id(f); } - fn visit_id_mut(&mut self, f: &mut F) { + fn visit_id_mut(&mut self, f: &mut F) { self.src1.visit_id_mut(f); self.src2.visit_id_mut(f); f(true, &mut self.dst1); @@ -1816,16 +1496,18 @@ impl ast::Arg5 { src3: self.src3.map_id(f), } } +} - fn visit_id(&self, f: &mut F) { - f(true, &self.dst1); - self.dst2.as_ref().map(|i| f(true, i)); +impl Arg for ast::Arg5 { + fn visit_id(&self, f: &mut F) { + f(true, self.dst1); + self.dst2.map(|i| f(true, i)); self.src1.visit_id(f); self.src2.visit_id(f); self.src3.visit_id(f); } - fn visit_id_mut(&mut self, f: &mut F) { + fn visit_id_mut(&mut self, f: &mut F) { self.src1.visit_id_mut(f); self.src2.visit_id_mut(f); self.src3.visit_id_mut(f); @@ -1842,11 +1524,13 @@ impl ast::Operand { ast::Operand::Imm(v) => ast::Operand::Imm(v), } } +} - fn visit_id(&self, f: &mut F) { +impl ast::Operand { + fn visit_id(&self, f: &mut F) { match self { - ast::Operand::Reg(i) => f(false, i), - ast::Operand::RegOffset(i, _) => f(false, i), + ast::Operand::Reg(i) => f(false, *i), + ast::Operand::RegOffset(i, _) => f(false, *i), ast::Operand::Imm(_) => (), } } @@ -1867,18 +1551,20 @@ impl ast::MovOperand { ast::MovOperand::Vec(s1, s2) => ast::MovOperand::Vec(s1, s2), } } +} - fn visit_id(&self, f: &mut F) { +impl ast::MovOperand { + fn visit_id(&self, f: &mut F) { match self { ast::MovOperand::Op(o) => o.visit_id(f), - ast::MovOperand::Vec(_, _) => (), + ast::MovOperand::Vec(_, _) => todo!(), } } fn visit_id_mut(&mut self, f: &mut F) { match self { ast::MovOperand::Op(o) => o.visit_id_mut(f), - ast::MovOperand::Vec(_, _) => (), + ast::MovOperand::Vec(_, _) => todo!(), } } } @@ -2007,19 +1693,17 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { } } -fn insert_implicit_conversions_ld_src ast::Type>( - func: &mut Vec, +fn insert_implicit_conversions_ld_src( + func: &mut Vec, instr_type: ast::Type, - type_check: &TypeCheck, - new_id: &mut impl FnMut(Option) -> spirv::Word, + id_def: &mut NumericIdResolver, state_space: ast::LdStateSpace, src: spirv::Word, ) -> spirv::Word { match state_space { ast::LdStateSpace::Param => insert_implicit_conversions_ld_src_impl( func, - type_check, - new_id, + id_def, instr_type, src, should_convert_ld_param_src, @@ -2031,15 +1715,14 @@ fn insert_implicit_conversions_ld_src ast::Type>( )); let new_src = insert_implicit_conversions_ld_src_impl( func, - type_check, - new_id, + id_def, new_src_type, src, should_convert_ld_generic_src_to_bitcast, ); insert_conversion_src( func, - new_id, + id_def, new_src, new_src_type, instr_type, @@ -2051,19 +1734,17 @@ fn insert_implicit_conversions_ld_src ast::Type>( } fn insert_implicit_conversions_ld_src_impl< - TypeCheck: Fn(spirv::Word) -> ast::Type, ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option, >( - func: &mut Vec, - type_check: &TypeCheck, - new_id: &mut impl FnMut(Option) -> spirv::Word, + func: &mut Vec, + id_def: &mut NumericIdResolver, instr_type: ast::Type, src: spirv::Word, should_convert: ShouldConvert, ) -> spirv::Word { - let src_type = type_check(src); + let src_type = id_def.get_type(src); if let Some(conv) = should_convert(src_type, instr_type) { - insert_conversion_src(func, new_id, src, src_type, instr_type, conv) + insert_conversion_src(func, id_def, src, src_type, instr_type, conv) } else { src } @@ -2096,14 +1777,14 @@ fn should_convert_ld_generic_src_to_bitcast( #[must_use] fn insert_conversion_src( - func: &mut Vec, - new_id: &mut impl FnMut(Option) -> spirv::Word, + func: &mut Vec, + id_def: &mut NumericIdResolver, src: spirv::Word, src_type: ast::Type, instr_type: ast::Type, conv: ConversionKind, ) -> spirv::Word { - let temp_src = new_id(Some(instr_type)); + let temp_src = id_def.new_id(Some(instr_type)); func.push(Statement::Converison(ImplicitConversion { src: src, dst: temp_src, @@ -2116,24 +1797,22 @@ fn insert_conversion_src( fn insert_with_implicit_conversion_dst< T, - TypeCheck: Fn(spirv::Word) -> ast::Type, ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option, Setter: Fn(&mut T) -> &mut spirv::Word, - ToInstruction: FnOnce(T) -> Instruction, + ToInstruction: FnOnce(T) -> Instruction, >( - func: &mut Vec, + func: &mut Vec, instr_type: ast::ScalarType, - type_check: &TypeCheck, - new_id: &mut impl FnMut(Option) -> spirv::Word, + id_def: &mut NumericIdResolver, should_convert: ShouldConvert, mut t: T, setter: Setter, to_inst: ToInstruction, ) { let dst = setter(&mut t); - let dst_type = type_check(*dst); + let dst_type = id_def.get_type(*dst); let dst_coercion = should_convert(dst_type, instr_type) - .map(|conv| get_conversion_dst(new_id, dst, ast::Type::Scalar(instr_type), dst_type, conv)); + .map(|conv| get_conversion_dst(id_def, dst, ast::Type::Scalar(instr_type), dst_type, conv)); func.push(Statement::Instruction(to_inst(t))); if let Some(conv) = dst_coercion { func.push(conv); @@ -2142,14 +1821,14 @@ fn insert_with_implicit_conversion_dst< #[must_use] fn get_conversion_dst( - new_id: &mut impl FnMut(Option) -> spirv::Word, + id_def: &mut NumericIdResolver, dst: &mut spirv::Word, instr_type: ast::Type, dst_type: ast::Type, kind: ConversionKind, -) -> Statement { +) -> ExpandedStatement { let original_dst = *dst; - let temp_dst = new_id(Some(instr_type)); + let temp_dst = id_def.new_id(Some(instr_type)); *dst = temp_dst; Statement::Converison(ImplicitConversion { src: temp_dst, @@ -2245,20 +1924,19 @@ fn should_convert_relaxed_dst( } } -fn insert_implicit_bitcasts ast::Type>( - func: &mut Vec, - type_check: &TypeCheck, - new_id: &mut impl FnMut(Option) -> spirv::Word, - mut instr: Instruction, +fn insert_implicit_bitcasts( + func: &mut Vec, + id_def: &mut NumericIdResolver, + mut instr: Instruction, ) { let mut dst_coercion = None; if let Some(instr_type) = instr.get_type() { instr.visit_id_mut(&mut |is_dst, id| { - let id_type = type_check(*id); - if should_bitcast(instr_type, type_check(*id)) { + let id_type = id_def.get_type(*id); + if should_bitcast(instr_type, id_def.get_type(*id)) { if is_dst { dst_coercion = Some(get_conversion_dst( - new_id, + id_def, id, instr_type, id_type, @@ -2267,7 +1945,7 @@ fn insert_implicit_bitcasts ast::Type>( } else { *id = insert_conversion_src( func, - new_id, + id_def, *id, id_type, instr_type, @@ -2290,724 +1968,6 @@ mod tests { use crate::ast; use crate::ptx; - // page 411 - #[test] - fn to_reverse_postorder1() { - let input = vec![ - BasicBlock { - // A - start: StmtIndex(0), - pred: vec![], - succ: vec![BBIndex(1), BBIndex(2)], - }, - BasicBlock { - // B - start: StmtIndex(1), - pred: vec![BBIndex(0), BBIndex(11)], - succ: vec![BBIndex(3), BBIndex(6)], - }, - BasicBlock { - // C - start: StmtIndex(2), - pred: vec![BBIndex(0), BBIndex(4)], - succ: vec![BBIndex(4), BBIndex(7)], - }, - BasicBlock { - // D - start: StmtIndex(3), - pred: vec![BBIndex(1)], - succ: vec![BBIndex(5), BBIndex(6)], - }, - BasicBlock { - // E - start: StmtIndex(4), - pred: vec![BBIndex(2)], - succ: vec![BBIndex(2), BBIndex(7)], - }, - BasicBlock { - // F - start: StmtIndex(5), - pred: vec![BBIndex(3)], - succ: vec![BBIndex(8), BBIndex(10)], - }, - BasicBlock { - // G - start: StmtIndex(6), - pred: vec![BBIndex(1), BBIndex(3)], - succ: vec![BBIndex(9)], - }, - BasicBlock { - // H - start: StmtIndex(7), - pred: vec![BBIndex(2), BBIndex(4)], - succ: vec![BBIndex(12)], - }, - BasicBlock { - // I - start: StmtIndex(8), - pred: vec![BBIndex(5), BBIndex(9)], - succ: vec![BBIndex(11)], - }, - BasicBlock { - // J - start: StmtIndex(9), - pred: vec![BBIndex(6)], - succ: vec![BBIndex(8)], - }, - BasicBlock { - // K - start: StmtIndex(10), - pred: vec![BBIndex(5)], - succ: vec![BBIndex(11)], - }, - BasicBlock { - // L - start: StmtIndex(11), - pred: vec![BBIndex(8), BBIndex(10)], - succ: vec![BBIndex(1), BBIndex(12)], - }, - BasicBlock { - // M - start: StmtIndex(12), - pred: vec![BBIndex(7), BBIndex(11)], - succ: vec![], - }, - ]; - let rpostord = to_reverse_postorder(&input); - assert_eq!( - rpostord, - vec![ - BBIndex(0), // A - BBIndex(2), // C - BBIndex(4), // E - BBIndex(7), // H - BBIndex(1), // B - BBIndex(3), // D - BBIndex(6), // G - BBIndex(9), // J - BBIndex(5), // F - BBIndex(10), // K - BBIndex(8), // I - BBIndex(11), // L - BBIndex(12), // M - ] - ); - } - - #[test] - fn get_basic_blocks_empty() { - let func = Vec::new(); - let bbs = get_basic_blocks(&func); - assert_eq!( - bbs, - vec![BasicBlock { - start: StmtIndex(0), - pred: vec![], - succ: vec![], - }] - ); - } - - #[test] - fn get_basic_blocks_miniloop() { - let func = vec![ - Statement::Label(12), - Statement::Instruction(Instruction::Bra( - ast::BraData { uniform: false }, - Arg1 { src: 12 }, - )), - ]; - let bbs = get_basic_blocks(&func); - assert_eq!( - bbs, - vec![BasicBlock { - start: StmtIndex(0), - pred: vec![BBIndex(0)], - succ: vec![BBIndex(0)], - }] - ); - } - - // "A Simple, Fast Dominance Algorithm" - Fig. 4 - fn simple_fast_dom_fig4() -> Vec { - vec![ - BasicBlock { - start: StmtIndex(6), - pred: vec![], - succ: vec![BBIndex(1), BBIndex(2)], - }, - BasicBlock { - start: StmtIndex(5), - pred: vec![BBIndex(0)], - succ: vec![BBIndex(5)], - }, - BasicBlock { - start: StmtIndex(4), - pred: vec![BBIndex(0)], - succ: vec![BBIndex(3), BBIndex(4)], - }, - BasicBlock { - start: StmtIndex(3), - pred: vec![BBIndex(2), BBIndex(4)], - succ: vec![BBIndex(4)], - }, - BasicBlock { - start: StmtIndex(2), - pred: vec![BBIndex(2), BBIndex(3), BBIndex(5)], - succ: vec![BBIndex(3), BBIndex(5)], - }, - BasicBlock { - start: StmtIndex(1), - pred: vec![BBIndex(1), BBIndex(4)], - succ: vec![BBIndex(4)], - }, - ] - } - - #[test] - fn immediate_dominators1() { - let input = simple_fast_dom_fig4(); - let reverse_postorder = vec![ - BBIndex(0), - BBIndex(1), - BBIndex(2), - BBIndex(3), - BBIndex(4), - BBIndex(5), - ]; - let imm_dominators = immediate_dominators(&input, &reverse_postorder); - assert_eq!( - imm_dominators, - vec![ - BBIndex(0), - BBIndex(0), - BBIndex(0), - BBIndex(0), - BBIndex(0), - BBIndex(0) - ] - ); - } - - // page 411 - #[test] - fn immediate_dominators2() { - let input = vec![ - BasicBlock { - // A - start: StmtIndex(0), - pred: vec![], - succ: vec![BBIndex(1), BBIndex(2)], - }, - BasicBlock { - // B - start: StmtIndex(1), - pred: vec![BBIndex(0), BBIndex(11)], - succ: vec![BBIndex(3), BBIndex(6)], - }, - BasicBlock { - // C - start: StmtIndex(2), - pred: vec![BBIndex(0), BBIndex(4)], - succ: vec![BBIndex(4), BBIndex(7)], - }, - BasicBlock { - // D - start: StmtIndex(3), - pred: vec![BBIndex(1)], - succ: vec![BBIndex(5), BBIndex(6)], - }, - BasicBlock { - // E - start: StmtIndex(4), - pred: vec![BBIndex(2)], - succ: vec![BBIndex(2), BBIndex(7)], - }, - BasicBlock { - // F - start: StmtIndex(5), - pred: vec![BBIndex(3)], - succ: vec![BBIndex(8), BBIndex(10)], - }, - BasicBlock { - // G - start: StmtIndex(6), - pred: vec![BBIndex(1), BBIndex(3)], - succ: vec![BBIndex(9)], - }, - BasicBlock { - // H - start: StmtIndex(7), - pred: vec![BBIndex(2), BBIndex(4)], - succ: vec![BBIndex(12)], - }, - BasicBlock { - // I - start: StmtIndex(8), - pred: vec![BBIndex(5), BBIndex(9)], - succ: vec![BBIndex(11)], - }, - BasicBlock { - // J - start: StmtIndex(9), - pred: vec![BBIndex(6)], - succ: vec![BBIndex(8)], - }, - BasicBlock { - // K - start: StmtIndex(10), - pred: vec![BBIndex(5)], - succ: vec![BBIndex(11)], - }, - BasicBlock { - // L - start: StmtIndex(11), - pred: vec![BBIndex(8), BBIndex(10)], - succ: vec![BBIndex(1), BBIndex(12)], - }, - BasicBlock { - // M - start: StmtIndex(12), - pred: vec![BBIndex(7), BBIndex(11)], - succ: vec![], - }, - ]; - let reverse_postorder = vec![ - BBIndex(0), // A - BBIndex(2), // C - BBIndex(4), // E - BBIndex(7), // H - BBIndex(1), // B - BBIndex(3), // D - BBIndex(6), // G - BBIndex(9), // J - BBIndex(5), // F - BBIndex(10), // K - BBIndex(8), // I - BBIndex(11), // L - BBIndex(12), // M - ]; - let imm_dominators = immediate_dominators(&input, &reverse_postorder); - assert_eq!( - imm_dominators, - vec![ - BBIndex(0), - BBIndex(0), - BBIndex(0), - BBIndex(1), - BBIndex(2), - BBIndex(3), - BBIndex(1), - BBIndex(2), - BBIndex(1), - BBIndex(6), - BBIndex(5), - BBIndex(1), - BBIndex(0) - ] - ); - } - - fn sort_pred_succ(bb: &mut BasicBlock) { - bb.pred.sort(); - bb.succ.sort(); - } - - // page 403 - const FIG_19_4: &'static str = "{ - .reg.u32 i; - .reg.u32 j; - .reg.u32 k; - .reg.pred p; - .reg.pred q; - - mov.u32 i, 1; - mov.u32 j, 1; - mov.u32 k, 0; - block_2: - setp.ge.u32 p, k, 100; - @p bra block_4; // conditional p block_4 if_false1 - // if_false1: - setp.ge.u32 q, j, 20; - @q bra block_6; // conditional q block_6 if_false2 - // if_false2: - mov.u32 j, i; - add.u32 k, k, 1; - bra block_7; - block_6: - mov.u32 j, k; - add.u32 k, k, 2; - block_7: - bra block_2; - block_4: - ret; - }"; - - #[test] - fn get_basic_blocks_fig_19_4() { - let func = FIG_19_4; - let mut errors = Vec::new(); - let ast = ptx::FunctionBodyParser::new() - .parse(&mut errors, func) - .unwrap(); - assert_eq!(errors.len(), 0); - let mut constant_ids = HashMap::new(); - collect_label_ids(&mut constant_ids, &ast); - let registers = collect_var_definitions(&[], &ast); - let mut type_check = HashMap::new(); - let (normalized_ids, mut unique_ids) = - normalize_identifiers(ast, &constant_ids, &mut type_check, registers); - let type_check = RefCell::new(type_check); - let new_id = &mut |typ: Option| { - let to_insert = unique_ids; - { - let mut type_check = type_check.borrow_mut(); - typ.map(|t| (*type_check).insert(to_insert, t)); - } - unique_ids += 1; - to_insert - }; - let normalized_stmts = normalize_statements(normalized_ids, new_id); - let mut bbs = get_basic_blocks(&normalized_stmts); - bbs.iter_mut().for_each(sort_pred_succ); - assert_eq!( - bbs, - vec![ - BasicBlock { - start: StmtIndex(0), - pred: vec![], - succ: vec![BBIndex(1)], - }, - BasicBlock { - start: StmtIndex(6), - pred: vec![BBIndex(0), BBIndex(5)], - succ: vec![BBIndex(2), BBIndex(6)], - }, - BasicBlock { - start: StmtIndex(10), - pred: vec![BBIndex(1)], - succ: vec![BBIndex(3), BBIndex(4)], - }, - BasicBlock { - start: StmtIndex(14), - pred: vec![BBIndex(2)], - succ: vec![BBIndex(5)], - }, - BasicBlock { - start: StmtIndex(19), - pred: vec![BBIndex(2)], - succ: vec![BBIndex(5)], - }, - BasicBlock { - start: StmtIndex(23), - pred: vec![BBIndex(3), BBIndex(4)], - succ: vec![BBIndex(1)], - }, - BasicBlock { - start: StmtIndex(25), - pred: vec![BBIndex(1)], - succ: vec![], - }, - ] - ); - } - - fn cfg_fig_19_4() -> Vec { - vec![ - BasicBlock { - start: StmtIndex(0), - pred: vec![], - succ: vec![BBIndex(1)], - }, - BasicBlock { - start: StmtIndex(3), - pred: vec![BBIndex(0), BBIndex(5)], - succ: vec![BBIndex(2), BBIndex(6)], - }, - BasicBlock { - start: StmtIndex(6), - pred: vec![BBIndex(1)], - succ: vec![BBIndex(3), BBIndex(4)], - }, - BasicBlock { - start: StmtIndex(9), - pred: vec![BBIndex(2)], - succ: vec![BBIndex(5)], - }, - BasicBlock { - start: StmtIndex(13), - pred: vec![BBIndex(2)], - succ: vec![BBIndex(5)], - }, - BasicBlock { - start: StmtIndex(16), - pred: vec![BBIndex(3), BBIndex(4)], - succ: vec![BBIndex(1)], - }, - BasicBlock { - start: StmtIndex(18), - pred: vec![BBIndex(1)], - succ: vec![], - }, - ] - } - - // cfg from 19.4 with slighlty shuffled order of succ/pred - #[test] - fn reverse_postorder_fig_19_4() { - let mut cfg = cfg_fig_19_4(); - cfg[1].pred.swap(0, 1); - cfg[2].succ.swap(0, 1); - let rpostorder = vec![ - BBIndex(0), - BBIndex(1), - BBIndex(6), - BBIndex(2), - BBIndex(3), - BBIndex(4), - BBIndex(5), - ]; - let doms = immediate_dominators(&cfg, &rpostorder); - assert_eq!( - doms, - vec![ - BBIndex(0), - BBIndex(0), - BBIndex(1), - BBIndex(2), - BBIndex(2), - BBIndex(2), - BBIndex(1) - ] - ); - } - - #[test] - fn dominance_frontiers_fig_19_4() { - let cfg = cfg_fig_19_4(); - let order = to_reverse_postorder(&cfg); - let doms = immediate_dominators(&cfg, &order); - let dom_fronts = dominance_frontiers(&cfg, &doms) - .into_iter() - .map(|hs| hs.into_iter().collect::>()) - .collect::>(); - let should = vec![ - vec![], - vec![BBIndex(1)], - vec![BBIndex(1)], - vec![BBIndex(5)], - vec![BBIndex(5)], - vec![BBIndex(1)], - vec![], - ]; - assert_eq!(dom_fronts, should); - } - - #[test] - fn gather_phi_sets_fig_19_4() { - let func = FIG_19_4; - let mut errors = Vec::new(); - let fn_ast = ptx::FunctionBodyParser::new() - .parse(&mut errors, func) - .unwrap(); - assert_eq!(errors.len(), 0); - let mut constant_ids = HashMap::new(); - collect_label_ids(&mut constant_ids, &fn_ast); - assert_eq!(constant_ids.len(), 4); - - let mut type_check = HashMap::new(); - let registers = collect_var_definitions(&[], &fn_ast); - let (normalized_ids, mut unique_ids) = - normalize_identifiers(fn_ast, &constant_ids, &mut type_check, registers); - let type_check = RefCell::new(type_check); - let new_id = &mut |typ: Option| { - let to_insert = unique_ids; - { - let mut type_check = type_check.borrow_mut(); - typ.map(|t| (*type_check).insert(to_insert, t)); - } - unique_ids += 1; - to_insert - }; - let normalized_stmts = normalize_statements(normalized_ids, new_id); - let bbs = get_basic_blocks(&normalized_stmts); - let rpostorder = to_reverse_postorder(&bbs); - let doms = immediate_dominators(&bbs, &rpostorder); - let dom_fronts = dominance_frontiers(&bbs, &doms); - let phi = gather_phi_sets( - &normalized_stmts, - constant_ids.len() as u32, - unique_ids, - &bbs, - &dom_fronts, - ); - assert_eq!( - phi, - vec![ - HashSet::new(), - to_hashset(vec![5, 6]), - HashSet::new(), - HashSet::new(), - HashSet::new(), - to_hashset(vec![5, 6]), - HashSet::new() - ] - ); - } - - fn to_hashset(v: Vec) -> HashSet { - v.into_iter().collect::>() - } - - #[test] - fn ssa_rename_19_4() { - let func = FIG_19_4; - let mut errors = Vec::new(); - let fn_ast = ptx::FunctionBodyParser::new() - .parse(&mut errors, func) - .unwrap(); - assert_eq!(errors.len(), 0); - let (func, _, mut ssa_phis, unique_ids) = to_ssa(&[], fn_ast); - assert_phi_dst_id(unique_ids, &ssa_phis); - assert_dst_unique(&func, &ssa_phis); - sort_phi(&mut ssa_phis); - - let i1 = unique_ids; - let j1 = unique_ids + 1; - let j2 = get_dst_from_src(&ssa_phis[1], j1); - let j3 = get_dst(&func[10]); - let j4 = get_dst_from_src(&ssa_phis[5], j3); - let j5 = get_dst(&func[14]); - let k1 = unique_ids + 2; - let k2 = get_dst_from_src(&ssa_phis[1], k1); - let k3 = get_dst(&func[11]); - let k4 = get_dst_from_src(&ssa_phis[5], k3); - let k5 = get_dst(&func[15]); - let p1 = get_dst(&func[4]); - let q1 = get_dst(&func[7]); - let block_2 = get_label(&func[3]); - let if_false1 = get_label(&func[6]); - let if_false2 = get_label(&func[9]); - let block_6 = get_label(&func[13]); - let block_7 = get_label(&func[16]); - let block_4 = get_label(&func[18]); - - { - assert_eq!(get_ids(&func[0]), vec![i1]); - assert_eq!(get_ids(&func[1]), vec![j1]); - assert_eq!(get_ids(&func[2]), vec![k1]); - - assert_eq!( - ssa_phis[1], - to_phi(vec![(j2, vec![j4, j1]), (k2, vec![k4, k1])]) - ); - assert_eq!(get_ids(&func[3]), vec![block_2]); - assert_eq!(get_ids(&func[4]), vec![p1, k2]); - assert_eq!(get_ids(&func[5]), vec![p1, block_4, if_false1]); - - assert_eq!(get_ids(&func[6]), vec![if_false1]); - assert_eq!(get_ids(&func[7]), vec![q1, j2]); - assert_eq!(get_ids(&func[8]), vec![q1, block_6, if_false2]); - - assert_eq!(get_ids(&func[9]), vec![if_false2]); - assert_eq!(get_ids(&func[10]), vec![j3, i1]); - assert_eq!(get_ids(&func[11]), vec![k3, k2]); - assert_eq!(get_ids(&func[12]), vec![block_7]); - - assert_eq!(get_ids(&func[13]), vec![block_6]); - assert_eq!(get_ids(&func[14]), vec![j5, k2]); - assert_eq!(get_ids(&func[15]), vec![k5, k2]); - - assert_eq!( - ssa_phis[5], - to_phi(vec![(j4, vec![j3, j5]), (k4, vec![k3, k5])]) - ); - assert_eq!(get_ids(&func[16]), vec![block_7]); - assert_eq!(get_ids(&func[17]), vec![block_2]); - - assert_eq!(get_ids(&func[18]), vec![block_4]); - assert_eq!(get_ids(&func[19]), vec![]); - } - } - - fn assert_phi_dst_id(max_id: spirv::Word, phis: &[Vec]) { - for phi_set in phis { - for phi in phi_set { - assert!(phi.dst > max_id); - } - } - } - - fn assert_dst_unique(func: &[Statement], phis: &[Vec]) { - let mut seen = HashSet::new(); - for s in func { - s.visit_id(&mut |is_dst, id| { - if is_dst { - assert!(seen.insert(id)); - } - }); - } - for phi_set in phis { - for phi in phi_set { - assert!(seen.insert(phi.dst)); - } - } - } - - fn get_ids(s: &Statement) -> Vec { - let mut result = Vec::new(); - s.visit_id(&mut |_, id| { - result.push(id); - }); - result - } - - fn sort_phi(phis: &mut [Vec]) { - for phi_set in phis { - phi_set.sort_by_key(|phi| phi.dst); - } - } - - fn to_phi(raw: Vec<(spirv::Word, Vec)>) -> Vec { - let result = raw - .into_iter() - .map(|(dst, src)| PhiDef { - dst: dst, - src: src.into_iter().collect::>(), - }) - .collect::>(); - let mut result = [result]; - sort_phi(&mut result); - let [result] = result; - result - } - - fn get_dst(s: &Statement) -> spirv::Word { - let mut result = None; - s.visit_id(&mut |is_dst, id| { - if is_dst { - assert_eq!(result.replace(id), None); - } - }); - result.unwrap() - } - - fn get_label(s: &Statement) -> spirv::Word { - match s { - Statement::Label(id) => *id, - _ => panic!(), - } - } - - fn get_dst_from_src(phi: &[PhiDef], src: spirv::Word) -> spirv::Word { - for phi_set in phi { - if phi_set.src.contains(&src) { - return phi_set.dst; - } - } - panic!() - } - static SCALAR_TYPES: [ast::ScalarType; 15] = [ ast::ScalarType::B8, ast::ScalarType::B16,