diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index e9d602c..6dbf916 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index e9cf904..60a33e1 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -33,6 +33,20 @@ extern "C" return __ockl_bfe_u32(base, pos, len); } + // LLVM contains mentions of llvm.amdgcn.ubfe.i64 and llvm.amdgcn.sbfe.i64, + // but using it only leads to LLVM crashes on RDNA2 + uint64_t FUNC(bfe_u64)(uint64_t base, uint32_t pos, uint32_t len) + { + // NVIDIA docs are incorrect. In 64 bit `bfe` both `pos` and `len` + // parameters use whole 32 bit number and not just bottom 8 bits + if (pos >= 64) + return 0; + if (len >= 64) + return base >> pos; + len = std::min(len, 63U); + return (base >> pos) & ((1UL << len) - 1UL); + } + int32_t __ockl_bfe_i32(int32_t, uint32_t, uint32_t) __attribute__((device)); int32_t FUNC(bfe_s32)(int32_t base, uint32_t pos_32, uint32_t len_32) { @@ -49,23 +63,42 @@ extern "C" return __ockl_bfe_i32(base, pos, len); } - // LLVM contains mentions of llvm.amdgcn.ubfe.i64 and llvm.amdgcn.sbfe.i64, - // but using it only leads to LLVM crashes on RDNA2 - uint64_t FUNC(bfe_u64)(uint64_t base, uint32_t b, uint32_t c) + static __device__ uint32_t add_sat(uint32_t x, uint32_t y) { - uint8_t pos = uint8_t(b); - uint8_t len = uint8_t(c); - if (len == 0) - return 0; - return (base >> pos) & ((1U << len) - 1U); + uint32_t result; + if (__builtin_add_overflow(x, y, &result)) + { + return UINT32_MAX; + } + else + { + return result; + } } - int64_t FUNC(bfe_s64)(int64_t base, uint32_t b, uint32_t c) + static __device__ uint32_t sub_sat(uint32_t x, uint32_t y) { - uint8_t pos = uint8_t(b); - uint8_t len = uint8_t(c); + uint32_t result; + if (__builtin_sub_overflow(x, y, &result)) + { + return 0; + } + else + { + return result; + } + } + + int64_t FUNC(bfe_s64)(int64_t base, uint32_t pos, uint32_t len) + { + // NVIDIA docs are incorrect. In 64 bit `bfe` both `pos` and `len` + // parameters use whole 32 bit number and not just bottom 8 bits if (len == 0) return 0; - return (base >> pos) & ((1U << len) - 1U); + if (pos >= 64) + return (base >> 63U); + if (add_sat(pos, len) >= 64) + len = sub_sat(64, pos); + return (base << (64U - pos - len)) >> (64U - len); } }