Prechádzať zdrojové kódy

LibProtocol: Add a Download object so users don't have to manage ID's

LibProtocol::Client::start_download() now gives you a Download object
with convenient hooks (on_finish & on_progress).

Also, the IPC handshake is snuck into the Client constructor, so you
don't need to perform it after instantiating a Client.

This makes using LibProtocol much more pleasant. :^)
Andreas Kling 5 rokov pred
rodič
commit
653e61d9cf

+ 19 - 8
Libraries/LibProtocol/Client.cpp

@@ -1,4 +1,5 @@
 #include <LibProtocol/Client.h>
+#include <LibProtocol/Download.h>
 #include <SharedBuffer.h>
 
 namespace LibProtocol {
@@ -6,6 +7,7 @@ namespace LibProtocol {
 Client::Client()
     : ConnectionNG(*this, "/tmp/psportal")
 {
+    handshake();
 }
 
 void Client::handshake()
@@ -20,27 +22,36 @@ bool Client::is_supported_protocol(const String& protocol)
     return send_sync<ProtocolServer::IsSupportedProtocol>(protocol)->supported();
 }
 
-i32 Client::start_download(const String& url)
+RefPtr<Download> Client::start_download(const String& url)
 {
-    return send_sync<ProtocolServer::StartDownload>(url)->download_id();
+    i32 download_id = send_sync<ProtocolServer::StartDownload>(url)->download_id();
+    auto download = Download::create_from_id({}, *this, download_id);
+    m_downloads.set(download_id, download);
+    return download;
 }
 
-bool Client::stop_download(i32 download_id)
+bool Client::stop_download(Badge<Download>, Download& download)
 {
-    return send_sync<ProtocolServer::StopDownload>(download_id)->success();
+    if (!m_downloads.contains(download.id()))
+        return false;
+    return send_sync<ProtocolServer::StopDownload>(download.id())->success();
 }
 
 void Client::handle(const ProtocolClient::DownloadFinished& message)
 {
-    if (on_download_finish)
-        on_download_finish(message.download_id(), message.success(), message.total_size(), message.shared_buffer_id());
+    RefPtr<Download> download;
+    if ((download = m_downloads.get(message.download_id()).value_or(nullptr))) {
+        download->did_finish({}, message.success(), message.total_size(), message.shared_buffer_id());
+    }
     send_sync<ProtocolServer::DisownSharedBuffer>(message.shared_buffer_id());
+    m_downloads.remove(message.download_id());
 }
 
 void Client::handle(const ProtocolClient::DownloadProgress& message)
 {
-    if (on_download_progress)
-        on_download_progress(message.download_id(), message.total_size(), message.downloaded_size());
+    if (auto download = m_downloads.get(message.download_id()).value_or(nullptr)) {
+        download->did_progress({}, message.total_size(), message.downloaded_size());
+    }
 }
 
 }

+ 6 - 4
Libraries/LibProtocol/Client.h

@@ -6,6 +6,8 @@
 
 namespace LibProtocol {
 
+class Download;
+
 class Client : public IPC::Client::ConnectionNG<ProtocolClientEndpoint, ProtocolServerEndpoint>
     , public ProtocolClientEndpoint {
     C_OBJECT(Client)
@@ -15,15 +17,15 @@ public:
     virtual void handshake() override;
 
     bool is_supported_protocol(const String&);
-    i32 start_download(const String& url);
-    bool stop_download(i32 download_id);
+    RefPtr<Download> start_download(const String& url);
 
-    Function<void(i32 download_id, bool success, u32 total_size, i32 shared_buffer_id)> on_download_finish;
-    Function<void(i32 download_id, u64 total_size, u64 downloaded_size)> on_download_progress;
+    bool stop_download(Badge<Download>, Download&);
 
 private:
     virtual void handle(const ProtocolClient::DownloadProgress&) override;
     virtual void handle(const ProtocolClient::DownloadFinished&) override;
+
+    HashMap<i32, RefPtr<Download>> m_downloads;
 };
 
 }

+ 38 - 0
Libraries/LibProtocol/Download.cpp

