diff --git a/rpcs3/util/atomic.cpp b/rpcs3/util/atomic.cpp index 9c0be06044..bcb17a8705 100644 --- a/rpcs3/util/atomic.cpp +++ b/rpcs3/util/atomic.cpp @@ -20,6 +20,7 @@ #include #include "asm.hpp" +#include "endian.hpp" // Total number of entries, should be a power of 2. static constexpr std::size_t s_hashtable_size = 1u << 18; @@ -30,13 +31,18 @@ static thread_local bool(*s_tls_wait_cb)(const void* data) = [](const void*){ re // Callback for notification functions for optimizations static thread_local void(*s_tls_notify_cb)(const void* data, u64 progress) = [](const void*, u64){}; +static inline bool operator &(atomic_wait::op lhs, atomic_wait::op_flag rhs) +{ + return !!(static_cast(lhs) & static_cast(rhs)); +} + // Compare data in memory with old value, and return true if they are equal template static NEVER_INLINE bool #ifdef _WIN32 __vectorcall #endif -ptr_cmp(const void* data, u32 size, __m128i old128, __m128i mask128, atomic_wait::info* ext = nullptr) +ptr_cmp(const void* data, u32 _size, __m128i old128, __m128i mask128, atomic_wait::info* ext = nullptr) { if constexpr (CheckCb) { @@ -46,32 +52,138 @@ ptr_cmp(const void* data, u32 size, __m128i old128, __m128i mask128, atomic_wait } } - const u64 old_value = _mm_cvtsi128_si64(old128); - const u64 mask = _mm_cvtsi128_si64(mask128); + using atomic_wait::op; + using atomic_wait::op_flag; + + const u8 size = static_cast(_size); + const op flag{static_cast(_size >> 8)}; bool result = false; - switch (size) + if (size <= 8) { - case 1: result = (reinterpret_cast*>(data)->load() & mask) == (old_value & mask); break; - case 2: result = (reinterpret_cast*>(data)->load() & mask) == (old_value & mask); break; - case 4: result = (reinterpret_cast*>(data)->load() & mask) == (old_value & mask); break; - case 8: result = (reinterpret_cast*>(data)->load() & mask) == (old_value & mask); break; - case 16: - { - const auto v0 = std::bit_cast<__m128i>(atomic_storage::load(*reinterpret_cast(data))); - const auto v1 = _mm_xor_si128(v0, old128); - const auto v2 = _mm_and_si128(v1, mask128); - const auto v3 = _mm_packs_epi16(v2, v2); + u64 new_value = 0; + u64 old_value = _mm_cvtsi128_si64(old128); + u64 mask = _mm_cvtsi128_si64(mask128) & (UINT64_MAX >> ((64 - size * 8) & 63)); - result = _mm_cvtsi128_si64(v3) == 0; - break; + switch (size) + { + case 1: new_value = reinterpret_cast*>(data)->load(); break; + case 2: new_value = reinterpret_cast*>(data)->load(); break; + case 4: new_value = reinterpret_cast*>(data)->load(); break; + case 8: new_value = reinterpret_cast*>(data)->load(); break; + default: + { + fprintf(stderr, "ptr_cmp(): bad size (arg=0x%x)" HERE "\n", _size); + std::abort(); + } + } + + if (flag & op_flag::bit_not) + { + new_value = ~new_value; + } + + if (!mask) [[unlikely]] + { + new_value = 0; + old_value = 0; + } + else + { + if (flag & op_flag::byteswap) + { + switch (size) + { + case 2: + { + new_value = stx::se_storage::swap(static_cast(new_value)); + old_value = stx::se_storage::swap(static_cast(old_value)); + mask = stx::se_storage::swap(static_cast(mask)); + break; + } + case 4: + { + new_value = stx::se_storage::swap(static_cast(new_value)); + old_value = stx::se_storage::swap(static_cast(old_value)); + mask = stx::se_storage::swap(static_cast(mask)); + break; + } + case 8: + { + new_value = stx::se_storage::swap(new_value); + old_value = stx::se_storage::swap(old_value); + mask = stx::se_storage::swap(mask); + } + default: + { + break; + } + } + } + + // Make most significant bit sign bit + const auto shv = std::countl_zero(mask); + new_value &= mask; + old_value &= mask; + new_value <<= shv; + old_value <<= shv; + } + + s64 news = new_value; + s64 olds = old_value; + + u64 newa = news < 0 ? (0ull - new_value) : new_value; + u64 olda = olds < 0 ? (0ull - old_value) : old_value; + + switch (op{static_cast(static_cast(flag) & 0xf)}) + { + case op::eq: result = old_value == new_value; break; + case op::slt: result = olds < news; break; + case op::sgt: result = olds > news; break; + case op::ult: result = old_value < new_value; break; + case op::ugt: result = old_value > new_value; break; + case op::alt: result = olda < newa; break; + case op::agt: result = olda > newa; break; + case op::pop: + { + // Count is taken from least significant byte and ignores some flags + const u64 count = _mm_cvtsi128_si64(old128) & 0xff; + + u64 bitc = new_value; + bitc = (bitc & 0xaaaaaaaaaaaaaaaa) / 2 + (bitc & 0x5555555555555555); + bitc = (bitc & 0xcccccccccccccccc) / 4 + (bitc & 0x3333333333333333); + bitc = (bitc & 0xf0f0f0f0f0f0f0f0) / 16 + (bitc & 0x0f0f0f0f0f0f0f0f); + bitc = (bitc & 0xff00ff00ff00ff00) / 256 + (bitc & 0x00ff00ff00ff00ff); + bitc = ((bitc & 0xffff0000ffff0000) >> 16) + (bitc & 0x0000ffff0000ffff); + bitc = (bitc >> 32) + bitc; + + result = count < bitc; + break; + } + default: + { + fmt::raw_error("ptr_cmp(): unrecognized atomic wait operation."); + } + } } - default: + else if (size == 16 && (flag == op::eq || flag == (op::eq | op_flag::inverse))) { - fprintf(stderr, "ptr_cmp(): bad size (size=%u)" HERE "\n", size); - std::abort(); + u128 new_value = atomic_storage::load(*reinterpret_cast(data)); + u128 old_value = std::bit_cast(old128); + u128 mask = std::bit_cast(mask128); + + // TODO + result = !((old_value ^ new_value) & mask); } + else if (size == 16) + { + fmt::raw_error("ptr_cmp(): no alternative operations are supported for 16-byte atomic wait yet."); + } + + if (flag & op_flag::inverse) + { + result = !result; } // Check other wait variables if provided @@ -101,16 +213,8 @@ __vectorcall #endif cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m128i val2) { - // In force wake up, one of the size arguments is zero (obsolete) - const u32 size = std::min(size1, size2); - - if (!size) [[unlikely]] - { - return 2; - } - // Compare only masks, new value is not available in this mode - if ((size1 | size2) == umax) + if (size1 == umax) { // Simple mask overlap const auto v0 = _mm_and_si128(mask1, mask2); @@ -121,6 +225,17 @@ cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m12 // Generate masked value inequality bits const auto v0 = _mm_and_si128(_mm_and_si128(mask1, mask2), _mm_xor_si128(val1, val2)); + using atomic_wait::op; + using atomic_wait::op_flag; + + const u8 size = std::min(static_cast(size2), static_cast(size1)); + const op flag{static_cast(size2 >> 8)}; + + if (flag != op::eq && flag != (op::eq | op_flag::inverse)) + { + fmt::raw_error("cmp_mask(): no operations are supported for notification with forced value yet."); + } + if (size <= 8) { // Generate sized mask @@ -128,14 +243,14 @@ cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m12 if (!(_mm_cvtsi128_si64(v0) & mask)) { - return 0; + return flag & op_flag::inverse ? 2 : 0; } } else if (size == 16) { if (!_mm_cvtsi128_si64(_mm_packs_epi16(v0, v0))) { - return 0; + return flag & op_flag::inverse ? 2 : 0; } } else @@ -145,7 +260,7 @@ cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m12 } // Use force wake-up - return 2; + return flag & op_flag::inverse ? 0 : 2; } static atomic_t s_min_tsc{0}; @@ -227,7 +342,8 @@ namespace atomic_wait // Temporarily reduced unique tsc stamp to 48 bits to make space for refs (TODO) u64 tsc0 : 48 = 0; u64 link : 16 = 0; - u16 size{}; + u8 size{}; + u8 flag{}; atomic_t refs{}; atomic_t sync{}; @@ -262,6 +378,7 @@ namespace atomic_wait tsc0 = 0; link = 0; size = 0; + flag = 0; sync = 0; #ifdef USE_STD @@ -868,7 +985,8 @@ atomic_wait_engine::wait(const void* data, u32 size, __m128i old_value, u64 time // Store some info for notifiers (some may be unused) cond->link = 0; - cond->size = static_cast(size); + cond->size = static_cast(size); + cond->flag = static_cast(size >> 8); cond->mask = mask; cond->oldv = old_value; cond->tsc0 = stamp0; @@ -877,7 +995,8 @@ atomic_wait_engine::wait(const void* data, u32 size, __m128i old_value, u64 time { // Extensions point to original cond_id, copy remaining info cond_ext[i]->link = cond_id; - cond_ext[i]->size = static_cast(ext[i].size); + cond_ext[i]->size = static_cast(ext[i].size); + cond_ext[i]->flag = static_cast(ext[i].size >> 8); cond_ext[i]->mask = ext[i].mask; cond_ext[i]->oldv = ext[i].old; cond_ext[i]->tsc0 = stamp0; @@ -1058,7 +1177,7 @@ alert_sema(u32 cond_id, const void* data, u64 info, u32 size, __m128i mask, __m1 u32 cmp_res = 0; - if (cond->sync && (!size ? (!info || cond->tid == info) : (cond->ptr == data && ((cmp_res = cmp_mask(size, mask, new_value, cond->size, cond->mask, cond->oldv)))))) + if (cond->sync && (!size ? (!info || cond->tid == info) : (cond->ptr == data && ((cmp_res = cmp_mask(size, mask, new_value, cond->size | (cond->flag << 8), cond->mask, cond->oldv)))))) { // Redirect if necessary const auto _old = cond; diff --git a/rpcs3/util/atomic.hpp b/rpcs3/util/atomic.hpp index 8a81ccff56..681bfef601 100644 --- a/rpcs3/util/atomic.hpp +++ b/rpcs3/util/atomic.hpp @@ -14,14 +14,56 @@ enum class atomic_wait_timeout : u64 inf = 0xffffffffffffffff, }; -// Unused externally +// Various extensions for atomic_t::wait namespace atomic_wait { + // Max number of simultaneous atomic variables to wait on (can be extended if really necessary) constexpr uint max_list = 8; struct root_info; struct sema_handle; + enum class op : u8 + { + eq, // Wait while value is bitwise equal to + slt, // Wait while signed value is less than + sgt, // Wait while signed value is greater than + ult, // Wait while unsigned value is less than + ugt, // Wait while unsigned value is greater than + alt, // Wait while absolute value is less than + agt, // Wait while absolute value is greater than + pop, // Wait while set bit count of the value is less than + __max + }; + + static_assert(static_cast(op::__max) == 8); + + enum class op_flag : u8 + { + inverse = 1 << 4, // Perform inverse operation (negate the result) + bit_not = 1 << 5, // Perform bitwise NOT on loaded value before operation + byteswap = 1 << 6, // Perform byteswap on both arguments and masks when applicable + }; + + constexpr op_flag op_ne = {}; + constexpr op_flag op_be = std::endian::native == std::endian::little ? op_flag::byteswap : op_flag{0}; + constexpr op_flag op_le = std::endian::native == std::endian::little ? op_flag{0} : op_flag::byteswap; + + constexpr op operator |(op_flag lhs, op_flag rhs) + { + return op{static_cast(static_cast(lhs) | static_cast(rhs))}; + } + + constexpr op operator |(op_flag lhs, op rhs) + { + return op{static_cast(static_cast(lhs) | static_cast(rhs))}; + } + + constexpr op operator |(op lhs, op_flag rhs) + { + return op{static_cast(static_cast(lhs) | static_cast(rhs))}; + } + struct info { const void* data; @@ -114,24 +156,24 @@ namespace atomic_wait return *this; } - template + template constexpr void set(atomic_t& var, U value) { static_assert(Index < Max); m_info[Index].data = &var.raw(); - m_info[Index].size = sizeof(T2); + m_info[Index].size = sizeof(T2) | (static_cast(Flags) << 8); m_info[Index].template set_value(value); m_info[Index].mask = _mm_set1_epi64x(-1); } - template + template constexpr void set(atomic_t& var, U value, V mask) { static_assert(Index < Max); m_info[Index].data = &var.raw(); - m_info[Index].size = sizeof(T2); + m_info[Index].size = sizeof(T2) | (static_cast(Flags) << 8); m_info[Index].template set_value(value); m_info[Index].template set_mask(mask); } @@ -1387,34 +1429,36 @@ public: } // Timeout is discouraged + template void wait(type old_value, atomic_wait_timeout timeout = atomic_wait_timeout::inf) const noexcept { if constexpr (sizeof(T) <= 8) { const __m128i old = _mm_cvtsi64_si128(std::bit_cast>(old_value)); - atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast(timeout), _mm_set1_epi64x(-1)); + atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast(Flags) << 8), old, static_cast(timeout), _mm_set1_epi64x(-1)); } else if constexpr (sizeof(T) == 16) { const __m128i old = std::bit_cast<__m128i>(old_value); - atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast(timeout), _mm_set1_epi64x(-1)); + atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast(Flags) << 8), old, static_cast(timeout), _mm_set1_epi64x(-1)); } } // Overload with mask (only selected bits are checked), timeout is discouraged + template void wait(type old_value, type mask_value, atomic_wait_timeout timeout = atomic_wait_timeout::inf) { if constexpr (sizeof(T) <= 8) { const __m128i old = _mm_cvtsi64_si128(std::bit_cast>(old_value)); const __m128i mask = _mm_cvtsi64_si128(std::bit_cast>(mask_value)); - atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast(timeout), mask); + atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast(Flags) << 8), old, static_cast(timeout), mask); } else if constexpr (sizeof(T) == 16) { const __m128i old = std::bit_cast<__m128i>(old_value); const __m128i mask = std::bit_cast<__m128i>(mask_value); - atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast(timeout), mask); + atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast(Flags) << 8), old, static_cast(timeout), mask); } }