From 4e9a71ed3884e66db666b0413f5efd4ff9d97a3a Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 20 Jul 2020 20:15:23 +0200 Subject: [PATCH] Update type lookup map when emitting new instructions during translation --- ptx/src/translate.rs | 137 ++++++++++++++++++++++++------------------- 1 file changed, 76 insertions(+), 61 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 6620666..0d86066 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -180,11 +180,23 @@ fn to_ssa<'a>( collect_arg_ids(&mut contant_ids, &mut type_check, &f_args); collect_label_ids(&mut contant_ids, &f_body); let registers = collect_var_definitions(&f_args, &f_body); - let (normalized_ids, unique_ids) = + let (normalized_ids, mut unique_ids) = normalize_identifiers(f_body, &contant_ids, &mut type_check, registers); - let (normalized_stmts, unique_ids) = normalize_statements(normalized_ids, unique_ids); - let (mut func_body, unique_ids) = - insert_implicit_conversions(normalized_stmts, unique_ids, &|x| type_check[&x]); + let type_check = RefCell::new(type_check); + let new_id = &mut |typ: Option| { + let to_insert = unique_ids; + { + let mut type_check = type_check.borrow_mut(); + typ.map(|t| (*type_check).insert(to_insert, t)); + } + unique_ids += 1; + to_insert + }; + let normalized_stmts = normalize_statements(normalized_ids, new_id); + let mut func_body = insert_implicit_conversions(normalized_stmts, new_id, &|x| { + let type_check = type_check.borrow(); + type_check[&x] + }); let bbs = get_basic_blocks(&func_body); let rpostorder = to_reverse_postorder(&bbs); let doms = immediate_dominators(&bbs, &rpostorder); @@ -202,22 +214,16 @@ fn to_ssa<'a>( fn normalize_statements( func: Vec>, - unique_ids: spirv::Word, -) -> (Vec, spirv::Word) { + new_id: &mut impl FnMut(Option) -> spirv::Word, +) -> Vec { let mut result = Vec::with_capacity(func.len()); - let mut id = unique_ids; - let new_id = &mut || { - let to_insert = id; - id += 1; - to_insert - }; for s in func { match s { ast::Statement::Label(id) => result.push(Statement::Label(id)), ast::Statement::Instruction(pred, inst) => { if let Some(pred) = pred { - let mut if_true = new_id(); - let mut if_false = new_id(); + let mut if_true = new_id(None); + let mut if_false = new_id(None); if pred.not { std::mem::swap(&mut if_true, &mut if_false); } @@ -245,13 +251,13 @@ fn normalize_statements( ast::Statement::Variable(_) => unreachable!(), } } - (result, id) + result } #[must_use] fn normalize_insert_instruction( func: &mut Vec, - new_id: &mut impl FnMut() -> spirv::Word, + new_id: &mut impl FnMut(Option) -> spirv::Word, instr: ast::Instruction, ) -> Instruction { match instr { @@ -302,7 +308,7 @@ fn normalize_insert_instruction( fn normalize_expand_arg2( func: &mut Vec, - new_id: &mut impl FnMut() -> spirv::Word, + new_id: &mut impl FnMut(Option) -> spirv::Word, inst_type: &impl Fn() -> Option, a: ast::Arg2, ) -> Arg2 { @@ -314,7 +320,7 @@ fn normalize_expand_arg2( fn normalize_expand_arg2mov( func: &mut Vec, - new_id: &mut impl FnMut() -> spirv::Word, + new_id: &mut impl FnMut(Option) -> spirv::Word, inst_type: &impl Fn() -> Option, a: ast::Arg2Mov, ) -> Arg2 { @@ -326,7 +332,7 @@ fn normalize_expand_arg2mov( fn normalize_expand_arg2st( func: &mut Vec, - new_id: &mut impl FnMut() -> spirv::Word, + new_id: &mut impl FnMut(Option) -> spirv::Word, inst_type: &impl Fn() -> Option, a: ast::Arg2St, ) -> Arg2St { @@ -338,7 +344,7 @@ fn normalize_expand_arg2st( fn normalize_expand_arg3( func: &mut Vec, - new_id: &mut impl FnMut() -> spirv::Word, + new_id: &mut impl FnMut(Option) -> spirv::Word, inst_type: &impl Fn() -> Option, a: ast::Arg3, ) -> Arg3 { @@ -351,7 +357,7 @@ fn normalize_expand_arg3( fn normalize_expand_arg4( func: &mut Vec, - new_id: &mut impl FnMut() -> spirv::Word, + new_id: &mut impl FnMut(Option) -> spirv::Word, inst_type: &impl Fn() -> Option, a: ast::Arg4, ) -> Arg4 { @@ -365,7 +371,7 @@ fn normalize_expand_arg4( fn normalize_expand_arg5( func: &mut Vec, - new_id: &mut impl FnMut() -> spirv::Word, + new_id: &mut impl FnMut(Option) -> spirv::Word, inst_type: &impl Fn() -> Option, a: ast::Arg5, ) -> Arg5 { @@ -380,7 +386,7 @@ fn normalize_expand_arg5( fn normalize_expand_operand( func: &mut Vec, - new_id: &mut impl FnMut() -> spirv::Word, + new_id: &mut impl FnMut(Option) -> spirv::Word, inst_type: &impl Fn() -> Option, opr: ast::Operand, ) -> spirv::Word { @@ -388,7 +394,7 @@ fn normalize_expand_operand( ast::Operand::Reg(r) => r, ast::Operand::Imm(x) => { if let Some(typ) = inst_type() { - let id = new_id(); + let id = new_id(Some(ast::Type::Scalar(typ))); func.push(Statement::Constant(ConstantDefinition { dst: id, typ: typ, @@ -405,7 +411,7 @@ fn normalize_expand_operand( fn normalize_expand_mov_operand( func: &mut Vec, - new_id: &mut impl FnMut() -> spirv::Word, + new_id: &mut impl FnMut(Option) -> spirv::Word, inst_type: &impl Fn() -> Option, opr: ast::MovOperand, ) -> spirv::Word { @@ -456,15 +462,9 @@ fn collect_var_definitions<'a>( */ fn insert_implicit_conversions ast::Type>( normalized_ids: Vec, - unique_ids: spirv::Word, + new_id: &mut impl FnMut(Option) -> spirv::Word, type_check: &TypeCheck, -) -> (Vec, spirv::Word) { - let mut id = unique_ids; - let new_id = &mut || { - let temp = id; - id += 1; - temp - }; +) -> Vec { let mut result = Vec::with_capacity(normalized_ids.len()); for s in normalized_ids.into_iter() { match s { @@ -518,7 +518,7 @@ fn insert_implicit_conversions ast::Type>( Statement::Converison(_) => unreachable!(), } } - (result, id) + result } fn get_function_type( @@ -2007,14 +2007,11 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { } } -fn insert_implicit_conversions_ld_src< - TypeCheck: Fn(spirv::Word) -> ast::Type, - NewId: FnMut() -> spirv::Word, ->( +fn insert_implicit_conversions_ld_src ast::Type>( func: &mut Vec, instr_type: ast::Type, type_check: &TypeCheck, - new_id: &mut NewId, + new_id: &mut impl FnMut(Option) -> spirv::Word, state_space: ast::LdStateSpace, src: spirv::Word, ) -> spirv::Word { @@ -2055,12 +2052,11 @@ fn insert_implicit_conversions_ld_src< fn insert_implicit_conversions_ld_src_impl< TypeCheck: Fn(spirv::Word) -> ast::Type, - NewId: FnMut() -> spirv::Word, ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option, >( func: &mut Vec, type_check: &TypeCheck, - new_id: &mut NewId, + new_id: &mut impl FnMut(Option) -> spirv::Word, instr_type: ast::Type, src: spirv::Word, should_convert: ShouldConvert, @@ -2099,15 +2095,15 @@ fn should_convert_ld_generic_src_to_bitcast( } #[must_use] -fn insert_conversion_src spirv::Word>( +fn insert_conversion_src( func: &mut Vec, - new_id: &mut NewId, + new_id: &mut impl FnMut(Option) -> spirv::Word, src: spirv::Word, src_type: ast::Type, instr_type: ast::Type, conv: ConversionKind, ) -> spirv::Word { - let temp_src = new_id(); + let temp_src = new_id(Some(instr_type)); func.push(Statement::Converison(ImplicitConversion { src: src, dst: temp_src, @@ -2121,7 +2117,6 @@ fn insert_conversion_src spirv::Word>( fn insert_with_implicit_conversion_dst< T, TypeCheck: Fn(spirv::Word) -> ast::Type, - NewId: FnMut() -> spirv::Word, ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option, Setter: Fn(&mut T) -> &mut spirv::Word, ToInstruction: FnOnce(T) -> Instruction, @@ -2129,7 +2124,7 @@ fn insert_with_implicit_conversion_dst< func: &mut Vec, instr_type: ast::ScalarType, type_check: &TypeCheck, - new_id: &mut NewId, + new_id: &mut impl FnMut(Option) -> spirv::Word, should_convert: ShouldConvert, mut t: T, setter: Setter, @@ -2146,15 +2141,15 @@ fn insert_with_implicit_conversion_dst< } #[must_use] -fn get_conversion_dst spirv::Word>( - new_id: &mut NewId, +fn get_conversion_dst( + new_id: &mut impl FnMut(Option) -> spirv::Word, dst: &mut spirv::Word, instr_type: ast::Type, dst_type: ast::Type, kind: ConversionKind, ) -> Statement { let original_dst = *dst; - let temp_dst = new_id(); + let temp_dst = new_id(Some(instr_type)); *dst = temp_dst; Statement::Converison(ImplicitConversion { src: temp_dst, @@ -2250,13 +2245,10 @@ fn should_convert_relaxed_dst( } } -fn insert_implicit_bitcasts< - TypeCheck: Fn(spirv::Word) -> ast::Type, - NewId: FnMut() -> spirv::Word, ->( +fn insert_implicit_bitcasts ast::Type>( func: &mut Vec, type_check: &TypeCheck, - new_id: &mut NewId, + new_id: &mut impl FnMut(Option) -> spirv::Word, mut instr: Instruction, ) { let mut dst_coercion = None; @@ -2662,9 +2654,20 @@ mod tests { let mut constant_ids = HashMap::new(); collect_label_ids(&mut constant_ids, &ast); let registers = collect_var_definitions(&[], &ast); - let (normalized_ids, unique_ids) = - normalize_identifiers(ast, &constant_ids, &mut HashMap::new(), registers); - let (normalized_stmts, _) = normalize_statements(normalized_ids, unique_ids); + let mut type_check = HashMap::new(); + let (normalized_ids, mut unique_ids) = + normalize_identifiers(ast, &constant_ids, &mut type_check, registers); + let type_check = RefCell::new(type_check); + let new_id = &mut |typ: Option| { + let to_insert = unique_ids; + { + let mut type_check = type_check.borrow_mut(); + typ.map(|t| (*type_check).insert(to_insert, t)); + } + unique_ids += 1; + to_insert + }; + let normalized_stmts = normalize_statements(normalized_ids, new_id); let mut bbs = get_basic_blocks(&normalized_stmts); bbs.iter_mut().for_each(sort_pred_succ); assert_eq!( @@ -2811,10 +2814,22 @@ mod tests { let mut constant_ids = HashMap::new(); collect_label_ids(&mut constant_ids, &fn_ast); assert_eq!(constant_ids.len(), 4); + + let mut type_check = HashMap::new(); let registers = collect_var_definitions(&[], &fn_ast); - let (normalized_ids, unique_ids) = - normalize_identifiers(fn_ast, &constant_ids, &mut HashMap::new(), registers); - let (normalized_stmts, max_id) = normalize_statements(normalized_ids, unique_ids); + let (normalized_ids, mut unique_ids) = + normalize_identifiers(fn_ast, &constant_ids, &mut type_check, registers); + let type_check = RefCell::new(type_check); + let new_id = &mut |typ: Option| { + let to_insert = unique_ids; + { + let mut type_check = type_check.borrow_mut(); + typ.map(|t| (*type_check).insert(to_insert, t)); + } + unique_ids += 1; + to_insert + }; + let normalized_stmts = normalize_statements(normalized_ids, new_id); let bbs = get_basic_blocks(&normalized_stmts); let rpostorder = to_reverse_postorder(&bbs); let doms = immediate_dominators(&bbs, &rpostorder); @@ -2822,7 +2837,7 @@ mod tests { let phi = gather_phi_sets( &normalized_stmts, constant_ids.len() as u32, - max_id, + unique_ids, &bbs, &dom_fronts, );