Update type lookup map when emitting new instructions during translation

This commit is contained in:
Andrzej Janik 2020-07-20 20:15:23 +02:00
parent 872d69c714
commit 4e9a71ed38

View file

@ -180,11 +180,23 @@ fn to_ssa<'a>(
collect_arg_ids(&mut contant_ids, &mut type_check, &f_args);
collect_label_ids(&mut contant_ids, &f_body);
let registers = collect_var_definitions(&f_args, &f_body);
let (normalized_ids, unique_ids) =
let (normalized_ids, mut unique_ids) =
normalize_identifiers(f_body, &contant_ids, &mut type_check, registers);
let (normalized_stmts, unique_ids) = normalize_statements(normalized_ids, unique_ids);
let (mut func_body, unique_ids) =
insert_implicit_conversions(normalized_stmts, unique_ids, &|x| type_check[&x]);
let type_check = RefCell::new(type_check);
let new_id = &mut |typ: Option<ast::Type>| {
let to_insert = unique_ids;
{
let mut type_check = type_check.borrow_mut();
typ.map(|t| (*type_check).insert(to_insert, t));
}
unique_ids += 1;
to_insert
};
let normalized_stmts = normalize_statements(normalized_ids, new_id);
let mut func_body = insert_implicit_conversions(normalized_stmts, new_id, &|x| {
let type_check = type_check.borrow();
type_check[&x]
});
let bbs = get_basic_blocks(&func_body);
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
@ -202,22 +214,16 @@ fn to_ssa<'a>(
fn normalize_statements(
func: Vec<ast::Statement<spirv::Word>>,
unique_ids: spirv::Word,
) -> (Vec<Statement>, spirv::Word) {
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
) -> Vec<Statement> {
let mut result = Vec::with_capacity(func.len());
let mut id = unique_ids;
let new_id = &mut || {
let to_insert = id;
id += 1;
to_insert
};
for s in func {
match s {
ast::Statement::Label(id) => result.push(Statement::Label(id)),
ast::Statement::Instruction(pred, inst) => {
if let Some(pred) = pred {
let mut if_true = new_id();
let mut if_false = new_id();
let mut if_true = new_id(None);
let mut if_false = new_id(None);
if pred.not {
std::mem::swap(&mut if_true, &mut if_false);
}
@ -245,13 +251,13 @@ fn normalize_statements(
ast::Statement::Variable(_) => unreachable!(),
}
}
(result, id)
result
}
#[must_use]
fn normalize_insert_instruction(
func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
instr: ast::Instruction<spirv::Word>,
) -> Instruction {
match instr {
@ -302,7 +308,7 @@ fn normalize_insert_instruction(
fn normalize_expand_arg2(
func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2<spirv::Word>,
) -> Arg2 {
@ -314,7 +320,7 @@ fn normalize_expand_arg2(
fn normalize_expand_arg2mov(
func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2Mov<spirv::Word>,
) -> Arg2 {
@ -326,7 +332,7 @@ fn normalize_expand_arg2mov(
fn normalize_expand_arg2st(
func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2St<spirv::Word>,
) -> Arg2St {
@ -338,7 +344,7 @@ fn normalize_expand_arg2st(
fn normalize_expand_arg3(
func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg3<spirv::Word>,
) -> Arg3 {
@ -351,7 +357,7 @@ fn normalize_expand_arg3(
fn normalize_expand_arg4(
func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg4<spirv::Word>,
) -> Arg4 {
@ -365,7 +371,7 @@ fn normalize_expand_arg4(
fn normalize_expand_arg5(
func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg5<spirv::Word>,
) -> Arg5 {
@ -380,7 +386,7 @@ fn normalize_expand_arg5(
fn normalize_expand_operand(
func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
opr: ast::Operand<spirv::Word>,
) -> spirv::Word {
@ -388,7 +394,7 @@ fn normalize_expand_operand(
ast::Operand::Reg(r) => r,
ast::Operand::Imm(x) => {
if let Some(typ) = inst_type() {
let id = new_id();
let id = new_id(Some(ast::Type::Scalar(typ)));
func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: typ,
@ -405,7 +411,7 @@ fn normalize_expand_operand(
fn normalize_expand_mov_operand(
func: &mut Vec<Statement>,
new_id: &mut impl FnMut() -> spirv::Word,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
opr: ast::MovOperand<spirv::Word>,
) -> spirv::Word {
@ -456,15 +462,9 @@ fn collect_var_definitions<'a>(
*/
fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
normalized_ids: Vec<Statement>,
unique_ids: spirv::Word,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
type_check: &TypeCheck,
) -> (Vec<Statement>, spirv::Word) {
let mut id = unique_ids;
let new_id = &mut || {
let temp = id;
id += 1;
temp
};
) -> Vec<Statement> {
let mut result = Vec::with_capacity(normalized_ids.len());
for s in normalized_ids.into_iter() {
match s {
@ -518,7 +518,7 @@ fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
Statement::Converison(_) => unreachable!(),
}
}
(result, id)
result
}
fn get_function_type(
@ -2007,14 +2007,11 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
}
}
fn insert_implicit_conversions_ld_src<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
>(
fn insert_implicit_conversions_ld_src<TypeCheck: Fn(spirv::Word) -> ast::Type>(
func: &mut Vec<Statement>,
instr_type: ast::Type,
type_check: &TypeCheck,
new_id: &mut NewId,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
state_space: ast::LdStateSpace,
src: spirv::Word,
) -> spirv::Word {
@ -2055,12 +2052,11 @@ fn insert_implicit_conversions_ld_src<
fn insert_implicit_conversions_ld_src_impl<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
>(
func: &mut Vec<Statement>,
type_check: &TypeCheck,
new_id: &mut NewId,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
instr_type: ast::Type,
src: spirv::Word,
should_convert: ShouldConvert,
@ -2099,15 +2095,15 @@ fn should_convert_ld_generic_src_to_bitcast(
}
#[must_use]
fn insert_conversion_src<NewId: FnMut() -> spirv::Word>(
fn insert_conversion_src(
func: &mut Vec<Statement>,
new_id: &mut NewId,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
src: spirv::Word,
src_type: ast::Type,
instr_type: ast::Type,
conv: ConversionKind,
) -> spirv::Word {
let temp_src = new_id();
let temp_src = new_id(Some(instr_type));
func.push(Statement::Converison(ImplicitConversion {
src: src,
dst: temp_src,
@ -2121,7 +2117,6 @@ fn insert_conversion_src<NewId: FnMut() -> spirv::Word>(
fn insert_with_implicit_conversion_dst<
T,
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
Setter: Fn(&mut T) -> &mut spirv::Word,
ToInstruction: FnOnce(T) -> Instruction,
@ -2129,7 +2124,7 @@ fn insert_with_implicit_conversion_dst<
func: &mut Vec<Statement>,
instr_type: ast::ScalarType,
type_check: &TypeCheck,
new_id: &mut NewId,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
should_convert: ShouldConvert,
mut t: T,
setter: Setter,
@ -2146,15 +2141,15 @@ fn insert_with_implicit_conversion_dst<
}
#[must_use]
fn get_conversion_dst<NewId: FnMut() -> spirv::Word>(
new_id: &mut NewId,
fn get_conversion_dst(
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
dst: &mut spirv::Word,
instr_type: ast::Type,
dst_type: ast::Type,
kind: ConversionKind,
) -> Statement {
let original_dst = *dst;
let temp_dst = new_id();
let temp_dst = new_id(Some(instr_type));
*dst = temp_dst;
Statement::Converison(ImplicitConversion {
src: temp_dst,
@ -2250,13 +2245,10 @@ fn should_convert_relaxed_dst(
}
}
fn insert_implicit_bitcasts<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
>(
fn insert_implicit_bitcasts<TypeCheck: Fn(spirv::Word) -> ast::Type>(
func: &mut Vec<Statement>,
type_check: &TypeCheck,
new_id: &mut NewId,
new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
mut instr: Instruction,
) {
let mut dst_coercion = None;
@ -2662,9 +2654,20 @@ mod tests {
let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &ast);
let registers = collect_var_definitions(&[], &ast);
let (normalized_ids, unique_ids) =
normalize_identifiers(ast, &constant_ids, &mut HashMap::new(), registers);
let (normalized_stmts, _) = normalize_statements(normalized_ids, unique_ids);
let mut type_check = HashMap::new();
let (normalized_ids, mut unique_ids) =
normalize_identifiers(ast, &constant_ids, &mut type_check, registers);
let type_check = RefCell::new(type_check);
let new_id = &mut |typ: Option<ast::Type>| {
let to_insert = unique_ids;
{
let mut type_check = type_check.borrow_mut();
typ.map(|t| (*type_check).insert(to_insert, t));
}
unique_ids += 1;
to_insert
};
let normalized_stmts = normalize_statements(normalized_ids, new_id);
let mut bbs = get_basic_blocks(&normalized_stmts);
bbs.iter_mut().for_each(sort_pred_succ);
assert_eq!(
@ -2811,10 +2814,22 @@ mod tests {
let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &fn_ast);
assert_eq!(constant_ids.len(), 4);
let mut type_check = HashMap::new();
let registers = collect_var_definitions(&[], &fn_ast);
let (normalized_ids, unique_ids) =
normalize_identifiers(fn_ast, &constant_ids, &mut HashMap::new(), registers);
let (normalized_stmts, max_id) = normalize_statements(normalized_ids, unique_ids);
let (normalized_ids, mut unique_ids) =
normalize_identifiers(fn_ast, &constant_ids, &mut type_check, registers);
let type_check = RefCell::new(type_check);
let new_id = &mut |typ: Option<ast::Type>| {
let to_insert = unique_ids;
{
let mut type_check = type_check.borrow_mut();
typ.map(|t| (*type_check).insert(to_insert, t));
}
unique_ids += 1;
to_insert
};
let normalized_stmts = normalize_statements(normalized_ids, new_id);
let bbs = get_basic_blocks(&normalized_stmts);
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
@ -2822,7 +2837,7 @@ mod tests {
let phi = gather_phi_sets(
&normalized_stmts,
constant_ids.len() as u32,
max_id,
unique_ids,
&bbs,
&dom_fronts,
);