From ac55b3beeb30d79f98cde97de67b0bab4e189164 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 3 Feb 2025 14:45:00 +0000 Subject: [PATCH] Write down key algorithm to track mode setting insertion in ftz pass --- Cargo.lock | 39 +++++ ptx/Cargo.toml | 1 + ptx/src/pass/insert_ftz_control.rs | 231 +++++++++++++++++++++++++++++ ptx/src/pass/mod.rs | 8 +- 4 files changed, 276 insertions(+), 3 deletions(-) create mode 100644 ptx/src/pass/insert_ftz_control.rs diff --git a/Cargo.lock b/Cargo.lock index bc0d08a..6e060e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -322,6 +322,12 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" version = "0.3.9" @@ -338,6 +344,12 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "fnv" version = "1.0.7" @@ -367,6 +379,12 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" +[[package]] +name = "hashbrown" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" + [[package]] name = "heck" version = "0.5.0" @@ -377,6 +395,16 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" name = "hip_runtime-sys" version = "0.0.0" +[[package]] +name = "indexmap" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" +dependencies = [ + "equivalent", + "hashbrown", +] + [[package]] name = "itertools" version = "0.13.0" @@ -561,6 +589,16 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "petgraph" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" +dependencies = [ + "fixedbitset", + "indexmap", +] + [[package]] name = "plain" version = "0.2.3" @@ -643,6 +681,7 @@ dependencies = [ "hip_runtime-sys", "llvm_zluda", "paste", + "petgraph", "ptx_parser", "quick-error", "rustc-hash 2.0.0", diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 9f3fa02..aac1e6b 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -17,6 +17,7 @@ bitflags = "1.2" rustc-hash = "2.0.0" strum = "0.26" strum_macros = "0.26" +petgraph = "0.7.1" [dev-dependencies] hip_runtime-sys = { path = "../ext/hip_runtime-sys" } diff --git a/ptx/src/pass/insert_ftz_control.rs b/ptx/src/pass/insert_ftz_control.rs new file mode 100644 index 0000000..6c69f8d --- /dev/null +++ b/ptx/src/pass/insert_ftz_control.rs @@ -0,0 +1,231 @@ +use std::hash::Hash; + +use super::BrachCondition; +use super::Directive2; +use super::Function2; +use super::SpirvWord; +use super::Statement; +use super::TranslateError; +use petgraph::graph::NodeIndex; +use petgraph::visit::IntoNodeReferences; +use petgraph::Direction; +use petgraph::Graph; +use ptx_parser as ast; +use rustc_hash::FxHashMap; +use rustc_hash::FxHashSet; + +struct ControlFlowGraph { + entry_points: FxHashMap, + basic_blocks: FxHashMap, + graph: Graph>, ()>, +} + +impl ControlFlowGraph { + fn new() -> Self { + Self { + entry_points: FxHashMap::default(), + basic_blocks: FxHashMap::default(), + graph: Graph::new(), + } + } + + fn add_entry_basic_block(&mut self, label: SpirvWord) -> NodeIndex { + let idx = self.graph.add_node(Node { + label, + entry: Some(ExtendedMode::Entry(label)), + exit: Some(ExtendedMode::Entry(label)), + }); + assert_eq!(self.entry_points.insert(label, idx), None); + idx + } + + fn get_or_add_basic_block(&mut self, label: SpirvWord) -> NodeIndex { + self.basic_blocks.get(&label).copied().unwrap_or_else(|| { + let idx = self.graph.add_node(Node { + label, + entry: None, + exit: None, + }); + self.basic_blocks.insert(label, idx); + idx + }) + } + + fn add_jump(&mut self, from: NodeIndex, to: SpirvWord) { + let to = self.get_or_add_basic_block(to); + self.graph.add_edge(from, to, ()); + } + + fn set_modes(&mut self, node: NodeIndex, entry: T, exit: T) { + self.graph[node].entry = Some(ExtendedMode::BasicBlock(entry)); + self.graph[node].exit = Some(ExtendedMode::BasicBlock(exit)); + } +} + +#[derive(Debug)] +struct Node { + label: SpirvWord, + entry: Option, + exit: Option, +} + +pub(crate) fn run<'input>( + flat_resolver: &mut super::GlobalStringIdentResolver2<'input>, + directives: Vec, super::SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + let mut cfg = ControlFlowGraph::::new(); + let mut node_idx_to_name = FxHashMap::, SpirvWord>::default(); + for directive in directives.iter() { + match directive { + super::Directive2::Method(Function2 { + func_decl: ast::MethodDeclaration { name, .. }, + body, + .. + }) => { + for statement in body.iter() { + todo!() + } + } + _ => continue, + } + } + todo!() +} + +fn compute(g: ControlFlowGraph) -> FxHashSet { + let mut must_insert_mode = FxHashSet::::default(); + let mut remaining = g + .graph + .node_references() + .rev() + .filter_map(|(index, node)| node.entry.as_ref().map(|mode| (index, node.label, *mode))) + .collect::>(); + 'next_basic_block: while let Some((index, node_id, expected_mode)) = remaining.pop() { + let mut to_visit = UniqueVec::new(g.graph.neighbors_directed(index, Direction::Incoming)); + let mut visited = FxHashSet::default(); + while let Some(current) = to_visit.pop() { + if visited.contains(¤t) { + continue; + } + visited.insert(current); + let exit_mode = g.graph.node_weight(current).unwrap().exit; + match exit_mode { + None => { + for predecessor in g.graph.neighbors_directed(current, Direction::Incoming) { + if !visited.contains(&predecessor) { + to_visit.push(predecessor); + } + } + } + Some(mode) => { + if mode != expected_mode { + must_insert_mode.insert(node_id); + continue 'next_basic_block; + } + } + } + } + } + must_insert_mode +} + +#[derive(Eq, PartialEq, Clone, Copy)] +enum ExtendedMode { + BasicBlock(T), + Entry(SpirvWord), +} + +struct UniqueVec { + set: FxHashSet, + vec: Vec, +} + +impl UniqueVec { + fn new(iter: impl Iterator) -> Self { + let mut set = FxHashSet::default(); + let mut vec = Vec::new(); + for item in iter { + if set.contains(&item) { + continue; + } + set.insert(item); + vec.push(item); + } + Self { set, vec } + } + + fn pop(&mut self) -> Option { + if let Some(t) = self.vec.pop() { + assert!(self.set.remove(&t)); + Some(t) + } else { + None + } + } + + fn push(&mut self, t: T) -> bool { + if self.set.insert(t) { + self.vec.push(t); + true + } else { + false + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn transitive_change() { + let mut graph = ControlFlowGraph::::new(); + let entry_id = SpirvWord(1); + let empty_id = SpirvWord(2); + let false_id = SpirvWord(3); + let entry = graph.add_entry_basic_block(entry_id); + graph.add_jump(entry, empty_id); + let empty = graph.get_or_add_basic_block(empty_id); + graph.add_jump(empty, false_id); + let false_ = graph.get_or_add_basic_block(false_id); + graph.set_modes(false_, false, false); + let result = super::compute(graph); + assert_eq!(result.len(), 1); + assert!(result.contains(&false_id)); + } + + #[test] + fn codependency() { + let mut graph = ControlFlowGraph::::new(); + let entry_id = SpirvWord(1); + let left_f_id = SpirvWord(2); + let right_f_id = SpirvWord(3); + let left_none_id = SpirvWord(4); + let mid_none_id = SpirvWord(5); + let right_none_id = SpirvWord(6); + let entry = graph.add_entry_basic_block(entry_id); + graph.add_jump(entry, left_f_id); + graph.add_jump(entry, right_f_id); + let left_f = graph.get_or_add_basic_block(left_f_id); + graph.set_modes(left_f, false, false); + let right_f = graph.get_or_add_basic_block(right_f_id); + graph.set_modes(right_f, false, false); + graph.add_jump(left_f, left_none_id); + let left_none = graph.get_or_add_basic_block(left_none_id); + graph.add_jump(right_f, right_none_id); + let right_none = graph.get_or_add_basic_block(right_none_id); + graph.add_jump(left_none, mid_none_id); + graph.add_jump(right_none, mid_none_id); + let mid_none = graph.get_or_add_basic_block(mid_none_id); + graph.add_jump(mid_none, left_none_id); + graph.add_jump(mid_none, right_none_id); + //println!( + // "{:?}", + // petgraph::dot::Dot::with_config(&graph.graph, &[petgraph::dot::Config::EdgeNoLabel]) + //); + let result = super::compute(graph); + assert_eq!(result.len(), 2); + assert!(result.contains(&left_f_id)); + assert!(result.contains(&right_f_id)); + } +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index c32cc39..8cc9926 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -17,12 +17,13 @@ mod expand_operands; mod fix_special_registers2; mod hoist_globals; mod insert_explicit_load_store; +mod insert_ftz_control; mod insert_implicit_conversions2; mod normalize_identifiers2; mod normalize_predicates2; mod replace_instructions_with_function_calls; -mod resolve_function_pointers; mod replace_known_functions; +mod resolve_function_pointers; static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_"; @@ -46,11 +47,12 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result>, ptx_parser::ParsedOperand>> = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?; + let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?; let directives = expand_operands::run(&mut flat_resolver, directives)?; let directives = deparamize_functions::run(&mut flat_resolver, directives)?; let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?; + let directives = insert_ftz_control::run(&mut flat_resolver, directives)?; let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?; let directives = hoist_globals::run(directives)?; let llvm_ir = emit_llvm::run(flat_resolver, directives)?; @@ -525,7 +527,7 @@ struct FunctionPointerDetails { src: SpirvWord, } -#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)] pub struct SpirvWord(u32); impl From for SpirvWord {