diff --git a/ptx/src/pass/insert_ftz_control.rs b/ptx/src/pass/insert_ftz_control.rs index 2d16015..a048f41 100644 --- a/ptx/src/pass/insert_ftz_control.rs +++ b/ptx/src/pass/insert_ftz_control.rs @@ -1,11 +1,10 @@ -use crate::pass::error_unreachable; - use super::BrachCondition; use super::Directive2; use super::Function2; use super::SpirvWord; use super::Statement; use super::TranslateError; +use crate::pass::error_unreachable; use microlp::OptimizationDirection; use microlp::Problem; use microlp::Variable; @@ -18,8 +17,11 @@ use rustc_hash::FxHashMap; use rustc_hash::FxHashSet; use std::hash::Hash; use std::iter; +use std::mem; +use strum::EnumCount; +use strum_macros::{EnumCount, VariantArray}; -#[derive(Default)] +#[derive(Default, PartialEq, Eq, Clone, Copy, Debug, VariantArray, EnumCount)] enum DenormalMode { #[default] FlushToZero, @@ -36,7 +38,13 @@ impl DenormalMode { } } -#[derive(Default)] +impl Into for DenormalMode { + fn into(self) -> usize { + self as usize + } +} + +#[derive(Default, PartialEq, Eq, Clone, Copy, Debug, VariantArray, EnumCount)] enum RoundingMode { #[default] NearestEven, @@ -65,20 +73,49 @@ impl RoundingMode { } } +impl Into for RoundingMode { + fn into(self) -> usize { + self as usize + } +} + struct InstructionModes { denormal_f32: Option, - denormal_f16_f64: Option, + denormal_f16f64: Option, rounding_f32: Option, - rounding_f16_f64: Option, + rounding_f16f64: Option, } impl InstructionModes { + fn fold_into(self, entry: &mut Self, exit: &mut Self) { + fn set_if_none(source: &mut Option, value: Option) { + match (*source, value) { + (None, Some(x)) => *source = Some(x), + _ => {} + } + } + fn set_if_some(source: &mut Option, value: Option) { + match (source, value) { + (Some(ref mut x), Some(y)) => *x = y, + _ => {} + } + } + set_if_none(&mut entry.denormal_f32, self.denormal_f32); + set_if_none(&mut entry.denormal_f16f64, self.denormal_f16f64); + set_if_none(&mut entry.rounding_f32, self.rounding_f32); + set_if_none(&mut entry.rounding_f16f64, self.rounding_f16f64); + set_if_some(&mut exit.denormal_f32, self.denormal_f32); + set_if_some(&mut exit.denormal_f16f64, self.denormal_f16f64); + set_if_some(&mut exit.rounding_f32, self.rounding_f32); + set_if_some(&mut exit.rounding_f16f64, self.rounding_f16f64); + } + fn none() -> Self { Self { denormal_f32: None, - denormal_f16_f64: None, + denormal_f16f64: None, rounding_f32: None, - rounding_f16_f64: None, + rounding_f16f64: None, } } @@ -89,8 +126,8 @@ impl InstructionModes { ) -> Self { if type_ != ast::ScalarType::F32 { Self { - denormal_f16_f64: denormal, - rounding_f16_f64: rounding, + denormal_f16f64: denormal, + rounding_f16f64: rounding, ..Self::none() } } else { @@ -109,7 +146,7 @@ impl InstructionModes { ) -> Self { if type_ != ast::ScalarType::F32 { Self { - denormal_f16_f64: denormal, + denormal_f16f64: denormal, rounding_f32: rounding, ..Self::none() } @@ -191,13 +228,13 @@ impl InstructionModes { } } -struct ControlFlowGraph { +struct ControlFlowGraph { entry_points: FxHashMap, basic_blocks: FxHashMap, - graph: Graph>, ()>, + graph: Graph, } -impl ControlFlowGraph { +impl ControlFlowGraph { fn new() -> Self { Self { entry_points: FxHashMap::default(), @@ -207,22 +244,14 @@ impl ControlFlowGraph { } fn add_entry_basic_block(&mut self, label: SpirvWord) -> NodeIndex { - let idx = self.graph.add_node(Node { - label, - entry: Some(ExtendedMode::Entry(label)), - exit: Some(ExtendedMode::Entry(label)), - }); + let idx = self.graph.add_node(Node::entry(label)); assert_eq!(self.entry_points.insert(label, idx), None); idx } fn get_or_add_basic_block(&mut self, label: SpirvWord) -> NodeIndex { self.basic_blocks.get(&label).copied().unwrap_or_else(|| { - let idx = self.graph.add_node(Node { - label, - entry: None, - exit: None, - }); + let idx = self.graph.add_node(Node::new(label)); self.basic_blocks.insert(label, idx); idx }) @@ -233,24 +262,90 @@ impl ControlFlowGraph { self.graph.add_edge(from, to, ()); } - fn set_modes(&mut self, node: NodeIndex, entry: T, exit: T) { - self.graph[node].entry = Some(ExtendedMode::BasicBlock(entry)); - self.graph[node].exit = Some(ExtendedMode::BasicBlock(exit)); + fn set_modes(&mut self, node: NodeIndex, entry: InstructionModes, exit: InstructionModes) { + self.graph[node].denormal_f32 = Mode { + entry: entry.denormal_f32.map(ExtendedMode::BasicBlock), + exit: exit.denormal_f32.map(ExtendedMode::BasicBlock), + }; + self.graph[node].denormal_f16f64 = Mode { + entry: entry.denormal_f16f64.map(ExtendedMode::BasicBlock), + exit: exit.denormal_f16f64.map(ExtendedMode::BasicBlock), + }; + self.graph[node].rounding_f32 = Mode { + entry: entry.rounding_f32.map(ExtendedMode::BasicBlock), + exit: exit.rounding_f32.map(ExtendedMode::BasicBlock), + }; + self.graph[node].rounding_f16f64 = Mode { + entry: entry.rounding_f16f64.map(ExtendedMode::BasicBlock), + exit: exit.rounding_f16f64.map(ExtendedMode::BasicBlock), + }; } } -#[derive(Debug)] -struct Node { +#[derive(Clone, Copy)] +struct Mode { + entry: Option>, + exit: Option>, +} + +impl Mode { + fn new() -> Self { + Self { + entry: None, + exit: None, + } + } + + fn entry(label: SpirvWord) -> Self { + Self { + entry: Some(ExtendedMode::Entry(label)), + exit: Some(ExtendedMode::Entry(label)), + } + } +} + +struct Node { label: SpirvWord, - entry: Option, - exit: Option, + denormal_f32: Mode, + denormal_f16f64: Mode, + rounding_f32: Mode, + rounding_f16f64: Mode, +} + +impl Node { + fn entry(label: SpirvWord) -> Self { + Self { + label, + denormal_f32: Mode::entry(label), + denormal_f16f64: Mode::entry(label), + rounding_f32: Mode::entry(label), + rounding_f16f64: Mode::entry(label), + } + } + + fn new(label: SpirvWord) -> Self { + Self { + label, + denormal_f32: Mode::new(), + denormal_f16f64: Mode::new(), + rounding_f32: Mode::new(), + rounding_f16f64: Mode::new(), + } + } +} + +trait EnumTuple { + const LENGTH: usize; + + fn get(&self, x: usize) -> u8; + fn get_mut(&mut self, x: usize) -> &mut u8; } pub(crate) fn run<'input>( flat_resolver: &mut super::GlobalStringIdentResolver2<'input>, directives: Vec, super::SpirvWord>>, ) -> Result, SpirvWord>>, TranslateError> { - let mut cfg = ControlFlowGraph::::new(); + let mut cfg = ControlFlowGraph::new(); for directive in directives.iter() { match directive { super::Directive2::Method(Function2 { @@ -259,11 +354,18 @@ pub(crate) fn run<'input>( .. }) => { let mut basic_block = Some(cfg.add_entry_basic_block(*name)); + let mut entry = InstructionModes::none(); + let mut exit = InstructionModes::none(); for statement in body.iter() { match statement { Statement::Instruction(ast::Instruction::Bra { arguments }) => { let bb_index = basic_block.ok_or_else(error_unreachable)?; cfg.add_jump(bb_index, arguments.src); + cfg.set_modes( + bb_index, + mem::replace(&mut entry, InstructionModes::none()), + mem::replace(&mut exit, InstructionModes::none()), + ); basic_block = None; } Statement::Label(label) => { @@ -275,22 +377,31 @@ pub(crate) fn run<'input>( let bb_index = basic_block.ok_or_else(error_unreachable)?; cfg.add_jump(bb_index, *if_true); cfg.add_jump(bb_index, *if_false); + cfg.set_modes( + bb_index, + mem::replace(&mut entry, InstructionModes::none()), + mem::replace(&mut exit, InstructionModes::none()), + ); basic_block = None; } Statement::Instruction(instruction) => { let modes = get_modes(instruction); + modes.fold_into(&mut entry, &mut exit); } - _ => continue, + _ => {} } } } - _ => continue, + _ => {} } } todo!() } -fn compute(graph: ControlFlowGraph) -> PartialModeInsertion { +fn compute_single_mode( + graph: &ControlFlowGraph, + mut getter: impl FnMut(&Node) -> Mode, +) -> PartialModeInsertion { let mut must_insert_mode = FxHashSet::::default(); let mut maybe_insert_mode = FxHashMap::default(); let mut remaining = graph @@ -298,7 +409,8 @@ fn compute(graph: ControlFlowGraph) -> PartialModeInsertion .node_references() .rev() .filter_map(|(index, node)| { - node.entry + getter(node) + .entry .as_ref() .map(|mode| match mode { ExtendedMode::BasicBlock(mode) => Some((index, node.label, *mode)), @@ -316,7 +428,7 @@ fn compute(graph: ControlFlowGraph) -> PartialModeInsertion continue; } visited.insert(current); - let exit_mode = graph.graph.node_weight(current).unwrap().exit; + let exit_mode = getter(graph.graph.node_weight(current).unwrap()).exit; match exit_mode { None => { for predecessor in graph.graph.neighbors_directed(current, Direction::Incoming) @@ -355,7 +467,7 @@ struct PartialModeInsertion { bb_maybe_insert_mode: FxHashMap)>, } -fn optimize + TryFrom + std::fmt::Debug, const N: usize>( +fn optimize + strum::VariantArray + std::fmt::Debug, const N: usize>( partial: PartialModeInsertion, ) -> ModeInsertions { let mut problem = Problem::new(OptimizationDirection::Maximize); @@ -389,7 +501,7 @@ fn optimize + TryFrom + std::fmt::Debug, const N: u for (kernel, modes) in kernel_modes { for (mode, var) in modes.into_iter().enumerate() { if solution[var] > 0.5 { - kernels.insert(kernel, T::try_from(mode).unwrap_or_else(|_| todo!())); + kernels.insert(kernel, T::VARIANTS[mode]); } } } @@ -642,6 +754,7 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { mod tests { use super::*; use int_enum::IntEnum; + use strum::EnumCount; #[repr(usize)] #[derive(IntEnum, Eq, PartialEq, Copy, Clone, Debug)] @@ -650,9 +763,27 @@ mod tests { True = 1, } + fn ftz() -> InstructionModes { + InstructionModes { + denormal_f32: Some(DenormalMode::FlushToZero), + denormal_f16f64: None, + rounding_f32: None, + rounding_f16f64: None, + } + } + + fn preserve() -> InstructionModes { + InstructionModes { + denormal_f32: Some(DenormalMode::Preserve), + denormal_f16f64: None, + rounding_f32: None, + rounding_f16f64: None, + } + } + #[test] fn transitive_mixed() { - let mut graph = ControlFlowGraph::::new(); + let mut graph = ControlFlowGraph::new(); let entry_id = SpirvWord(1); let false_id = SpirvWord(2); let empty_id = SpirvWord(3); @@ -660,29 +791,29 @@ mod tests { let entry = graph.add_entry_basic_block(entry_id); graph.add_jump(entry, false_id); let false_ = graph.get_or_add_basic_block(false_id); - graph.set_modes(false_, Bool::False, Bool::False); + graph.set_modes(false_, ftz(), ftz()); graph.add_jump(false_, empty_id); let empty = graph.get_or_add_basic_block(empty_id); graph.add_jump(empty, false2_id); let false2_ = graph.get_or_add_basic_block(false2_id); - graph.set_modes(false2_, Bool::False, Bool::False); - let partial_result = super::compute(graph); + graph.set_modes(false2_, ftz(), ftz()); + let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32); assert_eq!(partial_result.bb_must_insert_mode.len(), 0); assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1); assert_eq!( partial_result.bb_maybe_insert_mode[&false_id], - (Bool::False, iter::once(entry_id).collect()) + (DenormalMode::FlushToZero, iter::once(entry_id).collect()) ); - let result = optimize::(partial_result); + let result = optimize::(partial_result); assert_eq!(result.basic_blocks.len(), 0); assert_eq!(result.kernels.len(), 1); - assert_eq!(result.kernels[&entry_id], Bool::False); + assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero); } #[test] fn transitive_change_twice() { - let mut graph = ControlFlowGraph::::new(); + let mut graph = ControlFlowGraph::new(); let entry_id = SpirvWord(1); let false_id = SpirvWord(2); let empty_id = SpirvWord(3); @@ -690,30 +821,30 @@ mod tests { let entry = graph.add_entry_basic_block(entry_id); graph.add_jump(entry, false_id); let false_ = graph.get_or_add_basic_block(false_id); - graph.set_modes(false_, Bool::False, Bool::False); + graph.set_modes(false_, ftz(), ftz()); graph.add_jump(false_, empty_id); let empty = graph.get_or_add_basic_block(empty_id); graph.add_jump(empty, true_id); let true_ = graph.get_or_add_basic_block(true_id); - graph.set_modes(true_, Bool::True, Bool::True); - let partial_result = super::compute(graph); + graph.set_modes(true_, preserve(), preserve()); + let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32); assert_eq!(partial_result.bb_must_insert_mode.len(), 1); assert!(partial_result.bb_must_insert_mode.contains(&true_id)); assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1); assert_eq!( partial_result.bb_maybe_insert_mode[&false_id], - (Bool::False, iter::once(entry_id).collect()) + (DenormalMode::FlushToZero, iter::once(entry_id).collect()) ); - let result = optimize::(partial_result); + let result = optimize::(partial_result); assert_eq!(result.basic_blocks, iter::once(true_id).collect()); assert_eq!(result.kernels.len(), 1); - assert_eq!(result.kernels[&entry_id], Bool::False); + assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero); } #[test] fn transitive_change() { - let mut graph = ControlFlowGraph::::new(); + let mut graph = ControlFlowGraph::new(); let entry_id = SpirvWord(1); let empty_id = SpirvWord(2); let true_id = SpirvWord(3); @@ -722,24 +853,24 @@ mod tests { let empty = graph.get_or_add_basic_block(empty_id); graph.add_jump(empty, true_id); let true_ = graph.get_or_add_basic_block(true_id); - graph.set_modes(true_, Bool::True, Bool::True); - let partial_result = super::compute(graph); + graph.set_modes(true_, preserve(), preserve()); + let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32); assert_eq!(partial_result.bb_must_insert_mode.len(), 0); assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1); assert_eq!( partial_result.bb_maybe_insert_mode[&true_id], - (Bool::True, iter::once(entry_id).collect()) + (DenormalMode::Preserve, iter::once(entry_id).collect()) ); - let result = optimize::(partial_result); + let result = optimize::(partial_result); assert_eq!(result.basic_blocks.len(), 0); assert_eq!(result.kernels.len(), 1); - assert_eq!(result.kernels[&entry_id], Bool::True); + assert_eq!(result.kernels[&entry_id], DenormalMode::Preserve); } #[test] fn codependency() { - let mut graph = ControlFlowGraph::::new(); + let mut graph = ControlFlowGraph::new(); let entry_id = SpirvWord(1); let left_f_id = SpirvWord(2); let right_f_id = SpirvWord(3); @@ -750,9 +881,9 @@ mod tests { graph.add_jump(entry, left_f_id); graph.add_jump(entry, right_f_id); let left_f = graph.get_or_add_basic_block(left_f_id); - graph.set_modes(left_f, Bool::False, Bool::False); + graph.set_modes(left_f, ftz(), ftz()); let right_f = graph.get_or_add_basic_block(right_f_id); - graph.set_modes(right_f, Bool::False, Bool::False); + graph.set_modes(right_f, ftz(), ftz()); graph.add_jump(left_f, left_none_id); let left_none = graph.get_or_add_basic_block(left_none_id); graph.add_jump(right_f, right_none_id); @@ -766,21 +897,21 @@ mod tests { // "{:?}", // petgraph::dot::Dot::with_config(&graph.graph, &[petgraph::dot::Config::EdgeNoLabel]) //); - let partial_result = super::compute(graph); + let partial_result = super::compute_single_mode(&graph, |node| node.denormal_f32); assert_eq!(partial_result.bb_must_insert_mode.len(), 0); assert_eq!(partial_result.bb_maybe_insert_mode.len(), 2); assert_eq!( partial_result.bb_maybe_insert_mode[&left_f_id], - (Bool::False, iter::once(entry_id).collect()) + (DenormalMode::FlushToZero, iter::once(entry_id).collect()) ); assert_eq!( partial_result.bb_maybe_insert_mode[&right_f_id], - (Bool::False, iter::once(entry_id).collect()) + (DenormalMode::FlushToZero, iter::once(entry_id).collect()) ); - let result = optimize::(partial_result); + let result = optimize::(partial_result); assert_eq!(result.basic_blocks.len(), 0); assert_eq!(result.kernels.len(), 1); - assert_eq!(result.kernels[&entry_id], Bool::False); + assert_eq!(result.kernels[&entry_id], DenormalMode::FlushToZero); } }