From 5fc2f3367f0f58449255ee74467c9be0b731c873 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 14 Mar 2025 18:56:45 +0000 Subject: [PATCH] Emit rounding mode change instruction --- ptx/src/pass/emit_llvm.rs | 40 +++++++++++++++++++++----- ptx/src/pass/insert_ftz_control/mod.rs | 16 +++++------ 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 66ceb75..734243e 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -2261,19 +2261,45 @@ impl<'a> MethodEmitContext<'a> { } fn emit_set_mode(&mut self, mode_reg: ModeRegister) -> Result<(), TranslateError> { + fn hwreg(reg: u32, offset: u32, size: u32) -> u32 { + reg | (offset << 6) | ((size - 1) << 11) + } + fn denormal_to_value(ftz: bool) -> u32 { + if ftz { + 0 + } else { + 3 + } + } + fn rounding_to_value(ftz: ast::RoundingMode) -> u32 { + match ftz { + ptx_parser::RoundingMode::NearestEven => 0, + ptx_parser::RoundingMode::Zero => 3, + ptx_parser::RoundingMode::NegativeInf => 2, + ptx_parser::RoundingMode::PositiveInf => 1, + } + } + fn merge_regs(f32: u32, f16f64: u32) -> u32 { + f32 | f16f64 << 2 + } let intrinsic = c"llvm.amdgcn.s.setreg"; - let llvm_i32 = get_scalar_type(self.context, ast::ScalarType::B32); let (hwreg, value) = match mode_reg { ModeRegister::Denormal { f32, f16f64 } => { - let (reg, offset, size) = (1, 4, 4u32); - let hwreg = reg | (offset << 6) | ((size - 1) << 11); - let f32 = if f32 { 0 } else { 3 }; - let f16f64 = if f16f64 { 0 } else { 3 }; - let value = f32 | f16f64 << 2; + let hwreg = hwreg(1, 4, 4); + let f32 = denormal_to_value(f32); + let f16f64 = denormal_to_value(f16f64); + let value = merge_regs(f32, f16f64); + (hwreg, value) + } + ModeRegister::Rounding { f32, f16f64 } => { + let hwreg = hwreg(1, 0, 4); + let f32 = rounding_to_value(f32); + let f16f64 = rounding_to_value(f16f64); + let value = merge_regs(f32, f16f64); (hwreg, value) } - ModeRegister::Rounding { .. } => todo!(), }; + let llvm_i32 = get_scalar_type(self.context, ast::ScalarType::B32); let hwreg_llvm = unsafe { LLVMConstInt(llvm_i32, hwreg as _, 0) }; let value_llvm = unsafe { LLVMConstInt(llvm_i32, value as _, 0) }; self.emit_intrinsic( diff --git a/ptx/src/pass/insert_ftz_control/mod.rs b/ptx/src/pass/insert_ftz_control/mod.rs index 24120a4..dfaafe3 100644 --- a/ptx/src/pass/insert_ftz_control/mod.rs +++ b/ptx/src/pass/insert_ftz_control/mod.rs @@ -770,15 +770,15 @@ pub(crate) fn run<'input>( _ => {} } } - println!( - "{:?}", - petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel]) - ); + //println!( + // "{:?}", + // petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel]) + //); cfg.fixup_function_calls()?; - println!( - "{:?}", - petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel]) - ); + //println!( + // "{:?}", + // petgraph::dot::Dot::with_config(&cfg.graph, &[petgraph::dot::Config::EdgeNoLabel]) + //); let rounding_f32 = compute_single_mode(&cfg, |node| node.rounding_f32); let denormal_f32 = compute_single_mode(&cfg, |node| node.denormal_f32); let denormal_f16f64 = compute_single_mode(&cfg, |node| node.denormal_f16f64);