@@ -0,0 +1,38 @@
+#include <LibC/SharedBuffer.h>
+#include <LibProtocol/Client.h>
+#include <LibProtocol/Download.h>
+
+namespace LibProtocol {
+
+Download::Download(Client& client, i32 download_id)
+    : m_client(client.make_weak_ptr())
+    , m_download_id(download_id)
+{
+}
+
+bool Download::stop()
+{
+    return m_client->stop_download({}, *this);
+}
+
+void Download::did_finish(Badge<Client>, bool success, u32 total_size, i32 shared_buffer_id)
+{
+    if (!on_finish)
+        return;
+
+    ByteBuffer payload;
+    RefPtr<SharedBuffer> shared_buffer;
+    if (success && shared_buffer_id != -1) {
+        shared_buffer = SharedBuffer::create_from_shared_buffer_id(shared_buffer_id);
+        payload = ByteBuffer::wrap(shared_buffer->data(), total_size);
+    }
+    on_finish(success, payload, move(shared_buffer));
+}
+
+void Download::did_progress(Badge<Client>, u32 total_size, u32 downloaded_size)
+{
+    if (on_progress)
+        on_progress(total_size, downloaded_size);
+}
+
+}

+ 37 - 0
Libraries/LibProtocol/Download.h

@@ -0,0 +1,37 @@
+#pragma once
+
+#include <AK/Badge.h>
+#include <AK/ByteBuffer.h>
+#include <AK/Function.h>
+#include <AK/RefCounted.h>
+#include <AK/WeakPtr.h>
+
+class SharedBuffer;
+
+namespace LibProtocol {
+
+class Client;
+
+class Download : public RefCounted<Download> {
+public:
+    static NonnullRefPtr<Download> create_from_id(Badge<Client>, Client& client, i32 download_id)
+    {
+        return adopt(*new Download(client, download_id));
+    }
+
+    int id() const { return m_download_id; }
+    bool stop();
+
+    Function<void(bool success, const ByteBuffer& payload, RefPtr<SharedBuffer> payload_storage)> on_finish;
+    Function<void(u32 total_size, u32 downloaded_size)> on_progress;
+
+    void did_finish(Badge<Client>, bool success, u32 total_size, i32 shared_buffer_id);
+    void did_progress(Badge<Client>, u32 total_size, u32 downloaded_size);
+
+private:
+    explicit Download(Client&, i32 download_id);
+    WeakPtr<Client> m_client;
+    int m_download_id { -1 };
+};
+
+}

+ 1 - 0
Libraries/LibProtocol/Makefile

@@ -1,6 +1,7 @@
 include ../../Makefile.common
 
 OBJS = \
+    Download.o \
     Client.o
 
 LIBRARY = libprotocol.a

+ 11 - 16
Userland/pro.cpp

@@ -2,6 +2,7 @@
 #include <LibC/SharedBuffer.h>
 #include <LibCore/CEventLoop.h>
 #include <LibProtocol/Client.h>
+#include <LibProtocol/Download.h>
 #include <stdio.h>
 
 int main(int argc, char** argv)
@@ -20,25 +21,19 @@ int main(int argc, char** argv)
 
     CEventLoop loop;
     auto protocol_client = LibProtocol::Client::construct();
-    protocol_client->handshake();
 
-    protocol_client->on_download_finish = [&](i32 download_id, bool success, u32 total_size, i32 shared_buffer_id) {
-        dbgprintf("download %d finished, success=%u, shared_buffer_id=%d\n", download_id, success, shared_buffer_id);
-        if (success) {
-            ASSERT(shared_buffer_id != -1);
-            auto shared_buffer = SharedBuffer::create_from_shared_buffer_id(shared_buffer_id);
-            auto payload_bytes = ByteBuffer::wrap(shared_buffer->data(), total_size);
-            write(STDOUT_FILENO, payload_bytes.data(), payload_bytes.size());
-        }
-        loop.quit(0);
+    auto download = protocol_client->start_download(url.to_string());
+    download->on_progress = [](u32 total_size, u32 downloaded_size) {
+        dbgprintf("download progress: %u / %u\n", downloaded_size, total_size);
     };
-
-    protocol_client->on_download_progress = [&](i32 download_id, u32 total_size, u32 downloaded_size) {
-        dbgprintf("download %d progress: %u / %u\n", download_id, downloaded_size, total_size);
+    download->on_finish = [&](bool success, auto& payload, auto) {
+        if (success)
+            write(STDOUT_FILENO, payload.data(), payload.size());
+        else
+            fprintf(stderr, "Download failed :(\n");
+        loop.quit(0);
     };
-
-    i32 download_id = protocol_client->start_download(url.to_string());
-    dbgprintf("started download with id %d\n", download_id);
+    dbgprintf("started download with id %d\n", download->id());
 
     return loop.exec();
 }