mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-04 15:19:49 +00:00
Port normalize_predicates
This commit is contained in:
parent
3b5efbf88b
commit
e87388bc35
3 changed files with 157 additions and 37 deletions
|
@ -1,5 +1,6 @@
|
||||||
use ptx_parser as ast;
|
use ptx_parser as ast;
|
||||||
use rspirv::{binary::Assemble, dr};
|
use rspirv::{binary::Assemble, dr};
|
||||||
|
use rustc_hash::FxHashMap;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
use std::num::NonZeroU8;
|
use std::num::NonZeroU8;
|
||||||
use std::{
|
use std::{
|
||||||
|
@ -27,6 +28,7 @@ mod normalize_identifiers;
|
||||||
mod normalize_identifiers2;
|
mod normalize_identifiers2;
|
||||||
mod normalize_labels;
|
mod normalize_labels;
|
||||||
mod normalize_predicates;
|
mod normalize_predicates;
|
||||||
|
mod normalize_predicates2;
|
||||||
|
|
||||||
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
|
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
|
||||||
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
||||||
|
@ -1690,3 +1692,43 @@ type NormalizedFunction2<'input> = Function2<
|
||||||
),
|
),
|
||||||
ast::ParsedOperand<SpirvWord>,
|
ast::ParsedOperand<SpirvWord>,
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
type UnconditionalDirective<'input> = Directive2<
|
||||||
|
'input,
|
||||||
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
|
ast::ParsedOperand<SpirvWord>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
type UnconditionalFunction<'input> = Function2<
|
||||||
|
'input,
|
||||||
|
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
|
||||||
|
ast::ParsedOperand<SpirvWord>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
struct GlobalStringIdentResolver2<'input> {
|
||||||
|
pub(crate) current_id: SpirvWord,
|
||||||
|
pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'input> GlobalStringIdentResolver2<'input> {
|
||||||
|
fn register_intermediate(
|
||||||
|
&mut self,
|
||||||
|
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||||
|
) -> SpirvWord {
|
||||||
|
let new_id = self.current_id;
|
||||||
|
self.ident_map.insert(
|
||||||
|
new_id,
|
||||||
|
IdentEntry {
|
||||||
|
name: None,
|
||||||
|
type_space,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
self.current_id.0 += 1;
|
||||||
|
new_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct IdentEntry<'input> {
|
||||||
|
name: Option<Cow<'input, str>>,
|
||||||
|
type_space: Option<(ast::Type, ast::StateSpace)>,
|
||||||
|
}
|
||||||
|
|
|
@ -3,45 +3,45 @@ use ptx_parser as ast;
|
||||||
use rustc_hash::FxHashMap;
|
use rustc_hash::FxHashMap;
|
||||||
|
|
||||||
pub(crate) fn run<'input>(
|
pub(crate) fn run<'input>(
|
||||||
fn_defs: &mut GlobalStringIdentResolver<'input>,
|
fn_defs: &mut GlobalStringIdentResolver2<'input>,
|
||||||
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
|
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
|
||||||
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
|
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
|
||||||
let mut resolver = NameResolver::new(fn_defs);
|
let mut resolver = NameResolver::new(fn_defs);
|
||||||
let result = directives
|
let result = directives
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|directive| remap_directive(&mut resolver, directive))
|
.map(|directive| run_directive(&mut resolver, directive))
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
resolver.end_scope();
|
resolver.end_scope();
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remap_directive<'input, 'b>(
|
fn run_directive<'input, 'b>(
|
||||||
resolver: &mut NameResolver<'input, 'b>,
|
resolver: &mut NameResolver<'input, 'b>,
|
||||||
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
|
||||||
) -> Result<NormalizedDirective2<'input>, TranslateError> {
|
) -> Result<NormalizedDirective2<'input>, TranslateError> {
|
||||||
Ok(match directive {
|
Ok(match directive {
|
||||||
ast::Directive::Variable(linking, var) => {
|
ast::Directive::Variable(linking, var) => {
|
||||||
NormalizedDirective2::Variable(linking, remap_variable(resolver, var)?)
|
NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
|
||||||
}
|
}
|
||||||
ast::Directive::Method(linking, directive) => {
|
ast::Directive::Method(linking, directive) => {
|
||||||
NormalizedDirective2::Method(remap_method(resolver, linking, directive)?)
|
NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remap_method<'input, 'b>(
|
fn run_method<'input, 'b>(
|
||||||
resolver: &mut NameResolver<'input, 'b>,
|
resolver: &mut NameResolver<'input, 'b>,
|
||||||
linkage: ast::LinkingDirective,
|
linkage: ast::LinkingDirective,
|
||||||
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
|
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||||
) -> Result<NormalizedFunction2<'input>, TranslateError> {
|
) -> Result<NormalizedFunction2<'input>, TranslateError> {
|
||||||
let name = match method.func_directive.name {
|
let name = match method.func_directive.name {
|
||||||
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
|
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
|
||||||
ast::MethodName::Func(text) => ast::MethodName::Func(
|
ast::MethodName::Func(text) => {
|
||||||
resolver.add(Cow::Borrowed(method.func_directive.name.text()), None)?,
|
ast::MethodName::Func(resolver.add(Cow::Borrowed(text), None)?)
|
||||||
),
|
}
|
||||||
};
|
};
|
||||||
resolver.start_scope();
|
resolver.start_scope();
|
||||||
let func_decl = Rc::new(RefCell::new(remap_function_decl(
|
let func_decl = Rc::new(RefCell::new(run_function_decl(
|
||||||
resolver,
|
resolver,
|
||||||
method.func_directive,
|
method.func_directive,
|
||||||
name,
|
name,
|
||||||
|
@ -50,7 +50,7 @@ fn remap_method<'input, 'b>(
|
||||||
.body
|
.body
|
||||||
.map(|statements| {
|
.map(|statements| {
|
||||||
let mut result = Vec::with_capacity(statements.len());
|
let mut result = Vec::with_capacity(statements.len());
|
||||||
remap_statements(resolver, &mut result, statements)?;
|
run_statements(resolver, &mut result, statements)?;
|
||||||
Ok::<_, TranslateError>(result)
|
Ok::<_, TranslateError>(result)
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
|
@ -65,7 +65,7 @@ fn remap_method<'input, 'b>(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remap_function_decl<'input, 'b>(
|
fn run_function_decl<'input, 'b>(
|
||||||
resolver: &mut NameResolver<'input, 'b>,
|
resolver: &mut NameResolver<'input, 'b>,
|
||||||
func_directive: ast::MethodDeclaration<'input, &'input str>,
|
func_directive: ast::MethodDeclaration<'input, &'input str>,
|
||||||
name: ast::MethodName<'input, SpirvWord>,
|
name: ast::MethodName<'input, SpirvWord>,
|
||||||
|
@ -74,12 +74,12 @@ fn remap_function_decl<'input, 'b>(
|
||||||
let return_arguments = func_directive
|
let return_arguments = func_directive
|
||||||
.return_arguments
|
.return_arguments
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|var| remap_variable(resolver, var))
|
.map(|var| run_variable(resolver, var))
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
let input_arguments = func_directive
|
let input_arguments = func_directive
|
||||||
.input_arguments
|
.input_arguments
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|var| remap_variable(resolver, var))
|
.map(|var| run_variable(resolver, var))
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
Ok(ast::MethodDeclaration {
|
Ok(ast::MethodDeclaration {
|
||||||
return_arguments,
|
return_arguments,
|
||||||
|
@ -89,7 +89,7 @@ fn remap_function_decl<'input, 'b>(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remap_variable<'input, 'b>(
|
fn run_variable<'input, 'b>(
|
||||||
resolver: &mut NameResolver<'input, 'b>,
|
resolver: &mut NameResolver<'input, 'b>,
|
||||||
variable: ast::Variable<&'input str>,
|
variable: ast::Variable<&'input str>,
|
||||||
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
|
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
|
||||||
|
@ -105,7 +105,7 @@ fn remap_variable<'input, 'b>(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remap_statements<'input, 'b>(
|
fn run_statements<'input, 'b>(
|
||||||
resolver: &mut NameResolver<'input, 'b>,
|
resolver: &mut NameResolver<'input, 'b>,
|
||||||
result: &mut Vec<NormalizedStatement>,
|
result: &mut Vec<NormalizedStatement>,
|
||||||
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||||
|
@ -123,7 +123,7 @@ fn remap_statements<'input, 'b>(
|
||||||
ast::Statement::Label(label) => {
|
ast::Statement::Label(label) => {
|
||||||
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
|
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
|
||||||
}
|
}
|
||||||
ast::Statement::Variable(variable) => remap_multivariable(resolver, result, variable)?,
|
ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
|
||||||
ast::Statement::Instruction(predicate, instruction) => {
|
ast::Statement::Instruction(predicate, instruction) => {
|
||||||
result.push(Statement::Instruction((
|
result.push(Statement::Instruction((
|
||||||
predicate
|
predicate
|
||||||
|
@ -134,12 +134,12 @@ fn remap_statements<'input, 'b>(
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.transpose()?,
|
.transpose()?,
|
||||||
remap_instruction(resolver, instruction)?,
|
run_instruction(resolver, instruction)?,
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
ast::Statement::Block(block) => {
|
ast::Statement::Block(block) => {
|
||||||
resolver.start_scope();
|
resolver.start_scope();
|
||||||
remap_statements(resolver, result, block)?;
|
run_statements(resolver, result, block)?;
|
||||||
resolver.end_scope();
|
resolver.end_scope();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -147,7 +147,7 @@ fn remap_statements<'input, 'b>(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remap_instruction<'input, 'b>(
|
fn run_instruction<'input, 'b>(
|
||||||
resolver: &mut NameResolver<'input, 'b>,
|
resolver: &mut NameResolver<'input, 'b>,
|
||||||
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
|
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
|
||||||
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
|
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
|
||||||
|
@ -162,7 +162,7 @@ fn remap_instruction<'input, 'b>(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remap_multivariable<'input, 'b>(
|
fn run_multivariable<'input, 'b>(
|
||||||
resolver: &mut NameResolver<'input, 'b>,
|
resolver: &mut NameResolver<'input, 'b>,
|
||||||
result: &mut Vec<NormalizedStatement>,
|
result: &mut Vec<NormalizedStatement>,
|
||||||
variable: ast::MultiVariable<&'input str>,
|
variable: ast::MultiVariable<&'input str>,
|
||||||
|
@ -203,12 +203,12 @@ fn remap_multivariable<'input, 'b>(
|
||||||
}
|
}
|
||||||
|
|
||||||
struct NameResolver<'input, 'b> {
|
struct NameResolver<'input, 'b> {
|
||||||
flat_resolver: &'b mut GlobalStringIdentResolver<'input>,
|
flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
|
||||||
scopes: Vec<ScopeStringIdentResolver<'input>>,
|
scopes: Vec<ScopeStringIdentResolver<'input>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'input, 'b> NameResolver<'input, 'b> {
|
impl<'input, 'b> NameResolver<'input, 'b> {
|
||||||
fn new(flat_resolver: &'b mut GlobalStringIdentResolver<'input>) -> Self {
|
fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
flat_resolver,
|
flat_resolver,
|
||||||
scopes: vec![ScopeStringIdentResolver::new()],
|
scopes: vec![ScopeStringIdentResolver::new()],
|
||||||
|
@ -239,9 +239,13 @@ impl<'input, 'b> NameResolver<'input, 'b> {
|
||||||
{
|
{
|
||||||
return Err(error_unknown_symbol());
|
return Err(error_unknown_symbol());
|
||||||
}
|
}
|
||||||
current_scope
|
current_scope.ident_map.insert(
|
||||||
.ident_map
|
result,
|
||||||
.insert(result, IdentEntry { name, type_space });
|
IdentEntry {
|
||||||
|
name: Some(name),
|
||||||
|
type_space,
|
||||||
|
},
|
||||||
|
);
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -276,17 +280,7 @@ impl<'input> ScopeStringIdentResolver<'input> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn flush(self, resolver: &mut GlobalStringIdentResolver<'input>) {
|
fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
|
||||||
resolver.ident_map.extend(self.ident_map);
|
resolver.ident_map.extend(self.ident_map);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct GlobalStringIdentResolver<'input> {
|
|
||||||
pub(crate) current_id: SpirvWord,
|
|
||||||
pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct IdentEntry<'input> {
|
|
||||||
name: Cow<'input, str>,
|
|
||||||
type_space: Option<(ast::Type, ast::StateSpace)>,
|
|
||||||
}
|
|
||||||
|
|
84
ptx/src/pass/normalize_predicates2.rs
Normal file
84
ptx/src/pass/normalize_predicates2.rs
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
|
||||||
|
pub(crate) fn run<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
directives: Vec<NormalizedDirective2<'input>>,
|
||||||
|
) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
|
||||||
|
directives
|
||||||
|
.into_iter()
|
||||||
|
.map(|directive| run_directive(resolver, directive))
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_directive<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
directive: NormalizedDirective2<'input>,
|
||||||
|
) -> Result<UnconditionalDirective<'input>, TranslateError> {
|
||||||
|
Ok(match directive {
|
||||||
|
Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
|
||||||
|
Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_method<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
method: NormalizedFunction2<'input>,
|
||||||
|
) -> Result<UnconditionalFunction<'input>, TranslateError> {
|
||||||
|
let body = method
|
||||||
|
.body
|
||||||
|
.map(|statements| {
|
||||||
|
let mut result = Vec::with_capacity(statements.len());
|
||||||
|
for statement in statements {
|
||||||
|
run_statement(resolver, &mut result, statement)?;
|
||||||
|
}
|
||||||
|
Ok::<_, TranslateError>(result)
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
|
Ok(Function2 {
|
||||||
|
func_decl: method.func_decl,
|
||||||
|
globals: method.globals,
|
||||||
|
body,
|
||||||
|
import_as: method.import_as,
|
||||||
|
tuning: method.tuning,
|
||||||
|
linkage: method.linkage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_statement<'input>(
|
||||||
|
resolver: &mut GlobalStringIdentResolver2<'input>,
|
||||||
|
result: &mut Vec<UnconditionalStatement>,
|
||||||
|
statement: NormalizedStatement,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
Ok(match statement {
|
||||||
|
Statement::Label(label) => result.push(Statement::Label(label)),
|
||||||
|
Statement::Variable(var) => result.push(Statement::Variable(var)),
|
||||||
|
Statement::Instruction((predicate, instruction)) => {
|
||||||
|
if let Some(pred) = predicate {
|
||||||
|
let if_true = resolver.register_intermediate(None);
|
||||||
|
let if_false = resolver.register_intermediate(None);
|
||||||
|
let folded_bra = match &instruction {
|
||||||
|
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
let mut branch = BrachCondition {
|
||||||
|
predicate: pred.label,
|
||||||
|
if_true: folded_bra.unwrap_or(if_true),
|
||||||
|
if_false,
|
||||||
|
};
|
||||||
|
if pred.not {
|
||||||
|
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
|
||||||
|
}
|
||||||
|
result.push(Statement::Conditional(branch));
|
||||||
|
if folded_bra.is_none() {
|
||||||
|
result.push(Statement::Label(if_true));
|
||||||
|
result.push(Statement::Instruction(instruction));
|
||||||
|
}
|
||||||
|
result.push(Statement::Label(if_false));
|
||||||
|
} else {
|
||||||
|
result.push(Statement::Instruction(instruction));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
})
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue