mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Slightly improve stateful optimization
This commit is contained in:
parent
994cfb3386
commit
f0771e1fb6
2 changed files with 95 additions and 82 deletions
|
@ -1239,8 +1239,8 @@ fn to_ssa<'input, 'b>(
|
|||
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
|
||||
let typed_statements =
|
||||
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
||||
let typed_statements =
|
||||
convert_to_stateful_memory_access(&func_decl, typed_statements, &mut numeric_id_defs)?;
|
||||
let (func_decl, typed_statements) =
|
||||
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
|
||||
let ssa_statements = insert_mem_ssa_statements(
|
||||
typed_statements,
|
||||
&mut numeric_id_defs,
|
||||
|
@ -4311,14 +4311,27 @@ fn expand_map_variables<'a, 'b>(
|
|||
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
|
||||
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
|
||||
// argument expansion
|
||||
// TODO: propagate through calls?
|
||||
// TODO: propagate out of calls and into calls
|
||||
fn convert_to_stateful_memory_access<'a, 'input>(
|
||||
func_args: &Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
|
||||
func_args: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
|
||||
func_body: Vec<TypedStatement>,
|
||||
id_defs: &mut NumericIdResolver<'a>,
|
||||
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||
let mut func_args = func_args.borrow_mut();
|
||||
let func_args_64bit = (*func_args)
|
||||
) -> Result<
|
||||
(
|
||||
Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
|
||||
Vec<TypedStatement>,
|
||||
),
|
||||
TranslateError,
|
||||
> {
|
||||
let mut method_decl = func_args.borrow_mut();
|
||||
if !method_decl.name.is_kernel() {
|
||||
drop(method_decl);
|
||||
return Ok((func_args, func_body));
|
||||
}
|
||||
if Rc::strong_count(&func_args) != 1 {
|
||||
return Err(error_unreachable());
|
||||
}
|
||||
let func_args_64bit = (*method_decl)
|
||||
.input_arguments
|
||||
.iter()
|
||||
.filter_map(|arg| match arg.v_type {
|
||||
|
@ -4462,6 +4475,18 @@ fn convert_to_stateful_memory_access<'a, 'input>(
|
|||
}));
|
||||
remapped_ids.insert(reg, new_id);
|
||||
}
|
||||
for arg in (*method_decl).input_arguments.iter_mut() {
|
||||
let new_id = id_defs.register_variable(
|
||||
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||
ast::StateSpace::Reg,
|
||||
);
|
||||
let old_name = arg.name;
|
||||
if func_args_ptr.contains(&arg.name) {
|
||||
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
|
||||
arg.name = new_id;
|
||||
}
|
||||
remapped_ids.insert(old_name, new_id);
|
||||
}
|
||||
for statement in func_body {
|
||||
match statement {
|
||||
l @ Statement::Label(_) => result.push(l),
|
||||
|
@ -4550,7 +4575,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
|
|||
convert_to_stateful_memory_access_postprocess(
|
||||
id_defs,
|
||||
&remapped_ids,
|
||||
&func_args_ptr,
|
||||
&mut result,
|
||||
&mut post_statements,
|
||||
arg_desc,
|
||||
|
@ -4567,7 +4591,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
|
|||
convert_to_stateful_memory_access_postprocess(
|
||||
id_defs,
|
||||
&remapped_ids,
|
||||
&func_args_ptr,
|
||||
&mut result,
|
||||
&mut post_statements,
|
||||
arg_desc,
|
||||
|
@ -4584,7 +4607,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
|
|||
convert_to_stateful_memory_access_postprocess(
|
||||
id_defs,
|
||||
&remapped_ids,
|
||||
&func_args_ptr,
|
||||
&mut result,
|
||||
&mut post_statements,
|
||||
arg_desc,
|
||||
|
@ -4597,81 +4619,69 @@ fn convert_to_stateful_memory_access<'a, 'input>(
|
|||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
for arg in (*func_args).input_arguments.iter_mut() {
|
||||
if func_args_ptr.contains(&arg.name) {
|
||||
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
|
||||
arg.state_space = ast::StateSpace::Reg;
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
drop(method_decl);
|
||||
Ok((func_args, result))
|
||||
}
|
||||
|
||||
fn convert_to_stateful_memory_access_postprocess(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
remapped_ids: &HashMap<spirv::Word, spirv::Word>,
|
||||
func_args_ptr: &HashSet<spirv::Word>,
|
||||
result: &mut Vec<TypedStatement>,
|
||||
post_statements: &mut Vec<TypedStatement>,
|
||||
arg_desc: ArgumentDescriptor<spirv::Word>,
|
||||
expected_type: Option<(&ast::Type, ast::StateSpace)>,
|
||||
) -> Result<spirv::Word, TranslateError> {
|
||||
Ok(
|
||||
match remapped_ids
|
||||
.get(&arg_desc.op)
|
||||
.or_else(|| func_args_ptr.get(&arg_desc.op))
|
||||
{
|
||||
Some(new_id) => {
|
||||
let (new_operand_type, new_operand_space, is_variable) =
|
||||
id_defs.get_typed(*new_id)?;
|
||||
if let Some((expected_type, expected_space)) = expected_type {
|
||||
let implicit_conversion = arg_desc
|
||||
.non_default_implicit_conversion
|
||||
.unwrap_or(default_implicit_conversion);
|
||||
if implicit_conversion(
|
||||
(new_operand_space, &new_operand_type),
|
||||
(expected_space, expected_type),
|
||||
)
|
||||
.is_ok()
|
||||
{
|
||||
return Ok(*new_id);
|
||||
}
|
||||
}
|
||||
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?;
|
||||
let new_operand_type_clone = new_operand_type.clone();
|
||||
let converting_id = id_defs
|
||||
.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
|
||||
let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) {
|
||||
ConversionKind::Default
|
||||
} else {
|
||||
ConversionKind::PtrToPtr
|
||||
};
|
||||
if arg_desc.is_dst {
|
||||
post_statements.push(Statement::Conversion(ImplicitConversion {
|
||||
src: converting_id,
|
||||
dst: *new_id,
|
||||
from_type: old_operand_type,
|
||||
from_space: old_operand_space,
|
||||
to_type: new_operand_type,
|
||||
to_space: new_operand_space,
|
||||
kind,
|
||||
}));
|
||||
converting_id
|
||||
} else {
|
||||
result.push(Statement::Conversion(ImplicitConversion {
|
||||
src: *new_id,
|
||||
dst: converting_id,
|
||||
from_type: new_operand_type,
|
||||
from_space: new_operand_space,
|
||||
to_type: old_operand_type,
|
||||
to_space: old_operand_space,
|
||||
kind,
|
||||
}));
|
||||
converting_id
|
||||
Ok(match remapped_ids.get(&arg_desc.op) {
|
||||
Some(new_id) => {
|
||||
let (new_operand_type, new_operand_space, is_variable) = id_defs.get_typed(*new_id)?;
|
||||
if let Some((expected_type, expected_space)) = expected_type {
|
||||
let implicit_conversion = arg_desc
|
||||
.non_default_implicit_conversion
|
||||
.unwrap_or(default_implicit_conversion);
|
||||
if implicit_conversion(
|
||||
(new_operand_space, &new_operand_type),
|
||||
(expected_space, expected_type),
|
||||
)
|
||||
.is_ok()
|
||||
{
|
||||
return Ok(*new_id);
|
||||
}
|
||||
}
|
||||
None => arg_desc.op,
|
||||
},
|
||||
)
|
||||
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?;
|
||||
let new_operand_type_clone = new_operand_type.clone();
|
||||
let converting_id =
|
||||
id_defs.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
|
||||
let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) {
|
||||
ConversionKind::Default
|
||||
} else {
|
||||
ConversionKind::PtrToPtr
|
||||
};
|
||||
if arg_desc.is_dst {
|
||||
post_statements.push(Statement::Conversion(ImplicitConversion {
|
||||
src: converting_id,
|
||||
dst: *new_id,
|
||||
from_type: old_operand_type,
|
||||
from_space: old_operand_space,
|
||||
to_type: new_operand_type,
|
||||
to_space: new_operand_space,
|
||||
kind,
|
||||
}));
|
||||
converting_id
|
||||
} else {
|
||||
result.push(Statement::Conversion(ImplicitConversion {
|
||||
src: *new_id,
|
||||
dst: converting_id,
|
||||
from_type: new_operand_type,
|
||||
from_space: new_operand_space,
|
||||
to_type: old_operand_type,
|
||||
to_space: old_operand_space,
|
||||
kind,
|
||||
}));
|
||||
converting_id
|
||||
}
|
||||
}
|
||||
None => arg_desc.op,
|
||||
})
|
||||
}
|
||||
|
||||
fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {
|
||||
|
|
|
@ -219,15 +219,18 @@ unsafe fn to_str<T>(image: *const T) -> Option<&'static str> {
|
|||
|
||||
fn directive_to_kernel(dir: &ast::Directive<ast::ParsedArgParams>) -> Option<(String, Vec<usize>)> {
|
||||
match dir {
|
||||
ast::Directive::Method(ast::Function {
|
||||
func_directive:
|
||||
ast::MethodDeclaration {
|
||||
name: ast::MethodName::Kernel(name),
|
||||
input_arguments,
|
||||
..
|
||||
},
|
||||
..
|
||||
}) => {
|
||||
ast::Directive::Method(
|
||||
_,
|
||||
ast::Function {
|
||||
func_directive:
|
||||
ast::MethodDeclaration {
|
||||
name: ast::MethodName::Kernel(name),
|
||||
input_arguments,
|
||||
..
|
||||
},
|
||||
..
|
||||
},
|
||||
) => {
|
||||
let arg_sizes = input_arguments
|
||||
.iter()
|
||||
.map(|arg| ast::Type::from(arg.v_type.clone()).size_of())
|
||||
|
|
Loading…
Add table
Reference in a new issue