Fix floating point min/max (#399)
Some checks failed
ZLUDA / Build (Linux) (push) Has been cancelled
ZLUDA / Build (Windows) (push) Has been cancelled

This commit is contained in:
Andrzej Janik 2025-07-01 15:58:16 -07:00 committed by GitHub
commit 6d56fa8c34
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 74 additions and 3 deletions

View file

@ -2181,7 +2181,7 @@ impl<'a> MethodEmitContext<'a> {
ptx_parser::MinMaxDetails::Signed(..) => "llvm.smin",
ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umin",
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
return Err(error_todo())
"llvm.minimum"
}
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
};
@ -2208,7 +2208,7 @@ impl<'a> MethodEmitContext<'a> {
ptx_parser::MinMaxDetails::Signed(..) => "llvm.smax",
ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umax",
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
return Err(error_todo())
"llvm.maximum"
}
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
};

45
ptx/src/test/ll/fmax.ll Normal file
View file

@ -0,0 +1,45 @@
define amdgpu_kernel void @fmax(ptr addrspace(4) byref(i64) %"35", ptr addrspace(4) byref(i64) %"36") #0 {
%"37" = alloca i64, align 8, addrspace(5)
%"38" = alloca i64, align 8, addrspace(5)
%"39" = alloca half, align 2, addrspace(5)
%"40" = alloca half, align 2, addrspace(5)
%"41" = alloca half, align 2, addrspace(5)
%"42" = alloca half, align 2, addrspace(5)
br label %1
1: ; preds = %0
br label %"34"
"34": ; preds = %1
%"43" = load i64, ptr addrspace(4) %"35", align 4
store i64 %"43", ptr addrspace(5) %"37", align 4
%"44" = load i64, ptr addrspace(4) %"36", align 4
store i64 %"44", ptr addrspace(5) %"38", align 4
%"46" = load i64, ptr addrspace(5) %"37", align 4
%"55" = inttoptr i64 %"46" to ptr
%"54" = load i16, ptr %"55", align 2
%"45" = bitcast i16 %"54" to half
store half %"45", ptr addrspace(5) %"39", align 2
%"47" = load i64, ptr addrspace(5) %"37", align 4
%"56" = inttoptr i64 %"47" to ptr
%"33" = getelementptr inbounds i8, ptr %"56", i64 2
%"57" = load i16, ptr %"33", align 2
%"48" = bitcast i16 %"57" to half
store half %"48", ptr addrspace(5) %"40", align 2
%"50" = load half, ptr addrspace(5) %"40", align 2
%"51" = load half, ptr addrspace(5) %"39", align 2
%"49" = call half @llvm.maxnum.f16(half %"50", half %"51")
store half %"49", ptr addrspace(5) %"41", align 2
%"52" = load i64, ptr addrspace(5) %"38", align 4
%"53" = load half, ptr addrspace(5) %"41", align 2
%"58" = inttoptr i64 %"52" to ptr
%"59" = bitcast half %"53" to i16
store i16 %"59", ptr %"58", align 2
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare half @llvm.maxnum.f16(half, half) #1
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

View file

@ -0,0 +1,25 @@
.version 7.0
.target sm_80
.address_size 64
.visible .entry fmax(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f16 temp1;
.reg .f16 temp2;
.reg .f16 result1;
.reg .f16 result2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.b16 temp1, [in_addr];
ld.b16 temp2, [in_addr+2];
max.f16 result1, temp2, temp1;
st.b16 [out_addr], result1;
ret;
}

View file

@ -18,7 +18,7 @@ use std::str;
macro_rules! read_test_file {
($file:expr) => {
{
// CARGO_MANIFEST_DIR is the crate directory (ptx), but file! is relative to the workspace root (and therefore also includes ptx).
// CARGO_MANIFEST_DIR is the crate directory (ptx), but file! is relative to the workspace root (and therefore also includes ptx).
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.pop();
path.push(file!());
@ -175,6 +175,7 @@ test_ptx!(sin, [std::f32::consts::PI / 2f32], [1f32]);
test_ptx!(cos, [std::f32::consts::PI], [-1f32]);
test_ptx!(lg2, [512f32], [9f32]);
test_ptx!(ex2, [10f32], [1024f32]);
test_ptx!(fmax, [0u16, half::f16::NAN.to_bits()], [0u16]);
test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]);
test_ptx!(cvt_rzi, [-13.8f32, 12.9f32], [-13f32, 12f32]);
test_ptx!(cvt_s32_f32, [-13.8f32, 12.9f32], [-13i32, 13i32]);