Explorar el Código

ProtocolServer: Implement and handle download progress

Also updates `pro` to display download progress and speed on stderr
AnotherTest hace 5 años
padre
commit
06cf9d3fb7

+ 10 - 0
Libraries/LibCore/NetworkJob.cpp

@@ -78,6 +78,16 @@ void NetworkJob::did_fail(Error error)
     shutdown();
 }
 
+void NetworkJob::did_progress(Optional<u32> total_size, u32 downloaded)
+{
+    // NOTE: We protect ourselves here, since the callback may otherwise
+    //       trigger destruction of this job somehow.
+    NonnullRefPtr<NetworkJob> protector(*this);
+
+    if (on_progress)
+        on_progress(total_size, downloaded);
+}
+
 const char* to_string(NetworkJob::Error error)
 {
     switch (error) {

+ 2 - 0
Libraries/LibCore/NetworkJob.h

@@ -44,6 +44,7 @@ public:
     virtual ~NetworkJob() override;
 
     Function<void(bool success)> on_finish;
+    Function<void(Optional<u32>, u32)> on_progress;
 
     bool is_cancelled() const { return m_error == Error::Cancelled; }
     bool has_error() const { return m_error != Error::None; }
@@ -64,6 +65,7 @@ protected:
     NetworkJob();
     void did_finish(NonnullRefPtr<NetworkResponse>&&);
     void did_fail(Error);
+    void did_progress(Optional<u32> total_size, u32 downloaded);
 
 private:
     RefPtr<NetworkResponse> m_response;

+ 15 - 3
Libraries/LibHTTP/HttpJob.cpp

@@ -156,11 +156,23 @@ void HttpJob::on_socket_connected()
         m_received_size += payload.size();
 
         auto content_length_header = m_headers.get("Content-Length");
+        Optional<u32> content_length {};
+
         if (content_length_header.has_value()) {
             bool ok;
-            auto content_length = content_length_header.value().to_uint(ok);
-            if (ok && m_received_size >= content_length) {
-                m_received_size = content_length;
+            auto length = content_length_header.value().to_uint(ok);
+            if (ok)
+                content_length = length;
+        }
+
+        deferred_invoke([this, content_length](auto&) {
+            did_progress(content_length, m_received_size);
+        });
+
+        if (content_length.has_value()) {
+            auto length = content_length.value();
+            if (m_received_size >= length) {
+                m_received_size = length;
                 finish_up();
             }
         }

+ 15 - 6
Libraries/LibHTTP/HttpsJob.cpp

@@ -166,16 +166,25 @@ void HttpsJob::on_socket_connected()
         m_received_size += payload.size();
 
         auto content_length_header = m_headers.get("Content-Length");
+        Optional<u32> content_length {};
+
         if (content_length_header.has_value()) {
             bool ok;
-            auto content_length = content_length_header.value().to_uint(ok);
-            if (ok && m_received_size >= content_length) {
-                m_received_size = content_length;
+            auto length = content_length_header.value().to_uint(ok);
+            if (ok)
+                content_length = length;
+        }
+
+        // This needs to be synchronous
+        // FIXME: Somehow enforce that this should not modify anything
+        did_progress(content_length, m_received_size);
+
+        if (content_length.has_value()) {
+            auto length = content_length.value();
+            if (m_received_size >= length) {
+                m_received_size = length;
                 finish_up();
             }
-        } else {
-            // no content-length, assume closed connection
-            finish_up();
         }
     };
 }

+ 1 - 1
Libraries/LibProtocol/Download.cpp

@@ -55,7 +55,7 @@ void Download::did_finish(Badge<Client>, bool success, u32 total_size, i32 shbuf
     on_finish(success, payload, move(shared_buffer));
 }
 
-void Download::did_progress(Badge<Client>, u32 total_size, u32 downloaded_size)
+void Download::did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size)
 {
     if (on_progress)
         on_progress(total_size, downloaded_size);

+ 2 - 2
Libraries/LibProtocol/Download.h

@@ -47,10 +47,10 @@ public:
     bool stop();
 
     Function<void(bool success, const ByteBuffer& payload, RefPtr<SharedBuffer> payload_storage)> on_finish;
-    Function<void(u32 total_size, u32 downloaded_size)> on_progress;
+    Function<void(Optional<u32> total_size, u32 downloaded_size)> on_progress;
 
     void did_finish(Badge<Client>, bool success, u32 total_size, i32 shbuf_id);
-    void did_progress(Badge<Client>, u32 total_size, u32 downloaded_size);
+    void did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size);
 
 private:
     explicit Download(Client&, i32 download_id);

+ 1 - 1
Servers/ProtocolServer/Download.cpp

@@ -74,7 +74,7 @@ void Download::did_finish(bool success)
     all_downloads().remove(m_id);
 }
 
