Refactor type-of-function resolution

This commit is contained in:
Andrzej Janik 2024-09-16 17:20:46 +02:00
parent e87388bc35
commit c84d257bb7
2 changed files with 83 additions and 0 deletions

View file

@ -29,6 +29,7 @@ mod normalize_identifiers2;
mod normalize_labels;
mod normalize_predicates;
mod normalize_predicates2;
mod resolve_function_pointers;
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");

View file

@ -0,0 +1,82 @@
use super::*;
use ptx_parser as ast;
use rustc_hash::FxHashSet;
pub(crate) fn run<'input>(
directives: Vec<UnconditionalDirective<'input>>,
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
let mut functions = FxHashSet::default();
directives
.into_iter()
.map(|directive| run_directive(&mut functions, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
functions: &mut FxHashSet<SpirvWord>,
directive: UnconditionalDirective<'input>,
) -> Result<UnconditionalDirective<'input>, TranslateError> {
Ok(match directive {
var @ Directive2::Variable(..) => var,
Directive2::Method(method) => {
{
let func_decl = method.func_decl.borrow();
match func_decl.name {
ptx_parser::MethodName::Kernel(_) => {}
ptx_parser::MethodName::Func(name) => {
functions.insert(name);
}
}
}
Directive2::Method(run_method(functions, method)?)
}
})
}
fn run_method<'input>(
functions: &mut FxHashSet<SpirvWord>,
method: UnconditionalFunction<'input>,
) -> Result<UnconditionalFunction<'input>, TranslateError> {
let body = method
.body
.map(|statements| {
statements
.into_iter()
.map(|statement| run_statement(functions, statement))
.collect::<Result<Vec<_>, _>>()
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
})
}
fn run_statement<'input>(
functions: &mut FxHashSet<SpirvWord>,
statement: UnconditionalStatement,
) -> Result<UnconditionalStatement, TranslateError> {
Ok(match statement {
Statement::Instruction(ast::Instruction::Mov {
data,
arguments:
ast::MovArgs {
dst: ast::ParsedOperand::Reg(dst_reg),
src: ast::ParsedOperand::Reg(src_reg),
},
}) if functions.contains(&src_reg) => {
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
return Err(error_mismatched_type());
}
UnconditionalStatement::FunctionPointer(FunctionPointerDetails {
dst: dst_reg,
src: src_reg,
})
}
s => s,
})
}