Make stateful optimization build

This commit is contained in:
Andrzej Janik 2021-06-06 18:14:49 +02:00
commit 9ad88ac982

View file

@ -1239,8 +1239,8 @@ fn to_ssa<'input, 'b>(
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements = let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
//let typed_statements = let typed_statements =
// convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; convert_to_stateful_memory_access(&func_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements( let ssa_statements = insert_mem_ssa_statements(
typed_statements, typed_statements,
&mut numeric_id_defs, &mut numeric_id_defs,
@ -4307,14 +4307,14 @@ fn expand_map_variables<'a, 'b>(
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after // TODO: once insert_mem_ssa_statements is moved to later, move this pass after
// argument expansion // argument expansion
// TODO: propagate through calls? // TODO: propagate through calls?
/* fn convert_to_stateful_memory_access<'a, 'input>(
fn convert_to_stateful_memory_access<'a>( func_args: &Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
func_args: &mut SpirvMethodDecl,
func_body: Vec<TypedStatement>, func_body: Vec<TypedStatement>,
id_defs: &mut NumericIdResolver<'a>, id_defs: &mut NumericIdResolver<'a>,
) -> Result<Vec<TypedStatement>, TranslateError> { ) -> Result<Vec<TypedStatement>, TranslateError> {
let func_args_64bit = func_args let mut func_args = func_args.borrow_mut();
.input let func_args_64bit = (*func_args)
.input_arguments
.iter() .iter()
.filter_map(|arg| match arg.v_type { .filter_map(|arg| match arg.v_type {
ast::Type::Scalar(ast::ScalarType::U64) ast::Type::Scalar(ast::ScalarType::U64)
@ -4445,15 +4445,15 @@ fn convert_to_stateful_memory_access<'a>(
let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len()); let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
for reg in regs_ptr_seen { for reg in regs_ptr_seen {
let new_id = id_defs.register_variable( let new_id = id_defs.register_variable(
ast::Type::Pointer(ast::ScalarType::U8), ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
ast::StateSpace::Global, ast::StateSpace::Reg,
); );
result.push(Statement::Variable(ast::Variable { result.push(Statement::Variable(ast::Variable {
align: None, align: None,
name: new_id, name: new_id,
array_init: Vec::new(), array_init: Vec::new(),
v_type: ast::Type::Pointer(ast::ScalarType::U8), v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
state_space: ast::StateSpace::Global, state_space: ast::StateSpace::Reg,
})); }));
remapped_ids.insert(reg, new_id); remapped_ids.insert(reg, new_id);
} }
@ -4515,8 +4515,10 @@ fn convert_to_stateful_memory_access<'a>(
} }
_ => return Err(error_unreachable()), _ => return Err(error_unreachable()),
}; };
let offset_neg = let offset_neg = id_defs.register_intermediate(Some((
id_defs.register_intermediate(Some(ast::Type::Scalar(ast::ScalarType::S64))); ast::Type::Scalar(ast::ScalarType::S64),
ast::StateSpace::Reg,
)));
result.push(Statement::Instruction(ast::Instruction::Neg( result.push(Statement::Instruction(ast::Instruction::Neg(
ast::NegDetails { ast::NegDetails {
typ: ast::ScalarType::S64, typ: ast::ScalarType::S64,
@ -4538,9 +4540,8 @@ fn convert_to_stateful_memory_access<'a>(
} }
Statement::Instruction(inst) => { Statement::Instruction(inst) => {
let mut post_statements = Vec::new(); let mut post_statements = Vec::new();
let new_statement = inst.visit( let new_statement =
&mut |arg_desc: ArgumentDescriptor<spirv::Word>, inst.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
expected_type: Option<&ast::Type>| {
convert_to_stateful_memory_access_postprocess( convert_to_stateful_memory_access_postprocess(
id_defs, id_defs,
&remapped_ids, &remapped_ids,
@ -4550,16 +4551,14 @@ fn convert_to_stateful_memory_access<'a>(
arg_desc, arg_desc,
expected_type, expected_type,
) )
}, })?;
)?;
result.push(new_statement); result.push(new_statement);
result.extend(post_statements); result.extend(post_statements);
} }
Statement::Call(call) => { Statement::Call(call) => {
let mut post_statements = Vec::new(); let mut post_statements = Vec::new();
let new_statement = call.visit( let new_statement =
&mut |arg_desc: ArgumentDescriptor<spirv::Word>, call.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
expected_type: Option<&ast::Type>| {
convert_to_stateful_memory_access_postprocess( convert_to_stateful_memory_access_postprocess(
id_defs, id_defs,
&remapped_ids, &remapped_ids,
@ -4569,16 +4568,14 @@ fn convert_to_stateful_memory_access<'a>(
arg_desc, arg_desc,
expected_type, expected_type,
) )
}, })?;
)?;
result.push(new_statement); result.push(new_statement);
result.extend(post_statements); result.extend(post_statements);
} }
Statement::RepackVector(pack) => { Statement::RepackVector(pack) => {
let mut post_statements = Vec::new(); let mut post_statements = Vec::new();
let new_statement = pack.visit( let new_statement =
&mut |arg_desc: ArgumentDescriptor<spirv::Word>, pack.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
expected_type: Option<&ast::Type>| {
convert_to_stateful_memory_access_postprocess( convert_to_stateful_memory_access_postprocess(
id_defs, id_defs,
&remapped_ids, &remapped_ids,
@ -4588,18 +4585,17 @@ fn convert_to_stateful_memory_access<'a>(
arg_desc, arg_desc,
expected_type, expected_type,
) )
}, })?;
)?;
result.push(new_statement); result.push(new_statement);
result.extend(post_statements); result.extend(post_statements);
} }
_ => return Err(error_unreachable()), _ => return Err(error_unreachable()),
} }
} }
for arg in func_args.input.iter_mut() { for arg in (*func_args).input_arguments.iter_mut() {
if func_args_ptr.contains(&arg.name) { if func_args_ptr.contains(&arg.name) {
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8); arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
arg.state_space = ast::StateSpace::Global; arg.state_space = ast::StateSpace::Reg;
} }
} }
Ok(result) Ok(result)
@ -4612,43 +4608,40 @@ fn convert_to_stateful_memory_access_postprocess(
result: &mut Vec<TypedStatement>, result: &mut Vec<TypedStatement>,
post_statements: &mut Vec<TypedStatement>, post_statements: &mut Vec<TypedStatement>,
arg_desc: ArgumentDescriptor<spirv::Word>, arg_desc: ArgumentDescriptor<spirv::Word>,
expected_type: Option<&ast::Type>, expected_type: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> { ) -> Result<spirv::Word, TranslateError> {
Ok(match remapped_ids.get(&arg_desc.op) { Ok(match remapped_ids.get(&arg_desc.op) {
Some(new_id) => { Some(new_id) => {
// We skip conversion here to trigger PtrAcces in a later pass // We skip conversion here to trigger PtrAcces in a later pass
let old_type = match expected_type { let old_type = match expected_type {
Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id), Some((ast::Type::Pointer(_, ast::StateSpace::Global), ast::StateSpace::Reg)) => {
return Ok(*new_id)
}
_ => id_defs.get_typed(arg_desc.op)?.0, _ => id_defs.get_typed(arg_desc.op)?.0,
}; };
let old_type_clone = old_type.clone(); let old_type_clone = old_type.clone();
let converting_id = id_defs.register_intermediate(Some(old_type_clone)); let converting_id =
id_defs.register_intermediate(Some((old_type_clone, ast::StateSpace::Reg)));
if arg_desc.is_dst { if arg_desc.is_dst {
post_statements.push(Statement::Conversion(ImplicitConversion { post_statements.push(Statement::Conversion(ImplicitConversion {
src: converting_id, src: converting_id,
dst: *new_id, dst: *new_id,
from_type: old_type, from_type: old_type,
to_type: ast::Type::Pointer( from_space: ast::StateSpace::Reg,
ast::PointerType::Scalar(ast::ScalarType::U8), to_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
ast::StateSpace::Global, to_space: ast::StateSpace::Reg,
), kind: ConversionKind::BitToPtr,
kind: ConversionKind::BitToPtr(ast::StateSpace::Global),
src_
dst_sema: arg_desc.sema,
})); }));
converting_id converting_id
} else { } else {
result.push(Statement::Conversion(ImplicitConversion { result.push(Statement::Conversion(ImplicitConversion {
src: *new_id, src: *new_id,
dst: converting_id, dst: converting_id,
from_type: ast::Type::Pointer( from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
ast::PointerType::Scalar(ast::ScalarType::U8), from_space: ast::StateSpace::Reg,
ast::StateSpace::Global,
),
to_type: old_type, to_type: old_type,
kind: ConversionKind::PtrToBit(ast::ScalarType::U64), to_space: ast::StateSpace::Reg,
src_sema: arg_desc.sema, kind: ConversionKind::AddressOf,
dst_
})); }));
converting_id converting_id
} }
@ -4660,22 +4653,23 @@ fn convert_to_stateful_memory_access_postprocess(
} }
// We skip conversion here to trigger PtrAcces in a later pass // We skip conversion here to trigger PtrAcces in a later pass
let old_type = match expected_type { let old_type = match expected_type {
Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id), Some((
ast::Type::Pointer(_, ast::StateSpace::Global),
ast::StateSpace::Reg,
)) => return Ok(*new_id),
_ => id_defs.get_typed(arg_desc.op)?.0, _ => id_defs.get_typed(arg_desc.op)?.0,
}; };
let old_type_clone = old_type.clone(); let old_type_clone = old_type.clone();
let converting_id = id_defs.register_intermediate(Some(old_type)); let converting_id =
id_defs.register_intermediate(Some((old_type, ast::StateSpace::Reg)));
result.push(Statement::Conversion(ImplicitConversion { result.push(Statement::Conversion(ImplicitConversion {
src: *new_id, src: *new_id,
dst: converting_id, dst: converting_id,
from_type: ast::Type::Pointer( from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
ast::PointerType::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), from_space: ast::StateSpace::Reg,
ast::StateSpace::Param,
),
to_type: old_type_clone, to_type: old_type_clone,
kind: ConversionKind::PtrToPtr { spirv_ptr: false }, to_space: ast::StateSpace::Reg,
src_sema: arg_desc.sema, kind: ConversionKind::PtrToPtr,
dst_
})); }));
converting_id converting_id
} }
@ -4710,7 +4704,6 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
_ => false, _ => false,
} }
} }
*/
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] #[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
enum PtxSpecialRegister { enum PtxSpecialRegister {