mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Make stateful optimization build
This commit is contained in:
parent
e940b9400f
commit
9ad88ac982
1 changed files with 51 additions and 58 deletions
|
@ -1239,8 +1239,8 @@ fn to_ssa<'input, 'b>(
|
|||
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
|
||||
let typed_statements =
|
||||
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
||||
//let typed_statements =
|
||||
// convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
|
||||
let typed_statements =
|
||||
convert_to_stateful_memory_access(&func_decl, typed_statements, &mut numeric_id_defs)?;
|
||||
let ssa_statements = insert_mem_ssa_statements(
|
||||
typed_statements,
|
||||
&mut numeric_id_defs,
|
||||
|
@ -4307,14 +4307,14 @@ fn expand_map_variables<'a, 'b>(
|
|||
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
|
||||
// argument expansion
|
||||
// TODO: propagate through calls?
|
||||
/*
|
||||
fn convert_to_stateful_memory_access<'a>(
|
||||
func_args: &mut SpirvMethodDecl,
|
||||
fn convert_to_stateful_memory_access<'a, 'input>(
|
||||
func_args: &Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
|
||||
func_body: Vec<TypedStatement>,
|
||||
id_defs: &mut NumericIdResolver<'a>,
|
||||
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||
let func_args_64bit = func_args
|
||||
.input
|
||||
let mut func_args = func_args.borrow_mut();
|
||||
let func_args_64bit = (*func_args)
|
||||
.input_arguments
|
||||
.iter()
|
||||
.filter_map(|arg| match arg.v_type {
|
||||
ast::Type::Scalar(ast::ScalarType::U64)
|
||||
|
@ -4445,15 +4445,15 @@ fn convert_to_stateful_memory_access<'a>(
|
|||
let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
|
||||
for reg in regs_ptr_seen {
|
||||
let new_id = id_defs.register_variable(
|
||||
ast::Type::Pointer(ast::ScalarType::U8),
|
||||
ast::StateSpace::Global,
|
||||
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||
ast::StateSpace::Reg,
|
||||
);
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: new_id,
|
||||
array_init: Vec::new(),
|
||||
v_type: ast::Type::Pointer(ast::ScalarType::U8),
|
||||
state_space: ast::StateSpace::Global,
|
||||
v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
}));
|
||||
remapped_ids.insert(reg, new_id);
|
||||
}
|
||||
|
@ -4515,8 +4515,10 @@ fn convert_to_stateful_memory_access<'a>(
|
|||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let offset_neg =
|
||||
id_defs.register_intermediate(Some(ast::Type::Scalar(ast::ScalarType::S64)));
|
||||
let offset_neg = id_defs.register_intermediate(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::S64),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
result.push(Statement::Instruction(ast::Instruction::Neg(
|
||||
ast::NegDetails {
|
||||
typ: ast::ScalarType::S64,
|
||||
|
@ -4538,9 +4540,8 @@ fn convert_to_stateful_memory_access<'a>(
|
|||
}
|
||||
Statement::Instruction(inst) => {
|
||||
let mut post_statements = Vec::new();
|
||||
let new_statement = inst.visit(
|
||||
&mut |arg_desc: ArgumentDescriptor<spirv::Word>,
|
||||
expected_type: Option<&ast::Type>| {
|
||||
let new_statement =
|
||||
inst.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
|
||||
convert_to_stateful_memory_access_postprocess(
|
||||
id_defs,
|
||||
&remapped_ids,
|
||||
|
@ -4550,16 +4551,14 @@ fn convert_to_stateful_memory_access<'a>(
|
|||
arg_desc,
|
||||
expected_type,
|
||||
)
|
||||
},
|
||||
)?;
|
||||
})?;
|
||||
result.push(new_statement);
|
||||
result.extend(post_statements);
|
||||
}
|
||||
Statement::Call(call) => {
|
||||
let mut post_statements = Vec::new();
|
||||
let new_statement = call.visit(
|
||||
&mut |arg_desc: ArgumentDescriptor<spirv::Word>,
|
||||
expected_type: Option<&ast::Type>| {
|
||||
let new_statement =
|
||||
call.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
|
||||
convert_to_stateful_memory_access_postprocess(
|
||||
id_defs,
|
||||
&remapped_ids,
|
||||
|
@ -4569,16 +4568,14 @@ fn convert_to_stateful_memory_access<'a>(
|
|||
arg_desc,
|
||||
expected_type,
|
||||
)
|
||||
},
|
||||
)?;
|
||||
})?;
|
||||
result.push(new_statement);
|
||||
result.extend(post_statements);
|
||||
}
|
||||
Statement::RepackVector(pack) => {
|
||||
let mut post_statements = Vec::new();
|
||||
let new_statement = pack.visit(
|
||||
&mut |arg_desc: ArgumentDescriptor<spirv::Word>,
|
||||
expected_type: Option<&ast::Type>| {
|
||||
let new_statement =
|
||||
pack.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
|
||||
convert_to_stateful_memory_access_postprocess(
|
||||
id_defs,
|
||||
&remapped_ids,
|
||||
|
@ -4588,18 +4585,17 @@ fn convert_to_stateful_memory_access<'a>(
|
|||
arg_desc,
|
||||
expected_type,
|
||||
)
|
||||
},
|
||||
)?;
|
||||
})?;
|
||||
result.push(new_statement);
|
||||
result.extend(post_statements);
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
for arg in func_args.input.iter_mut() {
|
||||
for arg in (*func_args).input_arguments.iter_mut() {
|
||||
if func_args_ptr.contains(&arg.name) {
|
||||
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8);
|
||||
arg.state_space = ast::StateSpace::Global;
|
||||
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
|
||||
arg.state_space = ast::StateSpace::Reg;
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
|
@ -4612,43 +4608,40 @@ fn convert_to_stateful_memory_access_postprocess(
|
|||
result: &mut Vec<TypedStatement>,
|
||||
post_statements: &mut Vec<TypedStatement>,
|
||||
arg_desc: ArgumentDescriptor<spirv::Word>,
|
||||
expected_type: Option<&ast::Type>,
|
||||
expected_type: Option<(&ast::Type, ast::StateSpace)>,
|
||||
) -> Result<spirv::Word, TranslateError> {
|
||||
Ok(match remapped_ids.get(&arg_desc.op) {
|
||||
Some(new_id) => {
|
||||
// We skip conversion here to trigger PtrAcces in a later pass
|
||||
let old_type = match expected_type {
|
||||
Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id),
|
||||
Some((ast::Type::Pointer(_, ast::StateSpace::Global), ast::StateSpace::Reg)) => {
|
||||
return Ok(*new_id)
|
||||
}
|
||||
_ => id_defs.get_typed(arg_desc.op)?.0,
|
||||
};
|
||||
let old_type_clone = old_type.clone();
|
||||
let converting_id = id_defs.register_intermediate(Some(old_type_clone));
|
||||
let converting_id =
|
||||
id_defs.register_intermediate(Some((old_type_clone, ast::StateSpace::Reg)));
|
||||
if arg_desc.is_dst {
|
||||
post_statements.push(Statement::Conversion(ImplicitConversion {
|
||||
src: converting_id,
|
||||
dst: *new_id,
|
||||
from_type: old_type,
|
||||
to_type: ast::Type::Pointer(
|
||||
ast::PointerType::Scalar(ast::ScalarType::U8),
|
||||
ast::StateSpace::Global,
|
||||
),
|
||||
kind: ConversionKind::BitToPtr(ast::StateSpace::Global),
|
||||
src_
|
||||
dst_sema: arg_desc.sema,
|
||||
from_space: ast::StateSpace::Reg,
|
||||
to_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||
to_space: ast::StateSpace::Reg,
|
||||
kind: ConversionKind::BitToPtr,
|
||||
}));
|
||||
converting_id
|
||||
} else {
|
||||
result.push(Statement::Conversion(ImplicitConversion {
|
||||
src: *new_id,
|
||||
dst: converting_id,
|
||||
from_type: ast::Type::Pointer(
|
||||
ast::PointerType::Scalar(ast::ScalarType::U8),
|
||||
ast::StateSpace::Global,
|
||||
),
|
||||
from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||
from_space: ast::StateSpace::Reg,
|
||||
to_type: old_type,
|
||||
kind: ConversionKind::PtrToBit(ast::ScalarType::U64),
|
||||
src_sema: arg_desc.sema,
|
||||
dst_
|
||||
to_space: ast::StateSpace::Reg,
|
||||
kind: ConversionKind::AddressOf,
|
||||
}));
|
||||
converting_id
|
||||
}
|
||||
|
@ -4660,22 +4653,23 @@ fn convert_to_stateful_memory_access_postprocess(
|
|||
}
|
||||
// We skip conversion here to trigger PtrAcces in a later pass
|
||||
let old_type = match expected_type {
|
||||
Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id),
|
||||
Some((
|
||||
ast::Type::Pointer(_, ast::StateSpace::Global),
|
||||
ast::StateSpace::Reg,
|
||||
)) => return Ok(*new_id),
|
||||
_ => id_defs.get_typed(arg_desc.op)?.0,
|
||||
};
|
||||
let old_type_clone = old_type.clone();
|
||||
let converting_id = id_defs.register_intermediate(Some(old_type));
|
||||
let converting_id =
|
||||
id_defs.register_intermediate(Some((old_type, ast::StateSpace::Reg)));
|
||||
result.push(Statement::Conversion(ImplicitConversion {
|
||||
src: *new_id,
|
||||
dst: converting_id,
|
||||
from_type: ast::Type::Pointer(
|
||||
ast::PointerType::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||
ast::StateSpace::Param,
|
||||
),
|
||||
from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||
from_space: ast::StateSpace::Reg,
|
||||
to_type: old_type_clone,
|
||||
kind: ConversionKind::PtrToPtr { spirv_ptr: false },
|
||||
src_sema: arg_desc.sema,
|
||||
dst_
|
||||
to_space: ast::StateSpace::Reg,
|
||||
kind: ConversionKind::PtrToPtr,
|
||||
}));
|
||||
converting_id
|
||||
}
|
||||
|
@ -4710,7 +4704,6 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
|
|||
_ => false,
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
|
||||
enum PtxSpecialRegister {
|
||||
|
|
Loading…
Add table
Reference in a new issue