Write down key algorithm to track mode setting insertion in ftz pass

This commit is contained in:
Andrzej Janik 2025-02-03 14:45:00 +00:00
commit ac55b3beeb
4 changed files with 276 additions and 3 deletions

39
Cargo.lock generated
View file

@ -322,6 +322,12 @@ version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
[[package]]
name = "equivalent"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "errno"
version = "0.3.9"
@ -338,6 +344,12 @@ version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
[[package]]
name = "fixedbitset"
version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99"
[[package]]
name = "fnv"
version = "1.0.7"
@ -367,6 +379,12 @@ version = "1.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
[[package]]
name = "hashbrown"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
[[package]]
name = "heck"
version = "0.5.0"
@ -377,6 +395,16 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
name = "hip_runtime-sys"
version = "0.0.0"
[[package]]
name = "indexmap"
version = "2.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652"
dependencies = [
"equivalent",
"hashbrown",
]
[[package]]
name = "itertools"
version = "0.13.0"
@ -561,6 +589,16 @@ version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "petgraph"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772"
dependencies = [
"fixedbitset",
"indexmap",
]
[[package]]
name = "plain"
version = "0.2.3"
@ -643,6 +681,7 @@ dependencies = [
"hip_runtime-sys",
"llvm_zluda",
"paste",
"petgraph",
"ptx_parser",
"quick-error",
"rustc-hash 2.0.0",

View file

@ -17,6 +17,7 @@ bitflags = "1.2"
rustc-hash = "2.0.0"
strum = "0.26"
strum_macros = "0.26"
petgraph = "0.7.1"
[dev-dependencies]
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }

View file

