diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index 976be8a..345d7fd 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -2266,22 +2266,50 @@ impl<'a> MethodEmitContext<'a> { let llvm_prefix = match data { ptx_parser::MinMaxDetails::Signed(..) => "llvm.smin", ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umin", - ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => { - "llvm.minimum" - } ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum", }; let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_())); let llvm_type = get_scalar_type(self.context, data.type_()); - self.emit_intrinsic( + + let a = self.resolver.value(arguments.src1)?; + let b = self.resolver.value(arguments.src2)?; + + let min = self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, - Some(arguments.dst), + None, Some(&data.type_().into()), - vec![ - (self.resolver.value(arguments.src1)?, llvm_type), - (self.resolver.value(arguments.src2)?, llvm_type), - ], + vec![(a, llvm_type), (b, llvm_type)], )?; + + if let ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { + nan: true, type_, .. + }) = data + { + let is_nan = unsafe { + LLVMBuildFCmp( + self.builder, + LLVMRealPredicate::LLVMRealUNO, + a, + b, + LLVM_UNNAMED.as_ptr(), + ) + }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildSelect( + self.builder, + is_nan, + LLVMConstReal(get_scalar_type(self.context, type_), f64::NAN), + min, + dst, + ) + }); + } else { + self.resolver.with_result(arguments.dst, |dst| unsafe { + let dst = CStr::from_ptr(dst); + LLVMSetValueName2(min, dst.as_ptr(), dst.count_bytes()); + min + }); + } Ok(()) } diff --git a/ptx/src/test/ll/min_f16.ll b/ptx/src/test/ll/min_f16.ll new file mode 100644 index 0000000..d5e4f92 --- /dev/null +++ b/ptx/src/test/ll/min_f16.ll @@ -0,0 +1,43 @@ +define amdgpu_kernel void @min_f16(ptr addrspace(4) byref(i64) %"36", ptr addrspace(4) byref(i64) %"37") #0 { + %"38" = alloca i64, align 8, addrspace(5) + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca half, align 2, addrspace(5) + %"41" = alloca half, align 2, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"35" + +"35": ; preds = %1 + %"42" = load i64, ptr addrspace(4) %"36", align 8 + store i64 %"42", ptr addrspace(5) %"38", align 8 + %"43" = load i64, ptr addrspace(4) %"37", align 8 + store i64 %"43", ptr addrspace(5) %"39", align 8 + %"45" = load i64, ptr addrspace(5) %"38", align 8 + %"54" = inttoptr i64 %"45" to ptr + %"53" = load i16, ptr %"54", align 2 + %"44" = bitcast i16 %"53" to half + store half %"44", ptr addrspace(5) %"40", align 2 + %"46" = load i64, ptr addrspace(5) %"38", align 8 + %"55" = inttoptr i64 %"46" to ptr + %"34" = getelementptr inbounds i8, ptr %"55", i64 2 + %"56" = load i16, ptr %"34", align 2 + %"47" = bitcast i16 %"56" to half + store half %"47", ptr addrspace(5) %"41", align 2 + %"49" = load half, ptr addrspace(5) %"40", align 2 + %"50" = load half, ptr addrspace(5) %"41", align 2 + %"48" = call half @llvm.minnum.f16(half %"49", half %"50") + store half %"48", ptr addrspace(5) %"40", align 2 + %"51" = load i64, ptr addrspace(5) %"39", align 8 + %"52" = load half, ptr addrspace(5) %"40", align 2 + %"57" = inttoptr i64 %"51" to ptr + %"58" = bitcast half %"52" to i16 + store i16 %"58", ptr %"57", align 2 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare half @llvm.minnum.f16(half, half) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } diff --git a/ptx/src/test/ll/min_nan_f16.ll b/ptx/src/test/ll/min_nan_f16.ll new file mode 100644 index 0000000..0ffd981 --- /dev/null +++ b/ptx/src/test/ll/min_nan_f16.ll @@ -0,0 +1,45 @@ +define amdgpu_kernel void @min_nan_f16(ptr addrspace(4) byref(i64) %"36", ptr addrspace(4) byref(i64) %"37") #0 { + %"38" = alloca i64, align 8, addrspace(5) + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca half, align 2, addrspace(5) + %"41" = alloca half, align 2, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"35" + +"35": ; preds = %1 + %"42" = load i64, ptr addrspace(4) %"36", align 8 + store i64 %"42", ptr addrspace(5) %"38", align 8 + %"43" = load i64, ptr addrspace(4) %"37", align 8 + store i64 %"43", ptr addrspace(5) %"39", align 8 + %"45" = load i64, ptr addrspace(5) %"38", align 8 + %"54" = inttoptr i64 %"45" to ptr + %"53" = load i16, ptr %"54", align 2 + %"44" = bitcast i16 %"53" to half + store half %"44", ptr addrspace(5) %"40", align 2 + %"46" = load i64, ptr addrspace(5) %"38", align 8 + %"55" = inttoptr i64 %"46" to ptr + %"34" = getelementptr inbounds i8, ptr %"55", i64 2 + %"56" = load i16, ptr %"34", align 2 + %"47" = bitcast i16 %"56" to half + store half %"47", ptr addrspace(5) %"41", align 2 + %"49" = load half, ptr addrspace(5) %"40", align 2 + %"50" = load half, ptr addrspace(5) %"41", align 2 + %2 = call half @llvm.minnum.f16(half %"49", half %"50") + %3 = fcmp uno half %"49", %"50" + %"48" = select i1 %3, half 0xH7E00, half %2 + store half %"48", ptr addrspace(5) %"40", align 2 + %"51" = load i64, ptr addrspace(5) %"39", align 8 + %"52" = load half, ptr addrspace(5) %"40", align 2 + %"57" = inttoptr i64 %"51" to ptr + %"58" = bitcast half %"52" to i16 + store i16 %"58", ptr %"57", align 2 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare half @llvm.minnum.f16(half, half) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } diff --git a/ptx/src/test/spirv_run/min_f16.ptx b/ptx/src/test/spirv_run/min_f16.ptx new file mode 100644 index 0000000..61d30c6 --- /dev/null +++ b/ptx/src/test/spirv_run/min_f16.ptx @@ -0,0 +1,23 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry min_f16( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f16 temp1; + .reg .f16 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b16 temp1, [in_addr]; + ld.b16 temp2, [in_addr+2]; + min.f16 temp1, temp1, temp2; + st.b16 [out_addr], temp1; + ret; +} diff --git a/ptx/src/test/spirv_run/min_nan_f16.ptx b/ptx/src/test/spirv_run/min_nan_f16.ptx new file mode 100644 index 0000000..3c1ca60 --- /dev/null +++ b/ptx/src/test/spirv_run/min_nan_f16.ptx @@ -0,0 +1,23 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry min_nan_f16( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f16 temp1; + .reg .f16 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b16 temp1, [in_addr]; + ld.b16 temp2, [in_addr+2]; + min.NaN.f16 temp1, temp1, temp2; + st.b16 [out_addr], temp1; + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 345b112..a6e62c2 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -143,6 +143,12 @@ test_ptx!(shr_oob, [-32768i16], [-1i16]); test_ptx!(or, [1u64, 2u64], [3u64]); test_ptx!(sub, [2u64], [1u64]); test_ptx!(min, [555i32, 444i32], [444i32]); +test_ptx!( + min_f16, + [half::f16::NAN, half::f16::from_f64(123.0)], + [half::f16::from_f64(123.0)] +); +test_ptx!(min_nan_f16); test_ptx!(max, [555i32, 444i32], [555i32]); test_ptx!(global_array, [0xDEADu32], [1u32]); test_ptx!(global_array_f32, [0x0], [0f32]);