diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index f7cdcc3..b045a83 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -518,13 +518,13 @@ pub struct MadFloatDesc {} #[derive(Copy, Clone)] pub struct AbsDetails { - pub flush_to_zero: bool, + pub flush_to_zero: Option, pub typ: ScalarType, } #[derive(Copy, Clone)] pub struct RcpDetails { pub rounding: Option, - pub flush_to_zero: bool, + pub flush_to_zero: Option, pub is_f64: bool, } @@ -769,7 +769,7 @@ pub struct AddIntDesc { pub struct SetpData { pub typ: ScalarType, - pub flush_to_zero: bool, + pub flush_to_zero: Option, pub cmp_op: SetpCompareOp, } @@ -799,7 +799,7 @@ pub enum SetpBoolPostOp { pub struct SetpBoolData { pub typ: ScalarType, - pub flush_to_zero: bool, + pub flush_to_zero: Option, pub cmp_op: SetpCompareOp, pub bool_op: SetpBoolPostOp, } @@ -831,7 +831,7 @@ pub struct CvtIntToIntDesc { pub struct CvtDesc { pub rounding: Option, - pub flush_to_zero: bool, + pub flush_to_zero: Option, pub saturate: bool, pub dst: Dst, pub src: Src, @@ -873,7 +873,7 @@ impl CvtDetails { dst, src, saturate, - flush_to_zero, + flush_to_zero: Some(flush_to_zero), rounding: Some(rounding), }) } @@ -893,7 +893,7 @@ impl CvtDetails { dst, src, saturate, - flush_to_zero, + flush_to_zero: Some(flush_to_zero), rounding: Some(rounding), }) } @@ -1009,7 +1009,7 @@ pub struct ArithSInt { pub struct ArithFloat { pub typ: FloatType, pub rounding: Option, - pub flush_to_zero: bool, + pub flush_to_zero: Option, pub saturate: bool, } @@ -1022,7 +1022,7 @@ pub enum MinMaxDetails { #[derive(Copy, Clone)] pub struct MinMaxFloat { - pub ftz: bool, + pub flush_to_zero: Option, pub nan: bool, pub typ: FloatType, } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index a132705..163a233 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -740,17 +740,29 @@ InstSetp: ast::Instruction> = { }; SetpMode: ast::SetpData = { - => ast::SetpData{ + => ast::SetpData { typ: t, - flush_to_zero: ftz.is_some(), + flush_to_zero: None, + cmp_op: cmp_op, + }, + ".f32" => ast::SetpData { + typ: ast::ScalarType::F32, + flush_to_zero: Some(ftz.is_some()), cmp_op: cmp_op, } + }; SetpBoolMode: ast::SetpBoolData = { - => ast::SetpBoolData{ + => ast::SetpBoolData { typ: t, - flush_to_zero: ftz.is_some(), + flush_to_zero: None, + cmp_op: cmp_op, + bool_op: bool_op, + }, + ".f32" => ast::SetpBoolData { + typ: ast::ScalarType::F32, + flush_to_zero: Some(ftz.is_some()), cmp_op: cmp_op, bool_op: bool_op, } @@ -783,7 +795,7 @@ SetpBoolPostOp: ast::SetpBoolPostOp = { ".xor" => ast::SetpBoolPostOp::Xor, }; -SetpType: ast::ScalarType = { +SetpTypeNoF32: ast::ScalarType = { ".b16" => ast::ScalarType::B16, ".b32" => ast::ScalarType::B32, ".b64" => ast::ScalarType::B64, @@ -793,7 +805,6 @@ SetpType: ast::ScalarType = { ".s16" => ast::ScalarType::S16, ".s32" => ast::ScalarType::S32, ".s64" => ast::ScalarType::S64, - ".f32" => ast::ScalarType::F32, ".f64" => ast::ScalarType::F64, }; @@ -857,7 +868,7 @@ InstCvt: ast::Instruction> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: r, - flush_to_zero: false, + flush_to_zero: None, saturate: s.is_some(), dst: ast::FloatType::F16, src: ast::FloatType::F16 @@ -868,7 +879,7 @@ InstCvt: ast::Instruction> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: None, - flush_to_zero: f.is_some(), + flush_to_zero: Some(f.is_some()), saturate: s.is_some(), dst: ast::FloatType::F32, src: ast::FloatType::F16 @@ -879,7 +890,7 @@ InstCvt: ast::Instruction> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: None, - flush_to_zero: false, + flush_to_zero: None, saturate: s.is_some(), dst: ast::FloatType::F64, src: ast::FloatType::F16 @@ -890,7 +901,7 @@ InstCvt: ast::Instruction> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: Some(r), - flush_to_zero: f.is_some(), + flush_to_zero: Some(f.is_some()), saturate: s.is_some(), dst: ast::FloatType::F16, src: ast::FloatType::F32 @@ -901,7 +912,7 @@ InstCvt: ast::Instruction> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: r, - flush_to_zero: f.is_some(), + flush_to_zero: Some(f.is_some()), saturate: s.is_some(), dst: ast::FloatType::F32, src: ast::FloatType::F32 @@ -912,7 +923,7 @@ InstCvt: ast::Instruction> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: None, - flush_to_zero: false, + flush_to_zero: None, saturate: s.is_some(), dst: ast::FloatType::F64, src: ast::FloatType::F32 @@ -923,7 +934,7 @@ InstCvt: ast::Instruction> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: Some(r), - flush_to_zero: false, + flush_to_zero: None, saturate: s.is_some(), dst: ast::FloatType::F16, src: ast::FloatType::F64 @@ -934,7 +945,7 @@ InstCvt: ast::Instruction> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: Some(r), - flush_to_zero: s.is_some(), + flush_to_zero: Some(s.is_some()), saturate: s.is_some(), dst: ast::FloatType::F32, src: ast::FloatType::F64 @@ -945,7 +956,7 @@ InstCvt: ast::Instruction> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: r, - flush_to_zero: false, + flush_to_zero: None, saturate: s.is_some(), dst: ast::FloatType::F64, src: ast::FloatType::F64 @@ -1082,19 +1093,19 @@ InstCall: ast::Instruction> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs InstAbs: ast::Instruction> = { "abs" => { - ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: false, typ: t }, a) + ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: None, typ: t }, a) }, "abs" ".f32" => { - ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: f.is_some(), typ: ast::ScalarType::F32 }, a) + ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: Some(f.is_some()), typ: ast::ScalarType::F32 }, a) }, "abs" ".f64" => { - ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: false, typ: ast::ScalarType::F64 }, a) + ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: None, typ: ast::ScalarType::F64 }, a) }, "abs" ".f16" => { - ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: f.is_some(), typ: ast::ScalarType::F16 }, a) + ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: Some(f.is_some()), typ: ast::ScalarType::F16 }, a) }, "abs" ".f16x2" => { - todo!() + ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: Some(f.is_some()), typ: ast::ScalarType::F16x2 }, a) }, }; @@ -1128,7 +1139,7 @@ InstRcp: ast::Instruction> = { "rcp" ".f32" => { let details = ast::RcpDetails { rounding, - flush_to_zero: ftz.is_some(), + flush_to_zero: Some(ftz.is_some()), is_f64: false, }; ast::Instruction::Rcp(details, a) @@ -1136,7 +1147,7 @@ InstRcp: ast::Instruction> = { "rcp" ".f64" => { let details = ast::RcpDetails { rounding: Some(rn), - flush_to_zero: false, + flush_to_zero: None, is_f64: true, }; ast::Instruction::Rcp(details, a) @@ -1173,16 +1184,16 @@ MinMaxDetails: ast::MinMaxDetails = { => ast::MinMaxDetails::Unsigned(t), => ast::MinMaxDetails::Signed(t), ".f32" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F32 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F32 } ), ".f64" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ ftz: false, nan: false, typ: ast::FloatType::F64 } + ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::FloatType::F64 } ), ".f16" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16 } ), ".f16x2" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16x2 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16x2 } ) } @@ -1203,25 +1214,25 @@ ArithFloat: ast::ArithFloat = { ".f32" => ast::ArithFloat { typ: ast::FloatType::F32, rounding: rn, - flush_to_zero: ftz.is_some(), + flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".f64" => ast::ArithFloat { typ: ast::FloatType::F64, rounding: rn, - flush_to_zero: false, + flush_to_zero: None, saturate: false, }, ".f16" => ast::ArithFloat { typ: ast::FloatType::F16, rounding: rn.map(|_| ast::RoundingMode::NearestEven), - flush_to_zero: ftz.is_some(), + flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".f16x2" => ast::ArithFloat { typ: ast::FloatType::F16x2, rounding: rn.map(|_| ast::RoundingMode::NearestEven), - flush_to_zero: ftz.is_some(), + flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index b4ae149..1b27ecc 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -81,6 +81,10 @@ test_ptx!(global_array, [0xDEADu32], [1u32]); test_ptx!(extern_shared, [127u64], [127u64]); test_ptx!(extern_shared_call, [121u64], [123u64]); test_ptx!(rcp, [2f32], [0.5f32]); +// 0b1_00000000_10000000000000000000000u32 is a large denormal +// 0x3f000000 is 0.5 +test_ptx!(mul_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0u32]); +test_ptx!(mul_non_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0b1_00000000_01000000000000000000000u32]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/mul_ftz.ptx b/ptx/src/test/spirv_run/mul_ftz.ptx new file mode 100644 index 0000000..eb24215 --- /dev/null +++ b/ptx/src/test/spirv_run/mul_ftz.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry mul_ftz( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 temp1; + .reg .f32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 temp1, [in_addr]; + ld.f32 temp2, [in_addr+4]; + mul.ftz.f32 temp1, temp1, temp2; + st.f32 [out_addr], temp1; + ret; +} diff --git a/ptx/src/test/spirv_run/mul_ftz.spvtxt b/ptx/src/test/spirv_run/mul_ftz.spvtxt new file mode 100644 index 0000000..e114374 --- /dev/null +++ b/ptx/src/test/spirv_run/mul_ftz.spvtxt @@ -0,0 +1,46 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %25 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "mul_lo" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %28 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_2 = OpConstant %ulong 2 + %1 = OpFunction %void None %28 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %23 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %21 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %16 = OpIMul %ulong %17 %ulong_2 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %22 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mul_non_ftz.ptx b/ptx/src/test/spirv_run/mul_non_ftz.ptx new file mode 100644 index 0000000..31cd14c --- /dev/null +++ b/ptx/src/test/spirv_run/mul_non_ftz.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry mul_non_ftz( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 temp1; + .reg .f32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 temp1, [in_addr]; + ld.f32 temp2, [in_addr+4]; + mul.f32 temp1, temp1, temp2; + st.f32 [out_addr], temp1; + ret; +} diff --git a/ptx/src/test/spirv_run/mul_non_ftz.spvtxt b/ptx/src/test/spirv_run/mul_non_ftz.spvtxt new file mode 100644 index 0000000..78153aa --- /dev/null +++ b/ptx/src/test/spirv_run/mul_non_ftz.spvtxt @@ -0,0 +1,61 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + OpCapability DenormFlushToZero + OpCapability DenormPreserve + OpExtension "SPV_KHR_float_controls" + %30 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "mul_non_ftz" + OpExecutionMode %1 DenormPreserve 32 + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %33 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Generic_float = OpTypePointer Generic %float + %ulong_4 = OpConstant %ulong 4 + %1 = OpFunction %void None %33 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %28 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_float Function + %7 = OpVariable %_ptr_Function_float Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %25 = OpConvertUToPtr %_ptr_Generic_float %15 + %14 = OpLoad %float %25 + OpStore %6 %14 + %17 = OpLoad %ulong %4 + %24 = OpIAdd %ulong %17 %ulong_4 + %26 = OpConvertUToPtr %_ptr_Generic_float %24 + %16 = OpLoad %float %26 + OpStore %7 %16 + %19 = OpLoad %float %6 + %20 = OpLoad %float %7 + %18 = OpFMul %float %19 %20 + OpStore %6 %18 + %21 = OpLoad %ulong %5 + %22 = OpLoad %float %6 + %27 = OpConvertUToPtr %_ptr_Generic_float %21 + OpStore %27 %22 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/rcp.spvtxt b/ptx/src/test/spirv_run/rcp.spvtxt index 08b3e6e..fd10ff1 100644 --- a/ptx/src/test/spirv_run/rcp.spvtxt +++ b/ptx/src/test/spirv_run/rcp.spvtxt @@ -7,9 +7,11 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 + OpExtension "SPV_KHR_float_controls" %23 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "rcp" + OpExecutionMode %1 DenormPreserve 32 OpDecorate %15 FPFastMathMode AllowRecip %void = OpTypeVoid %ulong = OpTypeInt 64 0 diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index cccf6ad..604b4ef 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,7 +1,7 @@ use crate::ast; use half::f16; use rspirv::{binary::Disassemble, dr}; -use std::{borrow::Cow, iter, mem}; +use std::{borrow::Cow, hash::Hash, iter, mem}; use std::{ collections::{hash_map, HashMap, HashSet}, convert::TryFrom, @@ -438,6 +438,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result(ast: ast::Module<'a>) -> Result( globals, body: Some(statements), }) => { - let call_key = match func_decl { - ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name), - ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id), - }; + let call_key = CallgraphKey::new(&func_decl); let statements = statements .into_iter() .map(|statement| match statement { @@ -563,10 +562,7 @@ fn convert_dynamic_shared_memory_usage<'input>( globals, body: Some(statements), }) => { - let call_key = match func_decl { - ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name), - ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id), - }; + let call_key = CallgraphKey::new(&func_decl); if !methods_using_extern_shared.contains(&call_key) { return Directive::Method(Function { func_decl, @@ -726,12 +722,171 @@ fn get_callers_of_extern_shared_single<'a>( } } +type DenormCountMap = HashMap; + +fn denorm_count_map_update(map: &mut DenormCountMap, key: T, value: bool) { + let num_value = if value { 1 } else { -1 }; + denorm_count_map_update_impl(map, key, num_value); +} + +fn denorm_count_map_update_impl( + map: &mut DenormCountMap, + key: T, + num_value: isize, +) { + match map.entry(key) { + hash_map::Entry::Occupied(mut counter) => { + *(counter.get_mut()) += num_value; + } + hash_map::Entry::Vacant(entry) => { + entry.insert(num_value); + } + } +} + +fn denorm_count_map_merge( + dst: &mut DenormCountMap, + src: &DenormCountMap, +) { + for (k, count) in src { + denorm_count_map_update_impl(dst, *k, *count); + } +} + +// HACK ALERT! +// This function is a "good enough" heuristic of whetever to mark f16/f32 operations +// in the kernel as flushing denorms to zero or preserving them +// PTX support per-instruction ftz information. Unfortunately SPIR-V has no +// such capability, so instead we guesstimate which use is more common in the kernel +// and emit suitable execution mode +fn compute_denorm_information<'input>( + module: &[Directive<'input>], +) -> HashMap<&'input str, HashMap> { + let mut direct_func_calls = MultiHashMap::new(); + let mut denorm_methods = HashMap::new(); + for directive in module.iter() { + match directive { + Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {} + Directive::Method(Function { + func_decl, + body: Some(statements), + .. + }) => { + let mut flush_counter = DenormCountMap::new(); + let method_key = CallgraphKey::new(func_decl); + for statement in statements { + match statement { + Statement::Instruction(inst) => { + if let Some((flush, width)) = inst.flush_to_zero() { + denorm_count_map_update(&mut flush_counter, width, flush); + } + } + Statement::LoadVar(_, _) => {} + Statement::StoreVar(_, _) => {} + Statement::Call(ResolvedCall { func, .. }) => { + multi_hash_map_append(&mut direct_func_calls, method_key, *func); + } + Statement::Composite(_) => {} + Statement::Conditional(_) => {} + Statement::Conversion(_) => {} + Statement::Constant(_) => {} + Statement::RetValue(_, _) => {} + Statement::Undef(_, _) => {} + Statement::Label(_) => {} + Statement::Variable(_) => {} + } + } + denorm_methods.insert(method_key, flush_counter); + } + } + } + let summed_denorm_methods = sum_up_denorm_use(module, denorm_methods, &direct_func_calls); + summed_denorm_methods + .into_iter() + .filter_map(|(name, v)| { + let width_to_denorm = v + .into_iter() + .map(|(k, ftz_over_preserve)| { + let mode = if ftz_over_preserve > 0 { + spirv::ExecutionMode::DenormFlushToZero + } else { + spirv::ExecutionMode::DenormPreserve + }; + (k, mode) + }) + .collect(); + Some((name, width_to_denorm)) + }) + .collect() +} + +fn sum_up_denorm_use<'input>( + module: &[Directive<'input>], + denorm_methods: HashMap, DenormCountMap>, + direct_func_calls: &MultiHashMap, spirv::Word>, +) -> HashMap<&'input str, DenormCountMap> { + let mut result = HashMap::new(); + let empty = Vec::new(); + for (method_key, denorm_map) in denorm_methods.iter() { + match method_key { + CallgraphKey::Kernel(name) => { + let mut sum = denorm_map.clone(); + let mut visited = HashSet::new(); + for child in direct_func_calls + .get(&CallgraphKey::Kernel(name)) + .unwrap_or(&empty) + { + sum_up_denorm_use_single( + &denorm_methods, + direct_func_calls, + &mut sum, + &mut visited, + *child, + ); + } + result.insert(*name, sum); + } + CallgraphKey::Func(_) => {} + } + } + result +} + +fn sum_up_denorm_use_single<'input>( + denorm_methods: &HashMap, DenormCountMap>, + direct_func_calls: &MultiHashMap, spirv::Word>, + sum: &mut DenormCountMap, + visited: &mut HashSet, + current: spirv::Word, +) { + if !visited.insert(current) { + return; + } + if let Some(denorm_map) = denorm_methods.get(&CallgraphKey::Func(current)) { + denorm_count_map_merge(sum, denorm_map); + } + if let Some(children) = direct_func_calls.get(&CallgraphKey::Func(current)) { + for child in children { + sum_up_denorm_use_single(denorm_methods, direct_func_calls, sum, visited, *child); + } + } +} + #[derive(Hash, PartialEq, Eq, Copy, Clone)] enum CallgraphKey<'input> { Kernel(&'input str), Func(spirv::Word), } +impl<'input> CallgraphKey<'input> { + fn new(decl: &ast::MethodDecl<'input, spirv::Word>) -> Self { + match decl { + ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name), + ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(*id), + } + } +} + fn emit_builtins( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -764,6 +919,7 @@ fn emit_function_header<'a>( map: &mut TypeWordMap, global: &GlobalStringIdResolver<'a>, func_directive: ast::MethodDecl, + denorm_information: &HashMap<&'a str, HashMap>, kernel_info: &mut HashMap, ) -> Result<(), TranslateError> { if let ast::MethodDecl::Kernel { @@ -797,6 +953,11 @@ fn emit_function_header<'a>( .collect::>(); global_variables.append(&mut interface); builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables); + if let Some(exec_modes) = denorm_information.get(name) { + for (size_of, exec_mode) in exec_modes { + builder.execution_mode(fn_id, *exec_mode, [(*size_of as u32) * 8]) + } + } fn_id } ast::MethodDecl::Func(_, name, _) => name, @@ -844,9 +1005,14 @@ fn emit_capabilities(builder: &mut dr::Builder) { builder.capability(spirv::Capability::Int64); builder.capability(spirv::Capability::Float16); builder.capability(spirv::Capability::Float64); + builder.capability(spirv::Capability::DenormFlushToZero); + builder.capability(spirv::Capability::DenormPreserve); } -fn emit_extensions(_: &mut dr::Builder) {} +// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html +fn emit_extensions(builder: &mut dr::Builder) { + builder.extension("SPV_KHR_float_controls"); +} fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word { builder.ext_inst_import("OpenCL.std") @@ -2088,7 +2254,7 @@ fn emit_function_body_ops( ast::MulDetails::Unsigned(ref ctr) => { emit_mul_uint(builder, map, opencl, ctr, arg)? } - ast::MulDetails::Float(_) => todo!(), + ast::MulDetails::Float(ref ctr) => emit_mul_float(builder, map, ctr, arg)?, }, ast::Instruction::Add(add, arg) => match add { ast::ArithDetails::Signed(ref desc) => { @@ -2215,15 +2381,27 @@ fn emit_function_body_ops( Ok(()) } +fn emit_mul_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + ctr: &ast::ArithFloat, + arg: &ast::Arg3, +) -> Result<(), dr::Error> { + if ctr.saturate { + todo!() + } + let result_type = map.get_or_add_scalar(builder, ctr.typ.into()); + builder.f_mul(result_type, Some(arg.dst), arg.src1, arg.src2)?; + emit_rounding_decoration(builder, arg.dst, ctr.rounding); + Ok(()) +} + fn emit_rcp( builder: &mut dr::Builder, map: &mut TypeWordMap, desc: &ast::RcpDetails, a: &ast::Arg2, ) -> Result<(), TranslateError> { - if desc.flush_to_zero { - todo!() - } let (instr_type, constant) = if desc.is_f64 { (ast::ScalarType::F64, vec_repr(1.0f64)) } else { @@ -2360,9 +2538,6 @@ fn emit_add_float( desc: &ast::ArithFloat, arg: &ast::Arg3, ) -> Result<(), dr::Error> { - if desc.flush_to_zero { - todo!() - } let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); builder.f_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?; emit_rounding_decoration(builder, arg.dst, desc.rounding); @@ -2375,9 +2550,6 @@ fn emit_sub_float( desc: &ast::ArithFloat, arg: &ast::Arg3, ) -> Result<(), dr::Error> { - if desc.flush_to_zero { - todo!() - } let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?; emit_rounding_decoration(builder, arg.dst, desc.rounding); @@ -2441,7 +2613,7 @@ fn emit_cvt( if desc.dst == desc.src { return Ok(()); } - if desc.saturate || desc.flush_to_zero { + if desc.saturate { todo!() } let dest_t: ast::ScalarType = desc.dst.into(); @@ -2450,7 +2622,7 @@ fn emit_cvt( emit_rounding_decoration(builder, arg.dst, desc.rounding); } ast::CvtDetails::FloatFromInt(desc) => { - if desc.saturate || desc.flush_to_zero { + if desc.saturate { todo!() } let dest_t: ast::ScalarType = desc.dst.into(); @@ -2463,9 +2635,6 @@ fn emit_cvt( emit_rounding_decoration(builder, arg.dst, desc.rounding); } ast::CvtDetails::IntFromFloat(desc) => { - if desc.flush_to_zero { - todo!() - } let dest_t: ast::ScalarType = desc.dst.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); if desc.dst.is_signed() { @@ -2561,9 +2730,6 @@ fn emit_setp( setp: &ast::SetpData, arg: &ast::Arg4Setp, ) -> Result<(), dr::Error> { - if setp.flush_to_zero { - todo!() - } let result_type = map.get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred)); let result_id = Some(arg.dst1); let operand_1 = arg.src1; @@ -4122,6 +4288,73 @@ impl ast::Instruction { | ast::Instruction::Mad(_, _) => None, } } + + // .wide instructions don't support ftz, so it's enough to just look at the + // type declared by the instruction + fn flush_to_zero(&self) -> Option<(bool, u8)> { + match self { + ast::Instruction::Ld(_, _) => None, + ast::Instruction::St(_, _) => None, + ast::Instruction::Mov(_, _) => None, + ast::Instruction::Not(_, _) => None, + ast::Instruction::Bra(_, _) => None, + ast::Instruction::Shl(_, _) => None, + ast::Instruction::Shr(_, _) => None, + ast::Instruction::Ret(_) => None, + ast::Instruction::Call(_) => None, + ast::Instruction::Or(_, _) => None, + ast::Instruction::Cvta(_, _) => None, + ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None, + ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None, + ast::Instruction::Add(ast::ArithDetails::Signed(_), _) => None, + ast::Instruction::Add(ast::ArithDetails::Unsigned(_), _) => None, + ast::Instruction::Mul(ast::MulDetails::Unsigned(_), _) => None, + ast::Instruction::Mul(ast::MulDetails::Signed(_), _) => None, + ast::Instruction::Mad(ast::MulDetails::Unsigned(_), _) => None, + ast::Instruction::Mad(ast::MulDetails::Signed(_), _) => None, + ast::Instruction::Min(ast::MinMaxDetails::Signed(_), _) => None, + ast::Instruction::Min(ast::MinMaxDetails::Unsigned(_), _) => None, + ast::Instruction::Max(ast::MinMaxDetails::Signed(_), _) => None, + ast::Instruction::Max(ast::MinMaxDetails::Unsigned(_), _) => None, + ast::Instruction::Cvt(ast::CvtDetails::IntFromInt(_), _) => None, + ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _) + | ast::Instruction::Add(ast::ArithDetails::Float(float_control), _) + | ast::Instruction::Mul(ast::MulDetails::Float(float_control), _) + | ast::Instruction::Mad(ast::MulDetails::Float(float_control), _) => float_control + .flush_to_zero + .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())), + ast::Instruction::Setp(details, _) => details + .flush_to_zero + .map(|ftz| (ftz, details.typ.size_of())), + ast::Instruction::SetpBool(details, _) => details + .flush_to_zero + .map(|ftz| (ftz, details.typ.size_of())), + ast::Instruction::Abs(details, _) => details + .flush_to_zero + .map(|ftz| (ftz, details.typ.size_of())), + ast::Instruction::Min(ast::MinMaxDetails::Float(float_control), _) + | ast::Instruction::Max(ast::MinMaxDetails::Float(float_control), _) => float_control + .flush_to_zero + .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())), + ast::Instruction::Rcp(details, _) => details + .flush_to_zero + .map(|ftz| (ftz, if details.is_f64 { 8 } else { 4 })), + // Modifier .ftz can only be specified when either .dtype or .atype + // is .f32 and applies only to single precision (.f32) inputs and results. + ast::Instruction::Cvt( + ast::CvtDetails::FloatFromFloat(ast::CvtDesc { flush_to_zero, .. }), + _, + ) + | ast::Instruction::Cvt( + ast::CvtDetails::FloatFromInt(ast::CvtDesc { flush_to_zero, .. }), + _, + ) + | ast::Instruction::Cvt( + ast::CvtDetails::IntFromFloat(ast::CvtDesc { flush_to_zero, .. }), + _, + ) => flush_to_zero.map(|ftz| (ftz, 4)), + } + } } impl VisitVariableExpanded for ast::Instruction {