diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs index a2a60dc..e203394 100644 --- a/ptx/src/pass/deparamize_functions.rs +++ b/ptx/src/pass/deparamize_functions.rs @@ -95,16 +95,7 @@ fn run_method<'input>( Ok::<_, TranslateError>(body) }) .transpose()?; - Ok(Function2 { - return_arguments: method.return_arguments, - name: method.name, - input_arguments: method.input_arguments, - body, - import_as: method.import_as, - tuning: method.tuning, - linkage: method.linkage, - is_kernel: method.is_kernel, - }) + Ok(Function2 { body, ..method }) } fn run_statement<'input>( diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 255aee0..90c6b8b 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -243,6 +243,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { } if !method.is_kernel { self.resolver.register(method.name, fn_); + self.emit_fn_attribute(fn_, "denormal-fp-math-f32", "dynamic"); + self.emit_fn_attribute(fn_, "denormal-fp-math", "dynamic"); + } else { + self.emit_fn_attribute( + fn_, + "denormal-fp-math-f32", + llvm_ftz(method.flush_to_zero_f32), + ); + self.emit_fn_attribute( + fn_, + "denormal-fp-math", + llvm_ftz(method.flush_to_zero_f16f64), + ); } for (i, param) in method.input_arguments.iter().enumerate() { let value = unsafe { LLVMGetParam(fn_, i as u32) }; @@ -413,6 +426,14 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { } } +fn llvm_ftz(ftz: bool) -> &'static str { + if ftz { + "preserve-sign" + } else { + "ieee" + } +} + fn get_input_argument_type( context: LLVMContextRef, v_type: &ast::Type, @@ -469,6 +490,7 @@ impl<'a> MethodEmitContext<'a> { Statement::FunctionPointer(_) => todo!(), Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?, Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?, + Statement::SetMode(mode_reg) => self.emit_set_mode(mode_reg)?, }) } @@ -1124,7 +1146,7 @@ impl<'a> MethodEmitContext<'a> { let cos = self.emit_intrinsic( c"llvm.cos.f32", Some(arguments.dst), - &ast::ScalarType::F32.into(), + Some(&ast::ScalarType::F32.into()), vec![(self.resolver.value(arguments.src)?, llvm_f32)], )?; unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) } @@ -1377,7 +1399,7 @@ impl<'a> MethodEmitContext<'a> { let sin = self.emit_intrinsic( c"llvm.sin.f32", Some(arguments.dst), - &ast::ScalarType::F32.into(), + Some(&ast::ScalarType::F32.into()), vec![(self.resolver.value(arguments.src)?, llvm_f32)], )?; unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) } @@ -1388,12 +1410,12 @@ impl<'a> MethodEmitContext<'a> { &mut self, name: &CStr, dst: Option, - return_type: &ast::Type, + return_type: Option<&ast::Type>, arguments: Vec<(LLVMValueRef, LLVMTypeRef)>, ) -> Result { let fn_type = get_function_type( self.context, - iter::once(return_type), + return_type.into_iter(), arguments.iter().map(|(_, type_)| Ok(*type_)), )?; let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; @@ -1612,7 +1634,7 @@ impl<'a> MethodEmitContext<'a> { let clamped = self.emit_intrinsic( c"llvm.umin", None, - &from.into(), + Some(&from.into()), vec![ (self.resolver.value(arguments.src)?, from_llvm), (max, from_llvm), @@ -1642,7 +1664,7 @@ impl<'a> MethodEmitContext<'a> { let zero_clamped = self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) }, None, - &from.into(), + Some(&from.into()), vec![ (self.resolver.value(arguments.src)?, from_llvm), (zero, from_llvm), @@ -1661,7 +1683,7 @@ impl<'a> MethodEmitContext<'a> { let fully_clamped = self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) }, None, - &from.into(), + Some(&from.into()), vec![(zero_clamped, from_llvm), (max, from_llvm)], )?; let resize_fn = if to.layout().size() >= from.layout().size() { @@ -1701,7 +1723,7 @@ impl<'a> MethodEmitContext<'a> { let rounded_float = self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, None, - &from.into(), + Some(&from.into()), vec![( self.resolver.value(arguments.src)?, get_scalar_type(self.context, from), @@ -1770,7 +1792,7 @@ impl<'a> MethodEmitContext<'a> { self.emit_intrinsic( intrinsic, Some(arguments.dst), - &data.type_.into(), + Some(&data.type_.into()), vec![(self.resolver.value(arguments.src)?, type_)], )?; Ok(()) @@ -1791,7 +1813,7 @@ impl<'a> MethodEmitContext<'a> { self.emit_intrinsic( intrinsic, Some(arguments.dst), - &data.type_.into(), + Some(&data.type_.into()), vec![(self.resolver.value(arguments.src)?, type_)], )?; Ok(()) @@ -1813,7 +1835,7 @@ impl<'a> MethodEmitContext<'a> { self.emit_intrinsic( intrinsic, Some(arguments.dst), - &data.type_.into(), + Some(&data.type_.into()), vec![(self.resolver.value(arguments.src)?, type_)], )?; Ok(()) @@ -1935,7 +1957,7 @@ impl<'a> MethodEmitContext<'a> { self.emit_intrinsic( intrinsic, Some(arguments.dst), - &data.type_.into(), + Some(&data.type_.into()), vec![( self.resolver.value(arguments.src)?, get_scalar_type(self.context, data.type_), @@ -1952,7 +1974,7 @@ impl<'a> MethodEmitContext<'a> { self.emit_intrinsic( c"llvm.amdgcn.log.f32", Some(arguments.dst), - &ast::ScalarType::F32.into(), + Some(&ast::ScalarType::F32.into()), vec![( self.resolver.value(arguments.src)?, get_scalar_type(self.context, ast::ScalarType::F32.into()), @@ -2007,7 +2029,7 @@ impl<'a> MethodEmitContext<'a> { self.emit_intrinsic( intrinsic, Some(arguments.dst), - &type_.into(), + Some(&type_.into()), vec![(self.resolver.value(arguments.src)?, llvm_type)], )?; Ok(()) @@ -2031,7 +2053,7 @@ impl<'a> MethodEmitContext<'a> { self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, Some(arguments.dst), - &data.type_().into(), + Some(&data.type_().into()), vec![ (self.resolver.value(arguments.src1)?, llvm_type), (self.resolver.value(arguments.src2)?, llvm_type), @@ -2058,7 +2080,7 @@ impl<'a> MethodEmitContext<'a> { self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, Some(arguments.dst), - &data.type_().into(), + Some(&data.type_().into()), vec![ (self.resolver.value(arguments.src1)?, llvm_type), (self.resolver.value(arguments.src2)?, llvm_type), @@ -2076,7 +2098,7 @@ impl<'a> MethodEmitContext<'a> { self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, Some(arguments.dst), - &data.type_.into(), + Some(&data.type_.into()), vec![ ( self.resolver.value(arguments.src1)?, @@ -2197,12 +2219,49 @@ impl<'a> MethodEmitContext<'a> { self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(llvm_intrinsic.as_bytes()) }, Some(arguments.dst), - &data.type_.into(), + Some(&data.type_.into()), intrinsic_arguments, )?; Ok(()) } + fn emit_set_mode(&mut self, mode_reg: ModeRegister) -> Result<(), TranslateError> { + let intrinsic = c"llvm.amdgcn.s.setreg"; + let llvm_i32 = get_scalar_type(self.context, ast::ScalarType::B32); + let (hwreg, value) = match mode_reg { + ModeRegister::DenormalF32(ftz) => { + let (reg, offset, size) = (1, 4, 2u32); + let hwreg = reg | (offset << 6) | ((size - 1) << 11); + (hwreg, if ftz { 0u32 } else { 3 }) + } + ModeRegister::DenormalF16F64(ftz) => { + let (reg, offset, size) = (1, 6, 2u32); + let hwreg = reg | (offset << 6) | ((size - 1) << 11); + (hwreg, if ftz { 0 } else { 3 }) + } + ModeRegister::DenormalBoth { f32, f16f64 } => { + let (reg, offset, size) = (1, 4, 4u32); + let hwreg = reg | (offset << 6) | ((size - 1) << 11); + let f32 = if f32 { 0 } else { 3 }; + let f16f64 = if f16f64 { 0 } else { 3 }; + let value = f32 | f16f64 << 2; + (hwreg, value) + } + ModeRegister::RoundingF32(rounding_mode) => todo!(), + ModeRegister::RoundingF16F64(rounding_mode) => todo!(), + ModeRegister::RoundingBoth { f32, f16f64 } => todo!(), + }; + let hwreg_llvm = unsafe { LLVMConstInt(llvm_i32, hwreg as _, 0) }; + let value_llvm = unsafe { LLVMConstInt(llvm_i32, value as _, 0) }; + self.emit_intrinsic( + intrinsic, + None, + None, + vec![(hwreg_llvm, llvm_i32), (value_llvm, llvm_i32)], + )?; + Ok(()) + } + /* // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` // Should be available in LLVM 19 diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs index 07806f9..c87dd92 100644 --- a/ptx/src/pass/expand_operands.rs +++ b/ptx/src/pass/expand_operands.rs @@ -41,14 +41,18 @@ fn run_method<'input>( }) .transpose()?; Ok(Function2 { + body, return_arguments: method.return_arguments, name: method.name, input_arguments: method.input_arguments, - body, import_as: method.import_as, tuning: method.tuning, linkage: method.linkage, is_kernel: method.is_kernel, + flush_to_zero_f32: method.flush_to_zero_f32, + flush_to_zero_f16f64: method.flush_to_zero_f16f64, + roundind_mode_f32: method.roundind_mode_f32, + roundind_mode_f16f64: method.roundind_mode_f16f64, }) } diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers2.rs index 3323305..ad484fd 100644 --- a/ptx/src/pass/fix_special_registers2.rs +++ b/ptx/src/pass/fix_special_registers2.rs @@ -20,6 +20,10 @@ pub(super) fn run<'a, 'input>( tuning: Vec::new(), linkage: ast::LinkingDirective::EXTERN, is_kernel: false, + flush_to_zero_f32: false, + flush_to_zero_f16f64: false, + roundind_mode_f32: ptx_parser::RoundingMode::NearestEven, + roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven, })); sreg_to_function.insert(sreg, name); }, @@ -60,16 +64,7 @@ fn run_method<'a, 'input>( Ok::<_, TranslateError>(result) }) .transpose()?; - Ok(Function2 { - return_arguments: method.return_arguments, - name: method.name, - input_arguments: method.input_arguments, - body, - import_as: method.import_as, - tuning: method.tuning, - linkage: method.linkage, - is_kernel: method.is_kernel, - }) + Ok(Function2 { body, ..method }) } fn run_statement<'a, 'input>( diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index 014c49b..935e78d 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -64,16 +64,7 @@ fn run_method<'a, 'input>( Ok::<_, TranslateError>(result) }) .transpose()?; - Ok(Function2 { - return_arguments: method.return_arguments, - name: method.name, - input_arguments: method.input_arguments, - body, - import_as: method.import_as, - tuning: method.tuning, - linkage: method.linkage, - is_kernel: method.is_kernel, - }) + Ok(Function2 { body, ..method }) } fn run_statement<'a, 'input>( diff --git a/ptx/src/pass/insert_ftz_control.rs b/ptx/src/pass/insert_ftz_control.rs index a048f41..f21a804 100644 --- a/ptx/src/pass/insert_ftz_control.rs +++ b/ptx/src/pass/insert_ftz_control.rs @@ -1,6 +1,7 @@ use super::BrachCondition; use super::Directive2; use super::Function2; +use super::ModeRegister; use super::SpirvWord; use super::Statement; use super::TranslateError; @@ -18,6 +19,7 @@ use rustc_hash::FxHashSet; use std::hash::Hash; use std::iter; use std::mem; +use std::u32; use strum::EnumCount; use strum_macros::{EnumCount, VariantArray}; @@ -36,6 +38,13 @@ impl DenormalMode { DenormalMode::Preserve } } + + fn to_ftz(self) -> bool { + match self { + DenormalMode::FlushToZero => true, + DenormalMode::Preserve => false, + } + } } impl Into for DenormalMode { @@ -94,20 +103,19 @@ impl InstructionModes { _ => {} } } - fn set_if_some(source: &mut Option, value: Option) { - match (source, value) { - (Some(ref mut x), Some(y)) => *x = y, - _ => {} + fn set_if_any(source: &mut Option, value: Option) { + if let Some(x) = value { + *source = Some(x); } } set_if_none(&mut entry.denormal_f32, self.denormal_f32); set_if_none(&mut entry.denormal_f16f64, self.denormal_f16f64); set_if_none(&mut entry.rounding_f32, self.rounding_f32); set_if_none(&mut entry.rounding_f16f64, self.rounding_f16f64); - set_if_some(&mut exit.denormal_f32, self.denormal_f32); - set_if_some(&mut exit.denormal_f16f64, self.denormal_f16f64); - set_if_some(&mut exit.rounding_f32, self.rounding_f32); - set_if_some(&mut exit.rounding_f16f64, self.rounding_f16f64); + set_if_any(&mut exit.denormal_f32, self.denormal_f32); + set_if_any(&mut exit.denormal_f16f64, self.denormal_f16f64); + set_if_any(&mut exit.rounding_f32, self.rounding_f32); + set_if_any(&mut exit.rounding_f16f64, self.rounding_f16f64); } fn none() -> Self { @@ -209,18 +217,12 @@ impl InstructionModes { flush_to_zero.map(DenormalMode::from_ftz), Some(RoundingMode::from_ast(rounding)), ), - ast::CvtMode::SignedFromFP { - flush_to_zero, - rounding, + // float to int contains rounding field, but it's not a rounding + // mode but rather round-to-int operation that will be applied + ast::CvtMode::SignedFromFP { flush_to_zero, .. } + | ast::CvtMode::UnsignedFromFP { flush_to_zero, .. } => { + Self::new(cvt.from, flush_to_zero.map(DenormalMode::from_ftz), None) } - | ast::CvtMode::UnsignedFromFP { - flush_to_zero, - rounding, - } => Self::new( - cvt.from, - flush_to_zero.map(DenormalMode::from_ftz), - Some(RoundingMode::from_ast(rounding)), - ), ast::CvtMode::FPFromSigned(rnd) | ast::CvtMode::FPFromUnsigned(rnd) => { Self::new(cvt.to, None, Some(RoundingMode::from_ast(rnd))) } @@ -263,22 +265,15 @@ impl ControlFlowGraph { } fn set_modes(&mut self, node: NodeIndex, entry: InstructionModes, exit: InstructionModes) { - self.graph[node].denormal_f32 = Mode { - entry: entry.denormal_f32.map(ExtendedMode::BasicBlock), - exit: exit.denormal_f32.map(ExtendedMode::BasicBlock), - }; - self.graph[node].denormal_f16f64 = Mode { - entry: entry.denormal_f16f64.map(ExtendedMode::BasicBlock), - exit: exit.denormal_f16f64.map(ExtendedMode::BasicBlock), - }; - self.graph[node].rounding_f32 = Mode { - entry: entry.rounding_f32.map(ExtendedMode::BasicBlock), - exit: exit.rounding_f32.map(ExtendedMode::BasicBlock), - }; - self.graph[node].rounding_f16f64 = Mode { - entry: entry.rounding_f16f64.map(ExtendedMode::BasicBlock), - exit: exit.rounding_f16f64.map(ExtendedMode::BasicBlock), - }; + let node = &mut self.graph[node]; + node.denormal_f32.entry = entry.denormal_f32.map(ExtendedMode::BasicBlock); + node.denormal_f16f64.entry = entry.denormal_f16f64.map(ExtendedMode::BasicBlock); + node.rounding_f32.entry = entry.rounding_f32.map(ExtendedMode::BasicBlock); + node.rounding_f16f64.entry = entry.rounding_f16f64.map(ExtendedMode::BasicBlock); + node.denormal_f32.exit = exit.denormal_f32.map(ExtendedMode::BasicBlock); + node.denormal_f16f64.exit = exit.denormal_f16f64.map(ExtendedMode::BasicBlock); + node.rounding_f32.exit = exit.rounding_f32.map(ExtendedMode::BasicBlock); + node.rounding_f16f64.exit = exit.rounding_f16f64.map(ExtendedMode::BasicBlock); } } @@ -343,7 +338,7 @@ trait EnumTuple { pub(crate) fn run<'input>( flat_resolver: &mut super::GlobalStringIdentResolver2<'input>, - directives: Vec, super::SpirvWord>>, + mut directives: Vec, super::SpirvWord>>, ) -> Result, SpirvWord>>, TranslateError> { let mut cfg = ControlFlowGraph::new(); for directive in directives.iter() { @@ -351,42 +346,39 @@ pub(crate) fn run<'input>( super::Directive2::Method(Function2 { name, body: Some(body), + is_kernel, .. }) => { - let mut basic_block = Some(cfg.add_entry_basic_block(*name)); - let mut entry = InstructionModes::none(); - let mut exit = InstructionModes::none(); + // TODO: implement for non-kernels + if !*is_kernel { + todo!() + } + let entry_index = cfg.add_entry_basic_block(*name); + let mut bb_state = BasicBlockState::new(&mut cfg); + let mut body_iter = body.iter(); + match body_iter.next() { + Some(Statement::Label(label)) => { + bb_state.cfg.add_jump(entry_index, *label); + bb_state.start(*label); + } + _ => return Err(error_unreachable()), + }; for statement in body.iter() { match statement { Statement::Instruction(ast::Instruction::Bra { arguments }) => { - let bb_index = basic_block.ok_or_else(error_unreachable)?; - cfg.add_jump(bb_index, arguments.src); - cfg.set_modes( - bb_index, - mem::replace(&mut entry, InstructionModes::none()), - mem::replace(&mut exit, InstructionModes::none()), - ); - basic_block = None; + bb_state.end(&[arguments.src]); } Statement::Label(label) => { - basic_block = Some(cfg.get_or_add_basic_block(*label)); + bb_state.start(*label); } Statement::Conditional(BrachCondition { if_true, if_false, .. }) => { - let bb_index = basic_block.ok_or_else(error_unreachable)?; - cfg.add_jump(bb_index, *if_true); - cfg.add_jump(bb_index, *if_false); - cfg.set_modes( - bb_index, - mem::replace(&mut entry, InstructionModes::none()), - mem::replace(&mut exit, InstructionModes::none()), - ); - basic_block = None; + bb_state.end(&[*if_true, *if_false]); } Statement::Instruction(instruction) => { let modes = get_modes(instruction); - modes.fold_into(&mut entry, &mut exit); + bb_state.append(modes); } _ => {} } @@ -395,7 +387,370 @@ pub(crate) fn run<'input>( _ => {} } } - todo!() + let denormal_f32 = compute_single_mode(&cfg, |node| node.denormal_f32); + let denormal_f16f64 = compute_single_mode(&cfg, |node| node.denormal_f16f64); + let rounding_f32 = compute_single_mode(&cfg, |node| node.rounding_f32); + let rounding_f16f64 = compute_single_mode(&cfg, |node| node.rounding_f16f64); + let denormal_f32 = optimize::(denormal_f32); + let denormal_f16f64 = optimize::(denormal_f16f64); + let rounding_f32 = optimize::(rounding_f32); + let rounding_f16f64 = optimize::(rounding_f16f64); + insert_mode_control( + flat_resolver, + &mut directives, + &cfg, + denormal_f32, + denormal_f16f64, + rounding_f32, + rounding_f16f64, + )?; + Ok(directives) +} + +fn insert_mode_control<'input>( + flat_resolver: &mut super::GlobalStringIdentResolver2<'input>, + directives: &mut [Directive2, SpirvWord>], + cfg: &ControlFlowGraph, + denormal_f32: ModeInsertions, + denormal_f16f64: ModeInsertions, + rounding_f32: ModeInsertions, + rounding_f16f64: ModeInsertions, +) -> Result<(), TranslateError> { + for directive in directives.iter_mut() { + let body_ptr = match directive { + Directive2::Variable(..) | Directive2::Method(Function2 { body: None, .. }) => continue, + Directive2::Method(Function2 { + name, + body: Some(body), + flush_to_zero_f32, + flush_to_zero_f16f64, + roundind_mode_f32: rounding_mode_f32, + roundind_mode_f16f64: rounding_mode_f16f64, + .. + }) => { + *flush_to_zero_f32 = denormal_f32 + .kernels + .get(name) + .copied() + .unwrap_or(DenormalMode::default()) + .to_ftz(); + *flush_to_zero_f16f64 = denormal_f16f64 + .kernels + .get(name) + .copied() + .unwrap_or(DenormalMode::default()) + .to_ftz(); + *rounding_mode_f32 = rounding_f32 + .kernels + .get(name) + .copied() + .unwrap_or(RoundingMode::default()) + .to_ast(); + *rounding_mode_f16f64 = rounding_f16f64 + .kernels + .get(name) + .copied() + .unwrap_or(RoundingMode::default()) + .to_ast(); + body + } + }; + let mut old_body = mem::replace(body_ptr, Vec::new()); + let mut result = Vec::with_capacity(old_body.len()); + let mut bb_state = BasicBlockControlState::new( + &denormal_f32, + &denormal_f16f64, + &rounding_f32, + &rounding_f16f64, + ); + for statement in old_body.into_iter() { + match &statement { + Statement::Label(label) => { + bb_state.start(*label); + } + Statement::Instruction(instruction) => { + let modes = get_modes(&instruction); + bb_state.insert(&mut result, modes)?; + } + _ => {} + } + result.push(statement); + } + *body_ptr = result; + } + Ok(()) +} + +struct BasicBlockControlState<'a> { + global_denormal_f32: &'a ModeInsertions, + global_denormal_f16f64: &'a ModeInsertions, + global_rounding_f32: &'a ModeInsertions, + global_rounding_f16f64: &'a ModeInsertions, + basic_block: SpirvWord, + denormal_f32: RegisterState, + denormal_f16f64: RegisterState, + foldable_rounding_f32: Option, + foldable_rounding_f16f64: Option, +} + +#[derive(Clone, Copy)] +enum RegisterState { + Inherited, + Unknown, + Value(Option, T), +} + +impl RegisterState { + fn empty() -> Self { + Self::Unknown + } + + fn new(must_insert: bool) -> Self { + if must_insert { + Self::Unknown + } else { + Self::Inherited + } + } +} + +impl<'a> BasicBlockControlState<'a> { + fn new( + global_denormal_f32: &'a ModeInsertions, + global_denormal_f16f64: &'a ModeInsertions, + global_rounding_f32: &'a ModeInsertions, + global_rounding_f16f64: &'a ModeInsertions, + ) -> Self { + BasicBlockControlState { + global_denormal_f32, + global_denormal_f16f64, + global_rounding_f32, + global_rounding_f16f64, + basic_block: SpirvWord(u32::MAX), + denormal_f32: RegisterState::empty(), + denormal_f16f64: RegisterState::empty(), + foldable_rounding_f32: None, + foldable_rounding_f16f64: None, + } + } + + fn start(&mut self, label: SpirvWord) { + self.denormal_f32 = + RegisterState::new(self.global_denormal_f32.basic_blocks.contains(&label)); + self.denormal_f32 = + RegisterState::new(self.global_denormal_f16f64.basic_blocks.contains(&label)); + self.foldable_rounding_f32 = None; + self.foldable_rounding_f16f64 = None; + self.basic_block = label; + } + + fn add_or_fold_mode_set( + &mut self, + result: &mut Vec, SpirvWord>>, + new_mode: bool, + ) -> Option { + // try and fold into the other mode set + if let RegisterState::Value(Some(other_index), other_value) = self.denormal_f16f64 { + if let Some(Statement::SetMode(ModeRegister::DenormalF16F64(_))) = + result.get_mut(other_index) + { + result[other_index] = Statement::SetMode(ModeRegister::DenormalBoth { + f32: new_mode, + f16f64: other_value, + }); + return None; + } + } + result.push(Statement::SetMode(ModeRegister::DenormalF32(new_mode))); + Some(result.len() - 1) + } + + fn insert( + &mut self, + result: &mut Vec, SpirvWord>>, + modes: InstructionModes, + ) -> Result<(), TranslateError> { + self.insert_one::(result, modes.denormal_f32.map(DenormalMode::to_ftz))?; + self.insert_one::( + result, + modes.denormal_f16f64.map(DenormalMode::to_ftz), + )?; + Ok(()) + } + + fn insert_one( + &mut self, + result: &mut Vec, SpirvWord>>, + mode: Option, + ) -> Result<(), TranslateError> { + if let Some(new_mode) = mode { + let register_state = View::get_register(self); + match register_state { + RegisterState::Inherited => { + View::set_register(self, RegisterState::Value(None, new_mode)); + } + RegisterState::Unknown => { + View::set_register( + self, + RegisterState::Value( + Some(self.add_or_fold_mode_set2::(result, new_mode)), + new_mode, + ), + ); + } + RegisterState::Value(_, old_value) => { + if new_mode == old_value { + return Ok(()); + } + View::set_register( + self, + RegisterState::Value( + Some(self.add_or_fold_mode_set2::(result, new_mode)), + new_mode, + ), + ); + } + } + } + Ok(()) + } + + // Return the index of the last insertion of SetMode with this mode + fn add_or_fold_mode_set2( + &self, + result: &mut Vec, SpirvWord>>, + new_mode: View::Value, + ) -> usize { + // try and fold into the other mode set in struction + if let RegisterState::Value(Some(twin_index), _) = View::TwinView::get_register(self) { + if let Some(Statement::SetMode(register_mode)) = result.get_mut(twin_index) { + if let Some(twin_mode) = View::TwinView::get_single_mode(register_mode) { + *register_mode = View::new_mode(new_mode, Some(twin_mode)); + return twin_index; + } + } + } + result.push(Statement::SetMode(View::new_mode(new_mode, None))); + result.len() - 1 + } +} + +trait ModeView { + type Value: PartialEq + Eq + Copy + Clone; + type TwinView: ModeView; + + fn get_register(bb: &BasicBlockControlState) -> RegisterState; + fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState); + fn new_mode(t: Self::Value, other: Option) -> ModeRegister; + fn get_single_mode(reg: &ModeRegister) -> Option; +} + +struct DenormalF32View; + +impl ModeView for DenormalF32View { + type Value = bool; + type TwinView = DenormalF16F64View; + + fn get_register(bb: &BasicBlockControlState) -> RegisterState { + bb.denormal_f32 + } + + fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState) { + bb.denormal_f32 = reg; + } + + fn new_mode(f32: Self::Value, f16f64: Option) -> ModeRegister { + match f16f64 { + Some(f16f64) => ModeRegister::DenormalBoth { f32, f16f64 }, + None => ModeRegister::DenormalF32(f32), + } + } + + fn get_single_mode(reg: &ModeRegister) -> Option { + match reg { + ModeRegister::DenormalF32(value) => Some(*value), + _ => None, + } + } +} + +struct DenormalF16F64View; + +impl ModeView for DenormalF16F64View { + type Value = bool; + type TwinView = DenormalF32View; + + fn get_register(bb: &BasicBlockControlState) -> RegisterState { + bb.denormal_f16f64 + } + + fn set_register(bb: &mut BasicBlockControlState, reg: RegisterState) { + bb.denormal_f16f64 = reg; + } + + fn new_mode(f16f64: Self::Value, f32: Option) -> ModeRegister { + match f32 { + Some(f32) => ModeRegister::DenormalBoth { f16f64, f32 }, + None => ModeRegister::DenormalF16F64(f16f64), + } + } + + fn get_single_mode(reg: &ModeRegister) -> Option { + match reg { + ModeRegister::DenormalF16F64(value) => Some(*value), + _ => None, + } + } +} + +struct BasicBlockState<'a> { + cfg: &'a mut ControlFlowGraph, + node_index: Option, + // If it's a kernel basic block then we don't track entry instruction mode + entry: InstructionModes, + exit: InstructionModes, +} + +impl<'a> BasicBlockState<'a> { + fn new(cfg: &'a mut ControlFlowGraph) -> BasicBlockState<'a> { + Self { + cfg, + node_index: None, + entry: InstructionModes::none(), + exit: InstructionModes::none(), + } + } + + fn start(&mut self, label: SpirvWord) { + self.end(&[]); + self.node_index = Some(self.cfg.get_or_add_basic_block(label)); + } + + fn end(&mut self, jumps: &[SpirvWord]) { + let node_index = self.node_index.take(); + let node_index = match node_index { + Some(x) => x, + None => return, + }; + for target in jumps { + self.cfg.add_jump(node_index, *target); + } + self.cfg.set_modes( + node_index, + mem::replace(&mut self.entry, InstructionModes::none()), + mem::replace(&mut self.exit, InstructionModes::none()), + ); + } + + fn append(&mut self, modes: InstructionModes) { + modes.fold_into(&mut self.entry, &mut self.exit); + } +} + +impl<'a> Drop for BasicBlockState<'a> { + fn drop(&mut self) { + self.end(&[]); + } } fn compute_single_mode( @@ -424,10 +779,9 @@ fn compute_single_mode( UniqueVec::new(graph.graph.neighbors_directed(index, Direction::Incoming)); let mut visited = FxHashSet::default(); while let Some(current) = to_visit.pop() { - if visited.contains(¤t) { + if !visited.insert(current) { continue; } - visited.insert(current); let exit_mode = getter(graph.graph.node_weight(current).unwrap()).exit; match exit_mode { None => { @@ -462,6 +816,7 @@ fn compute_single_mode( } } +#[derive(Debug)] struct PartialModeInsertion { bb_must_insert_mode: FxHashSet, bb_maybe_insert_mode: FxHashMap)>, @@ -498,10 +853,11 @@ fn optimize + strum::VariantArray + std::fmt::Debug, const } } let mut kernels = FxHashMap::default(); - for (kernel, modes) in kernel_modes { + 'iterate_kernels: for (kernel, modes) in kernel_modes { for (mode, var) in modes.into_iter().enumerate() { if solution[var] > 0.5 { kernels.insert(kernel, T::VARIANTS[mode]); + continue 'iterate_kernels; } } } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 442b1e7..9eda5f3 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -19,6 +19,7 @@ mod hoist_globals; mod insert_explicit_load_store; mod insert_ftz_control; mod insert_implicit_conversions2; +mod normalize_basic_blocks; mod normalize_identifiers2; mod normalize_predicates2; mod replace_instructions_with_function_calls; @@ -52,6 +53,7 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result { FunctionPointer(FunctionPointerDetails), VectorRead(VectorRead), VectorWrite(VectorWrite), + SetMode(ModeRegister), +} + +enum ModeRegister { + DenormalF32(bool), + DenormalF16F64(bool), + DenormalBoth { + f32: bool, + f16f64: bool, + }, + RoundingF32(ast::RoundingMode), + RoundingF16F64(ast::RoundingMode), + RoundingBoth { + f32: ast::RoundingMode, + f16f64: ast::RoundingMode, + }, } impl> Statement, T> { @@ -469,6 +487,7 @@ impl> Statement, T> { let src = visitor.visit_ident(src, None, false, false)?; Statement::FunctionPointer(FunctionPointerDetails { dst, src }) } + Statement::SetMode(mode_register) => Statement::SetMode(mode_register), }) } } @@ -573,6 +592,10 @@ struct Function2 { import_as: Option, tuning: Vec, linkage: ast::LinkingDirective, + flush_to_zero_f32: bool, + flush_to_zero_f16f64: bool, + roundind_mode_f32: ast::RoundingMode, + roundind_mode_f16f64: ast::RoundingMode, } type NormalizedDirective2 = Directive2< diff --git a/ptx/src/pass/normalize_basic_blocks.rs b/ptx/src/pass/normalize_basic_blocks.rs new file mode 100644 index 0000000..c87a8ad --- /dev/null +++ b/ptx/src/pass/normalize_basic_blocks.rs @@ -0,0 +1,52 @@ +use super::*; + +// This pass normalized ptx modules in two ways that makes mode computation pass +// and code emissions passes much simpler: +// * Inserts label at the start of every function +// This makes control flow graph simpler in mode computation block: we can +// represent kernels as separate nodes with its own separate entry/exit mode +// * Inserts label at the start of every basic block + +pub(crate) fn run( + flat_resolver: &mut GlobalStringIdentResolver2<'_>, + mut directives: Vec, SpirvWord>>, +) -> Vec, SpirvWord>> { + for directive in directives.iter_mut() { + let body_ref = match directive { + Directive2::Method(Function2 { + body: Some(body), .. + }) => body, + _ => continue, + }; + let body = std::mem::replace(body_ref, Vec::new()); + let mut result = Vec::with_capacity(body.len()); + let mut needs_label = false; + let mut body_iterator = body.into_iter(); + match body_iterator.next() { + Some(Statement::Label(_)) => {} + Some(statement) => { + result.push(Statement::Label(flat_resolver.register_unnamed(None))); + result.push(statement); + } + None => {} + } + for statement in body_iterator { + if needs_label && !matches!(statement, Statement::Label(..)) { + result.push(Statement::Label(flat_resolver.register_unnamed(None))); + } + needs_label = is_block_terminator(&statement); + result.push(statement); + } + *body_ref = result; + } + directives +} + +fn is_block_terminator(instruction: &Statement, SpirvWord>) -> bool { + match instruction { + Statement::Conditional(..) + | Statement::Instruction(ast::Instruction::Bra { .. }) + | Statement::Instruction(ast::Instruction::Ret { .. }) => true, + _ => false, + } +} diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs index 4d94897..f5ef55c 100644 --- a/ptx/src/pass/normalize_identifiers2.rs +++ b/ptx/src/pass/normalize_identifiers2.rs @@ -52,9 +52,13 @@ fn run_method<'input, 'b>( input_arguments, body, import_as: None, - tuning: method.tuning, linkage, is_kernel, + tuning: method.tuning, + flush_to_zero_f32: false, + flush_to_zero_f16f64: false, + roundind_mode_f32: ptx_parser::RoundingMode::NearestEven, + roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven, }) } diff --git a/ptx/src/pass/normalize_predicates2.rs b/ptx/src/pass/normalize_predicates2.rs index f8505b6..f8be688 100644 --- a/ptx/src/pass/normalize_predicates2.rs +++ b/ptx/src/pass/normalize_predicates2.rs @@ -36,14 +36,18 @@ fn run_method<'input>( }) .transpose()?; Ok(Function2 { + body, return_arguments: method.return_arguments, name: method.name, input_arguments: method.input_arguments, - body, import_as: method.import_as, tuning: method.tuning, linkage: method.linkage, is_kernel: method.is_kernel, + flush_to_zero_f32: method.flush_to_zero_f32, + flush_to_zero_f16f64: method.flush_to_zero_f16f64, + roundind_mode_f32: method.roundind_mode_f32, + roundind_mode_f16f64: method.roundind_mode_f16f64, }) } diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_function_calls.rs index 089d276..f54c134 100644 --- a/ptx/src/pass/replace_instructions_with_function_calls.rs +++ b/ptx/src/pass/replace_instructions_with_function_calls.rs @@ -21,6 +21,10 @@ pub(super) fn run<'input>( tuning: Vec::new(), linkage: ast::LinkingDirective::EXTERN, is_kernel: false, + flush_to_zero_f32: false, + flush_to_zero_f16f64: false, + roundind_mode_f32: ptx_parser::RoundingMode::NearestEven, + roundind_mode_f16f64: ptx_parser::RoundingMode::NearestEven, }) }) .collect::>(); diff --git a/ptx/src/pass/resolve_function_pointers.rs b/ptx/src/pass/resolve_function_pointers.rs index 1721afd..81b9f0a 100644 --- a/ptx/src/pass/resolve_function_pointers.rs +++ b/ptx/src/pass/resolve_function_pointers.rs @@ -40,16 +40,7 @@ fn run_method<'input>( .collect::, _>>() }) .transpose()?; - Ok(Function2 { - return_arguments: method.return_arguments, - name: method.name, - input_arguments: method.input_arguments, - body, - import_as: method.import_as, - tuning: method.tuning, - linkage: method.linkage, - is_kernel: method.is_kernel, - }) + Ok(Function2 { body, ..method }) } fn run_statement<'input>( diff --git a/ptx/src/test/spirv_run/malformed_label.ptx b/ptx/src/test/spirv_run/malformed_label.ptx new file mode 100644 index 0000000..cb41a7c --- /dev/null +++ b/ptx/src/test/spirv_run/malformed_label.ptx @@ -0,0 +1,27 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry malformed_label( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + bra BB0; +// this basic block does not start with a label + ld.u64 temp, [out_addr]; + +BB0: + ld.u64 temp, [in_addr]; + add.u64 temp2, temp, 1; + st.u64 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 99573f6..e1c1670 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -186,6 +186,8 @@ test_ptx!( [0x800000u32, 0xFFFFFF] ); +test_ptx!(malformed_label, [2u64], [3u64]); + test_ptx!(assertfail); test_ptx!(func_ptr); test_ptx!(lanemask_lt);