From c86473b39696820453c6507c15aad396226d0368 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 10 Mar 2025 22:27:08 +0000 Subject: [PATCH] Bugfixing --- .../insert_ftz_control/call_with_mode.ptx | 8 ++ ptx/src/pass/insert_ftz_control/mod.rs | 88 ++++++++++----- ptx/src/pass/insert_ftz_control/test.rs | 33 +++++- ptx/src/pass/mod.rs | 2 +- ptx/src/pass/normalize_basic_blocks.rs | 105 +++++++++++++++--- 5 files changed, 192 insertions(+), 44 deletions(-) diff --git a/ptx/src/pass/insert_ftz_control/call_with_mode.ptx b/ptx/src/pass/insert_ftz_control/call_with_mode.ptx index cfff97c..506145a 100644 --- a/ptx/src/pass/insert_ftz_control/call_with_mode.ptx +++ b/ptx/src/pass/insert_ftz_control/call_with_mode.ptx @@ -10,12 +10,20 @@ add.rz.ftz.f32 temp, temp, temp; call use_modes; + add.rp.ftz.f32 temp, temp, temp; ret; } .func use_modes() { .reg .f32 temp; + .reg .pred pred; + @pred bra SET_RM; + @!pred bra SET_RZ; +SET_RM: add.rm.f32 temp, temp, temp; ret; +SET_RZ: + add.rz.f32 temp, temp, temp; + ret; } diff --git a/ptx/src/pass/insert_ftz_control/mod.rs b/ptx/src/pass/insert_ftz_control/mod.rs index a7097dd..afd2bcb 100644 --- a/ptx/src/pass/insert_ftz_control/mod.rs +++ b/ptx/src/pass/insert_ftz_control/mod.rs @@ -20,7 +20,6 @@ use smallvec::SmallVec; use std::hash::Hash; use std::iter; use std::mem; -use std::u32; use strum::EnumCount; use strum_macros::{EnumCount, VariantArray}; use unwrap_or::unwrap_some_or; @@ -250,7 +249,7 @@ struct ControlFlowGraph { // map function -> return label call_returns: FxHashMap>, // map function -> return basic blocks - function_rets: FxHashMap>, + functions_rets: FxHashMap, graph: Graph, } @@ -260,7 +259,7 @@ impl ControlFlowGraph { entry_points: FxHashMap::default(), basic_blocks: FxHashMap::default(), call_returns: FxHashMap::default(), - function_rets: FxHashMap::default(), + functions_rets: FxHashMap::default(), graph: Graph::new(), } } @@ -298,7 +297,7 @@ impl ControlFlowGraph { } fn fixup_function_calls(&mut self) { - for (function, sources) in self.function_rets.iter() { + for (function, source) in self.functions_rets.iter() { for target in self .call_returns .get(function) @@ -307,15 +306,15 @@ impl ControlFlowGraph { .flatten() .copied() { - for source in sources { - self.graph.add_edge(*source, target, ()); - } + self.graph.add_edge(*source, target, ()); } } } } #[derive(Clone, Copy)] +//#[cfg_attr(test, derive(Debug))] +#[derive(Debug)] struct Mode { entry: Option>, exit: Option>, @@ -337,6 +336,8 @@ impl Mode { } } +//#[cfg_attr(test, derive(Debug))] +#[derive(Debug)] struct Node { label: SpirvWord, denormal_f32: Mode, @@ -376,7 +377,7 @@ trait EnumTuple { pub(crate) fn run<'input>( flat_resolver: &mut super::GlobalStringIdentResolver2<'input>, - mut directives: Vec, super::SpirvWord>>, + directives: Vec, super::SpirvWord>>, ) -> Result, SpirvWord>>, TranslateError> { let mut cfg = ControlFlowGraph::new(); for directive in directives.iter() { @@ -398,13 +399,17 @@ pub(crate) fn run<'input>( arguments: ast::CallArgs { func, .. }, .. }) => { - let after_call_label = match body_iter.peek() { - Some(Statement::Label(l)) => *l, + let after_call_label = match body_iter.next() { + Some(Statement::Instruction(ast::Instruction::Bra { + arguments: ast::BraArgs { src }, + })) => *src, _ => return Err(error_unreachable()), }; bb_state.record_call(*func, after_call_label)?; + //body_iter.next(); } - Statement::Instruction(ast::Instruction::Ret { .. }) => { + Statement::RetValue(..) + | Statement::Instruction(ast::Instruction::Ret { .. }) => { bb_state.record_ret(*name)?; } Statement::Label(label) => { @@ -426,7 +431,15 @@ pub(crate) fn run<'input>( _ => {} } } + println!( + "{:?}", + petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel]) + ); cfg.fixup_function_calls(); + println!( + "{:?}", + petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel]) + ); let denormal_f32 = compute_single_mode(&cfg, |node| node.denormal_f32); let denormal_f16f64 = compute_single_mode(&cfg, |node| node.denormal_f16f64); let rounding_f32 = compute_single_mode(&cfg, |node| node.rounding_f32); @@ -434,7 +447,8 @@ pub(crate) fn run<'input>( let denormal_f32 = optimize::(denormal_f32); let denormal_f16f64 = optimize::(denormal_f16f64); let rounding_f32 = optimize::(rounding_f32); - let rounding_f16f64 = optimize::(rounding_f16f64); + let rounding_f16f64: MandatoryModeInsertions = + optimize::(rounding_f16f64); let denormal = join_modes( flat_resolver, &cfg, @@ -483,7 +497,7 @@ fn join_modes<'input, T: Eq + PartialEq + Copy + Default>( mut f16f64_exit_view: impl FnMut(&Node) -> Option>, ) -> Result, TranslateError> { // Returns None if there are multiple conflicting modes - fn get_incoming_mode( + fn get_incoming_mode( cfg: &ControlFlowGraph, kernels: &FxHashMap, node: NodeIndex, @@ -500,11 +514,11 @@ fn join_modes<'input, T: Eq + PartialEq + Copy + Default>( if !visited.insert(node) { continue; } - let x = &cfg.graph[node]; - match (mode, exit_getter(x)) { + let node_data = &cfg.graph[node]; + match (mode, exit_getter(node_data)) { (_, None) => { for next in cfg.graph.neighbors_directed(node, Direction::Incoming) { - if !visited.insert(next) { + if !visited.contains(&next) { to_visit.push(next); } } @@ -513,7 +527,7 @@ fn join_modes<'input, T: Eq + PartialEq + Copy + Default>( let new_mode = match new_mode { ExtendedMode::BasicBlock(new_mode) => new_mode, ExtendedMode::Entry(kernel) => { - *kernels.get(&kernel).ok_or_else(error_unreachable)? + kernels.get(&kernel).copied().unwrap_or_default() } }; if let Some(existing_mode) = existing_mode { @@ -546,7 +560,7 @@ fn join_modes<'input, T: Eq + PartialEq + Copy + Default>( .kernels .get(&kernel) .copied() - .ok_or_else(error_unreachable)?, + .unwrap_or_default(), ), // None means that no instruction in the basic block sets mode, but // another basic block might rely on this instruction transitively @@ -560,7 +574,7 @@ fn join_modes<'input, T: Eq + PartialEq + Copy + Default>( .kernels .get(&kernel) .copied() - .ok_or_else(error_unreachable)?, + .unwrap_or_default(), ), None => None, }; @@ -713,7 +727,9 @@ fn insert_mode_control<'input>( let old_body = mem::replace(body_ptr, Vec::new()); let mut result = Vec::with_capacity(old_body.len()); let mut bb_state = BasicBlockControlState::new(&global_modes, fn_name, initial_mode); - for mut statement in old_body.into_iter() { + let mut old_body = old_body.into_iter(); + while let Some(mut statement) = old_body.next() { + let mut call_target = None; match &mut statement { Statement::Label(label) => { bb_state.start(*label, &mut result)?; @@ -723,6 +739,7 @@ fn insert_mode_control<'input>( .. }) => { bb_state.redirect_jump(func)?; + call_target = Some(*func); } Statement::Conditional(BrachCondition { if_true, if_false, .. @@ -742,6 +759,16 @@ fn insert_mode_control<'input>( _ => {} } result.push(statement); + if let Some(call_target) = call_target { + if let Some(Statement::Instruction(ast::Instruction::Bra { + arguments: ast::BraArgs { src: post_call_label }, + })) = old_body.next() + { + // get return block for the function, if there is a mode + // change between caller and callee then apply it here + todo!() + } + } } *body_ptr = result; new_directives.push(directive); @@ -1165,8 +1192,8 @@ impl<'a> BasicBlockState<'a> { fn_call: SpirvWord, after_call_label: SpirvWord, ) -> Result<(), TranslateError> { - let node_index = self.node_index.ok_or_else(error_unreachable)?; - let after_call_label = self.cfg.add_jump(node_index, after_call_label); + self.end(&[fn_call]).ok_or_else(error_unreachable)?; + let after_call_label = self.cfg.get_or_add_basic_block(after_call_label); let call_returns = self .cfg .call_returns @@ -1178,8 +1205,11 @@ impl<'a> BasicBlockState<'a> { fn record_ret(&mut self, fn_name: SpirvWord) -> Result<(), TranslateError> { let node_index = self.node_index.ok_or_else(error_unreachable)?; - let function_rets = self.cfg.function_rets.entry(fn_name).or_insert(Vec::new()); - function_rets.push(node_index); + let previous_function_ret = self.cfg.functions_rets.insert(fn_name, node_index); + // This pass relies on there being only a single `ret;` in a function + if previous_function_ret.is_some() { + return Err(error_unreachable()); + } Ok(()) } @@ -1263,7 +1293,11 @@ struct PartialModeInsertion { bb_maybe_insert_mode: FxHashMap)>, } -fn optimize + strum::VariantArray + std::fmt::Debug, const N: usize>( +// Only returns kernel mode insertions if a kernel is relevant to the optimization problem +fn optimize< + T: Copy + Into + strum::VariantArray + std::fmt::Debug + Default, + const N: usize, +>( partial: PartialModeInsertion, ) -> MandatoryModeInsertions { let mut problem = Problem::new(OptimizationDirection::Maximize); @@ -1341,6 +1375,8 @@ struct MandatoryModeInsertions { } #[derive(Eq, PartialEq, Clone, Copy)] +//#[cfg_attr(test, derive(Debug))] +#[derive(Debug)] enum ExtendedMode { BasicBlock(T), Entry(SpirvWord), @@ -1549,4 +1585,4 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { } #[cfg(test)] -mod test; \ No newline at end of file +mod test; diff --git a/ptx/src/pass/insert_ftz_control/test.rs b/ptx/src/pass/insert_ftz_control/test.rs index 33ab51c..355d706 100644 --- a/ptx/src/pass/insert_ftz_control/test.rs +++ b/ptx/src/pass/insert_ftz_control/test.rs @@ -198,7 +198,7 @@ fn compile_methods(ptx: &str) -> Vec, Spir let directives = normalize_identifiers2::run(&mut scoped_resolver, module.directives).unwrap(); let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap(); let directives = expand_operands::run(&mut flat_resolver, directives).unwrap(); - let directives = normalize_basic_blocks::run(&mut flat_resolver, directives); + let directives = normalize_basic_blocks::run(&mut flat_resolver, directives).unwrap(); let directives = super::run(&mut flat_resolver, directives).unwrap(); directives .into_iter() @@ -220,10 +220,37 @@ fn call_with_mode() { &**methods[1].body.as_ref().unwrap(), [ Statement::Label(..), + Statement::Variable(..), + Statement::Instruction(ast::Instruction::Add { .. }), + Statement::Instruction(ast::Instruction::Call { .. }), + Statement::Instruction(ast::Instruction::Bra { .. }), + Statement::Label(..), + // Dual prelude Statement::SetMode(ModeRegister::Denormal { - f32: false, - f16f64: false + f32: true, + f16f64: true }), + Statement::SetMode(ModeRegister::Rounding { + f32: ast::RoundingMode::PositiveInf, + f16f64: ast::RoundingMode::NearestEven + }), + Statement::Instruction(ast::Instruction::Bra { .. }), + // Denormal prelude + Statement::Label(..), + Statement::SetMode(ModeRegister::Denormal { + f32: true, + f16f64: true + }), + Statement::Instruction(ast::Instruction::Bra { .. }), + // Rounding prelude + Statement::Label(..), + Statement::SetMode(ModeRegister::Rounding { + f32: ast::RoundingMode::PositiveInf, + f16f64: ast::RoundingMode::NearestEven + }), + Statement::Instruction(ast::Instruction::Bra { .. }), + Statement::Label(..), + Statement::Instruction(ast::Instruction::Add { .. }), Statement::Instruction(ast::Instruction::Ret { .. }), ] )); diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index d6e9aa4..10741c0 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -51,9 +51,9 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result, mut directives: Vec, SpirvWord>>, -) -> Vec, SpirvWord>> { +) -> Result, SpirvWord>>, TranslateError> { for directive in directives.iter_mut() { let body_ref = match directive { Directive2::Method(Function2 { @@ -20,8 +27,9 @@ pub(crate) fn run( }; let body = std::mem::replace(body_ref, Vec::new()); let mut result = Vec::with_capacity(body.len()); - let mut needs_label = false; + let mut previous_instruction_was_terminator = TerminatorKind::Not; let mut body_iterator = body.into_iter(); + let mut return_statements = Vec::new(); match body_iterator.next() { Some(Statement::Label(_)) => {} Some(statement) => { @@ -31,25 +39,94 @@ pub(crate) fn run( None => {} } for statement in body_iterator { - if needs_label && !matches!(statement, Statement::Label(..)) { - result.push(Statement::Label(flat_resolver.register_unnamed(None))); + match previous_instruction_was_terminator { + TerminatorKind::Not => match statement { + Statement::Label(label) => { + result.push(Statement::Instruction(ast::Instruction::Bra { + arguments: ast::BraArgs { src: label }, + })) + } + _ => {} + }, + TerminatorKind::Real => { + if !matches!(statement, Statement::Label(..)) { + result.push(Statement::Label(flat_resolver.register_unnamed(None))); + } + } + TerminatorKind::Fake => match statement { + // if it happens that there is a label after a call just reuse it + Statement::Label(label) => { + result.push(Statement::Instruction(ast::Instruction::Bra { + arguments: ast::BraArgs { src: label }, + })) + } + _ => { + let label = flat_resolver.register_unnamed(None); + result.push(Statement::Instruction(ast::Instruction::Bra { + arguments: ast::BraArgs { src: label }, + })); + result.push(Statement::Label(label)); + } + }, } - needs_label = is_block_terminator(&statement); + match statement { + Statement::RetValue(..) => { + return Err(error_unreachable()); + } + Statement::Instruction(ast::Instruction::Ret { .. }) => { + return_statements.push(result.len()) + } + _ => {} + } + previous_instruction_was_terminator = is_block_terminator(&statement); result.push(statement); } + convert_from_multiple_returns_to_single_return( + flat_resolver, + &mut result, + return_statements, + )?; *body_ref = result; } - directives + Ok(directives) } -fn is_block_terminator(instruction: &Statement, SpirvWord>) -> bool { - match instruction { +enum TerminatorKind { + Not, + Real, + Fake, +} + +fn convert_from_multiple_returns_to_single_return( + flat_resolver: &mut GlobalStringIdentResolver2<'_>, + result: &mut Vec, SpirvWord>>, + return_statements: Vec, +) -> Result<(), TranslateError> { + Ok(if return_statements.len() > 1 { + let ret_bb = flat_resolver.register_unnamed(None); + result.push(Statement::Label(ret_bb)); + result.push(Statement::Instruction(ast::Instruction::Ret { + data: ast::RetData { uniform: false }, + })); + for ret_index in return_statements { + let statement = result.get_mut(ret_index).ok_or_else(error_unreachable)?; + *statement = Statement::Instruction(ast::Instruction::Bra { + arguments: ast::BraArgs { src: ret_bb }, + }); + } + }) +} + +fn is_block_terminator( + statement: &Statement, SpirvWord>, +) -> TerminatorKind { + match statement { Statement::Conditional(..) | Statement::Instruction(ast::Instruction::Bra { .. }) // Normally call is not a terminator, but we treat it as such because it - // makes the instruction modes to global modes pass possible - | Statement::Instruction(ast::Instruction::Call { .. }) - | Statement::Instruction(ast::Instruction::Ret { .. }) => true, - _ => false, + // makes the "instruction modes to global modes" pass possible + | Statement::Instruction(ast::Instruction::Ret { .. }) => TerminatorKind::Real, + Statement::Instruction(ast::Instruction::Call { .. }) => TerminatorKind::Fake, + _ => TerminatorKind::Not, } }