Slightly improve stateful optimization

This commit is contained in:
Andrzej Janik 2021-06-11 00:00:56 +02:00
parent 994cfb3386
commit f0771e1fb6
2 changed files with 95 additions and 82 deletions

View file

@ -1239,8 +1239,8 @@ fn to_ssa<'input, 'b>(
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
let typed_statements =
convert_to_stateful_memory_access(&func_decl, typed_statements, &mut numeric_id_defs)?;
let (func_decl, typed_statements) =
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
typed_statements,
&mut numeric_id_defs,
@ -4311,14 +4311,27 @@ fn expand_map_variables<'a, 'b>(
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
// argument expansion
// TODO: propagate through calls?
// TODO: propagate out of calls and into calls
fn convert_to_stateful_memory_access<'a, 'input>(
func_args: &Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
func_args: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
func_body: Vec<TypedStatement>,
id_defs: &mut NumericIdResolver<'a>,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut func_args = func_args.borrow_mut();
let func_args_64bit = (*func_args)
) -> Result<
(
Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
Vec<TypedStatement>,
),
TranslateError,
> {
let mut method_decl = func_args.borrow_mut();
if !method_decl.name.is_kernel() {
drop(method_decl);
return Ok((func_args, func_body));
}
if Rc::strong_count(&func_args) != 1 {
return Err(error_unreachable());
}
let func_args_64bit = (*method_decl)
.input_arguments
.iter()
.filter_map(|arg| match arg.v_type {
@ -4462,6 +4475,18 @@ fn convert_to_stateful_memory_access<'a, 'input>(
}));
remapped_ids.insert(reg, new_id);
}
for arg in (*method_decl).input_arguments.iter_mut() {
let new_id = id_defs.register_variable(
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
ast::StateSpace::Reg,
);
let old_name = arg.name;
if func_args_ptr.contains(&arg.name) {
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
arg.name = new_id;
}
remapped_ids.insert(old_name, new_id);
}
for statement in func_body {
match statement {
l @ Statement::Label(_) => result.push(l),
@ -4550,7 +4575,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
&func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
@ -4567,7 +4591,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
&func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
@ -4584,7 +4607,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
&func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
@ -4597,81 +4619,69 @@ fn convert_to_stateful_memory_access<'a, 'input>(
_ => return Err(error_unreachable()),
}
}
for arg in (*func_args).input_arguments.iter_mut() {
if func_args_ptr.contains(&arg.name) {
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
arg.state_space = ast::StateSpace::Reg;
}
}
Ok(result)
drop(method_decl);
Ok((func_args, result))
}
fn convert_to_stateful_memory_access_postprocess(
id_defs: &mut NumericIdResolver,
remapped_ids: &HashMap<spirv::Word, spirv::Word>,
func_args_ptr: &HashSet<spirv::Word>,
result: &mut Vec<TypedStatement>,
post_statements: &mut Vec<TypedStatement>,
arg_desc: ArgumentDescriptor<spirv::Word>,
expected_type: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
Ok(
match remapped_ids
.get(&arg_desc.op)
.or_else(|| func_args_ptr.get(&arg_desc.op))
{
Some(new_id) => {
let (new_operand_type, new_operand_space, is_variable) =
id_defs.get_typed(*new_id)?;
if let Some((expected_type, expected_space)) = expected_type {
let implicit_conversion = arg_desc
.non_default_implicit_conversion
.unwrap_or(default_implicit_conversion);
if implicit_conversion(
(new_operand_space, &new_operand_type),
(expected_space, expected_type),
)
.is_ok()
{
return Ok(*new_id);
}
}
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?;
let new_operand_type_clone = new_operand_type.clone();
let converting_id = id_defs
.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) {
ConversionKind::Default
} else {
ConversionKind::PtrToPtr
};
if arg_desc.is_dst {
post_statements.push(Statement::Conversion(ImplicitConversion {
src: converting_id,
dst: *new_id,
from_type: old_operand_type,
from_space: old_operand_space,
to_type: new_operand_type,
to_space: new_operand_space,
kind,
}));
converting_id
} else {
result.push(Statement::Conversion(ImplicitConversion {
src: *new_id,
dst: converting_id,
from_type: new_operand_type,
from_space: new_operand_space,
to_type: old_operand_type,
to_space: old_operand_space,
kind,
}));
converting_id
Ok(match remapped_ids.get(&arg_desc.op) {
Some(new_id) => {
let (new_operand_type, new_operand_space, is_variable) = id_defs.get_typed(*new_id)?;
if let Some((expected_type, expected_space)) = expected_type {
let implicit_conversion = arg_desc
.non_default_implicit_conversion
.unwrap_or(default_implicit_conversion);
if implicit_conversion(
(new_operand_space, &new_operand_type),
(expected_space, expected_type),
)
.is_ok()
{
return Ok(*new_id);
}
}
None => arg_desc.op,
},
)
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?;
let new_operand_type_clone = new_operand_type.clone();
let converting_id =
id_defs.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) {
ConversionKind::Default
} else {
ConversionKind::PtrToPtr
};
if arg_desc.is_dst {
post_statements.push(Statement::Conversion(ImplicitConversion {
src: converting_id,
dst: *new_id,
from_type: old_operand_type,
from_space: old_operand_space,
to_type: new_operand_type,
to_space: new_operand_space,
kind,
}));
converting_id
} else {
result.push(Statement::Conversion(ImplicitConversion {
src: *new_id,
dst: converting_id,
from_type: new_operand_type,
from_space: new_operand_space,
to_type: old_operand_type,
to_space: old_operand_space,
kind,
}));
converting_id
}
}
None => arg_desc.op,
})
}
fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {

View file

@ -219,15 +219,18 @@ unsafe fn to_str<T>(image: *const T) -> Option<&'static str> {
fn directive_to_kernel(dir: &ast::Directive<ast::ParsedArgParams>) -> Option<(String, Vec<usize>)> {
match dir {
ast::Directive::Method(ast::Function {
func_directive:
ast::MethodDeclaration {
name: ast::MethodName::Kernel(name),
input_arguments,
..
},
..
}) => {
ast::Directive::Method(
_,
ast::Function {
func_directive:
ast::MethodDeclaration {
name: ast::MethodName::Kernel(name),
input_arguments,
..
},
..
},
) => {
let arg_sizes = input_arguments
.iter()
.map(|arg| ast::Type::from(arg.v_type.clone()).size_of())