diff --git a/ptx/src/pass/insert_ftz_control/mod.rs b/ptx/src/pass/insert_ftz_control/mod.rs index afd2bcb..5aab249 100644 --- a/ptx/src/pass/insert_ftz_control/mod.rs +++ b/ptx/src/pass/insert_ftz_control/mod.rs @@ -108,6 +108,115 @@ struct InstructionModes { rounding_f16f64: Option, } +struct ResolvedInstructionModes { + denormal_f32: Resolved, + denormal_f16f64: Resolved, + rounding_f32: Resolved, + rounding_f16f64: Resolved, +} + +/* +struct ExitInstructionModes { + denormal_f32: Resolved, + denormal_f16f64: Resolved, + rounding_f32: Resolved, + rounding_f16f64: Resolved, +} + +impl ExitInstructionModes { + fn from_node( + denormal: &TwinModeInsertions, + rounding: &TwinModeInsertions, + Node { + label: ret_block_name, + denormal_f32, + denormal_f16f64, + rounding_f32, + rounding_f16f64, + }: &Node, + ) -> Result { + let denormal_entry = &denormal.basic_blocks; + let rounding_entry = &rounding.basic_blocks; + let denormal_f32 = match denormal_f32.exit { + Some(ExtendedMode::Entry(kernel)) => Resolved::Value( + denormal_entry + .get(&kernel) + .ok_or_else(error_unreachable)? + .twin_mode + .ok_or_else(error_unreachable)? + .f32 + .to_ftz(), + ), + Some(ExtendedMode::BasicBlock(value)) => Resolved::Value(value.to_ftz()), + None => denormal_entry + .get(ret_block_name) + .ok_or_else(error_unreachable)? + .twin_mode + .map(|m| m.f32.to_ftz()), + }; + /* + let denormal_f16f64 = match denormal_f16f64.exit { + None => denormal_entry + .get(ret_block_name) + .ok_or_else(error_unreachable)? + .twin_mode + .map(|m| m.f16f64.to_ftz()), + Some(ExtendedMode::Entry(kernel)) => Some( + denormal_entry + .get(&kernel) + .ok_or_else(error_unreachable)? + .twin_mode + .unwrap() + .f16f64 + .to_ftz(), + ), + Some(ExtendedMode::BasicBlock(value)) => Some(value.to_ftz()), + }; + let rounding_f32 = match rounding_f32.exit { + None => rounding_entry + .get(ret_block_name) + .ok_or_else(error_unreachable)? + .twin_mode + .map(|m| m.f32.to_ast()), + Some(ExtendedMode::Entry(kernel)) => Some( + rounding_entry + .get(&kernel) + .ok_or_else(error_unreachable)? + .twin_mode + .unwrap() + .f32 + .to_ast(), + ), + Some(ExtendedMode::BasicBlock(value)) => Some(value.to_ast()), + }; + let rounding_f16f64 = match rounding_f16f64.exit { + None => rounding_entry + .get(ret_block_name) + .ok_or_else(error_unreachable)? + .twin_mode + .map(|m| m.f16f64.to_ast()), + Some(ExtendedMode::Entry(kernel)) => Some( + rounding_entry + .get(&kernel) + .ok_or_else(error_unreachable)? + .twin_mode + .unwrap() + .f16f64 + .to_ast(), + ), + Some(ExtendedMode::BasicBlock(value)) => Some(value.to_ast()), + }; + */ + Ok(Self { + denormal_f32, + denormal_f16f64, + rounding_f32, + rounding_f16f64, + }) + } +} + */ + impl InstructionModes { fn fold_into(self, entry: &mut Self, exit: &mut Self) { fn set_if_none(source: &mut Option, value: Option) { @@ -248,7 +357,7 @@ struct ControlFlowGraph { basic_blocks: FxHashMap, // map function -> return label call_returns: FxHashMap>, - // map function -> return basic blocks + // map function -> return basic block functions_rets: FxHashMap, graph: Graph, } @@ -312,6 +421,207 @@ impl ControlFlowGraph { } } +struct ResolvedControlFlowGraph { + entry_points: FxHashMap, + basic_blocks: FxHashMap, + // map function -> return label + call_returns: FxHashMap>, + // map function -> return basic block + functions_rets: FxHashMap, + graph: Graph, +} + +impl ResolvedControlFlowGraph { + fn new( + cfg: ControlFlowGraph, + f32_denormal_kernels: &FxHashMap, + f16f64_denormal_kernels: &FxHashMap, + f32_rounding_kernels: &FxHashMap, + f16f64_rounding_kernels: &FxHashMap, + ) -> Result { + fn get_incoming_mode( + cfg: &ControlFlowGraph, + kernels: &FxHashMap, + node: NodeIndex, + mut exit_getter: impl FnMut(&Node) -> Option>, + ) -> Result, TranslateError> { + let mut mode: Option = None; + let mut visited = iter::once(node).collect::>(); + let mut to_visit = cfg + .graph + .neighbors_directed(node, Direction::Incoming) + .map(|x| x) + .collect::>(); + while let Some(node) = to_visit.pop() { + if !visited.insert(node) { + continue; + } + let node_data = &cfg.graph[node]; + match (mode, exit_getter(node_data)) { + (_, None) => { + for next in cfg.graph.neighbors_directed(node, Direction::Incoming) { + if !visited.contains(&next) { + to_visit.push(next); + } + } + } + (existing_mode, Some(new_mode)) => { + let new_mode = match new_mode { + ExtendedMode::BasicBlock(new_mode) => new_mode, + ExtendedMode::Entry(kernel) => { + kernels.get(&kernel).copied().unwrap_or_default() + } + }; + if let Some(existing_mode) = existing_mode { + if existing_mode != new_mode { + return Ok(Resolved::Conflict); + } + } + mode = Some(new_mode); + } + } + } + mode.map(Resolved::Value).ok_or_else(error_unreachable) + } + fn resolve_mode( + cfg: &ControlFlowGraph, + kernels: &FxHashMap, + node: NodeIndex, + exit_getter: impl FnMut(&Node) -> Option>, + mode: &Mode, + ) -> Result, TranslateError> { + let entry = match mode.entry { + Some(ExtendedMode::Entry(kernel)) => { + Resolved::Value(kernels.get(&kernel).copied().unwrap_or_default()) + } + Some(ExtendedMode::BasicBlock(bb)) => Resolved::Value(bb), + None => get_incoming_mode(cfg, kernels, node, exit_getter)?, + }; + let exit = match mode.entry { + Some(ExtendedMode::BasicBlock(bb)) => Resolved::Value(bb), + Some(ExtendedMode::Entry(_)) | None => entry, + }; + Ok(ResolvedMode { entry, exit }) + } + fn resolve_node_impl( + cfg: &ControlFlowGraph, + f32_denormal_kernels: &FxHashMap, + f16f64_denormal_kernels: &FxHashMap, + f32_rounding_kernels: &FxHashMap, + f16f64_rounding_kernels: &FxHashMap, + index: NodeIndex, + node: &Node, + ) -> Result { + let denormal_f32 = resolve_mode( + cfg, + f32_denormal_kernels, + index, + |node| node.denormal_f32.exit, + &node.denormal_f32, + )?; + let denormal_f16f64 = resolve_mode( + cfg, + f16f64_denormal_kernels, + index, + |node| node.denormal_f16f64.exit, + &node.denormal_f16f64, + )?; + let rounding_f32 = resolve_mode( + cfg, + f32_rounding_kernels, + index, + |node| node.rounding_f32.exit, + &node.rounding_f32, + )?; + let rounding_f16f64 = resolve_mode( + cfg, + f16f64_rounding_kernels, + index, + |node| node.rounding_f16f64.exit, + &node.rounding_f16f64, + )?; + Ok(ResolvedNode { + label: node.label, + denormal_f32, + denormal_f16f64, + rounding_f32, + rounding_f16f64, + }) + } + fn resolve_node( + cfg: &ControlFlowGraph, + f32_denormal_kernels: &FxHashMap, + f16f64_denormal_kernels: &FxHashMap, + f32_rounding_kernels: &FxHashMap, + f16f64_rounding_kernels: &FxHashMap, + index: NodeIndex, + node: &Node, + error: &mut bool, + ) -> ResolvedNode { + match resolve_node_impl( + cfg, + f32_denormal_kernels, + f16f64_denormal_kernels, + f32_rounding_kernels, + f16f64_rounding_kernels, + index, + node, + ) { + Ok(node) => node, + Err(_) => { + *error = true; + ResolvedNode { + label: SpirvWord(u32::MAX), + denormal_f32: ResolvedMode { + entry: Resolved::Conflict, + exit: Resolved::Conflict, + }, + denormal_f16f64: ResolvedMode { + entry: Resolved::Conflict, + exit: Resolved::Conflict, + }, + rounding_f32: ResolvedMode { + entry: Resolved::Conflict, + exit: Resolved::Conflict, + }, + rounding_f16f64: ResolvedMode { + entry: Resolved::Conflict, + exit: Resolved::Conflict, + }, + } + } + } + } + let mut error = false; + let graph = cfg.graph.map( + |index, node| { + resolve_node( + &cfg, + f32_denormal_kernels, + f16f64_denormal_kernels, + f32_rounding_kernels, + f16f64_rounding_kernels, + index, + node, + &mut error, + ) + }, + |_, ()| (), + ); + if error { + Err(error_unreachable()) + } else { + Ok(Self { + entry_points: cfg.entry_points, + basic_blocks: cfg.basic_blocks, + call_returns: cfg.call_returns, + functions_rets: cfg.functions_rets, + graph, + }) + } + } +} + #[derive(Clone, Copy)] //#[cfg_attr(test, derive(Debug))] #[derive(Debug)] @@ -336,6 +646,12 @@ impl Mode { } } +#[derive(Copy, Clone)] +struct ResolvedMode { + entry: Resolved, + exit: Resolved, +} + //#[cfg_attr(test, derive(Debug))] #[derive(Debug)] struct Node { @@ -346,6 +662,14 @@ struct Node { rounding_f16f64: Mode, } +struct ResolvedNode { + label: SpirvWord, + denormal_f32: ResolvedMode, + denormal_f16f64: ResolvedMode, + rounding_f32: ResolvedMode, + rounding_f16f64: ResolvedMode, +} + impl Node { fn entry(label: SpirvWord) -> Self { Self { @@ -440,40 +764,137 @@ pub(crate) fn run<'input>( "{:?}", petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel]) ); + let rounding_f32 = compute_single_mode(&cfg, |node| node.rounding_f32); let denormal_f32 = compute_single_mode(&cfg, |node| node.denormal_f32); let denormal_f16f64 = compute_single_mode(&cfg, |node| node.denormal_f16f64); - let rounding_f32 = compute_single_mode(&cfg, |node| node.rounding_f32); let rounding_f16f64 = compute_single_mode(&cfg, |node| node.rounding_f16f64); let denormal_f32 = optimize::(denormal_f32); let denormal_f16f64 = optimize::(denormal_f16f64); let rounding_f32 = optimize::(rounding_f32); let rounding_f16f64: MandatoryModeInsertions = optimize::(rounding_f16f64); + let cfg = ResolvedControlFlowGraph::new( + cfg, + &denormal_f32.kernels, + &denormal_f16f64.kernels, + &rounding_f32.kernels, + &rounding_f16f64.kernels, + )?; + let temp = join_modes2( + flat_resolver, + cfg, + denormal_f32, + denormal_f16f64, + rounding_f32, + rounding_f16f64, + )?; + + /* let denormal = join_modes( flat_resolver, &cfg, denormal_f32, - |node| node.denormal_f32.entry, - |node| node.denormal_f32.exit, + |node| node.denormal_f32, denormal_f16f64, - |node| node.denormal_f16f64.entry, - |node| node.denormal_f16f64.exit, + |node| node.denormal_f16f64, )?; let rounding = join_modes( flat_resolver, &cfg, rounding_f32, - |node| node.rounding_f32.entry, - |node| node.rounding_f32.exit, + |node| node.rounding_f32, rounding_f16f64, - |node| node.rounding_f16f64.entry, - |node| node.rounding_f16f64.exit, + |node| node.rounding_f16f64, )?; let all_modes = FullModeInsertion::new(flat_resolver, denormal, rounding)?; - let directives = insert_mode_control(directives, all_modes)?; + */ + let directives = insert_mode_control(directives, temp)?; Ok(directives) } +fn join_modes2( + flat_resolver: &mut super::GlobalStringIdentResolver2, + cfg: ResolvedControlFlowGraph, + mandatory_denormal_f32: MandatoryModeInsertions, + mandatory_denormal_f16f64: MandatoryModeInsertions, + mandatory_rounding_f32: MandatoryModeInsertions, + mandatory_rounding_f16f64: MandatoryModeInsertions, +) -> Result { + let basic_blocks = cfg + .graph + .node_weights() + .map(|basic_block| { + let denormal_prologue = if mandatory_denormal_f32 + .basic_blocks + .contains(&basic_block.label) + || mandatory_denormal_f16f64 + .basic_blocks + .contains(&basic_block.label) + { + Some(flat_resolver.register_unnamed(None)) + } else { + None + }; + let rounding_prologue = if mandatory_rounding_f32 + .basic_blocks + .contains(&basic_block.label) + || mandatory_rounding_f16f64 + .basic_blocks + .contains(&basic_block.label) + { + Some(flat_resolver.register_unnamed(None)) + } else { + None + }; + let dual_prologue = if denormal_prologue.is_some() && rounding_prologue.is_some() { + Some(flat_resolver.register_unnamed(None)) + } else { + None + }; + let denormal = BasicBlockEntryState { + prologue: denormal_prologue, + twin_mode: TwinMode { + f32: basic_block.denormal_f32.entry, + f16f64: basic_block.denormal_f16f64.entry, + }, + }; + let rounding = BasicBlockEntryState { + prologue: rounding_prologue, + twin_mode: TwinMode { + f32: basic_block.rounding_f32.entry, + f16f64: basic_block.rounding_f16f64.entry, + }, + }; + Ok(( + basic_block.label, + FullBasicBlockEntryState { + dual_prologue, + denormal, + rounding, + }, + )) + }) + .collect::, _>>()?; + let functions_exit_modes = cfg + .functions_rets + .into_iter() + .map(|(bb, node)| { + let weights = cfg.graph.node_weight(node).ok_or_else(error_unreachable)?; + let modes = ResolvedInstructionModes { + denormal_f32: weights.denormal_f32.exit.map(DenormalMode::to_ftz), + denormal_f16f64: weights.denormal_f16f64.exit.map(DenormalMode::to_ftz), + rounding_f32: weights.rounding_f32.exit.map(RoundingMode::to_ast), + rounding_f16f64: weights.rounding_f16f64.exit.map(RoundingMode::to_ast), + }; + Ok((bb, modes)) + }) + .collect::, _>>()?; + Ok(FullModeInsertion2 { + basic_blocks, + functions_exit_modes, + }) +} + // For every basic block this pass computes: // - Name of mode prologue basic block. Mode prologue is a basic block which // contains single instruction that sets mode to the desired value. It will @@ -486,66 +907,19 @@ pub(crate) fn run<'input>( // We don't need to compute exit mode because this will be computed naturally // when emitting instructions in a basic block. We need exit mode to know if we // jump directly to the next bb or jump to mode prologue +/* fn join_modes<'input, T: Eq + PartialEq + Copy + Default>( flat_resolver: &mut super::GlobalStringIdentResolver2<'input>, - cfg: &ControlFlowGraph, + cfg: &ResolvedControlFlowGraph, f32_insertions: MandatoryModeInsertions, - mut f32_enter_view: impl FnMut(&Node) -> Option>, - mut f32_exit_view: impl FnMut(&Node) -> Option>, + mut f32_view: impl FnMut(&ResolvedNode) -> ResolvedMode, f16f64_insertions: MandatoryModeInsertions, - mut f16f64_enter_view: impl FnMut(&Node) -> Option>, - mut f16f64_exit_view: impl FnMut(&Node) -> Option>, + mut f16f64_view: impl FnMut(&ResolvedNode) -> ResolvedMode, ) -> Result, TranslateError> { - // Returns None if there are multiple conflicting modes - fn get_incoming_mode( - cfg: &ControlFlowGraph, - kernels: &FxHashMap, - node: NodeIndex, - mut exit_getter: impl FnMut(&Node) -> Option>, - ) -> Result, TranslateError> { - let mut mode: Option = None; - let mut visited = iter::once(node).collect::>(); - let mut to_visit = cfg - .graph - .neighbors_directed(node, Direction::Incoming) - .map(|x| x) - .collect::>(); - while let Some(node) = to_visit.pop() { - if !visited.insert(node) { - continue; - } - let node_data = &cfg.graph[node]; - match (mode, exit_getter(node_data)) { - (_, None) => { - for next in cfg.graph.neighbors_directed(node, Direction::Incoming) { - if !visited.contains(&next) { - to_visit.push(next); - } - } - } - (existing_mode, Some(new_mode)) => { - let new_mode = match new_mode { - ExtendedMode::BasicBlock(new_mode) => new_mode, - ExtendedMode::Entry(kernel) => { - kernels.get(&kernel).copied().unwrap_or_default() - } - }; - if let Some(existing_mode) = existing_mode { - if existing_mode != new_mode { - return Ok(None); - } - } - mode = Some(new_mode); - } - } - } - mode.map(Some).ok_or_else(error_unreachable) - } let basic_blocks = cfg .graph - .node_references() - .into_iter() - .map(|(node, basic_block)| { + .node_weights() + .map(|basic_block| { let requires_prologue = f32_insertions.basic_blocks.contains(&basic_block.label) || f16f64_insertions.basic_blocks.contains(&basic_block.label); let prologue: Option = if requires_prologue { @@ -553,49 +927,14 @@ fn join_modes<'input, T: Eq + PartialEq + Copy + Default>( } else { None }; - let f32 = match f32_enter_view(&basic_block) { - Some(ExtendedMode::BasicBlock(mode)) => Some(mode), - Some(ExtendedMode::Entry(kernel)) => Some( - f32_insertions - .kernels - .get(&kernel) - .copied() - .unwrap_or_default(), - ), - // None means that no instruction in the basic block sets mode, but - // another basic block might rely on this instruction transitively - // passing a mode - None => None, - }; - let f16f64 = match f16f64_enter_view(&basic_block) { - Some(ExtendedMode::BasicBlock(mode)) => Some(mode), - Some(ExtendedMode::Entry(kernel)) => Some( - f16f64_insertions - .kernels - .get(&kernel) - .copied() - .unwrap_or_default(), - ), - None => None, - }; - let twin_mode = match (f32, f16f64) { - (Some(f32), Some(f16f64)) => Some(TwinMode { f32, f16f64 }), - (None, Some(f16f64)) => { - let f32 = get_incoming_mode(cfg, &f32_insertions.kernels, node, |node| { - f32_exit_view(node) - })?; - let f32 = f32.unwrap_or_default(); - Some(TwinMode { f32, f16f64 }) - } - (Some(f32), None) => { - let f16f64 = - get_incoming_mode(cfg, &f16f64_insertions.kernels, node, |node| { - f16f64_exit_view(node) - })?; - let f16f64 = f16f64.unwrap_or_default(); - Some(TwinMode { f32, f16f64 }) - } - (None, None) => None, + let f32 = f32_view(basic_block); + let f16f64 = f16f64_view(basic_block); + let twin_mode = match (f32.entry, f16f64.entry) { + (Resolved::Conflict, Resolved::Conflict) => Resolved::Conflict, + (f32, f16f64) => Resolved::Value(TwinMode { + f32: f32.unwrap_of_default(), + f16f64: f16f64.unwrap_of_default(), + }), }; Ok(( basic_block.label, @@ -608,11 +947,17 @@ fn join_modes<'input, T: Eq + PartialEq + Copy + Default>( .collect::, _>>()?; Ok(TwinModeInsertions { basic_blocks }) } + */ struct TwinModeInsertions { basic_blocks: FxHashMap>, } +struct FullModeInsertion2 { + basic_blocks: FxHashMap, + functions_exit_modes: FxHashMap, +} + struct FullModeInsertion { basic_blocks: FxHashMap, } @@ -623,15 +968,18 @@ impl FullModeInsertion { denormal: TwinModeInsertions, rounding: TwinModeInsertions, ) -> Result { - let denormal = denormal.basic_blocks; - let rounding = rounding.basic_blocks; - if denormal.len() != rounding.len() { + if denormal.basic_blocks.len() != rounding.basic_blocks.len() { return Err(error_unreachable()); } let basic_blocks = denormal + .basic_blocks .into_iter() .map(|(bb, denormal)| { - let rounding = rounding.get(&bb).copied().ok_or_else(error_unreachable)?; + let rounding = rounding + .basic_blocks + .get(&bb) + .copied() + .ok_or_else(error_unreachable)?; let dual_prologue = if denormal.prologue.is_some() && rounding.prologue.is_some() { Some(flat_resolver.register_unnamed(None)) } else { @@ -660,19 +1008,18 @@ struct FullBasicBlockEntryState { #[derive(Clone, Copy)] struct BasicBlockEntryState { prologue: Option, - // It is None in case where no instructions in the basic block uses mode - twin_mode: Option>, + twin_mode: TwinMode>, } -#[derive(Clone, Copy, Default)] +#[derive(Clone, Copy)] struct TwinMode { f32: T, f16f64: T, } -fn insert_mode_control<'input>( +fn insert_mode_control( directives: Vec, SpirvWord>>, - global_modes: FullModeInsertion, + global_modes: FullModeInsertion2, ) -> Result, SpirvWord>>, TranslateError> { let directives_len = directives.len(); directives @@ -697,33 +1044,18 @@ fn insert_mode_control<'input>( .basic_blocks .get(&name) .ok_or_else(error_unreachable)?; - *flush_to_zero_f32 = initial_mode - .denormal - .twin_mode - .unwrap_or_default() - .f32 - .to_ftz(); - *flush_to_zero_f16f64 = initial_mode - .denormal - .twin_mode - .unwrap_or_default() - .f16f64 - .to_ftz(); - *rounding_mode_f32 = initial_mode - .rounding - .twin_mode - .unwrap_or_default() - .f32 - .to_ast(); - *rounding_mode_f16f64 = initial_mode - .rounding - .twin_mode - .unwrap_or_default() - .f16f64 - .to_ast(); + let denormal_mode = initial_mode.denormal.twin_mode; + let rounding_mode = initial_mode.rounding.twin_mode; + *flush_to_zero_f32 = denormal_mode.f32.ok_or_else(error_unreachable)?.to_ftz(); + *flush_to_zero_f16f64 = + denormal_mode.f16f64.ok_or_else(error_unreachable)?.to_ftz(); + *rounding_mode_f32 = rounding_mode.f32.ok_or_else(error_unreachable)?.to_ast(); + *rounding_mode_f16f64 = + rounding_mode.f16f64.ok_or_else(error_unreachable)?.to_ast(); (name, initial_mode, body) } }; + emit_mode_prelude(fn_name, &mut new_directives); let old_body = mem::replace(body_ptr, Vec::new()); let mut result = Vec::with_capacity(old_body.len()); let mut bb_state = BasicBlockControlState::new(&global_modes, fn_name, initial_mode); @@ -760,13 +1092,26 @@ fn insert_mode_control<'input>( } result.push(statement); if let Some(call_target) = call_target { - if let Some(Statement::Instruction(ast::Instruction::Bra { - arguments: ast::BraArgs { src: post_call_label }, - })) = old_body.next() + let mut post_call_bra = old_body.next().ok_or_else(error_unreachable)?; + if let Statement::Instruction(ast::Instruction::Bra { + arguments: + ast::BraArgs { + src: ref mut post_call_label, + }, + }) = post_call_bra { - // get return block for the function, if there is a mode - // change between caller and callee then apply it here - todo!() + let node_exit_mode = global_modes + .functions_exit_modes + .get(&call_target) + .ok_or_else(error_unreachable)?; + redirect_jump_impl( + &bb_state.global_modes, + node_exit_mode, + post_call_label, + )?; + result.push(post_call_bra); + } else { + return Err(error_unreachable()); } } } @@ -780,8 +1125,40 @@ fn insert_mode_control<'input>( }) } +fn emit_mode_prelude( + fn_name: SpirvWord, + global_modes: FullModeInsertion2, + new_directives: &mut SmallVec<[Directive2, SpirvWord>; 4]>, +) -> Result<(), TranslateError> { + let fn_mode_state = global_modes.basic_blocks.get(&fn_name).ok_or_else(error_unreachable)?; + if let Some(dual_prologue) = fn_mode_state.dual_prologue { + new_directives.push(Directive2::Method( + Function2 { + return_arguments: todo!(), + name: dual_prologue, + input_arguments: todo!(), + body: todo!(), + is_kernel: false, + import_as: None, + tuning: Vec::new(), + linkage: ast::LinkingDirective::NONE, + flush_to_zero_f32: false, + flush_to_zero_f16f64: false, + rounding_mode_f32: ptx_parser::RoundingMode::NearestEven, + rounding_mode_f16f64: ptx_parser::RoundingMode::NearestEven, + } + )); + } + if let Some(prologue) = fn_mode_state.denormal.prologue { + todo!() + } + if let Some(prologue) = fn_mode_state.rounding.prologue { + todo!() + } +} + struct BasicBlockControlState<'a> { - global_modes: &'a FullModeInsertion, + global_modes: &'a FullModeInsertion2, denormal_f32: RegisterState, denormal_f16f64: RegisterState, rounding_f32: RegisterState, @@ -791,7 +1168,7 @@ struct BasicBlockControlState<'a> { #[derive(Clone, Copy)] struct RegisterState { - current_value: Option, + current_value: Resolved, // This is slightly subtle: this value is Some iff there's a SetMode in this // basic block setting this mode, but on which no instruciton relies last_foldable: Option, @@ -801,39 +1178,38 @@ impl RegisterState { fn single(t: T) -> Self { RegisterState { last_foldable: None, - current_value: Some(t), + current_value: Resolved::Value(t), } } - fn empty() -> Self { + fn conflict() -> Self { RegisterState { last_foldable: None, - current_value: None, + current_value: Resolved::Conflict, } } - fn new(computed: &BasicBlockEntryState) -> (RegisterState, RegisterState) + fn new(value: Resolved) -> RegisterState where U: Into, { - match computed.twin_mode { - Some(ref mode) => ( - RegisterState::single(mode.f32.into()), - RegisterState::single(mode.f16f64.into()), - ), - None => (RegisterState::empty(), RegisterState::empty()), + RegisterState { + current_value: value.map(Into::into), + last_foldable: None, } } } impl<'a> BasicBlockControlState<'a> { fn new( - global_modes: &'a FullModeInsertion, + global_modes: &'a FullModeInsertion2, current_bb: SpirvWord, initial_mode: &FullBasicBlockEntryState, ) -> Self { - let (denormal_f32, denormal_f16f64) = RegisterState::new(&initial_mode.denormal); - let (rounding_f32, rounding_f16f64) = RegisterState::new(&initial_mode.rounding); + let denormal_f32 = RegisterState::new(initial_mode.denormal.twin_mode.f32); + let denormal_f16f64 = RegisterState::new(initial_mode.denormal.twin_mode.f16f64); + let rounding_f32 = RegisterState::new(initial_mode.rounding.twin_mode.f32); + let rounding_f16f64 = RegisterState::new(initial_mode.rounding.twin_mode.f16f64); BasicBlockControlState { global_modes, denormal_f32, @@ -855,23 +1231,33 @@ impl<'a> BasicBlockControlState<'a> { .get(&basic_block) .ok_or_else(error_unreachable)?; - let (denormal_f32, denormal_f16f64) = RegisterState::new(&bb_state.denormal); + let denormal_f32 = RegisterState::new(bb_state.denormal.twin_mode.f32); + let denormal_f16f64 = RegisterState::new(bb_state.denormal.twin_mode.f16f64); self.denormal_f32 = denormal_f32; self.denormal_f16f64 = denormal_f16f64; - let (rounding_f32, rounding_f16f64) = RegisterState::new(&bb_state.rounding); + let rounding_f32 = RegisterState::new(bb_state.rounding.twin_mode.f32); + let rounding_f16f64 = RegisterState::new(bb_state.rounding.twin_mode.f16f64); self.rounding_f32 = rounding_f32; self.rounding_f16f64 = rounding_f16f64; if let Some(prologue) = bb_state.dual_prologue { statements.push(Statement::Label(prologue)); - let denormal = bb_state.denormal.twin_mode.ok_or_else(error_unreachable)?; statements.push(Statement::SetMode(ModeRegister::Denormal { - f32: denormal.f32.to_ftz(), - f16f64: denormal.f16f64.to_ftz(), + f32: bb_state.denormal.twin_mode.f32.unwrap_of_default().to_ftz(), + f16f64: bb_state + .denormal + .twin_mode + .f16f64 + .unwrap_of_default() + .to_ftz(), })); - let rounding = bb_state.rounding.twin_mode.ok_or_else(error_unreachable)?; statements.push(Statement::SetMode(ModeRegister::Rounding { - f32: rounding.f32.to_ast(), - f16f64: rounding.f16f64.to_ast(), + f32: bb_state.rounding.twin_mode.f32.unwrap_of_default().to_ast(), + f16f64: bb_state + .rounding + .twin_mode + .f16f64 + .unwrap_of_default() + .to_ast(), })); statements.push(Statement::Instruction(ast::Instruction::Bra { arguments: ast::BraArgs { src: basic_block }, @@ -879,10 +1265,14 @@ impl<'a> BasicBlockControlState<'a> { } if let Some(prologue) = bb_state.denormal.prologue { statements.push(Statement::Label(prologue)); - let denormal = bb_state.denormal.twin_mode.ok_or_else(error_unreachable)?; statements.push(Statement::SetMode(ModeRegister::Denormal { - f32: denormal.f32.to_ftz(), - f16f64: denormal.f16f64.to_ftz(), + f32: bb_state.denormal.twin_mode.f32.unwrap_of_default().to_ftz(), + f16f64: bb_state + .denormal + .twin_mode + .f16f64 + .unwrap_of_default() + .to_ftz(), })); statements.push(Statement::Instruction(ast::Instruction::Bra { arguments: ast::BraArgs { src: basic_block }, @@ -890,10 +1280,14 @@ impl<'a> BasicBlockControlState<'a> { } if let Some(prologue) = bb_state.rounding.prologue { statements.push(Statement::Label(prologue)); - let rounding = bb_state.rounding.twin_mode.ok_or_else(error_unreachable)?; statements.push(Statement::SetMode(ModeRegister::Rounding { - f32: rounding.f32.to_ast(), - f16f64: rounding.f16f64.to_ast(), + f32: bb_state.rounding.twin_mode.f32.unwrap_of_default().to_ast(), + f16f64: bb_state + .rounding + .twin_mode + .f16f64 + .unwrap_of_default() + .to_ast(), })); statements.push(Statement::Instruction(ast::Instruction::Bra { arguments: ast::BraArgs { src: basic_block }, @@ -902,29 +1296,6 @@ impl<'a> BasicBlockControlState<'a> { Ok(()) } - /* - fn add_or_fold_mode_set( - &mut self, - result: &mut Vec, SpirvWord>>, - new_mode: bool, - ) -> Option { - // try and fold into the other mode set - if let RegisterState::Value(Some(other_index), other_value) = self.denormal_f16f64 { - if let Some(Statement::SetMode(ModeRegister::DenormalF16F64(_))) = - result.get_mut(other_index) - { - result[other_index] = Statement::SetMode(ModeRegister::DenormalBoth { - f32: new_mode, - f16f64: other_value, - }); - return None; - } - } - result.push(Statement::SetMode(ModeRegister::DenormalF32(new_mode))); - Some(result.len() - 1) - } - */ - fn insert( &mut self, result: &mut Vec, SpirvWord>>, @@ -935,6 +1306,11 @@ impl<'a> BasicBlockControlState<'a> { result, modes.denormal_f16f64.map(DenormalMode::to_ftz), )?; + self.insert_one::(result, modes.rounding_f32.map(RoundingMode::to_ast))?; + self.insert_one::( + result, + modes.rounding_f16f64.map(RoundingMode::to_ast), + )?; Ok(()) } @@ -949,10 +1325,12 @@ impl<'a> BasicBlockControlState<'a> { View::set_register(bb, reg); } let new_mode = unwrap_some_or!(mode, return Ok(())); - // if let Some(new_mode) = mode { let register_state = View::get_register(self); match register_state.current_value { - Some(old) if old == new_mode => { + Resolved::Conflict => { + return Err(error_unreachable()); + } + Resolved::Value(old) if old == new_mode => { set_fold_index::(self, None); } _ => match register_state.last_foldable { @@ -977,81 +1355,114 @@ impl<'a> BasicBlockControlState<'a> { } }, } - //} Ok(()) } - // Return the index of the last insertion of SetMode with this mode - /* - fn add_or_fold_mode_set2( - &self, - result: &mut Vec, SpirvWord>>, - new_mode: View::Value, - ) -> Result<(), TranslateError> { - // try and fold into the other mode set instruction - View::get_register(bb) - if let RegisterState { last_foldable: } = View::TwinView::get_register(self) { - if let Some(Statement::SetMode(register_mode)) = result.get_mut(twin_index) { - if let Some(twin_mode) = View::TwinView::get_single_mode(register_mode) { - *register_mode = View::new_mode(new_mode, Some(twin_mode)); - return twin_index; - } - } - } - result.push(Statement::SetMode(View::new_mode(new_mode, None))); - result.len() - 1 - } - */ - fn redirect_jump(&self, jump_target: &mut SpirvWord) -> Result<(), TranslateError> { - let target = self - .global_modes - .basic_blocks - .get(jump_target) - .ok_or_else(error_unreachable)?; - let jump_to_denormal = match ( - self.denormal_f32.current_value, - self.denormal_f16f64.current_value, - ) { - (None, None) => false, - (Some(current_f32), Some(current_f16f64)) => { - if let Some(target_mode) = target.denormal.twin_mode { - target_mode.f32.to_ftz() != current_f32 - || target_mode.f16f64.to_ftz() != current_f16f64 - } else { - false - } - } - _ => return Err(error_unreachable()), + let current_mode = ResolvedInstructionModes { + denormal_f32: self.denormal_f32.current_value, + denormal_f16f64: self.denormal_f16f64.current_value, + rounding_f32: self.rounding_f32.current_value, + rounding_f16f64: self.rounding_f16f64.current_value, }; - let jump_to_rounding = match ( - self.rounding_f32.current_value, - self.rounding_f16f64.current_value, - ) { - (None, None) => false, - (Some(current_f32), Some(current_f16f64)) => { - if let Some(target_mode) = target.rounding.twin_mode { - target_mode.f32.to_ast() != current_f32 - || target_mode.f16f64.to_ast() != current_f16f64 - } else { - false - } - } - _ => return Err(error_unreachable()), - }; - match (jump_to_denormal, jump_to_rounding) { - (true, false) => { - *jump_target = target.denormal.prologue.ok_or_else(error_unreachable)?; - } - (false, true) => { - *jump_target = target.rounding.prologue.ok_or_else(error_unreachable)?; - } - (true, true) => { - *jump_target = target.dual_prologue.ok_or_else(error_unreachable)?; - } - (false, false) => {} + redirect_jump_impl(self.global_modes, ¤t_mode, jump_target) + } +} + +fn redirect_jump_impl( + global_modes: &FullModeInsertion2, + current_mode: &ResolvedInstructionModes, + jump_target: &mut SpirvWord, +) -> Result<(), TranslateError> { + let target = global_modes + .basic_blocks + .get(jump_target) + .ok_or_else(error_unreachable)?; + let jump_to_denormal_prelude = current_mode + .denormal_f32 + .mode_change(target.denormal.twin_mode.f32.map(DenormalMode::to_ftz)) + || current_mode + .denormal_f16f64 + .mode_change(target.denormal.twin_mode.f16f64.map(DenormalMode::to_ftz)); + let jump_to_rounding_prelude = current_mode + .rounding_f32 + .mode_change(target.rounding.twin_mode.f32.map(RoundingMode::to_ast)) + || current_mode + .rounding_f16f64 + .mode_change(target.rounding.twin_mode.f16f64.map(RoundingMode::to_ast)); + match (jump_to_denormal_prelude, jump_to_rounding_prelude) { + (true, false) => { + *jump_target = target.denormal.prologue.ok_or_else(error_unreachable)?; + } + (false, true) => { + *jump_target = target.rounding.prologue.ok_or_else(error_unreachable)?; + } + (true, true) => { + *jump_target = target.dual_prologue.ok_or_else(error_unreachable)?; + } + (false, false) => {} + } + Ok(()) +} + +struct ModeJumpTargets { + dual_prologue: Option, + denormal: Option, + rounding: Option, +} + +#[derive(Copy, Clone)] +enum Resolved { + Conflict, + Value(T), +} + +impl Resolved { + fn unwrap_of_default(self) -> T { + match self { + Resolved::Conflict => T::default(), + Resolved::Value(t) => t, + } + } +} + +impl Resolved { + fn mode_change(self, target: Self) -> bool { + match (self, target) { + (Resolved::Conflict, Resolved::Conflict) => false, + (Resolved::Conflict, Resolved::Value(_)) => true, + (Resolved::Value(_), Resolved::Conflict) => false, + (Resolved::Value(x), Resolved::Value(y)) => x != y, + } + } +} + +impl Resolved { + fn map(self, f: F) -> Resolved + where + F: FnOnce(T) -> U, + { + match self { + Resolved::Value(x) => Resolved::Value(f(x)), + Resolved::Conflict => Resolved::Conflict, + } + } + + fn ok_or_else(self, err: F) -> Result + where + F: FnOnce() -> E, + { + match self { + Resolved::Value(v) => Ok(v), + Resolved::Conflict => Err(err()), + } + } + + fn has_value(&self) -> bool { + match self { + Resolved::Value(_) => true, + Resolved::Conflict => false, } - Ok(()) } } @@ -1119,6 +1530,60 @@ impl ModeView for DenormalF16F64View { } } +struct RoundingF32View; + +impl ModeView for RoundingF32View { + type Value = ast::RoundingMode; + type TwinView = RoundingF16F64View; + + fn get_register(bb: &BasicBlockControlState) -> RegisterState { + bb.rounding_f32 + } + + fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState) { + bb.rounding_f32 = reg; + } + + fn new_mode(f32: Self::Value, f16f64: Self::Value) -> ModeRegister { + ModeRegister::Rounding { f32, f16f64 } + } + + fn set_single_mode(reg: &mut ModeRegister, x: Self::Value) -> Result<(), TranslateError> { + match reg { + ModeRegister::Rounding { f32, f16f64: _ } => *f32 = x, + ModeRegister::Denormal { .. } => return Err(error_unreachable()), + } + Ok(()) + } +} + +struct RoundingF16F64View; + +impl ModeView for RoundingF16F64View { + type Value = ast::RoundingMode; + type TwinView = RoundingF32View; + + fn get_register(bb: &BasicBlockControlState) -> RegisterState { + bb.rounding_f16f64 + } + + fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState) { + bb.rounding_f16f64 = reg; + } + + fn new_mode(f16f64: Self::Value, f32: Self::Value) -> ModeRegister { + ModeRegister::Rounding { f32, f16f64 } + } + + fn set_single_mode(reg: &mut ModeRegister, x: Self::Value) -> Result<(), TranslateError> { + match reg { + ModeRegister::Rounding { f32: _, f16f64 } => *f16f64 = x, + ModeRegister::Denormal { .. } => return Err(error_unreachable()), + } + Ok(()) + } +} + struct BasicBlockState<'a> { cfg: &'a mut ControlFlowGraph, node_index: Option, diff --git a/ptx/src/pass/insert_ftz_control/test.rs b/ptx/src/pass/insert_ftz_control/test.rs index 355d706..05c1fc8 100644 --- a/ptx/src/pass/insert_ftz_control/test.rs +++ b/ptx/src/pass/insert_ftz_control/test.rs @@ -214,10 +214,12 @@ static CALL_WITH_MODE_PTX: &'static str = include_str!("call_with_mode.ptx"); #[test] fn call_with_mode() { let methods = compile_methods(CALL_WITH_MODE_PTX); + assert!(matches!(methods[0].body, None)); + let method_1 = methods[1].body.as_ref().unwrap(); assert!(matches!( - &**methods[1].body.as_ref().unwrap(), + &**method_1, [ Statement::Label(..), Statement::Variable(..), @@ -254,4 +256,71 @@ fn call_with_mode() { Statement::Instruction(ast::Instruction::Ret { .. }), ] )); + let [to_fn0] = calls(method_1); + let [_, dual_prelude, _, _, add] = labels(method_1); + let [post_call, post_prelude_0, post_prelude_1, post_prelude_2] = branches(method_1); + assert_eq!(methods[0].name, to_fn0); + assert_eq!(post_call, dual_prelude); + assert_eq!(post_prelude_0, add); + assert_eq!(post_prelude_1, add); + assert_eq!(post_prelude_2, add); + + let method_2 = methods[2].body.as_ref().unwrap(); + assert!(matches!( + &**method_2, + [ + Statement::Label(..), + Statement::SetMode(ModeRegister::Denormal { + f32: true, + f16f64: true + }), + Statement::SetMode(ModeRegister::Rounding { + f32: ast::RoundingMode::PositiveInf, + f16f64: ast::RoundingMode::NearestEven + }), + Statement::Instruction(ast::Instruction::Call { .. }), + Statement::Instruction(ast::Instruction::Ret { .. }), + ] + )); +} + +fn branches( + fn_: &Vec, SpirvWord>>, +) -> [SpirvWord; N] { + fn_.iter() + .filter_map(|s| match s { + Statement::Instruction(ast::Instruction::Bra { + arguments: ast::BraArgs { src }, + }) => Some(*src), + _ => None, + }) + .collect::>() + .try_into() + .unwrap() +} + +fn labels( + fn_: &Vec, SpirvWord>>, +) -> [SpirvWord; N] { + fn_.iter() + .filter_map(|s: &Statement, SpirvWord>| match s { + Statement::Label(label) => Some(*label), + _ => None, + }) + .collect::>() + .try_into() + .unwrap() +} + +fn calls( + fn_: &Vec, SpirvWord>>, +) -> [SpirvWord; N] { + fn_.iter() + .filter_map(|s| match s { + Statement::Instruction(ast::Instruction::Call { arguments: ast::CallArgs { func,.. }, .. }) => Some(*func), + _ => None, + }) + .collect::>() + .try_into() + .unwrap() }