Port normalize_predicates

This commit is contained in:
Andrzej Janik 2024-09-16 17:08:12 +02:00
commit e87388bc35
3 changed files with 157 additions and 37 deletions

View file

@ -1,5 +1,6 @@
use ptx_parser as ast;
use rspirv::{binary::Assemble, dr};
use rustc_hash::FxHashMap;
use std::hash::Hash;
use std::num::NonZeroU8;
use std::{
@ -27,6 +28,7 @@ mod normalize_identifiers;
mod normalize_identifiers2;
mod normalize_labels;
mod normalize_predicates;
mod normalize_predicates2;
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
@ -1690,3 +1692,43 @@ type NormalizedFunction2<'input> = Function2<
),
ast::ParsedOperand<SpirvWord>,
>;
type UnconditionalDirective<'input> = Directive2<
'input,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>;
type UnconditionalFunction<'input> = Function2<
'input,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
ast::ParsedOperand<SpirvWord>,
>;
struct GlobalStringIdentResolver2<'input> {
pub(crate) current_id: SpirvWord,
pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
}
impl<'input> GlobalStringIdentResolver2<'input> {
fn register_intermediate(
&mut self,
type_space: Option<(ast::Type, ast::StateSpace)>,
) -> SpirvWord {
let new_id = self.current_id;
self.ident_map.insert(
new_id,
IdentEntry {
name: None,
type_space,
},
);
self.current_id.0 += 1;
new_id
}
}
struct IdentEntry<'input> {
name: Option<Cow<'input, str>>,
type_space: Option<(ast::Type, ast::StateSpace)>,
}

View file

