Fix how full-precision fp32 sqrt and div are handled (#467)
Some checks failed
ZLUDA / Build (Linux) (push) Has been cancelled
ZLUDA / Build (Windows) (push) Has been cancelled
ZLUDA / Build AMD GPU unit tests (push) Has been cancelled
ZLUDA / Run AMD GPU unit tests (push) Has been cancelled

Previously, when compiling full precision `sqrt`/`div` we'd leave it to the LLVM. LLVM looks at module's `denormal-fp-math-f32` mode, which is incompatible with how we handle denormals and could give wrong results in certain edge cases.
Instead handle it fully inside ZLUDA
This commit is contained in:
Andrzej Janik 2025-08-15 02:24:40 +02:00 committed by GitHub
commit 65367f04ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 1092 additions and 139 deletions

View file

@ -10,7 +10,9 @@
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
#define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME
#define DECLARE_ATTR(TYPE, NAME) extern const TYPE ATTR(NAME) __device__
#define DECLARE_ATTR(TYPE, NAME) \
extern const TYPE ATTR(NAME) \
__device__
extern "C"
{
@ -100,19 +102,6 @@ extern "C"
}
}
static __device__ uint32_t sub_sat(uint32_t x, uint32_t y)
{
uint32_t result;
if (__builtin_sub_overflow(x, y, &result))
{
return 0;
}
else
{
return result;
}
}
int64_t FUNC(bfe_s64)(int64_t base, uint32_t pos, uint32_t len)
{
// NVIDIA docs are incorrect. In 64 bit `bfe` both `pos` and `len`
@ -122,7 +111,7 @@ extern "C"
if (pos >= 64)
return (base >> 63U);
if (add_sat(pos, len) >= 64)
len = sub_sat(64, pos);
len = 64 - pos;
return (base << (64U - pos - len)) >> (64U - len);
}
@ -174,11 +163,8 @@ extern "C"
BAR_RED_IMPL(and);
BAR_RED_IMPL(or);
struct ShflSyncResult
{
uint32_t output;
bool in_bounds;
};
typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
// shfl.sync opts consists of two values, the warp end ID and the subsection mask.
//
@ -192,7 +178,6 @@ extern "C"
// The warp end ID is the max lane ID for a specific mode. For the CUDA __shfl_sync
// intrinsics, it is always 31 for idx, bfly, and down, and 0 for up. This is used for the
// bounds check.
#define SHFL_SYNC_IMPL(mode, calculate_index, CMP) \
ShflSyncResult FUNC(shfl_sync_##mode##_b32_pred)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask __attribute__((unused))) \
{ \
@ -208,12 +193,12 @@ extern "C"
idx = self; \
} \
int32_t output = __builtin_amdgcn_ds_bpermute(idx << 2, (int32_t)input); \
return {(uint32_t)output, !out_of_bounds}; \
return {(uint32_t)output, uint32_t(!out_of_bounds)}; \
} \
\
uint32_t FUNC(shfl_sync_##mode##_b32)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask) \
{ \
return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).output; \
return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).x; \
}
// We are using the HIP __shfl intrinsics to implement these, rather than the __shfl_sync
@ -226,7 +211,8 @@ extern "C"
SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >);
DECLARE_ATTR(uint32_t, CLOCK_RATE);
void FUNC(nanosleep_u32)(uint32_t nanoseconds) {
void FUNC(nanosleep_u32)(uint32_t nanoseconds)
{
// clock_rate is in kHz
uint64_t cycles_per_ns = ATTR(CLOCK_RATE) / 1000000;
uint64_t cycles = nanoseconds * cycles_per_ns;
@ -335,4 +321,157 @@ extern "C"
else
return value;
}
// Logic taken from legalizeFSQRTF32/lowerFSQRTF32 in LLVM AMDGPU target
__device__ static float precise_square_root(float x, bool needs_denorm_handling)
{
// Constants for denormal handling
const float scale_threshold = 0x1.0p-96f; // Very small value threshold
const float scale_up_factor = 0x1.0p+32f; // 2^32
const float scale_down_factor = 0x1.0p-16f; // 2^-16
// Check if input needs scaling (for very small values)
bool need_scale = scale_threshold > x;
auto scaled = scale_up_factor * x;
// Scale up input if needed
float sqrt_x = need_scale ? scaled : x;
float sqrt_s;
// Check if we need special denormal handling
if (needs_denorm_handling)
{
// Use hardware sqrt as initial approximation
sqrt_s = __builtin_sqrtf(sqrt_x); // Or equivalent hardware instruction
// Bit manipulations to get next values up/down
uint32_t sqrt_s_bits = std::bit_cast<uint32_t>(sqrt_s);
// Next value down (subtract 1 from bit pattern)
uint32_t sqrt_s_next_down_bits = sqrt_s_bits - 1;
float sqrt_s_next_down = std::bit_cast<float>(sqrt_s_next_down_bits);
// Calculate residual: x - sqrt_next_down * sqrt
float neg_sqrt_s_next_down = -sqrt_s_next_down;
float sqrt_vp = std::fma(neg_sqrt_s_next_down, sqrt_s, sqrt_x);
// Next value up (add 1 to bit pattern)
uint32_t sqrt_s_next_up_bits = sqrt_s_bits + 1;
float sqrt_s_next_up = std::bit_cast<float>(sqrt_s_next_up_bits);
// Calculate residual: x - sqrt_next_up * sqrt
float neg_sqrt_s_next_up = -sqrt_s_next_up;
float sqrt_vs = std::fma(neg_sqrt_s_next_up, sqrt_s, sqrt_x);
// Select correctly rounded result
if (sqrt_vp <= 0.0f)
{
sqrt_s = sqrt_s_next_down;
}
if (sqrt_vs > 0.0f)
{
sqrt_s = sqrt_s_next_up;
}
}
else
{
// Use Newton-Raphson method with reciprocal square root
// Initial approximation
float sqrt_r = __builtin_amdgcn_rsqf(sqrt_x); // Or equivalent hardware 1/sqrt instruction
sqrt_s = sqrt_x * sqrt_r;
// Refine approximation
float half = 0.5f;
float sqrt_h = sqrt_r * half;
float neg_sqrt_h = -sqrt_h;
// Calculate error term
float sqrt_e = std::fma(neg_sqrt_h, sqrt_s, half);
// First refinement
sqrt_h = std::fma(sqrt_h, sqrt_e, sqrt_h);
sqrt_s = std::fma(sqrt_s, sqrt_e, sqrt_s);
// Second refinement
float neg_sqrt_s = -sqrt_s;
float sqrt_d = std::fma(neg_sqrt_s, sqrt_s, sqrt_x);
sqrt_s = std::fma(sqrt_d, sqrt_h, sqrt_s);
}
// Scale back if input was scaled
if (need_scale)
{
sqrt_s *= scale_down_factor;
}
// Special case handling for zero and infinity
bool is_zero_or_inf = __builtin_isfpclass(sqrt_x, __FPCLASS_POSINF | __FPCLASS_POSZERO | __FPCLASS_NEGZERO);
return is_zero_or_inf ? sqrt_x : sqrt_s;
}
float FUNC(sqrt_rn_f32)(float x)
{
return precise_square_root(x, true);
}
float FUNC(sqrt_rn_ftz_f32)(float x)
{
return precise_square_root(x, false);
}
struct DivRnFtzF32Part1Result
{
float fma_4;
float fma_1;
float fma_3;
uint8_t numerator_scaled_flag;
};
DivRnFtzF32Part1Result FUNC(div_f32_part1)(float lhs, float rhs)
{
float one = 1.0f;
// Division scale operations
bool denominator_scaled_flag;
float denominator_scaled = __builtin_amdgcn_div_scalef(lhs, rhs, false, &denominator_scaled_flag);
bool numerator_scaled_flag;
float numerator_scaled = __builtin_amdgcn_div_scalef(lhs, rhs, true, &numerator_scaled_flag);
// Reciprocal approximation
float approx_rcp = __builtin_amdgcn_rcpf(denominator_scaled);
float neg_div_scale0 = -denominator_scaled;
// Perform division approximation steps
float fma_0 = fmaf(neg_div_scale0, approx_rcp, one);
float fma_1 = fmaf(fma_0, approx_rcp, approx_rcp);
float mul = numerator_scaled * fma_1;
float fma_2 = fmaf(neg_div_scale0, mul, numerator_scaled);
float fma_3 = fmaf(fma_2, fma_1, mul);
float fma_4 = fmaf(neg_div_scale0, fma_3, numerator_scaled);
return {fma_4, fma_1, fma_3, numerator_scaled_flag};
}
__device__ static float div_f32_part2(float x, float y, DivRnFtzF32Part1Result part1)
{
float fmas = __builtin_amdgcn_div_fmasf(part1.fma_4, part1.fma_1, part1.fma_3, part1.numerator_scaled_flag);
float result = __builtin_amdgcn_div_fixupf(fmas, y, x);
return result;
}
float FUNC(div_f32_part2)(float x, float y,
float fma_4,
float fma_1,
float fma_3,
uint8_t numerator_scaled_flag)
{
return div_f32_part2(x, y, {fma_4, fma_1, fma_3, numerator_scaled_flag});
}
}