mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-05 15:49:24 +00:00
Write down key algorithm to track mode setting insertion in ftz pass
This commit is contained in:
parent
a6e6454d8b
commit
ac55b3beeb
4 changed files with 276 additions and 3 deletions
39
Cargo.lock
generated
39
Cargo.lock
generated
|
@ -322,6 +322,12 @@ version = "1.13.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.9"
|
||||
|
@ -338,6 +344,12 @@ version = "2.1.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
|
||||
|
||||
[[package]]
|
||||
name = "fixedbitset"
|
||||
version = "0.5.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99"
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
|
@ -367,6 +379,12 @@ version = "1.8.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.5.0"
|
||||
|
@ -377,6 +395,16 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
|||
name = "hip_runtime-sys"
|
||||
version = "0.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.13.0"
|
||||
|
@ -561,6 +589,16 @@ version = "1.0.15"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||
|
||||
[[package]]
|
||||
name = "petgraph"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772"
|
||||
dependencies = [
|
||||
"fixedbitset",
|
||||
"indexmap",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "plain"
|
||||
version = "0.2.3"
|
||||
|
@ -643,6 +681,7 @@ dependencies = [
|
|||
"hip_runtime-sys",
|
||||
"llvm_zluda",
|
||||
"paste",
|
||||
"petgraph",
|
||||
"ptx_parser",
|
||||
"quick-error",
|
||||
"rustc-hash 2.0.0",
|
||||
|
|
|
@ -17,6 +17,7 @@ bitflags = "1.2"
|
|||
rustc-hash = "2.0.0"
|
||||
strum = "0.26"
|
||||
strum_macros = "0.26"
|
||||
petgraph = "0.7.1"
|
||||
|
||||
[dev-dependencies]
|
||||
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
||||
|
|
231
ptx/src/pass/insert_ftz_control.rs
Normal file
231
ptx/src/pass/insert_ftz_control.rs
Normal file
|
@ -0,0 +1,231 @@
|
|||
use std::hash::Hash;
|
||||
|
||||
use super::BrachCondition;
|
||||
use super::Directive2;
|
||||
use super::Function2;
|
||||
use super::SpirvWord;
|
||||
use super::Statement;
|
||||
use super::TranslateError;
|
||||
use petgraph::graph::NodeIndex;
|
||||
use petgraph::visit::IntoNodeReferences;
|
||||
use petgraph::Direction;
|
||||
use petgraph::Graph;
|
||||
use ptx_parser as ast;
|
||||
use rustc_hash::FxHashMap;
|
||||
use rustc_hash::FxHashSet;
|
||||
|
||||
struct ControlFlowGraph<T: Eq + PartialEq> {
|
||||
entry_points: FxHashMap<SpirvWord, NodeIndex>,
|
||||
basic_blocks: FxHashMap<SpirvWord, NodeIndex>,
|
||||
graph: Graph<Node<ExtendedMode<T>>, ()>,
|
||||
}
|
||||
|
||||
impl<T: Eq + PartialEq> ControlFlowGraph<T> {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
entry_points: FxHashMap::default(),
|
||||
basic_blocks: FxHashMap::default(),
|
||||
graph: Graph::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_entry_basic_block(&mut self, label: SpirvWord) -> NodeIndex {
|
||||
let idx = self.graph.add_node(Node {
|
||||
label,
|
||||
entry: Some(ExtendedMode::Entry(label)),
|
||||
exit: Some(ExtendedMode::Entry(label)),
|
||||
});
|
||||
assert_eq!(self.entry_points.insert(label, idx), None);
|
||||
idx
|
||||
}
|
||||
|
||||
fn get_or_add_basic_block(&mut self, label: SpirvWord) -> NodeIndex {
|
||||
self.basic_blocks.get(&label).copied().unwrap_or_else(|| {
|
||||
let idx = self.graph.add_node(Node {
|
||||
label,
|
||||
entry: None,
|
||||
exit: None,
|
||||
});
|
||||
self.basic_blocks.insert(label, idx);
|
||||
idx
|
||||
})
|
||||
}
|
||||
|
||||
fn add_jump(&mut self, from: NodeIndex, to: SpirvWord) {
|
||||
let to = self.get_or_add_basic_block(to);
|
||||
self.graph.add_edge(from, to, ());
|
||||
}
|
||||
|
||||
fn set_modes(&mut self, node: NodeIndex, entry: T, exit: T) {
|
||||
self.graph[node].entry = Some(ExtendedMode::BasicBlock(entry));
|
||||
self.graph[node].exit = Some(ExtendedMode::BasicBlock(exit));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Node<T> {
|
||||
label: SpirvWord,
|
||||
entry: Option<T>,
|
||||
exit: Option<T>,
|
||||
}
|
||||
|
||||
pub(crate) fn run<'input>(
|
||||
flat_resolver: &mut super::GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<super::Directive2<'input, ast::Instruction<SpirvWord>, super::SpirvWord>>,
|
||||
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let mut cfg = ControlFlowGraph::<bool>::new();
|
||||
let mut node_idx_to_name = FxHashMap::<NodeIndex<u32>, SpirvWord>::default();
|
||||
for directive in directives.iter() {
|
||||
match directive {
|
||||
super::Directive2::Method(Function2 {
|
||||
func_decl: ast::MethodDeclaration { name, .. },
|
||||
body,
|
||||
..
|
||||
}) => {
|
||||
for statement in body.iter() {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn compute<T: Copy + Eq>(g: ControlFlowGraph<T>) -> FxHashSet<SpirvWord> {
|
||||
let mut must_insert_mode = FxHashSet::<SpirvWord>::default();
|
||||
let mut remaining = g
|
||||
.graph
|
||||
.node_references()
|
||||
.rev()
|
||||
.filter_map(|(index, node)| node.entry.as_ref().map(|mode| (index, node.label, *mode)))
|
||||
.collect::<Vec<_>>();
|
||||
'next_basic_block: while let Some((index, node_id, expected_mode)) = remaining.pop() {
|
||||
let mut to_visit = UniqueVec::new(g.graph.neighbors_directed(index, Direction::Incoming));
|
||||
let mut visited = FxHashSet::default();
|
||||
while let Some(current) = to_visit.pop() {
|
||||
if visited.contains(¤t) {
|
||||
continue;
|
||||
}
|
||||
visited.insert(current);
|
||||
let exit_mode = g.graph.node_weight(current).unwrap().exit;
|
||||
match exit_mode {
|
||||
None => {
|
||||
for predecessor in g.graph.neighbors_directed(current, Direction::Incoming) {
|
||||
if !visited.contains(&predecessor) {
|
||||
to_visit.push(predecessor);
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(mode) => {
|
||||
if mode != expected_mode {
|
||||
must_insert_mode.insert(node_id);
|
||||
continue 'next_basic_block;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
must_insert_mode
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq, Clone, Copy)]
|
||||
enum ExtendedMode<T: Eq + PartialEq> {
|
||||
BasicBlock(T),
|
||||
Entry(SpirvWord),
|
||||
}
|
||||
|
||||
struct UniqueVec<T: Copy + Eq + Hash> {
|
||||
set: FxHashSet<T>,
|
||||
vec: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T: Copy + Eq + Hash> UniqueVec<T> {
|
||||
fn new(iter: impl Iterator<Item = T>) -> Self {
|
||||
let mut set = FxHashSet::default();
|
||||
let mut vec = Vec::new();
|
||||
for item in iter {
|
||||
if set.contains(&item) {
|
||||
continue;
|
||||
}
|
||||
set.insert(item);
|
||||
vec.push(item);
|
||||
}
|
||||
Self { set, vec }
|
||||
}
|
||||
|
||||
fn pop(&mut self) -> Option<T> {
|
||||
if let Some(t) = self.vec.pop() {
|
||||
assert!(self.set.remove(&t));
|
||||
Some(t)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn push(&mut self, t: T) -> bool {
|
||||
if self.set.insert(t) {
|
||||
self.vec.push(t);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn transitive_change() {
|
||||
let mut graph = ControlFlowGraph::<bool>::new();
|
||||
let entry_id = SpirvWord(1);
|
||||
let empty_id = SpirvWord(2);
|
||||
let false_id = SpirvWord(3);
|
||||
let entry = graph.add_entry_basic_block(entry_id);
|
||||
graph.add_jump(entry, empty_id);
|
||||
let empty = graph.get_or_add_basic_block(empty_id);
|
||||
graph.add_jump(empty, false_id);
|
||||
let false_ = graph.get_or_add_basic_block(false_id);
|
||||
graph.set_modes(false_, false, false);
|
||||
let result = super::compute(graph);
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result.contains(&false_id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn codependency() {
|
||||
let mut graph = ControlFlowGraph::<bool>::new();
|
||||
let entry_id = SpirvWord(1);
|
||||
let left_f_id = SpirvWord(2);
|
||||
let right_f_id = SpirvWord(3);
|
||||
let left_none_id = SpirvWord(4);
|
||||
let mid_none_id = SpirvWord(5);
|
||||
let right_none_id = SpirvWord(6);
|
||||
let entry = graph.add_entry_basic_block(entry_id);
|
||||
graph.add_jump(entry, left_f_id);
|
||||
graph.add_jump(entry, right_f_id);
|
||||
let left_f = graph.get_or_add_basic_block(left_f_id);
|
||||
graph.set_modes(left_f, false, false);
|
||||
let right_f = graph.get_or_add_basic_block(right_f_id);
|
||||
graph.set_modes(right_f, false, false);
|
||||
graph.add_jump(left_f, left_none_id);
|
||||
let left_none = graph.get_or_add_basic_block(left_none_id);
|
||||
graph.add_jump(right_f, right_none_id);
|
||||
let right_none = graph.get_or_add_basic_block(right_none_id);
|
||||
graph.add_jump(left_none, mid_none_id);
|
||||
graph.add_jump(right_none, mid_none_id);
|
||||
let mid_none = graph.get_or_add_basic_block(mid_none_id);
|
||||
graph.add_jump(mid_none, left_none_id);
|
||||
graph.add_jump(mid_none, right_none_id);
|
||||
//println!(
|
||||
// "{:?}",
|
||||
// petgraph::dot::Dot::with_config(&graph.graph, &[petgraph::dot::Config::EdgeNoLabel])
|
||||
//);
|
||||
let result = super::compute(graph);
|
||||
assert_eq!(result.len(), 2);
|
||||
assert!(result.contains(&left_f_id));
|
||||
assert!(result.contains(&right_f_id));
|
||||
}
|
||||
}
|
|
@ -17,12 +17,13 @@ mod expand_operands;
|
|||
mod fix_special_registers2;
|
||||
mod hoist_globals;
|
||||
mod insert_explicit_load_store;
|
||||
mod insert_ftz_control;
|
||||
mod insert_implicit_conversions2;
|
||||
mod normalize_identifiers2;
|
||||
mod normalize_predicates2;
|
||||
mod replace_instructions_with_function_calls;
|
||||
mod resolve_function_pointers;
|
||||
mod replace_known_functions;
|
||||
mod resolve_function_pointers;
|
||||
|
||||
static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
||||
const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_";
|
||||
|
@ -46,11 +47,12 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
|
|||
let directives = replace_known_functions::run(&flat_resolver, directives);
|
||||
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
|
||||
let directives = resolve_function_pointers::run(directives)?;
|
||||
let directives: Vec<Directive2<'_, ptx_parser::Instruction<ptx_parser::ParsedOperand<SpirvWord>>, ptx_parser::ParsedOperand<SpirvWord>>> = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
|
||||
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
|
||||
let directives = expand_operands::run(&mut flat_resolver, directives)?;
|
||||
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
|
||||
let directives = insert_ftz_control::run(&mut flat_resolver, directives)?;
|
||||
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
|
||||
let directives = hoist_globals::run(directives)?;
|
||||
let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
|
||||
|
@ -525,7 +527,7 @@ struct FunctionPointerDetails {
|
|||
src: SpirvWord,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)]
|
||||
pub struct SpirvWord(u32);
|
||||
|
||||
impl From<u32> for SpirvWord {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue