LibIPC: Introduce async functionality to IPC - part 1

Description:
This patch is part 1 in the draft for asynchronous IPC. This patch is
responsible for 100% for the functionality that is added to make async
IPC possible. The following patches in this series, contains the
changes to code in Ladybird at large, that needs to change in order to
compile and although I know this breaks with convention, I figured this
would be easier to review as a draft. I can squash/refactor, as soon as
some review of this has started. It also allows me to write additional
patches that can show how this should work.

Changes made
Added a sequence id to all messages, which is provided by the endpoint
as a static monotonic id.

Changed the int value for messages and responses (for debuggability
purposes). All messages are uneven values and all responses are that
value + 1, to encode additional data into the value.

Messages that respond with data, now must be changed from:

Messages::SomeEndpoint::ResponseType
foo(Foo arg1, Bar arg2)

to

NonnullRefPtr<Core::Promise<Messages::SomeEndpoint::ResponseType>>
foo(Foo arg1, Bar arg2)

After in-depth discussion, maintainers of the project opted to go for
non-industry standard approach to maintain a logical cohesiveness with
ladybird, where the return of a Promise<T> is more logically suitable,
even if it incurs a little extra verbosity.

Existing IPC methods, now needs to do the following 4 things, and
commit 2 of this Async IPC proposal will do a few of them for existing
messages:

- Change the return type from
  `Messages::EndPointName::MessageNameResponse` to
  `NonnullRefPtr<Messages::EndPointName::MessageName::Promise>`
- Construct the correct Promise type, which can be found in the
  `Messages::EndPointName::MessageName::Promise` typedef.
- Resolve the promise with a value, where appropriate.
- Return the promise

(Note: Under no circumstances are you allowed to configure a
`when_resolved` for this `Promise` as the IPC system relies on using
this to pickle the message and then write the message buffer over the
socket. You are only allowed to to resolve it with a value).
This commit is contained in:
Simon Farre 2025-03-26 22:42:55 +01:00
parent 6f1710121d
commit 3ca0ec9f71
6 changed files with 273 additions and 120 deletions

View file

