diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 73be00a..524196a 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -451,13 +451,14 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result(ast: ast::Module<'a>) -> Result( call_map: &HashMap<&str, HashSet>, denorm_information: &HashMap< ast::MethodName<'input, spirv::Word>, HashMap, >, -) -> CString { +) -> (CString, bool) { let denorm_counts = denorm_information .iter() .map(|(method, meth_denorm)| { @@ -509,9 +510,12 @@ fn emit_denorm_build_string<'input>( } } if flush_over_preserve > 0 { - CString::new("-ze-take-global-address -ze-denorms-are-zero").unwrap() + ( + CString::new("-ze-take-global-address -ze-denorms-are-zero").unwrap(), + true, + ) } else { - CString::new("-ze-take-global-address").unwrap() + (CString::new("-ze-take-global-address").unwrap(), false) } } @@ -520,10 +524,7 @@ fn emit_directives<'input>( map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver<'input>, opencl_id: spirv::Word, - denorm_information: &HashMap< - ast::MethodName<'input, spirv::Word>, - HashMap, - >, + should_flush_denorms: bool, call_map: &HashMap<&'input str, HashSet>, directives: Vec>, kernel_info: &mut HashMap, @@ -555,12 +556,28 @@ fn emit_directives<'input>( &id_defs, &f.globals, &*func_decl, - &denorm_information, call_map, &directives, kernel_info, )?; if func_decl.name.is_kernel() { + if should_flush_denorms { + builder.execution_mode( + fn_id, + spirv_headers::ExecutionMode::DenormFlushToZero, + [16], + ); + builder.execution_mode( + fn_id, + spirv_headers::ExecutionMode::DenormFlushToZero, + [32], + ); + builder.execution_mode( + fn_id, + spirv_headers::ExecutionMode::DenormFlushToZero, + [64], + ); + } // FP contraction happens when compiling source -> PTX and is illegal at this stage (unless you force it in cuModuleLoadDataEx) builder.execution_mode(fn_id, spirv_headers::ExecutionMode::ContractionOff, []); for t in f.tuning.iter() { @@ -1017,10 +1034,6 @@ fn emit_function_header<'a>( defined_globals: &GlobalStringIdResolver<'a>, synthetic_globals: &[ast::Variable], func_decl: &ast::MethodDeclaration<'a, spirv::Word>, - _denorm_information: &HashMap< - ast::MethodName<'a, spirv::Word>, - HashMap, - >, call_map: &HashMap<&'a str, HashSet>, direcitves: &[Directive], kernel_info: &mut HashMap, @@ -1095,21 +1108,6 @@ fn emit_function_header<'a>( spirv::FunctionControl::NONE, func_type, )?; - // TODO: re-enable when Intel float control extension works - /* - if let Some(denorm_modes) = denorm_information.get(&func_decl.name) { - for (size_of, denorm_mode) in denorm_modes { - builder.decorate( - fn_id, - spirv::Decoration::FunctionDenormModeINTEL, - [ - dr::Operand::LiteralInt32((*size_of as u32) * 8), - dr::Operand::FPDenormMode(*denorm_mode), - ], - ) - } - } - */ for (name, typ) in func_decl.effective_input_arguments() { let result_type = map.get_or_add(builder, typ); builder.function_parameter(Some(name), result_type)?;