Slightly improve stateful optimization

This commit is contained in:
Andrzej Janik 2021-06-11 00:00:56 +02:00
commit f0771e1fb6
2 changed files with 95 additions and 82 deletions

View file

@ -1239,8 +1239,8 @@ fn to_ssa<'input, 'b>(
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
let typed_statements =
convert_to_stateful_memory_access(&func_decl, typed_statements, &mut numeric_id_defs)?;
let (func_decl, typed_statements) =
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
typed_statements,
&mut numeric_id_defs,
@ -4311,14 +4311,27 @@ fn expand_map_variables<'a, 'b>(
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
// argument expansion
// TODO: propagate through calls?
// TODO: propagate out of calls and into calls
fn convert_to_stateful_memory_access<'a, 'input>(
func_args: &Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
func_args: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
func_body: Vec<TypedStatement>,
id_defs: &mut NumericIdResolver<'a>,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut func_args = func_args.borrow_mut();
let func_args_64bit = (*func_args)
) -> Result<
(
Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
Vec<TypedStatement>,
),
TranslateError,
> {
let mut method_decl = func_args.borrow_mut();
if !method_decl.name.is_kernel() {
drop(method_decl);
return Ok((func_args, func_body));
}
if Rc::strong_count(&func_args) != 1 {
return Err(error_unreachable());
}
let func_args_64bit = (*method_decl)
.input_arguments
.iter()
.filter_map(|arg| match arg.v_type {
@ -4462,6 +4475,18 @@ fn convert_to_stateful_memory_access<'a, 'input>(
}));
remapped_ids.insert(reg, new_id);
}
for arg in (*method_decl).input_arguments.iter_mut() {
let new_id = id_defs.register_variable(
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
ast::StateSpace::Reg,
);
let old_name = arg.name;
if func_args_ptr.contains(&arg.name) {
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
arg.name = new_id;
}
remapped_ids.insert(old_name, new_id);
}
for statement in func_body {
match statement {
l @ Statement::Label(_) => result.push(l),
@ -4550,7 +4575,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
&func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
@ -4567,7 +4591,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
&func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
@ -4584,7 +4607,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
&func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
@ -4597,32 +4619,21 @@ fn convert_to_stateful_memory_access<'a, 'input>(
_ => return Err(error_unreachable()),
}
}
for arg in (*func_args).input_arguments.iter_mut() {
if func_args_ptr.contains(&arg.name) {
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
arg.state_space = ast::StateSpace::Reg;
}
}
Ok(result)
drop(method_decl);
Ok((func_args, result))
}
fn convert_to_stateful_memory_access_postprocess(
id_defs: &mut NumericIdResolver,
remapped_ids: &HashMap<spirv::Word, spirv::Word>,
func_args_ptr: &HashSet<spirv::Word>,
result: &mut Vec<TypedStatement>,
post_statements: &mut Vec<TypedStatement>,
arg_desc: ArgumentDescriptor<spirv::Word>,
expected_type: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
Ok(
match remapped_ids
.get(&arg_desc.op)
.or_else(|| func_args_ptr.get(&arg_desc.op))
{
Ok(match remapped_ids.get(&arg_desc.op) {
Some(new_id) => {
let (new_operand_type, new_operand_space, is_variable) =
id_defs.get_typed(*new_id)?;
let (new_operand_type, new_operand_space, is_variable) = id_defs.get_typed(*new_id)?;
if let Some((expected_type, expected_space)) = expected_type {
let implicit_conversion = arg_desc
.non_default_implicit_conversion
@ -4638,8 +4649,8 @@ fn convert_to_stateful_memory_access_postprocess(
}
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?;
let new_operand_type_clone = new_operand_type.clone();
let converting_id = id_defs
.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
let converting_id =
id_defs.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) {
ConversionKind::Default
} else {
@ -4670,8 +4681,7 @@ fn convert_to_stateful_memory_access_postprocess(
}
}
None => arg_desc.op,
},
)
})
}
fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {

View file

@ -219,7 +219,9 @@ unsafe fn to_str<T>(image: *const T) -> Option<&'static str> {
fn directive_to_kernel(dir: &ast::Directive<ast::ParsedArgParams>) -> Option<(String, Vec<usize>)> {
match dir {
ast::Directive::Method(ast::Function {
ast::Directive::Method(
_,
ast::Function {
func_directive:
ast::MethodDeclaration {
name: ast::MethodName::Kernel(name),
@ -227,7 +229,8 @@ fn directive_to_kernel(dir: &ast::Directive<ast::ParsedArgParams>) -> Option<(St
..
},
..
}) => {
},
) => {
let arg_sizes = input_arguments
.iter()
.map(|arg| ast::Type::from(arg.v_type.clone()).size_of())