LibProtocol: Add a Download object so users don't have to manage ID's

LibProtocol::Client::start_download() now gives you a Download object
with convenient hooks (on_finish & on_progress).

Also, the IPC handshake is snuck into the Client constructor, so you
don't need to perform it after instantiating a Client.

This makes using LibProtocol much more pleasant. :^)
This commit is contained in:
Andreas Kling 2019-11-24 13:20:44 +01:00
parent 3dc87be891
commit 653e61d9cf
Notes: sideshowbarker 2024-07-19 11:05:54 +09:00
6 changed files with 112 additions and 28 deletions

View file

@ -1,4 +1,5 @@
#include <LibProtocol/Client.h>
#include <LibProtocol/Download.h>
#include <SharedBuffer.h>
namespace LibProtocol {
@ -6,6 +7,7 @@ namespace LibProtocol {
Client::Client()
: ConnectionNG(*this, "/tmp/psportal")
{
handshake();
}
void Client::handshake()
@ -20,27 +22,36 @@ bool Client::is_supported_protocol(const String& protocol)
return send_sync<ProtocolServer::IsSupportedProtocol>(protocol)->supported();
}
i32 Client::start_download(const String& url)
RefPtr<Download> Client::start_download(const String& url)
{
return send_sync<ProtocolServer::StartDownload>(url)->download_id();
i32 download_id = send_sync<ProtocolServer::StartDownload>(url)->download_id();
auto download = Download::create_from_id({}, *this, download_id);
m_downloads.set(download_id, download);
return download;
}
bool Client::stop_download(i32 download_id)
bool Client::stop_download(Badge<Download>, Download& download)
{
return send_sync<ProtocolServer::StopDownload>(download_id)->success();
if (!m_downloads.contains(download.id()))
return false;
return send_sync<ProtocolServer::StopDownload>(download.id())->success();
}
void Client::handle(const ProtocolClient::DownloadFinished& message)
{
if (on_download_finish)
on_download_finish(message.download_id(), message.success(), message.total_size(), message.shared_buffer_id());
RefPtr<Download> download;
if ((download = m_downloads.get(message.download_id()).value_or(nullptr))) {
download->did_finish({}, message.success(), message.total_size(), message.shared_buffer_id());
}
send_sync<ProtocolServer::DisownSharedBuffer>(message.shared_buffer_id());
m_downloads.remove(message.download_id());
}
void Client::handle(const ProtocolClient::DownloadProgress& message)
{
if (on_download_progress)
on_download_progress(message.download_id(), message.total_size(), message.downloaded_size());
if (auto download = m_downloads.get(message.download_id()).value_or(nullptr)) {
download->did_progress({}, message.total_size(), message.downloaded_size());
}
}
}

View file

@ -6,6 +6,8 @@
namespace LibProtocol {
class Download;
class Client : public IPC::Client::ConnectionNG<ProtocolClientEndpoint, ProtocolServerEndpoint>
, public ProtocolClientEndpoint {
C_OBJECT(Client)
@ -15,15 +17,15 @@ public:
virtual void handshake() override;
bool is_supported_protocol(const String&);
i32 start_download(const String& url);
bool stop_download(i32 download_id);
RefPtr<Download> start_download(const String& url);
Function<void(i32 download_id, bool success, u32 total_size, i32 shared_buffer_id)> on_download_finish;
Function<void(i32 download_id, u64 total_size, u64 downloaded_size)> on_download_progress;
bool stop_download(Badge<Download>, Download&);
private:
virtual void handle(const ProtocolClient::DownloadProgress&) override;
virtual void handle(const ProtocolClient::DownloadFinished&) override;
HashMap<i32, RefPtr<Download>> m_downloads;
};
}

View file

