diff --git a/Utilities/Thread.cpp b/Utilities/Thread.cpp index 77f876ea0f..e372d1e924 100644 --- a/Utilities/Thread.cpp +++ b/Utilities/Thread.cpp @@ -1609,6 +1609,8 @@ static void prepare_throw_access_violation(x64_context* context, const char* cau RIP(context) = (u64)std::addressof(throw_access_violation); } +static void _handle_interrupt(x64_context* ctx); + #ifdef _WIN32 static LONG exception_handler(PEXCEPTION_POINTERS pExp) @@ -1698,6 +1700,11 @@ static void signal_handler(int sig, siginfo_t* info, void* uct) { x64_context* context = (ucontext_t*)uct; + if (sig == SIGUSR1) + { + return _handle_interrupt(context); + } + #ifdef __APPLE__ const bool is_writing = context->uc_mcontext->__es.__err & 0x2; #else @@ -1735,7 +1742,15 @@ const bool s_exception_handler_set = []() -> bool if (::sigaction(SIGSEGV, &sa, NULL) == -1) { - std::printf("sigaction() failed (0x%x).", errno); + std::printf("sigaction(SIGSEGV) failed (0x%x).", errno); + std::abort(); + } + + sa.sa_sigaction = signal_handler; + + if (::sigaction(SIGUSR1, &sa, NULL) == -1) + { + std::printf("sigaction(SIGUSR1) failed (0x%x).", errno); std::abort(); } @@ -1767,13 +1782,23 @@ struct thread_ctrl::internal { std::mutex mutex; std::condition_variable cond; - std::condition_variable join; // Allows simultaneous joining + std::condition_variable jcv; // Allows simultaneous joining + std::condition_variable icv; task_stack atexit; - std::exception_ptr exception; // Caught exception + std::exception_ptr exception; // Stored exception std::chrono::high_resolution_clock::time_point time_limit; + +#ifdef _WIN32 + DWORD thread_id = 0; + x64_context _context{}; +#endif + + x64_context* thread_ctx{}; + + atomic_t interrupt{}; // Interrupt function }; thread_local thread_ctrl::internal* g_tls_internal = nullptr; @@ -1804,7 +1829,6 @@ void thread_ctrl::start(const std::shared_ptr& ctrl, task_stack tas } catch (...) { - ctrl->initialize_once(); ctrl->m_data->exception = std::current_exception(); } @@ -1814,15 +1838,11 @@ void thread_ctrl::start(const std::shared_ptr& ctrl, task_stack tas void thread_ctrl::wait_start(u64 timeout) { - initialize_once(); - m_data->time_limit = std::chrono::high_resolution_clock::now() + std::chrono::microseconds(timeout); } bool thread_ctrl::wait_wait(u64 timeout) { - initialize_once(); - std::unique_lock lock(m_data->mutex, std::adopt_lock); if (timeout && m_data->cond.wait_until(lock, m_data->time_limit) == std::cv_status::timeout) @@ -1846,11 +1866,12 @@ void thread_ctrl::test() void thread_ctrl::initialize() { - initialize_once(); // TODO (temporarily) - // Initialize TLS variable g_tls_this_thread = this; g_tls_internal = this->m_data; +#ifdef _WIN32 + m_data->thread_id = GetCurrentThreadId(); +#endif g_tls_log_prefix = [] { @@ -1892,6 +1913,10 @@ void thread_ctrl::initialize() void thread_ctrl::finalize() noexcept { + // Disable and discard possible interrupts + interrupt_disable(); + test_interrupt(); + // TODO vm::reservation_free(); @@ -1909,7 +1934,6 @@ void thread_ctrl::finalize() noexcept void thread_ctrl::push_atexit(task_stack task) { - initialize_once(); m_data->atexit.push(std::move(task)); } @@ -1922,6 +1946,8 @@ thread_ctrl::thread_ctrl(std::string&& name) #undef new new (&m_thread) std::thread; #pragma pop_macro("new") + + initialize_once(); } thread_ctrl::~thread_ctrl() @@ -1967,24 +1993,20 @@ void thread_ctrl::join() // Notify others if necessary if (UNLIKELY(m_joining.exchange(0x80000000) != 1)) { - initialize_once(); - // Serialize for reliable notification m_data->mutex.lock(); m_data->mutex.unlock(); - m_data->join.notify_all(); + m_data->jcv.notify_all(); } } else { // Hard way - initialize_once(); - std::unique_lock lock(m_data->mutex); - m_data->join.wait(lock, WRAP_EXPR(m_joining >= 0x80000000)); + m_data->jcv.wait(lock, WRAP_EXPR(m_joining >= 0x80000000)); } - if (UNLIKELY(m_data && m_data->exception)) + if (UNLIKELY(m_data && m_data->exception && !std::uncaught_exception())) { std::rethrow_exception(m_data->exception); } @@ -1992,7 +2014,6 @@ void thread_ctrl::join() void thread_ctrl::lock() { - initialize_once(); m_data->mutex.lock(); } @@ -2008,8 +2029,6 @@ void thread_ctrl::lock_notify() return; } - initialize_once(); - // Serialize for reliable notification, condition is assumed to be changed externally m_data->mutex.lock(); m_data->mutex.unlock(); @@ -2026,6 +2045,116 @@ void thread_ctrl::set_exception(std::exception_ptr e) m_data->exception = e; } +static void _handle_interrupt(x64_context* ctx) +{ + g_tls_internal->thread_ctx = ctx; + thread_ctrl::handle_interrupt(); +} + +void thread_ctrl::handle_interrupt() +{ + const auto _this = g_tls_this_thread; + const auto ctx = g_tls_internal->thread_ctx; + + if (_this->m_guard & 0x80000000) + { + // Discard interrupt if interrupts are disabled + if (g_tls_internal->interrupt.exchange(nullptr)) + { + _this->lock(); + _this->unlock(); + g_tls_internal->icv.notify_one(); + } + } + else if (_this->m_guard == 0) + { + // Set interrupt immediately if no guard set + if (const auto handler = g_tls_internal->interrupt.exchange(nullptr)) + { + _this->lock(); + _this->unlock(); + g_tls_internal->icv.notify_one(); + + // Install function call + *--(u64*&)(RSP(ctx)) = RIP(ctx); + RIP(ctx) = (u64)handler; + } + } + else + { + // Set delayed interrupt otherwise + _this->m_guard |= 0x40000000; + } + +#ifdef _WIN32 + RtlRestoreContext(ctx, nullptr); +#endif +} + +void thread_ctrl::interrupt(void(*handler)()) +{ + VERIFY(this != g_tls_this_thread); // TODO: self-interrupt + VERIFY(m_data->interrupt.compare_and_swap_test(nullptr, handler)); // TODO: multiple interrupts + +#ifdef _WIN32 + const auto ctx = &m_data->_context; + m_data->thread_ctx = ctx; + + const HANDLE nt = OpenThread(THREAD_ALL_ACCESS, FALSE, m_data->thread_id); + VERIFY(nt); + VERIFY(SuspendThread(nt) != -1); + + ctx->ContextFlags = CONTEXT_FULL; + VERIFY(GetThreadContext(nt, ctx)); + + ctx->ContextFlags = CONTEXT_FULL; + const u64 _rip = RIP(ctx); + RIP(ctx) = (u64)std::addressof(thread_ctrl::handle_interrupt); + VERIFY(SetThreadContext(nt, ctx)); + + RIP(ctx) = _rip; + VERIFY(ResumeThread(nt) != -1); + CloseHandle(nt); +#else + pthread_kill(reinterpret_cast(m_thread).native_handle(), SIGUSR1); +#endif + + std::unique_lock lock(m_data->mutex, std::adopt_lock); + + while (m_data->interrupt) + { + m_data->icv.wait(lock); + } + + lock.release(); +} + +void thread_ctrl::test_interrupt() +{ + if (m_guard & 0x80000000) + { + if (m_data->interrupt.exchange(nullptr)) + { + lock(), unlock(), m_data->icv.notify_one(); + } + + return; + } + + if (m_guard == 0x40000000 && !std::uncaught_exception()) + { + m_guard = 0; + + // Execute delayed interrupt handler + if (const auto handler = m_data->interrupt.exchange(nullptr)) + { + lock(), unlock(), m_data->icv.notify_one(); + + return handler(); + } + } +} + void thread_ctrl::sleep(u64 useconds) { std::this_thread::sleep_for(std::chrono::microseconds(useconds)); diff --git a/Utilities/Thread.h b/Utilities/Thread.h index 46782c8c98..4e560d6f14 100644 --- a/Utilities/Thread.h +++ b/Utilities/Thread.h @@ -94,6 +94,9 @@ private: // Thread join contention counter atomic_t m_joining{}; + // Thread interrupt guard counter + u32 m_guard = 0x80000000; + // Thread internals atomic_t m_data{}; @@ -187,6 +190,42 @@ public: // Set exception (internal data must be initialized, thread mutex must be locked) void set_exception(std::exception_ptr); + // Internal + static void handle_interrupt(); + + // Interrupt thread with specified handler call (thread mutex must be locked) + void interrupt(void(*handler)()); + + // Interrupt guard recursive enter + void guard_enter() + { + m_guard++; + } + + // Interrupt guard recursive leave + void guard_leave() + { + if (UNLIKELY(--m_guard & 0x40000000)) + { + test_interrupt(); + } + } + + // Allow interrupts + void interrupt_enable() + { + m_guard &= ~0x80000000; + } + + // Disable and discard any interrupt + void interrupt_disable() + { + m_guard |= 0x80000000; + } + + // Check interrupt if delayed by guard scope + void test_interrupt(); + // Current thread sleeps for specified amount of microseconds. // Wrapper for std::this_thread::sleep, doesn't require valid thread_ctrl. [[deprecated]] static void sleep(u64 useconds); @@ -352,6 +391,36 @@ public: } }; +// Interrupt guard scope +class thread_guard final +{ + thread_ctrl* m_thread; + +public: + thread_guard(const thread_guard&) = delete; + + thread_guard(thread_ctrl* thread) + : m_thread(thread) + { + m_thread->guard_enter(); + } + + thread_guard(named_thread& thread) + : thread_guard(thread.operator->()) + { + } + + thread_guard() + : thread_guard(thread_ctrl::get_current()) + { + } + + ~thread_guard() noexcept(false) + { + m_thread->guard_leave(); + } +}; + // Wrapper for named thread, joins automatically in the destructor, can only be used in function scope class scope_thread final {