mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-03 06:40:21 +00:00
Port normalize_predicates
This commit is contained in:
parent
3b5efbf88b
commit
e87388bc35
3 changed files with 157 additions and 37 deletions
|
@ -1,5 +1,6 @@
|
|||
use ptx_parser as ast;
|
||||
use rspirv::{binary::Assemble, dr};
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::hash::Hash;
|
||||
use std::num::NonZeroU8;
|
||||
use std::{
|
||||
|
@ -27,6 +28,7 @@ mod normalize_identifiers;
|
|||
mod normalize_identifiers2;
|
||||
mod normalize_labels;
|
||||
mod normalize_predicates;
|
||||
mod normalize_predicates2;
|
||||
|
||||
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
|
||||
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
||||
|
@ -1690,3 +1692,43 @@ type NormalizedFunction2<'input> = Function2<
|
|||
),
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
|
||||
type UnconditionalDirective<'input> = Directive2<
|
||||
'input,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
|
||||
type UnconditionalFunction<'input> = Function2<
|
||||
'input,
|
||||
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||
ast::ParsedOperand<SpirvWord>,
|
||||
>;
|
||||
|
||||
struct GlobalStringIdentResolver2<'input> {
|
||||
pub(crate) current_id: SpirvWord,
|
||||
pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
|
||||
}
|
||||
|
||||
impl<'input> GlobalStringIdentResolver2<'input> {
|
||||
fn register_intermediate(
|
||||
&mut self,
|
||||
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||
) -> SpirvWord {
|
||||
let new_id = self.current_id;
|
||||
self.ident_map.insert(
|
||||
new_id,
|
||||
IdentEntry {
|
||||
name: None,
|
||||
type_space,
|
||||
},
|
||||
);
|
||||
self.current_id.0 += 1;
|
||||
new_id
|
||||
}
|
||||
}
|
||||
|
||||
struct IdentEntry<'input> {
|
||||
name: Option<Cow<'input, str>>,
|
||||
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||
}
|
||||
|
|
|
@ -3,45 +3,45 @@ use ptx_parser as ast;
|
|||
use rustc_hash::FxHashMap;
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
fn_defs: &mut GlobalStringIdentResolver<'input>,
|
||||
fn_defs: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
|
||||
let mut resolver = NameResolver::new(fn_defs);
|
||||
let result = directives
|
||||
.into_iter()
|
||||
.map(|directive| remap_directive(&mut resolver, directive))
|
||||
.map(|directive| run_directive(&mut resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
resolver.end_scope();
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn remap_directive<'input, 'b>(
|
||||
fn run_directive<'input, 'b>(
|
||||
resolver: &mut NameResolver<'input, 'b>,
|
||||
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
||||
) -> Result<NormalizedDirective2<'input>, TranslateError> {
|
||||
Ok(match directive {
|
||||
ast::Directive::Variable(linking, var) => {
|
||||
NormalizedDirective2::Variable(linking, remap_variable(resolver, var)?)
|
||||
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
|
||||
}
|
||||
ast::Directive::Method(linking, directive) => {
|
||||
NormalizedDirective2::Method(remap_method(resolver, linking, directive)?)
|
||||
NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn remap_method<'input, 'b>(
|
||||
fn run_method<'input, 'b>(
|
||||
resolver: &mut NameResolver<'input, 'b>,
|
||||
linkage: ast::LinkingDirective,
|
||||
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<NormalizedFunction2<'input>, TranslateError> {
|
||||
let name = match method.func_directive.name {
|
||||
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
|
||||
ast::MethodName::Func(text) => ast::MethodName::Func(
|
||||
resolver.add(Cow::Borrowed(method.func_directive.name.text()), None)?,
|
||||
),
|
||||
ast::MethodName::Func(text) => {
|
||||
ast::MethodName::Func(resolver.add(Cow::Borrowed(text), None)?)
|
||||
}
|
||||
};
|
||||
resolver.start_scope();
|
||||
let func_decl = Rc::new(RefCell::new(remap_function_decl(
|
||||
let func_decl = Rc::new(RefCell::new(run_function_decl(
|
||||
resolver,
|
||||
method.func_directive,
|
||||
name,
|
||||
|
@ -50,7 +50,7 @@ fn remap_method<'input, 'b>(
|
|||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
remap_statements(resolver, &mut result, statements)?;
|
||||
run_statements(resolver, &mut result, statements)?;
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
|
@ -65,7 +65,7 @@ fn remap_method<'input, 'b>(
|
|||
})
|
||||
}
|
||||
|
||||
fn remap_function_decl<'input, 'b>(
|
||||
fn run_function_decl<'input, 'b>(
|
||||
resolver: &mut NameResolver<'input, 'b>,
|
||||
func_directive: ast::MethodDeclaration<'input, &'input str>,
|
||||
name: ast::MethodName<'input, SpirvWord>,
|
||||
|
@ -74,12 +74,12 @@ fn remap_function_decl<'input, 'b>(
|
|||
let return_arguments = func_directive
|
||||
.return_arguments
|
||||
.into_iter()
|
||||
.map(|var| remap_variable(resolver, var))
|
||||
.map(|var| run_variable(resolver, var))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let input_arguments = func_directive
|
||||
.input_arguments
|
||||
.into_iter()
|
||||
.map(|var| remap_variable(resolver, var))
|
||||
.map(|var| run_variable(resolver, var))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(ast::MethodDeclaration {
|
||||
return_arguments,
|
||||
|
@ -89,7 +89,7 @@ fn remap_function_decl<'input, 'b>(
|
|||
})
|
||||
}
|
||||
|
||||
fn remap_variable<'input, 'b>(
|
||||
fn run_variable<'input, 'b>(
|
||||
resolver: &mut NameResolver<'input, 'b>,
|
||||
variable: ast::Variable<&'input str>,
|
||||
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
|
||||
|
@ -105,7 +105,7 @@ fn remap_variable<'input, 'b>(
|
|||
})
|
||||
}
|
||||
|
||||
fn remap_statements<'input, 'b>(
|
||||
fn run_statements<'input, 'b>(
|
||||
resolver: &mut NameResolver<'input, 'b>,
|
||||
result: &mut Vec<NormalizedStatement>,
|
||||
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||
|
@ -123,7 +123,7 @@ fn remap_statements<'input, 'b>(
|
|||
ast::Statement::Label(label) => {
|
||||
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
|
||||
}
|
||||
ast::Statement::Variable(variable) => remap_multivariable(resolver, result, variable)?,
|
||||
ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
|
||||
ast::Statement::Instruction(predicate, instruction) => {
|
||||
result.push(Statement::Instruction((
|
||||
predicate
|
||||
|
@ -134,12 +134,12 @@ fn remap_statements<'input, 'b>(
|
|||
})
|
||||
})
|
||||
.transpose()?,
|
||||
remap_instruction(resolver, instruction)?,
|
||||
run_instruction(resolver, instruction)?,
|
||||
)))
|
||||
}
|
||||
ast::Statement::Block(block) => {
|
||||
resolver.start_scope();
|
||||
remap_statements(resolver, result, block)?;
|
||||
run_statements(resolver, result, block)?;
|
||||
resolver.end_scope();
|
||||
}
|
||||
}
|
||||
|
@ -147,7 +147,7 @@ fn remap_statements<'input, 'b>(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn remap_instruction<'input, 'b>(
|
||||
fn run_instruction<'input, 'b>(
|
||||
resolver: &mut NameResolver<'input, 'b>,
|
||||
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
|
||||
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
|
||||
|
@ -162,7 +162,7 @@ fn remap_instruction<'input, 'b>(
|
|||
})
|
||||
}
|
||||
|
||||
fn remap_multivariable<'input, 'b>(
|
||||
fn run_multivariable<'input, 'b>(
|
||||
resolver: &mut NameResolver<'input, 'b>,
|
||||
result: &mut Vec<NormalizedStatement>,
|
||||
variable: ast::MultiVariable<&'input str>,
|
||||
|
@ -203,12 +203,12 @@ fn remap_multivariable<'input, 'b>(
|
|||
}
|
||||
|
||||
struct NameResolver<'input, 'b> {
|
||||
flat_resolver: &'b mut GlobalStringIdentResolver<'input>,
|
||||
flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
|
||||
scopes: Vec<ScopeStringIdentResolver<'input>>,
|
||||
}
|
||||
|
||||
impl<'input, 'b> NameResolver<'input, 'b> {
|
||||
fn new(flat_resolver: &'b mut GlobalStringIdentResolver<'input>) -> Self {
|
||||
fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
|
||||
Self {
|
||||
flat_resolver,
|
||||
scopes: vec![ScopeStringIdentResolver::new()],
|
||||
|
@ -239,9 +239,13 @@ impl<'input, 'b> NameResolver<'input, 'b> {
|
|||
{
|
||||
return Err(error_unknown_symbol());
|
||||
}
|
||||
current_scope
|
||||
.ident_map
|
||||
.insert(result, IdentEntry { name, type_space });
|
||||
current_scope.ident_map.insert(
|
||||
result,
|
||||
IdentEntry {
|
||||
name: Some(name),
|
||||
type_space,
|
||||
},
|
||||
);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
|
@ -276,17 +280,7 @@ impl<'input> ScopeStringIdentResolver<'input> {
|
|||
}
|
||||
}
|
||||
|
||||
fn flush(self, resolver: &mut GlobalStringIdentResolver<'input>) {
|
||||
fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
|
||||
resolver.ident_map.extend(self.ident_map);
|
||||
}
|
||||
}
|
||||
|
||||
struct GlobalStringIdentResolver<'input> {
|
||||
pub(crate) current_id: SpirvWord,
|
||||
pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
|
||||
}
|
||||
|
||||
struct IdentEntry<'input> {
|
||||
name: Cow<'input, str>,
|
||||
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||
}
|
||||
|
|
84
ptx/src/pass/normalize_predicates2.rs
Normal file
84
ptx/src/pass/normalize_predicates2.rs
Normal file
|
@ -0,0 +1,84 @@
|
|||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<NormalizedDirective2<'input>>,
|
||||
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|directive| run_directive(resolver, directive))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
fn run_directive<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
directive: NormalizedDirective2<'input>,
|
||||
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
||||
Ok(match directive {
|
||||
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_method<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
method: NormalizedFunction2<'input>,
|
||||
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
||||
let body = method
|
||||
.body
|
||||
.map(|statements| {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
run_statement(resolver, &mut result, statement)?;
|
||||
}
|
||||
Ok::<_, TranslateError>(result)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Function2 {
|
||||
func_decl: method.func_decl,
|
||||
globals: method.globals,
|
||||
body,
|
||||
import_as: method.import_as,
|
||||
tuning: method.tuning,
|
||||
linkage: method.linkage,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_statement<'input>(
|
||||
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||
result: &mut Vec<UnconditionalStatement>,
|
||||
statement: NormalizedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
Ok(match statement {
|
||||
Statement::Label(label) => result.push(Statement::Label(label)),
|
||||
Statement::Variable(var) => result.push(Statement::Variable(var)),
|
||||
Statement::Instruction((predicate, instruction)) => {
|
||||
if let Some(pred) = predicate {
|
||||
let if_true = resolver.register_intermediate(None);
|
||||
let if_false = resolver.register_intermediate(None);
|
||||
let folded_bra = match &instruction {
|
||||
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
|
||||
_ => None,
|
||||
};
|
||||
let mut branch = BrachCondition {
|
||||
predicate: pred.label,
|
||||
if_true: folded_bra.unwrap_or(if_true),
|
||||
if_false,
|
||||
};
|
||||
if pred.not {
|
||||
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
|
||||
}
|
||||
result.push(Statement::Conditional(branch));
|
||||
if folded_bra.is_none() {
|
||||
result.push(Statement::Label(if_true));
|
||||
result.push(Statement::Instruction(instruction));
|
||||
}
|
||||
result.push(Statement::Label(if_false));
|
||||
} else {
|
||||
result.push(Statement::Instruction(instruction));
|
||||
}
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
})
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue