Fix min.ftz.nan.f16 for ROCm 6.3.4

This commit is contained in:
Violet 2025-09-11 22:33:19 +00:00
commit 49a41eed62
6 changed files with 177 additions and 9 deletions

View file

@ -2266,22 +2266,50 @@ impl<'a> MethodEmitContext<'a> {
let llvm_prefix = match data {
ptx_parser::MinMaxDetails::Signed(..) => "llvm.smin",
ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umin",
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
"llvm.minimum"
}
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
};
let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
let llvm_type = get_scalar_type(self.context, data.type_());
self.emit_intrinsic(
let a = self.resolver.value(arguments.src1)?;
let b = self.resolver.value(arguments.src2)?;
let min = self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
None,
Some(&data.type_().into()),
vec![
(self.resolver.value(arguments.src1)?, llvm_type),
(self.resolver.value(arguments.src2)?, llvm_type),
],
vec![(a, llvm_type), (b, llvm_type)],
)?;
if let ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat {
nan: true, type_, ..
}) = data
{
let is_nan = unsafe {
LLVMBuildFCmp(
self.builder,
LLVMRealPredicate::LLVMRealUNO,
a,
b,
LLVM_UNNAMED.as_ptr(),
)
};
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildSelect(
self.builder,
is_nan,
LLVMConstReal(get_scalar_type(self.context, type_), f64::NAN),
min,
dst,
)
});
} else {
self.resolver.with_result(arguments.dst, |dst| unsafe {
let dst = CStr::from_ptr(dst);
LLVMSetValueName2(min, dst.as_ptr(), dst.count_bytes());
min
});
}
Ok(())
}

View file

@ -0,0 +1,43 @@
define amdgpu_kernel void @min_f16(ptr addrspace(4) byref(i64) %"36", ptr addrspace(4) byref(i64) %"37") #0 {
%"38" = alloca i64, align 8, addrspace(5)
%"39" = alloca i64, align 8, addrspace(5)
%"40" = alloca half, align 2, addrspace(5)
%"41" = alloca half, align 2, addrspace(5)
br label %1
1: ; preds = %0
br label %"35"
"35": ; preds = %1
%"42" = load i64, ptr addrspace(4) %"36", align 8
store i64 %"42", ptr addrspace(5) %"38", align 8
%"43" = load i64, ptr addrspace(4) %"37", align 8
store i64 %"43", ptr addrspace(5) %"39", align 8
%"45" = load i64, ptr addrspace(5) %"38", align 8
%"54" = inttoptr i64 %"45" to ptr
%"53" = load i16, ptr %"54", align 2
%"44" = bitcast i16 %"53" to half
store half %"44", ptr addrspace(5) %"40", align 2
%"46" = load i64, ptr addrspace(5) %"38", align 8
%"55" = inttoptr i64 %"46" to ptr
%"34" = getelementptr inbounds i8, ptr %"55", i64 2
%"56" = load i16, ptr %"34", align 2
%"47" = bitcast i16 %"56" to half
store half %"47", ptr addrspace(5) %"41", align 2
%"49" = load half, ptr addrspace(5) %"40", align 2
%"50" = load half, ptr addrspace(5) %"41", align 2
%"48" = call half @llvm.minnum.f16(half %"49", half %"50")
store half %"48", ptr addrspace(5) %"40", align 2
%"51" = load i64, ptr addrspace(5) %"39", align 8
%"52" = load half, ptr addrspace(5) %"40", align 2
%"57" = inttoptr i64 %"51" to ptr
%"58" = bitcast half %"52" to i16
store i16 %"58", ptr %"57", align 2
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare half @llvm.minnum.f16(half, half) #1
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

View file

@ -0,0 +1,45 @@
define amdgpu_kernel void @min_nan_f16(ptr addrspace(4) byref(i64) %"36", ptr addrspace(4) byref(i64) %"37") #0 {
%"38" = alloca i64, align 8, addrspace(5)
%"39" = alloca i64, align 8, addrspace(5)
%"40" = alloca half, align 2, addrspace(5)
%"41" = alloca half, align 2, addrspace(5)
br label %1
1: ; preds = %0
br label %"35"
"35": ; preds = %1
%"42" = load i64, ptr addrspace(4) %"36", align 8
store i64 %"42", ptr addrspace(5) %"38", align 8
%"43" = load i64, ptr addrspace(4) %"37", align 8
store i64 %"43", ptr addrspace(5) %"39", align 8
%"45" = load i64, ptr addrspace(5) %"38", align 8
%"54" = inttoptr i64 %"45" to ptr
%"53" = load i16, ptr %"54", align 2
%"44" = bitcast i16 %"53" to half
store half %"44", ptr addrspace(5) %"40", align 2
%"46" = load i64, ptr addrspace(5) %"38", align 8
%"55" = inttoptr i64 %"46" to ptr
%"34" = getelementptr inbounds i8, ptr %"55", i64 2
%"56" = load i16, ptr %"34", align 2
%"47" = bitcast i16 %"56" to half
store half %"47", ptr addrspace(5) %"41", align 2
%"49" = load half, ptr addrspace(5) %"40", align 2
%"50" = load half, ptr addrspace(5) %"41", align 2
%2 = call half @llvm.minnum.f16(half %"49", half %"50")
%3 = fcmp uno half %"49", %"50"
%"48" = select i1 %3, half 0xH7E00, half %2
store half %"48", ptr addrspace(5) %"40", align 2
%"51" = load i64, ptr addrspace(5) %"39", align 8
%"52" = load half, ptr addrspace(5) %"40", align 2
%"57" = inttoptr i64 %"51" to ptr
%"58" = bitcast half %"52" to i16
store i16 %"58", ptr %"57", align 2
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare half @llvm.minnum.f16(half, half) #1
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

View file

@ -0,0 +1,23 @@
.version 7.0
.target sm_80
.address_size 64
.visible .entry min_f16(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f16 temp1;
.reg .f16 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.b16 temp1, [in_addr];
ld.b16 temp2, [in_addr+2];
min.f16 temp1, temp1, temp2;
st.b16 [out_addr], temp1;
ret;
}

View file

@ -0,0 +1,23 @@
.version 7.0
.target sm_80
.address_size 64
.visible .entry min_nan_f16(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f16 temp1;
.reg .f16 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.b16 temp1, [in_addr];
ld.b16 temp2, [in_addr+2];
min.NaN.f16 temp1, temp1, temp2;
st.b16 [out_addr], temp1;
ret;
}

View file

@ -143,6 +143,12 @@ test_ptx!(shr_oob, [-32768i16], [-1i16]);
test_ptx!(or, [1u64, 2u64], [3u64]);
test_ptx!(sub, [2u64], [1u64]);
test_ptx!(min, [555i32, 444i32], [444i32]);
test_ptx!(
min_f16,
[half::f16::NAN, half::f16::from_f64(123.0)],
[half::f16::from_f64(123.0)]
);
test_ptx!(min_nan_f16);
test_ptx!(max, [555i32, 444i32], [555i32]);
test_ptx!(global_array, [0xDEADu32], [1u32]);
test_ptx!(global_array_f32, [0x0], [0f32]);