diff --git a/Meta/Lagom/Tools/CodeGenerators/IPCCompiler/main.cpp b/Meta/Lagom/Tools/CodeGenerators/IPCCompiler/main.cpp index 47c42a438e5..4299a505e80 100644 --- a/Meta/Lagom/Tools/CodeGenerators/IPCCompiler/main.cpp +++ b/Meta/Lagom/Tools/CodeGenerators/IPCCompiler/main.cpp @@ -372,9 +372,9 @@ public:)~~~"); static i32 static_message_id() { return (int)MessageID::@message.pascal_name@; } virtual const char* message_name() const override { return "@endpoint.name@::@message.pascal_name@"; } - static ErrorOr> decode(Stream& stream, Core::LocalSocket& socket) + static ErrorOr> decode(Stream& stream, Queue& files) { - IPC::Decoder decoder { stream, socket };)~~~"); + IPC::Decoder decoder { stream, files };)~~~"); for (auto const& parameter : parameters) { auto parameter_generator = message_generator.fork(); @@ -620,7 +620,7 @@ public: static u32 static_magic() { return @endpoint.magic@; } - static ErrorOr> decode_message(ReadonlyBytes buffer, [[maybe_unused]] Core::LocalSocket& socket) + static ErrorOr> decode_message(ReadonlyBytes buffer, [[maybe_unused]] Queue& files) { FixedMemoryStream stream { buffer }; auto message_endpoint_magic = TRY(stream.read_value());)~~~"); @@ -649,7 +649,7 @@ public: message_generator.append(R"~~~( case (int)Messages::@endpoint.name@::MessageID::@message.pascal_name@: - return TRY(Messages::@endpoint.name@::@message.pascal_name@::decode(stream, socket));)~~~"); + return TRY(Messages::@endpoint.name@::@message.pascal_name@::decode(stream, files));)~~~"); }; do_decode_message(message.name); diff --git a/Userland/Libraries/LibIPC/Connection.cpp b/Userland/Libraries/LibIPC/Connection.cpp index 7355cd7cd35..6ae16408995 100644 --- a/Userland/Libraries/LibIPC/Connection.cpp +++ b/Userland/Libraries/LibIPC/Connection.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -35,18 +36,6 @@ void ConnectionBase::set_deferred_invoker(NonnullOwnPtr deferre m_deferred_invoker = move(deferred_invoker); } -void ConnectionBase::set_fd_passing_socket(NonnullOwnPtr socket) -{ - m_fd_passing_socket = move(socket); -} - -Core::LocalSocket& ConnectionBase::fd_passing_socket() -{ - if (m_fd_passing_socket) - return *m_fd_passing_socket; - return *m_socket; -} - ErrorOr ConnectionBase::post_message(Message const& message) { return post_message(TRY(message.encode())); @@ -59,7 +48,7 @@ ErrorOr ConnectionBase::post_message(MessageBuffer buffer) if (!m_socket->is_open()) return Error::from_string_literal("Trying to post_message during IPC shutdown"); - if (auto result = buffer.transfer_message(fd_passing_socket(), *m_socket); result.is_error()) { + if (auto result = buffer.transfer_message(*m_socket); result.is_error()) { shutdown_with_error(result.error()); return result.release_error(); } @@ -122,6 +111,7 @@ ErrorOr> ConnectionBase::read_as_much_as_possible_from_socket_without } u8 buffer[4096]; + Vector received_fds; bool should_shut_down = false; auto schedule_shutdown = [this, &should_shut_down]() { @@ -132,7 +122,7 @@ ErrorOr> ConnectionBase::read_as_much_as_possible_from_socket_without }; while (m_socket->is_open()) { - auto maybe_bytes_read = m_socket->read_without_waiting({ buffer, 4096 }); + auto maybe_bytes_read = m_socket->receive_message({ buffer, 4096 }, MSG_DONTWAIT, received_fds); if (maybe_bytes_read.is_error()) { auto error = maybe_bytes_read.release_error(); if (error.is_syscall() && error.code() == EAGAIN) { @@ -156,6 +146,8 @@ ErrorOr> ConnectionBase::read_as_much_as_possible_from_socket_without } bytes.append(bytes_read.data(), bytes_read.size()); + for (auto const& fd : received_fds) + m_unprocessed_fds.enqueue(IPC::File::adopt_fd(fd)); } if (!bytes.is_empty()) { diff --git a/Userland/Libraries/LibIPC/Connection.h b/Userland/Libraries/LibIPC/Connection.h index 717414ed1da..6a7b3006ed3 100644 --- a/Userland/Libraries/LibIPC/Connection.h +++ b/Userland/Libraries/LibIPC/Connection.h @@ -8,12 +8,14 @@ #pragma once #include +#include #include #include #include #include #include #include +#include #include #include #include @@ -38,7 +40,7 @@ class ConnectionBase : public Core::EventReceiver { public: virtual ~ConnectionBase() override = default; - void set_fd_passing_socket(NonnullOwnPtr); + void set_fd_passing_socket(NonnullOwnPtr) { } void set_deferred_invoker(NonnullOwnPtr); DeferredInvoker& deferred_invoker() { return *m_deferred_invoker; } @@ -49,7 +51,7 @@ public: virtual void die() { } Core::LocalSocket& socket() { return *m_socket; } - Core::LocalSocket& fd_passing_socket(); + Core::LocalSocket const& fd_passing_socket() const { return *m_socket; } protected: explicit ConnectionBase(IPC::Stub&, NonnullOwnPtr, u32 local_endpoint_magic); @@ -70,11 +72,11 @@ protected: IPC::Stub& m_local_stub; NonnullOwnPtr m_socket; - OwnPtr m_fd_passing_socket; RefPtr m_responsiveness_timer; Vector> m_unprocessed_messages; + Queue m_unprocessed_fds; ByteBuffer m_unprocessed_bytes; u32 m_local_endpoint_magic { 0 }; @@ -138,13 +140,13 @@ protected: index += sizeof(message_size); auto remaining_bytes = ReadonlyBytes { bytes.data() + index, message_size }; - auto local_message = LocalEndpoint::decode_message(remaining_bytes, fd_passing_socket()); + auto local_message = LocalEndpoint::decode_message(remaining_bytes, m_unprocessed_fds); if (!local_message.is_error()) { m_unprocessed_messages.append(local_message.release_value()); continue; } - auto peer_message = PeerEndpoint::decode_message(remaining_bytes, fd_passing_socket()); + auto peer_message = PeerEndpoint::decode_message(remaining_bytes, m_unprocessed_fds); if (!peer_message.is_error()) { m_unprocessed_messages.append(peer_message.release_value()); continue; diff --git a/Userland/Libraries/LibIPC/Decoder.cpp b/Userland/Libraries/LibIPC/Decoder.cpp index 1abd370ce5c..ed54a7d8894 100644 --- a/Userland/Libraries/LibIPC/Decoder.cpp +++ b/Userland/Libraries/LibIPC/Decoder.cpp @@ -90,8 +90,12 @@ ErrorOr decode(Decoder& decoder) template<> ErrorOr decode(Decoder& decoder) { - int fd = TRY(decoder.socket().receive_fd(O_CLOEXEC)); - return File::adopt_fd(fd); + auto file = TRY(decoder.files().try_dequeue()); + auto fd = file.fd(); + + auto fd_flags = TRY(Core::System::fcntl(fd, F_GETFD)); + TRY(Core::System::fcntl(fd, F_SETFD, fd_flags | FD_CLOEXEC)); + return file; } template<> diff --git a/Userland/Libraries/LibIPC/Decoder.h b/Userland/Libraries/LibIPC/Decoder.h index dd6c6b9b379..a32481de206 100644 --- a/Userland/Libraries/LibIPC/Decoder.h +++ b/Userland/Libraries/LibIPC/Decoder.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -35,9 +36,9 @@ inline ErrorOr decode(Decoder&) class Decoder { public: - Decoder(Stream& stream, Core::LocalSocket& socket) + Decoder(Stream& stream, Queue& files) : m_stream(stream) - , m_socket(socket) + , m_files(files) { } @@ -60,11 +61,11 @@ public: ErrorOr decode_size(); Stream& stream() { return m_stream; } - Core::LocalSocket& socket() { return m_socket; } + Queue& files() { return m_files; } private: Stream& m_stream; - Core::LocalSocket& m_socket; + Queue& m_files; }; template diff --git a/Userland/Libraries/LibIPC/Message.cpp b/Userland/Libraries/LibIPC/Message.cpp index ef96322b71e..a040964246b 100644 --- a/Userland/Libraries/LibIPC/Message.cpp +++ b/Userland/Libraries/LibIPC/Message.cpp @@ -37,7 +37,7 @@ ErrorOr MessageBuffer::append_file_descriptor(int fd) return {}; } -ErrorOr MessageBuffer::transfer_message(Core::LocalSocket& fd_passing_socket, Core::LocalSocket& data_socket) +ErrorOr MessageBuffer::transfer_message(Core::LocalSocket& socket) { Checked checked_message_size { m_data.size() }; checked_message_size -= sizeof(MessageSizeType); @@ -45,17 +45,30 @@ ErrorOr MessageBuffer::transfer_message(Core::LocalSocket& fd_passing_sock if (checked_message_size.has_overflow()) return Error::from_string_literal("Message is too large for IPC encoding"); - auto message_size = checked_message_size.value(); + MessageSizeType const message_size = checked_message_size.value(); m_data.span().overwrite(0, reinterpret_cast(&message_size), sizeof(message_size)); - for (auto const& fd : m_fds) - TRY(fd_passing_socket.send_fd(fd->value())); + auto raw_fds = Vector {}; + auto num_fds_to_transfer = m_fds.size(); + if (num_fds_to_transfer > 0) { + raw_fds.ensure_capacity(num_fds_to_transfer); + for (auto& owned_fd : m_fds) { + raw_fds.unchecked_append(owned_fd->value()); + } + } ReadonlyBytes bytes_to_write { m_data.span() }; size_t writes_done = 0; while (!bytes_to_write.is_empty()) { - auto maybe_nwritten = data_socket.write_some(bytes_to_write); + ErrorOr maybe_nwritten = 0; + if (num_fds_to_transfer > 0) { + maybe_nwritten = socket.send_message(bytes_to_write, 0, raw_fds); + if (!maybe_nwritten.is_error()) + num_fds_to_transfer = 0; + } else { + maybe_nwritten = socket.write_some(bytes_to_write); + } ++writes_done; if (maybe_nwritten.is_error()) { diff --git a/Userland/Libraries/LibIPC/Message.h b/Userland/Libraries/LibIPC/Message.h index cf07a3b109b..faf378ef1b7 100644 --- a/Userland/Libraries/LibIPC/Message.h +++ b/Userland/Libraries/LibIPC/Message.h @@ -44,7 +44,7 @@ public: ErrorOr append_file_descriptor(int fd); - ErrorOr transfer_message(Core::LocalSocket& fd_passing_socket, Core::LocalSocket& data_socket); + ErrorOr transfer_message(Core::LocalSocket& socket); private: Vector m_data; diff --git a/Userland/Libraries/LibWeb/HTML/MessagePort.cpp b/Userland/Libraries/LibWeb/HTML/MessagePort.cpp index 6cf22e73735..d2fae5f876e 100644 --- a/Userland/Libraries/LibWeb/HTML/MessagePort.cpp +++ b/Userland/Libraries/LibWeb/HTML/MessagePort.cpp @@ -5,6 +5,7 @@ * SPDX-License-Identifier: BSD-2-Clause */ +#include #include #include #include @@ -259,7 +260,7 @@ ErrorOr MessagePort::send_message_on_socket(SerializedTransferRecord const IPC::Encoder encoder(buffer); MUST(encoder.encode(serialize_with_transfer_result)); - TRY(buffer.transfer_message(*m_fd_passing_socket, *m_socket)); + TRY(buffer.transfer_message(*m_socket)); return {}; } @@ -276,41 +277,76 @@ void MessagePort::post_port_message(SerializedTransferRecord serialize_with_tran }); } -void MessagePort::read_from_socket() +ErrorOr MessagePort::parse_message() { - auto num_bytes_ready = MUST(m_socket->pending_bytes()); + static constexpr size_t HEADER_SIZE = sizeof(u32); + + auto num_bytes_ready = m_buffered_data.size(); switch (m_socket_state) { case SocketState::Header: { - if (num_bytes_ready < sizeof(u32)) - break; - m_socket_incoming_message_size = MUST(m_socket->read_value()); - num_bytes_ready -= sizeof(u32); + if (num_bytes_ready < HEADER_SIZE) + return ParseDecision::NotEnoughData; + + m_socket_incoming_message_size = ByteReader::load32(m_buffered_data.data()); + // NOTE: We don't decrement the number of ready bytes because we want to remove the entire + // message + header from the buffer in one go on success m_socket_state = SocketState::Data; - } [[fallthrough]]; + } case SocketState::Data: { if (num_bytes_ready < m_socket_incoming_message_size) - break; + return ParseDecision::NotEnoughData; - Vector data; - data.resize(m_socket_incoming_message_size, true); - MUST(m_socket->read_until_filled(data)); + auto payload = m_buffered_data.span().slice(HEADER_SIZE, m_socket_incoming_message_size); - FixedMemoryStream stream { data, FixedMemoryStream::Mode::ReadOnly }; - IPC::Decoder decoder(stream, *m_fd_passing_socket); + FixedMemoryStream stream { payload, FixedMemoryStream::Mode::ReadOnly }; + IPC::Decoder decoder { stream, m_unprocessed_fds }; - auto serialize_with_transfer_result = MUST(decoder.decode()); + auto serialized_transfer_record = TRY(decoder.decode()); // Make sure to advance our state machine before dispatching the MessageEvent, // as dispatching events can run arbitrary JS (and cause us to receive another message!) m_socket_state = SocketState::Header; - post_message_task_steps(serialize_with_transfer_result); + m_buffered_data.remove(0, HEADER_SIZE + m_socket_incoming_message_size); + + post_message_task_steps(serialized_transfer_record); + break; } case SocketState::Error: - VERIFY_NOT_REACHED(); - break; + return Error::from_errno(ENOMSG); + } + + return ParseDecision::ParseNextMessage; +} + +void MessagePort::read_from_socket() +{ + u8 buffer[4096] {}; + + Vector fds; + // FIXME: What if pending bytes is > 4096? Should we loop here? + auto maybe_bytes = m_socket->receive_message(buffer, MSG_NOSIGNAL, fds); + if (maybe_bytes.is_error()) { + dbgln("MessagePort::read_from_socket(): Failed to receive message: {}", maybe_bytes.error()); + return; + } + auto bytes = maybe_bytes.release_value(); + + m_buffered_data.append(bytes.data(), bytes.size()); + + for (auto fd : fds) + m_unprocessed_fds.enqueue(IPC::File::adopt_fd(fd)); + + while (true) { + auto parse_decision_or_error = parse_message(); + if (parse_decision_or_error.is_error()) { + dbgln("MessagePort::read_from_socket(): Failed to parse message: {}", parse_decision_or_error.error()); + return; + } + if (parse_decision_or_error.value() == ParseDecision::NotEnoughData) + break; } } diff --git a/Userland/Libraries/LibWeb/HTML/MessagePort.h b/Userland/Libraries/LibWeb/HTML/MessagePort.h index 10caba667af..ccbe35a8fe8 100644 --- a/Userland/Libraries/LibWeb/HTML/MessagePort.h +++ b/Userland/Libraries/LibWeb/HTML/MessagePort.h @@ -78,6 +78,12 @@ private: ErrorOr send_message_on_socket(SerializedTransferRecord const&); void read_from_socket(); + enum class ParseDecision { + NotEnoughData, + ParseNextMessage, + }; + ErrorOr parse_message(); + // The HTML spec implies(!) that this is MessagePort.[[RemotePort]] JS::GCPtr m_remote_port; @@ -93,6 +99,8 @@ private: Error, } m_socket_state { SocketState::Header }; size_t m_socket_incoming_message_size { 0 }; + Queue m_unprocessed_fds; + Vector m_buffered_data; JS::GCPtr m_worker_event_target; };