From 241cf43a52c88ea791bb8ce75a37de6b96721e71 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 12 Feb 2025 18:48:46 +0100 Subject: [PATCH] Add optimization of initial ftz placement in kernel --- Cargo.lock | 135 ++++++++++++++++ ptx/Cargo.toml | 2 + ptx/src/pass/insert_ftz_control.rs | 247 ++++++++++++++++++++++++++--- 3 files changed, 359 insertions(+), 25 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6e060e4..66e9625 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,6 +52,12 @@ dependencies = [ "serde", ] +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + [[package]] name = "beef" version = "0.5.2" @@ -405,6 +411,19 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "int-enum" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a37a9c11c6ecfec8b9bed97337dfecff3686d02ba8f52e8addad2829d047128" +dependencies = [ + "proc-macro2", + "proc-macro2-diagnostics", + "quote", + "syn 2.0.89", + "version_check", +] + [[package]] name = "itertools" version = "0.13.0" @@ -521,6 +540,16 @@ dependencies = [ "libc", ] +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.4" @@ -536,12 +565,37 @@ dependencies = [ "libc", ] +[[package]] +name = "microlp" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edaa5264bc1f7668bc12e10757f8f529a526656c796cc2106cf2be10c5b8d483" +dependencies = [ + "log", + "sprs", +] + [[package]] name = "minimal-lexical" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "nom" version = "7.1.3" @@ -552,6 +606,33 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "num_enum" version = "0.4.3" @@ -611,6 +692,15 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "prettyplease" version = "0.2.25" @@ -669,6 +759,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proc-macro2-diagnostics" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", + "version_check", + "yansi", +] + [[package]] name = "ptx" version = "0.0.0" @@ -679,7 +782,9 @@ dependencies = [ "cuda-driver-sys", "half", "hip_runtime-sys", + "int-enum", "llvm_zluda", + "microlp", "paste", "petgraph", "ptx_parser", @@ -742,6 +847,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "regex" version = "1.11.0" @@ -881,6 +992,24 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "sprs" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bff8419009a08f6cb7519a602c5590241fbff1446bcc823c07af15386eb801b" +dependencies = [ + "ndarray", + "num-complex", + "num-traits", + "smallvec", +] + [[package]] name = "strum" version = "0.26.3" @@ -1153,6 +1282,12 @@ dependencies = [ "serde_json", ] +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "zluda" version = "0.0.0" diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index aac1e6b..143b562 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -18,6 +18,8 @@ rustc-hash = "2.0.0" strum = "0.26" strum_macros = "0.26" petgraph = "0.7.1" +microlp = "0.2.10" +int-enum = "1.1" [dev-dependencies] hip_runtime-sys = { path = "../ext/hip_runtime-sys" } diff --git a/ptx/src/pass/insert_ftz_control.rs b/ptx/src/pass/insert_ftz_control.rs index 6c69f8d..4fc4136 100644 --- a/ptx/src/pass/insert_ftz_control.rs +++ b/ptx/src/pass/insert_ftz_control.rs @@ -1,11 +1,12 @@ -use std::hash::Hash; - use super::BrachCondition; use super::Directive2; use super::Function2; use super::SpirvWord; use super::Statement; use super::TranslateError; +use microlp::OptimizationDirection; +use microlp::Problem; +use microlp::Variable; use petgraph::graph::NodeIndex; use petgraph::visit::IntoNodeReferences; use petgraph::Direction; @@ -13,6 +14,8 @@ use petgraph::Graph; use ptx_parser as ast; use rustc_hash::FxHashMap; use rustc_hash::FxHashSet; +use std::hash::Hash; +use std::iter; struct ControlFlowGraph { entry_points: FxHashMap, @@ -92,41 +95,145 @@ pub(crate) fn run<'input>( todo!() } -fn compute(g: ControlFlowGraph) -> FxHashSet { +fn compute(graph: ControlFlowGraph) -> PartialModeInsertion { let mut must_insert_mode = FxHashSet::::default(); - let mut remaining = g + let mut maybe_insert_mode = FxHashMap::default(); + let mut remaining = graph .graph .node_references() .rev() - .filter_map(|(index, node)| node.entry.as_ref().map(|mode| (index, node.label, *mode))) + .filter_map(|(index, node)| { + node.entry + .as_ref() + .map(|mode| match mode { + ExtendedMode::BasicBlock(mode) => Some((index, node.label, *mode)), + ExtendedMode::Entry(_) => None, + }) + .flatten() + }) .collect::>(); 'next_basic_block: while let Some((index, node_id, expected_mode)) = remaining.pop() { - let mut to_visit = UniqueVec::new(g.graph.neighbors_directed(index, Direction::Incoming)); + let mut to_visit = + UniqueVec::new(graph.graph.neighbors_directed(index, Direction::Incoming)); let mut visited = FxHashSet::default(); while let Some(current) = to_visit.pop() { if visited.contains(¤t) { continue; } visited.insert(current); - let exit_mode = g.graph.node_weight(current).unwrap().exit; + let exit_mode = graph.graph.node_weight(current).unwrap().exit; match exit_mode { None => { - for predecessor in g.graph.neighbors_directed(current, Direction::Incoming) { + for predecessor in graph.graph.neighbors_directed(current, Direction::Incoming) + { if !visited.contains(&predecessor) { to_visit.push(predecessor); } } } - Some(mode) => { + Some(ExtendedMode::BasicBlock(mode)) => { if mode != expected_mode { + maybe_insert_mode.remove(&node_id); must_insert_mode.insert(node_id); continue 'next_basic_block; } } + Some(ExtendedMode::Entry(kernel)) => match maybe_insert_mode.entry(node_id) { + std::collections::hash_map::Entry::Vacant(entry) => { + entry.insert((expected_mode, iter::once(kernel).collect::>())); + } + std::collections::hash_map::Entry::Occupied(mut entry) => { + entry.get_mut().1.insert(kernel); + } + }, } } } - must_insert_mode + PartialModeInsertion { + bb_must_insert_mode: must_insert_mode, + bb_maybe_insert_mode: maybe_insert_mode, + } +} + +struct PartialModeInsertion { + bb_must_insert_mode: FxHashSet, + bb_maybe_insert_mode: FxHashMap)>, +} + +fn optimize + TryFrom + std::fmt::Debug, const N: usize>( + partial: PartialModeInsertion, +) -> ModeInsertions { + let mut problem = Problem::new(OptimizationDirection::Maximize); + let mut kernel_modes = FxHashMap::default(); + let basic_block_variables = partial + .bb_maybe_insert_mode + .into_iter() + .map(|(basic_block, (value, entry_points))| { + let modes = entry_points + .iter() + .map(|entry_point| { + let kernel_modes = kernel_modes + .entry(*entry_point) + .or_insert_with(|| one_of::(&mut problem)); + kernel_modes[value.into()] + }) + .collect::>(); + let bb = and(&mut problem, &*modes); + (basic_block, bb) + }) + .collect::>(); + // TODO: add fallback on Error + let solution = problem.solve().unwrap(); + let mut basic_blocks = partial.bb_must_insert_mode; + for (basic_block, variable) in basic_block_variables { + if solution[variable] < 0.5 { + basic_blocks.insert(basic_block); + } + } + let mut kernels = FxHashMap::default(); + for (kernel, modes) in kernel_modes { + for (mode, var) in modes.into_iter().enumerate() { + if solution[var] > 0.5 { + kernels.insert(kernel, T::try_from(mode).unwrap_or_else(|_| todo!())); + } + } + } + ModeInsertions { + basic_blocks, + kernels, + } +} + +fn and(problem: &mut Problem, variables: &[Variable]) -> Variable { + let result = problem.add_binary_var(1.0); + for var in variables { + problem.add_constraint( + &[(result, 1.0), (*var, -1.0)], + microlp::ComparisonOp::Le, + 0.0, + ); + } + problem.add_constraint( + iter::once((result, 1.0)).chain(variables.iter().map(|var| (*var, -1.0))), + microlp::ComparisonOp::Ge, + -((variables.len() - 1) as f64), + ); + result +} + +fn one_of(problem: &mut Problem) -> [Variable; N] { + let result = std::array::from_fn(|_| problem.add_binary_var(0.0)); + problem.add_constraint( + result.into_iter().map(|var| (var, 1.0)), + microlp::ComparisonOp::Eq, + 1.0, + ); + result +} + +struct ModeInsertions { + basic_blocks: FxHashSet, + kernels: FxHashMap, } #[derive(Eq, PartialEq, Clone, Copy)] @@ -176,27 +283,105 @@ impl UniqueVec { #[cfg(test)] mod tests { use super::*; + use int_enum::IntEnum; + + #[repr(usize)] + #[derive(IntEnum, Eq, PartialEq, Copy, Clone, Debug)] + enum Bool { + False = 0, + True = 1, + } + + #[test] + fn transitive_mixed() { + let mut graph = ControlFlowGraph::::new(); + let entry_id = SpirvWord(1); + let false_id = SpirvWord(2); + let empty_id = SpirvWord(3); + let false2_id = SpirvWord(4); + let entry = graph.add_entry_basic_block(entry_id); + graph.add_jump(entry, false_id); + let false_ = graph.get_or_add_basic_block(false_id); + graph.set_modes(false_, Bool::False, Bool::False); + graph.add_jump(false_, empty_id); + let empty = graph.get_or_add_basic_block(empty_id); + graph.add_jump(empty, false2_id); + let false2_ = graph.get_or_add_basic_block(false2_id); + graph.set_modes(false2_, Bool::False, Bool::False); + let partial_result = super::compute(graph); + assert_eq!(partial_result.bb_must_insert_mode.len(), 0); + assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1); + assert_eq!( + partial_result.bb_maybe_insert_mode[&false_id], + (Bool::False, iter::once(entry_id).collect()) + ); + + let result = optimize::(partial_result); + assert_eq!(result.basic_blocks.len(), 0); + assert_eq!(result.kernels.len(), 1); + assert_eq!(result.kernels[&entry_id], Bool::False); + } + + #[test] + fn transitive_change_twice() { + let mut graph = ControlFlowGraph::::new(); + let entry_id = SpirvWord(1); + let false_id = SpirvWord(2); + let empty_id = SpirvWord(3); + let true_id = SpirvWord(4); + let entry = graph.add_entry_basic_block(entry_id); + graph.add_jump(entry, false_id); + let false_ = graph.get_or_add_basic_block(false_id); + graph.set_modes(false_, Bool::False, Bool::False); + graph.add_jump(false_, empty_id); + let empty = graph.get_or_add_basic_block(empty_id); + graph.add_jump(empty, true_id); + let true_ = graph.get_or_add_basic_block(true_id); + graph.set_modes(true_, Bool::True, Bool::True); + let partial_result = super::compute(graph); + assert_eq!(partial_result.bb_must_insert_mode.len(), 1); + assert!(partial_result.bb_must_insert_mode.contains(&true_id)); + assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1); + assert_eq!( + partial_result.bb_maybe_insert_mode[&false_id], + (Bool::False, iter::once(entry_id).collect()) + ); + + let result = optimize::(partial_result); + assert_eq!(result.basic_blocks, iter::once(true_id).collect()); + assert_eq!(result.kernels.len(), 1); + assert_eq!(result.kernels[&entry_id], Bool::False); + } #[test] fn transitive_change() { - let mut graph = ControlFlowGraph::::new(); + let mut graph = ControlFlowGraph::::new(); let entry_id = SpirvWord(1); let empty_id = SpirvWord(2); - let false_id = SpirvWord(3); + let true_id = SpirvWord(3); let entry = graph.add_entry_basic_block(entry_id); graph.add_jump(entry, empty_id); let empty = graph.get_or_add_basic_block(empty_id); - graph.add_jump(empty, false_id); - let false_ = graph.get_or_add_basic_block(false_id); - graph.set_modes(false_, false, false); - let result = super::compute(graph); - assert_eq!(result.len(), 1); - assert!(result.contains(&false_id)); + graph.add_jump(empty, true_id); + let true_ = graph.get_or_add_basic_block(true_id); + graph.set_modes(true_, Bool::True, Bool::True); + let partial_result = super::compute(graph); + assert_eq!(partial_result.bb_must_insert_mode.len(), 0); + assert_eq!(partial_result.bb_maybe_insert_mode.len(), 1); + assert_eq!( + partial_result.bb_maybe_insert_mode[&true_id], + (Bool::True, iter::once(entry_id).collect()) + ); + + let result = optimize::(partial_result); + assert_eq!(result.basic_blocks.len(), 0); + assert_eq!(result.kernels.len(), 1); + assert_eq!(result.kernels[&entry_id], Bool::True); } #[test] fn codependency() { - let mut graph = ControlFlowGraph::::new(); + let mut graph = ControlFlowGraph::::new(); let entry_id = SpirvWord(1); let left_f_id = SpirvWord(2); let right_f_id = SpirvWord(3); @@ -207,9 +392,9 @@ mod tests { graph.add_jump(entry, left_f_id); graph.add_jump(entry, right_f_id); let left_f = graph.get_or_add_basic_block(left_f_id); - graph.set_modes(left_f, false, false); + graph.set_modes(left_f, Bool::False, Bool::False); let right_f = graph.get_or_add_basic_block(right_f_id); - graph.set_modes(right_f, false, false); + graph.set_modes(right_f, Bool::False, Bool::False); graph.add_jump(left_f, left_none_id); let left_none = graph.get_or_add_basic_block(left_none_id); graph.add_jump(right_f, right_none_id); @@ -223,9 +408,21 @@ mod tests { // "{:?}", // petgraph::dot::Dot::with_config(&graph.graph, &[petgraph::dot::Config::EdgeNoLabel]) //); - let result = super::compute(graph); - assert_eq!(result.len(), 2); - assert!(result.contains(&left_f_id)); - assert!(result.contains(&right_f_id)); + let partial_result = super::compute(graph); + assert_eq!(partial_result.bb_must_insert_mode.len(), 0); + assert_eq!(partial_result.bb_maybe_insert_mode.len(), 2); + assert_eq!( + partial_result.bb_maybe_insert_mode[&left_f_id], + (Bool::False, iter::once(entry_id).collect()) + ); + assert_eq!( + partial_result.bb_maybe_insert_mode[&right_f_id], + (Bool::False, iter::once(entry_id).collect()) + ); + + let result = optimize::(partial_result); + assert_eq!(result.basic_blocks.len(), 0); + assert_eq!(result.kernels.len(), 1); + assert_eq!(result.kernels[&entry_id], Bool::False); } }