diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 20a716e..1e138dd 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -22,6 +22,7 @@ mod insert_implicit_conversions2; mod normalize_basic_blocks; mod normalize_identifiers2; mod normalize_predicates2; +mod remove_unreachable_basic_blocks; mod replace_instructions_with_function_calls; mod replace_known_functions; mod resolve_function_pointers; @@ -52,6 +53,7 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + for directive in directives.iter_mut() { + match directive { + Directive2::Method(Function2 { + body: Some(body), .. + }) => { + let old_body = std::mem::replace(body, Vec::new()); + let mut cfg = ControlFlowGraph::new(); + let mut old_body_iter = old_body.iter(); + let mut current_bb = match old_body_iter.next() { + Some(Statement::Label(label)) => cfg.add_or_get_node(*label), + _ => return Err(error_unreachable()), + }; + let first_bb = current_bb; + for statement in old_body_iter { + match statement { + Statement::Label(label) => { + current_bb = cfg.add_or_get_node(*label); + } + Statement::Conditional(branch) => { + cfg.add_branch(current_bb, branch.if_true); + cfg.add_branch(current_bb, branch.if_false); + } + Statement::Instruction(ast::Instruction::Bra { + arguments: ast::BraArgs { src }, + }) => { + cfg.add_branch(current_bb, *src); + } + _ => {} + } + } + let mut bfs = Bfs::new(&cfg.graph, first_bb); + while let Some(_) = bfs.next(&cfg.graph) {} + let mut visited = true; + *body = try_filter_to_vec(old_body.into_iter(), |statement| { + match statement { + Statement::Label(label) => { + visited = bfs + .discovered + .is_visited(cfg.nodes.get(label).ok_or_else(error_unreachable)?); + } + _ => {} + } + Ok(visited) + })?; + } + _ => {} + } + } + Ok(directives) +} + +fn try_filter_to_vec( + mut iter: impl ExactSizeIterator, + mut filter: impl FnMut(&T) -> Result, +) -> Result, E> { + iter.try_fold(Vec::with_capacity(iter.len()), |mut vec, item| { + match filter(&item) { + Ok(true) => vec.push(item), + Ok(false) => {} + Err(err) => return Err(err), + } + Ok(vec) + }) +} + +struct ControlFlowGraph { + graph: Graph, + nodes: FxHashMap, +} + +impl ControlFlowGraph { + fn new() -> Self { + Self { + graph: Graph::new(), + nodes: FxHashMap::default(), + } + } + + fn add_or_get_node(&mut self, id: SpirvWord) -> NodeIndex { + *self + .nodes + .entry(id) + .or_insert_with(|| self.graph.add_node(id)) + } + + fn add_branch(&mut self, from: NodeIndex, to: SpirvWord) -> NodeIndex { + let to = self.add_or_get_node(to); + self.graph.add_edge(from, to, ()); + to + } +}