From e87388bc352601201960458c2768b571c5947696 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 16 Sep 2024 17:08:12 +0200 Subject: [PATCH] Port normalize_predicates --- ptx/src/pass/mod.rs | 42 +++++++++++++ ptx/src/pass/normalize_identifiers2.rs | 68 ++++++++++----------- ptx/src/pass/normalize_predicates2.rs | 84 ++++++++++++++++++++++++++ 3 files changed, 157 insertions(+), 37 deletions(-) create mode 100644 ptx/src/pass/normalize_predicates2.rs diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 409425f..9277de4 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1,5 +1,6 @@ use ptx_parser as ast; use rspirv::{binary::Assemble, dr}; +use rustc_hash::FxHashMap; use std::hash::Hash; use std::num::NonZeroU8; use std::{ @@ -27,6 +28,7 @@ mod normalize_identifiers; mod normalize_identifiers2; mod normalize_labels; mod normalize_predicates; +mod normalize_predicates2; static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); @@ -1690,3 +1692,43 @@ type NormalizedFunction2<'input> = Function2< ), ast::ParsedOperand, >; + +type UnconditionalDirective<'input> = Directive2< + 'input, + ast::Instruction>, + ast::ParsedOperand, +>; + +type UnconditionalFunction<'input> = Function2< + 'input, + ast::Instruction>, + ast::ParsedOperand, +>; + +struct GlobalStringIdentResolver2<'input> { + pub(crate) current_id: SpirvWord, + pub(crate) ident_map: FxHashMap>, +} + +impl<'input> GlobalStringIdentResolver2<'input> { + fn register_intermediate( + &mut self, + type_space: Option<(ast::Type, ast::StateSpace)>, + ) -> SpirvWord { + let new_id = self.current_id; + self.ident_map.insert( + new_id, + IdentEntry { + name: None, + type_space, + }, + ); + self.current_id.0 += 1; + new_id + } +} + +struct IdentEntry<'input> { + name: Option>, + type_space: Option<(ast::Type, ast::StateSpace)>, +} diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs index 925feb7..e3fb88d 100644 --- a/ptx/src/pass/normalize_identifiers2.rs +++ b/ptx/src/pass/normalize_identifiers2.rs @@ -3,45 +3,45 @@ use ptx_parser as ast; use rustc_hash::FxHashMap; pub(crate) fn run<'input>( - fn_defs: &mut GlobalStringIdentResolver<'input>, + fn_defs: &mut GlobalStringIdentResolver2<'input>, directives: Vec>>, ) -> Result>, TranslateError> { let mut resolver = NameResolver::new(fn_defs); let result = directives .into_iter() - .map(|directive| remap_directive(&mut resolver, directive)) + .map(|directive| run_directive(&mut resolver, directive)) .collect::, _>>()?; resolver.end_scope(); Ok(result) } -fn remap_directive<'input, 'b>( +fn run_directive<'input, 'b>( resolver: &mut NameResolver<'input, 'b>, directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>, ) -> Result, TranslateError> { Ok(match directive { ast::Directive::Variable(linking, var) => { - NormalizedDirective2::Variable(linking, remap_variable(resolver, var)?) + NormalizedDirective2::Variable(linking, run_variable(resolver, var)?) } ast::Directive::Method(linking, directive) => { - NormalizedDirective2::Method(remap_method(resolver, linking, directive)?) + NormalizedDirective2::Method(run_method(resolver, linking, directive)?) } }) } -fn remap_method<'input, 'b>( +fn run_method<'input, 'b>( resolver: &mut NameResolver<'input, 'b>, linkage: ast::LinkingDirective, method: ast::Function<'input, &'input str, ast::Statement>>, ) -> Result, TranslateError> { let name = match method.func_directive.name { ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), - ast::MethodName::Func(text) => ast::MethodName::Func( - resolver.add(Cow::Borrowed(method.func_directive.name.text()), None)?, - ), + ast::MethodName::Func(text) => { + ast::MethodName::Func(resolver.add(Cow::Borrowed(text), None)?) + } }; resolver.start_scope(); - let func_decl = Rc::new(RefCell::new(remap_function_decl( + let func_decl = Rc::new(RefCell::new(run_function_decl( resolver, method.func_directive, name, @@ -50,7 +50,7 @@ fn remap_method<'input, 'b>( .body .map(|statements| { let mut result = Vec::with_capacity(statements.len()); - remap_statements(resolver, &mut result, statements)?; + run_statements(resolver, &mut result, statements)?; Ok::<_, TranslateError>(result) }) .transpose()?; @@ -65,7 +65,7 @@ fn remap_method<'input, 'b>( }) } -fn remap_function_decl<'input, 'b>( +fn run_function_decl<'input, 'b>( resolver: &mut NameResolver<'input, 'b>, func_directive: ast::MethodDeclaration<'input, &'input str>, name: ast::MethodName<'input, SpirvWord>, @@ -74,12 +74,12 @@ fn remap_function_decl<'input, 'b>( let return_arguments = func_directive .return_arguments .into_iter() - .map(|var| remap_variable(resolver, var)) + .map(|var| run_variable(resolver, var)) .collect::, _>>()?; let input_arguments = func_directive .input_arguments .into_iter() - .map(|var| remap_variable(resolver, var)) + .map(|var| run_variable(resolver, var)) .collect::, _>>()?; Ok(ast::MethodDeclaration { return_arguments, @@ -89,7 +89,7 @@ fn remap_function_decl<'input, 'b>( }) } -fn remap_variable<'input, 'b>( +fn run_variable<'input, 'b>( resolver: &mut NameResolver<'input, 'b>, variable: ast::Variable<&'input str>, ) -> Result, TranslateError> { @@ -105,7 +105,7 @@ fn remap_variable<'input, 'b>( }) } -fn remap_statements<'input, 'b>( +fn run_statements<'input, 'b>( resolver: &mut NameResolver<'input, 'b>, result: &mut Vec, statements: Vec>>, @@ -123,7 +123,7 @@ fn remap_statements<'input, 'b>( ast::Statement::Label(label) => { result.push(Statement::Label(resolver.get_in_current_scope(label)?)) } - ast::Statement::Variable(variable) => remap_multivariable(resolver, result, variable)?, + ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?, ast::Statement::Instruction(predicate, instruction) => { result.push(Statement::Instruction(( predicate @@ -134,12 +134,12 @@ fn remap_statements<'input, 'b>( }) }) .transpose()?, - remap_instruction(resolver, instruction)?, + run_instruction(resolver, instruction)?, ))) } ast::Statement::Block(block) => { resolver.start_scope(); - remap_statements(resolver, result, block)?; + run_statements(resolver, result, block)?; resolver.end_scope(); } } @@ -147,7 +147,7 @@ fn remap_statements<'input, 'b>( Ok(()) } -fn remap_instruction<'input, 'b>( +fn run_instruction<'input, 'b>( resolver: &mut NameResolver<'input, 'b>, instruction: ast::Instruction>, ) -> Result>, TranslateError> { @@ -162,7 +162,7 @@ fn remap_instruction<'input, 'b>( }) } -fn remap_multivariable<'input, 'b>( +fn run_multivariable<'input, 'b>( resolver: &mut NameResolver<'input, 'b>, result: &mut Vec, variable: ast::MultiVariable<&'input str>, @@ -203,12 +203,12 @@ fn remap_multivariable<'input, 'b>( } struct NameResolver<'input, 'b> { - flat_resolver: &'b mut GlobalStringIdentResolver<'input>, + flat_resolver: &'b mut GlobalStringIdentResolver2<'input>, scopes: Vec>, } impl<'input, 'b> NameResolver<'input, 'b> { - fn new(flat_resolver: &'b mut GlobalStringIdentResolver<'input>) -> Self { + fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self { Self { flat_resolver, scopes: vec![ScopeStringIdentResolver::new()], @@ -239,9 +239,13 @@ impl<'input, 'b> NameResolver<'input, 'b> { { return Err(error_unknown_symbol()); } - current_scope - .ident_map - .insert(result, IdentEntry { name, type_space }); + current_scope.ident_map.insert( + result, + IdentEntry { + name: Some(name), + type_space, + }, + ); Ok(result) } @@ -276,17 +280,7 @@ impl<'input> ScopeStringIdentResolver<'input> { } } - fn flush(self, resolver: &mut GlobalStringIdentResolver<'input>) { + fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) { resolver.ident_map.extend(self.ident_map); } } - -struct GlobalStringIdentResolver<'input> { - pub(crate) current_id: SpirvWord, - pub(crate) ident_map: FxHashMap>, -} - -struct IdentEntry<'input> { - name: Cow<'input, str>, - type_space: Option<(ast::Type, ast::StateSpace)>, -} diff --git a/ptx/src/pass/normalize_predicates2.rs b/ptx/src/pass/normalize_predicates2.rs new file mode 100644 index 0000000..2d15bba --- /dev/null +++ b/ptx/src/pass/normalize_predicates2.rs @@ -0,0 +1,84 @@ +use super::*; +use ptx_parser as ast; + +pub(crate) fn run<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec>, +) -> Result>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directive: NormalizedDirective2<'input>, +) -> Result, TranslateError> { + Ok(match directive { + Directive2::Variable(linking, var) => Directive2::Variable(linking, var), + Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), + }) +} + +fn run_method<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + method: NormalizedFunction2<'input>, +) -> Result, TranslateError> { + let body = method + .body + .map(|statements| { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + run_statement(resolver, &mut result, statement)?; + } + Ok::<_, TranslateError>(result) + }) + .transpose()?; + Ok(Function2 { + func_decl: method.func_decl, + globals: method.globals, + body, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + }) +} + +fn run_statement<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + result: &mut Vec, + statement: NormalizedStatement, +) -> Result<(), TranslateError> { + Ok(match statement { + Statement::Label(label) => result.push(Statement::Label(label)), + Statement::Variable(var) => result.push(Statement::Variable(var)), + Statement::Instruction((predicate, instruction)) => { + if let Some(pred) = predicate { + let if_true = resolver.register_intermediate(None); + let if_false = resolver.register_intermediate(None); + let folded_bra = match &instruction { + ast::Instruction::Bra { arguments, .. } => Some(arguments.src), + _ => None, + }; + let mut branch = BrachCondition { + predicate: pred.label, + if_true: folded_bra.unwrap_or(if_true), + if_false, + }; + if pred.not { + std::mem::swap(&mut branch.if_true, &mut branch.if_false); + } + result.push(Statement::Conditional(branch)); + if folded_bra.is_none() { + result.push(Statement::Label(if_true)); + result.push(Statement::Instruction(instruction)); + } + result.push(Statement::Label(if_false)); + } else { + result.push(Statement::Instruction(instruction)); + } + } + _ => return Err(error_unreachable()), + }) +}