diff --git a/Libraries/LibWebSocket/WebSocket.cpp b/Libraries/LibWebSocket/WebSocket.cpp index cf2548b084a..a098bcb40f2 100644 --- a/Libraries/LibWebSocket/WebSocket.cpp +++ b/Libraries/LibWebSocket/WebSocket.cpp @@ -151,7 +151,10 @@ void WebSocket::drain_read() } auto bytes = result.release_value(); m_buffered_data.append(bytes.data(), bytes.size()); - read_frame(); + do { + if (auto maybe_error = read_frame(); maybe_error.is_error()) + break; + } while (!m_buffered_data.is_empty()); } break; case InternalState::Closed: case InternalState::Errored: { @@ -388,7 +391,7 @@ void WebSocket::read_server_handshake() // If needed, we will keep reading the header on the next drain_read call } -void WebSocket::read_frame() +ErrorOr WebSocket::read_frame() { VERIFY(m_impl); VERIFY(m_state == WebSocket::InternalState::Open || m_state == WebSocket::InternalState::Closing); @@ -408,7 +411,7 @@ void WebSocket::read_frame() set_state(WebSocket::InternalState::Closed); notify_close(m_last_close_code, m_last_close_message, true); discard_connection(); - return; + return AK::Error::from_errno(ECONNABORTED); } auto op_code = (WebSocket::OpCode)(head_bytes[0] & 0x0f); @@ -422,7 +425,7 @@ void WebSocket::read_frame() // A code of 127 means that the next 8 bytes contains the payload length auto actual_bytes = get_buffered_bytes(8); if (actual_bytes.is_null()) - return; + return AK::Error::from_errno(EAGAIN); u64 full_payload_length = (u64)((u64)(actual_bytes[0] & 0xff) << 56) | (u64)((u64)(actual_bytes[1] & 0xff) << 48) | (u64)((u64)(actual_bytes[2] & 0xff) << 40) @@ -437,7 +440,7 @@ void WebSocket::read_frame() // A code of 126 means that the next 2 bytes contains the payload length auto actual_bytes = get_buffered_bytes(2); if (actual_bytes.is_null()) - return; + return AK::Error::from_errno(EAGAIN); payload_length = (size_t)((size_t)(actual_bytes[0] & 0xff) << 8) | (size_t)((size_t)(actual_bytes[1] & 0xff) << 0); } else { @@ -454,7 +457,7 @@ void WebSocket::read_frame() if (is_masked) { auto masking_key_data = get_buffered_bytes(4); if (masking_key_data.is_null()) - return; + return AK::Error::from_errno(EAGAIN); masking_key[0] = masking_key_data[0]; masking_key[1] = masking_key_data[1]; masking_key[2] = masking_key_data[2]; @@ -466,7 +469,7 @@ void WebSocket::read_frame() while (read_length < payload_length) { auto payload_part = get_buffered_bytes(payload_length - read_length); if (payload_part.is_null()) - return; + return AK::Error::from_errno(EAGAIN); // We read at most "actual_length - read" bytes, so this is safe to do. payload.overwrite(read_length, payload_part.data(), payload_part.size()); read_length += payload_part.size(); @@ -496,16 +499,16 @@ void WebSocket::read_frame() m_last_close_message = {}; } close(m_last_close_code, m_last_close_message); - return; + return {}; } if (op_code == WebSocket::OpCode::Ping) { // Immediately send a pong frame as a reply, with the given payload. send_frame(WebSocket::OpCode::Pong, payload, true); - return; + return {}; } if (op_code == WebSocket::OpCode::Pong) { // We can safely ignore the pong - return; + return {}; } if (!is_final_frame) { if (op_code != WebSocket::OpCode::Continuation) { @@ -514,7 +517,7 @@ void WebSocket::read_frame() } // First and next fragmented message m_fragmented_data_buffer.append(payload.data(), payload_length); - return; + return {}; } if (is_final_frame && op_code == WebSocket::OpCode::Continuation) { // Last fragmented message @@ -526,13 +529,14 @@ void WebSocket::read_frame() } if (op_code == WebSocket::OpCode::Text) { notify_message(Message(move(payload), true)); - return; + return {}; } if (op_code == WebSocket::OpCode::Binary) { notify_message(Message(move(payload), false)); - return; + return {}; } dbgln("Websocket: Found unknown opcode {}", (u8)op_code); + return {}; } void WebSocket::send_frame(WebSocket::OpCode op_code, ReadonlyBytes payload, bool is_final) diff --git a/Libraries/LibWebSocket/WebSocket.h b/Libraries/LibWebSocket/WebSocket.h index bdaf1b5db4f..cb3b72054ad 100644 --- a/Libraries/LibWebSocket/WebSocket.h +++ b/Libraries/LibWebSocket/WebSocket.h @@ -89,7 +89,7 @@ private: void send_client_handshake(); void read_server_handshake(); - void read_frame(); + ErrorOr read_frame(); void send_frame(OpCode, ReadonlyBytes, bool is_final); void notify_open(); diff --git a/Services/RequestServer/WebSocketImplCurl.cpp b/Services/RequestServer/WebSocketImplCurl.cpp index b541374c1a2..af9062f9c71 100644 --- a/Services/RequestServer/WebSocketImplCurl.cpp +++ b/Services/RequestServer/WebSocketImplCurl.cpp @@ -160,6 +160,43 @@ void WebSocketImplCurl::discard_connection() } } +void WebSocketImplCurl::read_from_socket() +{ + bool received_data = false; + + // "Wait on the socket only if curl_easy_recv returns CURLE_AGAIN. The reason for this is libcurl or the SSL + // library may internally cache some data, therefore you should call curl_easy_recv until all data is read which + // would include any cached data." + for (;;) { + u8 buffer[65536]; + size_t nread = 0; + CURLcode const result = curl_easy_recv(m_easy_handle, buffer, sizeof(buffer), &nread); + if (result == CURLE_AGAIN) + break; + + if (result != CURLE_OK) { + dbgln("Failed to read from WebSocket: {}", curl_easy_strerror(result)); + on_connection_error(); + return; + } + + // "Reading exactly 0 bytes indicates a closed connection." + if (nread == 0) { + dbgln("Failed to read from WebSocket: Server closed connection"); + on_connection_error(); + return; + } + + received_data = true; + + if (auto const err = m_read_buffer.write_until_depleted({ buffer, nread }); err.is_error()) + on_connection_error(); + } + + if (received_data) + on_ready_to_read(); +} + bool WebSocketImplCurl::did_connect() { curl_socket_t socket_fd = CURL_SOCKET_BAD; @@ -169,21 +206,7 @@ bool WebSocketImplCurl::did_connect() m_read_notifier = Core::Notifier::construct(socket_fd, Core::Notifier::Type::Read); m_read_notifier->on_activation = [this] { - u8 buffer[65536]; - size_t nread = 0; - CURLcode const result = curl_easy_recv(m_easy_handle, buffer, sizeof(buffer), &nread); - if (result == CURLE_AGAIN) - return; - - if (result != CURLE_OK) { - dbgln("Failed to read from WebSocket: {}", curl_easy_strerror(result)); - on_connection_error(); - } - - if (auto const err = m_read_buffer.write_until_depleted({ buffer, nread }); err.is_error()) - on_connection_error(); - - on_ready_to_read(); + read_from_socket(); }; m_error_notifier = Core::Notifier::construct(socket_fd, Core::Notifier::Type::Error | Core::Notifier::Type::HangUp); m_error_notifier->on_activation = [this] { @@ -191,6 +214,11 @@ bool WebSocketImplCurl::did_connect() }; on_connected(); + + // There may be data waiting for us already (e.g. if the server sends us data immediately upon opening a WebSocket), + // so try reading immediately. + read_from_socket(); + return true; } diff --git a/Services/RequestServer/WebSocketImplCurl.h b/Services/RequestServer/WebSocketImplCurl.h index 25f0500e5bc..853767e5662 100644 --- a/Services/RequestServer/WebSocketImplCurl.h +++ b/Services/RequestServer/WebSocketImplCurl.h @@ -34,6 +34,8 @@ public: private: explicit WebSocketImplCurl(CURLM*); + void read_from_socket(); + CURLM* m_multi_handle { nullptr }; CURL* m_easy_handle { nullptr }; RefPtr m_read_notifier;