diff --git a/Libraries/LibIPC/Connection.cpp b/Libraries/LibIPC/Connection.cpp index 2e0d103977a..b50401da84a 100644 --- a/Libraries/LibIPC/Connection.cpp +++ b/Libraries/LibIPC/Connection.cpp @@ -14,10 +14,11 @@ namespace IPC { -ConnectionBase::ConnectionBase(IPC::Stub& local_stub, Transport transport, u32 local_endpoint_magic) +ConnectionBase::ConnectionBase(IPC::Stub& local_stub, Transport transport, u32 local_endpoint_magic, u32 peer_endpoint_magic) : m_local_stub(local_stub) , m_transport(move(transport)) , m_local_endpoint_magic(local_endpoint_magic) + , m_peer_endpoint_magic(peer_endpoint_magic) { m_responsiveness_timer = Core::Timer::create_single_shot(3000, [this] { may_have_become_unresponsive(); }); @@ -29,21 +30,27 @@ ConnectionBase::ConnectionBase(IPC::Stub& local_stub, Transport transport, u32 l }); m_send_queue = adopt_ref(*new SendQueue); - m_send_thread = Threading::Thread::construct([this, queue = m_send_queue]() -> intptr_t { + m_acknowledgement_wait_queue = adopt_ref(*new AcknowledgementWaitQueue); + m_send_thread = Threading::Thread::construct([this, send_queue = m_send_queue, acknowledgement_wait_queue = m_acknowledgement_wait_queue]() -> intptr_t { for (;;) { - queue->mutex.lock(); - while (queue->messages.is_empty() && queue->running) - queue->condition.wait(); + send_queue->mutex.lock(); + while (send_queue->messages.is_empty() && send_queue->running) + send_queue->condition.wait(); - if (!queue->running) { - queue->mutex.unlock(); + if (!send_queue->running) { + send_queue->mutex.unlock(); break; } - auto message = queue->messages.take_first(); - queue->mutex.unlock(); + auto [message_buffer, needs_acknowledgement] = send_queue->messages.take_first(); + send_queue->mutex.unlock(); - if (auto result = message.transfer_message(m_transport); result.is_error()) { + if (needs_acknowledgement == MessageNeedsAcknowledgement::Yes) { + Threading::MutexLocker lock(acknowledgement_wait_queue->mutex); + acknowledgement_wait_queue->messages.append(message_buffer); + } + + if (auto result = message_buffer.transfer_message(m_transport); result.is_error()) { dbgln("ConnectionBase::send_thread: {}", result.error()); continue; } @@ -73,7 +80,7 @@ ErrorOr ConnectionBase::post_message(Message const& message) return post_message(message.endpoint_magic(), TRY(message.encode())); } -ErrorOr ConnectionBase::post_message(u32 endpoint_magic, MessageBuffer buffer) +ErrorOr ConnectionBase::post_message(u32 endpoint_magic, MessageBuffer buffer, MessageNeedsAcknowledgement needs_acknowledgement) { // NOTE: If this connection is being shut down, but has not yet been destroyed, // the socket will be closed. Don't try to send more messages. @@ -87,7 +94,7 @@ ErrorOr ConnectionBase::post_message(u32 endpoint_magic, MessageBuffer buf { Threading::MutexLocker locker(m_send_queue->mutex); - m_send_queue->messages.append(move(buffer)); + m_send_queue->messages.append({ move(buffer), needs_acknowledgement }); m_send_queue->condition.signal(); } @@ -218,6 +225,8 @@ OwnPtr ConnectionBase::wait_for_specific_endpoint_message_impl(u32 void ConnectionBase::try_parse_messages(Vector const& bytes, size_t& index) { u32 message_size = 0; + u32 pending_ack_count = 0; + u32 received_ack_count = 0; for (; index + sizeof(message_size) < bytes.size(); index += message_size) { memcpy(&message_size, bytes.data() + index, sizeof(message_size)); if (message_size == 0 || bytes.size() - index - sizeof(uint32_t) < message_size) @@ -232,9 +241,19 @@ void ConnectionBase::try_parse_messages(Vector const& bytes, size_t& index) m_unprocessed_fds.return_fds_to_front_of_queue(wrapper->take_fds()); auto parsed_message = try_parse_message(wrapped_message, m_unprocessed_fds); VERIFY(parsed_message); + VERIFY(parsed_message->message_id() != Acknowledgement::MESSAGE_ID); + pending_ack_count++; m_unprocessed_messages.append(parsed_message.release_nonnull()); continue; } + + if (message->message_id() == Acknowledgement::MESSAGE_ID) { + VERIFY(message->endpoint_magic() == m_local_endpoint_magic); + received_ack_count += static_cast(message.ptr())->ack_count(); + continue; + } + + pending_ack_count++; m_unprocessed_messages.append(message.release_nonnull()); continue; } @@ -243,6 +262,17 @@ void ConnectionBase::try_parse_messages(Vector const& bytes, size_t& index) dbgln("{:hex-dump}", remaining_bytes); break; } + + if (received_ack_count > 0) { + Threading::MutexLocker lock(m_acknowledgement_wait_queue->mutex); + for (size_t i = 0; i < received_ack_count; ++i) + m_acknowledgement_wait_queue->messages.take_first(); + } + + if (is_open() && pending_ack_count > 0) { + auto acknowledgement = Acknowledgement::create(m_peer_endpoint_magic, pending_ack_count); + MUST(post_message(m_peer_endpoint_magic, MUST(acknowledgement->encode()), MessageNeedsAcknowledgement::No)); + } } } diff --git a/Libraries/LibIPC/Connection.h b/Libraries/LibIPC/Connection.h index e489f64ac1f..8577d529f3d 100644 --- a/Libraries/LibIPC/Connection.h +++ b/Libraries/LibIPC/Connection.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -26,9 +27,14 @@ class ConnectionBase : public Core::EventReceiver { public: virtual ~ConnectionBase() override; + enum class MessageNeedsAcknowledgement { + No, + Yes, + }; + [[nodiscard]] bool is_open() const; ErrorOr post_message(Message const&); - ErrorOr post_message(u32 endpoint_magic, MessageBuffer); + ErrorOr post_message(u32 endpoint_magic, MessageBuffer, MessageNeedsAcknowledgement = MessageNeedsAcknowledgement::Yes); void shutdown(); virtual void die() { } @@ -36,7 +42,7 @@ public: Transport& transport() { return m_transport; } protected: - explicit ConnectionBase(IPC::Stub&, Transport, u32 local_endpoint_magic); + explicit ConnectionBase(IPC::Stub&, Transport, u32 local_endpoint_magic, u32 peer_endpoint_magic); virtual void may_have_become_unresponsive() { } virtual void did_become_responsive() { } @@ -62,23 +68,38 @@ protected: ByteBuffer m_unprocessed_bytes; u32 m_local_endpoint_magic { 0 }; + u32 m_peer_endpoint_magic { 0 }; + + struct MessageToSend { + MessageBuffer buffer; + MessageNeedsAcknowledgement needs_acknowledgement { MessageNeedsAcknowledgement::Yes }; + }; struct SendQueue : public AtomicRefCounted { - AK::SinglyLinkedList messages; + AK::SinglyLinkedList messages; Threading::Mutex mutex; Threading::ConditionVariable condition { mutex }; bool running { true }; }; + // After a message is sent, it is moved to the acknowledgement wait queue until an acknowledgement is received from the peer. + // This is necessary to handle a specific behavior of the macOS kernel, which may prematurely garbage-collect the file + // descriptor contained in the message before the peer receives it. https://openradar.me/9477351 + struct AcknowledgementWaitQueue : public AtomicRefCounted { + AK::SinglyLinkedList messages; + Threading::Mutex mutex; + }; + RefPtr m_send_thread; RefPtr m_send_queue; + RefPtr m_acknowledgement_wait_queue; }; template class Connection : public ConnectionBase { public: Connection(IPC::Stub& local_stub, Transport transport) - : ConnectionBase(local_stub, move(transport), LocalEndpoint::static_magic()) + : ConnectionBase(local_stub, move(transport), LocalEndpoint::static_magic(), PeerEndpoint::static_magic()) { } diff --git a/Libraries/LibIPC/Message.cpp b/Libraries/LibIPC/Message.cpp index 063ed4d27ff..0bb186e0669 100644 --- a/Libraries/LibIPC/Message.cpp +++ b/Libraries/LibIPC/Message.cpp @@ -111,4 +111,32 @@ ErrorOr> LargeMessageWrapper::decode(u32 endp return make(endpoint_magic, wrapped_message_data, move(wrapped_fds)); } +Acknowledgement::Acknowledgement(u32 endpoint_magic, u32 ack_count) + : m_endpoint_magic(endpoint_magic) + , m_ack_count(ack_count) +{ +} + +NonnullOwnPtr Acknowledgement::create(u32 endpoint_magic, u32 ack_count) +{ + return make(endpoint_magic, ack_count); +} + +ErrorOr Acknowledgement::encode() const +{ + MessageBuffer buffer; + Encoder stream { buffer }; + TRY(stream.encode(m_endpoint_magic)); + TRY(stream.encode(MESSAGE_ID)); + TRY(stream.encode(m_ack_count)); + return buffer; +} + +ErrorOr> Acknowledgement::decode(u32 endpoint_magic, Stream& stream, UnprocessedFileDescriptors& files) +{ + Decoder decoder { stream, files }; + auto ack_count = TRY(decoder.decode()); + return make(endpoint_magic, ack_count); +} + } diff --git a/Libraries/LibIPC/Message.h b/Libraries/LibIPC/Message.h index 7d1b741225a..6a4ccae1940 100644 --- a/Libraries/LibIPC/Message.h +++ b/Libraries/LibIPC/Message.h @@ -119,4 +119,28 @@ private: Vector m_wrapped_fds; }; +class Acknowledgement : public Message { +public: + ~Acknowledgement() override = default; + + static constexpr int MESSAGE_ID = 0xFFFFFFFF; + + static NonnullOwnPtr create(u32 endpoint_magic, u32 ack_count); + + u32 endpoint_magic() const override { return m_endpoint_magic; } + int message_id() const override { return MESSAGE_ID; } + char const* message_name() const override { return "Acknowledgement"; } + ErrorOr encode() const override; + + static ErrorOr> decode(u32 endpoint_magic, Stream& stream, UnprocessedFileDescriptors& files); + + u32 ack_count() const { return m_ack_count; } + + Acknowledgement(u32 endpoint_magic, u32 number_of_acknowledged_messages); + +private: + u32 m_endpoint_magic { 0 }; + u32 m_ack_count { 0 }; +}; + } diff --git a/Meta/Lagom/Tools/CodeGenerators/IPCCompiler/main.cpp b/Meta/Lagom/Tools/CodeGenerators/IPCCompiler/main.cpp index 4d041620171..0c7c316b877 100644 --- a/Meta/Lagom/Tools/CodeGenerators/IPCCompiler/main.cpp +++ b/Meta/Lagom/Tools/CodeGenerators/IPCCompiler/main.cpp @@ -760,6 +760,8 @@ public: generator.append(R"~~~( case (int)IPC::LargeMessageWrapper::MESSAGE_ID: return TRY(IPC::LargeMessageWrapper::decode(message_endpoint_magic, stream, files)); + case (int)IPC::Acknowledgement::MESSAGE_ID: + return TRY(IPC::Acknowledgement::decode(message_endpoint_magic, stream, files)); )~~~"); generator.append(R"~~~(