@ -3,45 +3,45 @@ use ptx_parser as ast;
use rustc_hash::FxHashMap;
pub(crate) fn run<'input>(
fn_defs: &mut GlobalStringIdentResolver<'input>,
fn_defs: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
let mut resolver = NameResolver::new(fn_defs);
let result = directives
.into_iter()
.map(|directive| remap_directive(&mut resolver, directive))
.map(|directive| run_directive(&mut resolver, directive))
.collect::<Result<Vec<_>, _>>()?;
resolver.end_scope();
Ok(result)
}
fn remap_directive<'input, 'b>(
fn run_directive<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
) -> Result<NormalizedDirective2<'input>, TranslateError> {
Ok(match directive {
ast::Directive::Variable(linking, var) => {
NormalizedDirective2::Variable(linking, remap_variable(resolver, var)?)
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
}
ast::Directive::Method(linking, directive) => {
NormalizedDirective2::Method(remap_method(resolver, linking, directive)?)
NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
}
})
}
fn remap_method<'input, 'b>(
fn run_method<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
linkage: ast::LinkingDirective,
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<NormalizedFunction2<'input>, TranslateError> {
let name = match method.func_directive.name {
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
ast::MethodName::Func(text) => ast::MethodName::Func(
resolver.add(Cow::Borrowed(method.func_directive.name.text()), None)?,
),
ast::MethodName::Func(text) => {
ast::MethodName::Func(resolver.add(Cow::Borrowed(text), None)?)
}
};
resolver.start_scope();
let func_decl = Rc::new(RefCell::new(remap_function_decl(
let func_decl = Rc::new(RefCell::new(run_function_decl(
resolver,
method.func_directive,
name,
@ -50,7 +50,7 @@ fn remap_method<'input, 'b>(
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
remap_statements(resolver, &mut result, statements)?;
run_statements(resolver, &mut result, statements)?;
Ok::<_, TranslateError>(result)
})
.transpose()?;
@ -65,7 +65,7 @@ fn remap_method<'input, 'b>(
})
}
fn remap_function_decl<'input, 'b>(
fn run_function_decl<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
func_directive: ast::MethodDeclaration<'input, &'input str>,
name: ast::MethodName<'input, SpirvWord>,
@ -74,12 +74,12 @@ fn remap_function_decl<'input, 'b>(
let return_arguments = func_directive
.return_arguments
.into_iter()
.map(|var| remap_variable(resolver, var))
.map(|var| run_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?;
let input_arguments = func_directive
.input_arguments
.into_iter()
.map(|var| remap_variable(resolver, var))
.map(|var| run_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?;
Ok(ast::MethodDeclaration {
return_arguments,
@ -89,7 +89,7 @@ fn remap_function_decl<'input, 'b>(
})
}
fn remap_variable<'input, 'b>(
fn run_variable<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
variable: ast::Variable<&'input str>,
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
@ -105,7 +105,7 @@ fn remap_variable<'input, 'b>(
})
}
fn remap_statements<'input, 'b>(
fn run_statements<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
@ -123,7 +123,7 @@ fn remap_statements<'input, 'b>(
ast::Statement::Label(label) => {
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
}
ast::Statement::Variable(variable) => remap_multivariable(resolver, result, variable)?,
ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
ast::Statement::Instruction(predicate, instruction) => {
result.push(Statement::Instruction((
predicate
@ -134,12 +134,12 @@ fn remap_statements<'input, 'b>(
})
})
.transpose()?,
remap_instruction(resolver, instruction)?,
run_instruction(resolver, instruction)?,
)))
}
ast::Statement::Block(block) => {
resolver.start_scope();
remap_statements(resolver, result, block)?;
run_statements(resolver, result, block)?;
resolver.end_scope();
}
}
@ -147,7 +147,7 @@ fn remap_statements<'input, 'b>(
Ok(())
}
fn remap_instruction<'input, 'b>(
fn run_instruction<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
@ -162,7 +162,7 @@ fn remap_instruction<'input, 'b>(
})
}
fn remap_multivariable<'input, 'b>(
fn run_multivariable<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
variable: ast::MultiVariable<&'input str>,
@ -203,12 +203,12 @@ fn remap_multivariable<'input, 'b>(
}
struct NameResolver<'input, 'b> {
flat_resolver: &'b mut GlobalStringIdentResolver<'input>,
flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
scopes: Vec<ScopeStringIdentResolver<'input>>,
}
impl<'input, 'b> NameResolver<'input, 'b> {
fn new(flat_resolver: &'b mut GlobalStringIdentResolver<'input>) -> Self {
fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
Self {
flat_resolver,
scopes: vec![ScopeStringIdentResolver::new()],
@ -239,9 +239,13 @@ impl<'input, 'b> NameResolver<'input, 'b> {
{
return Err(error_unknown_symbol());
}
current_scope
.ident_map
.insert(result, IdentEntry { name, type_space });
current_scope.ident_map.insert(
result,
IdentEntry {
name: Some(name),
type_space,
},
);
Ok(result)
}
@ -276,17 +280,7 @@ impl<'input> ScopeStringIdentResolver<'input> {
}
}
fn flush(self, resolver: &mut GlobalStringIdentResolver<'input>) {
fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
resolver.ident_map.extend(self.ident_map);
}
}
struct GlobalStringIdentResolver<'input> {
pub(crate) current_id: SpirvWord,
pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
}
struct IdentEntry<'input> {
name: Cow<'input, str>,
type_space: Option<(ast::Type, ast::StateSpace)>,
}

View file

@ -0,0 +1,84 @@
use super::*;
use ptx_parser as ast;
pub(crate) fn run<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<NormalizedDirective2<'input>>,
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
directives
.into_iter()
.map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()
}
fn run_directive<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
directive: NormalizedDirective2<'input>,
) -> Result<UnconditionalDirective<'input>, TranslateError> {
Ok(match directive {
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
})
}
fn run_method<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
method: NormalizedFunction2<'input>,
) -> Result<UnconditionalFunction<'input>, TranslateError> {
let body = method
.body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
run_statement(resolver, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
.transpose()?;
Ok(Function2 {
func_decl: method.func_decl,
globals: method.globals,
body,
import_as: method.import_as,
tuning: method.tuning,
linkage: method.linkage,
})
}
fn run_statement<'input>(
resolver: &mut GlobalStringIdentResolver2<'input>,
result: &mut Vec<UnconditionalStatement>,
statement: NormalizedStatement,
) -> Result<(), TranslateError> {
Ok(match statement {
Statement::Label(label) => result.push(Statement::Label(label)),
Statement::Variable(var) => result.push(Statement::Variable(var)),
Statement::Instruction((predicate, instruction)) => {
if let Some(pred) = predicate {
let if_true = resolver.register_intermediate(None);
let if_false = resolver.register_intermediate(None);
let folded_bra = match &instruction {
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
_ => None,
};
let mut branch = BrachCondition {
predicate: pred.label,
if_true: folded_bra.unwrap_or(if_true),
if_false,
};
if pred.not {
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
}
result.push(Statement::Conditional(branch));
if folded_bra.is_none() {
result.push(Statement::Label(if_true));
result.push(Statement::Instruction(instruction));
}
result.push(Statement::Label(if_false));
} else {
result.push(Statement::Instruction(instruction));
}
}
_ => return Err(error_unreachable()),
})
}