mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-10-04 15:19:37 +00:00
Fix how full-precision fp32 sqrt and div are handled (#467)
Previously, when compiling full precision `sqrt`/`div` we'd leave it to the LLVM. LLVM looks at module's `denormal-fp-math-f32` mode, which is incompatible with how we handle denormals and could give wrong results in certain edge cases. Instead handle it fully inside ZLUDA
This commit is contained in:
parent
a420601128
commit
65367f04ee
18 changed files with 1092 additions and 139 deletions
|
@ -10,7 +10,9 @@
|
|||
|
||||
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
|
||||
#define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME
|
||||
#define DECLARE_ATTR(TYPE, NAME) extern const TYPE ATTR(NAME) __device__
|
||||
#define DECLARE_ATTR(TYPE, NAME) \
|
||||
extern const TYPE ATTR(NAME) \
|
||||
__device__
|
||||
|
||||
extern "C"
|
||||
{
|
||||
|
@ -100,19 +102,6 @@ extern "C"
|
|||
}
|
||||
}
|
||||
|
||||
static __device__ uint32_t sub_sat(uint32_t x, uint32_t y)
|
||||
{
|
||||
uint32_t result;
|
||||
if (__builtin_sub_overflow(x, y, &result))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t FUNC(bfe_s64)(int64_t base, uint32_t pos, uint32_t len)
|
||||
{
|
||||
// NVIDIA docs are incorrect. In 64 bit `bfe` both `pos` and `len`
|
||||
|
@ -122,7 +111,7 @@ extern "C"
|
|||
if (pos >= 64)
|
||||
return (base >> 63U);
|
||||
if (add_sat(pos, len) >= 64)
|
||||
len = sub_sat(64, pos);
|
||||
len = 64 - pos;
|
||||
return (base << (64U - pos - len)) >> (64U - len);
|
||||
}
|
||||
|
||||
|
@ -174,11 +163,8 @@ extern "C"
|
|||
BAR_RED_IMPL(and);
|
||||
BAR_RED_IMPL(or);
|
||||
|
||||
struct ShflSyncResult
|
||||
{
|
||||
uint32_t output;
|
||||
bool in_bounds;
|
||||
};
|
||||
|
||||
typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2)));
|
||||
|
||||
// shfl.sync opts consists of two values, the warp end ID and the subsection mask.
|
||||
//
|
||||
|
@ -192,7 +178,6 @@ extern "C"
|
|||
// The warp end ID is the max lane ID for a specific mode. For the CUDA __shfl_sync
|
||||
// intrinsics, it is always 31 for idx, bfly, and down, and 0 for up. This is used for the
|
||||
// bounds check.
|
||||
|
||||
#define SHFL_SYNC_IMPL(mode, calculate_index, CMP) \
|
||||
ShflSyncResult FUNC(shfl_sync_##mode##_b32_pred)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask __attribute__((unused))) \
|
||||
{ \
|
||||
|
@ -208,12 +193,12 @@ extern "C"
|
|||
idx = self; \
|
||||
} \
|
||||
int32_t output = __builtin_amdgcn_ds_bpermute(idx << 2, (int32_t)input); \
|
||||
return {(uint32_t)output, !out_of_bounds}; \
|
||||
return {(uint32_t)output, uint32_t(!out_of_bounds)}; \
|
||||
} \
|
||||
\
|
||||
uint32_t FUNC(shfl_sync_##mode##_b32)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask) \
|
||||
{ \
|
||||
return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).output; \
|
||||
return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).x; \
|
||||
}
|
||||
|
||||
// We are using the HIP __shfl intrinsics to implement these, rather than the __shfl_sync
|
||||
|
@ -226,7 +211,8 @@ extern "C"
|
|||
SHFL_SYNC_IMPL(idx, (delta & ~section_mask) | subsection, >);
|
||||
|
||||
DECLARE_ATTR(uint32_t, CLOCK_RATE);
|
||||
void FUNC(nanosleep_u32)(uint32_t nanoseconds) {
|
||||
void FUNC(nanosleep_u32)(uint32_t nanoseconds)
|
||||
{
|
||||
// clock_rate is in kHz
|
||||
uint64_t cycles_per_ns = ATTR(CLOCK_RATE) / 1000000;
|
||||
uint64_t cycles = nanoseconds * cycles_per_ns;
|
||||
|
@ -335,4 +321,157 @@ extern "C"
|
|||
else
|
||||
return value;
|
||||
}
|
||||
|
||||
// Logic taken from legalizeFSQRTF32/lowerFSQRTF32 in LLVM AMDGPU target
|
||||
__device__ static float precise_square_root(float x, bool needs_denorm_handling)
|
||||
{
|
||||
|
||||
// Constants for denormal handling
|
||||
const float scale_threshold = 0x1.0p-96f; // Very small value threshold
|
||||
const float scale_up_factor = 0x1.0p+32f; // 2^32
|
||||
const float scale_down_factor = 0x1.0p-16f; // 2^-16
|
||||
|
||||
// Check if input needs scaling (for very small values)
|
||||
bool need_scale = scale_threshold > x;
|
||||
auto scaled = scale_up_factor * x;
|
||||
|
||||
// Scale up input if needed
|
||||
float sqrt_x = need_scale ? scaled : x;
|
||||
|
||||
float sqrt_s;
|
||||
|
||||
// Check if we need special denormal handling
|
||||
|
||||
if (needs_denorm_handling)
|
||||
{
|
||||
// Use hardware sqrt as initial approximation
|
||||
sqrt_s = __builtin_sqrtf(sqrt_x); // Or equivalent hardware instruction
|
||||
|
||||
// Bit manipulations to get next values up/down
|
||||
uint32_t sqrt_s_bits = std::bit_cast<uint32_t>(sqrt_s);
|
||||
|
||||
// Next value down (subtract 1 from bit pattern)
|
||||
uint32_t sqrt_s_next_down_bits = sqrt_s_bits - 1;
|
||||
float sqrt_s_next_down = std::bit_cast<float>(sqrt_s_next_down_bits);
|
||||
|
||||
// Calculate residual: x - sqrt_next_down * sqrt
|
||||
float neg_sqrt_s_next_down = -sqrt_s_next_down;
|
||||
float sqrt_vp = std::fma(neg_sqrt_s_next_down, sqrt_s, sqrt_x);
|
||||
|
||||
// Next value up (add 1 to bit pattern)
|
||||
uint32_t sqrt_s_next_up_bits = sqrt_s_bits + 1;
|
||||
float sqrt_s_next_up = std::bit_cast<float>(sqrt_s_next_up_bits);
|
||||
|
||||
// Calculate residual: x - sqrt_next_up * sqrt
|
||||
float neg_sqrt_s_next_up = -sqrt_s_next_up;
|
||||
float sqrt_vs = std::fma(neg_sqrt_s_next_up, sqrt_s, sqrt_x);
|
||||
|
||||
// Select correctly rounded result
|
||||
if (sqrt_vp <= 0.0f)
|
||||
{
|
||||
sqrt_s = sqrt_s_next_down;
|
||||
}
|
||||
|
||||
if (sqrt_vs > 0.0f)
|
||||
{
|
||||
sqrt_s = sqrt_s_next_up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Use Newton-Raphson method with reciprocal square root
|
||||
|
||||
// Initial approximation
|
||||
float sqrt_r = __builtin_amdgcn_rsqf(sqrt_x); // Or equivalent hardware 1/sqrt instruction
|
||||
sqrt_s = sqrt_x * sqrt_r;
|
||||
|
||||
// Refine approximation
|
||||
float half = 0.5f;
|
||||
float sqrt_h = sqrt_r * half;
|
||||
float neg_sqrt_h = -sqrt_h;
|
||||
|
||||
// Calculate error term
|
||||
float sqrt_e = std::fma(neg_sqrt_h, sqrt_s, half);
|
||||
|
||||
// First refinement
|
||||
sqrt_h = std::fma(sqrt_h, sqrt_e, sqrt_h);
|
||||
sqrt_s = std::fma(sqrt_s, sqrt_e, sqrt_s);
|
||||
|
||||
// Second refinement
|
||||
float neg_sqrt_s = -sqrt_s;
|
||||
float sqrt_d = std::fma(neg_sqrt_s, sqrt_s, sqrt_x);
|
||||
sqrt_s = std::fma(sqrt_d, sqrt_h, sqrt_s);
|
||||
}
|
||||
|
||||
// Scale back if input was scaled
|
||||
if (need_scale)
|
||||
{
|
||||
sqrt_s *= scale_down_factor;
|
||||
}
|
||||
|
||||
// Special case handling for zero and infinity
|
||||
bool is_zero_or_inf = __builtin_isfpclass(sqrt_x, __FPCLASS_POSINF | __FPCLASS_POSZERO | __FPCLASS_NEGZERO);
|
||||
|
||||
return is_zero_or_inf ? sqrt_x : sqrt_s;
|
||||
}
|
||||
|
||||
float FUNC(sqrt_rn_f32)(float x)
|
||||
{
|
||||
return precise_square_root(x, true);
|
||||
}
|
||||
|
||||
float FUNC(sqrt_rn_ftz_f32)(float x)
|
||||
{
|
||||
return precise_square_root(x, false);
|
||||
}
|
||||
|
||||
struct DivRnFtzF32Part1Result
|
||||
{
|
||||
float fma_4;
|
||||
float fma_1;
|
||||
float fma_3;
|
||||
uint8_t numerator_scaled_flag;
|
||||
};
|
||||
|
||||
DivRnFtzF32Part1Result FUNC(div_f32_part1)(float lhs, float rhs)
|
||||
{
|
||||
float one = 1.0f;
|
||||
|
||||
// Division scale operations
|
||||
bool denominator_scaled_flag;
|
||||
float denominator_scaled = __builtin_amdgcn_div_scalef(lhs, rhs, false, &denominator_scaled_flag);
|
||||
|
||||
bool numerator_scaled_flag;
|
||||
float numerator_scaled = __builtin_amdgcn_div_scalef(lhs, rhs, true, &numerator_scaled_flag);
|
||||
|
||||
// Reciprocal approximation
|
||||
float approx_rcp = __builtin_amdgcn_rcpf(denominator_scaled);
|
||||
float neg_div_scale0 = -denominator_scaled;
|
||||
|
||||
// Perform division approximation steps
|
||||
float fma_0 = fmaf(neg_div_scale0, approx_rcp, one);
|
||||
float fma_1 = fmaf(fma_0, approx_rcp, approx_rcp);
|
||||
float mul = numerator_scaled * fma_1;
|
||||
float fma_2 = fmaf(neg_div_scale0, mul, numerator_scaled);
|
||||
float fma_3 = fmaf(fma_2, fma_1, mul);
|
||||
float fma_4 = fmaf(neg_div_scale0, fma_3, numerator_scaled);
|
||||
return {fma_4, fma_1, fma_3, numerator_scaled_flag};
|
||||
}
|
||||
|
||||
__device__ static float div_f32_part2(float x, float y, DivRnFtzF32Part1Result part1)
|
||||
{
|
||||
float fmas = __builtin_amdgcn_div_fmasf(part1.fma_4, part1.fma_1, part1.fma_3, part1.numerator_scaled_flag);
|
||||
float result = __builtin_amdgcn_div_fixupf(fmas, y, x);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
float FUNC(div_f32_part2)(float x, float y,
|
||||
float fma_4,
|
||||
float fma_1,
|
||||
float fma_3,
|
||||
uint8_t numerator_scaled_flag)
|
||||
{
|
||||
return div_f32_part2(x, y, {fma_4, fma_1, fma_3, numerator_scaled_flag});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue