From 8e409254b3f30577a840885f6d7a56b27f4c2611 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 5 Nov 2020 21:39:34 +0100 Subject: [PATCH] Fix same width float-to-float conversions --- ptx/src/ptx.lalrpop | 4 +- ptx/src/test/spirv_run/cvt_rni.ptx | 25 +++++++++++ ptx/src/test/spirv_run/cvt_rni.spvtxt | 63 +++++++++++++++++++++++++++ ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/translate.rs | 57 ++++++++++++++++++++---- 5 files changed, 139 insertions(+), 11 deletions(-) create mode 100644 ptx/src/test/spirv_run/cvt_rni.ptx create mode 100644 ptx/src/test/spirv_run/cvt_rni.spvtxt diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 584ef84..31c2356 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -1068,7 +1068,7 @@ InstCvt: ast::Instruction> = { } ), a) }, - "cvt" ".f32" ".f32" => { + "cvt" ".f32" ".f32" => { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: r, @@ -1112,7 +1112,7 @@ InstCvt: ast::Instruction> = { } ), a) }, - "cvt" ".f64" ".f64" => { + "cvt" ".f64" ".f64" => { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: r, diff --git a/ptx/src/test/spirv_run/cvt_rni.ptx b/ptx/src/test/spirv_run/cvt_rni.ptx new file mode 100644 index 0000000..ecf20f8 --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_rni.ptx @@ -0,0 +1,25 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry cvt_rni( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 temp1; + .reg .f32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 temp1, [in_addr]; + ld.f32 temp2, [in_addr+4]; + cvt.rni.f32.f32 temp1, temp1; + cvt.rni.f32.f32 temp2, temp2; + st.f32 [out_addr], temp1; + st.f32 [out_addr+4], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/cvt_rni.spvtxt b/ptx/src/test/spirv_run/cvt_rni.spvtxt new file mode 100644 index 0000000..cad84a2 --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_rni.spvtxt @@ -0,0 +1,63 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %34 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "cvt_rni" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %37 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Generic_float = OpTypePointer Generic %float + %ulong_4 = OpConstant %ulong 4 + %ulong_4_0 = OpConstant %ulong 4 + %1 = OpFunction %void None %37 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %32 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_float Function + %7 = OpVariable %_ptr_Function_float Function + OpStore %2 %8 + OpStore %3 %9 + %10 = OpLoad %ulong %2 + OpStore %4 %10 + %11 = OpLoad %ulong %3 + OpStore %5 %11 + %13 = OpLoad %ulong %4 + %28 = OpConvertUToPtr %_ptr_Generic_float %13 + %12 = OpLoad %float %28 + OpStore %6 %12 + %15 = OpLoad %ulong %4 + %25 = OpIAdd %ulong %15 %ulong_4 + %29 = OpConvertUToPtr %_ptr_Generic_float %25 + %14 = OpLoad %float %29 + OpStore %7 %14 + %17 = OpLoad %float %6 + %16 = OpExtInst %float %34 rint %17 + OpStore %6 %16 + %19 = OpLoad %float %7 + %18 = OpExtInst %float %34 rint %19 + OpStore %7 %18 + %20 = OpLoad %ulong %5 + %21 = OpLoad %float %6 + %30 = OpConvertUToPtr %_ptr_Generic_float %20 + OpStore %30 %21 + %22 = OpLoad %ulong %5 + %23 = OpLoad %float %7 + %27 = OpIAdd %ulong %22 %ulong_4_0 + %31 = OpConvertUToPtr %_ptr_Generic_float %27 + OpStore %31 %23 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 3fa82ba..163caac 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -108,6 +108,7 @@ test_ptx!(sin, [std::f32::consts::PI/2f32], [1f32]); test_ptx!(cos, [std::f32::consts::PI], [-1f32]); test_ptx!(lg2, [512f32], [9f32]); test_ptx!(ex2, [10f32], [1024f32]); +test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]); struct DisplayError { err: T, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7a0dd08..9519951 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -2813,7 +2813,7 @@ fn emit_function_body_ops( } } ast::Instruction::Cvt(dets, arg) => { - emit_cvt(builder, map, dets, arg)?; + emit_cvt(builder, map, opencl, dets, arg)?; } ast::Instruction::Cvta(_, arg) => { // This would be only meaningful if const/slm/global pointers @@ -3410,21 +3410,63 @@ fn emit_max( fn emit_cvt( builder: &mut dr::Builder, map: &mut TypeWordMap, + opencl: spirv::Word, dets: &ast::CvtDetails, arg: &ast::Arg2, ) -> Result<(), TranslateError> { match dets { ast::CvtDetails::FloatFromFloat(desc) => { - if desc.dst == desc.src { - return Ok(()); - } if desc.saturate { todo!() } let dest_t: ast::ScalarType = desc.dst.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - builder.f_convert(result_type, Some(arg.dst), arg.src)?; - emit_rounding_decoration(builder, arg.dst, desc.rounding); + if desc.dst == desc.src { + match desc.rounding { + Some(ast::RoundingMode::NearestEven) => { + builder.ext_inst( + result_type, + Some(arg.dst), + opencl, + spirv::CLOp::rint as u32, + [arg.src], + )?; + } + Some(ast::RoundingMode::Zero) => { + builder.ext_inst( + result_type, + Some(arg.dst), + opencl, + spirv::CLOp::trunc as u32, + [arg.src], + )?; + } + Some(ast::RoundingMode::NegativeInf) => { + builder.ext_inst( + result_type, + Some(arg.dst), + opencl, + spirv::CLOp::floor as u32, + [arg.src], + )?; + } + Some(ast::RoundingMode::PositiveInf) => { + builder.ext_inst( + result_type, + Some(arg.dst), + opencl, + spirv::CLOp::ceil as u32, + [arg.src], + )?; + } + None => { + builder.copy_object(result_type, Some(arg.dst), arg.src)?; + } + } + } else { + builder.f_convert(result_type, Some(arg.dst), arg.src)?; + emit_rounding_decoration(builder, arg.dst, desc.rounding); + } } ast::CvtDetails::FloatFromInt(desc) => { if desc.saturate { @@ -3451,9 +3493,6 @@ fn emit_cvt( emit_saturating_decoration(builder, arg.dst, desc.saturate); } ast::CvtDetails::IntFromInt(desc) => { - if desc.dst == desc.src { - return Ok(()); - } let dest_t: ast::ScalarType = desc.dst.into(); let src_t: ast::ScalarType = desc.src.into(); // first do shortening/widening