Make stateful optimization build

This commit is contained in:
Andrzej Janik 2021-06-06 18:14:49 +02:00
parent e940b9400f
commit 9ad88ac982

View file

@ -1239,8 +1239,8 @@ fn to_ssa<'input, 'b>(
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
//let typed_statements =
// convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
let typed_statements =
convert_to_stateful_memory_access(&func_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
typed_statements,
&mut numeric_id_defs,
@ -4307,14 +4307,14 @@ fn expand_map_variables<'a, 'b>(
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
// argument expansion
// TODO: propagate through calls?
/*
fn convert_to_stateful_memory_access<'a>(
func_args: &mut SpirvMethodDecl,
fn convert_to_stateful_memory_access<'a, 'input>(
func_args: &Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
func_body: Vec<TypedStatement>,
id_defs: &mut NumericIdResolver<'a>,
) -> Result<Vec<TypedStatement>, TranslateError> {
let func_args_64bit = func_args
.input
let mut func_args = func_args.borrow_mut();
let func_args_64bit = (*func_args)
.input_arguments
.iter()
.filter_map(|arg| match arg.v_type {
ast::Type::Scalar(ast::ScalarType::U64)
@ -4445,15 +4445,15 @@ fn convert_to_stateful_memory_access<'a>(
let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
for reg in regs_ptr_seen {
let new_id = id_defs.register_variable(
ast::Type::Pointer(ast::ScalarType::U8),
ast::StateSpace::Global,
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
ast::StateSpace::Reg,
);
result.push(Statement::Variable(ast::Variable {
align: None,
name: new_id,
array_init: Vec::new(),
v_type: ast::Type::Pointer(ast::ScalarType::U8),
state_space: ast::StateSpace::Global,
v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
state_space: ast::StateSpace::Reg,
}));
remapped_ids.insert(reg, new_id);
}
@ -4515,8 +4515,10 @@ fn convert_to_stateful_memory_access<'a>(
}
_ => return Err(error_unreachable()),
};
let offset_neg =
id_defs.register_intermediate(Some(ast::Type::Scalar(ast::ScalarType::S64)));
let offset_neg = id_defs.register_intermediate(Some((
ast::Type::Scalar(ast::ScalarType::S64),
ast::StateSpace::Reg,
)));
result.push(Statement::Instruction(ast::Instruction::Neg(
ast::NegDetails {
typ: ast::ScalarType::S64,
@ -4538,9 +4540,8 @@ fn convert_to_stateful_memory_access<'a>(
}
Statement::Instruction(inst) => {
let mut post_statements = Vec::new();
let new_statement = inst.visit(
&mut |arg_desc: ArgumentDescriptor<spirv::Word>,
expected_type: Option<&ast::Type>| {
let new_statement =
inst.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
@ -4550,16 +4551,14 @@ fn convert_to_stateful_memory_access<'a>(
arg_desc,
expected_type,
)
},
)?;
})?;
result.push(new_statement);
result.extend(post_statements);
}
Statement::Call(call) => {
let mut post_statements = Vec::new();
let new_statement = call.visit(
&mut |arg_desc: ArgumentDescriptor<spirv::Word>,
expected_type: Option<&ast::Type>| {
let new_statement =
call.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
@ -4569,16 +4568,14 @@ fn convert_to_stateful_memory_access<'a>(
arg_desc,
expected_type,
)
},
)?;
})?;
result.push(new_statement);
result.extend(post_statements);
}
Statement::RepackVector(pack) => {
let mut post_statements = Vec::new();
let new_statement = pack.visit(
&mut |arg_desc: ArgumentDescriptor<spirv::Word>,
expected_type: Option<&ast::Type>| {
let new_statement =
pack.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
@ -4588,18 +4585,17 @@ fn convert_to_stateful_memory_access<'a>(
arg_desc,
expected_type,
)
},
)?;
})?;
result.push(new_statement);
result.extend(post_statements);
}
_ => return Err(error_unreachable()),
}
}
for arg in func_args.input.iter_mut() {
for arg in (*func_args).input_arguments.iter_mut() {
if func_args_ptr.contains(&arg.name) {
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8);
arg.state_space = ast::StateSpace::Global;
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
arg.state_space = ast::StateSpace::Reg;
}
}
Ok(result)
@ -4612,43 +4608,40 @@ fn convert_to_stateful_memory_access_postprocess(
result: &mut Vec<TypedStatement>,
post_statements: &mut Vec<TypedStatement>,
arg_desc: ArgumentDescriptor<spirv::Word>,
expected_type: Option<&ast::Type>,
expected_type: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
Ok(match remapped_ids.get(&arg_desc.op) {
Some(new_id) => {
// We skip conversion here to trigger PtrAcces in a later pass
let old_type = match expected_type {
Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id),
Some((ast::Type::Pointer(_, ast::StateSpace::Global), ast::StateSpace::Reg)) => {
return Ok(*new_id)
}
_ => id_defs.get_typed(arg_desc.op)?.0,
};
let old_type_clone = old_type.clone();
let converting_id = id_defs.register_intermediate(Some(old_type_clone));
let converting_id =
id_defs.register_intermediate(Some((old_type_clone, ast::StateSpace::Reg)));
if arg_desc.is_dst {
post_statements.push(Statement::Conversion(ImplicitConversion {
src: converting_id,
dst: *new_id,
from_type: old_type,
to_type: ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::U8),
ast::StateSpace::Global,
),
kind: ConversionKind::BitToPtr(ast::StateSpace::Global),
src_
dst_sema: arg_desc.sema,
from_space: ast::StateSpace::Reg,
to_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
to_space: ast::StateSpace::Reg,
kind: ConversionKind::BitToPtr,
}));
converting_id
} else {
result.push(Statement::Conversion(ImplicitConversion {
src: *new_id,
dst: converting_id,
from_type: ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::U8),
ast::StateSpace::Global,
),
from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
from_space: ast::StateSpace::Reg,
to_type: old_type,
kind: ConversionKind::PtrToBit(ast::ScalarType::U64),
src_sema: arg_desc.sema,
dst_
to_space: ast::StateSpace::Reg,
kind: ConversionKind::AddressOf,
}));
converting_id
}
@ -4660,22 +4653,23 @@ fn convert_to_stateful_memory_access_postprocess(
}
// We skip conversion here to trigger PtrAcces in a later pass
let old_type = match expected_type {
Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id),
Some((
ast::Type::Pointer(_, ast::StateSpace::Global),
ast::StateSpace::Reg,
)) => return Ok(*new_id),
_ => id_defs.get_typed(arg_desc.op)?.0,
};
let old_type_clone = old_type.clone();
let converting_id = id_defs.register_intermediate(Some(old_type));
let converting_id =
id_defs.register_intermediate(Some((old_type, ast::StateSpace::Reg)));
result.push(Statement::Conversion(ImplicitConversion {
src: *new_id,
dst: converting_id,
from_type: ast::Type::Pointer(
ast::PointerType::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
ast::StateSpace::Param,
),
from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
from_space: ast::StateSpace::Reg,
to_type: old_type_clone,
kind: ConversionKind::PtrToPtr { spirv_ptr: false },
src_sema: arg_desc.sema,
dst_
to_space: ast::StateSpace::Reg,
kind: ConversionKind::PtrToPtr,
}));
converting_id
}
@ -4710,7 +4704,6 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
_ => false,
}
}
*/
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
enum PtxSpecialRegister {