From e205723b95aed068503a200140a56da215efc30c Mon Sep 17 00:00:00 2001 From: Andreas Kling Date: Wed, 18 Sep 2024 10:28:55 +0200 Subject: [PATCH] RequestServer: Make WebSocket IPC APIs asynchronous This fixes deadlocking when interacting with WebSockets while RequestServer is trying to stream downloaded data to WebContent. --- .../Libraries/LibRequests/RequestClient.cpp | 46 ++++++++++------ .../Libraries/LibRequests/RequestClient.h | 16 +++--- Userland/Libraries/LibRequests/WebSocket.cpp | 24 ++++++--- Userland/Libraries/LibRequests/WebSocket.h | 14 +++-- Userland/Libraries/LibWebSocket/WebSocket.cpp | 33 ++++++++---- Userland/Libraries/LibWebSocket/WebSocket.h | 4 ++ .../RequestServer/ConnectionFromClient.cpp | 54 +++++++------------ .../RequestServer/ConnectionFromClient.h | 10 ++-- .../Services/RequestServer/RequestClient.ipc | 12 +++-- .../Services/RequestServer/RequestServer.ipc | 10 ++-- 10 files changed, 129 insertions(+), 94 deletions(-) diff --git a/Userland/Libraries/LibRequests/RequestClient.cpp b/Userland/Libraries/LibRequests/RequestClient.cpp index d0e3e3c780f..6ad8586917c 100644 --- a/Userland/Libraries/LibRequests/RequestClient.cpp +++ b/Userland/Libraries/LibRequests/RequestClient.cpp @@ -96,45 +96,61 @@ void RequestClient::certificate_requested(i32 request_id) RefPtr RequestClient::websocket_connect(const URL::URL& url, ByteString const& origin, Vector const& protocols, Vector const& extensions, HTTP::HeaderMap const& request_headers) { - auto connection_id = IPCProxy::websocket_connect(url, origin, protocols, extensions, request_headers); - if (connection_id < 0) - return nullptr; - auto connection = WebSocket::create_from_id({}, *this, connection_id); - m_websockets.set(connection_id, connection); + auto websocket_id = m_next_websocket_id++; + IPCProxy::async_websocket_connect(websocket_id, url, origin, protocols, extensions, request_headers); + auto connection = WebSocket::create_from_id({}, *this, websocket_id); + m_websockets.set(websocket_id, connection); return connection; } -void RequestClient::websocket_connected(i32 connection_id) +void RequestClient::websocket_connected(i64 websocket_id) { - auto maybe_connection = m_websockets.get(connection_id); + auto maybe_connection = m_websockets.get(websocket_id); if (maybe_connection.has_value()) maybe_connection.value()->did_open({}); } -void RequestClient::websocket_received(i32 connection_id, bool is_text, ByteBuffer const& data) +void RequestClient::websocket_received(i64 websocket_id, bool is_text, ByteBuffer const& data) { - auto maybe_connection = m_websockets.get(connection_id); + auto maybe_connection = m_websockets.get(websocket_id); if (maybe_connection.has_value()) maybe_connection.value()->did_receive({}, data, is_text); } -void RequestClient::websocket_errored(i32 connection_id, i32 message) +void RequestClient::websocket_errored(i64 websocket_id, i32 message) { - auto maybe_connection = m_websockets.get(connection_id); + auto maybe_connection = m_websockets.get(websocket_id); if (maybe_connection.has_value()) maybe_connection.value()->did_error({}, message); } -void RequestClient::websocket_closed(i32 connection_id, u16 code, ByteString const& reason, bool clean) +void RequestClient::websocket_closed(i64 websocket_id, u16 code, ByteString const& reason, bool clean) { - auto maybe_connection = m_websockets.get(connection_id); + auto maybe_connection = m_websockets.get(websocket_id); if (maybe_connection.has_value()) maybe_connection.value()->did_close({}, code, reason, clean); } -void RequestClient::websocket_certificate_requested(i32 connection_id) +void RequestClient::websocket_ready_state_changed(i64 websocket_id, u32 ready_state) { - auto maybe_connection = m_websockets.get(connection_id); + auto maybe_connection = m_websockets.get(websocket_id); + if (maybe_connection.has_value()) { + VERIFY(ready_state <= static_cast(WebSocket::ReadyState::Closed)); + maybe_connection.value()->set_ready_state(static_cast(ready_state)); + } +} + +void RequestClient::websocket_subprotocol(i64 websocket_id, ByteString const& subprotocol) +{ + auto maybe_connection = m_websockets.get(websocket_id); + if (maybe_connection.has_value()) { + maybe_connection.value()->set_subprotocol_in_use(subprotocol); + } +} + +void RequestClient::websocket_certificate_requested(i64 websocket_id) +{ + auto maybe_connection = m_websockets.get(websocket_id); if (maybe_connection.has_value()) maybe_connection.value()->did_request_certificates({}); } diff --git a/Userland/Libraries/LibRequests/RequestClient.h b/Userland/Libraries/LibRequests/RequestClient.h index f5ca6fa14ef..d445ffd0ae4 100644 --- a/Userland/Libraries/LibRequests/RequestClient.h +++ b/Userland/Libraries/LibRequests/RequestClient.h @@ -44,14 +44,18 @@ private: virtual void certificate_requested(i32) override; virtual void headers_became_available(i32, HTTP::HeaderMap const&, Optional const&) override; - virtual void websocket_connected(i32) override; - virtual void websocket_received(i32, bool, ByteBuffer const&) override; - virtual void websocket_errored(i32, i32) override; - virtual void websocket_closed(i32, u16, ByteString const&, bool) override; - virtual void websocket_certificate_requested(i32) override; + virtual void websocket_connected(i64 websocket_id) override; + virtual void websocket_received(i64 websocket_id, bool, ByteBuffer const&) override; + virtual void websocket_errored(i64 websocket_id, i32) override; + virtual void websocket_closed(i64 websocket_id, u16, ByteString const&, bool) override; + virtual void websocket_ready_state_changed(i64 websocket_id, u32 ready_state) override; + virtual void websocket_subprotocol(i64 websocket_id, ByteString const& subprotocol) override; + virtual void websocket_certificate_requested(i64 websocket_id) override; HashMap> m_requests; - HashMap> m_websockets; + HashMap> m_websockets; + + i64 m_next_websocket_id { 0 }; }; } diff --git a/Userland/Libraries/LibRequests/WebSocket.cpp b/Userland/Libraries/LibRequests/WebSocket.cpp index 8d51b38b286..7fdc1d67f2c 100644 --- a/Userland/Libraries/LibRequests/WebSocket.cpp +++ b/Userland/Libraries/LibRequests/WebSocket.cpp @@ -9,25 +9,35 @@ namespace Requests { -WebSocket::WebSocket(RequestClient& client, i32 connection_id) +WebSocket::WebSocket(RequestClient& client, i64 connection_id) : m_client(client) - , m_connection_id(connection_id) + , m_websocket_id(connection_id) { } WebSocket::ReadyState WebSocket::ready_state() { - return static_cast(m_client->websocket_ready_state(m_connection_id)); + return m_ready_state; +} + +void WebSocket::set_ready_state(ReadyState ready_state) +{ + m_ready_state = ready_state; } ByteString WebSocket::subprotocol_in_use() { - return m_client->websocket_subprotocol_in_use(m_connection_id); + return m_subprotocol; +} + +void WebSocket::set_subprotocol_in_use(ByteString subprotocol) +{ + m_subprotocol = move(subprotocol); } void WebSocket::send(ByteBuffer binary_or_text_message, bool is_text) { - m_client->async_websocket_send(m_connection_id, is_text, move(binary_or_text_message)); + m_client->async_websocket_send(m_websocket_id, is_text, move(binary_or_text_message)); } void WebSocket::send(StringView text_message) @@ -37,7 +47,7 @@ void WebSocket::send(StringView text_message) void WebSocket::close(u16 code, ByteString reason) { - m_client->async_websocket_close(m_connection_id, code, move(reason)); + m_client->async_websocket_close(m_websocket_id, code, move(reason)); } void WebSocket::did_open(Badge) @@ -68,7 +78,7 @@ void WebSocket::did_request_certificates(Badge) { if (on_certificate_requested) { auto result = on_certificate_requested(); - if (!m_client->websocket_set_certificate(m_connection_id, result.certificate, result.key)) + if (!m_client->websocket_set_certificate(m_websocket_id, result.certificate, result.key)) dbgln("WebSocket: set_certificate failed"); } } diff --git a/Userland/Libraries/LibRequests/WebSocket.h b/Userland/Libraries/LibRequests/WebSocket.h index b6e8691df26..519e91ab86c 100644 --- a/Userland/Libraries/LibRequests/WebSocket.h +++ b/Userland/Libraries/LibRequests/WebSocket.h @@ -44,16 +44,18 @@ public: Closed = 3, }; - static NonnullRefPtr create_from_id(Badge, RequestClient& client, i32 connection_id) + static NonnullRefPtr create_from_id(Badge, RequestClient& client, i64 websocket_id) { - return adopt_ref(*new WebSocket(client, connection_id)); + return adopt_ref(*new WebSocket(client, websocket_id)); } - int id() const { return m_connection_id; } + i64 id() const { return m_websocket_id; } ReadyState ready_state(); + void set_ready_state(ReadyState); ByteString subprotocol_in_use(); + void set_subprotocol_in_use(ByteString); void send(ByteBuffer binary_or_text_message, bool is_text); void send(StringView text_message); @@ -72,9 +74,11 @@ public: void did_request_certificates(Badge); private: - explicit WebSocket(RequestClient&, i32 connection_id); + explicit WebSocket(RequestClient&, i64 websocket_id); WeakPtr m_client; - int m_connection_id { -1 }; + ReadyState m_ready_state { ReadyState::Connecting }; + ByteString m_subprotocol; + i64 m_websocket_id { -1 }; }; } diff --git a/Userland/Libraries/LibWebSocket/WebSocket.cpp b/Userland/Libraries/LibWebSocket/WebSocket.cpp index 3720c3e99cc..0dc0cfd4280 100644 --- a/Userland/Libraries/LibWebSocket/WebSocket.cpp +++ b/Userland/Libraries/LibWebSocket/WebSocket.cpp @@ -42,14 +42,14 @@ void WebSocket::start() m_impl->on_connected = [this] { if (m_state != WebSocket::InternalState::EstablishingProtocolConnection) return; - m_state = WebSocket::InternalState::SendingClientHandshake; + set_state(WebSocket::InternalState::SendingClientHandshake); send_client_handshake(); drain_read(); }; m_impl->on_ready_to_read = [this] { drain_read(); }; - m_state = WebSocket::InternalState::EstablishingProtocolConnection; + set_state(WebSocket::InternalState::EstablishingProtocolConnection); m_impl->connect(m_connection); } @@ -100,7 +100,7 @@ void WebSocket::close(u16 code, ByteString const& message) case InternalState::SendingClientHandshake: case InternalState::WaitingForServerHandshake: // FIXME: Fail the connection. - m_state = InternalState::Closing; + set_state(InternalState::Closing); break; case InternalState::Open: { auto message_bytes = message.bytes(); @@ -108,7 +108,7 @@ void WebSocket::close(u16 code, ByteString const& message) close_payload.overwrite(0, (u8*)&code, 2); close_payload.overwrite(2, message_bytes.data(), message_bytes.size()); send_frame(WebSocket::OpCode::ConnectionClose, close_payload, true); - m_state = InternalState::Closing; + set_state(InternalState::Closing); break; } default: @@ -120,7 +120,7 @@ void WebSocket::drain_read() { if (m_impl->eof()) { // The connection got closed by the server - m_state = WebSocket::InternalState::Closed; + set_state(WebSocket::InternalState::Closed); notify_close(m_last_close_code, m_last_close_message, true); discard_connection(); return; @@ -218,7 +218,7 @@ void WebSocket::send_client_handshake() builder.append("\r\n"sv); - m_state = WebSocket::InternalState::WaitingForServerHandshake; + set_state(WebSocket::InternalState::WaitingForServerHandshake); auto success = m_impl->send(builder.string_view().bytes()); VERIFY(success); } @@ -282,7 +282,7 @@ void WebSocket::read_server_handshake() return; } - m_state = WebSocket::InternalState::Open; + set_state(WebSocket::InternalState::Open); notify_open(); return; } @@ -400,7 +400,7 @@ void WebSocket::read_frame() auto head_bytes = get_buffered_bytes(2); if (head_bytes.is_null() || head_bytes.is_empty()) { // The connection got closed. - m_state = WebSocket::InternalState::Closed; + set_state(WebSocket::InternalState::Closed); notify_close(m_last_close_code, m_last_close_message, true); discard_connection(); return; @@ -487,7 +487,7 @@ void WebSocket::read_frame() m_last_close_code = (((u16)(payload[0] & 0xff) << 8) | ((u16)(payload[1] & 0xff))); m_last_close_message = ByteString(ReadonlyBytes(payload.offset_pointer(2), payload.size() - 2)); } - m_state = WebSocket::InternalState::Closing; + set_state(WebSocket::InternalState::Closing); return; } if (op_code == WebSocket::OpCode::Ping) { @@ -608,7 +608,7 @@ void WebSocket::send_frame(WebSocket::OpCode op_code, ReadonlyBytes payload, boo void WebSocket::fatal_error(WebSocket::Error error) { - m_state = WebSocket::InternalState::Errored; + set_state(WebSocket::InternalState::Errored); notify_error(error); discard_connection(); } @@ -653,4 +653,17 @@ void WebSocket::notify_message(Message message) on_message(move(message)); } +void WebSocket::set_state(InternalState state) +{ + if (m_state == state) + return; + auto old_ready_state = ready_state(); + m_state = state; + auto new_ready_state = ready_state(); + if (old_ready_state != new_ready_state) { + if (on_ready_state_change) + on_ready_state_change(ready_state()); + } +} + } diff --git a/Userland/Libraries/LibWebSocket/WebSocket.h b/Userland/Libraries/LibWebSocket/WebSocket.h index c695f234520..1aee8484221 100644 --- a/Userland/Libraries/LibWebSocket/WebSocket.h +++ b/Userland/Libraries/LibWebSocket/WebSocket.h @@ -46,6 +46,8 @@ public: Function on_open; Function on_close; Function on_message; + Function on_ready_state_change; + Function on_subprotocol; enum class Error { CouldNotEstablishConnection, @@ -97,6 +99,8 @@ private: InternalState m_state { InternalState::NotStarted }; + void set_state(InternalState); + ByteString m_subprotocol_in_use { ByteString::empty() }; ByteString m_websocket_key; diff --git a/Userland/Services/RequestServer/ConnectionFromClient.cpp b/Userland/Services/RequestServer/ConnectionFromClient.cpp index f2ed85484bf..4ab5578119a 100644 --- a/Userland/Services/RequestServer/ConnectionFromClient.cpp +++ b/Userland/Services/RequestServer/ConnectionFromClient.cpp @@ -386,12 +386,11 @@ void ConnectionFromClient::ensure_connection(URL::URL const& url, ::RequestServe dbgln("FIXME: EnsureConnection: Pre-connect to {}", url); } -static i32 s_next_websocket_id = 1; -Messages::RequestServer::WebsocketConnectResponse ConnectionFromClient::websocket_connect(URL::URL const& url, ByteString const& origin, Vector const& protocols, Vector const& extensions, HTTP::HeaderMap const& additional_request_headers) +void ConnectionFromClient::websocket_connect(i64 websocket_id, URL::URL const& url, ByteString const& origin, Vector const& protocols, Vector const& extensions, HTTP::HeaderMap const& additional_request_headers) { if (!url.is_valid()) { dbgln("WebSocket::Connect: Invalid URL requested: '{}'", url); - return -1; + return; } WebSocket::ConnectionInfo connection_info(url); @@ -400,56 +399,43 @@ Messages::RequestServer::WebsocketConnectResponse ConnectionFromClient::websocke connection_info.set_extensions(extensions); connection_info.set_headers(additional_request_headers); - auto id = ++s_next_websocket_id; auto connection = WebSocket::WebSocket::create(move(connection_info)); - connection->on_open = [this, id]() { - async_websocket_connected(id); + connection->on_open = [this, websocket_id]() { + async_websocket_connected(websocket_id); }; - connection->on_message = [this, id](auto message) { - async_websocket_received(id, message.is_text(), message.data()); + connection->on_message = [this, websocket_id](auto message) { + async_websocket_received(websocket_id, message.is_text(), message.data()); }; - connection->on_error = [this, id](auto message) { - async_websocket_errored(id, (i32)message); + connection->on_error = [this, websocket_id](auto message) { + async_websocket_errored(websocket_id, (i32)message); }; - connection->on_close = [this, id](u16 code, ByteString reason, bool was_clean) { - async_websocket_closed(id, code, move(reason), was_clean); + connection->on_close = [this, websocket_id](u16 code, ByteString reason, bool was_clean) { + async_websocket_closed(websocket_id, code, move(reason), was_clean); + }; + connection->on_ready_state_change = [this, websocket_id](auto state) { + async_websocket_ready_state_changed(websocket_id, (u32)state); }; connection->start(); - m_websockets.set(id, move(connection)); - return id; + m_websockets.set(websocket_id, move(connection)); } -Messages::RequestServer::WebsocketReadyStateResponse ConnectionFromClient::websocket_ready_state(i32 connection_id) +void ConnectionFromClient::websocket_send(i64 websocket_id, bool is_text, ByteBuffer const& data) { - if (auto connection = m_websockets.get(connection_id).value_or({})) - return (u32)connection->ready_state(); - return (u32)WebSocket::ReadyState::Closed; -} - -Messages::RequestServer::WebsocketSubprotocolInUseResponse ConnectionFromClient::websocket_subprotocol_in_use(i32 connection_id) -{ - if (auto connection = m_websockets.get(connection_id).value_or({})) - return connection->subprotocol_in_use(); - return ByteString::empty(); -} - -void ConnectionFromClient::websocket_send(i32 connection_id, bool is_text, ByteBuffer const& data) -{ - if (auto connection = m_websockets.get(connection_id).value_or({}); connection && connection->ready_state() == WebSocket::ReadyState::Open) + if (auto connection = m_websockets.get(websocket_id).value_or({}); connection && connection->ready_state() == WebSocket::ReadyState::Open) connection->send(WebSocket::Message { data, is_text }); } -void ConnectionFromClient::websocket_close(i32 connection_id, u16 code, ByteString const& reason) +void ConnectionFromClient::websocket_close(i64 websocket_id, u16 code, ByteString const& reason) { - if (auto connection = m_websockets.get(connection_id).value_or({}); connection && connection->ready_state() == WebSocket::ReadyState::Open) + if (auto connection = m_websockets.get(websocket_id).value_or({}); connection && connection->ready_state() == WebSocket::ReadyState::Open) connection->close(code, reason); } -Messages::RequestServer::WebsocketSetCertificateResponse ConnectionFromClient::websocket_set_certificate(i32 connection_id, ByteString const&, ByteString const&) +Messages::RequestServer::WebsocketSetCertificateResponse ConnectionFromClient::websocket_set_certificate(i64 websocket_id, ByteString const&, ByteString const&) { auto success = false; - if (auto connection = m_websockets.get(connection_id).value_or({}); connection) { + if (auto connection = m_websockets.get(websocket_id).value_or({}); connection) { // NO OP here // connection->set_certificate(certificate, key); success = true; diff --git a/Userland/Services/RequestServer/ConnectionFromClient.h b/Userland/Services/RequestServer/ConnectionFromClient.h index 41f9c8e326e..545851bfb62 100644 --- a/Userland/Services/RequestServer/ConnectionFromClient.h +++ b/Userland/Services/RequestServer/ConnectionFromClient.h @@ -42,12 +42,10 @@ private: virtual Messages::RequestServer::SetCertificateResponse set_certificate(i32, ByteString const&, ByteString const&) override; virtual void ensure_connection(URL::URL const& url, ::RequestServer::CacheLevel const& cache_level) override; - virtual Messages::RequestServer::WebsocketConnectResponse websocket_connect(URL::URL const&, ByteString const&, Vector const&, Vector const&, HTTP::HeaderMap const&) override; - virtual Messages::RequestServer::WebsocketReadyStateResponse websocket_ready_state(i32) override; - virtual Messages::RequestServer::WebsocketSubprotocolInUseResponse websocket_subprotocol_in_use(i32) override; - virtual void websocket_send(i32, bool, ByteBuffer const&) override; - virtual void websocket_close(i32, u16, ByteString const&) override; - virtual Messages::RequestServer::WebsocketSetCertificateResponse websocket_set_certificate(i32, ByteString const&, ByteString const&) override; + virtual void websocket_connect(i64 websocket_id, URL::URL const&, ByteString const&, Vector const&, Vector const&, HTTP::HeaderMap const&) override; + virtual void websocket_send(i64 websocket_id, bool, ByteBuffer const&) override; + virtual void websocket_close(i64 websocket_id, u16, ByteString const&) override; + virtual Messages::RequestServer::WebsocketSetCertificateResponse websocket_set_certificate(i64, ByteString const&, ByteString const&) override; HashMap> m_websockets; diff --git a/Userland/Services/RequestServer/RequestClient.ipc b/Userland/Services/RequestServer/RequestClient.ipc index 8ea4ac8dcea..772beb9912e 100644 --- a/Userland/Services/RequestServer/RequestClient.ipc +++ b/Userland/Services/RequestServer/RequestClient.ipc @@ -9,11 +9,13 @@ endpoint RequestClient // Websocket API // FIXME: See if this can be merged with the regular APIs - websocket_connected(i32 connection_id) =| - websocket_received(i32 connection_id, bool is_text, ByteBuffer data) =| - websocket_errored(i32 connection_id, i32 message) =| - websocket_closed(i32 connection_id, u16 code, ByteString reason, bool clean) =| - websocket_certificate_requested(i32 request_id) =| + websocket_connected(i64 websocket_id) =| + websocket_received(i64 websocket_id, bool is_text, ByteBuffer data) =| + websocket_errored(i64 websocket_id, i32 message) =| + websocket_closed(i64 websocket_id, u16 code, ByteString reason, bool clean) =| + websocket_ready_state_changed(i64 websocket_id, u32 ready_state) =| + websocket_subprotocol(i64 websocket_id, ByteString subprotocol) =| + websocket_certificate_requested(i64 websocket_id) =| // Certificate requests certificate_requested(i32 request_id) =| diff --git a/Userland/Services/RequestServer/RequestServer.ipc b/Userland/Services/RequestServer/RequestServer.ipc index 6916675cb6f..0e588136087 100644 --- a/Userland/Services/RequestServer/RequestServer.ipc +++ b/Userland/Services/RequestServer/RequestServer.ipc @@ -16,11 +16,9 @@ endpoint RequestServer ensure_connection(URL::URL url, ::RequestServer::CacheLevel cache_level) =| // Websocket Connection API - websocket_connect(URL::URL url, ByteString origin, Vector protocols, Vector extensions, HTTP::HeaderMap additional_request_headers) => (i32 connection_id) - websocket_ready_state(i32 connection_id) => (u32 ready_state) - websocket_subprotocol_in_use(i32 connection_id) => (ByteString subprotocol_in_use) - websocket_send(i32 connection_id, bool is_text, ByteBuffer data) =| - websocket_close(i32 connection_id, u16 code, ByteString reason) =| - websocket_set_certificate(i32 request_id, ByteString certificate, ByteString key) => (bool success) + websocket_connect(i64 websocket_id, URL::URL url, ByteString origin, Vector protocols, Vector extensions, HTTP::HeaderMap additional_request_headers) =| + websocket_send(i64 websocket_id, bool is_text, ByteBuffer data) =| + websocket_close(i64 websocket_id, u16 code, ByteString reason) =| + websocket_set_certificate(i64 request_id, ByteString certificate, ByteString key) => (bool success) }