From 06faa7b160515329b22340f06b6749eb55ed61b4 Mon Sep 17 00:00:00 2001 From: Andrew Kaster Date: Thu, 20 Feb 2025 05:21:22 -0700 Subject: [PATCH] LibWebSocket+RequestServer: Resolve WebSocket hosts using our resolver --- Libraries/LibWebSocket/CMakeLists.txt | 2 +- Libraries/LibWebSocket/ConnectionInfo.h | 5 + .../RequestServer/ConnectionFromClient.cpp | 107 +++++++++++------- Services/RequestServer/ConnectionFromClient.h | 2 + Services/RequestServer/WebSocketImplCurl.cpp | 8 ++ 5 files changed, 82 insertions(+), 42 deletions(-) diff --git a/Libraries/LibWebSocket/CMakeLists.txt b/Libraries/LibWebSocket/CMakeLists.txt index f37687c82cf..5a1846360db 100644 --- a/Libraries/LibWebSocket/CMakeLists.txt +++ b/Libraries/LibWebSocket/CMakeLists.txt @@ -6,4 +6,4 @@ set(SOURCES ) serenity_lib(LibWebSocket websocket) -target_link_libraries(LibWebSocket PRIVATE LibCore LibCrypto LibTLS LibURL) +target_link_libraries(LibWebSocket PRIVATE LibCore LibCrypto LibTLS LibURL LibDNS) diff --git a/Libraries/LibWebSocket/ConnectionInfo.h b/Libraries/LibWebSocket/ConnectionInfo.h index 62d7c7aa6ce..208a63289b3 100644 --- a/Libraries/LibWebSocket/ConnectionInfo.h +++ b/Libraries/LibWebSocket/ConnectionInfo.h @@ -7,6 +7,7 @@ #pragma once #include +#include #include #include @@ -33,6 +34,9 @@ public: Optional const& root_certificates_path() const { return m_root_certificates_path; } void set_root_certificates_path(Optional root_certificates_path) { m_root_certificates_path = move(root_certificates_path); } + Optional dns_result() const { return m_dns_result ? Optional(*m_dns_result) : OptionalNone {}; } + void set_dns_result(NonnullRefPtr dns_result) { m_dns_result = move(dns_result); } + // secure flag - defined in RFC 6455 Section 3 bool is_secure() const; @@ -46,6 +50,7 @@ private: Vector m_extensions {}; HTTP::HeaderMap m_headers; Optional m_root_certificates_path; + RefPtr m_dns_result; }; } diff --git a/Services/RequestServer/ConnectionFromClient.cpp b/Services/RequestServer/ConnectionFromClient.cpp index 45bd14d3da1..48ed36e4292 100644 --- a/Services/RequestServer/ConnectionFromClient.cpp +++ b/Services/RequestServer/ConnectionFromClient.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -73,6 +74,24 @@ static NonnullRefPtr default_resolver() return resolver; } +ByteString build_curl_resolve_list(DNS::LookupResult const& dns_result, StringView host, u16 port) +{ + StringBuilder resolve_opt_builder; + resolve_opt_builder.appendff("{}:{}:", host, port); + auto first = true; + for (auto& addr : dns_result.cached_addresses()) { + auto formatted_address = addr.visit( + [&](IPv4Address const& ipv4) { return ipv4.to_byte_string(); }, + [&](IPv6Address const& ipv6) { return MUST(ipv6.to_string()).to_byte_string(); }); + if (!first) + resolve_opt_builder.append(','); + first = false; + resolve_opt_builder.append(formatted_address); + } + + return resolve_opt_builder.to_byte_string(); +} + struct ConnectionFromClient::ActiveRequest { CURLM* multi { nullptr }; CURL* easy { nullptr }; @@ -463,20 +482,7 @@ void ConnectionFromClient::start_request(i32 request_id, ByteString const& metho set_option(CURLOPT_HEADERFUNCTION, &on_header_received); set_option(CURLOPT_HEADERDATA, reinterpret_cast(request.ptr())); - StringBuilder resolve_opt_builder; - resolve_opt_builder.appendff("{}:{}:", host, url.port_or_default()); - auto first = true; - for (auto& addr : dns_result->cached_addresses()) { - auto formatted_address = addr.visit( - [&](IPv4Address const& ipv4) { return ipv4.to_byte_string(); }, - [&](IPv6Address const& ipv6) { return MUST(ipv6.to_string()).to_byte_string(); }); - if (!first) - resolve_opt_builder.append(','); - first = false; - resolve_opt_builder.append(formatted_address); - } - - auto formatted_address = resolve_opt_builder.to_byte_string(); + auto formatted_address = build_curl_resolve_list(*dns_result, host, url.port_or_default()); if (curl_slist* resolve_list = curl_slist_append(nullptr, formatted_address.characters())) { set_option(CURLOPT_RESOLVE, resolve_list); request->curl_string_lists.append(resolve_list); @@ -653,37 +659,56 @@ void ConnectionFromClient::ensure_connection(URL::URL const& url, ::RequestServe void ConnectionFromClient::websocket_connect(i64 websocket_id, URL::URL const& url, ByteString const& origin, Vector const& protocols, Vector const& extensions, HTTP::HeaderMap const& additional_request_headers) { - // FIXME: Use our DNS resolver to resolve the hostname - WebSocket::ConnectionInfo connection_info(url); - connection_info.set_origin(origin); - connection_info.set_protocols(protocols); - connection_info.set_extensions(extensions); - connection_info.set_headers(additional_request_headers); + auto host = url.serialized_host().to_byte_string(); - if (!g_default_certificate_path.is_empty()) - connection_info.set_root_certificates_path(g_default_certificate_path); + // Check if host has the bracket notation for IPV6 addresses and remove them + if (host.starts_with("["sv) && host.ends_with("]"sv)) + host = host.substring(1, host.length() - 2); - auto impl = WebSocketImplCurl::create(m_curl_multi); - auto connection = WebSocket::WebSocket::create(move(connection_info), move(impl)); + m_resolver->dns.lookup(host, DNS::Messages::Class::IN, { DNS::Messages::ResourceType::A, DNS::Messages::ResourceType::AAAA }) + ->when_rejected([this, websocket_id](auto const& error) { + dbgln("WebSocketConnect: DNS lookup failed: {}", error); + async_websocket_errored(websocket_id, static_cast(Requests::WebSocket::Error::CouldNotEstablishConnection)); + }) + .when_resolved([this, websocket_id, host, url, origin, protocols, extensions, additional_request_headers](auto dns_result) { + if (dns_result->records().is_empty() || dns_result->cached_addresses().is_empty()) { + dbgln("WebSocketConnect: DNS lookup failed for '{}'", host); + async_websocket_errored(websocket_id, static_cast(Requests::WebSocket::Error::CouldNotEstablishConnection)); + return; + } - connection->on_open = [this, websocket_id]() { - async_websocket_connected(websocket_id); - }; - connection->on_message = [this, websocket_id](auto message) { - async_websocket_received(websocket_id, message.is_text(), message.data()); - }; - connection->on_error = [this, websocket_id](auto message) { - async_websocket_errored(websocket_id, (i32)message); - }; - connection->on_close = [this, websocket_id](u16 code, ByteString reason, bool was_clean) { - async_websocket_closed(websocket_id, code, move(reason), was_clean); - }; - connection->on_ready_state_change = [this, websocket_id](auto state) { - async_websocket_ready_state_changed(websocket_id, (u32)state); - }; + WebSocket::ConnectionInfo connection_info(url); + connection_info.set_origin(origin); + connection_info.set_protocols(protocols); + connection_info.set_extensions(extensions); + connection_info.set_headers(additional_request_headers); + connection_info.set_dns_result(move(dns_result)); - connection->start(); - m_websockets.set(websocket_id, move(connection)); + if (!g_default_certificate_path.is_empty()) + connection_info.set_root_certificates_path(g_default_certificate_path); + + auto impl = WebSocketImplCurl::create(m_curl_multi); + auto connection = WebSocket::WebSocket::create(move(connection_info), move(impl)); + + connection->on_open = [this, websocket_id]() { + async_websocket_connected(websocket_id); + }; + connection->on_message = [this, websocket_id](auto message) { + async_websocket_received(websocket_id, message.is_text(), message.data()); + }; + connection->on_error = [this, websocket_id](auto message) { + async_websocket_errored(websocket_id, (i32)message); + }; + connection->on_close = [this, websocket_id](u16 code, ByteString reason, bool was_clean) { + async_websocket_closed(websocket_id, code, move(reason), was_clean); + }; + connection->on_ready_state_change = [this, websocket_id](auto state) { + async_websocket_ready_state_changed(websocket_id, (u32)state); + }; + + connection->start(); + m_websockets.set(websocket_id, move(connection)); + }); } void ConnectionFromClient::websocket_send(i64 websocket_id, bool is_text, ByteBuffer const& data) diff --git a/Services/RequestServer/ConnectionFromClient.h b/Services/RequestServer/ConnectionFromClient.h index cde4d05bbeb..2404037c516 100644 --- a/Services/RequestServer/ConnectionFromClient.h +++ b/Services/RequestServer/ConnectionFromClient.h @@ -77,6 +77,8 @@ private: NonnullRefPtr m_resolver; }; +// FIXME: Find a good home for this +ByteString build_curl_resolve_list(DNS::LookupResult const&, StringView host, u16 port); constexpr inline uintptr_t websocket_private_tag = 0x1; } diff --git a/Services/RequestServer/WebSocketImplCurl.cpp b/Services/RequestServer/WebSocketImplCurl.cpp index 3d6d7d1e25d..fa496738f48 100644 --- a/Services/RequestServer/WebSocketImplCurl.cpp +++ b/Services/RequestServer/WebSocketImplCurl.cpp @@ -4,6 +4,8 @@ * SPDX-License-Identifier: BSD-2-Clause */ +#include "ConnectionFromClient.h" + #include namespace RequestServer { @@ -96,6 +98,12 @@ void WebSocketImplCurl::connect(WebSocket::ConnectionInfo const& info) set_option(CURLOPT_HTTPHEADER, curl_headers); m_curl_string_lists.append(curl_headers); + if (auto const& dns_info = info.dns_result(); dns_info.has_value()) { + auto* resolve_list = curl_slist_append(nullptr, build_curl_resolve_list(*dns_info, url.serialized_host(), url.port_or_default()).characters()); + set_option(CURLOPT_RESOLVE, resolve_list); + m_curl_string_lists.append(resolve_list); + } + CURLMcode const err = curl_multi_add_handle(m_multi_handle, m_easy_handle); VERIFY(err == CURLM_OK); }