@ -1,5 +1,6 @@
/*
* Copyright (c) 2021-2024, Andreas Kling <andreas@ladybird.org>
* Copyright (c) 2025, Simon Farre <simon.farre.cx@gmail.com>
* Copyright (c) 2022, the SerenityOS developers.
*
* SPDX-License-Identifier: BSD-2-Clause
@ -9,7 +10,6 @@
#include <LibCore/Socket.h>
#include <LibCore/Timer.h>
#include <LibIPC/Connection.h>
#include <LibIPC/Message.h>
#include <LibIPC/Stub.h>
namespace IPC {
@ -18,6 +18,7 @@ ConnectionBase::ConnectionBase(IPC::Stub& local_stub, Transport transport, u32 l
: m_local_stub(local_stub)
, m_transport(move(transport))
, m_local_endpoint_magic(local_endpoint_magic)
, m_event_loop(Core::EventLoop::current())
{
m_responsiveness_timer = Core::Timer::create_single_shot(3000, [this] { may_have_become_unresponsive(); });
@ -25,7 +26,7 @@ ConnectionBase::ConnectionBase(IPC::Stub& local_stub, Transport transport, u32 l
NonnullRefPtr protect = *this;
// FIXME: Do something about errors.
(void)drain_messages_from_peer();
handle_messages();
process_messages();
});
m_send_queue = adopt_ref(*new SendQueue);
@ -102,26 +103,6 @@ void ConnectionBase::shutdown_with_error(Error const& error)
shutdown();
}
void ConnectionBase::handle_messages()
{
auto messages = move(m_unprocessed_messages);
for (auto& message : messages) {
if (message->endpoint_magic() == m_local_endpoint_magic) {
auto handler_result = m_local_stub.handle(move(message));
if (handler_result.is_error()) {
dbgln("IPC::ConnectionBase::handle_messages: {}", handler_result.error());
continue;
}
if (auto response = handler_result.release_value()) {
if (auto post_result = post_message(*response); post_result.is_error()) {
dbgln("IPC::ConnectionBase::handle_messages: {}", post_result.error());
}
}
}
}
}
void ConnectionBase::wait_for_transport_to_become_readable()
{
m_transport.wait_until_readable();
@ -160,6 +141,33 @@ ErrorOr<Vector<u8>> ConnectionBase::read_as_much_as_possible_from_transport_with
return bytes;
}
void ConnectionBase::process_messages()
{
if (m_unprocessed_messages.is_empty()) {
return;
}
Vector<NonnullOwnPtr<Message>> messages;
swap(m_unprocessed_messages, messages);
for (auto& message : messages) {
if (message->is_response()) {
auto it = find_if(m_resolvers.begin(), m_resolvers.end(), [id = message->ipc_request_id()](auto const& res) { return res.request_id == id; });
if (it == m_resolvers.end()) {
continue;
}
it->promise_completer(move(message));
m_resolvers.remove(std::distance(m_resolvers.begin(), it));
} else {
// Handle message on thread this connection was created on.
// We don't necessarily need to defer invoke here, but in a future where we may want
// to have a dedicated IO thread, it needs to post the work back to the thread it needs to run on.
m_event_loop.deferred_invoke([this, msg = move(message)]() mutable {
m_local_stub.handle_ipc_message(NonnullRefPtr<IPC::ConnectionBase> { *this }, move(msg));
});
}
}
}
ErrorOr<void> ConnectionBase::drain_messages_from_peer()
{
auto bytes = TRY(read_as_much_as_possible_from_transport_without_blocking());
@ -181,13 +189,13 @@ ErrorOr<void> ConnectionBase::drain_messages_from_peer()
if (!m_unprocessed_messages.is_empty()) {
deferred_invoke([this] {
handle_messages();
process_messages();
});
}
return {};
}
OwnPtr<IPC::Message> ConnectionBase::wait_for_specific_endpoint_message_impl(u32 endpoint_magic, int message_id)
OwnPtr<IPC::Message> ConnectionBase::wait_for_specific_endpoint_message_impl(u64 request_id, u32 endpoint_magic, int message_id)
{
for (;;) {
// Double check we don't already have the event waiting for us.
@ -196,7 +204,7 @@ OwnPtr<IPC::Message> ConnectionBase::wait_for_specific_endpoint_message_impl(u32
auto& message = m_unprocessed_messages[i];
if (message->endpoint_magic() != endpoint_magic)
continue;
if (message->message_id() == message_id)
if (message->message_id() == message_id && message->ipc_request_id() == request_id)
return m_unprocessed_messages.take(i);
}

View file

@ -1,5 +1,6 @@
/*
* Copyright (c) 2018-2024, Andreas Kling <andreas@ladybird.org>
* Copyright (c) 2025, Simon Farre <simon.farre.cx@gmail.com>
* Copyright (c) 2022, the SerenityOS developers.
*
* SPDX-License-Identifier: BSD-2-Clause
@ -9,9 +10,14 @@
#include <AK/Forward.h>
#include <AK/Queue.h>
#include <LibCore/Event.h>
#include <LibCore/EventLoop.h>
#include <LibCore/EventReceiver.h>
#include <LibCore/Promise.h>
#include <LibCore/ThreadEventQueue.h>
#include <LibIPC/File.h>
#include <LibIPC/Forward.h>
#include <LibIPC/Message.h>
#include <LibIPC/Transport.h>
#include <LibThreading/ConditionVariable.h>
#include <LibThreading/MutexProtected.h>
@ -19,9 +25,16 @@
namespace IPC {
struct Completer {
u64 request_id { 0 };
Function<void(OwnPtr<IPC::Message>)> promise_completer;
};
class ConnectionBase : public Core::EventReceiver {
C_OBJECT_ABSTRACT(ConnectionBase);
void process_messages();
public:
virtual ~ConnectionBase() override;
@ -34,6 +47,8 @@ public:
Transport& transport() { return m_transport; }
ErrorOr<void> drain_messages_from_peer();
protected:
explicit ConnectionBase(IPC::Stub&, Transport, u32 local_endpoint_magic);
@ -42,18 +57,14 @@ protected:
virtual void shutdown_with_error(Error const&);
virtual OwnPtr<Message> try_parse_message(ReadonlyBytes, Queue<IPC::File>&) = 0;
OwnPtr<IPC::Message> wait_for_specific_endpoint_message_impl(u32 endpoint_magic, int message_id);
OwnPtr<IPC::Message> wait_for_specific_endpoint_message_impl(u64 request_id, u32 endpoint_magic, int message_id);
void wait_for_transport_to_become_readable();
ErrorOr<Vector<u8>> read_as_much_as_possible_from_transport_without_blocking();
ErrorOr<void> drain_messages_from_peer();
void try_parse_messages(Vector<u8> const& bytes, size_t& index);
void handle_messages();
IPC::Stub& m_local_stub;
Transport m_transport;
RefPtr<Core::Timer> m_responsiveness_timer;
Vector<NonnullOwnPtr<Message>> m_unprocessed_messages;
@ -71,6 +82,9 @@ protected:
RefPtr<Threading::Thread> m_send_thread;
RefPtr<SendQueue> m_send_queue;
// Arbitrary inline size.
Vector<Completer, 16> m_resolvers;
Core::EventLoop& m_event_loop;
};
template<typename LocalEndpoint, typename PeerEndpoint>
@ -81,6 +95,8 @@ public:
{
}
~Connection() override = default;
template<typename MessageType>
OwnPtr<MessageType> wait_for_specific_message()
{
@ -90,25 +106,41 @@ public:
template<typename RequestType, typename... Args>
NonnullOwnPtr<typename RequestType::ResponseType> send_sync(Args&&... args)
{
MUST(post_message(RequestType(forward<Args>(args)...)));
auto response = wait_for_specific_endpoint_message<typename RequestType::ResponseType, PeerEndpoint>();
auto const request_id = LocalEndpoint::next_ipc_request_id();
MUST(post_message(RequestType(request_id, forward<Args>(args)...)));
auto response = wait_for_specific_endpoint_message<typename RequestType::ResponseType, PeerEndpoint>(request_id);
VERIFY(response);
return response.release_nonnull();
}
template<typename RequestType, typename... Args>
NonnullRefPtr<Core::Promise<OwnPtr<typename RequestType::ResponseType>>> send(Args&&... args)
{
using Promise = Core::Promise<OwnPtr<typename RequestType::ResponseType>>;
auto promise = Promise::construct();
auto const msg = RequestType { LocalEndpoint::next_ipc_request_id(), forward<Args>(args)... };
MUST(post_message(msg));
m_resolvers.empend(msg.ipc_request_id(), [promise, msg_id = msg.message_id()](OwnPtr<IPC::Message> msg) {
ASSERT(msg->message_id() == msg_id + 1 && msg->endpoint_magic() == PeerEndpoint::static_magic() && msg->ipc_request_id() == msg_id);
promise->resolve(msg.release_nonnull<typename RequestType::ResponseType>());
});
return promise;
}
template<typename RequestType, typename... Args>
OwnPtr<typename RequestType::ResponseType> send_sync_but_allow_failure(Args&&... args)
{
if (post_message(RequestType(forward<Args>(args)...)).is_error())
auto const request_id = LocalEndpoint::next_ipc_request_id();
if (post_message(RequestType(request_id, forward<Args>(args)...)).is_error())
return nullptr;
return wait_for_specific_endpoint_message<typename RequestType::ResponseType, PeerEndpoint>();
return wait_for_specific_endpoint_message<typename RequestType::ResponseType, PeerEndpoint>(request_id);
}
protected:
template<typename MessageType, typename Endpoint>
OwnPtr<MessageType> wait_for_specific_endpoint_message()
OwnPtr<MessageType> wait_for_specific_endpoint_message(u64 request_id)
{
if (auto message = wait_for_specific_endpoint_message_impl(Endpoint::static_magic(), MessageType::static_message_id()))
if (auto message = wait_for_specific_endpoint_message_impl(request_id, Endpoint::static_magic(), MessageType::static_message_id()))
return message.template release_nonnull<MessageType>();
return {};
}
@ -126,5 +158,4 @@ protected:
return nullptr;
}
};
}

View file

@ -40,6 +40,13 @@ class MessageBuffer {
public:
MessageBuffer();
// Do not copy message buffers.
MessageBuffer(MessageBuffer const&) = delete;
MessageBuffer& operator=(MessageBuffer const&) = delete;
MessageBuffer(MessageBuffer&&) noexcept = default;
MessageBuffer& operator=(MessageBuffer&&) noexcept = default;
ErrorOr<void> extend_data_capacity(size_t capacity);
ErrorOr<void> append_data(u8 const* values, size_t count);
@ -70,9 +77,17 @@ public:
virtual int message_id() const = 0;
virtual char const* message_name() const = 0;
virtual ErrorOr<MessageBuffer> encode() const = 0;
u64 ipc_request_id() const { return m_request_id; }
void set_ipc_request_id(u64 request_id) { m_request_id = request_id; }
constexpr bool is_response() const { return (message_id() % 2) == 0; }
protected:
Message() = default;
};
explicit Message(u64 request_id)
: m_request_id(request_id)
{
}
private:
u64 m_request_id;
};
}

View file

@ -17,13 +17,15 @@ class BufferStream;
namespace IPC {
class ConnectionBase;
class Stub {
public:
virtual ~Stub() = default;
virtual u32 magic() const = 0;
virtual ByteString name() const = 0;
virtual ErrorOr<OwnPtr<MessageBuffer>> handle(NonnullOwnPtr<Message>) = 0;
virtual void handle_ipc_message(NonnullRefPtr<IPC::ConnectionBase> conn, NonnullOwnPtr<IPC::Message> message) = 0;
protected:
Stub() = default;

View file

@ -1,9 +1,11 @@
/*
* Copyright (c) 2018-2020, Andreas Kling <andreas@ladybird.org>
* Copyright (c) 2025, Simon Farre <simon.farre.cx@gmail.com>
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <AK/Forward.h>
#include <AK/Debug.h>
#include <AK/Enumerate.h>
#include <AK/Function.h>
@ -245,7 +247,10 @@ Vector<Endpoint> parse(ByteBuffer const& file_contents)
consume_whitespace();
// We have an unfortunate naming of "is_synchronous" when it really means "has response"
message.inputs.prepend(Parameter { .attributes = {}, .type = "u64", .type_for_encoding = "u64", .name = "ipc_request_id" });
if (message.is_synchronous) {
message.outputs.prepend(Parameter { .attributes = {}, .type = "u64", .type_for_encoding = "u64", .name = "ipc_request_id" });
assert_specific('(');
parse_parameters(message.outputs, message.name);
assert_specific(')');
@ -312,59 +317,60 @@ HashMap<ByteString, int> build_message_ids_for_endpoint(SourceGenerator generato
HashMap<ByteString, int> message_ids;
generator.appendln("\nenum class MessageID : i32 {");
auto id = 1;
for (auto const& message : endpoint.messages) {
message_ids.set(message.name, message_ids.size() + 1);
message_ids.set(message.name, id);
generator.set("message.pascal_name", pascal_case(message.name));
generator.set("message.id", ByteString::number(message_ids.size()));
generator.set("message.id", ByteString::number(id++));
generator.appendln(" @message.pascal_name@ = @message.id@,");
if (message.is_synchronous) {
message_ids.set(message.response_name(), message_ids.size() + 1);
message_ids.set(message.response_name(), id);
generator.set("message.pascal_name", pascal_case(message.response_name()));
generator.set("message.id", ByteString::number(message_ids.size()));
generator.set("message.id", ByteString::number(id++));
generator.appendln(" @message.pascal_name@ = @message.id@,");
}
id += message.is_synchronous ? 0 : 1;
}
generator.appendln("};");
return message_ids;
}
ByteString constructor_for_message(ByteString const& name, Vector<Parameter> const& parameters)
ByteString constructor_for_message(bool is_response, ByteString const& name, ReadonlySpan<Parameter const> parameters)
{
StringBuilder builder;
builder.append(name);
if (parameters.is_empty()) {
builder.append("() {}"sv);
return builder.to_byte_string();
}
builder.append('(');
for (size_t i = 0; i < parameters.size(); ++i) {
for (size_t i = is_response ? 1 : 0; i < parameters.size(); ++i) {
auto const& parameter = parameters[i];
builder.appendff("{} {}", parameter.type, parameter.name);
if (i != parameters.size() - 1)
builder.append(", "sv);
}
builder.append(") : "sv);
for (size_t i = 0; i < parameters.size(); ++i) {
if (is_response)
builder.append(") : Message(0) "sv);
else
builder.append(") : Message(ipc_request_id) "sv);
for (size_t i = 1; i < parameters.size(); ++i) {
builder.append(", "sv);
auto const& parameter = parameters[i];
builder.appendff("m_{}(move({}))", parameter.name, parameter.name);
if (i != parameters.size() - 1)
builder.append(", "sv);
}
builder.append(" {}"sv);
return builder.to_byte_string();
}
void do_message(SourceGenerator message_generator, ByteString const& name, Vector<Parameter> const& parameters, ByteString const& response_type = {})
void do_message(bool is_response, SourceGenerator message_generator, ByteString const& name, ReadonlySpan<Parameter const> parameters, ByteString const& response_type = {})
{
auto pascal_name = pascal_case(name);
message_generator.set("message.name", name);
message_generator.set("message.pascal_name", pascal_name);
message_generator.set("message.response_type", response_type);
message_generator.set("message.constructor", constructor_for_message(pascal_name, parameters));
message_generator.set("message.constructor", constructor_for_message(is_response, pascal_name, parameters));
message_generator.append(R"~~~(
class @message.pascal_name@ final : public IPC::Message {
@ -372,24 +378,30 @@ public:)~~~");
if (!response_type.is_empty())
message_generator.appendln(R"~~~(
typedef class @message.response_type@ ResponseType;)~~~");
typedef class @message.response_type@ ResponseType;
using Promise = Core::Promise<ResponseType>;)~~~");
if (is_response) {
message_generator.appendln(R"~~~(
using Promise = Core::Promise<@message.pascal_name@>;)~~~");
}
message_generator.appendln(R"~~~(
@message.pascal_name@(@message.pascal_name@ const&) = default;
@message.pascal_name@(@message.pascal_name@&&) = default;
@message.pascal_name@& operator=(@message.pascal_name@ const&) = default;
@message.pascal_name@(@message.pascal_name@ const&) = delete;
@message.pascal_name@& operator=(@message.pascal_name@ const&) = delete;
@message.constructor@)~~~");
if (parameters.size() == 1) {
auto const& parameter = parameters[0];
if (parameters.size() == 2) {
auto const& parameter = parameters[1];
message_generator.set("parameter.type"sv, parameter.type);
message_generator.set("parameter.name"sv, parameter.name);
message_generator.appendln(R"~~~(
template <typename WrappedReturnType>
requires(!SameAs<WrappedReturnType, @parameter.type@>)
@message.pascal_name@(WrappedReturnType&& value)
: m_@parameter.name@(forward<WrappedReturnType>(value))
@message.pascal_name@(u64 sequence_id, WrappedReturnType&& value)
: Message(sequence_id), m_@parameter.name@(forward<WrappedReturnType>(value))
{
})~~~");
}
@ -430,7 +442,7 @@ public:)~~~");
}
StringBuilder builder;
for (size_t i = 0; i < parameters.size(); ++i) {
for (size_t i = is_response ? 1 : 0; i < parameters.size(); ++i) {
auto const& parameter = parameters[i];
builder.appendff("move({})", parameter.name);
if (i != parameters.size() - 1)
@ -438,9 +450,17 @@ public:)~~~");
}
message_generator.set("message.constructor_call_parameters", builder.to_byte_string());
message_generator.appendln(R"~~~(
return make<@message.pascal_name@>(@message.constructor_call_parameters@);
if (is_response) {
message_generator.appendln(R"~~~(
auto result = make<@message.pascal_name@>(@message.constructor_call_parameters@);
result->set_ipc_request_id(ipc_request_id);
return result;
})~~~");
} else {
message_generator.appendln(R"~~~(
return make<@message.pascal_name@>(@message.constructor_call_parameters@);
})~~~");
}
message_generator.append(R"~~~(
static ErrorOr<IPC::MessageBuffer> static_encode()~~~");
@ -476,20 +496,19 @@ public:)~~~");
message_generator.append(R"~~~(
virtual ErrorOr<IPC::MessageBuffer> encode() const override
{
return static_encode()~~~");
for (auto const& [i, parameter] : enumerate(parameters)) {
return static_encode(ipc_request_id())~~~");
auto parameters_without_seqid = parameters.slice(1);
for (auto const& parameter : parameters_without_seqid) {
auto parameter_generator = message_generator.fork();
parameter_generator.append(", ");
parameter_generator.set("parameter.name", parameter.name);
parameter_generator.append("m_@parameter.name@");
if (i != parameters.size() - 1)
parameter_generator.append(", ");
}
message_generator.appendln(R"~~~();
})~~~");
for (auto const& parameter : parameters) {
for (auto const& parameter : parameters_without_seqid) {
auto parameter_generator = message_generator.fork();
parameter_generator.set("parameter.type", parameter.type);
parameter_generator.set("parameter.name", parameter.name);
@ -506,8 +525,7 @@ public:)~~~");
message_generator.append(R"~~~(
private:)~~~");
for (auto const& parameter : parameters) {
for (auto const& parameter : parameters_without_seqid) {
auto parameter_generator = message_generator.fork();
parameter_generator.set("parameter.type", parameter.type);
parameter_generator.set("parameter.name", parameter.name);
@ -518,7 +536,54 @@ private:)~~~");
message_generator.appendln("\n};");
}
void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& endpoint, Message const& message, ByteString const& name, Vector<Parameter> const& parameters, bool is_synchronous, bool is_try, bool is_utf8_string_overload = false)
void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& endpoint, Message const& message, ByteString const& name, ReadonlySpan<Parameter const> parameters, bool is_synchronous, bool is_try, bool is_utf8_string_overload);
void generate_proxy_awaitable_method(SourceGenerator& message_generator, Message const& message, ByteString const& name, ReadonlySpan<Parameter const> parameters_without_seqid, bool is_utf8_string_overload = false)
{
ASSERT(message.is_synchronous);
message_generator.set("message.name", message.name);
message_generator.set("message.pascal_name", pascal_case(message.name));
message_generator.set("handler_name", name);
message_generator.append(R"~~~(
auto send_@handler_name@()~~~");
for (auto const& [i, parameter] : enumerate(parameters_without_seqid)) {
ByteString type;
if (is_utf8_string_overload)
type = make_argument_type(parameter.type);
else
type = make_argument_type(parameter.type_for_encoding);
auto argument_generator = message_generator.fork();
argument_generator.set("argument.type", type);
argument_generator.set("argument.name", parameter.name);
argument_generator.append("@argument.type@ @argument.name@");
if (i != parameters_without_seqid.size() - 1)
argument_generator.append(", ");
}
message_generator.append(") {");
message_generator.append(R"~~~(
return m_connection. template send<Messages::@endpoint.name@::@message.pascal_name@>()~~~");
for (auto const& [i, parameter] : enumerate(parameters_without_seqid)) {
auto argument_generator = message_generator.fork();
argument_generator.set("argument.name", parameter.name);
if (is_primitive_or_simple_type(parameter.type))
argument_generator.append("@argument.name@");
else
argument_generator.append("move(@argument.name@)");
if (i != parameters_without_seqid.size() - 1)
argument_generator.append(", ");
}
message_generator.appendln(R"~~~();
})~~~");
}
void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& endpoint, Message const& message, ByteString const& name, ReadonlySpan<Parameter const> parameters, bool is_synchronous, bool is_try, bool is_utf8_string_overload = false)
{
// FIXME: For String parameters, we want to retain the property that all tranferred String objects are strictly UTF-8.
// So instead of generating a single proxy method that accepts StringView parameters, we generate two overloads.
@ -531,9 +596,9 @@ void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& e
ByteString return_type = "void";
if (is_synchronous) {
if (message.outputs.size() == 1)
return_type = message.outputs[0].type;
else if (!message.outputs.is_empty())
if (message.outputs.size() == 2)
return_type = message.outputs[1].type;
else if (message.outputs.size() > 1)
return_type = message_name(endpoint.name, message.name, true);
}
ByteString inner_return_type = return_type;
@ -550,7 +615,8 @@ void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& e
message_generator.append(R"~~~(
@message.complex_return_type@ @try_prefix_maybe@@async_prefix_maybe@@handler_name@()~~~");
for (auto const& [i, parameter] : enumerate(parameters)) {
auto parameters_without_seqid = parameters.slice(1);
for (auto const& [i, parameter] : enumerate(parameters_without_seqid)) {
ByteString type;
if (is_synchronous || is_try)
type = parameter.type;
@ -563,7 +629,7 @@ void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& e
argument_generator.set("argument.type", type);
argument_generator.set("argument.name", parameter.name);
argument_generator.append("@argument.type@ @argument.name@");
if (i != parameters.size() - 1)
if (i != parameters_without_seqid.size() - 1)
argument_generator.append(", ");
}
@ -584,11 +650,12 @@ void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& e
}
}
bool was_static = false;
if (is_synchronous && !is_try) {
if (return_type != "void") {
message_generator.append(R"~~~(
return )~~~");
if (message.outputs.size() != 1)
if (message.outputs.size() != 2)
message_generator.append("move(*");
} else {
message_generator.append(R"~~~(
@ -600,20 +667,25 @@ void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& e
message_generator.append(R"~~~(
auto result = m_connection.template send_sync_but_allow_failure<Messages::@endpoint.name@::@message.pascal_name@>()~~~");
} else {
was_static = true;
message_generator.append(R"~~~(
auto message_buffer = MUST(Messages::@endpoint.name@::@message.pascal_name@::static_encode()~~~");
auto message_buffer = MUST(Messages::@endpoint.name@::@message.pascal_name@::static_encode(LocalEndpoint::next_ipc_request_id())~~~");
}
for (auto const& [i, parameter] : enumerate(parameters)) {
for (auto const& [i, parameter] : enumerate(parameters_without_seqid)) {
auto const& type = is_synchronous || is_try ? parameter.type : parameter.type_for_encoding;
auto argument_generator = message_generator.fork();
argument_generator.set("argument.name", parameter.name);
if (was_static) {
argument_generator.append(", ");
was_static = false;
}
if (is_primitive_or_simple_type(type))
argument_generator.append("@argument.name@");
else
argument_generator.append("move(@argument.name@)");
if (i != parameters.size() - 1)
if (i != parameters_without_seqid.size() - 1)
argument_generator.append(", ");
}
@ -622,13 +694,13 @@ void generate_proxy_method(SourceGenerator& message_generator, Endpoint const& e
message_generator.append(")");
}
if (message.outputs.size() == 1) {
if (is_primitive_or_simple_type(message.outputs[0].type))
if (message.outputs.size() == 2) {
if (is_primitive_or_simple_type(message.outputs[1].type))
message_generator.append("->");
else
message_generator.append("->take_");
message_generator.append(message.outputs[0].name);
message_generator.append(message.outputs[1].name);
message_generator.append("()");
} else
message_generator.append(")");
@ -665,6 +737,7 @@ void do_message_for_proxy(SourceGenerator message_generator, Endpoint const& end
if (message.is_synchronous) {
generate_proxy_method(message_generator, endpoint, message, message.name, message.inputs, false, false);
generate_proxy_method(message_generator, endpoint, message, message.name, message.inputs, true, true);
generate_proxy_awaitable_method(message_generator, message, message.name, ReadonlySpan<Parameter const> { message.inputs }.slice(1));
}
}
@ -681,9 +754,10 @@ void build_endpoint(SourceGenerator generator, Endpoint const& endpoint)
ByteString response_name;
if (message.is_synchronous) {
response_name = message.response_name();
do_message(generator.fork(), response_name, message.outputs);
VERIFY(message.outputs.size() >= 1);
do_message(true, generator.fork(), response_name, message.outputs);
}
do_message(generator.fork(), message.name, message.inputs, response_name);
do_message(false, generator.fork(), message.name, message.inputs, response_name);
}
generator.appendln(R"~~~(
@ -714,6 +788,11 @@ class @endpoint.name@Stub;
class @endpoint.name@Endpoint {
public:
static u64 next_ipc_request_id() {
static u64 request_id = 1;
return request_id++;
}
template<typename LocalEndpoint>
using Proxy = @endpoint.name@Proxy<LocalEndpoint, @endpoint.name@Endpoint>;
using Stub = @endpoint.name@Stub;
@ -781,15 +860,16 @@ public:
virtual u32 magic() const override { return @endpoint.magic@; }
virtual ByteString name() const override { return "@endpoint.name@"; }
virtual ErrorOr<OwnPtr<IPC::MessageBuffer>> handle(NonnullOwnPtr<IPC::Message> message) override
virtual void handle_ipc_message([[maybe_unused]] NonnullRefPtr<IPC::ConnectionBase> conn, NonnullOwnPtr<IPC::Message> message) override
{
switch (message->message_id()) {)~~~");
for (auto const& message : endpoint.messages) {
auto do_handle_message = [&](ByteString const& name, Vector<Parameter> const& parameters, bool returns_something) {
auto do_handle_message = [&](ByteString const& name, ReadonlySpan<Parameter const> parameters, ReadonlySpan<Parameter const> response_args, bool returns_something) {
auto message_generator = generator.fork();
StringBuilder argument_generator;
for (auto const& [i, parameter] : enumerate(parameters)) {
auto parameters_without_seqid = parameters.slice(1);
for (auto const& [i, parameter] : enumerate(parameters_without_seqid)) {
if (is_primitive_or_simple_type(parameter.type))
argument_generator.append("request."sv);
else
@ -797,60 +877,75 @@ public:
argument_generator.append(parameter.name);
argument_generator.append("()"sv);
if (i != parameters.size() - 1)
if (i != parameters_without_seqid.size() - 1) {
argument_generator.append(", "sv);
}
}
StringBuilder arguments_types_generator;
if (returns_something) {
for (auto const& parameter : response_args.slice(1)) {
arguments_types_generator.append(", const "sv);
arguments_types_generator.append(parameter.type);
arguments_types_generator.append("&"sv);
}
}
message_generator.set("message.pascal_name", pascal_case(name));
message_generator.set("message.response_type", pascal_case(message.response_name()));
message_generator.set("handler_name", name);
message_generator.set("arguments_types", arguments_types_generator.to_byte_string());
message_generator.set("arguments", argument_generator.to_byte_string());
message_generator.append(R"~~~(
case (int)Messages::@endpoint.name@::MessageID::@message.pascal_name@: {)~~~");
if (returns_something) {
if (message.outputs.is_empty()) {
message_generator.append(R"~~~(
message_generator.append(R"~~~(
[[maybe_unused]] auto& request = static_cast<Messages::@endpoint.name@::@message.pascal_name@&>(*message);
@handler_name@(@arguments@);
auto response = Messages::@endpoint.name@::@message.response_type@ { };
return make<IPC::MessageBuffer>(TRY(response.encode()));)~~~");
} else {
message_generator.append(R"~~~(
[[maybe_unused]] auto& request = static_cast<Messages::@endpoint.name@::@message.pascal_name@&>(*message);
auto response = @handler_name@(@arguments@);
return make<IPC::MessageBuffer>(TRY(response.encode()));)~~~");
auto promise = @handler_name@(@arguments@);
promise->when_resolved([id = message->ipc_request_id(), conn](auto&& response) {
response.set_ipc_request_id(id);
auto res = response.encode();
if (auto const post_result = conn->post_message(res.release_value()); post_result.is_error()) {
dbgln("IPC::ConnectionBase::handle_messages: {}", post_result.error());
}
});
break;)~~~");
} else {
message_generator.append(R"~~~(
[[maybe_unused]] auto& request = static_cast<Messages::@endpoint.name@::@message.pascal_name@&>(*message);
@handler_name@(@arguments@);
return nullptr;)~~~");
break;)~~~");
}
message_generator.append(R"~~~(
})~~~");
};
do_handle_message(message.name, message.inputs, message.is_synchronous);
do_handle_message(message.name, message.inputs, message.outputs, message.is_synchronous);
}
generator.appendln(R"~~~(
default:
return Error::from_string_literal("Unknown message ID for @endpoint.name@ endpoint");
dbgln("Unknown message ID for @endpoint.name@ endpoint");
}
})~~~");
for (auto const& message : endpoint.messages) {
auto message_generator = generator.fork();
ByteString return_type = "void";
if (message.is_synchronous) {
message_generator.set("response.type", pascal_case(message.response_name()));
}
auto do_handle_message_decl = [&](ByteString const& name, Vector<Parameter> const& parameters) {
ByteString return_type = "void";
if (message.is_synchronous && !message.outputs.is_empty())
return_type = message_name(endpoint.name, message.name, true);
message_generator.set("message.complex_return_type", return_type);
message_generator.set("handler_name", name);
message_generator.append(R"~~~(
virtual @message.complex_return_type@ @handler_name@()~~~");
for (size_t i = 0; i < parameters.size(); ++i) {
if (message.is_synchronous) {
message_generator.append(R"~~~(
virtual NonnullRefPtr<Messages::@endpoint.name@::@response.type@::Promise> @handler_name@()~~~");
} else {
message_generator.append(R"~~~(
virtual void @handler_name@()~~~");
}
// The IPC request sequence id should not be a part of the signature. The user should
// not have to be involved with implementation details like that, so we skip it here.
for (size_t i = 1; i < parameters.size(); ++i) {
auto const& parameter = parameters[i];
auto argument_generator = message_generator.fork();
argument_generator.set("argument.type", parameter.type);
@ -859,7 +954,6 @@ public:
if (i != parameters.size() - 1)
argument_generator.append(", ");
}
message_generator.append(") = 0;");
};
@ -892,12 +986,14 @@ void build(StringBuilder& builder, Vector<Endpoint> const& endpoints)
#include <AK/OwnPtr.h>
#include <AK/Result.h>
#include <AK/Utf8View.h>
#include <AK/NonnullRefPtr.h>
#include <LibIPC/Connection.h>
#include <LibIPC/Decoder.h>
#include <LibIPC/Encoder.h>
#include <LibIPC/File.h>
#include <LibIPC/Message.h>
#include <LibIPC/Stub.h>
#include <LibCore/Promise.h>
#if defined(AK_COMPILER_CLANG)
#pragma clang diagnostic push

View file

@ -1,3 +1,4 @@
#include <LibGfx/Color.h>
#include <LibGfx/Rect.h>
#include <LibIPC/File.h>
#include <LibURL/URL.h>