@ -0,0 +1,38 @@
#include <LibC/SharedBuffer.h>
#include <LibProtocol/Client.h>
#include <LibProtocol/Download.h>
namespace LibProtocol {
Download::Download(Client& client, i32 download_id)
: m_client(client.make_weak_ptr())
, m_download_id(download_id)
{
}
bool Download::stop()
{
return m_client->stop_download({}, *this);
}
void Download::did_finish(Badge<Client>, bool success, u32 total_size, i32 shared_buffer_id)
{
if (!on_finish)
return;
ByteBuffer payload;
RefPtr<SharedBuffer> shared_buffer;
if (success && shared_buffer_id != -1) {
shared_buffer = SharedBuffer::create_from_shared_buffer_id(shared_buffer_id);
payload = ByteBuffer::wrap(shared_buffer->data(), total_size);
}
on_finish(success, payload, move(shared_buffer));
}
void Download::did_progress(Badge<Client>, u32 total_size, u32 downloaded_size)
{
if (on_progress)
on_progress(total_size, downloaded_size);
}
}

View file

@ -0,0 +1,37 @@
#pragma once
#include <AK/Badge.h>
#include <AK/ByteBuffer.h>
#include <AK/Function.h>
#include <AK/RefCounted.h>
#include <AK/WeakPtr.h>
class SharedBuffer;
namespace LibProtocol {
class Client;
class Download : public RefCounted<Download> {
public:
static NonnullRefPtr<Download> create_from_id(Badge<Client>, Client& client, i32 download_id)
{
return adopt(*new Download(client, download_id));
}
int id() const { return m_download_id; }
bool stop();
Function<void(bool success, const ByteBuffer& payload, RefPtr<SharedBuffer> payload_storage)> on_finish;
Function<void(u32 total_size, u32 downloaded_size)> on_progress;
void did_finish(Badge<Client>, bool success, u32 total_size, i32 shared_buffer_id);
void did_progress(Badge<Client>, u32 total_size, u32 downloaded_size);
private:
explicit Download(Client&, i32 download_id);
WeakPtr<Client> m_client;
int m_download_id { -1 };
};
}

View file

@ -1,6 +1,7 @@
include ../../Makefile.common
OBJS = \
Download.o \
Client.o
LIBRARY = libprotocol.a

View file

@ -2,6 +2,7 @@
#include <LibC/SharedBuffer.h>
#include <LibCore/CEventLoop.h>
#include <LibProtocol/Client.h>
#include <LibProtocol/Download.h>
#include <stdio.h>
int main(int argc, char** argv)
@ -20,25 +21,19 @@ int main(int argc, char** argv)
CEventLoop loop;
auto protocol_client = LibProtocol::Client::construct();
protocol_client->handshake();
protocol_client->on_download_finish = [&](i32 download_id, bool success, u32 total_size, i32 shared_buffer_id) {
dbgprintf("download %d finished, success=%u, shared_buffer_id=%d\n", download_id, success, shared_buffer_id);
if (success) {
ASSERT(shared_buffer_id != -1);
auto shared_buffer = SharedBuffer::create_from_shared_buffer_id(shared_buffer_id);
auto payload_bytes = ByteBuffer::wrap(shared_buffer->data(), total_size);
write(STDOUT_FILENO, payload_bytes.data(), payload_bytes.size());
}
auto download = protocol_client->start_download(url.to_string());
download->on_progress = [](u32 total_size, u32 downloaded_size) {
dbgprintf("download progress: %u / %u\n", downloaded_size, total_size);
};
download->on_finish = [&](bool success, auto& payload, auto) {
if (success)
write(STDOUT_FILENO, payload.data(), payload.size());
else
fprintf(stderr, "Download failed :(\n");
loop.quit(0);
};
protocol_client->on_download_progress = [&](i32 download_id, u32 total_size, u32 downloaded_size) {
dbgprintf("download %d progress: %u / %u\n", download_id, downloaded_size, total_size);
};
i32 download_id = protocol_client->start_download(url.to_string());
dbgprintf("started download with id %d\n", download_id);
dbgprintf("started download with id %d\n", download->id());
return loop.exec();
}