mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-05 07:41:25 +00:00
Make stateful optimization build
This commit is contained in:
parent
e940b9400f
commit
9ad88ac982
1 changed files with 51 additions and 58 deletions
|
@ -1239,8 +1239,8 @@ fn to_ssa<'input, 'b>(
|
||||||
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
|
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
|
||||||
let typed_statements =
|
let typed_statements =
|
||||||
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
|
||||||
//let typed_statements =
|
let typed_statements =
|
||||||
// convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
|
convert_to_stateful_memory_access(&func_decl, typed_statements, &mut numeric_id_defs)?;
|
||||||
let ssa_statements = insert_mem_ssa_statements(
|
let ssa_statements = insert_mem_ssa_statements(
|
||||||
typed_statements,
|
typed_statements,
|
||||||
&mut numeric_id_defs,
|
&mut numeric_id_defs,
|
||||||
|
@ -4307,14 +4307,14 @@ fn expand_map_variables<'a, 'b>(
|
||||||
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
|
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
|
||||||
// argument expansion
|
// argument expansion
|
||||||
// TODO: propagate through calls?
|
// TODO: propagate through calls?
|
||||||
/*
|
fn convert_to_stateful_memory_access<'a, 'input>(
|
||||||
fn convert_to_stateful_memory_access<'a>(
|
func_args: &Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
|
||||||
func_args: &mut SpirvMethodDecl,
|
|
||||||
func_body: Vec<TypedStatement>,
|
func_body: Vec<TypedStatement>,
|
||||||
id_defs: &mut NumericIdResolver<'a>,
|
id_defs: &mut NumericIdResolver<'a>,
|
||||||
) -> Result<Vec<TypedStatement>, TranslateError> {
|
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||||
let func_args_64bit = func_args
|
let mut func_args = func_args.borrow_mut();
|
||||||
.input
|
let func_args_64bit = (*func_args)
|
||||||
|
.input_arguments
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|arg| match arg.v_type {
|
.filter_map(|arg| match arg.v_type {
|
||||||
ast::Type::Scalar(ast::ScalarType::U64)
|
ast::Type::Scalar(ast::ScalarType::U64)
|
||||||
|
@ -4445,15 +4445,15 @@ fn convert_to_stateful_memory_access<'a>(
|
||||||
let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
|
let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
|
||||||
for reg in regs_ptr_seen {
|
for reg in regs_ptr_seen {
|
||||||
let new_id = id_defs.register_variable(
|
let new_id = id_defs.register_variable(
|
||||||
ast::Type::Pointer(ast::ScalarType::U8),
|
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||||
ast::StateSpace::Global,
|
ast::StateSpace::Reg,
|
||||||
);
|
);
|
||||||
result.push(Statement::Variable(ast::Variable {
|
result.push(Statement::Variable(ast::Variable {
|
||||||
align: None,
|
align: None,
|
||||||
name: new_id,
|
name: new_id,
|
||||||
array_init: Vec::new(),
|
array_init: Vec::new(),
|
||||||
v_type: ast::Type::Pointer(ast::ScalarType::U8),
|
v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||||
state_space: ast::StateSpace::Global,
|
state_space: ast::StateSpace::Reg,
|
||||||
}));
|
}));
|
||||||
remapped_ids.insert(reg, new_id);
|
remapped_ids.insert(reg, new_id);
|
||||||
}
|
}
|
||||||
|
@ -4515,8 +4515,10 @@ fn convert_to_stateful_memory_access<'a>(
|
||||||
}
|
}
|
||||||
_ => return Err(error_unreachable()),
|
_ => return Err(error_unreachable()),
|
||||||
};
|
};
|
||||||
let offset_neg =
|
let offset_neg = id_defs.register_intermediate(Some((
|
||||||
id_defs.register_intermediate(Some(ast::Type::Scalar(ast::ScalarType::S64)));
|
ast::Type::Scalar(ast::ScalarType::S64),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)));
|
||||||
result.push(Statement::Instruction(ast::Instruction::Neg(
|
result.push(Statement::Instruction(ast::Instruction::Neg(
|
||||||
ast::NegDetails {
|
ast::NegDetails {
|
||||||
typ: ast::ScalarType::S64,
|
typ: ast::ScalarType::S64,
|
||||||
|
@ -4538,9 +4540,8 @@ fn convert_to_stateful_memory_access<'a>(
|
||||||
}
|
}
|
||||||
Statement::Instruction(inst) => {
|
Statement::Instruction(inst) => {
|
||||||
let mut post_statements = Vec::new();
|
let mut post_statements = Vec::new();
|
||||||
let new_statement = inst.visit(
|
let new_statement =
|
||||||
&mut |arg_desc: ArgumentDescriptor<spirv::Word>,
|
inst.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
|
||||||
expected_type: Option<&ast::Type>| {
|
|
||||||
convert_to_stateful_memory_access_postprocess(
|
convert_to_stateful_memory_access_postprocess(
|
||||||
id_defs,
|
id_defs,
|
||||||
&remapped_ids,
|
&remapped_ids,
|
||||||
|
@ -4550,16 +4551,14 @@ fn convert_to_stateful_memory_access<'a>(
|
||||||
arg_desc,
|
arg_desc,
|
||||||
expected_type,
|
expected_type,
|
||||||
)
|
)
|
||||||
},
|
})?;
|
||||||
)?;
|
|
||||||
result.push(new_statement);
|
result.push(new_statement);
|
||||||
result.extend(post_statements);
|
result.extend(post_statements);
|
||||||
}
|
}
|
||||||
Statement::Call(call) => {
|
Statement::Call(call) => {
|
||||||
let mut post_statements = Vec::new();
|
let mut post_statements = Vec::new();
|
||||||
let new_statement = call.visit(
|
let new_statement =
|
||||||
&mut |arg_desc: ArgumentDescriptor<spirv::Word>,
|
call.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
|
||||||
expected_type: Option<&ast::Type>| {
|
|
||||||
convert_to_stateful_memory_access_postprocess(
|
convert_to_stateful_memory_access_postprocess(
|
||||||
id_defs,
|
id_defs,
|
||||||
&remapped_ids,
|
&remapped_ids,
|
||||||
|
@ -4569,16 +4568,14 @@ fn convert_to_stateful_memory_access<'a>(
|
||||||
arg_desc,
|
arg_desc,
|
||||||
expected_type,
|
expected_type,
|
||||||
)
|
)
|
||||||
},
|
})?;
|
||||||
)?;
|
|
||||||
result.push(new_statement);
|
result.push(new_statement);
|
||||||
result.extend(post_statements);
|
result.extend(post_statements);
|
||||||
}
|
}
|
||||||
Statement::RepackVector(pack) => {
|
Statement::RepackVector(pack) => {
|
||||||
let mut post_statements = Vec::new();
|
let mut post_statements = Vec::new();
|
||||||
let new_statement = pack.visit(
|
let new_statement =
|
||||||
&mut |arg_desc: ArgumentDescriptor<spirv::Word>,
|
pack.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
|
||||||
expected_type: Option<&ast::Type>| {
|
|
||||||
convert_to_stateful_memory_access_postprocess(
|
convert_to_stateful_memory_access_postprocess(
|
||||||
id_defs,
|
id_defs,
|
||||||
&remapped_ids,
|
&remapped_ids,
|
||||||
|
@ -4588,18 +4585,17 @@ fn convert_to_stateful_memory_access<'a>(
|
||||||
arg_desc,
|
arg_desc,
|
||||||
expected_type,
|
expected_type,
|
||||||
)
|
)
|
||||||
},
|
})?;
|
||||||
)?;
|
|
||||||
result.push(new_statement);
|
result.push(new_statement);
|
||||||
result.extend(post_statements);
|
result.extend(post_statements);
|
||||||
}
|
}
|
||||||
_ => return Err(error_unreachable()),
|
_ => return Err(error_unreachable()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for arg in func_args.input.iter_mut() {
|
for arg in (*func_args).input_arguments.iter_mut() {
|
||||||
if func_args_ptr.contains(&arg.name) {
|
if func_args_ptr.contains(&arg.name) {
|
||||||
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8);
|
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
|
||||||
arg.state_space = ast::StateSpace::Global;
|
arg.state_space = ast::StateSpace::Reg;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(result)
|
Ok(result)
|
||||||
|
@ -4612,43 +4608,40 @@ fn convert_to_stateful_memory_access_postprocess(
|
||||||
result: &mut Vec<TypedStatement>,
|
result: &mut Vec<TypedStatement>,
|
||||||
post_statements: &mut Vec<TypedStatement>,
|
post_statements: &mut Vec<TypedStatement>,
|
||||||
arg_desc: ArgumentDescriptor<spirv::Word>,
|
arg_desc: ArgumentDescriptor<spirv::Word>,
|
||||||
expected_type: Option<&ast::Type>,
|
expected_type: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
) -> Result<spirv::Word, TranslateError> {
|
) -> Result<spirv::Word, TranslateError> {
|
||||||
Ok(match remapped_ids.get(&arg_desc.op) {
|
Ok(match remapped_ids.get(&arg_desc.op) {
|
||||||
Some(new_id) => {
|
Some(new_id) => {
|
||||||
// We skip conversion here to trigger PtrAcces in a later pass
|
// We skip conversion here to trigger PtrAcces in a later pass
|
||||||
let old_type = match expected_type {
|
let old_type = match expected_type {
|
||||||
Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id),
|
Some((ast::Type::Pointer(_, ast::StateSpace::Global), ast::StateSpace::Reg)) => {
|
||||||
|
return Ok(*new_id)
|
||||||
|
}
|
||||||
_ => id_defs.get_typed(arg_desc.op)?.0,
|
_ => id_defs.get_typed(arg_desc.op)?.0,
|
||||||
};
|
};
|
||||||
let old_type_clone = old_type.clone();
|
let old_type_clone = old_type.clone();
|
||||||
let converting_id = id_defs.register_intermediate(Some(old_type_clone));
|
let converting_id =
|
||||||
|
id_defs.register_intermediate(Some((old_type_clone, ast::StateSpace::Reg)));
|
||||||
if arg_desc.is_dst {
|
if arg_desc.is_dst {
|
||||||
post_statements.push(Statement::Conversion(ImplicitConversion {
|
post_statements.push(Statement::Conversion(ImplicitConversion {
|
||||||
src: converting_id,
|
src: converting_id,
|
||||||
dst: *new_id,
|
dst: *new_id,
|
||||||
from_type: old_type,
|
from_type: old_type,
|
||||||
to_type: ast::Type::Pointer(
|
from_space: ast::StateSpace::Reg,
|
||||||
ast::PointerType::Scalar(ast::ScalarType::U8),
|
to_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||||
ast::StateSpace::Global,
|
to_space: ast::StateSpace::Reg,
|
||||||
),
|
kind: ConversionKind::BitToPtr,
|
||||||
kind: ConversionKind::BitToPtr(ast::StateSpace::Global),
|
|
||||||
src_
|
|
||||||
dst_sema: arg_desc.sema,
|
|
||||||
}));
|
}));
|
||||||
converting_id
|
converting_id
|
||||||
} else {
|
} else {
|
||||||
result.push(Statement::Conversion(ImplicitConversion {
|
result.push(Statement::Conversion(ImplicitConversion {
|
||||||
src: *new_id,
|
src: *new_id,
|
||||||
dst: converting_id,
|
dst: converting_id,
|
||||||
from_type: ast::Type::Pointer(
|
from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||||
ast::PointerType::Scalar(ast::ScalarType::U8),
|
from_space: ast::StateSpace::Reg,
|
||||||
ast::StateSpace::Global,
|
|
||||||
),
|
|
||||||
to_type: old_type,
|
to_type: old_type,
|
||||||
kind: ConversionKind::PtrToBit(ast::ScalarType::U64),
|
to_space: ast::StateSpace::Reg,
|
||||||
src_sema: arg_desc.sema,
|
kind: ConversionKind::AddressOf,
|
||||||
dst_
|
|
||||||
}));
|
}));
|
||||||
converting_id
|
converting_id
|
||||||
}
|
}
|
||||||
|
@ -4660,22 +4653,23 @@ fn convert_to_stateful_memory_access_postprocess(
|
||||||
}
|
}
|
||||||
// We skip conversion here to trigger PtrAcces in a later pass
|
// We skip conversion here to trigger PtrAcces in a later pass
|
||||||
let old_type = match expected_type {
|
let old_type = match expected_type {
|
||||||
Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id),
|
Some((
|
||||||
|
ast::Type::Pointer(_, ast::StateSpace::Global),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)) => return Ok(*new_id),
|
||||||
_ => id_defs.get_typed(arg_desc.op)?.0,
|
_ => id_defs.get_typed(arg_desc.op)?.0,
|
||||||
};
|
};
|
||||||
let old_type_clone = old_type.clone();
|
let old_type_clone = old_type.clone();
|
||||||
let converting_id = id_defs.register_intermediate(Some(old_type));
|
let converting_id =
|
||||||
|
id_defs.register_intermediate(Some((old_type, ast::StateSpace::Reg)));
|
||||||
result.push(Statement::Conversion(ImplicitConversion {
|
result.push(Statement::Conversion(ImplicitConversion {
|
||||||
src: *new_id,
|
src: *new_id,
|
||||||
dst: converting_id,
|
dst: converting_id,
|
||||||
from_type: ast::Type::Pointer(
|
from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||||
ast::PointerType::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
from_space: ast::StateSpace::Reg,
|
||||||
ast::StateSpace::Param,
|
|
||||||
),
|
|
||||||
to_type: old_type_clone,
|
to_type: old_type_clone,
|
||||||
kind: ConversionKind::PtrToPtr { spirv_ptr: false },
|
to_space: ast::StateSpace::Reg,
|
||||||
src_sema: arg_desc.sema,
|
kind: ConversionKind::PtrToPtr,
|
||||||
dst_
|
|
||||||
}));
|
}));
|
||||||
converting_id
|
converting_id
|
||||||
}
|
}
|
||||||
|
@ -4710,7 +4704,6 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
|
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
|
||||||
enum PtxSpecialRegister {
|
enum PtxSpecialRegister {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue