diff --git a/ptx/src/pass/insert_ftz_control.rs b/ptx/src/pass/insert_ftz_control.rs index f21a804..25eca97 100644 --- a/ptx/src/pass/insert_ftz_control.rs +++ b/ptx/src/pass/insert_ftz_control.rs @@ -349,21 +349,9 @@ pub(crate) fn run<'input>( is_kernel, .. }) => { - // TODO: implement for non-kernels - if !*is_kernel { - todo!() - } - let entry_index = cfg.add_entry_basic_block(*name); - let mut bb_state = BasicBlockState::new(&mut cfg); - let mut body_iter = body.iter(); - match body_iter.next() { - Some(Statement::Label(label)) => { - bb_state.cfg.add_jump(entry_index, *label); - bb_state.start(*label); - } - _ => return Err(error_unreachable()), - }; - for statement in body.iter() { + let (mut bb_state, body_iter) = + BasicBlockState::new(&mut cfg, *name, body, *is_kernel)?; + for statement in body_iter { match statement { Statement::Instruction(ast::Instruction::Bra { arguments }) => { bb_state.end(&[arguments.src]); @@ -712,13 +700,39 @@ struct BasicBlockState<'a> { } impl<'a> BasicBlockState<'a> { - fn new(cfg: &'a mut ControlFlowGraph) -> BasicBlockState<'a> { - Self { + #[must_use] + fn new<'x>( + cfg: &'a mut ControlFlowGraph, + fn_name: SpirvWord, + body: &'x Vec, SpirvWord>>, + is_kernel: bool, + ) -> Result< + ( + BasicBlockState<'a>, + impl Iterator, SpirvWord>>, + ), + TranslateError, + > { + let entry_index = if is_kernel { + cfg.add_entry_basic_block(fn_name) + } else { + cfg.get_or_add_basic_block(fn_name) + }; + let mut body_iter = body.iter(); + let mut bb_state = Self { cfg, node_index: None, entry: InstructionModes::none(), exit: InstructionModes::none(), - } + }; + match body_iter.next() { + Some(Statement::Label(label)) => { + bb_state.cfg.add_jump(entry_index, *label); + bb_state.start(*label); + } + _ => return Err(error_unreachable()), + }; + Ok((bb_state, body_iter)) } fn start(&mut self, label: SpirvWord) {