mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-03 14:50:53 +00:00
Slightly improve stateful optimization
This commit is contained in:
parent
994cfb3386
commit
f0771e1fb6
2 changed files with 95 additions and 82 deletions
|
@ -1239,8 +1239,8 @@ fn to_ssa<'input, 'b>(
|
||||||
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
|
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
|
||||||
let typed_statements =
|
let typed_statements =
|
||||||
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
||||||
let typed_statements =
|
let (func_decl, typed_statements) =
|
||||||
convert_to_stateful_memory_access(&func_decl, typed_statements, &mut numeric_id_defs)?;
|
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
|
||||||
let ssa_statements = insert_mem_ssa_statements(
|
let ssa_statements = insert_mem_ssa_statements(
|
||||||
typed_statements,
|
typed_statements,
|
||||||
&mut numeric_id_defs,
|
&mut numeric_id_defs,
|
||||||
|
@ -4311,14 +4311,27 @@ fn expand_map_variables<'a, 'b>(
|
||||||
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
|
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
|
||||||
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
|
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
|
||||||
// argument expansion
|
// argument expansion
|
||||||
// TODO: propagate through calls?
|
// TODO: propagate out of calls and into calls
|
||||||
fn convert_to_stateful_memory_access<'a, 'input>(
|
fn convert_to_stateful_memory_access<'a, 'input>(
|
||||||
func_args: &Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
|
func_args: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
|
||||||
func_body: Vec<TypedStatement>,
|
func_body: Vec<TypedStatement>,
|
||||||
id_defs: &mut NumericIdResolver<'a>,
|
id_defs: &mut NumericIdResolver<'a>,
|
||||||
) -> Result<Vec<TypedStatement>, TranslateError> {
|
) -> Result<
|
||||||
let mut func_args = func_args.borrow_mut();
|
(
|
||||||
let func_args_64bit = (*func_args)
|
Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
|
||||||
|
Vec<TypedStatement>,
|
||||||
|
),
|
||||||
|
TranslateError,
|
||||||
|
> {
|
||||||
|
let mut method_decl = func_args.borrow_mut();
|
||||||
|
if !method_decl.name.is_kernel() {
|
||||||
|
drop(method_decl);
|
||||||
|
return Ok((func_args, func_body));
|
||||||
|
}
|
||||||
|
if Rc::strong_count(&func_args) != 1 {
|
||||||
|
return Err(error_unreachable());
|
||||||
|
}
|
||||||
|
let func_args_64bit = (*method_decl)
|
||||||
.input_arguments
|
.input_arguments
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|arg| match arg.v_type {
|
.filter_map(|arg| match arg.v_type {
|
||||||
|
@ -4462,6 +4475,18 @@ fn convert_to_stateful_memory_access<'a, 'input>(
|
||||||
}));
|
}));
|
||||||
remapped_ids.insert(reg, new_id);
|
remapped_ids.insert(reg, new_id);
|
||||||
}
|
}
|
||||||
|
for arg in (*method_decl).input_arguments.iter_mut() {
|
||||||
|
let new_id = id_defs.register_variable(
|
||||||
|
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
);
|
||||||
|
let old_name = arg.name;
|
||||||
|
if func_args_ptr.contains(&arg.name) {
|
||||||
|
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
|
||||||
|
arg.name = new_id;
|
||||||
|
}
|
||||||
|
remapped_ids.insert(old_name, new_id);
|
||||||
|
}
|
||||||
for statement in func_body {
|
for statement in func_body {
|
||||||
match statement {
|
match statement {
|
||||||
l @ Statement::Label(_) => result.push(l),
|
l @ Statement::Label(_) => result.push(l),
|
||||||
|
@ -4550,7 +4575,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
|
||||||
convert_to_stateful_memory_access_postprocess(
|
convert_to_stateful_memory_access_postprocess(
|
||||||
id_defs,
|
id_defs,
|
||||||
&remapped_ids,
|
&remapped_ids,
|
||||||
&func_args_ptr,
|
|
||||||
&mut result,
|
&mut result,
|
||||||
&mut post_statements,
|
&mut post_statements,
|
||||||
arg_desc,
|
arg_desc,
|
||||||
|
@ -4567,7 +4591,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
|
||||||
convert_to_stateful_memory_access_postprocess(
|
convert_to_stateful_memory_access_postprocess(
|
||||||
id_defs,
|
id_defs,
|
||||||
&remapped_ids,
|
&remapped_ids,
|
||||||
&func_args_ptr,
|
|
||||||
&mut result,
|
&mut result,
|
||||||
&mut post_statements,
|
&mut post_statements,
|
||||||
arg_desc,
|
arg_desc,
|
||||||
|
@ -4584,7 +4607,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
|
||||||
convert_to_stateful_memory_access_postprocess(
|
convert_to_stateful_memory_access_postprocess(
|
||||||
id_defs,
|
id_defs,
|
||||||
&remapped_ids,
|
&remapped_ids,
|
||||||
&func_args_ptr,
|
|
||||||
&mut result,
|
&mut result,
|
||||||
&mut post_statements,
|
&mut post_statements,
|
||||||
arg_desc,
|
arg_desc,
|
||||||
|
@ -4597,32 +4619,21 @@ fn convert_to_stateful_memory_access<'a, 'input>(
|
||||||
_ => return Err(error_unreachable()),
|
_ => return Err(error_unreachable()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for arg in (*func_args).input_arguments.iter_mut() {
|
drop(method_decl);
|
||||||
if func_args_ptr.contains(&arg.name) {
|
Ok((func_args, result))
|
||||||
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
|
|
||||||
arg.state_space = ast::StateSpace::Reg;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_to_stateful_memory_access_postprocess(
|
fn convert_to_stateful_memory_access_postprocess(
|
||||||
id_defs: &mut NumericIdResolver,
|
id_defs: &mut NumericIdResolver,
|
||||||
remapped_ids: &HashMap<spirv::Word, spirv::Word>,
|
remapped_ids: &HashMap<spirv::Word, spirv::Word>,
|
||||||
func_args_ptr: &HashSet<spirv::Word>,
|
|
||||||
result: &mut Vec<TypedStatement>,
|
result: &mut Vec<TypedStatement>,
|
||||||
post_statements: &mut Vec<TypedStatement>,
|
post_statements: &mut Vec<TypedStatement>,
|
||||||
arg_desc: ArgumentDescriptor<spirv::Word>,
|
arg_desc: ArgumentDescriptor<spirv::Word>,
|
||||||
expected_type: Option<(&ast::Type, ast::StateSpace)>,
|
expected_type: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
) -> Result<spirv::Word, TranslateError> {
|
) -> Result<spirv::Word, TranslateError> {
|
||||||
Ok(
|
Ok(match remapped_ids.get(&arg_desc.op) {
|
||||||
match remapped_ids
|
|
||||||
.get(&arg_desc.op)
|
|
||||||
.or_else(|| func_args_ptr.get(&arg_desc.op))
|
|
||||||
{
|
|
||||||
Some(new_id) => {
|
Some(new_id) => {
|
||||||
let (new_operand_type, new_operand_space, is_variable) =
|
let (new_operand_type, new_operand_space, is_variable) = id_defs.get_typed(*new_id)?;
|
||||||
id_defs.get_typed(*new_id)?;
|
|
||||||
if let Some((expected_type, expected_space)) = expected_type {
|
if let Some((expected_type, expected_space)) = expected_type {
|
||||||
let implicit_conversion = arg_desc
|
let implicit_conversion = arg_desc
|
||||||
.non_default_implicit_conversion
|
.non_default_implicit_conversion
|
||||||
|
@ -4638,8 +4649,8 @@ fn convert_to_stateful_memory_access_postprocess(
|
||||||
}
|
}
|
||||||
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?;
|
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?;
|
||||||
let new_operand_type_clone = new_operand_type.clone();
|
let new_operand_type_clone = new_operand_type.clone();
|
||||||
let converting_id = id_defs
|
let converting_id =
|
||||||
.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
|
id_defs.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
|
||||||
let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) {
|
let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) {
|
||||||
ConversionKind::Default
|
ConversionKind::Default
|
||||||
} else {
|
} else {
|
||||||
|
@ -4670,8 +4681,7 @@ fn convert_to_stateful_memory_access_postprocess(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => arg_desc.op,
|
None => arg_desc.op,
|
||||||
},
|
})
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {
|
fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {
|
||||||
|
|
|
@ -219,7 +219,9 @@ unsafe fn to_str<T>(image: *const T) -> Option<&'static str> {
|
||||||
|
|
||||||
fn directive_to_kernel(dir: &ast::Directive<ast::ParsedArgParams>) -> Option<(String, Vec<usize>)> {
|
fn directive_to_kernel(dir: &ast::Directive<ast::ParsedArgParams>) -> Option<(String, Vec<usize>)> {
|
||||||
match dir {
|
match dir {
|
||||||
ast::Directive::Method(ast::Function {
|
ast::Directive::Method(
|
||||||
|
_,
|
||||||
|
ast::Function {
|
||||||
func_directive:
|
func_directive:
|
||||||
ast::MethodDeclaration {
|
ast::MethodDeclaration {
|
||||||
name: ast::MethodName::Kernel(name),
|
name: ast::MethodName::Kernel(name),
|
||||||
|
@ -227,7 +229,8 @@ fn directive_to_kernel(dir: &ast::Directive<ast::ParsedArgParams>) -> Option<(St
|
||||||
..
|
..
|
||||||
},
|
},
|
||||||
..
|
..
|
||||||
}) => {
|
},
|
||||||
|
) => {
|
||||||
let arg_sizes = input_arguments
|
let arg_sizes = input_arguments
|
||||||
.iter()
|
.iter()
|
||||||
.map(|arg| ast::Type::from(arg.v_type.clone()).size_of())
|
.map(|arg| ast::Type::from(arg.v_type.clone()).size_of())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue