mirror of
https://github.com/LadybirdBrowser/ladybird.git
synced 2025-04-21 12:05:15 +00:00
ProtocolServer: Attach downloads and their lifecycles to clients
Previously a download lived independently of the client connection it came from. This was the source of several undesirable behaviours, including the potential for clients to influence downloads they didn't start, and downloads living longer than their associated client connections. Now we attach downloads to client connections, which means they're cleaned up automatically when the client goes away, and there's significantly less risk of clients interfering with each other.
This commit is contained in:
parent
184ee8ac77
commit
f2621f37a4
Notes:
sideshowbarker
2024-07-19 06:34:13 +09:00
Author: https://github.com/deoxxa Commit: https://github.com/SerenityOS/serenity/commit/f2621f37a4c Pull-request: https://github.com/SerenityOS/serenity/pull/2218 Reviewed-by: https://github.com/alimpfard Reviewed-by: https://github.com/awesomekling Reviewed-by: https://github.com/bugaevc
17 changed files with 48 additions and 57 deletions
|
@ -31,22 +31,10 @@
|
|||
// FIXME: What about rollover?
|
||||
static i32 s_next_id = 1;
|
||||
|
||||
static HashMap<i32, RefPtr<Download>>& all_downloads()
|
||||
{
|
||||
static HashMap<i32, RefPtr<Download>> map;
|
||||
return map;
|
||||
}
|
||||
|
||||
Download* Download::find_by_id(i32 id)
|
||||
{
|
||||
return const_cast<Download*>(all_downloads().get(id).value_or(nullptr));
|
||||
}
|
||||
|
||||
Download::Download(PSClientConnection& client)
|
||||
: m_id(s_next_id++)
|
||||
, m_client(client.make_weak_ptr())
|
||||
: m_client(client)
|
||||
, m_id(s_next_id++)
|
||||
{
|
||||
all_downloads().set(m_id, this);
|
||||
}
|
||||
|
||||
Download::~Download()
|
||||
|
@ -55,7 +43,7 @@ Download::~Download()
|
|||
|
||||
void Download::stop()
|
||||
{
|
||||
all_downloads().remove(m_id);
|
||||
m_client.did_finish_download({}, *this, false);
|
||||
}
|
||||
|
||||
void Download::set_payload(const ByteBuffer& payload)
|
||||
|
@ -71,22 +59,12 @@ void Download::set_response_headers(const HashMap<String, String, CaseInsensitiv
|
|||
|
||||
void Download::did_finish(bool success)
|
||||
{
|
||||
if (!m_client) {
|
||||
dbg() << "Download::did_finish() after the client already disconnected.";
|
||||
return;
|
||||
}
|
||||
m_client->did_finish_download({}, *this, success);
|
||||
all_downloads().remove(m_id);
|
||||
m_client.did_finish_download({}, *this, success);
|
||||
}
|
||||
|
||||
void Download::did_progress(Optional<u32> total_size, u32 downloaded_size)
|
||||
{
|
||||
if (!m_client) {
|
||||
// FIXME: We should also abort the download in this situation, I guess!
|
||||
dbg() << "Download::did_progress() after the client already disconnected.";
|
||||
return;
|
||||
}
|
||||
m_total_size = total_size;
|
||||
m_downloaded_size = downloaded_size;
|
||||
m_client->did_progress_download({}, *this);
|
||||
m_client.did_progress_download({}, *this);
|
||||
}
|
||||
|
|
|
@ -31,16 +31,13 @@
|
|||
#include <AK/Optional.h>
|
||||
#include <AK/RefCounted.h>
|
||||
#include <AK/URL.h>
|
||||
#include <AK/WeakPtr.h>
|
||||
|
||||
class PSClientConnection;
|
||||
|
||||
class Download : public RefCounted<Download> {
|
||||
class Download {
|
||||
public:
|
||||
virtual ~Download();
|
||||
|
||||
static Download* find_by_id(i32);
|
||||
|
||||
i32 id() const { return m_id; }
|
||||
URL url() const { return m_url; }
|
||||
|
||||
|
@ -60,11 +57,11 @@ protected:
|
|||
void set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>&);
|
||||
|
||||
private:
|
||||
i32 m_id;
|
||||
PSClientConnection& m_client;
|
||||
i32 m_id { 0 };
|
||||
URL m_url;
|
||||
Optional<u32> m_total_size {};
|
||||
size_t m_downloaded_size { 0 };
|
||||
ByteBuffer m_payload;
|
||||
HashMap<String, String, CaseInsensitiveStringTraits> m_response_headers;
|
||||
WeakPtr<PSClientConnection> m_client;
|
||||
};
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
#include <LibGemini/GeminiJob.h>
|
||||
#include <ProtocolServer/GeminiDownload.h>
|
||||
|
||||
GeminiDownload::GeminiDownload(PSClientConnection& client, NonnullRefPtr<Gemini::GeminiJob>&& job)
|
||||
GeminiDownload::GeminiDownload(PSClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job)
|
||||
: Download(client)
|
||||
, m_job(job)
|
||||
{
|
||||
|
@ -55,9 +55,12 @@ GeminiDownload::GeminiDownload(PSClientConnection& client, NonnullRefPtr<Gemini:
|
|||
|
||||
GeminiDownload::~GeminiDownload()
|
||||
{
|
||||
m_job->on_finish = nullptr;
|
||||
m_job->on_progress = nullptr;
|
||||
m_job->shutdown();
|
||||
}
|
||||
|
||||
NonnullRefPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, PSClientConnection& client, NonnullRefPtr<Gemini::GeminiJob>&& job)
|
||||
NonnullOwnPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, PSClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job)
|
||||
{
|
||||
return adopt(*new GeminiDownload(client, move(job)));
|
||||
return adopt_own(*new GeminiDownload(client, move(job)));
|
||||
}
|
||||
|
|
|
@ -36,10 +36,10 @@ class GeminiProtocol;
|
|||
class GeminiDownload final : public Download {
|
||||
public:
|
||||
virtual ~GeminiDownload() override;
|
||||
static NonnullRefPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, PSClientConnection&, NonnullRefPtr<Gemini::GeminiJob>&&);
|
||||
static NonnullOwnPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, PSClientConnection&, NonnullRefPtr<Gemini::GeminiJob>);
|
||||
|
||||
private:
|
||||
explicit GeminiDownload(PSClientConnection&, NonnullRefPtr<Gemini::GeminiJob>&&);
|
||||
explicit GeminiDownload(PSClientConnection&, NonnullRefPtr<Gemini::GeminiJob>);
|
||||
|
||||
NonnullRefPtr<Gemini::GeminiJob> m_job;
|
||||
};
|
||||
|
|
|
@ -38,7 +38,7 @@ GeminiProtocol::~GeminiProtocol()
|
|||
{
|
||||
}
|
||||
|
||||
RefPtr<Download> GeminiProtocol::start_download(PSClientConnection& client, const URL& url)
|
||||
OwnPtr<Download> GeminiProtocol::start_download(PSClientConnection& client, const URL& url)
|
||||
{
|
||||
Gemini::GeminiRequest request;
|
||||
request.set_url(url);
|
||||
|
|
|
@ -33,5 +33,5 @@ public:
|
|||
GeminiProtocol();
|
||||
virtual ~GeminiProtocol() override;
|
||||
|
||||
virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) override;
|
||||
virtual OwnPtr<Download> start_download(PSClientConnection&, const URL&) override;
|
||||
};
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
#include <LibHTTP/HttpResponse.h>
|
||||
#include <ProtocolServer/HttpDownload.h>
|
||||
|
||||
HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpJob>&& job)
|
||||
HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job)
|
||||
: Download(client)
|
||||
, m_job(job)
|
||||
{
|
||||
|
@ -52,9 +52,12 @@ HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpJ
|
|||
|
||||
HttpDownload::~HttpDownload()
|
||||
{
|
||||
m_job->on_finish = nullptr;
|
||||
m_job->on_progress = nullptr;
|
||||
m_job->shutdown();
|
||||
}
|
||||
|
||||
NonnullRefPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, PSClientConnection& client, NonnullRefPtr<HTTP::HttpJob>&& job)
|
||||
NonnullOwnPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, PSClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job)
|
||||
{
|
||||
return adopt(*new HttpDownload(client, move(job)));
|
||||
return adopt_own(*new HttpDownload(client, move(job)));
|
||||
}
|
||||
|
|
|
@ -36,10 +36,10 @@ class HttpProtocol;
|
|||
class HttpDownload final : public Download {
|
||||
public:
|
||||
virtual ~HttpDownload() override;
|
||||
static NonnullRefPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, PSClientConnection&, NonnullRefPtr<HTTP::HttpJob>&&);
|
||||
static NonnullOwnPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, PSClientConnection&, NonnullRefPtr<HTTP::HttpJob>);
|
||||
|
||||
private:
|
||||
explicit HttpDownload(PSClientConnection&, NonnullRefPtr<HTTP::HttpJob>&&);
|
||||
explicit HttpDownload(PSClientConnection&, NonnullRefPtr<HTTP::HttpJob>);
|
||||
|
||||
NonnullRefPtr<HTTP::HttpJob> m_job;
|
||||
};
|
||||
|
|
|
@ -38,7 +38,7 @@ HttpProtocol::~HttpProtocol()
|
|||
{
|
||||
}
|
||||
|
||||
RefPtr<Download> HttpProtocol::start_download(PSClientConnection& client, const URL& url)
|
||||
OwnPtr<Download> HttpProtocol::start_download(PSClientConnection& client, const URL& url)
|
||||
{
|
||||
HTTP::HttpRequest request;
|
||||
request.set_method(HTTP::HttpRequest::Method::GET);
|
||||
|
|
|
@ -33,5 +33,5 @@ public:
|
|||
HttpProtocol();
|
||||
virtual ~HttpProtocol() override;
|
||||
|
||||
virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) override;
|
||||
virtual OwnPtr<Download> start_download(PSClientConnection&, const URL&) override;
|
||||
};
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
#include <LibHTTP/HttpsJob.h>
|
||||
#include <ProtocolServer/HttpsDownload.h>
|
||||
|
||||
HttpsDownload::HttpsDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpsJob>&& job)
|
||||
HttpsDownload::HttpsDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job)
|
||||
: Download(client)
|
||||
, m_job(job)
|
||||
{
|
||||
|
@ -52,9 +52,12 @@ HttpsDownload::HttpsDownload(PSClientConnection& client, NonnullRefPtr<HTTP::Htt
|
|||
|
||||
HttpsDownload::~HttpsDownload()
|
||||
{
|
||||
m_job->on_finish = nullptr;
|
||||
m_job->on_progress = nullptr;
|
||||
m_job->shutdown();
|
||||
}
|
||||
|
||||
NonnullRefPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, PSClientConnection& client, NonnullRefPtr<HTTP::HttpsJob>&& job)
|
||||
NonnullOwnPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, PSClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job)
|
||||
{
|
||||
return adopt(*new HttpsDownload(client, move(job)));
|
||||
return adopt_own(*new HttpsDownload(client, move(job)));
|
||||
}
|
||||
|
|
|
@ -36,10 +36,10 @@ class HttpsProtocol;
|
|||
class HttpsDownload final : public Download {
|
||||
public:
|
||||
virtual ~HttpsDownload() override;
|
||||
static NonnullRefPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, PSClientConnection&, NonnullRefPtr<HTTP::HttpsJob>&&);
|
||||
static NonnullOwnPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, PSClientConnection&, NonnullRefPtr<HTTP::HttpsJob>);
|
||||
|
||||
private:
|
||||
explicit HttpsDownload(PSClientConnection&, NonnullRefPtr<HTTP::HttpsJob>&&);
|
||||
explicit HttpsDownload(PSClientConnection&, NonnullRefPtr<HTTP::HttpsJob>);
|
||||
|
||||
NonnullRefPtr<HTTP::HttpsJob> m_job;
|
||||
};
|
||||
|
|
|
@ -38,7 +38,7 @@ HttpsProtocol::~HttpsProtocol()
|
|||
{
|
||||
}
|
||||
|
||||
RefPtr<Download> HttpsProtocol::start_download(PSClientConnection& client, const URL& url)
|
||||
OwnPtr<Download> HttpsProtocol::start_download(PSClientConnection& client, const URL& url)
|
||||
{
|
||||
HTTP::HttpRequest request;
|
||||
request.set_method(HTTP::HttpRequest::Method::GET);
|
||||
|
|
|
@ -33,5 +33,5 @@ public:
|
|||
HttpsProtocol();
|
||||
virtual ~HttpsProtocol() override;
|
||||
|
||||
virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) override;
|
||||
virtual OwnPtr<Download> start_download(PSClientConnection&, const URL&) override;
|
||||
};
|
||||
|
|
|
@ -63,12 +63,16 @@ OwnPtr<Messages::ProtocolServer::StartDownloadResponse> PSClientConnection::hand
|
|||
if (!protocol)
|
||||
return make<Messages::ProtocolServer::StartDownloadResponse>(-1);
|
||||
auto download = protocol->start_download(*this, url);
|
||||
return make<Messages::ProtocolServer::StartDownloadResponse>(download->id());
|
||||
if (!download)
|
||||
return make<Messages::ProtocolServer::StartDownloadResponse>(-1);
|
||||
auto id = download->id();
|
||||
m_downloads.set(id, move(download));
|
||||
return make<Messages::ProtocolServer::StartDownloadResponse>(id);
|
||||
}
|
||||
|
||||
OwnPtr<Messages::ProtocolServer::StopDownloadResponse> PSClientConnection::handle(const Messages::ProtocolServer::StopDownload& message)
|
||||
{
|
||||
auto* download = Download::find_by_id(message.download_id());
|
||||
auto* download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr));
|
||||
bool success = false;
|
||||
if (download) {
|
||||
download->stop();
|
||||
|
@ -93,6 +97,8 @@ void PSClientConnection::did_finish_download(Badge<Download>, Download& download
|
|||
for (auto& it : download.response_headers())
|
||||
response_headers.add(it.key, it.value);
|
||||
post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.total_size().value(), buffer ? buffer->shbuf_id() : -1, response_headers));
|
||||
|
||||
m_downloads.remove(download.id());
|
||||
}
|
||||
|
||||
void PSClientConnection::did_progress_download(Badge<Download>, Download& download)
|
||||
|
|
|
@ -51,5 +51,6 @@ private:
|
|||
virtual OwnPtr<Messages::ProtocolServer::StopDownloadResponse> handle(const Messages::ProtocolServer::StopDownload&) override;
|
||||
virtual OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> handle(const Messages::ProtocolServer::DisownSharedBuffer&) override;
|
||||
|
||||
HashMap<i32, OwnPtr<Download>> m_downloads;
|
||||
HashMap<i32, RefPtr<AK::SharedBuffer>> m_shared_buffers;
|
||||
};
|
||||
|
|
|
@ -37,7 +37,7 @@ public:
|
|||
virtual ~Protocol();
|
||||
|
||||
const String& name() const { return m_name; }
|
||||
virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) = 0;
|
||||
virtual OwnPtr<Download> start_download(PSClientConnection&, const URL&) = 0;
|
||||
|
||||
static Protocol* find_by_name(const String&);
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue