diff --git a/Libraries/LibIPC/Connection.cpp b/Libraries/LibIPC/Connection.cpp index a93d25875fb..247eb412167 100644 --- a/Libraries/LibIPC/Connection.cpp +++ b/Libraries/LibIPC/Connection.cpp @@ -1,5 +1,6 @@ /* * Copyright (c) 2021-2024, Andreas Kling + * Copyright (c) 2025, Simon Farre * Copyright (c) 2022, the SerenityOS developers. * * SPDX-License-Identifier: BSD-2-Clause @@ -9,7 +10,6 @@ #include #include #include -#include #include namespace IPC { @@ -18,6 +18,7 @@ ConnectionBase::ConnectionBase(IPC::Stub& local_stub, Transport transport, u32 l : m_local_stub(local_stub) , m_transport(move(transport)) , m_local_endpoint_magic(local_endpoint_magic) + , m_event_loop(Core::EventLoop::current()) { m_responsiveness_timer = Core::Timer::create_single_shot(3000, [this] { may_have_become_unresponsive(); }); @@ -25,7 +26,7 @@ ConnectionBase::ConnectionBase(IPC::Stub& local_stub, Transport transport, u32 l NonnullRefPtr protect = *this; // FIXME: Do something about errors. (void)drain_messages_from_peer(); - handle_messages(); + process_messages(); }); m_send_queue = adopt_ref(*new SendQueue); @@ -102,26 +103,6 @@ void ConnectionBase::shutdown_with_error(Error const& error) shutdown(); } -void ConnectionBase::handle_messages() -{ - auto messages = move(m_unprocessed_messages); - for (auto& message : messages) { - if (message->endpoint_magic() == m_local_endpoint_magic) { - auto handler_result = m_local_stub.handle(move(message)); - if (handler_result.is_error()) { - dbgln("IPC::ConnectionBase::handle_messages: {}", handler_result.error()); - continue; - } - - if (auto response = handler_result.release_value()) { - if (auto post_result = post_message(*response); post_result.is_error()) { - dbgln("IPC::ConnectionBase::handle_messages: {}", post_result.error()); - } - } - } - } -} - void ConnectionBase::wait_for_transport_to_become_readable() { m_transport.wait_until_readable(); @@ -160,6 +141,33 @@ ErrorOr> ConnectionBase::read_as_much_as_possible_from_transport_with return bytes; } +void ConnectionBase::process_messages() +{ + if (m_unprocessed_messages.is_empty()) { + return; + } + Vector> messages; + swap(m_unprocessed_messages, messages); + + for (auto& message : messages) { + if (message->is_response()) { + auto it = find_if(m_resolvers.begin(), m_resolvers.end(), [id = message->ipc_request_id()](auto const& res) { return res.request_id == id; }); + if (it == m_resolvers.end()) { + continue; + } + it->promise_completer(move(message)); + m_resolvers.remove(std::distance(m_resolvers.begin(), it)); + } else { + // Handle message on thread this connection was created on. + // We don't necessarily need to defer invoke here, but in a future where we may want + // to have a dedicated IO thread, it needs to post the work back to the thread it needs to run on. + m_event_loop.deferred_invoke([this, msg = move(message)]() mutable { + m_local_stub.handle_ipc_message(NonnullRefPtr { *this }, move(msg)); + }); + } + } +} + ErrorOr ConnectionBase::drain_messages_from_peer() { auto bytes = TRY(read_as_much_as_possible_from_transport_without_blocking()); @@ -181,13 +189,13 @@ ErrorOr ConnectionBase::drain_messages_from_peer() if (!m_unprocessed_messages.is_empty()) { deferred_invoke([this] { - handle_messages(); + process_messages(); }); } return {}; } -OwnPtr ConnectionBase::wait_for_specific_endpoint_message_impl(u32 endpoint_magic, int message_id) +OwnPtr ConnectionBase::wait_for_specific_endpoint_message_impl(u64 request_id, u32 endpoint_magic, int message_id) { for (;;) { // Double check we don't already have the event waiting for us. @@ -196,7 +204,7 @@ OwnPtr ConnectionBase::wait_for_specific_endpoint_message_impl(u32 auto& message = m_unprocessed_messages[i]; if (message->endpoint_magic() != endpoint_magic) continue; - if (message->message_id() == message_id) + if (message->message_id() == message_id && message->ipc_request_id() == request_id) return m_unprocessed_messages.take(i); } diff --git a/Libraries/LibIPC/Connection.h b/Libraries/LibIPC/Connection.h index 9d100df40d5..b7f9d9d2635 100644 --- a/Libraries/LibIPC/Connection.h +++ b/Libraries/LibIPC/Connection.h @@ -1,5 +1,6 @@ /* * Copyright (c) 2018-2024, Andreas Kling + * Copyright (c) 2025, Simon Farre * Copyright (c) 2022, the SerenityOS developers. * * SPDX-License-Identifier: BSD-2-Clause @@ -9,9 +10,14 @@ #include #include +#include +#include #include +#include +#include #include #include +#include #include #include #include @@ -19,9 +25,16 @@ namespace IPC { +struct Completer { + u64 request_id { 0 }; + Function)> promise_completer; +}; + class ConnectionBase : public Core::EventReceiver { C_OBJECT_ABSTRACT(ConnectionBase); + void process_messages(); + public: virtual ~ConnectionBase() override; @@ -34,6 +47,8 @@ public: Transport& transport() { return m_transport; } + ErrorOr drain_messages_from_peer(); + protected: explicit ConnectionBase(IPC::Stub&, Transport, u32 local_endpoint_magic); @@ -42,18 +57,14 @@ protected: virtual void shutdown_with_error(Error const&); virtual OwnPtr try_parse_message(ReadonlyBytes, Queue&) = 0; - OwnPtr wait_for_specific_endpoint_message_impl(u32 endpoint_magic, int message_id); + OwnPtr wait_for_specific_endpoint_message_impl(u64 request_id, u32 endpoint_magic, int message_id); void wait_for_transport_to_become_readable(); ErrorOr> read_as_much_as_possible_from_transport_without_blocking(); - ErrorOr drain_messages_from_peer(); + void try_parse_messages(Vector const& bytes, size_t& index); - void handle_messages(); - IPC::Stub& m_local_stub; - Transport m_transport; - RefPtr m_responsiveness_timer; Vector> m_unprocessed_messages; @@ -71,6 +82,9 @@ protected: RefPtr m_send_thread; RefPtr m_send_queue; + // Arbitrary inline size. + Vector m_resolvers; + Core::EventLoop& m_event_loop; }; template @@ -81,6 +95,8 @@ public: { } + ~Connection() override = default; + template OwnPtr wait_for_specific_message() { @@ -90,25 +106,41 @@ public: template NonnullOwnPtr send_sync(Args&&... args) { - MUST(post_message(RequestType(forward(args)...))); - auto response = wait_for_specific_endpoint_message(); + auto const request_id = LocalEndpoint::next_ipc_request_id(); + MUST(post_message(RequestType(request_id, forward(args)...))); + auto response = wait_for_specific_endpoint_message(request_id); VERIFY(response); return response.release_nonnull(); } + template + NonnullRefPtr>> send(Args&&... args) + { + using Promise = Core::Promise>; + auto promise = Promise::construct(); + auto const msg = RequestType { LocalEndpoint::next_ipc_request_id(), forward(args)... }; + MUST(post_message(msg)); + m_resolvers.empend(msg.ipc_request_id(), [promise, msg_id = msg.message_id()](OwnPtr msg) { + ASSERT(msg->message_id() == msg_id + 1 && msg->endpoint_magic() == PeerEndpoint::static_magic() && msg->ipc_request_id() == msg_id); + promise->resolve(msg.release_nonnull()); + }); + return promise; + } + template OwnPtr send_sync_but_allow_failure(Args&&... args) { - if (post_message(RequestType(forward(args)...)).is_error()) + auto const request_id = LocalEndpoint::next_ipc_request_id(); + if (post_message(RequestType(request_id, forward(args)...)).is_error()) return nullptr; - return wait_for_specific_endpoint_message(); + return wait_for_specific_endpoint_message(request_id); } protected: template - OwnPtr wait_for_specific_endpoint_message() + OwnPtr wait_for_specific_endpoint_message(u64 request_id) { - if (auto message = wait_for_specific_endpoint_message_impl(Endpoint::static_magic(), MessageType::static_message_id())) + if (auto message = wait_for_specific_endpoint_message_impl(request_id, Endpoint::static_magic(), MessageType::static_message_id())) return message.template release_nonnull(); return {}; } @@ -126,5 +158,4 @@ protected: return nullptr; } }; - } diff --git a/Libraries/LibIPC/Message.h b/Libraries/LibIPC/Message.h index 8321973269f..e76de578a3c 100644 --- a/Libraries/LibIPC/Message.h +++ b/Libraries/LibIPC/Message.h @@ -40,6 +40,13 @@ class MessageBuffer { public: MessageBuffer(); + // Do not copy message buffers. + MessageBuffer(MessageBuffer const&) = delete; + MessageBuffer& operator=(MessageBuffer const&) = delete; + + MessageBuffer(MessageBuffer&&) noexcept = default; + MessageBuffer& operator=(MessageBuffer&&) noexcept = default; + ErrorOr extend_data_capacity(size_t capacity); ErrorOr append_data(u8 const* values, size_t count); @@ -70,9 +77,17 @@ public: virtual int message_id() const = 0; virtual char const* message_name() const = 0; virtual ErrorOr encode() const = 0; + u64 ipc_request_id() const { return m_request_id; } + void set_ipc_request_id(u64 request_id) { m_request_id = request_id; } + constexpr bool is_response() const { return (message_id() % 2) == 0; } protected: - Message() = default; -}; + explicit Message(u64 request_id) + : m_request_id(request_id) + { + } +private: + u64 m_request_id; +}; } diff --git a/Libraries/LibIPC/Stub.h b/Libraries/LibIPC/Stub.h index 4a444cb4c79..966c6311ac9 100644 --- a/Libraries/LibIPC/Stub.h +++ b/Libraries/LibIPC/Stub.h @@ -17,13 +17,15 @@ class BufferStream; namespace IPC { +class ConnectionBase; + class Stub { public: virtual ~Stub() = default; virtual u32 magic() const = 0; virtual ByteString name() const = 0; - virtual ErrorOr> handle(NonnullOwnPtr) = 0; + virtual void handle_ipc_message(NonnullRefPtr conn, NonnullOwnPtr message) = 0; protected: Stub() = default; diff --git a/Meta/Lagom/Tools/CodeGenerators/IPCCompiler/main.cpp b/Meta/Lagom/Tools/CodeGenerators/IPCCompiler/main.cpp index c7fe62dfa4a..8284911d37a 100644 --- a/Meta/Lagom/Tools/CodeGenerators/IPCCompiler/main.cpp +++ b/Meta/Lagom/Tools/CodeGenerators/IPCCompiler/main.cpp @@ -1,9 +1,11 @@ /* * Copyright (c) 2018-2020, Andreas Kling + * Copyright (c) 2025, Simon Farre * * SPDX-License-Identifier: BSD-2-Clause */ +#include #include #include #include @@ -245,7 +247,10 @@ Vector parse(ByteBuffer const& file_contents) consume_whitespace(); + // We have an unfortunate naming of "is_synchronous" when it really means "has response" + message.inputs.prepend(Parameter { .attributes = {}, .type = "u64", .type_for_encoding = "u64", .name = "ipc_request_id" }); if (message.is_synchronous) { + message.outputs.prepend(Parameter { .attributes = {}, .type = "u64", .type_for_encoding = "u64", .name = "ipc_request_id" }); assert_specific('('); parse_parameters(message.outputs, message.name); assert_specific(')'); @@ -312,59 +317,60 @@ HashMap build_message_ids_for_endpoint(SourceGenerator generato HashMap message_ids; generator.appendln("\nenum class MessageID : i32 {"); + auto id = 1; for (auto const& message : endpoint.messages) { - - message_ids.set(message.name, message_ids.size() + 1); + message_ids.set(message.name, id); generator.set("message.pascal_name", pascal_case(message.name)); - generator.set("message.id", ByteString::number(message_ids.size())); + generator.set("message.id", ByteString::number(id++)); generator.appendln(" @message.pascal_name@ = @message.id@,"); if (message.is_synchronous) { - message_ids.set(message.response_name(), message_ids.size() + 1); + message_ids.set(message.response_name(), id); generator.set("message.pascal_name", pascal_case(message.response_name())); - generator.set("message.id", ByteString::number(message_ids.size())); - + generator.set("message.id", ByteString::number(id++)); generator.appendln(" @message.pascal_name@ = @message.id@,"); } + + id += message.is_synchronous ? 0 : 1; } generator.appendln("};"); return message_ids; } -ByteString constructor_for_message(ByteString const& name, Vector const& parameters) +ByteString constructor_for_message(bool is_response, ByteString const& name, ReadonlySpan parameters) { StringBuilder builder; builder.append(name); - if (parameters.is_empty()) { - builder.append("() {}"sv); - return builder.to_byte_string(); - } builder.append('('); - for (size_t i = 0; i < parameters.size(); ++i) { + for (size_t i = is_response ? 1 : 0; i < parameters.size(); ++i) { auto const& parameter = parameters[i]; builder.appendff("{} {}", parameter.type, parameter.name); if (i != parameters.size() - 1) builder.append(", "sv); } - builder.append(") : "sv); - for (size_t i = 0; i < parameters.size(); ++i) { + + if (is_response) + builder.append(") : Message(0) "sv); + else + builder.append(") : Message(ipc_request_id) "sv); + + for (size_t i = 1; i < parameters.size(); ++i) { + builder.append(", "sv); auto const& parameter = parameters[i]; builder.appendff("m_{}(move({}))", parameter.name, parameter.name); - if (i != parameters.size() - 1) - builder.append(", "sv); } builder.append(" {}"sv); return builder.to_byte_string(); } -void do_message(SourceGenerator message_generator, ByteString const& name, Vector const& parameters, ByteString const& response_type = {}) +void do_message(bool is_response, SourceGenerator message_generator, ByteString const& name, ReadonlySpan parameters, ByteString const& response_type = {}) { auto pascal_name = pascal_case(name); message_generator.set("message.name", name); message_generator.set("message.pascal_name", pascal_name); message_generator.set("message.response_type", response_type); - message_generator.set("message.constructor", constructor_for_message(pascal_name, parameters)); + message_generator.set("message.constructor", constructor_for_message(is_response, pascal_name, parameters)); message_generator.append(R"~~~( class @message.pascal_name@ final : public IPC::Message { @@ -372,24 +378,30 @@ public:)~~~"); if (!response_type.is_empty()) message_generator.appendln(R"~~~( - typedef class @message.response_type@ ResponseType;)~~~"); + typedef class @message.response_type@ ResponseType; + using Promise = Core::Promise;)~~~"); + + if (is_response) { + message_generator.appendln(R"~~~( + using Promise = Core::Promise<@message.pascal_name@>;)~~~"); + } message_generator.appendln(R"~~~( - @message.pascal_name@(@message.pascal_name@ const&) = default; @message.pascal_name@(@message.pascal_name@&&) = default; - @message.pascal_name@& operator=(@message.pascal_name@ const&) = default; + @message.pascal_name@(@message.pascal_name@ const&) = delete; + @message.pascal_name@& operator=(@message.pascal_name@ const&) = delete; @message.constructor@)~~~"); - if (parameters.size() == 1) { - auto const& parameter = parameters[0]; + if (parameters.size() == 2) { + auto const& parameter = parameters[1]; message_generator.set("parameter.type"sv, parameter.type); message_generator.set("parameter.name"sv, parameter.name); message_generator.appendln(R"~~~( template requires(!SameAs) - @message.pascal_name@(WrappedReturnType&& value) - : m_@parameter.name@(forward(value)) + @message.pascal_name@(u64 sequence_id, WrappedReturnType&& value) + : Message(sequence_id), m_@parameter.name@(forward(value)) { })~~~"); } @@ -430,7 +442,7 @@ public:)~~~"); } StringBuilder builder; - for (size_t i = 0; i < parameters.size(); ++i) { + for (size_t i = is_response ? 1 : 0; i < parameters.size(); ++i) { auto const& parameter = parameters[i]; builder.appendff("move({})", parameter.name); if (i != parameters.size() - 1) @@ -438,9 +450,17 @@ public:)~~~"); } message_generator.set("message.constructor_call_parameters", builder.to_byte_string()); - message_generator.appendln(R"~~~( - return make<@message.pascal_name@>(@message.constructor_call_parameters@); + if (is_response) { + message_generator.appendln(R"~~~( + auto result = make<@message.pascal_name@>(@message.constructor_call_parameters@); + result->set_ipc_request_id(ipc_request_id); + return result; })~~~"); + } else { + message_generator.appendln(R"~~~( + return make<@message.pascal_name@>(@message.constructor_call_parameters@); + })~~~"); + } message_generator.append(R"~~~( static ErrorOr static_encode()~~~"); @@ -476,20 +496,19 @@ public:)~~~"); message_generator.append(R"~~~( virtual ErrorOr encode() const override { - return static_encode()~~~"); - - for (auto const& [i, parameter] : enumerate(parameters)) { + return static_encode(ipc_request_id())~~~"); + auto parameters_without_seqid = parameters.slice(1); + for (auto const& parameter : parameters_without_seqid) { auto parameter_generator = message_generator.fork(); + parameter_generator.append(", "); parameter_generator.set("parameter.name", parameter.name); parameter_generator.append("m_@parameter.name@"); - if (i != parameters.size() - 1) - parameter_generator.append(", "); } message_generator.appendln(R"~~~(); })~~~"); - for (auto const& parameter : parameters) { + for (auto const& parameter : parameters_without_seqid) { auto parameter_generator = message_generator.fork(); parameter_generator.set("parameter.type", parameter.type); parameter_generator.set("parameter.name", parameter.name); @@ -506,8 +525,7 @@ public:)~~~"); message_generator.append(R"~~~( private:)~~~"); - - for (auto const& parameter : parameters) { + for (auto const& parameter : parameters_without_seqid) { auto parameter_generator = message_generator.fork(); parameter_generator.set("parameter.type", parameter.type); parameter_generator.set("parameter.name", parameter.name); @@ -518,7 +536,54 @@ private:)~~~"); message_generator.appendln("\n};"); } -void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& endpoint, Message const& message, ByteString const& name, Vector const& parameters, bool is_synchronous, bool is_try, bool is_utf8_string_overload = false) +void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& endpoint, Message const& message, ByteString const& name, ReadonlySpan parameters, bool is_synchronous, bool is_try, bool is_utf8_string_overload); + +void generate_proxy_awaitable_method(SourceGenerator& message_generator, Message const& message, ByteString const& name, ReadonlySpan parameters_without_seqid, bool is_utf8_string_overload = false) +{ + ASSERT(message.is_synchronous); + + message_generator.set("message.name", message.name); + message_generator.set("message.pascal_name", pascal_case(message.name)); + + message_generator.set("handler_name", name); + message_generator.append(R"~~~( + auto send_@handler_name@()~~~"); + + for (auto const& [i, parameter] : enumerate(parameters_without_seqid)) { + ByteString type; + if (is_utf8_string_overload) + type = make_argument_type(parameter.type); + else + type = make_argument_type(parameter.type_for_encoding); + + auto argument_generator = message_generator.fork(); + argument_generator.set("argument.type", type); + argument_generator.set("argument.name", parameter.name); + argument_generator.append("@argument.type@ @argument.name@"); + if (i != parameters_without_seqid.size() - 1) + argument_generator.append(", "); + } + + message_generator.append(") {"); + message_generator.append(R"~~~( + return m_connection. template send()~~~"); + + for (auto const& [i, parameter] : enumerate(parameters_without_seqid)) { + auto argument_generator = message_generator.fork(); + argument_generator.set("argument.name", parameter.name); + if (is_primitive_or_simple_type(parameter.type)) + argument_generator.append("@argument.name@"); + else + argument_generator.append("move(@argument.name@)"); + if (i != parameters_without_seqid.size() - 1) + argument_generator.append(", "); + } + + message_generator.appendln(R"~~~(); + })~~~"); +} + +void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& endpoint, Message const& message, ByteString const& name, ReadonlySpan parameters, bool is_synchronous, bool is_try, bool is_utf8_string_overload = false) { // FIXME: For String parameters, we want to retain the property that all tranferred String objects are strictly UTF-8. // So instead of generating a single proxy method that accepts StringView parameters, we generate two overloads. @@ -531,9 +596,9 @@ void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& e ByteString return_type = "void"; if (is_synchronous) { - if (message.outputs.size() == 1) - return_type = message.outputs[0].type; - else if (!message.outputs.is_empty()) + if (message.outputs.size() == 2) + return_type = message.outputs[1].type; + else if (message.outputs.size() > 1) return_type = message_name(endpoint.name, message.name, true); } ByteString inner_return_type = return_type; @@ -550,7 +615,8 @@ void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& e message_generator.append(R"~~~( @message.complex_return_type@ @try_prefix_maybe@@async_prefix_maybe@@handler_name@()~~~"); - for (auto const& [i, parameter] : enumerate(parameters)) { + auto parameters_without_seqid = parameters.slice(1); + for (auto const& [i, parameter] : enumerate(parameters_without_seqid)) { ByteString type; if (is_synchronous || is_try) type = parameter.type; @@ -563,7 +629,7 @@ void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& e argument_generator.set("argument.type", type); argument_generator.set("argument.name", parameter.name); argument_generator.append("@argument.type@ @argument.name@"); - if (i != parameters.size() - 1) + if (i != parameters_without_seqid.size() - 1) argument_generator.append(", "); } @@ -584,11 +650,12 @@ void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& e } } + bool was_static = false; if (is_synchronous && !is_try) { if (return_type != "void") { message_generator.append(R"~~~( return )~~~"); - if (message.outputs.size() != 1) + if (message.outputs.size() != 2) message_generator.append("move(*"); } else { message_generator.append(R"~~~( @@ -600,20 +667,25 @@ void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& e message_generator.append(R"~~~( auto result = m_connection.template send_sync_but_allow_failure()~~~"); } else { + was_static = true; message_generator.append(R"~~~( - auto message_buffer = MUST(Messages::@endpoint.name@::@message.pascal_name@::static_encode()~~~"); + auto message_buffer = MUST(Messages::@endpoint.name@::@message.pascal_name@::static_encode(LocalEndpoint::next_ipc_request_id())~~~"); } - for (auto const& [i, parameter] : enumerate(parameters)) { + for (auto const& [i, parameter] : enumerate(parameters_without_seqid)) { auto const& type = is_synchronous || is_try ? parameter.type : parameter.type_for_encoding; auto argument_generator = message_generator.fork(); argument_generator.set("argument.name", parameter.name); + if (was_static) { + argument_generator.append(", "); + was_static = false; + } if (is_primitive_or_simple_type(type)) argument_generator.append("@argument.name@"); else argument_generator.append("move(@argument.name@)"); - if (i != parameters.size() - 1) + if (i != parameters_without_seqid.size() - 1) argument_generator.append(", "); } @@ -622,13 +694,13 @@ void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& e message_generator.append(")"); } - if (message.outputs.size() == 1) { - if (is_primitive_or_simple_type(message.outputs[0].type)) + if (message.outputs.size() == 2) { + if (is_primitive_or_simple_type(message.outputs[1].type)) message_generator.append("->"); else message_generator.append("->take_"); - message_generator.append(message.outputs[0].name); + message_generator.append(message.outputs[1].name); message_generator.append("()"); } else message_generator.append(")"); @@ -665,6 +737,7 @@ void do_message_for_proxy(SourceGenerator message_generator, Endpoint const& end if (message.is_synchronous) { generate_proxy_method(message_generator, endpoint, message, message.name, message.inputs, false, false); generate_proxy_method(message_generator, endpoint, message, message.name, message.inputs, true, true); + generate_proxy_awaitable_method(message_generator, message, message.name, ReadonlySpan { message.inputs }.slice(1)); } } @@ -681,9 +754,10 @@ void build_endpoint(SourceGenerator generator, Endpoint const& endpoint) ByteString response_name; if (message.is_synchronous) { response_name = message.response_name(); - do_message(generator.fork(), response_name, message.outputs); + VERIFY(message.outputs.size() >= 1); + do_message(true, generator.fork(), response_name, message.outputs); } - do_message(generator.fork(), message.name, message.inputs, response_name); + do_message(false, generator.fork(), message.name, message.inputs, response_name); } generator.appendln(R"~~~( @@ -714,6 +788,11 @@ class @endpoint.name@Stub; class @endpoint.name@Endpoint { public: + static u64 next_ipc_request_id() { + static u64 request_id = 1; + return request_id++; + } + template using Proxy = @endpoint.name@Proxy; using Stub = @endpoint.name@Stub; @@ -781,15 +860,16 @@ public: virtual u32 magic() const override { return @endpoint.magic@; } virtual ByteString name() const override { return "@endpoint.name@"; } - virtual ErrorOr> handle(NonnullOwnPtr message) override + virtual void handle_ipc_message([[maybe_unused]] NonnullRefPtr conn, NonnullOwnPtr message) override { switch (message->message_id()) {)~~~"); for (auto const& message : endpoint.messages) { - auto do_handle_message = [&](ByteString const& name, Vector const& parameters, bool returns_something) { + auto do_handle_message = [&](ByteString const& name, ReadonlySpan parameters, ReadonlySpan response_args, bool returns_something) { auto message_generator = generator.fork(); - StringBuilder argument_generator; - for (auto const& [i, parameter] : enumerate(parameters)) { + auto parameters_without_seqid = parameters.slice(1); + for (auto const& [i, parameter] : enumerate(parameters_without_seqid)) { + if (is_primitive_or_simple_type(parameter.type)) argument_generator.append("request."sv); else @@ -797,60 +877,75 @@ public: argument_generator.append(parameter.name); argument_generator.append("()"sv); - if (i != parameters.size() - 1) + if (i != parameters_without_seqid.size() - 1) { argument_generator.append(", "sv); + } + } + + StringBuilder arguments_types_generator; + if (returns_something) { + for (auto const& parameter : response_args.slice(1)) { + arguments_types_generator.append(", const "sv); + arguments_types_generator.append(parameter.type); + arguments_types_generator.append("&"sv); + } } message_generator.set("message.pascal_name", pascal_case(name)); message_generator.set("message.response_type", pascal_case(message.response_name())); message_generator.set("handler_name", name); + message_generator.set("arguments_types", arguments_types_generator.to_byte_string()); message_generator.set("arguments", argument_generator.to_byte_string()); message_generator.append(R"~~~( case (int)Messages::@endpoint.name@::MessageID::@message.pascal_name@: {)~~~"); if (returns_something) { - if (message.outputs.is_empty()) { - message_generator.append(R"~~~( + message_generator.append(R"~~~( [[maybe_unused]] auto& request = static_cast(*message); - @handler_name@(@arguments@); - auto response = Messages::@endpoint.name@::@message.response_type@ { }; - return make(TRY(response.encode()));)~~~"); - } else { - message_generator.append(R"~~~( - [[maybe_unused]] auto& request = static_cast(*message); - auto response = @handler_name@(@arguments@); - return make(TRY(response.encode()));)~~~"); + auto promise = @handler_name@(@arguments@); + promise->when_resolved([id = message->ipc_request_id(), conn](auto&& response) { + response.set_ipc_request_id(id); + auto res = response.encode(); + if (auto const post_result = conn->post_message(res.release_value()); post_result.is_error()) { + dbgln("IPC::ConnectionBase::handle_messages: {}", post_result.error()); } + }); + break;)~~~"); } else { message_generator.append(R"~~~( [[maybe_unused]] auto& request = static_cast(*message); @handler_name@(@arguments@); - return nullptr;)~~~"); + break;)~~~"); } message_generator.append(R"~~~( })~~~"); }; - do_handle_message(message.name, message.inputs, message.is_synchronous); + do_handle_message(message.name, message.inputs, message.outputs, message.is_synchronous); } generator.appendln(R"~~~( default: - return Error::from_string_literal("Unknown message ID for @endpoint.name@ endpoint"); + dbgln("Unknown message ID for @endpoint.name@ endpoint"); } })~~~"); for (auto const& message : endpoint.messages) { auto message_generator = generator.fork(); + ByteString return_type = "void"; + if (message.is_synchronous) { + message_generator.set("response.type", pascal_case(message.response_name())); + } auto do_handle_message_decl = [&](ByteString const& name, Vector const& parameters) { - ByteString return_type = "void"; - if (message.is_synchronous && !message.outputs.is_empty()) - return_type = message_name(endpoint.name, message.name, true); - message_generator.set("message.complex_return_type", return_type); - message_generator.set("handler_name", name); - message_generator.append(R"~~~( - virtual @message.complex_return_type@ @handler_name@()~~~"); - - for (size_t i = 0; i < parameters.size(); ++i) { + if (message.is_synchronous) { + message_generator.append(R"~~~( + virtual NonnullRefPtr @handler_name@()~~~"); + } else { + message_generator.append(R"~~~( + virtual void @handler_name@()~~~"); + } + // The IPC request sequence id should not be a part of the signature. The user should + // not have to be involved with implementation details like that, so we skip it here. + for (size_t i = 1; i < parameters.size(); ++i) { auto const& parameter = parameters[i]; auto argument_generator = message_generator.fork(); argument_generator.set("argument.type", parameter.type); @@ -859,7 +954,6 @@ public: if (i != parameters.size() - 1) argument_generator.append(", "); } - message_generator.append(") = 0;"); }; @@ -892,12 +986,14 @@ void build(StringBuilder& builder, Vector const& endpoints) #include #include #include +#include #include #include #include #include #include #include +#include #if defined(AK_COMPILER_CLANG) #pragma clang diagnostic push diff --git a/Services/WebContent/WebContentServer.ipc b/Services/WebContent/WebContentServer.ipc index 82637b6d83c..ca3ed3cdff1 100644 --- a/Services/WebContent/WebContentServer.ipc +++ b/Services/WebContent/WebContentServer.ipc @@ -1,3 +1,4 @@ +#include #include #include #include