diff --git a/Applications/Browser/DownloadWidget.cpp b/Applications/Browser/DownloadWidget.cpp index 9441e15a80f..ed7796c3ac1 100644 --- a/Applications/Browser/DownloadWidget.cpp +++ b/Applications/Browser/DownloadWidget.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -61,9 +62,19 @@ DownloadWidget::DownloadWidget(const URL& url) m_download->on_progress = [this](Optional total_size, u32 downloaded_size) { did_progress(total_size.value(), downloaded_size); }; - m_download->on_finish = [this](bool success, auto payload, auto payload_storage, auto& response_headers, auto) { - did_finish(success, payload, payload_storage, response_headers); - }; + + { + auto file_or_error = Core::File::open(m_destination_path, Core::IODevice::WriteOnly); + if (file_or_error.is_error()) { + GUI::MessageBox::show(window(), String::formatted("Cannot open {} for writing", m_destination_path), "Download failed", GUI::MessageBox::Type::Error); + window()->close(); + return; + } + m_output_file_stream = make(*file_or_error.value()); + } + + m_download->on_finish = [this](bool success, auto) { did_finish(success); }; + m_download->stream_into(*m_output_file_stream); set_fill_with_background_color(true); auto& layout = set_layout(); @@ -149,7 +160,7 @@ void DownloadWidget::did_progress(Optional total_size, u32 downloaded_size) } } -void DownloadWidget::did_finish(bool success, [[maybe_unused]] ReadonlyBytes payload, [[maybe_unused]] RefPtr payload_storage, [[maybe_unused]] const HashMap& response_headers) +void DownloadWidget::did_finish(bool success) { dbg() << "did_finish, success=" << success; @@ -166,17 +177,6 @@ void DownloadWidget::did_finish(bool success, [[maybe_unused]] ReadonlyBytes pay window()->close(); return; } - - auto file_or_error = Core::File::open(m_destination_path, Core::IODevice::WriteOnly); - if (file_or_error.is_error()) { - GUI::MessageBox::show(window(), String::formatted("Cannot open {} for writing", m_destination_path), "Download failed", GUI::MessageBox::Type::Error); - window()->close(); - return; - } - - auto& file = *file_or_error.value(); - bool write_success = file.write(payload.data(), payload.size()); - ASSERT(write_success); } } diff --git a/Applications/Browser/DownloadWidget.h b/Applications/Browser/DownloadWidget.h index 2b69af5d266..44fc21c60d5 100644 --- a/Applications/Browser/DownloadWidget.h +++ b/Applications/Browser/DownloadWidget.h @@ -28,6 +28,7 @@ #include #include +#include #include #include #include @@ -44,7 +45,7 @@ private: explicit DownloadWidget(const URL&); void did_progress(Optional total_size, u32 downloaded_size); - void did_finish(bool success, ReadonlyBytes payload, RefPtr payload_storage, const HashMap& response_headers); + void did_finish(bool success); URL m_url; String m_destination_path; @@ -53,6 +54,7 @@ private: RefPtr m_progress_label; RefPtr m_cancel_button; RefPtr m_close_button; + OwnPtr m_output_file_stream; Core::ElapsedTimer m_elapsed_timer; }; diff --git a/Applications/Browser/main.cpp b/Applications/Browser/main.cpp index 3624ca60767..07b79ed505e 100644 --- a/Applications/Browser/main.cpp +++ b/Applications/Browser/main.cpp @@ -68,7 +68,7 @@ int main(int argc, char** argv) return 1; } - if (pledge("stdio shared_buffer accept unix cpath rpath wpath fattr", nullptr) < 0) { + if (pledge("stdio shared_buffer accept unix cpath rpath wpath fattr sendfd recvfd", nullptr) < 0) { perror("pledge"); return 1; } @@ -86,7 +86,7 @@ int main(int argc, char** argv) Web::ResourceLoader::the(); // FIXME: Once there is a standalone Download Manager, we can drop the "unix" pledge. - if (pledge("stdio shared_buffer accept unix cpath rpath wpath", nullptr) < 0) { + if (pledge("stdio shared_buffer accept unix cpath rpath wpath sendfd recvfd", nullptr) < 0) { perror("pledge"); return 1; } diff --git a/Libraries/LibCore/NetworkJob.cpp b/Libraries/LibCore/NetworkJob.cpp index 901aeb3d025..0b48a93a724 100644 --- a/Libraries/LibCore/NetworkJob.cpp +++ b/Libraries/LibCore/NetworkJob.cpp @@ -32,7 +32,8 @@ namespace Core { -NetworkJob::NetworkJob() +NetworkJob::NetworkJob(OutputStream& output_stream) + : m_output_stream(output_stream) { } diff --git a/Libraries/LibCore/NetworkJob.h b/Libraries/LibCore/NetworkJob.h index 8e2f57d8a1b..94ead448056 100644 --- a/Libraries/LibCore/NetworkJob.h +++ b/Libraries/LibCore/NetworkJob.h @@ -27,6 +27,7 @@ #pragma once #include +#include #include namespace Core { @@ -43,6 +44,8 @@ public: }; virtual ~NetworkJob() override; + // Could fire twice, after Headers and after Trailers! + Function& response_headers, Optional response_code)> on_headers_received; Function on_finish; Function, u32)> on_progress; @@ -62,13 +65,16 @@ public: } protected: - NetworkJob(); + NetworkJob(OutputStream&); void did_finish(NonnullRefPtr&&); void did_fail(Error); void did_progress(Optional total_size, u32 downloaded); + size_t do_write(ReadonlyBytes bytes) { return m_output_stream.write(bytes); } + private: RefPtr m_response; + OutputStream& m_output_stream; Error m_error { Error::None }; }; diff --git a/Libraries/LibCore/NetworkResponse.cpp b/Libraries/LibCore/NetworkResponse.cpp index 70654c670da..ebaf7eff7d8 100644 --- a/Libraries/LibCore/NetworkResponse.cpp +++ b/Libraries/LibCore/NetworkResponse.cpp @@ -28,8 +28,7 @@ namespace Core { -NetworkResponse::NetworkResponse(ByteBuffer&& payload) - : m_payload(payload) +NetworkResponse::NetworkResponse() { } diff --git a/Libraries/LibCore/NetworkResponse.h b/Libraries/LibCore/NetworkResponse.h index d39e2832eb5..d2ff33a5695 100644 --- a/Libraries/LibCore/NetworkResponse.h +++ b/Libraries/LibCore/NetworkResponse.h @@ -36,13 +36,11 @@ public: virtual ~NetworkResponse(); bool is_error() const { return m_error; } - const ByteBuffer& payload() const { return m_payload; } protected: - explicit NetworkResponse(ByteBuffer&&); + explicit NetworkResponse(); bool m_error { false }; - ByteBuffer m_payload; }; } diff --git a/Libraries/LibGemini/GeminiJob.cpp b/Libraries/LibGemini/GeminiJob.cpp index fcd19000cfc..7e5863be00c 100644 --- a/Libraries/LibGemini/GeminiJob.cpp +++ b/Libraries/LibGemini/GeminiJob.cpp @@ -142,9 +142,9 @@ bool GeminiJob::eof() const return m_socket->eof(); } -bool GeminiJob::write(const ByteBuffer& data) +bool GeminiJob::write(ReadonlyBytes bytes) { - return m_socket->write(data); + return m_socket->write(bytes); } } diff --git a/Libraries/LibGemini/GeminiJob.h b/Libraries/LibGemini/GeminiJob.h index a48a8fa7439..531a6368d50 100644 --- a/Libraries/LibGemini/GeminiJob.h +++ b/Libraries/LibGemini/GeminiJob.h @@ -37,8 +37,8 @@ namespace Gemini { class GeminiJob final : public Job { C_OBJECT(GeminiJob) public: - explicit GeminiJob(const GeminiRequest& request, const Vector* override_certificates = nullptr) - : Job(request) + explicit GeminiJob(const GeminiRequest& request, OutputStream& output_stream, const Vector* override_certificates = nullptr) + : Job(request, output_stream) , m_override_ca_certificates(override_certificates) { } @@ -61,7 +61,7 @@ protected: virtual bool can_read() const override; virtual ByteBuffer receive(size_t) override; virtual bool eof() const override; - virtual bool write(const ByteBuffer&) override; + virtual bool write(ReadonlyBytes) override; virtual bool is_established() const override { return m_socket->is_established(); } virtual bool should_fail_on_empty_payload() const override { return false; } virtual void read_while_data_available(Function) override; diff --git a/Libraries/LibGemini/GeminiResponse.cpp b/Libraries/LibGemini/GeminiResponse.cpp index 8aeecb34a52..e9ed7afd320 100644 --- a/Libraries/LibGemini/GeminiResponse.cpp +++ b/Libraries/LibGemini/GeminiResponse.cpp @@ -28,9 +28,8 @@ namespace Gemini { -GeminiResponse::GeminiResponse(int status, String meta, ByteBuffer&& payload) - : Core::NetworkResponse(move(payload)) - , m_status(status) +GeminiResponse::GeminiResponse(int status, String meta) + : m_status(status) , m_meta(meta) { } diff --git a/Libraries/LibGemini/GeminiResponse.h b/Libraries/LibGemini/GeminiResponse.h index ff3254a5877..ec410af91bf 100644 --- a/Libraries/LibGemini/GeminiResponse.h +++ b/Libraries/LibGemini/GeminiResponse.h @@ -34,16 +34,16 @@ namespace Gemini { class GeminiResponse : public Core::NetworkResponse { public: virtual ~GeminiResponse() override; - static NonnullRefPtr create(int status, String meta, ByteBuffer&& payload) + static NonnullRefPtr create(int status, String meta) { - return adopt(*new GeminiResponse(status, meta, move(payload))); + return adopt(*new GeminiResponse(status, meta)); } int status() const { return m_status; } String meta() const { return m_meta; } private: - GeminiResponse(int status, String, ByteBuffer&&); + GeminiResponse(int status, String); int m_status { 0 }; String m_meta; diff --git a/Libraries/LibGemini/Job.cpp b/Libraries/LibGemini/Job.cpp index 6ff386d5f4b..d167c7eaef4 100644 --- a/Libraries/LibGemini/Job.cpp +++ b/Libraries/LibGemini/Job.cpp @@ -33,8 +33,9 @@ namespace Gemini { -Job::Job(const GeminiRequest& request) - : m_request(request) +Job::Job(const GeminiRequest& request, OutputStream& output_stream) + : Core::NetworkJob(output_stream) + , m_request(request) { } @@ -42,6 +43,23 @@ Job::~Job() { } +void Job::flush_received_buffers() +{ + for (size_t i = 0; i < m_received_buffers.size(); ++i) { + auto& payload = m_received_buffers[i]; + auto written = do_write(payload); + m_received_size -= written; + if (written == payload.size()) { + // FIXME: Make this a take-first-friendly object? + m_received_buffers.take_first(); + continue; + } + ASSERT(written < payload.size()); + payload = payload.slice(written, payload.size() - written); + return; + } +} + void Job::on_socket_connected() { register_on_ready_to_write([this] { @@ -126,6 +144,7 @@ void Job::on_socket_connected() m_received_buffers.append(payload); m_received_size += payload.size(); + flush_received_buffers(); deferred_invoke([this](auto&) { did_progress({}, m_received_size); }); @@ -144,15 +163,17 @@ void Job::on_socket_connected() void Job::finish_up() { m_state = State::Finished; - auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size); - u8* flat_ptr = flattened_buffer.data(); - for (auto& received_buffer : m_received_buffers) { - memcpy(flat_ptr, received_buffer.data(), received_buffer.size()); - flat_ptr += received_buffer.size(); + flush_received_buffers(); + if (m_received_size != 0) { + // FIXME: What do we do? ignore it? + // "Transmission failed" is not strictly correct, but let's roll with it for now. + deferred_invoke([this](auto&) { + did_fail(Error::TransmissionFailed); + }); + return; } - m_received_buffers.clear(); - auto response = GeminiResponse::create(m_status, m_meta, move(flattened_buffer)); + auto response = GeminiResponse::create(m_status, m_meta); deferred_invoke([this, response](auto&) { did_finish(move(response)); }); diff --git a/Libraries/LibGemini/Job.h b/Libraries/LibGemini/Job.h index 8101b3ef663..b1380786099 100644 --- a/Libraries/LibGemini/Job.h +++ b/Libraries/LibGemini/Job.h @@ -36,7 +36,7 @@ namespace Gemini { class Job : public Core::NetworkJob { public: - explicit Job(const GeminiRequest&); + explicit Job(const GeminiRequest&, OutputStream&); virtual ~Job() override; virtual void start() override = 0; @@ -48,6 +48,7 @@ public: protected: void finish_up(); void on_socket_connected(); + void flush_received_buffers(); virtual void register_on_ready_to_read(Function) = 0; virtual void register_on_ready_to_write(Function) = 0; virtual bool can_read_line() const = 0; @@ -55,7 +56,7 @@ protected: virtual bool can_read() const = 0; virtual ByteBuffer receive(size_t) = 0; virtual bool eof() const = 0; - virtual bool write(const ByteBuffer&) = 0; + virtual bool write(ReadonlyBytes) = 0; virtual bool is_established() const = 0; virtual bool should_fail_on_empty_payload() const { return false; } virtual void read_while_data_available(Function read) { read(); }; @@ -70,7 +71,7 @@ protected: State m_state { State::InStatus }; int m_status { -1 }; String m_meta; - Vector m_received_buffers; + Vector m_received_buffers; size_t m_received_size { 0 }; bool m_sent_data { false }; bool m_should_have_payload { false }; diff --git a/Libraries/LibHTTP/HttpJob.cpp b/Libraries/LibHTTP/HttpJob.cpp index 78ba90b7432..9013ea1862e 100644 --- a/Libraries/LibHTTP/HttpJob.cpp +++ b/Libraries/LibHTTP/HttpJob.cpp @@ -98,9 +98,9 @@ bool HttpJob::eof() const return m_socket->eof(); } -bool HttpJob::write(const ByteBuffer& data) +bool HttpJob::write(ReadonlyBytes bytes) { - return m_socket->write(data); + return m_socket->write(bytes); } } diff --git a/Libraries/LibHTTP/HttpJob.h b/Libraries/LibHTTP/HttpJob.h index 98dae2b3075..924ff6bf5f9 100644 --- a/Libraries/LibHTTP/HttpJob.h +++ b/Libraries/LibHTTP/HttpJob.h @@ -38,8 +38,8 @@ namespace HTTP { class HttpJob final : public Job { C_OBJECT(HttpJob) public: - explicit HttpJob(const HttpRequest& request) - : Job(request) + explicit HttpJob(const HttpRequest& request, OutputStream& output_stream) + : Job(request, output_stream) { } @@ -59,7 +59,7 @@ protected: virtual bool can_read() const override; virtual ByteBuffer receive(size_t) override; virtual bool eof() const override; - virtual bool write(const ByteBuffer&) override; + virtual bool write(ReadonlyBytes) override; virtual bool is_established() const override { return true; } private: diff --git a/Libraries/LibHTTP/HttpRequest.cpp b/Libraries/LibHTTP/HttpRequest.cpp index e5eaeefa97b..8ea42a4e841 100644 --- a/Libraries/LibHTTP/HttpRequest.cpp +++ b/Libraries/LibHTTP/HttpRequest.cpp @@ -71,11 +71,12 @@ ByteBuffer HttpRequest::to_raw_request() const builder.append(header.value); builder.append("\r\n"); } - builder.append("Connection: close\r\n\r\n"); + builder.append("Connection: close\r\n"); if (!m_body.is_empty()) { + builder.appendff("Content-Length: {}\r\n\r\n", m_body.size()); builder.append((const char*)m_body.data(), m_body.size()); - builder.append("\r\n"); } + builder.append("\r\n"); return builder.to_byte_buffer(); } diff --git a/Libraries/LibHTTP/HttpRequest.h b/Libraries/LibHTTP/HttpRequest.h index 8c1b158aa8d..c1a92d33eb7 100644 --- a/Libraries/LibHTTP/HttpRequest.h +++ b/Libraries/LibHTTP/HttpRequest.h @@ -62,7 +62,8 @@ public: void set_method(Method method) { m_method = method; } const ByteBuffer& body() const { return m_body; } - void set_body(const ByteBuffer& body) { m_body = body; } + void set_body(ReadonlyBytes body) { m_body = ByteBuffer::copy(body); } + void set_body(ByteBuffer&& body) { m_body = move(body); } String method_name() const; ByteBuffer to_raw_request() const; diff --git a/Libraries/LibHTTP/HttpResponse.cpp b/Libraries/LibHTTP/HttpResponse.cpp index db5711ff5c0..61571577ca8 100644 --- a/Libraries/LibHTTP/HttpResponse.cpp +++ b/Libraries/LibHTTP/HttpResponse.cpp @@ -28,9 +28,8 @@ namespace HTTP { -HttpResponse::HttpResponse(int code, HashMap&& headers, ByteBuffer&& payload) - : Core::NetworkResponse(move(payload)) - , m_code(code) +HttpResponse::HttpResponse(int code, HashMap&& headers) + : m_code(code) , m_headers(move(headers)) { } diff --git a/Libraries/LibHTTP/HttpResponse.h b/Libraries/LibHTTP/HttpResponse.h index f36b98bd13d..a5e9fa74dc4 100644 --- a/Libraries/LibHTTP/HttpResponse.h +++ b/Libraries/LibHTTP/HttpResponse.h @@ -35,16 +35,16 @@ namespace HTTP { class HttpResponse : public Core::NetworkResponse { public: virtual ~HttpResponse() override; - static NonnullRefPtr create(int code, HashMap&& headers, ByteBuffer&& payload) + static NonnullRefPtr create(int code, HashMap&& headers) { - return adopt(*new HttpResponse(code, move(headers), move(payload))); + return adopt(*new HttpResponse(code, move(headers))); } int code() const { return m_code; } const HashMap& headers() const { return m_headers; } private: - HttpResponse(int code, HashMap&&, ByteBuffer&&); + HttpResponse(int code, HashMap&&); int m_code { 0 }; HashMap m_headers; diff --git a/Libraries/LibHTTP/HttpsJob.cpp b/Libraries/LibHTTP/HttpsJob.cpp index 0b1b5855da6..20e50d9503f 100644 --- a/Libraries/LibHTTP/HttpsJob.cpp +++ b/Libraries/LibHTTP/HttpsJob.cpp @@ -143,7 +143,7 @@ bool HttpsJob::eof() const return m_socket->eof(); } -bool HttpsJob::write(const ByteBuffer& data) +bool HttpsJob::write(ReadonlyBytes data) { return m_socket->write(data); } diff --git a/Libraries/LibHTTP/HttpsJob.h b/Libraries/LibHTTP/HttpsJob.h index 391df6198f4..7ead5fbf48e 100644 --- a/Libraries/LibHTTP/HttpsJob.h +++ b/Libraries/LibHTTP/HttpsJob.h @@ -38,8 +38,8 @@ namespace HTTP { class HttpsJob final : public Job { C_OBJECT(HttpsJob) public: - explicit HttpsJob(const HttpRequest& request, const Vector* override_certs = nullptr) - : Job(request) + explicit HttpsJob(const HttpRequest& request, OutputStream& output_stream, const Vector* override_certs = nullptr) + : Job(request, output_stream) , m_override_ca_certificates(override_certs) { } @@ -62,7 +62,7 @@ protected: virtual bool can_read() const override; virtual ByteBuffer receive(size_t) override; virtual bool eof() const override; - virtual bool write(const ByteBuffer&) override; + virtual bool write(ReadonlyBytes) override; virtual bool is_established() const override { return m_socket->is_established(); } virtual bool should_fail_on_empty_payload() const override { return false; } virtual void read_while_data_available(Function) override; diff --git a/Libraries/LibHTTP/Job.cpp b/Libraries/LibHTTP/Job.cpp index 7996d521572..c60c701bee9 100644 --- a/Libraries/LibHTTP/Job.cpp +++ b/Libraries/LibHTTP/Job.cpp @@ -68,8 +68,9 @@ static ByteBuffer handle_content_encoding(const ByteBuffer& buf, const String& c return buf; } -Job::Job(const HttpRequest& request) - : m_request(request) +Job::Job(const HttpRequest& request, OutputStream& output_stream) + : Core::NetworkJob(output_stream) + , m_request(request) { } @@ -77,6 +78,35 @@ Job::~Job() { } +void Job::flush_received_buffers() +{ + if (!m_can_stream_response || m_buffered_size == 0) + return; +#ifdef JOB_DEBUG + dbg() << "Job: Flushing received buffers: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers"; +#endif + for (size_t i = 0; i < m_received_buffers.size(); ++i) { + auto& payload = m_received_buffers[i]; + auto written = do_write(payload); + m_buffered_size -= written; + if (written == payload.size()) { + // FIXME: Make this a take-first-friendly object? + m_received_buffers.take_first(); + --i; + continue; + } + ASSERT(written < payload.size()); + payload = payload.slice(written, payload.size() - written); +#ifdef JOB_DEBUG + dbg() << "Job: Flushing received buffers done: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers"; +#endif + return; + } +#ifdef JOB_DEBUG + dbg() << "Job: Flushing received buffers done: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers"; +#endif +} + void Job::on_socket_connected() { register_on_ready_to_write([&] { @@ -135,6 +165,8 @@ void Job::on_socket_connected() if (m_state == State::Trailers) { return finish_up(); } else { + if (on_headers_received) + on_headers_received(m_headers, m_code > 0 ? m_code : Optional {}); m_state = State::InBody; } return; @@ -163,6 +195,13 @@ void Job::on_socket_connected() } auto value = line.substring(name.length() + 2, line.length() - name.length() - 2); m_headers.set(name, value); + if (name.equals_ignoring_case("Content-Encoding")) { + // Assume that any content-encoding means that we can't decode it as a stream :( +#ifdef JOB_DEBUG + dbg() << "Content-Encoding " << value << " detected, cannot stream output :("; +#endif + m_can_stream_response = false; + } #ifdef JOB_DEBUG dbg() << "Job: [" << name << "] = '" << value << "'"; #endif @@ -252,7 +291,9 @@ void Job::on_socket_connected() } m_received_buffers.append(payload); + m_buffered_size += payload.size(); m_received_size += payload.size(); + flush_received_buffers(); if (m_current_chunk_remaining_size.has_value()) { auto size = m_current_chunk_remaining_size.value() - payload.size(); @@ -313,20 +354,37 @@ void Job::on_socket_connected() void Job::finish_up() { m_state = State::Finished; - auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size); - u8* flat_ptr = flattened_buffer.data(); - for (auto& received_buffer : m_received_buffers) { - memcpy(flat_ptr, received_buffer.data(), received_buffer.size()); - flat_ptr += received_buffer.size(); - } - m_received_buffers.clear(); + if (!m_can_stream_response) { + auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size); + u8* flat_ptr = flattened_buffer.data(); + for (auto& received_buffer : m_received_buffers) { + memcpy(flat_ptr, received_buffer.data(), received_buffer.size()); + flat_ptr += received_buffer.size(); + } + m_received_buffers.clear(); - auto content_encoding = m_headers.get("Content-Encoding"); - if (content_encoding.has_value()) { - flattened_buffer = handle_content_encoding(flattened_buffer, content_encoding.value()); + // For the time being, we cannot stream stuff with content-encoding set to _anything_. + auto content_encoding = m_headers.get("Content-Encoding"); + if (content_encoding.has_value()) { + flattened_buffer = handle_content_encoding(flattened_buffer, content_encoding.value()); + } + + m_buffered_size = flattened_buffer.size(); + m_received_buffers.append(move(flattened_buffer)); + m_can_stream_response = true; } - auto response = HttpResponse::create(m_code, move(m_headers), move(flattened_buffer)); + flush_received_buffers(); + if (m_buffered_size != 0) { + // FIXME: What do we do? ignore it? + // "Transmission failed" is not strictly correct, but let's roll with it for now. + deferred_invoke([this](auto&) { + did_fail(Error::TransmissionFailed); + }); + return; + } + + auto response = HttpResponse::create(m_code, move(m_headers)); deferred_invoke([this, response](auto&) { did_finish(move(response)); }); diff --git a/Libraries/LibHTTP/Job.h b/Libraries/LibHTTP/Job.h index fd410f892fc..b57b220396b 100644 --- a/Libraries/LibHTTP/Job.h +++ b/Libraries/LibHTTP/Job.h @@ -26,6 +26,7 @@ #pragma once +#include #include #include #include @@ -37,7 +38,7 @@ namespace HTTP { class Job : public Core::NetworkJob { public: - explicit Job(const HttpRequest&); + explicit Job(const HttpRequest&, OutputStream&); virtual ~Job() override; virtual void start() override = 0; @@ -49,6 +50,7 @@ public: protected: void finish_up(); void on_socket_connected(); + void flush_received_buffers(); virtual void register_on_ready_to_read(Function) = 0; virtual void register_on_ready_to_write(Function) = 0; virtual bool can_read_line() const = 0; @@ -56,7 +58,7 @@ protected: virtual bool can_read() const = 0; virtual ByteBuffer receive(size_t) = 0; virtual bool eof() const = 0; - virtual bool write(const ByteBuffer&) = 0; + virtual bool write(ReadonlyBytes) = 0; virtual bool is_established() const = 0; virtual bool should_fail_on_empty_payload() const { return true; } virtual void read_while_data_available(Function read) { read(); }; @@ -73,11 +75,13 @@ protected: State m_state { State::InStatus }; int m_code { -1 }; HashMap m_headers; - Vector m_received_buffers; + Vector m_received_buffers; + size_t m_buffered_size { 0 }; size_t m_received_size { 0 }; bool m_sent_data { 0 }; Optional m_current_chunk_remaining_size; Optional m_current_chunk_total_size; + bool m_can_stream_response { true }; }; } diff --git a/Libraries/LibProtocol/Client.cpp b/Libraries/LibProtocol/Client.cpp index 0138ed93497..7a452736eed 100644 --- a/Libraries/LibProtocol/Client.cpp +++ b/Libraries/LibProtocol/Client.cpp @@ -24,6 +24,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +#include #include #include #include @@ -47,16 +48,20 @@ bool Client::is_supported_protocol(const String& protocol) return send_sync(protocol)->supported(); } -RefPtr Client::start_download(const String& method, const String& url, const HashMap& request_headers, const ByteBuffer& request_body) +template +RefPtr Client::start_download(const String& method, const String& url, const HashMap& request_headers, ReadonlyBytes request_body) { IPC::Dictionary header_dictionary; for (auto& it : request_headers) header_dictionary.add(it.key, it.value); - i32 download_id = send_sync(method, url, header_dictionary, String::copy(request_body))->download_id(); - if (download_id < 0) + auto response = send_sync(method, url, header_dictionary, ByteBuffer::copy(request_body)); + auto download_id = response->download_id(); + auto response_fd = response->response_fd().fd(); + if (download_id < 0 || response_fd < 0) return nullptr; auto download = Download::create_from_id({}, *this, download_id); + download->set_download_fd({}, response_fd); m_downloads.set(download_id, download); return download; } @@ -79,9 +84,8 @@ void Client::handle(const Messages::ProtocolClient::DownloadFinished& message) { RefPtr download; if ((download = m_downloads.get(message.download_id()).value_or(nullptr))) { - download->did_finish({}, message.success(), message.status_code(), message.total_size(), message.shbuf_id(), message.response_headers()); + download->did_finish({}, message.success(), message.total_size()); } - send_sync(message.shbuf_id()); m_downloads.remove(message.download_id()); } @@ -92,6 +96,15 @@ void Client::handle(const Messages::ProtocolClient::DownloadProgress& message) } } +void Client::handle(const Messages::ProtocolClient::HeadersBecameAvailable& message) +{ + if (auto download = const_cast(m_downloads.get(message.download_id()).value_or(nullptr))) { + HashMap headers; + message.response_headers().for_each_entry([&](auto& name, auto& value) { headers.set(name, value); }); + download->did_receive_headers({}, headers, message.status_code()); + } +} + OwnPtr Client::handle(const Messages::ProtocolClient::CertificateRequested& message) { if (auto download = const_cast(m_downloads.get(message.download_id()).value_or(nullptr))) { @@ -102,3 +115,6 @@ OwnPtr Client::handle(co } } + +template RefPtr Protocol::Client::start_download(const String& method, const String& url, const HashMap& request_headers, ReadonlyBytes request_body); +template RefPtr Protocol::Client::start_download(const String& method, const String& url, const HashMap& request_headers, ReadonlyBytes request_body); diff --git a/Libraries/LibProtocol/Client.h b/Libraries/LibProtocol/Client.h index ca9f93b3e66..0546edaf2d9 100644 --- a/Libraries/LibProtocol/Client.h +++ b/Libraries/LibProtocol/Client.h @@ -44,7 +44,8 @@ public: virtual void handshake() override; bool is_supported_protocol(const String&); - RefPtr start_download(const String& method, const String& url, const HashMap& request_headers = {}, const ByteBuffer& request_body = {}); + template> + RefPtr start_download(const String& method, const String& url, const HashMap& request_headers = {}, ReadonlyBytes request_body = {}); bool stop_download(Badge, Download&); bool set_certificate(Badge, Download&, String, String); @@ -55,6 +56,7 @@ private: virtual void handle(const Messages::ProtocolClient::DownloadProgress&) override; virtual void handle(const Messages::ProtocolClient::DownloadFinished&) override; virtual OwnPtr handle(const Messages::ProtocolClient::CertificateRequested&) override; + virtual void handle(const Messages::ProtocolClient::HeadersBecameAvailable&) override; HashMap> m_downloads; }; diff --git a/Libraries/LibProtocol/Download.cpp b/Libraries/LibProtocol/Download.cpp index b71127951b4..f5011d04a58 100644 --- a/Libraries/LibProtocol/Download.cpp +++ b/Libraries/LibProtocol/Download.cpp @@ -41,25 +41,81 @@ bool Download::stop() return m_client->stop_download({}, *this); } -void Download::did_finish(Badge, bool success, Optional status_code, u32 total_size, i32 shbuf_id, const IPC::Dictionary& response_headers) +void Download::stream_into(OutputStream& stream) +{ + ASSERT(!m_internal_stream_data); + + auto notifier = Core::Notifier::construct(fd(), Core::Notifier::Read); + + m_internal_stream_data = make(fd()); + m_internal_stream_data->read_notifier = notifier; + + auto user_on_finish = move(on_finish); + on_finish = [this](auto success, auto total_size) { + m_internal_stream_data->success = success; + m_internal_stream_data->total_size = total_size; + m_internal_stream_data->download_done = true; + }; + + notifier->on_ready_to_read = [this, &stream, user_on_finish = move(user_on_finish)] { + constexpr size_t buffer_size = 1 * KiB; + static char buf[buffer_size]; + auto nread = m_internal_stream_data->read_stream.read({ buf, buffer_size }); + if (!stream.write_or_error({ buf, nread })) { + // FIXME: What do we do here? + TODO(); + } + + if (m_internal_stream_data->read_stream.eof() || (m_internal_stream_data->download_done && !m_internal_stream_data->success)) { + m_internal_stream_data->read_notifier->close(); + user_on_finish(m_internal_stream_data->success, m_internal_stream_data->total_size); + } else { + m_internal_stream_data->read_stream.handle_any_error(); + } + }; +} + +void Download::set_should_buffer_all_input(bool value) +{ + if (m_should_buffer_all_input == value) + return; + + if (m_internal_buffered_data && !value) { + m_internal_buffered_data = nullptr; + m_should_buffer_all_input = false; + return; + } + + ASSERT(!m_internal_stream_data); + ASSERT(!m_internal_buffered_data); + ASSERT(on_buffered_download_finish); // Not having this set makes no sense. + m_internal_buffered_data = make(fd()); + m_should_buffer_all_input = true; + + on_headers_received = [this](auto& headers, auto response_code) { + m_internal_buffered_data->response_headers = headers; + m_internal_buffered_data->response_code = move(response_code); + }; + + on_finish = [this](auto success, u32 total_size) { + auto output_buffer = m_internal_buffered_data->payload_stream.copy_into_contiguous_buffer(); + on_buffered_download_finish( + success, + total_size, + m_internal_buffered_data->response_headers, + m_internal_buffered_data->response_code, + output_buffer); + }; + + stream_into(m_internal_buffered_data->payload_stream); +} + +void Download::did_finish(Badge, bool success, u32 total_size) { if (!on_finish) return; - ReadonlyBytes payload; - RefPtr shared_buffer; - if (success && shbuf_id != -1) { - shared_buffer = SharedBuffer::create_from_shbuf_id(shbuf_id); - payload = { shared_buffer->data(), total_size }; - } - - // FIXME: It's a bit silly that we copy the response headers here just so we can move them into a HashMap with different traits. - HashMap caseless_response_headers; - response_headers.for_each_entry([&](auto& name, auto& value) { - caseless_response_headers.set(name, value); - }); - - on_finish(success, payload, move(shared_buffer), caseless_response_headers, status_code); + on_finish(success, total_size); } void Download::did_progress(Badge, Optional total_size, u32 downloaded_size) @@ -68,6 +124,12 @@ void Download::did_progress(Badge, Optional total_size, u32 downloa on_progress(total_size, downloaded_size); } +void Download::did_receive_headers(Badge, const HashMap& response_headers, Optional response_code) +{ + if (on_headers_received) + on_headers_received(response_headers, response_code); +} + void Download::did_request_certificates(Badge) { if (on_certificate_requested) { diff --git a/Libraries/LibProtocol/Download.h b/Libraries/LibProtocol/Download.h index c8058a98c87..42d6b21b04f 100644 --- a/Libraries/LibProtocol/Download.h +++ b/Libraries/LibProtocol/Download.h @@ -28,10 +28,13 @@ #include #include +#include #include +#include #include #include #include +#include #include namespace Protocol { @@ -51,20 +54,65 @@ public: } int id() const { return m_download_id; } + int fd() const { return m_fd; } bool stop(); - Function payload_storage, const HashMap& response_headers, Optional status_code)> on_finish; + void stream_into(OutputStream&); + + bool should_buffer_all_input() const { return m_should_buffer_all_input; } + /// Note: Will override `on_finish', and `on_headers_received', and expects `on_buffered_download_finish' to be set! + void set_should_buffer_all_input(bool); + + /// Note: Must be set before `set_should_buffer_all_input(true)`. + Function& response_headers, Optional response_code, ReadonlyBytes payload)> on_buffered_download_finish; + Function on_finish; Function total_size, u32 downloaded_size)> on_progress; + Function& response_headers, Optional response_code)> on_headers_received; Function on_certificate_requested; - void did_finish(Badge, bool success, Optional status_code, u32 total_size, i32 shbuf_id, const IPC::Dictionary& response_headers); + void did_finish(Badge, bool success, u32 total_size); void did_progress(Badge, Optional total_size, u32 downloaded_size); + void did_receive_headers(Badge, const HashMap& response_headers, Optional response_code); void did_request_certificates(Badge); + RefPtr& write_notifier(Badge) { return m_write_notifier; } + void set_download_fd(Badge, int fd) { m_fd = fd; } + private: explicit Download(Client&, i32 download_id); WeakPtr m_client; int m_download_id { -1 }; + RefPtr m_write_notifier; + int m_fd { -1 }; + bool m_should_buffer_all_input { false }; + + struct InternalBufferedData { + InternalBufferedData(int fd) + : read_stream(fd) + { + } + + InputFileStream read_stream; + DuplexMemoryStream payload_stream; + HashMap response_headers; + Optional response_code; + }; + + struct InternalStreamData { + InternalStreamData(int fd) + : read_stream(fd) + { + } + + InputFileStream read_stream; + RefPtr read_notifier; + bool success; + u32 total_size { 0 }; + bool download_done { false }; + }; + + OwnPtr m_internal_buffered_data; + OwnPtr m_internal_stream_data; }; } diff --git a/Libraries/LibWeb/DOM/XMLHttpRequest.cpp b/Libraries/LibWeb/DOM/XMLHttpRequest.cpp index 092ad060c8c..f085ade4fd2 100644 --- a/Libraries/LibWeb/DOM/XMLHttpRequest.cpp +++ b/Libraries/LibWeb/DOM/XMLHttpRequest.cpp @@ -92,10 +92,10 @@ void XMLHttpRequest::send() // we need to make ResourceLoader give us more detailed updates than just "done" and "error". ResourceLoader::the().load( m_window->document().complete_url(m_url), - [weak_this = make_weak_ptr()](auto& data, auto&) { + [weak_this = make_weak_ptr()](auto data, auto&) { if (!weak_this) return; - const_cast(*weak_this).m_response = data; + const_cast(*weak_this).m_response = ByteBuffer::copy(data); const_cast(*weak_this).set_ready_state(ReadyState::Done); const_cast(*weak_this).dispatch_event(DOM::Event::create(HTML::EventNames::load)); }, diff --git a/Libraries/LibWeb/HTML/HTMLScriptElement.cpp b/Libraries/LibWeb/HTML/HTMLScriptElement.cpp index a71ecd9fbdf..c1b2f84e11c 100644 --- a/Libraries/LibWeb/HTML/HTMLScriptElement.cpp +++ b/Libraries/LibWeb/HTML/HTMLScriptElement.cpp @@ -128,7 +128,7 @@ void HTMLScriptElement::prepare_script(Badge) // FIXME: This load should be made asynchronous and the parser should spin an event loop etc. ResourceLoader::the().load_sync( url, - [this, url](auto& data, auto&) { + [this, url](auto data, auto&) { if (data.is_null()) { dbg() << "HTMLScriptElement: Failed to load " << url; return; diff --git a/Libraries/LibWeb/Loader/FrameLoader.cpp b/Libraries/LibWeb/Loader/FrameLoader.cpp index 72598378d11..561be8f104d 100644 --- a/Libraries/LibWeb/Loader/FrameLoader.cpp +++ b/Libraries/LibWeb/Loader/FrameLoader.cpp @@ -171,6 +171,7 @@ bool FrameLoader::load(const LoadRequest& request, Type type) return true; if (url.protocol() == "http" || url.protocol() == "https") { +#if 0 URL favicon_url; favicon_url.set_protocol(url.protocol()); favicon_url.set_host(url.host()); @@ -191,6 +192,7 @@ bool FrameLoader::load(const LoadRequest& request, Type type) if (auto* page = frame().page()) page->client().page_did_change_favicon(*bitmap); }); +#endif } return true; diff --git a/Libraries/LibWeb/Loader/Resource.cpp b/Libraries/LibWeb/Loader/Resource.cpp index f0a46f1a41c..539cb272e17 100644 --- a/Libraries/LibWeb/Loader/Resource.cpp +++ b/Libraries/LibWeb/Loader/Resource.cpp @@ -84,10 +84,10 @@ static String mime_type_from_content_type(const String& content_type) return content_type; } -void Resource::did_load(Badge, const ByteBuffer& data, const HashMap& headers) +void Resource::did_load(Badge, ReadonlyBytes data, const HashMap& headers) { ASSERT(!m_loaded); - m_encoded_data = data; + m_encoded_data = ByteBuffer::copy(data); m_response_headers = headers; m_loaded = true; diff --git a/Libraries/LibWeb/Loader/Resource.h b/Libraries/LibWeb/Loader/Resource.h index 61c883aab92..9838071c658 100644 --- a/Libraries/LibWeb/Loader/Resource.h +++ b/Libraries/LibWeb/Loader/Resource.h @@ -77,7 +77,7 @@ public: void for_each_client(Function); - void did_load(Badge, const ByteBuffer& data, const HashMap& headers); + void did_load(Badge, ReadonlyBytes data, const HashMap& headers); void did_fail(Badge, const String& error); protected: diff --git a/Libraries/LibWeb/Loader/ResourceLoader.cpp b/Libraries/LibWeb/Loader/ResourceLoader.cpp index 8e7f304426a..333237e7d9c 100644 --- a/Libraries/LibWeb/Loader/ResourceLoader.cpp +++ b/Libraries/LibWeb/Loader/ResourceLoader.cpp @@ -53,13 +53,13 @@ ResourceLoader::ResourceLoader() { } -void ResourceLoader::load_sync(const URL& url, Function& response_headers)> success_callback, Function error_callback) +void ResourceLoader::load_sync(const URL& url, Function& response_headers)> success_callback, Function error_callback) { Core::EventLoop loop; load( url, - [&](auto& data, auto& response_headers) { + [&](auto data, auto& response_headers) { success_callback(data, response_headers); loop.quit(0); }, @@ -97,7 +97,7 @@ RefPtr ResourceLoader::load_resource(Resource::Type type, const LoadRe load( request, - [=](auto& data, auto& headers) { + [=](auto data, auto& headers) { const_cast(*resource).did_load({}, data, headers); }, [=](auto& error) { @@ -107,7 +107,7 @@ RefPtr ResourceLoader::load_resource(Resource::Type type, const LoadRe return resource; } -void ResourceLoader::load(const LoadRequest& request, Function& response_headers)> success_callback, Function error_callback) +void ResourceLoader::load(const LoadRequest& request, Function& response_headers)> success_callback, Function error_callback) { auto& url = request.url(); if (is_port_blocked(url.port())) { @@ -170,7 +170,12 @@ void ResourceLoader::load(const LoadRequest& request, Functionon_finish = [this, success_callback = move(success_callback), error_callback = move(error_callback)](bool success, ReadonlyBytes payload, auto, auto& response_headers, auto status_code) { + download->on_buffered_download_finish = [this, success_callback = move(success_callback), error_callback = move(error_callback), download](bool success, auto, auto& response_headers, auto status_code, ReadonlyBytes payload) { + if (status_code.has_value() && status_code.value() >= 400 && status_code.value() <= 499) { + if (error_callback) + error_callback(String::format("HTTP error (%u)", status_code.value())); + return; + } --m_pending_loads; if (on_load_counter_change) on_load_counter_change(); @@ -179,13 +184,9 @@ void ResourceLoader::load(const LoadRequest& request, Function= 400 && status_code.value() <= 499) { - if (error_callback) - error_callback(String::format("HTTP error (%u)", status_code.value())); - return; - } - success_callback(ByteBuffer::copy(payload.data(), payload.size()), response_headers); + success_callback(payload, response_headers); }; + download->set_should_buffer_all_input(true); download->on_certificate_requested = []() -> Protocol::Download::CertificateAndKey { return {}; }; @@ -199,7 +200,7 @@ void ResourceLoader::load(const LoadRequest& request, Function& response_headers)> success_callback, Function error_callback) +void ResourceLoader::load(const URL& url, Function& response_headers)> success_callback, Function error_callback) { LoadRequest request; request.set_url(url); diff --git a/Libraries/LibWeb/Loader/ResourceLoader.h b/Libraries/LibWeb/Loader/ResourceLoader.h index d6cb8d1ede1..5fe23adc8a6 100644 --- a/Libraries/LibWeb/Loader/ResourceLoader.h +++ b/Libraries/LibWeb/Loader/ResourceLoader.h @@ -44,9 +44,9 @@ public: RefPtr load_resource(Resource::Type, const LoadRequest&); - void load(const LoadRequest&, Function& response_headers)> success_callback, Function error_callback = nullptr); - void load(const URL&, Function& response_headers)> success_callback, Function error_callback = nullptr); - void load_sync(const URL&, Function& response_headers)> success_callback, Function error_callback = nullptr); + void load(const LoadRequest&, Function& response_headers)> success_callback, Function error_callback = nullptr); + void load(const URL&, Function& response_headers)> success_callback, Function error_callback = nullptr); + void load_sync(const URL&, Function& response_headers)> success_callback, Function error_callback = nullptr); Function on_load_counter_change; diff --git a/Services/ProtocolServer/ClientConnection.cpp b/Services/ProtocolServer/ClientConnection.cpp index 93f02ae2e62..eee7277869f 100644 --- a/Services/ProtocolServer/ClientConnection.cpp +++ b/Services/ProtocolServer/ClientConnection.cpp @@ -62,16 +62,17 @@ OwnPtr ClientConnection::handle { URL url(message.url()); if (!url.is_valid()) - return make(-1); + return make(-1, -1); auto* protocol = Protocol::find_by_name(url.protocol()); if (!protocol) - return make(-1); - auto download = protocol->start_download(*this, message.method(), url, message.request_headers().entries(), message.request_body().to_byte_buffer()); + return make(-1, -1); + auto download = protocol->start_download(*this, message.method(), url, message.request_headers().entries(), message.request_body()); if (!download) - return make(-1); + return make(-1, -1); auto id = download->id(); + auto fd = download->download_fd(); m_downloads.set(id, move(download)); - return make(id); + return make(id, fd); } OwnPtr ClientConnection::handle(const Messages::ProtocolServer::StopDownload& message) @@ -86,22 +87,20 @@ OwnPtr ClientConnection::handle( return make(success); } -void ClientConnection::did_finish_download(Badge, Download& download, bool success) +void ClientConnection::did_receive_headers(Badge, Download& download) { - RefPtr buffer; - if (success && download.payload().size() > 0 && !download.payload().is_null()) { - buffer = SharedBuffer::create_with_size(download.payload().size()); - memcpy(buffer->data(), download.payload().data(), download.payload().size()); - buffer->seal(); - buffer->share_with(client_pid()); - m_shared_buffers.set(buffer->shbuf_id(), buffer); - } - ASSERT(download.total_size().has_value()); - IPC::Dictionary response_headers; for (auto& it : download.response_headers()) response_headers.add(it.key, it.value); - post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.status_code(), download.total_size().value(), buffer ? buffer->shbuf_id() : -1, response_headers)); + + post_message(Messages::ProtocolClient::HeadersBecameAvailable(download.id(), move(response_headers), download.status_code())); +} + +void ClientConnection::did_finish_download(Badge, Download& download, bool success) +{ + ASSERT(download.total_size().has_value()); + + post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.total_size().value())); m_downloads.remove(download.id()); } @@ -121,12 +120,6 @@ OwnPtr ClientConnection::handle(const M return make(client_id()); } -OwnPtr ClientConnection::handle(const Messages::ProtocolServer::DisownSharedBuffer& message) -{ - m_shared_buffers.remove(message.shbuf_id()); - return make(); -} - OwnPtr ClientConnection::handle(const Messages::ProtocolServer::SetCertificate& message) { auto* download = const_cast(m_downloads.get(message.download_id()).value_or(nullptr)); diff --git a/Services/ProtocolServer/ClientConnection.h b/Services/ProtocolServer/ClientConnection.h index 4439fdbd878..778f3eff813 100644 --- a/Services/ProtocolServer/ClientConnection.h +++ b/Services/ProtocolServer/ClientConnection.h @@ -45,6 +45,7 @@ public: virtual void die() override; + void did_receive_headers(Badge, Download&); void did_finish_download(Badge, Download&, bool success); void did_progress_download(Badge, Download&); void did_request_certificates(Badge, Download&); @@ -54,11 +55,9 @@ private: virtual OwnPtr handle(const Messages::ProtocolServer::IsSupportedProtocol&) override; virtual OwnPtr handle(const Messages::ProtocolServer::StartDownload&) override; virtual OwnPtr handle(const Messages::ProtocolServer::StopDownload&) override; - virtual OwnPtr handle(const Messages::ProtocolServer::DisownSharedBuffer&) override; - virtual OwnPtr handle(const Messages::ProtocolServer::SetCertificate&); + virtual OwnPtr handle(const Messages::ProtocolServer::SetCertificate&) override; HashMap> m_downloads; - HashMap> m_shared_buffers; }; } diff --git a/Services/ProtocolServer/Download.cpp b/Services/ProtocolServer/Download.cpp index d0d9aa2ab8c..11a7d289332 100644 --- a/Services/ProtocolServer/Download.cpp +++ b/Services/ProtocolServer/Download.cpp @@ -33,9 +33,10 @@ namespace ProtocolServer { // FIXME: What about rollover? static i32 s_next_id = 1; -Download::Download(ClientConnection& client) +Download::Download(ClientConnection& client, NonnullOwnPtr&& output_stream) : m_client(client) , m_id(s_next_id++) + , m_output_stream(move(output_stream)) { } @@ -48,15 +49,10 @@ void Download::stop() m_client.did_finish_download({}, *this, false); } -void Download::set_payload(const ByteBuffer& payload) -{ - m_payload = payload; - m_total_size = payload.size(); -} - void Download::set_response_headers(const HashMap& response_headers) { m_response_headers = response_headers; + m_client.did_receive_headers({}, *this); } void Download::set_certificate(String, String) diff --git a/Services/ProtocolServer/Download.h b/Services/ProtocolServer/Download.h index f0d0342006d..35c60269f4c 100644 --- a/Services/ProtocolServer/Download.h +++ b/Services/ProtocolServer/Download.h @@ -26,8 +26,9 @@ #pragma once -#include +#include #include +#include #include #include #include @@ -45,30 +46,35 @@ public: Optional status_code() const { return m_status_code; } Optional total_size() const { return m_total_size; } size_t downloaded_size() const { return m_downloaded_size; } - const ByteBuffer& payload() const { return m_payload; } const HashMap& response_headers() const { return m_response_headers; } void stop(); virtual void set_certificate(String, String); + // FIXME: Want Badge, but can't make one from HttpProtocol, etc. + void set_download_fd(int fd) { m_download_fd = fd; } + int download_fd() const { return m_download_fd; } + protected: - explicit Download(ClientConnection&); + explicit Download(ClientConnection&, NonnullOwnPtr&&); void did_finish(bool success); void did_progress(Optional total_size, u32 downloaded_size); void set_status_code(u32 status_code) { m_status_code = status_code; } void did_request_certificates(); - void set_payload(const ByteBuffer&); void set_response_headers(const HashMap&); + void set_downloaded_size(size_t size) { m_downloaded_size = size; } + const OutputFileStream& output_stream() const { return *m_output_stream; } private: ClientConnection& m_client; i32 m_id { 0 }; + int m_download_fd { -1 }; // Passed to client. URL m_url; Optional m_status_code; Optional m_total_size {}; size_t m_downloaded_size { 0 }; - ByteBuffer m_payload; + NonnullOwnPtr m_output_stream; HashMap m_response_headers; }; diff --git a/Services/ProtocolServer/GeminiDownload.cpp b/Services/ProtocolServer/GeminiDownload.cpp index a504aaca7f7..0bba75519d1 100644 --- a/Services/ProtocolServer/GeminiDownload.cpp +++ b/Services/ProtocolServer/GeminiDownload.cpp @@ -30,13 +30,13 @@ namespace ProtocolServer { -GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtr job) - : Download(client) +GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtr job, NonnullOwnPtr&& output_stream) + : Download(client, move(output_stream)) , m_job(job) { m_job->on_finish = [this](bool success) { if (auto* response = m_job->response()) { - set_payload(response->payload()); + set_downloaded_size(this->output_stream().size()); if (!response->meta().is_empty()) { HashMap headers; headers.set("meta", response->meta()); @@ -76,9 +76,9 @@ GeminiDownload::~GeminiDownload() m_job->shutdown(); } -NonnullOwnPtr GeminiDownload::create_with_job(Badge, ClientConnection& client, NonnullRefPtr job) +NonnullOwnPtr GeminiDownload::create_with_job(Badge, ClientConnection& client, NonnullRefPtr job, NonnullOwnPtr&& output_stream) { - return adopt_own(*new GeminiDownload(client, move(job))); + return adopt_own(*new GeminiDownload(client, move(job), move(output_stream))); } } diff --git a/Services/ProtocolServer/GeminiDownload.h b/Services/ProtocolServer/GeminiDownload.h index c429bac7b1c..fcd81c121f8 100644 --- a/Services/ProtocolServer/GeminiDownload.h +++ b/Services/ProtocolServer/GeminiDownload.h @@ -36,10 +36,10 @@ namespace ProtocolServer { class GeminiDownload final : public Download { public: virtual ~GeminiDownload() override; - static NonnullOwnPtr create_with_job(Badge, ClientConnection&, NonnullRefPtr); + static NonnullOwnPtr create_with_job(Badge, ClientConnection&, NonnullRefPtr, NonnullOwnPtr&&); private: - explicit GeminiDownload(ClientConnection&, NonnullRefPtr); + explicit GeminiDownload(ClientConnection&, NonnullRefPtr, NonnullOwnPtr&&); virtual void set_certificate(String certificate, String key) override; diff --git a/Services/ProtocolServer/GeminiProtocol.cpp b/Services/ProtocolServer/GeminiProtocol.cpp index f1167ecd61e..ff4380de7f4 100644 --- a/Services/ProtocolServer/GeminiProtocol.cpp +++ b/Services/ProtocolServer/GeminiProtocol.cpp @@ -40,12 +40,22 @@ GeminiProtocol::~GeminiProtocol() { } -OwnPtr GeminiProtocol::start_download(ClientConnection& client, const String&, const URL& url, const HashMap&, const ByteBuffer&) +OwnPtr GeminiProtocol::start_download(ClientConnection& client, const String&, const URL& url, const HashMap&, ReadonlyBytes) { Gemini::GeminiRequest request; request.set_url(url); - auto job = Gemini::GeminiJob::construct(request); - auto download = GeminiDownload::create_with_job({}, client, (Gemini::GeminiJob&)*job); + + int fd_pair[2] { 0 }; + if (pipe(fd_pair) != 0) { + auto saved_errno = errno; + dbgln("Protocol: pipe() failed: {}", strerror(saved_errno)); + return nullptr; + } + auto output_stream = make(fd_pair[1]); + output_stream->make_unbuffered(); + auto job = Gemini::GeminiJob::construct(request, *output_stream); + auto download = GeminiDownload::create_with_job({}, client, (Gemini::GeminiJob&)*job, move(output_stream)); + download->set_download_fd(fd_pair[0]); job->start(); return download; } diff --git a/Services/ProtocolServer/GeminiProtocol.h b/Services/ProtocolServer/GeminiProtocol.h index f9ed21cca30..23a4d7c7176 100644 --- a/Services/ProtocolServer/GeminiProtocol.h +++ b/Services/ProtocolServer/GeminiProtocol.h @@ -35,7 +35,7 @@ public: GeminiProtocol(); virtual ~GeminiProtocol() override; - virtual OwnPtr start_download(ClientConnection&, const String& method, const URL&, const HashMap&, const ByteBuffer& request_body) override; + virtual OwnPtr start_download(ClientConnection&, const String& method, const URL&, const HashMap&, ReadonlyBytes body) override; }; } diff --git a/Services/ProtocolServer/HttpDownload.cpp b/Services/ProtocolServer/HttpDownload.cpp index bfa22351d3d..8ba945d3357 100644 --- a/Services/ProtocolServer/HttpDownload.cpp +++ b/Services/ProtocolServer/HttpDownload.cpp @@ -30,15 +30,21 @@ namespace ProtocolServer { -HttpDownload::HttpDownload(ClientConnection& client, NonnullRefPtr job) - : Download(client) +HttpDownload::HttpDownload(ClientConnection& client, NonnullRefPtr job, NonnullOwnPtr&& output_stream) + : Download(client, move(output_stream)) , m_job(job) { + m_job->on_headers_received = [this](auto& headers, auto response_code) { + if (response_code.has_value()) + set_status_code(response_code.value()); + set_response_headers(headers); + }; + m_job->on_finish = [this](bool success) { if (auto* response = m_job->response()) { set_status_code(response->code()); - set_payload(response->payload()); set_response_headers(response->headers()); + set_downloaded_size(this->output_stream().size()); } // if we didn't know the total size, pretend that the download finished successfully @@ -60,9 +66,9 @@ HttpDownload::~HttpDownload() m_job->shutdown(); } -NonnullOwnPtr HttpDownload::create_with_job(Badge, ClientConnection& client, NonnullRefPtr job) +NonnullOwnPtr HttpDownload::create_with_job(Badge, ClientConnection& client, NonnullRefPtr job, NonnullOwnPtr&& output_stream) { - return adopt_own(*new HttpDownload(client, move(job))); + return adopt_own(*new HttpDownload(client, move(job), move(output_stream))); } } diff --git a/Services/ProtocolServer/HttpDownload.h b/Services/ProtocolServer/HttpDownload.h index d0d745ef0cb..50095bd0e54 100644 --- a/Services/ProtocolServer/HttpDownload.h +++ b/Services/ProtocolServer/HttpDownload.h @@ -36,10 +36,10 @@ namespace ProtocolServer { class HttpDownload final : public Download { public: virtual ~HttpDownload() override; - static NonnullOwnPtr create_with_job(Badge, ClientConnection&, NonnullRefPtr); + static NonnullOwnPtr create_with_job(Badge, ClientConnection&, NonnullRefPtr, NonnullOwnPtr&&); private: - explicit HttpDownload(ClientConnection&, NonnullRefPtr); + explicit HttpDownload(ClientConnection&, NonnullRefPtr, NonnullOwnPtr&&); NonnullRefPtr m_job; }; diff --git a/Services/ProtocolServer/HttpProtocol.cpp b/Services/ProtocolServer/HttpProtocol.cpp index b0e74e766af..e8c7a3203cc 100644 --- a/Services/ProtocolServer/HttpProtocol.cpp +++ b/Services/ProtocolServer/HttpProtocol.cpp @@ -28,6 +28,7 @@ #include #include #include +#include namespace ProtocolServer { @@ -40,7 +41,7 @@ HttpProtocol::~HttpProtocol() { } -OwnPtr HttpProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap& headers, const ByteBuffer& request_body) +OwnPtr HttpProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap& headers, ReadonlyBytes body) { HTTP::HttpRequest request; if (method.equals_ignoring_case("post")) @@ -49,9 +50,20 @@ OwnPtr HttpProtocol::start_download(ClientConnection& client, const St request.set_method(HTTP::HttpRequest::Method::GET); request.set_url(url); request.set_headers(headers); - request.set_body(request_body); - auto job = HTTP::HttpJob::construct(request); - auto download = HttpDownload::create_with_job({}, client, (HTTP::HttpJob&)*job); + request.set_body(body); + + int fd_pair[2] { 0 }; + if (pipe(fd_pair) != 0) { + auto saved_errno = errno; + dbgln("Protocol: pipe() failed: {}", strerror(saved_errno)); + return nullptr; + } + + auto output_stream = make(fd_pair[1]); + output_stream->make_unbuffered(); + auto job = HTTP::HttpJob::construct(request, *output_stream); + auto download = HttpDownload::create_with_job({}, client, (HTTP::HttpJob&)*job, move(output_stream)); + download->set_download_fd(fd_pair[0]); job->start(); return download; } diff --git a/Services/ProtocolServer/HttpProtocol.h b/Services/ProtocolServer/HttpProtocol.h index aa9601b8ce5..8c4a564f377 100644 --- a/Services/ProtocolServer/HttpProtocol.h +++ b/Services/ProtocolServer/HttpProtocol.h @@ -35,7 +35,7 @@ public: HttpProtocol(); virtual ~HttpProtocol() override; - virtual OwnPtr start_download(ClientConnection&, const String& method, const URL&, const HashMap& headers, const ByteBuffer& request_body) override; + virtual OwnPtr start_download(ClientConnection&, const String& method, const URL&, const HashMap& headers, ReadonlyBytes body) override; }; } diff --git a/Services/ProtocolServer/HttpsDownload.cpp b/Services/ProtocolServer/HttpsDownload.cpp index fe381d216ea..991dd730d88 100644 --- a/Services/ProtocolServer/HttpsDownload.cpp +++ b/Services/ProtocolServer/HttpsDownload.cpp @@ -30,15 +30,21 @@ namespace ProtocolServer { -HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtr job) - : Download(client) +HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtr job, NonnullOwnPtr&& output_stream) + : Download(client, move(output_stream)) , m_job(job) { + m_job->on_headers_received = [this](auto& headers, auto response_code) { + if (response_code.has_value()) + set_status_code(response_code.value()); + set_response_headers(headers); + }; + m_job->on_finish = [this](bool success) { if (auto* response = m_job->response()) { set_status_code(response->code()); - set_payload(response->payload()); set_response_headers(response->headers()); + set_downloaded_size(this->output_stream().size()); } // if we didn't know the total size, pretend that the download finished successfully @@ -68,9 +74,9 @@ HttpsDownload::~HttpsDownload() m_job->shutdown(); } -NonnullOwnPtr HttpsDownload::create_with_job(Badge, ClientConnection& client, NonnullRefPtr job) +NonnullOwnPtr HttpsDownload::create_with_job(Badge, ClientConnection& client, NonnullRefPtr job, NonnullOwnPtr&& output_stream) { - return adopt_own(*new HttpsDownload(client, move(job))); + return adopt_own(*new HttpsDownload(client, move(job), move(output_stream))); } } diff --git a/Services/ProtocolServer/HttpsDownload.h b/Services/ProtocolServer/HttpsDownload.h index 48f255b2fac..254172b3f72 100644 --- a/Services/ProtocolServer/HttpsDownload.h +++ b/Services/ProtocolServer/HttpsDownload.h @@ -36,10 +36,10 @@ namespace ProtocolServer { class HttpsDownload final : public Download { public: virtual ~HttpsDownload() override; - static NonnullOwnPtr create_with_job(Badge, ClientConnection&, NonnullRefPtr); + static NonnullOwnPtr create_with_job(Badge, ClientConnection&, NonnullRefPtr, NonnullOwnPtr&&); private: - explicit HttpsDownload(ClientConnection&, NonnullRefPtr); + explicit HttpsDownload(ClientConnection&, NonnullRefPtr, NonnullOwnPtr&&); virtual void set_certificate(String certificate, String key) override; diff --git a/Services/ProtocolServer/HttpsProtocol.cpp b/Services/ProtocolServer/HttpsProtocol.cpp index 3de9ca8e2b5..e34ff324223 100644 --- a/Services/ProtocolServer/HttpsProtocol.cpp +++ b/Services/ProtocolServer/HttpsProtocol.cpp @@ -40,7 +40,7 @@ HttpsProtocol::~HttpsProtocol() { } -OwnPtr HttpsProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap& headers, const ByteBuffer& request_body) +OwnPtr HttpsProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap& headers, ReadonlyBytes body) { HTTP::HttpRequest request; if (method.equals_ignoring_case("post")) @@ -49,9 +49,19 @@ OwnPtr HttpsProtocol::start_download(ClientConnection& client, const S request.set_method(HTTP::HttpRequest::Method::GET); request.set_url(url); request.set_headers(headers); - request.set_body(request_body); - auto job = HTTP::HttpsJob::construct(request); - auto download = HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job); + request.set_body(body); + + int fd_pair[2] { 0 }; + if (pipe(fd_pair) != 0) { + auto saved_errno = errno; + dbgln("Protocol: pipe() failed: {}", strerror(saved_errno)); + return nullptr; + } + auto output_stream = make(fd_pair[1]); + output_stream->make_unbuffered(); + auto job = HTTP::HttpsJob::construct(request, *output_stream); + auto download = HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job, move(output_stream)); + download->set_download_fd(fd_pair[0]); job->start(); return download; } diff --git a/Services/ProtocolServer/HttpsProtocol.h b/Services/ProtocolServer/HttpsProtocol.h index 9cb0ce190b2..40e59aa2712 100644 --- a/Services/ProtocolServer/HttpsProtocol.h +++ b/Services/ProtocolServer/HttpsProtocol.h @@ -35,7 +35,7 @@ public: HttpsProtocol(); virtual ~HttpsProtocol() override; - virtual OwnPtr start_download(ClientConnection&, const String& method, const URL&, const HashMap& headers, const ByteBuffer& request_body) override; + virtual OwnPtr start_download(ClientConnection&, const String& method, const URL&, const HashMap& headers, ReadonlyBytes body) override; }; } diff --git a/Services/ProtocolServer/Protocol.h b/Services/ProtocolServer/Protocol.h index 035b56cb6eb..609362f548a 100644 --- a/Services/ProtocolServer/Protocol.h +++ b/Services/ProtocolServer/Protocol.h @@ -37,7 +37,7 @@ public: virtual ~Protocol(); const String& name() const { return m_name; } - virtual OwnPtr start_download(ClientConnection&, const String& method, const URL&, const HashMap& headers, const ByteBuffer& request_body) = 0; + virtual OwnPtr start_download(ClientConnection&, const String& method, const URL&, const HashMap& headers, ReadonlyBytes body) = 0; static Protocol* find_by_name(const String&); diff --git a/Services/ProtocolServer/ProtocolClient.ipc b/Services/ProtocolServer/ProtocolClient.ipc index ef00d760ced..88f4cfc96d1 100644 --- a/Services/ProtocolServer/ProtocolClient.ipc +++ b/Services/ProtocolServer/ProtocolClient.ipc @@ -2,7 +2,8 @@ endpoint ProtocolClient = 13 { // Download notifications DownloadProgress(i32 download_id, Optional total_size, u32 downloaded_size) =| - DownloadFinished(i32 download_id, bool success, Optional status_code, u32 total_size, i32 shbuf_id, IPC::Dictionary response_headers) =| + DownloadFinished(i32 download_id, bool success, u32 total_size) =| + HeadersBecameAvailable(i32 download_id, IPC::Dictionary response_headers, Optional status_code) =| // Certificate requests CertificateRequested(i32 download_id) => () diff --git a/Services/ProtocolServer/ProtocolServer.ipc b/Services/ProtocolServer/ProtocolServer.ipc index 4cf1204520c..0707afb733f 100644 --- a/Services/ProtocolServer/ProtocolServer.ipc +++ b/Services/ProtocolServer/ProtocolServer.ipc @@ -3,14 +3,11 @@ endpoint ProtocolServer = 9 // Basic protocol Greet() => (i32 client_id) - // FIXME: It would be nice if the kernel provided a way to avoid this - DisownSharedBuffer(i32 shbuf_id) => () - // Test if a specific protocol is supported, e.g "http" IsSupportedProtocol(String protocol) => (bool supported) // Download API - StartDownload(String method, URL url, IPC::Dictionary request_headers, String request_body) => (i32 download_id) + StartDownload(String method, URL url, IPC::Dictionary request_headers, ByteBuffer request_body) => (i32 download_id, IPC::File response_fd) StopDownload(i32 download_id) => (bool success) SetCertificate(i32 download_id, String certificate, String key) => (bool success) } diff --git a/Services/ProtocolServer/main.cpp b/Services/ProtocolServer/main.cpp index 765dbe5056e..62fc908dd1b 100644 --- a/Services/ProtocolServer/main.cpp +++ b/Services/ProtocolServer/main.cpp @@ -35,7 +35,7 @@ int main(int, char**) { - if (pledge("stdio inet shared_buffer accept unix rpath cpath fattr", nullptr) < 0) { + if (pledge("stdio inet shared_buffer accept unix rpath cpath fattr sendfd recvfd", nullptr) < 0) { perror("pledge"); return 1; } @@ -45,7 +45,7 @@ int main(int, char**) Core::EventLoop event_loop; // FIXME: Establish a connection to LookupServer and then drop "unix"? - if (pledge("stdio inet shared_buffer accept unix", nullptr) < 0) { + if (pledge("stdio inet shared_buffer accept unix sendfd recvfd", nullptr) < 0) { perror("pledge"); return 1; } diff --git a/Userland/pro.cpp b/Userland/pro.cpp index 8cac03531af..3b350f142d2 100644 --- a/Userland/pro.cpp +++ b/Userland/pro.cpp @@ -24,6 +24,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +#include #include #include #include @@ -116,29 +117,50 @@ private: bool m_might_be_wrong { false }; }; -static void do_write(ReadonlyBytes payload) -{ - size_t length_remaining = payload.size(); - size_t length_written = 0; - while (length_remaining > 0) { - auto nwritten = fwrite(payload.data() + length_written, sizeof(char), length_remaining, stdout); - if (nwritten > 0) { - length_remaining -= nwritten; - length_written += nwritten; - continue; - } +template +class ConditionalOutputFileStream final : public OutputFileStream { +public: + template + ConditionalOutputFileStream(ConditionT&& condition, Args... args) + : OutputFileStream(args...) + , m_condition(condition) + { + } - if (feof(stdout)) { - fprintf(stderr, "pro: unexpected eof while writing\n"); + ~ConditionalOutputFileStream() + { + if (!m_condition()) return; - } - if (ferror(stdout)) { - fprintf(stderr, "pro: error while writing\n"); - return; + if (!m_buffer.is_empty()) { + OutputFileStream::write(m_buffer); + m_buffer.clear(); } } -} + +private: + size_t write(ReadonlyBytes bytes) override + { + if (!m_condition()) { + write_to_buffer:; + m_buffer.append(bytes.data(), bytes.size()); + return bytes.size(); + } + + if (!m_buffer.is_empty()) { + auto size = OutputFileStream::write(m_buffer); + m_buffer = m_buffer.slice(size, m_buffer.size() - size); + } + + if (!m_buffer.is_empty()) + goto write_to_buffer; + + return OutputFileStream::write(bytes); + } + + ConditionT m_condition; + ByteBuffer m_buffer; +}; int main(int argc, char** argv) { @@ -195,6 +217,8 @@ int main(int argc, char** argv) timeval prev_time, current_time, time_diff; gettimeofday(&prev_time, nullptr); + bool received_actual_headers = false; + download->on_progress = [&](Optional maybe_total_size, u32 downloaded_size) { fprintf(stderr, "\r\033[2K"); if (maybe_total_size.has_value()) { @@ -215,10 +239,13 @@ int main(int argc, char** argv) previous_downloaded_size = downloaded_size; prev_time = current_time; }; - download->on_finish = [&](bool success, auto payload, auto, auto& response_headers, auto) { - fprintf(stderr, "\033]9;-1;\033\\"); - fprintf(stderr, "\n"); - if (success && save_at_provided_name) { + + if (save_at_provided_name) { + download->on_headers_received = [&](auto& response_headers, auto status_code) { + if (received_actual_headers) + return; + dbg() << "Received headers! response code = " << status_code.value_or(0); + received_actual_headers = true; // And not trailers! String output_name; if (auto content_disposition = response_headers.get("Content-Disposition"); content_disposition.has_value()) { auto& value = content_disposition.value(); @@ -245,17 +272,26 @@ int main(int argc, char** argv) if (freopen(output_name.characters(), "w", stdout) == nullptr) { perror("freopen"); - success = false; // oops! loop.quit(1); + return; } - } - if (success) - do_write(payload); - else + }; + } + download->on_finish = [&](bool success, auto) { + fprintf(stderr, "\033]9;-1;\033\\"); + fprintf(stderr, "\n"); + if (!success) fprintf(stderr, "Download failed :(\n"); loop.quit(0); }; + + auto output_stream = ConditionalOutputFileStream { [&] { return save_at_provided_name ? received_actual_headers : true; }, stdout }; + download->stream_into(output_stream); + dbgprintf("started download with id %d\n", download->id()); - return loop.exec(); + auto rc = loop.exec(); + // FIXME: This shouldn't be needed. + fclose(stdout); + return rc; }