LibIPC: Protect underlying socket of TransportSocket with RWLock

This is necessary to prevent the socket from being closed while it is
being used for reading or writing.
This commit is contained in:
Aliaksandr Kalenik 2025-04-10 22:05:11 +02:00
commit 5ed5169d60
2 changed files with 11 additions and 0 deletions

View file

@ -65,6 +65,7 @@ TransportSocket::TransportSocket(NonnullOwnPtr<Core::LocalSocket> socket)
auto [bytes, fds] = send_queue->dequeue(4096); auto [bytes, fds] = send_queue->dequeue(4096);
ReadonlyBytes remaining_to_send_bytes = bytes; ReadonlyBytes remaining_to_send_bytes = bytes;
Threading::RWLockLocker<Threading::LockMode::Read> lock(m_socket_rw_lock);
auto result = send_message(*m_socket, remaining_to_send_bytes, fds); auto result = send_message(*m_socket, remaining_to_send_bytes, fds);
if (result.is_error()) { if (result.is_error()) {
dbgln("TransportSocket::send_thread: {}", result.error()); dbgln("TransportSocket::send_thread: {}", result.error());
@ -104,22 +105,26 @@ TransportSocket::~TransportSocket()
void TransportSocket::set_up_read_hook(Function<void()> hook) void TransportSocket::set_up_read_hook(Function<void()> hook)
{ {
Threading::RWLockLocker<Threading::LockMode::Write> lock(m_socket_rw_lock);
VERIFY(m_socket->is_open()); VERIFY(m_socket->is_open());
m_socket->on_ready_to_read = move(hook); m_socket->on_ready_to_read = move(hook);
} }
bool TransportSocket::is_open() const bool TransportSocket::is_open() const
{ {
Threading::RWLockLocker<Threading::LockMode::Read> lock(m_socket_rw_lock);
return m_socket->is_open(); return m_socket->is_open();
} }
void TransportSocket::close() void TransportSocket::close()
{ {
Threading::RWLockLocker<Threading::LockMode::Write> lock(m_socket_rw_lock);
m_socket->close(); m_socket->close();
} }
void TransportSocket::wait_until_readable() void TransportSocket::wait_until_readable()
{ {
Threading::RWLockLocker<Threading::LockMode::Read> lock(m_socket_rw_lock);
auto maybe_did_become_readable = m_socket->can_read_without_blocking(-1); auto maybe_did_become_readable = m_socket->can_read_without_blocking(-1);
if (maybe_did_become_readable.is_error()) { if (maybe_did_become_readable.is_error()) {
dbgln("TransportSocket::wait_until_readable: {}", maybe_did_become_readable.error()); dbgln("TransportSocket::wait_until_readable: {}", maybe_did_become_readable.error());
@ -194,6 +199,8 @@ ErrorOr<void> TransportSocket::send_message(Core::LocalSocket& socket, ReadonlyB
TransportSocket::ShouldShutdown TransportSocket::read_as_many_messages_as_possible_without_blocking(Function<void(Message)>&& callback) TransportSocket::ShouldShutdown TransportSocket::read_as_many_messages_as_possible_without_blocking(Function<void(Message)>&& callback)
{ {
Threading::RWLockLocker<Threading::LockMode::Read> lock(m_socket_rw_lock);
bool should_shutdown = false; bool should_shutdown = false;
while (is_open()) { while (is_open()) {
u8 buffer[4096]; u8 buffer[4096];
@ -286,11 +293,13 @@ TransportSocket::ShouldShutdown TransportSocket::read_as_many_messages_as_possib
ErrorOr<int> TransportSocket::release_underlying_transport_for_transfer() ErrorOr<int> TransportSocket::release_underlying_transport_for_transfer()
{ {
Threading::RWLockLocker<Threading::LockMode::Write> lock(m_socket_rw_lock);
return m_socket->release_fd(); return m_socket->release_fd();
} }
ErrorOr<IPC::File> TransportSocket::clone_for_transfer() ErrorOr<IPC::File> TransportSocket::clone_for_transfer()
{ {
Threading::RWLockLocker<Threading::LockMode::Write> lock(m_socket_rw_lock);
return IPC::File::clone_fd(m_socket->fd().value()); return IPC::File::clone_fd(m_socket->fd().value());
} }

View file

@ -12,6 +12,7 @@
#include <LibIPC/UnprocessedFileDescriptors.h> #include <LibIPC/UnprocessedFileDescriptors.h>
#include <LibThreading/ConditionVariable.h> #include <LibThreading/ConditionVariable.h>
#include <LibThreading/MutexProtected.h> #include <LibThreading/MutexProtected.h>
#include <LibThreading/RWLock.h>
#include <LibThreading/Thread.h> #include <LibThreading/Thread.h>
namespace IPC { namespace IPC {
@ -104,6 +105,7 @@ private:
static ErrorOr<void> send_message(Core::LocalSocket&, ReadonlyBytes& bytes, Vector<int>& unowned_fds); static ErrorOr<void> send_message(Core::LocalSocket&, ReadonlyBytes& bytes, Vector<int>& unowned_fds);
NonnullOwnPtr<Core::LocalSocket> m_socket; NonnullOwnPtr<Core::LocalSocket> m_socket;
mutable Threading::RWLock m_socket_rw_lock;
ByteBuffer m_unprocessed_bytes; ByteBuffer m_unprocessed_bytes;
UnprocessedFileDescriptors m_unprocessed_fds; UnprocessedFileDescriptors m_unprocessed_fds;