-void Download::did_progress(size_t total_size, size_t downloaded_size)
+void Download::did_progress(Optional<u32> total_size, u32 downloaded_size)
 {
     if (!m_client) {
         // FIXME: We should also abort the download in this situation, I guess!

+ 4 - 3
Servers/ProtocolServer/Download.h

@@ -27,6 +27,7 @@
 #pragma once
 
 #include <AK/ByteBuffer.h>
+#include <AK/Optional.h>
 #include <AK/RefCounted.h>
 #include <AK/URL.h>
 #include <AK/WeakPtr.h>
@@ -42,7 +43,7 @@ public:
     i32 id() const { return m_id; }
     URL url() const { return m_url; }
 
-    size_t total_size() const { return m_total_size; }
+    Optional<u32> total_size() const { return m_total_size; }
     size_t downloaded_size() const { return m_downloaded_size; }
     const ByteBuffer& payload() const { return m_payload; }
 
@@ -52,13 +53,13 @@ protected:
     explicit Download(PSClientConnection&);
 
     void did_finish(bool success);
-    void did_progress(size_t total_size, size_t downloaded_size);
+    void did_progress(Optional<u32> total_size, u32 downloaded_size);
     void set_payload(const ByteBuffer&);
 
 private:
     i32 m_id;
     URL m_url;
-    size_t m_total_size { 0 };
+    Optional<u32> m_total_size {};
     size_t m_downloaded_size { 0 };
     ByteBuffer m_payload;
     WeakPtr<PSClientConnection> m_client;

+ 9 - 0
Servers/ProtocolServer/HttpDownload.cpp

@@ -35,8 +35,17 @@ HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpJ
     m_job->on_finish = [this](bool success) {
         if (m_job->response())
             set_payload(m_job->response()->payload());
+
+        // if we didn't know the total size, pretend that the download finished successfully
+        // and set the total size to the downloaded size
+        if (!total_size().has_value())
+            did_progress(downloaded_size(), downloaded_size());
+
         did_finish(success);
     };
+    m_job->on_progress = [this](Optional<u32> total, u32 current) {
+        did_progress(total, current);
+    };
 }
 
 HttpDownload::~HttpDownload()

+ 9 - 0
Servers/ProtocolServer/HttpsDownload.cpp

@@ -35,8 +35,17 @@ HttpsDownload::HttpsDownload(PSClientConnection& client, NonnullRefPtr<HTTP::Htt
     m_job->on_finish = [this](bool success) {
         if (m_job->response())
             set_payload(m_job->response()->payload());
+
+        // if we didn't know the total size, pretend that the download finished successfully
+        // and set the total size to the downloaded size
+        if (!total_size().has_value())
+            did_progress(downloaded_size(), downloaded_size());
+
         did_finish(success);
     };
+    m_job->on_progress = [this](Optional<u32> total, u32 current) {
+        did_progress(total, current);
+    };
 }
 
 HttpsDownload::~HttpsDownload()

+ 2 - 1
Servers/ProtocolServer/HttpsProtocol.cpp

@@ -44,6 +44,7 @@ RefPtr<Download> HttpsProtocol::start_download(PSClientConnection& client, const
     request.set_method(HTTP::HttpRequest::Method::GET);
     request.set_url(url);
     auto job = HTTP::HttpsJob::construct(request);
+    auto download = HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job);
     job->start();
-    return HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job);
+    return download;
 }

+ 2 - 1
Servers/ProtocolServer/PSClientConnection.cpp

@@ -86,7 +86,8 @@ void PSClientConnection::did_finish_download(Badge<Download>, Download& download
         buffer->share_with(client_pid());
         m_shared_buffers.set(buffer->shbuf_id(), buffer);
     }
-    post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.total_size(), buffer ? buffer->shbuf_id() : -1));
+    ASSERT(download.total_size().has_value());
+    post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.total_size().value(), buffer ? buffer->shbuf_id() : -1));
 }
 
 void PSClientConnection::did_progress_download(Badge<Download>, Download& download)

+ 1 - 1
Servers/ProtocolServer/ProtocolClient.ipc

@@ -1,6 +1,6 @@
 endpoint ProtocolClient = 13
 {
     // Download notifications
-    DownloadProgress(i32 download_id, u32 total_size, u32 downloaded_size) =|
+    DownloadProgress(i32 download_id, Optional<u32> total_size, u32 downloaded_size) =|
     DownloadFinished(i32 download_id, bool success, u32 total_size, i32 shbuf_id) =|
 }

+ 24 - 3
Userland/pro.cpp

@@ -24,8 +24,9 @@
  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  */
 
-#include <AK/URL.h>
+#include <AK/NumberFormat.h>
 #include <AK/SharedBuffer.h>
+#include <AK/URL.h>
 #include <LibCore/EventLoop.h>
 #include <LibProtocol/Client.h>
 #include <LibProtocol/Download.h>
@@ -53,10 +54,30 @@ int main(int argc, char** argv)
         fprintf(stderr, "Failed to start download for '%s'\n", url_string.characters());
         return 1;
     }
-    download->on_progress = [](u32 total_size, u32 downloaded_size) {
-        dbgprintf("download progress: %u / %u\n", downloaded_size, total_size);
+    u32 previous_downloaded_size { 0 };
+    timeval prev_time, current_time, time_diff;
+    gettimeofday(&prev_time, nullptr);
+
+    download->on_progress = [&](Optional<u32> maybe_total_size, u32 downloaded_size) {
+        fprintf(stderr, "\r\033[2K");
+        if (maybe_total_size.has_value())
+            fprintf(stderr, "Download progress: %s / %s", human_readable_size(downloaded_size).characters(), human_readable_size(maybe_total_size.value()).characters());
+        else
+            fprintf(stderr, "Download progress: %s / ???", human_readable_size(downloaded_size).characters());
+
+        gettimeofday(&current_time, nullptr);
+        timersub(&current_time, &prev_time, &time_diff);
+
+        auto time_diff_ms = time_diff.tv_sec * 1000 + time_diff.tv_usec / 1000;
+        auto size_diff = downloaded_size - previous_downloaded_size;
+
+        fprintf(stderr, " at %s/s", human_readable_size(((float)size_diff / (float)time_diff_ms) * 1000).characters());
+
+        previous_downloaded_size = downloaded_size;
+        prev_time = current_time;
     };
     download->on_finish = [&](bool success, auto& payload, auto) {
+        fprintf(stderr, "\n");
         if (success)
             write(STDOUT_FILENO, payload.data(), payload.size());
         else