@ -0,0 +1,231 @@
use std::hash::Hash;
use super::BrachCondition;
use super::Directive2;
use super::Function2;
use super::SpirvWord;
use super::Statement;
use super::TranslateError;
use petgraph::graph::NodeIndex;
use petgraph::visit::IntoNodeReferences;
use petgraph::Direction;
use petgraph::Graph;
use ptx_parser as ast;
use rustc_hash::FxHashMap;
use rustc_hash::FxHashSet;
struct ControlFlowGraph<T: Eq + PartialEq> {
entry_points: FxHashMap<SpirvWord, NodeIndex>,
basic_blocks: FxHashMap<SpirvWord, NodeIndex>,
graph: Graph<Node<ExtendedMode<T>>, ()>,
}
impl<T: Eq + PartialEq> ControlFlowGraph<T> {
fn new() -> Self {
Self {
entry_points: FxHashMap::default(),
basic_blocks: FxHashMap::default(),
graph: Graph::new(),
}
}
fn add_entry_basic_block(&mut self, label: SpirvWord) -> NodeIndex {
let idx = self.graph.add_node(Node {
label,
entry: Some(ExtendedMode::Entry(label)),
exit: Some(ExtendedMode::Entry(label)),
});
assert_eq!(self.entry_points.insert(label, idx), None);
idx
}
fn get_or_add_basic_block(&mut self, label: SpirvWord) -> NodeIndex {
self.basic_blocks.get(&label).copied().unwrap_or_else(|| {
let idx = self.graph.add_node(Node {
label,
entry: None,
exit: None,
});
self.basic_blocks.insert(label, idx);
idx
})
}
fn add_jump(&mut self, from: NodeIndex, to: SpirvWord) {
let to = self.get_or_add_basic_block(to);
self.graph.add_edge(from, to, ());
}
fn set_modes(&mut self, node: NodeIndex, entry: T, exit: T) {
self.graph[node].entry = Some(ExtendedMode::BasicBlock(entry));
self.graph[node].exit = Some(ExtendedMode::BasicBlock(exit));
}
}
#[derive(Debug)]
struct Node<T> {
label: SpirvWord,
entry: Option<T>,
exit: Option<T>,
}
pub(crate) fn run<'input>(
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
directives: Vec<super::Directive2<'input, ast::Instruction<SpirvWord>, super::SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut cfg = ControlFlowGraph::<bool>::new();
let mut node_idx_to_name = FxHashMap::<NodeIndex<u32>, SpirvWord>::default();
for directive in directives.iter() {
match directive {
super::Directive2::Method(Function2 {
func_decl: ast::MethodDeclaration { name, .. },
body,
..
}) => {
for statement in body.iter() {
todo!()
}
}
_ => continue,
}
}
todo!()
}
fn compute<T: Copy + Eq>(g: ControlFlowGraph<T>) -> FxHashSet<SpirvWord> {
let mut must_insert_mode = FxHashSet::<SpirvWord>::default();
let mut remaining = g
.graph
.node_references()
.rev()
.filter_map(|(index, node)| node.entry.as_ref().map(|mode| (index, node.label, *mode)))
.collect::<Vec<_>>();
'next_basic_block: while let Some((index, node_id, expected_mode)) = remaining.pop() {
let mut to_visit = UniqueVec::new(g.graph.neighbors_directed(index, Direction::Incoming));
let mut visited = FxHashSet::default();
while let Some(current) = to_visit.pop() {
if visited.contains(&current) {
continue;
}
visited.insert(current);
let exit_mode = g.graph.node_weight(current).unwrap().exit;
match exit_mode {
None => {
for predecessor in g.graph.neighbors_directed(current, Direction::Incoming) {
if !visited.contains(&predecessor) {
to_visit.push(predecessor);
}
}
}
Some(mode) => {
if mode != expected_mode {
must_insert_mode.insert(node_id);
continue 'next_basic_block;
}
}
}
}
}
must_insert_mode
}
#[derive(Eq, PartialEq, Clone, Copy)]
enum ExtendedMode<T: Eq + PartialEq> {
BasicBlock(T),
Entry(SpirvWord),
}
struct UniqueVec<T: Copy + Eq + Hash> {
set: FxHashSet<T>,
vec: Vec<T>,
}
impl<T: Copy + Eq + Hash> UniqueVec<T> {
fn new(iter: impl Iterator<Item = T>) -> Self {
let mut set = FxHashSet::default();
let mut vec = Vec::new();
for item in iter {
if set.contains(&item) {
continue;
}
set.insert(item);
vec.push(item);
}
Self { set, vec }
}
fn pop(&mut self) -> Option<T> {
if let Some(t) = self.vec.pop() {
assert!(self.set.remove(&t));
Some(t)
} else {
None
}
}
fn push(&mut self, t: T) -> bool {
if self.set.insert(t) {
self.vec.push(t);
true
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn transitive_change() {
let mut graph = ControlFlowGraph::<bool>::new();
let entry_id = SpirvWord(1);
let empty_id = SpirvWord(2);
let false_id = SpirvWord(3);
let entry = graph.add_entry_basic_block(entry_id);
graph.add_jump(entry, empty_id);
let empty = graph.get_or_add_basic_block(empty_id);
graph.add_jump(empty, false_id);
let false_ = graph.get_or_add_basic_block(false_id);
graph.set_modes(false_, false, false);
let result = super::compute(graph);
assert_eq!(result.len(), 1);
assert!(result.contains(&false_id));
}
#[test]
fn codependency() {
let mut graph = ControlFlowGraph::<bool>::new();
let entry_id = SpirvWord(1);
let left_f_id = SpirvWord(2);
let right_f_id = SpirvWord(3);
let left_none_id = SpirvWord(4);
let mid_none_id = SpirvWord(5);
let right_none_id = SpirvWord(6);
let entry = graph.add_entry_basic_block(entry_id);
graph.add_jump(entry, left_f_id);
graph.add_jump(entry, right_f_id);
let left_f = graph.get_or_add_basic_block(left_f_id);
graph.set_modes(left_f, false, false);
let right_f = graph.get_or_add_basic_block(right_f_id);
graph.set_modes(right_f, false, false);
graph.add_jump(left_f, left_none_id);
let left_none = graph.get_or_add_basic_block(left_none_id);
graph.add_jump(right_f, right_none_id);
let right_none = graph.get_or_add_basic_block(right_none_id);
graph.add_jump(left_none, mid_none_id);
graph.add_jump(right_none, mid_none_id);
let mid_none = graph.get_or_add_basic_block(mid_none_id);
graph.add_jump(mid_none, left_none_id);
graph.add_jump(mid_none, right_none_id);
//println!(
// "{:?}",
// petgraph::dot::Dot::with_config(&graph.graph, &[petgraph::dot::Config::EdgeNoLabel])
//);
let result = super::compute(graph);
assert_eq!(result.len(), 2);
assert!(result.contains(&left_f_id));
assert!(result.contains(&right_f_id));
}
}

View file

@ -17,12 +17,13 @@ mod expand_operands;
mod fix_special_registers2;
mod hoist_globals;
mod insert_explicit_load_store;
mod insert_ftz_control;
mod insert_implicit_conversions2;
mod normalize_identifiers2;
mod normalize_predicates2;
mod replace_instructions_with_function_calls;
mod resolve_function_pointers;
mod replace_known_functions;
mod resolve_function_pointers;
static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_";
@ -46,11 +47,12 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
let directives = replace_known_functions::run(&flat_resolver, directives);
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
let directives = resolve_function_pointers::run(directives)?;
let directives: Vec<Directive2<'_, ptx_parser::Instruction<ptx_parser::ParsedOperand<SpirvWord>>, ptx_parser::ParsedOperand<SpirvWord>>> = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
let directives = expand_operands::run(&mut flat_resolver, directives)?;
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
let directives = insert_ftz_control::run(&mut flat_resolver, directives)?;
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
let directives = hoist_globals::run(directives)?;
let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
@ -525,7 +527,7 @@ struct FunctionPointerDetails {
src: SpirvWord,
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)]
pub struct SpirvWord(u32);
impl From<u32> for SpirvWord {