From 62d14cdffe57134fc89099672ee2954ee413b440 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 7 Nov 2020 16:14:37 +0100 Subject: [PATCH] Fix ftz behavior slightly --- ptx/src/test/spirv_run/mod.rs | 14 ++- ptx/src/test/spirv_run/mul_ftz.spvtxt | 119 ++++++++++++-------------- ptx/src/translate.rs | 55 ++++++++++-- 3 files changed, 114 insertions(+), 74 deletions(-) diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 5bbe45a..bd74508 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -83,8 +83,11 @@ test_ptx!(extern_shared_call, [121u64], [123u64]); test_ptx!(rcp, [2f32], [0.5f32]); // 0b1_00000000_10000000000000000000000u32 is a large denormal // 0x3f000000 is 0.5 -// TODO: mul_ftz fails because IGC does not yet handle SPV_INTEL_float_controls2 -// test_ptx!(mul_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0u32]); +test_ptx!( + mul_ftz, + [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], + [0b1_00000000_00000000000000000000000u32] +); test_ptx!( mul_non_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], @@ -196,7 +199,12 @@ fn run_spirv + ze::SafeRepr + Copy + Debug>( let (module, maybe_log) = match module.should_link_ptx_impl { Some(ptx_impl) => ze::Module::build_link_spirv(&mut ctx, &dev, &[ptx_impl, byte_il]), None => { - let (module, log) = ze::Module::build_spirv(&mut ctx, &dev, byte_il, None); + let (module, log) = ze::Module::build_spirv( + &mut ctx, + &dev, + byte_il, + Some(module.build_options.as_c_str()), + ); (module, Some(log)) } }; diff --git a/ptx/src/test/spirv_run/mul_ftz.spvtxt b/ptx/src/test/spirv_run/mul_ftz.spvtxt index 56cec5a..3e80ae3 100644 --- a/ptx/src/test/spirv_run/mul_ftz.spvtxt +++ b/ptx/src/test/spirv_run/mul_ftz.spvtxt @@ -1,64 +1,55 @@ -; SPIR-V -; Version: 1.3 -; Generator: rspirv -; Bound: 38 -OpCapability GenericPointer -OpCapability Linkage -OpCapability Addresses -OpCapability Kernel -OpCapability Int8 -OpCapability Int16 -OpCapability Int64 -OpCapability Float16 -OpCapability Float64 -; OpCapability FunctionFloatControlINTEL -; OpExtension "SPV_INTEL_float_controls2" -%30 = OpExtInstImport "OpenCL.std" -OpMemoryModel Physical64 OpenCL -OpEntryPoint Kernel %1 "mul_ftz" -OpDecorate %1 FunctionDenormModeINTEL 32 FlushToZero -%31 = OpTypeVoid -%32 = OpTypeInt 64 0 -%33 = OpTypeFunction %31 %32 %32 -%34 = OpTypePointer Function %32 -%35 = OpTypeFloat 32 -%36 = OpTypePointer Function %35 -%37 = OpTypePointer Generic %35 -%23 = OpConstant %32 4 -%1 = OpFunction %31 None %33 -%8 = OpFunctionParameter %32 -%9 = OpFunctionParameter %32 -%28 = OpLabel -%2 = OpVariable %34 Function -%3 = OpVariable %34 Function -%4 = OpVariable %34 Function -%5 = OpVariable %34 Function -%6 = OpVariable %36 Function -%7 = OpVariable %36 Function -OpStore %2 %8 -OpStore %3 %9 -%11 = OpLoad %32 %2 -%10 = OpCopyObject %32 %11 -OpStore %4 %10 -%13 = OpLoad %32 %3 -%12 = OpCopyObject %32 %13 -OpStore %5 %12 -%15 = OpLoad %32 %4 -%25 = OpConvertUToPtr %37 %15 -%14 = OpLoad %35 %25 -OpStore %6 %14 -%17 = OpLoad %32 %4 -%24 = OpIAdd %32 %17 %23 -%26 = OpConvertUToPtr %37 %24 -%16 = OpLoad %35 %26 -OpStore %7 %16 -%19 = OpLoad %35 %6 -%20 = OpLoad %35 %7 -%18 = OpFMul %35 %19 %20 -OpStore %6 %18 -%21 = OpLoad %32 %5 -%22 = OpLoad %35 %6 -%27 = OpConvertUToPtr %37 %21 -OpStore %27 %22 -OpReturn -OpFunctionEnd + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %28 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "mul_ftz" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %31 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Generic_float = OpTypePointer Generic %float + %ulong_4 = OpConstant %ulong 4 + %1 = OpFunction %void None %31 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %26 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_float Function + %7 = OpVariable %_ptr_Function_float Function + OpStore %2 %8 + OpStore %3 %9 + %10 = OpLoad %ulong %2 + OpStore %4 %10 + %11 = OpLoad %ulong %3 + OpStore %5 %11 + %13 = OpLoad %ulong %4 + %23 = OpConvertUToPtr %_ptr_Generic_float %13 + %12 = OpLoad %float %23 + OpStore %6 %12 + %15 = OpLoad %ulong %4 + %22 = OpIAdd %ulong %15 %ulong_4 + %24 = OpConvertUToPtr %_ptr_Generic_float %22 + %14 = OpLoad %float %24 + OpStore %7 %14 + %17 = OpLoad %float %6 + %18 = OpLoad %float %7 + %16 = OpFMul %float %17 %18 + OpStore %6 %16 + %19 = OpLoad %ulong %5 + %20 = OpLoad %float %6 + %25 = OpConvertUToPtr %_ptr_Generic_float %19 + OpStore %25 %20 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 365d1e8..c0e15f2 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,7 +1,7 @@ use crate::ast; use half::f16; use rspirv::{binary::Disassemble, dr}; -use std::{borrow::Cow, convert::TryFrom, hash::Hash, iter, mem}; +use std::{borrow::Cow, convert::TryFrom, ffi::CString, hash::Hash, iter, mem}; use std::{ collections::{hash_map, HashMap, HashSet}, convert::TryInto, @@ -448,6 +448,7 @@ pub struct Module { pub spirv: dr::Module, pub kernel_info: HashMap, pub should_link_ptx_impl: Option<&'static [u8]>, + pub build_options: CString, } pub struct KernelInfo { @@ -484,6 +485,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result(ast: ast::Module<'a>) -> Result>, + denorm_information: &HashMap>, +) -> CString { + let denorm_counts = denorm_information + .iter() + .map(|(method, meth_denorm)| { + let f16_count = meth_denorm + .get(&(mem::size_of::() as u8)) + .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) + .1; + let f32_count = meth_denorm + .get(&(mem::size_of::() as u8)) + .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) + .1; + (method, (f16_count + f32_count)) + }) + .collect::>(); + let mut flush_over_preserve = 0; + for (kernel, children) in call_map { + flush_over_preserve += *denorm_counts.get(&MethodName::Kernel(kernel)).unwrap_or(&0); + for child_fn in children { + flush_over_preserve += *denorm_counts + .get(&MethodName::Func(*child_fn)) + .unwrap_or(&0); + } + } + if flush_over_preserve > 0 { + CString::new("-cl-denorms-are-zero").unwrap() + } else { + CString::default() + } +} + fn emit_directives<'input>( builder: &mut dr::Builder, map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver<'input>, opencl_id: spirv::Word, - denorm_information: &HashMap, HashMap>, + denorm_information: &HashMap, HashMap>, call_map: &HashMap<&'input str, HashSet>, directives: Vec, kernel_info: &mut HashMap, @@ -579,6 +617,9 @@ fn get_call_map<'input>( .. }) => { let call_key = MethodName::new(&func_decl); + if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) { + entry.insert(Vec::new()); + } for statement in statements { match statement { Statement::Call(call) => { @@ -895,7 +936,7 @@ fn denorm_count_map_update_impl( // and emit suitable execution mode fn compute_denorm_information<'input>( module: &[Directive<'input>], -) -> HashMap, HashMap> { +) -> HashMap, HashMap> { let mut denorm_methods = HashMap::new(); for directive in module { match directive { @@ -937,13 +978,13 @@ fn compute_denorm_information<'input>( .map(|(name, v)| { let width_to_denorm = v .into_iter() - .map(|(k, ftz_over_preserve)| { - let mode = if ftz_over_preserve > 0 { + .map(|(k, flush_over_preserve)| { + let mode = if flush_over_preserve > 0 { spirv::FPDenormMode::FlushToZero } else { spirv::FPDenormMode::Preserve }; - (k, mode) + (k, (mode, flush_over_preserve)) }) .collect(); (name, width_to_denorm) @@ -999,7 +1040,7 @@ fn emit_function_header<'a>( defined_globals: &GlobalStringIdResolver<'a>, synthetic_globals: &[ast::Variable], func_decl: &SpirvMethodDecl<'a>, - _denorm_information: &HashMap, HashMap>, + _denorm_information: &HashMap, HashMap>, call_map: &HashMap<&'a str, HashSet>, direcitves: &[Directive], kernel_info: &mut HashMap,