mirror of
https://github.com/LadybirdBrowser/ladybird.git
synced 2024-11-21 23:20:20 +00:00
ProtocolServer: Stream the downloaded data if possible
This patchset makes ProtocolServer stream the downloads to its client (LibProtocol), and as such changes the download API; a possible download lifecycle could be as such: notation = client->server:'>', server->client:'<', pipe activity:'*' ``` > StartDownload(GET, url, headers, {}) < Response(0, fd 8) * {data, 1024b} < HeadersBecameAvailable(0, response_headers, 200) < DownloadProgress(0, 4K, 1024) * {data, 1024b} * {data, 1024b} < DownloadProgress(0, 4K, 2048) * {data, 1024b} < DownloadProgress(0, 4K, 1024) < DownloadFinished(0, true, 4K) ``` Since managing the received file descriptor is a pain, LibProtocol implements `Download::stream_into(OutputStream)`, which can be used to stream the download into any given output stream (be it a file, or memory, or writing stuff with a delay, etc.). Also, as some of the users of this API require all the downloaded data upfront, LibProtocol also implements `set_should_buffer_all_input()`, which causes the download instance to buffer all the data until the download is complete, and to call the `on_buffered_download_finish` hook.
This commit is contained in:
parent
36d642ee75
commit
4a2da10e38
Notes:
sideshowbarker
2024-07-19 00:23:12 +09:00
Author: https://github.com/alimpfard Commit: https://github.com/SerenityOS/serenity/commit/4a2da10e38c Pull-request: https://github.com/SerenityOS/serenity/pull/4526
55 changed files with 528 additions and 235 deletions
|
@ -29,6 +29,7 @@
|
|||
#include <AK/SharedBuffer.h>
|
||||
#include <AK/StringBuilder.h>
|
||||
#include <LibCore/File.h>
|
||||
#include <LibCore/FileStream.h>
|
||||
#include <LibCore/StandardPaths.h>
|
||||
#include <LibDesktop/Launcher.h>
|
||||
#include <LibGUI/BoxLayout.h>
|
||||
|
@ -61,9 +62,19 @@ DownloadWidget::DownloadWidget(const URL& url)
|
|||
m_download->on_progress = [this](Optional<u32> total_size, u32 downloaded_size) {
|
||||
did_progress(total_size.value(), downloaded_size);
|
||||
};
|
||||
m_download->on_finish = [this](bool success, auto payload, auto payload_storage, auto& response_headers, auto) {
|
||||
did_finish(success, payload, payload_storage, response_headers);
|
||||
};
|
||||
|
||||
{
|
||||
auto file_or_error = Core::File::open(m_destination_path, Core::IODevice::WriteOnly);
|
||||
if (file_or_error.is_error()) {
|
||||
GUI::MessageBox::show(window(), String::formatted("Cannot open {} for writing", m_destination_path), "Download failed", GUI::MessageBox::Type::Error);
|
||||
window()->close();
|
||||
return;
|
||||
}
|
||||
m_output_file_stream = make<Core::OutputFileStream>(*file_or_error.value());
|
||||
}
|
||||
|
||||
m_download->on_finish = [this](bool success, auto) { did_finish(success); };
|
||||
m_download->stream_into(*m_output_file_stream);
|
||||
|
||||
set_fill_with_background_color(true);
|
||||
auto& layout = set_layout<GUI::VerticalBoxLayout>();
|
||||
|
@ -149,7 +160,7 @@ void DownloadWidget::did_progress(Optional<u32> total_size, u32 downloaded_size)
|
|||
}
|
||||
}
|
||||
|
||||
void DownloadWidget::did_finish(bool success, [[maybe_unused]] ReadonlyBytes payload, [[maybe_unused]] RefPtr<SharedBuffer> payload_storage, [[maybe_unused]] const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)
|
||||
void DownloadWidget::did_finish(bool success)
|
||||
{
|
||||
dbg() << "did_finish, success=" << success;
|
||||
|
||||
|
@ -166,17 +177,6 @@ void DownloadWidget::did_finish(bool success, [[maybe_unused]] ReadonlyBytes pay
|
|||
window()->close();
|
||||
return;
|
||||
}
|
||||
|
||||
auto file_or_error = Core::File::open(m_destination_path, Core::IODevice::WriteOnly);
|
||||
if (file_or_error.is_error()) {
|
||||
GUI::MessageBox::show(window(), String::formatted("Cannot open {} for writing", m_destination_path), "Download failed", GUI::MessageBox::Type::Error);
|
||||
window()->close();
|
||||
return;
|
||||
}
|
||||
|
||||
auto& file = *file_or_error.value();
|
||||
bool write_success = file.write(payload.data(), payload.size());
|
||||
ASSERT(write_success);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
|
||||
#include <AK/URL.h>
|
||||
#include <LibCore/ElapsedTimer.h>
|
||||
#include <LibCore/FileStream.h>
|
||||
#include <LibGUI/ProgressBar.h>
|
||||
#include <LibGUI/Widget.h>
|
||||
#include <LibProtocol/Download.h>
|
||||
|
@ -44,7 +45,7 @@ private:
|
|||
explicit DownloadWidget(const URL&);
|
||||
|
||||
void did_progress(Optional<u32> total_size, u32 downloaded_size);
|
||||
void did_finish(bool success, ReadonlyBytes payload, RefPtr<SharedBuffer> payload_storage, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers);
|
||||
void did_finish(bool success);
|
||||
|
||||
URL m_url;
|
||||
String m_destination_path;
|
||||
|
@ -53,6 +54,7 @@ private:
|
|||
RefPtr<GUI::Label> m_progress_label;
|
||||
RefPtr<GUI::Button> m_cancel_button;
|
||||
RefPtr<GUI::Button> m_close_button;
|
||||
OwnPtr<Core::OutputFileStream> m_output_file_stream;
|
||||
Core::ElapsedTimer m_elapsed_timer;
|
||||
};
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ int main(int argc, char** argv)
|
|||
return 1;
|
||||
}
|
||||
|
||||
if (pledge("stdio shared_buffer accept unix cpath rpath wpath fattr", nullptr) < 0) {
|
||||
if (pledge("stdio shared_buffer accept unix cpath rpath wpath fattr sendfd recvfd", nullptr) < 0) {
|
||||
perror("pledge");
|
||||
return 1;
|
||||
}
|
||||
|
@ -86,7 +86,7 @@ int main(int argc, char** argv)
|
|||
Web::ResourceLoader::the();
|
||||
|
||||
// FIXME: Once there is a standalone Download Manager, we can drop the "unix" pledge.
|
||||
if (pledge("stdio shared_buffer accept unix cpath rpath wpath", nullptr) < 0) {
|
||||
if (pledge("stdio shared_buffer accept unix cpath rpath wpath sendfd recvfd", nullptr) < 0) {
|
||||
perror("pledge");
|
||||
return 1;
|
||||
}
|
||||
|
|
|
@ -32,7 +32,8 @@
|
|||
|
||||
namespace Core {
|
||||
|
||||
NetworkJob::NetworkJob()
|
||||
NetworkJob::NetworkJob(OutputStream& output_stream)
|
||||
: m_output_stream(output_stream)
|
||||
{
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <AK/Function.h>
|
||||
#include <AK/Stream.h>
|
||||
#include <LibCore/Object.h>
|
||||
|
||||
namespace Core {
|
||||
|
@ -43,6 +44,8 @@ public:
|
|||
};
|
||||
virtual ~NetworkJob() override;
|
||||
|
||||
// Could fire twice, after Headers and after Trailers!
|
||||
Function<void(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code)> on_headers_received;
|
||||
Function<void(bool success)> on_finish;
|
||||
Function<void(Optional<u32>, u32)> on_progress;
|
||||
|
||||
|
@ -62,13 +65,16 @@ public:
|
|||
}
|
||||
|
||||
protected:
|
||||
NetworkJob();
|
||||
NetworkJob(OutputStream&);
|
||||
void did_finish(NonnullRefPtr<NetworkResponse>&&);
|
||||
void did_fail(Error);
|
||||
void did_progress(Optional<u32> total_size, u32 downloaded);
|
||||
|
||||
size_t do_write(ReadonlyBytes bytes) { return m_output_stream.write(bytes); }
|
||||
|
||||
private:
|
||||
RefPtr<NetworkResponse> m_response;
|
||||
OutputStream& m_output_stream;
|
||||
Error m_error { Error::None };
|
||||
};
|
||||
|
||||
|
|
|
@ -28,8 +28,7 @@
|
|||
|
||||
namespace Core {
|
||||
|
||||
NetworkResponse::NetworkResponse(ByteBuffer&& payload)
|
||||
: m_payload(payload)
|
||||
NetworkResponse::NetworkResponse()
|
||||
{
|
||||
}
|
||||
|
||||
|
|
|
@ -36,13 +36,11 @@ public:
|
|||
virtual ~NetworkResponse();
|
||||
|
||||
bool is_error() const { return m_error; }
|
||||
const ByteBuffer& payload() const { return m_payload; }
|
||||
|
||||
protected:
|
||||
explicit NetworkResponse(ByteBuffer&&);
|
||||
explicit NetworkResponse();
|
||||
|
||||
bool m_error { false };
|
||||
ByteBuffer m_payload;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -142,9 +142,9 @@ bool GeminiJob::eof() const
|
|||
return m_socket->eof();
|
||||
}
|
||||
|
||||
bool GeminiJob::write(const ByteBuffer& data)
|
||||
bool GeminiJob::write(ReadonlyBytes bytes)
|
||||
{
|
||||
return m_socket->write(data);
|
||||
return m_socket->write(bytes);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -37,8 +37,8 @@ namespace Gemini {
|
|||
class GeminiJob final : public Job {
|
||||
C_OBJECT(GeminiJob)
|
||||
public:
|
||||
explicit GeminiJob(const GeminiRequest& request, const Vector<Certificate>* override_certificates = nullptr)
|
||||
: Job(request)
|
||||
explicit GeminiJob(const GeminiRequest& request, OutputStream& output_stream, const Vector<Certificate>* override_certificates = nullptr)
|
||||
: Job(request, output_stream)
|
||||
, m_override_ca_certificates(override_certificates)
|
||||
{
|
||||
}
|
||||
|
@ -61,7 +61,7 @@ protected:
|
|||
virtual bool can_read() const override;
|
||||
virtual ByteBuffer receive(size_t) override;
|
||||
virtual bool eof() const override;
|
||||
virtual bool write(const ByteBuffer&) override;
|
||||
virtual bool write(ReadonlyBytes) override;
|
||||
virtual bool is_established() const override { return m_socket->is_established(); }
|
||||
virtual bool should_fail_on_empty_payload() const override { return false; }
|
||||
virtual void read_while_data_available(Function<IterationDecision()>) override;
|
||||
|
|
|
@ -28,9 +28,8 @@
|
|||
|
||||
namespace Gemini {
|
||||
|
||||
GeminiResponse::GeminiResponse(int status, String meta, ByteBuffer&& payload)
|
||||
: Core::NetworkResponse(move(payload))
|
||||
, m_status(status)
|
||||
GeminiResponse::GeminiResponse(int status, String meta)
|
||||
: m_status(status)
|
||||
, m_meta(meta)
|
||||
{
|
||||
}
|
||||
|
|
|
@ -34,16 +34,16 @@ namespace Gemini {
|
|||
class GeminiResponse : public Core::NetworkResponse {
|
||||
public:
|
||||
virtual ~GeminiResponse() override;
|
||||
static NonnullRefPtr<GeminiResponse> create(int status, String meta, ByteBuffer&& payload)
|
||||
static NonnullRefPtr<GeminiResponse> create(int status, String meta)
|
||||
{
|
||||
return adopt(*new GeminiResponse(status, meta, move(payload)));
|
||||
return adopt(*new GeminiResponse(status, meta));
|
||||
}
|
||||
|
||||
int status() const { return m_status; }
|
||||
String meta() const { return m_meta; }
|
||||
|
||||
private:
|
||||
GeminiResponse(int status, String, ByteBuffer&&);
|
||||
GeminiResponse(int status, String);
|
||||
|
||||
int m_status { 0 };
|
||||
String m_meta;
|
||||
|
|
|
@ -33,8 +33,9 @@
|
|||
|
||||
namespace Gemini {
|
||||
|
||||
Job::Job(const GeminiRequest& request)
|
||||
: m_request(request)
|
||||
Job::Job(const GeminiRequest& request, OutputStream& output_stream)
|
||||
: Core::NetworkJob(output_stream)
|
||||
, m_request(request)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -42,6 +43,23 @@ Job::~Job()
|
|||
{
|
||||
}
|
||||
|
||||
void Job::flush_received_buffers()
|
||||
{
|
||||
for (size_t i = 0; i < m_received_buffers.size(); ++i) {
|
||||
auto& payload = m_received_buffers[i];
|
||||
auto written = do_write(payload);
|
||||
m_received_size -= written;
|
||||
if (written == payload.size()) {
|
||||
// FIXME: Make this a take-first-friendly object?
|
||||
m_received_buffers.take_first();
|
||||
continue;
|
||||
}
|
||||
ASSERT(written < payload.size());
|
||||
payload = payload.slice(written, payload.size() - written);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void Job::on_socket_connected()
|
||||
{
|
||||
register_on_ready_to_write([this] {
|
||||
|
@ -126,6 +144,7 @@ void Job::on_socket_connected()
|
|||
|
||||
m_received_buffers.append(payload);
|
||||
m_received_size += payload.size();
|
||||
flush_received_buffers();
|
||||
|
||||
deferred_invoke([this](auto&) { did_progress({}, m_received_size); });
|
||||
|
||||
|
@ -144,15 +163,17 @@ void Job::on_socket_connected()
|
|||
void Job::finish_up()
|
||||
{
|
||||
m_state = State::Finished;
|
||||
auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size);
|
||||
u8* flat_ptr = flattened_buffer.data();
|
||||
for (auto& received_buffer : m_received_buffers) {
|
||||
memcpy(flat_ptr, received_buffer.data(), received_buffer.size());
|
||||
flat_ptr += received_buffer.size();
|
||||
flush_received_buffers();
|
||||
if (m_received_size != 0) {
|
||||
// FIXME: What do we do? ignore it?
|
||||
// "Transmission failed" is not strictly correct, but let's roll with it for now.
|
||||
deferred_invoke([this](auto&) {
|
||||
did_fail(Error::TransmissionFailed);
|
||||
});
|
||||
return;
|
||||
}
|
||||
m_received_buffers.clear();
|
||||
|
||||
auto response = GeminiResponse::create(m_status, m_meta, move(flattened_buffer));
|
||||
auto response = GeminiResponse::create(m_status, m_meta);
|
||||
deferred_invoke([this, response](auto&) {
|
||||
did_finish(move(response));
|
||||
});
|
||||
|
|
|
@ -36,7 +36,7 @@ namespace Gemini {
|
|||
|
||||
class Job : public Core::NetworkJob {
|
||||
public:
|
||||
explicit Job(const GeminiRequest&);
|
||||
explicit Job(const GeminiRequest&, OutputStream&);
|
||||
virtual ~Job() override;
|
||||
|
||||
virtual void start() override = 0;
|
||||
|
@ -48,6 +48,7 @@ public:
|
|||
protected:
|
||||
void finish_up();
|
||||
void on_socket_connected();
|
||||
void flush_received_buffers();
|
||||
virtual void register_on_ready_to_read(Function<void()>) = 0;
|
||||
virtual void register_on_ready_to_write(Function<void()>) = 0;
|
||||
virtual bool can_read_line() const = 0;
|
||||
|
@ -55,7 +56,7 @@ protected:
|
|||
virtual bool can_read() const = 0;
|
||||
virtual ByteBuffer receive(size_t) = 0;
|
||||
virtual bool eof() const = 0;
|
||||
virtual bool write(const ByteBuffer&) = 0;
|
||||
virtual bool write(ReadonlyBytes) = 0;
|
||||
virtual bool is_established() const = 0;
|
||||
virtual bool should_fail_on_empty_payload() const { return false; }
|
||||
virtual void read_while_data_available(Function<IterationDecision()> read) { read(); };
|
||||
|
@ -70,7 +71,7 @@ protected:
|
|||
State m_state { State::InStatus };
|
||||
int m_status { -1 };
|
||||
String m_meta;
|
||||
Vector<ByteBuffer> m_received_buffers;
|
||||
Vector<ByteBuffer, 2> m_received_buffers;
|
||||
size_t m_received_size { 0 };
|
||||
bool m_sent_data { false };
|
||||
bool m_should_have_payload { false };
|
||||
|
|
|
@ -98,9 +98,9 @@ bool HttpJob::eof() const
|
|||
return m_socket->eof();
|
||||
}
|
||||
|
||||
bool HttpJob::write(const ByteBuffer& data)
|
||||
bool HttpJob::write(ReadonlyBytes bytes)
|
||||
{
|
||||
return m_socket->write(data);
|
||||
return m_socket->write(bytes);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -38,8 +38,8 @@ namespace HTTP {
|
|||
class HttpJob final : public Job {
|
||||
C_OBJECT(HttpJob)
|
||||
public:
|
||||
explicit HttpJob(const HttpRequest& request)
|
||||
: Job(request)
|
||||
explicit HttpJob(const HttpRequest& request, OutputStream& output_stream)
|
||||
: Job(request, output_stream)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -59,7 +59,7 @@ protected:
|
|||
virtual bool can_read() const override;
|
||||
virtual ByteBuffer receive(size_t) override;
|
||||
virtual bool eof() const override;
|
||||
virtual bool write(const ByteBuffer&) override;
|
||||
virtual bool write(ReadonlyBytes) override;
|
||||
virtual bool is_established() const override { return true; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -71,11 +71,12 @@ ByteBuffer HttpRequest::to_raw_request() const
|
|||
builder.append(header.value);
|
||||
builder.append("\r\n");
|
||||
}
|
||||
builder.append("Connection: close\r\n\r\n");
|
||||
builder.append("Connection: close\r\n");
|
||||
if (!m_body.is_empty()) {
|
||||
builder.appendff("Content-Length: {}\r\n\r\n", m_body.size());
|
||||
builder.append((const char*)m_body.data(), m_body.size());
|
||||
builder.append("\r\n");
|
||||
}
|
||||
builder.append("\r\n");
|
||||
return builder.to_byte_buffer();
|
||||
}
|
||||
|
||||
|
|
|
@ -62,7 +62,8 @@ public:
|
|||
void set_method(Method method) { m_method = method; }
|
||||
|
||||
const ByteBuffer& body() const { return m_body; }
|
||||
void set_body(const ByteBuffer& body) { m_body = body; }
|
||||
void set_body(ReadonlyBytes body) { m_body = ByteBuffer::copy(body); }
|
||||
void set_body(ByteBuffer&& body) { m_body = move(body); }
|
||||
|
||||
String method_name() const;
|
||||
ByteBuffer to_raw_request() const;
|
||||
|
|
|
@ -28,9 +28,8 @@
|
|||
|
||||
namespace HTTP {
|
||||
|
||||
HttpResponse::HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers, ByteBuffer&& payload)
|
||||
: Core::NetworkResponse(move(payload))
|
||||
, m_code(code)
|
||||
HttpResponse::HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers)
|
||||
: m_code(code)
|
||||
, m_headers(move(headers))
|
||||
{
|
||||
}
|
||||
|
|
|
@ -35,16 +35,16 @@ namespace HTTP {
|
|||
class HttpResponse : public Core::NetworkResponse {
|
||||
public:
|
||||
virtual ~HttpResponse() override;
|
||||
static NonnullRefPtr<HttpResponse> create(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers, ByteBuffer&& payload)
|
||||
static NonnullRefPtr<HttpResponse> create(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers)
|
||||
{
|
||||
return adopt(*new HttpResponse(code, move(headers), move(payload)));
|
||||
return adopt(*new HttpResponse(code, move(headers)));
|
||||
}
|
||||
|
||||
int code() const { return m_code; }
|
||||
const HashMap<String, String, CaseInsensitiveStringTraits>& headers() const { return m_headers; }
|
||||
|
||||
private:
|
||||
HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&&, ByteBuffer&&);
|
||||
HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&&);
|
||||
|
||||
int m_code { 0 };
|
||||
HashMap<String, String, CaseInsensitiveStringTraits> m_headers;
|
||||
|
|
|
@ -143,7 +143,7 @@ bool HttpsJob::eof() const
|
|||
return m_socket->eof();
|
||||
}
|
||||
|
||||
bool HttpsJob::write(const ByteBuffer& data)
|
||||
bool HttpsJob::write(ReadonlyBytes data)
|
||||
{
|
||||
return m_socket->write(data);
|
||||
}
|
||||
|
|
|
@ -38,8 +38,8 @@ namespace HTTP {
|
|||
class HttpsJob final : public Job {
|
||||
C_OBJECT(HttpsJob)
|
||||
public:
|
||||
explicit HttpsJob(const HttpRequest& request, const Vector<Certificate>* override_certs = nullptr)
|
||||
: Job(request)
|
||||
explicit HttpsJob(const HttpRequest& request, OutputStream& output_stream, const Vector<Certificate>* override_certs = nullptr)
|
||||
: Job(request, output_stream)
|
||||
, m_override_ca_certificates(override_certs)
|
||||
{
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ protected:
|
|||
virtual bool can_read() const override;
|
||||
virtual ByteBuffer receive(size_t) override;
|
||||
virtual bool eof() const override;
|
||||
virtual bool write(const ByteBuffer&) override;
|
||||
virtual bool write(ReadonlyBytes) override;
|
||||
virtual bool is_established() const override { return m_socket->is_established(); }
|
||||
virtual bool should_fail_on_empty_payload() const override { return false; }
|
||||
virtual void read_while_data_available(Function<IterationDecision()>) override;
|
||||
|
|
|
@ -68,8 +68,9 @@ static ByteBuffer handle_content_encoding(const ByteBuffer& buf, const String& c
|
|||
return buf;
|
||||
}
|
||||
|
||||
Job::Job(const HttpRequest& request)
|
||||
: m_request(request)
|
||||
Job::Job(const HttpRequest& request, OutputStream& output_stream)
|
||||
: Core::NetworkJob(output_stream)
|
||||
, m_request(request)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -77,6 +78,35 @@ Job::~Job()
|
|||
{
|
||||
}
|
||||
|
||||
void Job::flush_received_buffers()
|
||||
{
|
||||
if (!m_can_stream_response || m_buffered_size == 0)
|
||||
return;
|
||||
#ifdef JOB_DEBUG
|
||||
dbg() << "Job: Flushing received buffers: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers";
|
||||
#endif
|
||||
for (size_t i = 0; i < m_received_buffers.size(); ++i) {
|
||||
auto& payload = m_received_buffers[i];
|
||||
auto written = do_write(payload);
|
||||
m_buffered_size -= written;
|
||||
if (written == payload.size()) {
|
||||
// FIXME: Make this a take-first-friendly object?
|
||||
m_received_buffers.take_first();
|
||||
--i;
|
||||
continue;
|
||||
}
|
||||
ASSERT(written < payload.size());
|
||||
payload = payload.slice(written, payload.size() - written);
|
||||
#ifdef JOB_DEBUG
|
||||
dbg() << "Job: Flushing received buffers done: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers";
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
#ifdef JOB_DEBUG
|
||||
dbg() << "Job: Flushing received buffers done: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers";
|
||||
#endif
|
||||
}
|
||||
|
||||
void Job::on_socket_connected()
|
||||
{
|
||||
register_on_ready_to_write([&] {
|
||||
|
@ -135,6 +165,8 @@ void Job::on_socket_connected()
|
|||
if (m_state == State::Trailers) {
|
||||
return finish_up();
|
||||
} else {
|
||||
if (on_headers_received)
|
||||
on_headers_received(m_headers, m_code > 0 ? m_code : Optional<u32> {});
|
||||
m_state = State::InBody;
|
||||
}
|
||||
return;
|
||||
|
@ -163,6 +195,13 @@ void Job::on_socket_connected()
|
|||
}
|
||||
auto value = line.substring(name.length() + 2, line.length() - name.length() - 2);
|
||||
m_headers.set(name, value);
|
||||
if (name.equals_ignoring_case("Content-Encoding")) {
|
||||
// Assume that any content-encoding means that we can't decode it as a stream :(
|
||||
#ifdef JOB_DEBUG
|
||||
dbg() << "Content-Encoding " << value << " detected, cannot stream output :(";
|
||||
#endif
|
||||
m_can_stream_response = false;
|
||||
}
|
||||
#ifdef JOB_DEBUG
|
||||
dbg() << "Job: [" << name << "] = '" << value << "'";
|
||||
#endif
|
||||
|
@ -252,7 +291,9 @@ void Job::on_socket_connected()
|
|||
}
|
||||
|
||||
m_received_buffers.append(payload);
|
||||
m_buffered_size += payload.size();
|
||||
m_received_size += payload.size();
|
||||
flush_received_buffers();
|
||||
|
||||
if (m_current_chunk_remaining_size.has_value()) {
|
||||
auto size = m_current_chunk_remaining_size.value() - payload.size();
|
||||
|
@ -313,20 +354,37 @@ void Job::on_socket_connected()
|
|||
void Job::finish_up()
|
||||
{
|
||||
m_state = State::Finished;
|
||||
auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size);
|
||||
u8* flat_ptr = flattened_buffer.data();
|
||||
for (auto& received_buffer : m_received_buffers) {
|
||||
memcpy(flat_ptr, received_buffer.data(), received_buffer.size());
|
||||
flat_ptr += received_buffer.size();
|
||||
}
|
||||
m_received_buffers.clear();
|
||||
if (!m_can_stream_response) {
|
||||
auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size);
|
||||
u8* flat_ptr = flattened_buffer.data();
|
||||
for (auto& received_buffer : m_received_buffers) {
|
||||
memcpy(flat_ptr, received_buffer.data(), received_buffer.size());
|
||||
flat_ptr += received_buffer.size();
|
||||
}
|
||||
m_received_buffers.clear();
|
||||
|
||||
auto content_encoding = m_headers.get("Content-Encoding");
|
||||
if (content_encoding.has_value()) {
|
||||
flattened_buffer = handle_content_encoding(flattened_buffer, content_encoding.value());
|
||||
// For the time being, we cannot stream stuff with content-encoding set to _anything_.
|
||||
auto content_encoding = m_headers.get("Content-Encoding");
|
||||
if (content_encoding.has_value()) {
|
||||
flattened_buffer = handle_content_encoding(flattened_buffer, content_encoding.value());
|
||||
}
|
||||
|
||||
m_buffered_size = flattened_buffer.size();
|
||||
m_received_buffers.append(move(flattened_buffer));
|
||||
m_can_stream_response = true;
|
||||
}
|
||||
|
||||
auto response = HttpResponse::create(m_code, move(m_headers), move(flattened_buffer));
|
||||
flush_received_buffers();
|
||||
if (m_buffered_size != 0) {
|
||||
// FIXME: What do we do? ignore it?
|
||||
// "Transmission failed" is not strictly correct, but let's roll with it for now.
|
||||
deferred_invoke([this](auto&) {
|
||||
did_fail(Error::TransmissionFailed);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
auto response = HttpResponse::create(m_code, move(m_headers));
|
||||
deferred_invoke([this, response](auto&) {
|
||||
did_finish(move(response));
|
||||
});
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <AK/FileStream.h>
|
||||
#include <AK/HashMap.h>
|
||||
#include <AK/Optional.h>
|
||||
#include <LibCore/NetworkJob.h>
|
||||
|
@ -37,7 +38,7 @@ namespace HTTP {
|
|||
|
||||
class Job : public Core::NetworkJob {
|
||||
public:
|
||||
explicit Job(const HttpRequest&);
|
||||
explicit Job(const HttpRequest&, OutputStream&);
|
||||
virtual ~Job() override;
|
||||
|
||||
virtual void start() override = 0;
|
||||
|
@ -49,6 +50,7 @@ public:
|
|||
protected:
|
||||
void finish_up();
|
||||
void on_socket_connected();
|
||||
void flush_received_buffers();
|
||||
virtual void register_on_ready_to_read(Function<void()>) = 0;
|
||||
virtual void register_on_ready_to_write(Function<void()>) = 0;
|
||||
virtual bool can_read_line() const = 0;
|
||||
|
@ -56,7 +58,7 @@ protected:
|
|||
virtual bool can_read() const = 0;
|
||||
virtual ByteBuffer receive(size_t) = 0;
|
||||
virtual bool eof() const = 0;
|
||||
virtual bool write(const ByteBuffer&) = 0;
|
||||
virtual bool write(ReadonlyBytes) = 0;
|
||||
virtual bool is_established() const = 0;
|
||||
virtual bool should_fail_on_empty_payload() const { return true; }
|
||||
virtual void read_while_data_available(Function<IterationDecision()> read) { read(); };
|
||||
|
@ -73,11 +75,13 @@ protected:
|
|||
State m_state { State::InStatus };
|
||||
int m_code { -1 };
|
||||
HashMap<String, String, CaseInsensitiveStringTraits> m_headers;
|
||||
Vector<ByteBuffer> m_received_buffers;
|
||||
Vector<ByteBuffer, 2> m_received_buffers;
|
||||
size_t m_buffered_size { 0 };
|
||||
size_t m_received_size { 0 };
|
||||
bool m_sent_data { 0 };
|
||||
Optional<ssize_t> m_current_chunk_remaining_size;
|
||||
Optional<size_t> m_current_chunk_total_size;
|
||||
bool m_can_stream_response { true };
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
|
||||
#include <AK/FileStream.h>
|
||||
#include <AK/SharedBuffer.h>
|
||||
#include <LibProtocol/Client.h>
|
||||
#include <LibProtocol/Download.h>
|
||||
|
@ -47,16 +48,20 @@ bool Client::is_supported_protocol(const String& protocol)
|
|||
return send_sync<Messages::ProtocolServer::IsSupportedProtocol>(protocol)->supported();
|
||||
}
|
||||
|
||||
RefPtr<Download> Client::start_download(const String& method, const String& url, const HashMap<String, String>& request_headers, const ByteBuffer& request_body)
|
||||
template<typename RequestHashMapTraits>
|
||||
RefPtr<Download> Client::start_download(const String& method, const String& url, const HashMap<String, String, RequestHashMapTraits>& request_headers, ReadonlyBytes request_body)
|
||||
{
|
||||
IPC::Dictionary header_dictionary;
|
||||
for (auto& it : request_headers)
|
||||
header_dictionary.add(it.key, it.value);
|
||||
|
||||
i32 download_id = send_sync<Messages::ProtocolServer::StartDownload>(method, url, header_dictionary, String::copy(request_body))->download_id();
|
||||
if (download_id < 0)
|
||||
auto response = send_sync<Messages::ProtocolServer::StartDownload>(method, url, header_dictionary, ByteBuffer::copy(request_body));
|
||||
auto download_id = response->download_id();
|
||||
auto response_fd = response->response_fd().fd();
|
||||
if (download_id < 0 || response_fd < 0)
|
||||
return nullptr;
|
||||
auto download = Download::create_from_id({}, *this, download_id);
|
||||
download->set_download_fd({}, response_fd);
|
||||
m_downloads.set(download_id, download);
|
||||
return download;
|
||||
}
|
||||
|
@ -79,9 +84,8 @@ void Client::handle(const Messages::ProtocolClient::DownloadFinished& message)
|
|||
{
|
||||
RefPtr<Download> download;
|
||||
if ((download = m_downloads.get(message.download_id()).value_or(nullptr))) {
|
||||
download->did_finish({}, message.success(), message.status_code(), message.total_size(), message.shbuf_id(), message.response_headers());
|
||||
download->did_finish({}, message.success(), message.total_size());
|
||||
}
|
||||
send_sync<Messages::ProtocolServer::DisownSharedBuffer>(message.shbuf_id());
|
||||
m_downloads.remove(message.download_id());
|
||||
}
|
||||
|
||||
|
@ -92,6 +96,15 @@ void Client::handle(const Messages::ProtocolClient::DownloadProgress& message)
|
|||
}
|
||||
}
|
||||
|
||||
void Client::handle(const Messages::ProtocolClient::HeadersBecameAvailable& message)
|
||||
{
|
||||
if (auto download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr))) {
|
||||
HashMap<String, String, CaseInsensitiveStringTraits> headers;
|
||||
message.response_headers().for_each_entry([&](auto& name, auto& value) { headers.set(name, value); });
|
||||
download->did_receive_headers({}, headers, message.status_code());
|
||||
}
|
||||
}
|
||||
|
||||
OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> Client::handle(const Messages::ProtocolClient::CertificateRequested& message)
|
||||
{
|
||||
if (auto download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr))) {
|
||||
|
@ -102,3 +115,6 @@ OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> Client::handle(co
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
template RefPtr<Protocol::Download> Protocol::Client::start_download(const String& method, const String& url, const HashMap<String, String>& request_headers, ReadonlyBytes request_body);
|
||||
template RefPtr<Protocol::Download> Protocol::Client::start_download(const String& method, const String& url, const HashMap<String, String, CaseInsensitiveStringTraits>& request_headers, ReadonlyBytes request_body);
|
||||
|
|
|
@ -44,7 +44,8 @@ public:
|
|||
virtual void handshake() override;
|
||||
|
||||
bool is_supported_protocol(const String&);
|
||||
RefPtr<Download> start_download(const String& method, const String& url, const HashMap<String, String>& request_headers = {}, const ByteBuffer& request_body = {});
|
||||
template<typename RequestHashMapTraits = Traits<String>>
|
||||
RefPtr<Download> start_download(const String& method, const String& url, const HashMap<String, String, RequestHashMapTraits>& request_headers = {}, ReadonlyBytes request_body = {});
|
||||
|
||||
bool stop_download(Badge<Download>, Download&);
|
||||
bool set_certificate(Badge<Download>, Download&, String, String);
|
||||
|
@ -55,6 +56,7 @@ private:
|
|||
virtual void handle(const Messages::ProtocolClient::DownloadProgress&) override;
|
||||
virtual void handle(const Messages::ProtocolClient::DownloadFinished&) override;
|
||||
virtual OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> handle(const Messages::ProtocolClient::CertificateRequested&) override;
|
||||
virtual void handle(const Messages::ProtocolClient::HeadersBecameAvailable&) override;
|
||||
|
||||
HashMap<i32, RefPtr<Download>> m_downloads;
|
||||
};
|
||||
|
|
|
@ -41,25 +41,81 @@ bool Download::stop()
|
|||
return m_client->stop_download({}, *this);
|
||||
}
|
||||
|
||||
void Download::did_finish(Badge<Client>, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, const IPC::Dictionary& response_headers)
|
||||
void Download::stream_into(OutputStream& stream)
|
||||
{
|
||||
ASSERT(!m_internal_stream_data);
|
||||
|
||||
auto notifier = Core::Notifier::construct(fd(), Core::Notifier::Read);
|
||||
|
||||
m_internal_stream_data = make<InternalStreamData>(fd());
|
||||
m_internal_stream_data->read_notifier = notifier;
|
||||
|
||||
auto user_on_finish = move(on_finish);
|
||||
on_finish = [this](auto success, auto total_size) {
|
||||
m_internal_stream_data->success = success;
|
||||
m_internal_stream_data->total_size = total_size;
|
||||
m_internal_stream_data->download_done = true;
|
||||
};
|
||||
|
||||
notifier->on_ready_to_read = [this, &stream, user_on_finish = move(user_on_finish)] {
|
||||
constexpr size_t buffer_size = 1 * KiB;
|
||||
static char buf[buffer_size];
|
||||
auto nread = m_internal_stream_data->read_stream.read({ buf, buffer_size });
|
||||
if (!stream.write_or_error({ buf, nread })) {
|
||||
// FIXME: What do we do here?
|
||||
TODO();
|
||||
}
|
||||
|
||||
if (m_internal_stream_data->read_stream.eof() || (m_internal_stream_data->download_done && !m_internal_stream_data->success)) {
|
||||
m_internal_stream_data->read_notifier->close();
|
||||
user_on_finish(m_internal_stream_data->success, m_internal_stream_data->total_size);
|
||||
} else {
|
||||
m_internal_stream_data->read_stream.handle_any_error();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
void Download::set_should_buffer_all_input(bool value)
|
||||
{
|
||||
if (m_should_buffer_all_input == value)
|
||||
return;
|
||||
|
||||
if (m_internal_buffered_data && !value) {
|
||||
m_internal_buffered_data = nullptr;
|
||||
m_should_buffer_all_input = false;
|
||||
return;
|
||||
}
|
||||
|
||||
ASSERT(!m_internal_stream_data);
|
||||
ASSERT(!m_internal_buffered_data);
|
||||
ASSERT(on_buffered_download_finish); // Not having this set makes no sense.
|
||||
m_internal_buffered_data = make<InternalBufferedData>(fd());
|
||||
m_should_buffer_all_input = true;
|
||||
|
||||
on_headers_received = [this](auto& headers, auto response_code) {
|
||||
m_internal_buffered_data->response_headers = headers;
|
||||
m_internal_buffered_data->response_code = move(response_code);
|
||||
};
|
||||
|
||||
on_finish = [this](auto success, u32 total_size) {
|
||||
auto output_buffer = m_internal_buffered_data->payload_stream.copy_into_contiguous_buffer();
|
||||
on_buffered_download_finish(
|
||||
success,
|
||||
total_size,
|
||||
m_internal_buffered_data->response_headers,
|
||||
m_internal_buffered_data->response_code,
|
||||
output_buffer);
|
||||
};
|
||||
|
||||
stream_into(m_internal_buffered_data->payload_stream);
|
||||
}
|
||||
|
||||
void Download::did_finish(Badge<Client>, bool success, u32 total_size)
|
||||
{
|
||||
if (!on_finish)
|
||||
return;
|
||||
|
||||
ReadonlyBytes payload;
|
||||
RefPtr<SharedBuffer> shared_buffer;
|
||||
if (success && shbuf_id != -1) {
|
||||
shared_buffer = SharedBuffer::create_from_shbuf_id(shbuf_id);
|
||||
payload = { shared_buffer->data<void>(), total_size };
|
||||
}
|
||||
|
||||
// FIXME: It's a bit silly that we copy the response headers here just so we can move them into a HashMap with different traits.
|
||||
HashMap<String, String, CaseInsensitiveStringTraits> caseless_response_headers;
|
||||
response_headers.for_each_entry([&](auto& name, auto& value) {
|
||||
caseless_response_headers.set(name, value);
|
||||
});
|
||||
|
||||
on_finish(success, payload, move(shared_buffer), caseless_response_headers, status_code);
|
||||
on_finish(success, total_size);
|
||||
}
|
||||
|
||||
void Download::did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size)
|
||||
|
@ -68,6 +124,12 @@ void Download::did_progress(Badge<Client>, Optional<u32> total_size, u32 downloa
|
|||
on_progress(total_size, downloaded_size);
|
||||
}
|
||||
|
||||
void Download::did_receive_headers(Badge<Client>, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code)
|
||||
{
|
||||
if (on_headers_received)
|
||||
on_headers_received(response_headers, response_code);
|
||||
}
|
||||
|
||||
void Download::did_request_certificates(Badge<Client>)
|
||||
{
|
||||
if (on_certificate_requested) {
|
||||
|
|
|
@ -28,10 +28,13 @@
|
|||
|
||||
#include <AK/Badge.h>
|
||||
#include <AK/ByteBuffer.h>
|
||||
#include <AK/FileStream.h>
|
||||
#include <AK/Function.h>
|
||||
#include <AK/MemoryStream.h>
|
||||
#include <AK/RefCounted.h>
|
||||
#include <AK/String.h>
|
||||
#include <AK/WeakPtr.h>
|
||||
#include <LibCore/Notifier.h>
|
||||
#include <LibIPC/Forward.h>
|
||||
|
||||
namespace Protocol {
|
||||
|
@ -51,20 +54,65 @@ public:
|
|||
}
|
||||
|
||||
int id() const { return m_download_id; }
|
||||
int fd() const { return m_fd; }
|
||||
bool stop();
|
||||
|
||||
Function<void(bool success, ReadonlyBytes payload, RefPtr<SharedBuffer> payload_storage, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> status_code)> on_finish;
|
||||
void stream_into(OutputStream&);
|
||||
|
||||
bool should_buffer_all_input() const { return m_should_buffer_all_input; }
|
||||
/// Note: Will override `on_finish', and `on_headers_received', and expects `on_buffered_download_finish' to be set!
|
||||
void set_should_buffer_all_input(bool);
|
||||
|
||||
/// Note: Must be set before `set_should_buffer_all_input(true)`.
|
||||
Function<void(bool success, u32 total_size, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code, ReadonlyBytes payload)> on_buffered_download_finish;
|
||||
Function<void(bool success, u32 total_size)> on_finish;
|
||||
Function<void(Optional<u32> total_size, u32 downloaded_size)> on_progress;
|
||||
Function<void(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code)> on_headers_received;
|
||||
Function<CertificateAndKey()> on_certificate_requested;
|
||||
|
||||
void did_finish(Badge<Client>, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, const IPC::Dictionary& response_headers);
|
||||
void did_finish(Badge<Client>, bool success, u32 total_size);
|
||||
void did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size);
|
||||
void did_receive_headers(Badge<Client>, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code);
|
||||
void did_request_certificates(Badge<Client>);
|
||||
|
||||
RefPtr<Core::Notifier>& write_notifier(Badge<Client>) { return m_write_notifier; }
|
||||
void set_download_fd(Badge<Client>, int fd) { m_fd = fd; }
|
||||
|
||||
private:
|
||||
explicit Download(Client&, i32 download_id);
|
||||
WeakPtr<Client> m_client;
|
||||
int m_download_id { -1 };
|
||||
RefPtr<Core::Notifier> m_write_notifier;
|
||||
int m_fd { -1 };
|
||||
bool m_should_buffer_all_input { false };
|
||||
|
||||
struct InternalBufferedData {
|
||||
InternalBufferedData(int fd)
|
||||
: read_stream(fd)
|
||||
{
|
||||
}
|
||||
|
||||
InputFileStream read_stream;
|
||||
DuplexMemoryStream payload_stream;
|
||||
HashMap<String, String, CaseInsensitiveStringTraits> response_headers;
|
||||
Optional<u32> response_code;
|
||||
};
|
||||
|
||||
struct InternalStreamData {
|
||||
InternalStreamData(int fd)
|
||||
: read_stream(fd)
|
||||
{
|
||||
}
|
||||
|
||||
InputFileStream read_stream;
|
||||
RefPtr<Core::Notifier> read_notifier;
|
||||
bool success;
|
||||
u32 total_size { 0 };
|
||||
bool download_done { false };
|
||||
};
|
||||
|
||||
OwnPtr<InternalBufferedData> m_internal_buffered_data;
|
||||
OwnPtr<InternalStreamData> m_internal_stream_data;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -92,10 +92,10 @@ void XMLHttpRequest::send()
|
|||
// we need to make ResourceLoader give us more detailed updates than just "done" and "error".
|
||||
ResourceLoader::the().load(
|
||||
m_window->document().complete_url(m_url),
|
||||
[weak_this = make_weak_ptr()](auto& data, auto&) {
|
||||
[weak_this = make_weak_ptr()](auto data, auto&) {
|
||||
if (!weak_this)
|
||||
return;
|
||||
const_cast<XMLHttpRequest&>(*weak_this).m_response = data;
|
||||
const_cast<XMLHttpRequest&>(*weak_this).m_response = ByteBuffer::copy(data);
|
||||
const_cast<XMLHttpRequest&>(*weak_this).set_ready_state(ReadyState::Done);
|
||||
const_cast<XMLHttpRequest&>(*weak_this).dispatch_event(DOM::Event::create(HTML::EventNames::load));
|
||||
},
|
||||
|
|
|
@ -128,7 +128,7 @@ void HTMLScriptElement::prepare_script(Badge<HTMLDocumentParser>)
|
|||
// FIXME: This load should be made asynchronous and the parser should spin an event loop etc.
|
||||
ResourceLoader::the().load_sync(
|
||||
url,
|
||||
[this, url](auto& data, auto&) {
|
||||
[this, url](auto data, auto&) {
|
||||
if (data.is_null()) {
|
||||
dbg() << "HTMLScriptElement: Failed to load " << url;
|
||||
return;
|
||||
|
|
|
@ -171,6 +171,7 @@ bool FrameLoader::load(const LoadRequest& request, Type type)
|
|||
return true;
|
||||
|
||||
if (url.protocol() == "http" || url.protocol() == "https") {
|
||||
#if 0
|
||||
URL favicon_url;
|
||||
favicon_url.set_protocol(url.protocol());
|
||||
favicon_url.set_host(url.host());
|
||||
|
@ -191,6 +192,7 @@ bool FrameLoader::load(const LoadRequest& request, Type type)
|
|||
if (auto* page = frame().page())
|
||||
page->client().page_did_change_favicon(*bitmap);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
return true;
|
||||
|
|
|
@ -84,10 +84,10 @@ static String mime_type_from_content_type(const String& content_type)
|
|||
return content_type;
|
||||
}
|
||||
|
||||
void Resource::did_load(Badge<ResourceLoader>, const ByteBuffer& data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers)
|
||||
void Resource::did_load(Badge<ResourceLoader>, ReadonlyBytes data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers)
|
||||
{
|
||||
ASSERT(!m_loaded);
|
||||
m_encoded_data = data;
|
||||
m_encoded_data = ByteBuffer::copy(data);
|
||||
m_response_headers = headers;
|
||||
m_loaded = true;
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ public:
|
|||
|
||||
void for_each_client(Function<void(ResourceClient&)>);
|
||||
|
||||
void did_load(Badge<ResourceLoader>, const ByteBuffer& data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers);
|
||||
void did_load(Badge<ResourceLoader>, ReadonlyBytes data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers);
|
||||
void did_fail(Badge<ResourceLoader>, const String& error);
|
||||
|
||||
protected:
|
||||
|
|
|
@ -53,13 +53,13 @@ ResourceLoader::ResourceLoader()
|
|||
{
|
||||
}
|
||||
|
||||
void ResourceLoader::load_sync(const URL& url, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
|
||||
void ResourceLoader::load_sync(const URL& url, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
|
||||
{
|
||||
Core::EventLoop loop;
|
||||
|
||||
load(
|
||||
url,
|
||||
[&](auto& data, auto& response_headers) {
|
||||
[&](auto data, auto& response_headers) {
|
||||
success_callback(data, response_headers);
|
||||
loop.quit(0);
|
||||
},
|
||||
|
@ -97,7 +97,7 @@ RefPtr<Resource> ResourceLoader::load_resource(Resource::Type type, const LoadRe
|
|||
|
||||
load(
|
||||
request,
|
||||
[=](auto& data, auto& headers) {
|
||||
[=](auto data, auto& headers) {
|
||||
const_cast<Resource&>(*resource).did_load({}, data, headers);
|
||||
},
|
||||
[=](auto& error) {
|
||||
|
@ -107,7 +107,7 @@ RefPtr<Resource> ResourceLoader::load_resource(Resource::Type type, const LoadRe
|
|||
return resource;
|
||||
}
|
||||
|
||||
void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
|
||||
void ResourceLoader::load(const LoadRequest& request, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
|
||||
{
|
||||
auto& url = request.url();
|
||||
if (is_port_blocked(url.port())) {
|
||||
|
@ -170,7 +170,12 @@ void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBu
|
|||
error_callback("Failed to initiate load");
|
||||
return;
|
||||
}
|
||||
download->on_finish = [this, success_callback = move(success_callback), error_callback = move(error_callback)](bool success, ReadonlyBytes payload, auto, auto& response_headers, auto status_code) {
|
||||
download->on_buffered_download_finish = [this, success_callback = move(success_callback), error_callback = move(error_callback), download](bool success, auto, auto& response_headers, auto status_code, ReadonlyBytes payload) {
|
||||
if (status_code.has_value() && status_code.value() >= 400 && status_code.value() <= 499) {
|
||||
if (error_callback)
|
||||
error_callback(String::format("HTTP error (%u)", status_code.value()));
|
||||
return;
|
||||
}
|
||||
--m_pending_loads;
|
||||
if (on_load_counter_change)
|
||||
on_load_counter_change();
|
||||
|
@ -179,13 +184,9 @@ void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBu
|
|||
error_callback("HTTP load failed");
|
||||
return;
|
||||
}
|
||||
if (status_code.has_value() && status_code.value() >= 400 && status_code.value() <= 499) {
|
||||
if (error_callback)
|
||||
error_callback(String::format("HTTP error (%u)", status_code.value()));
|
||||
return;
|
||||
}
|
||||
success_callback(ByteBuffer::copy(payload.data(), payload.size()), response_headers);
|
||||
success_callback(payload, response_headers);
|
||||
};
|
||||
download->set_should_buffer_all_input(true);
|
||||
download->on_certificate_requested = []() -> Protocol::Download::CertificateAndKey {
|
||||
return {};
|
||||
};
|
||||
|
@ -199,7 +200,7 @@ void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBu
|
|||
error_callback(String::format("Protocol not implemented: %s", url.protocol().characters()));
|
||||
}
|
||||
|
||||
void ResourceLoader::load(const URL& url, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
|
||||
void ResourceLoader::load(const URL& url, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
|
||||
{
|
||||
LoadRequest request;
|
||||
request.set_url(url);
|
||||
|
|
|
@ -44,9 +44,9 @@ public:
|
|||
|
||||
RefPtr<Resource> load_resource(Resource::Type, const LoadRequest&);
|
||||
|
||||
void load(const LoadRequest&, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
|
||||
void load(const URL&, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
|
||||
void load_sync(const URL&, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
|
||||
void load(const LoadRequest&, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
|
||||
void load(const URL&, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
|
||||
void load_sync(const URL&, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
|
||||
|
||||
Function<void()> on_load_counter_change;
|
||||
|
||||
|
|
|
@ -62,16 +62,17 @@ OwnPtr<Messages::ProtocolServer::StartDownloadResponse> ClientConnection::handle
|
|||
{
|
||||
URL url(message.url());
|
||||
if (!url.is_valid())
|
||||
return make<Messages::ProtocolServer::StartDownloadResponse>(-1);
|
||||
return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1);
|
||||
auto* protocol = Protocol::find_by_name(url.protocol());
|
||||
if (!protocol)
|
||||
return make<Messages::ProtocolServer::StartDownloadResponse>(-1);
|
||||
auto download = protocol->start_download(*this, message.method(), url, message.request_headers().entries(), message.request_body().to_byte_buffer());
|
||||
return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1);
|
||||
auto download = protocol->start_download(*this, message.method(), url, message.request_headers().entries(), message.request_body());
|
||||
if (!download)
|
||||
return make<Messages::ProtocolServer::StartDownloadResponse>(-1);
|
||||
return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1);
|
||||
auto id = download->id();
|
||||
auto fd = download->download_fd();
|
||||
m_downloads.set(id, move(download));
|
||||
return make<Messages::ProtocolServer::StartDownloadResponse>(id);
|
||||
return make<Messages::ProtocolServer::StartDownloadResponse>(id, fd);
|
||||
}
|
||||
|
||||
OwnPtr<Messages::ProtocolServer::StopDownloadResponse> ClientConnection::handle(const Messages::ProtocolServer::StopDownload& message)
|
||||
|
@ -86,22 +87,20 @@ OwnPtr<Messages::ProtocolServer::StopDownloadResponse> ClientConnection::handle(
|
|||
return make<Messages::ProtocolServer::StopDownloadResponse>(success);
|
||||
}
|
||||
|
||||
void ClientConnection::did_finish_download(Badge<Download>, Download& download, bool success)
|
||||
void ClientConnection::did_receive_headers(Badge<Download>, Download& download)
|
||||
{
|
||||
RefPtr<SharedBuffer> buffer;
|
||||
if (success && download.payload().size() > 0 && !download.payload().is_null()) {
|
||||
buffer = SharedBuffer::create_with_size(download.payload().size());
|
||||
memcpy(buffer->data<void>(), download.payload().data(), download.payload().size());
|
||||
buffer->seal();
|
||||
buffer->share_with(client_pid());
|
||||
m_shared_buffers.set(buffer->shbuf_id(), buffer);
|
||||
}
|
||||
ASSERT(download.total_size().has_value());
|
||||
|
||||
IPC::Dictionary response_headers;
|
||||
for (auto& it : download.response_headers())
|
||||
response_headers.add(it.key, it.value);
|
||||
post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.status_code(), download.total_size().value(), buffer ? buffer->shbuf_id() : -1, response_headers));
|
||||
|
||||
post_message(Messages::ProtocolClient::HeadersBecameAvailable(download.id(), move(response_headers), download.status_code()));
|
||||
}
|
||||
|
||||
void ClientConnection::did_finish_download(Badge<Download>, Download& download, bool success)
|
||||
{
|
||||
ASSERT(download.total_size().has_value());
|
||||
|
||||
post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.total_size().value()));
|
||||
|
||||
m_downloads.remove(download.id());
|
||||
}
|
||||
|
@ -121,12 +120,6 @@ OwnPtr<Messages::ProtocolServer::GreetResponse> ClientConnection::handle(const M
|
|||
return make<Messages::ProtocolServer::GreetResponse>(client_id());
|
||||
}
|
||||
|
||||
OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> ClientConnection::handle(const Messages::ProtocolServer::DisownSharedBuffer& message)
|
||||
{
|
||||
m_shared_buffers.remove(message.shbuf_id());
|
||||
return make<Messages::ProtocolServer::DisownSharedBufferResponse>();
|
||||
}
|
||||
|
||||
OwnPtr<Messages::ProtocolServer::SetCertificateResponse> ClientConnection::handle(const Messages::ProtocolServer::SetCertificate& message)
|
||||
{
|
||||
auto* download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr));
|
||||
|
|
|
@ -45,6 +45,7 @@ public:
|
|||
|
||||
virtual void die() override;
|
||||
|
||||
void did_receive_headers(Badge<Download>, Download&);
|
||||
void did_finish_download(Badge<Download>, Download&, bool success);
|
||||
void did_progress_download(Badge<Download>, Download&);
|
||||
void did_request_certificates(Badge<Download>, Download&);
|
||||
|
@ -54,11 +55,9 @@ private:
|
|||
virtual OwnPtr<Messages::ProtocolServer::IsSupportedProtocolResponse> handle(const Messages::ProtocolServer::IsSupportedProtocol&) override;
|
||||
virtual OwnPtr<Messages::ProtocolServer::StartDownloadResponse> handle(const Messages::ProtocolServer::StartDownload&) override;
|
||||
virtual OwnPtr<Messages::ProtocolServer::StopDownloadResponse> handle(const Messages::ProtocolServer::StopDownload&) override;
|
||||
virtual OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> handle(const Messages::ProtocolServer::DisownSharedBuffer&) override;
|
||||
virtual OwnPtr<Messages::ProtocolServer::SetCertificateResponse> handle(const Messages::ProtocolServer::SetCertificate&);
|
||||
virtual OwnPtr<Messages::ProtocolServer::SetCertificateResponse> handle(const Messages::ProtocolServer::SetCertificate&) override;
|
||||
|
||||
HashMap<i32, OwnPtr<Download>> m_downloads;
|
||||
HashMap<i32, RefPtr<AK::SharedBuffer>> m_shared_buffers;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -33,9 +33,10 @@ namespace ProtocolServer {
|
|||
// FIXME: What about rollover?
|
||||
static i32 s_next_id = 1;
|
||||
|
||||
Download::Download(ClientConnection& client)
|
||||
Download::Download(ClientConnection& client, NonnullOwnPtr<OutputFileStream>&& output_stream)
|
||||
: m_client(client)
|
||||
, m_id(s_next_id++)
|
||||
, m_output_stream(move(output_stream))
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -48,15 +49,10 @@ void Download::stop()
|
|||
m_client.did_finish_download({}, *this, false);
|
||||
}
|
||||
|
||||
void Download::set_payload(const ByteBuffer& payload)
|
||||
{
|
||||
m_payload = payload;
|
||||
m_total_size = payload.size();
|
||||
}
|
||||
|
||||
void Download::set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)
|
||||
{
|
||||
m_response_headers = response_headers;
|
||||
m_client.did_receive_headers({}, *this);
|
||||
}
|
||||
|
||||
void Download::set_certificate(String, String)
|
||||
|
|
|
@ -26,8 +26,9 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <AK/ByteBuffer.h>
|
||||
#include <AK/FileStream.h>
|
||||
#include <AK/HashMap.h>
|
||||
#include <AK/NonnullOwnPtr.h>
|
||||
#include <AK/Optional.h>
|
||||
#include <AK/RefCounted.h>
|
||||
#include <AK/URL.h>
|
||||
|
@ -45,30 +46,35 @@ public:
|
|||
Optional<u32> status_code() const { return m_status_code; }
|
||||
Optional<u32> total_size() const { return m_total_size; }
|
||||
size_t downloaded_size() const { return m_downloaded_size; }
|
||||
const ByteBuffer& payload() const { return m_payload; }
|
||||
const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers() const { return m_response_headers; }
|
||||
|
||||
void stop();
|
||||
virtual void set_certificate(String, String);
|
||||
|
||||
// FIXME: Want Badge<Protocol>, but can't make one from HttpProtocol, etc.
|
||||
void set_download_fd(int fd) { m_download_fd = fd; }
|
||||
int download_fd() const { return m_download_fd; }
|
||||
|
||||
protected:
|
||||
explicit Download(ClientConnection&);
|
||||
explicit Download(ClientConnection&, NonnullOwnPtr<OutputFileStream>&&);
|
||||
|
||||
void did_finish(bool success);
|
||||
void did_progress(Optional<u32> total_size, u32 downloaded_size);
|
||||
void set_status_code(u32 status_code) { m_status_code = status_code; }
|
||||
void did_request_certificates();
|
||||
void set_payload(const ByteBuffer&);
|
||||
void set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>&);
|
||||
void set_downloaded_size(size_t size) { m_downloaded_size = size; }
|
||||
const OutputFileStream& output_stream() const { return *m_output_stream; }
|
||||
|
||||
private:
|
||||
ClientConnection& m_client;
|
||||
i32 m_id { 0 };
|
||||
int m_download_fd { -1 }; // Passed to client.
|
||||
URL m_url;
|
||||
Optional<u32> m_status_code;
|
||||
Optional<u32> m_total_size {};
|
||||
size_t m_downloaded_size { 0 };
|
||||
ByteBuffer m_payload;
|
||||
NonnullOwnPtr<OutputFileStream> m_output_stream;
|
||||
HashMap<String, String, CaseInsensitiveStringTraits> m_response_headers;
|
||||
};
|
||||
|
||||
|
|
|
@ -30,13 +30,13 @@
|
|||
|
||||
namespace ProtocolServer {
|
||||
|
||||
GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job)
|
||||
: Download(client)
|
||||
GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
|
||||
: Download(client, move(output_stream))
|
||||
, m_job(job)
|
||||
{
|
||||
m_job->on_finish = [this](bool success) {
|
||||
if (auto* response = m_job->response()) {
|
||||
set_payload(response->payload());
|
||||
set_downloaded_size(this->output_stream().size());
|
||||
if (!response->meta().is_empty()) {
|
||||
HashMap<String, String, CaseInsensitiveStringTraits> headers;
|
||||
headers.set("meta", response->meta());
|
||||
|
@ -76,9 +76,9 @@ GeminiDownload::~GeminiDownload()
|
|||
m_job->shutdown();
|
||||
}
|
||||
|
||||
NonnullOwnPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job)
|
||||
NonnullOwnPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
|
||||
{
|
||||
return adopt_own(*new GeminiDownload(client, move(job)));
|
||||
return adopt_own(*new GeminiDownload(client, move(job), move(output_stream)));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -36,10 +36,10 @@ namespace ProtocolServer {
|
|||
class GeminiDownload final : public Download {
|
||||
public:
|
||||
virtual ~GeminiDownload() override;
|
||||
static NonnullOwnPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>);
|
||||
static NonnullOwnPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>, NonnullOwnPtr<OutputFileStream>&&);
|
||||
|
||||
private:
|
||||
explicit GeminiDownload(ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>);
|
||||
explicit GeminiDownload(ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>, NonnullOwnPtr<OutputFileStream>&&);
|
||||
|
||||
virtual void set_certificate(String certificate, String key) override;
|
||||
|
||||
|
|
|
@ -40,12 +40,22 @@ GeminiProtocol::~GeminiProtocol()
|
|||
{
|
||||
}
|
||||
|
||||
OwnPtr<Download> GeminiProtocol::start_download(ClientConnection& client, const String&, const URL& url, const HashMap<String, String>&, const ByteBuffer&)
|
||||
OwnPtr<Download> GeminiProtocol::start_download(ClientConnection& client, const String&, const URL& url, const HashMap<String, String>&, ReadonlyBytes)
|
||||
{
|
||||
Gemini::GeminiRequest request;
|
||||
request.set_url(url);
|
||||
auto job = Gemini::GeminiJob::construct(request);
|
||||
auto download = GeminiDownload::create_with_job({}, client, (Gemini::GeminiJob&)*job);
|
||||
|
||||
int fd_pair[2] { 0 };
|
||||
if (pipe(fd_pair) != 0) {
|
||||
auto saved_errno = errno;
|
||||
dbgln("Protocol: pipe() failed: {}", strerror(saved_errno));
|
||||
return nullptr;
|
||||
}
|
||||
auto output_stream = make<OutputFileStream>(fd_pair[1]);
|
||||
output_stream->make_unbuffered();
|
||||
auto job = Gemini::GeminiJob::construct(request, *output_stream);
|
||||
auto download = GeminiDownload::create_with_job({}, client, (Gemini::GeminiJob&)*job, move(output_stream));
|
||||
download->set_download_fd(fd_pair[0]);
|
||||
job->start();
|
||||
return download;
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ public:
|
|||
GeminiProtocol();
|
||||
virtual ~GeminiProtocol() override;
|
||||
|
||||
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>&, const ByteBuffer& request_body) override;
|
||||
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>&, ReadonlyBytes body) override;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -30,15 +30,21 @@
|
|||
|
||||
namespace ProtocolServer {
|
||||
|
||||
HttpDownload::HttpDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job)
|
||||
: Download(client)
|
||||
HttpDownload::HttpDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
|
||||
: Download(client, move(output_stream))
|
||||
, m_job(job)
|
||||
{
|
||||
m_job->on_headers_received = [this](auto& headers, auto response_code) {
|
||||
if (response_code.has_value())
|
||||
set_status_code(response_code.value());
|
||||
set_response_headers(headers);
|
||||
};
|
||||
|
||||
m_job->on_finish = [this](bool success) {
|
||||
if (auto* response = m_job->response()) {
|
||||
set_status_code(response->code());
|
||||
set_payload(response->payload());
|
||||
set_response_headers(response->headers());
|
||||
set_downloaded_size(this->output_stream().size());
|
||||
}
|
||||
|
||||
// if we didn't know the total size, pretend that the download finished successfully
|
||||
|
@ -60,9 +66,9 @@ HttpDownload::~HttpDownload()
|
|||
m_job->shutdown();
|
||||
}
|
||||
|
||||
NonnullOwnPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job)
|
||||
NonnullOwnPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
|
||||
{
|
||||
return adopt_own(*new HttpDownload(client, move(job)));
|
||||
return adopt_own(*new HttpDownload(client, move(job), move(output_stream)));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -36,10 +36,10 @@ namespace ProtocolServer {
|
|||
class HttpDownload final : public Download {
|
||||
public:
|
||||
virtual ~HttpDownload() override;
|
||||
static NonnullOwnPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpJob>);
|
||||
static NonnullOwnPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpJob>, NonnullOwnPtr<OutputFileStream>&&);
|
||||
|
||||
private:
|
||||
explicit HttpDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpJob>);
|
||||
explicit HttpDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpJob>, NonnullOwnPtr<OutputFileStream>&&);
|
||||
|
||||
NonnullRefPtr<HTTP::HttpJob> m_job;
|
||||
};
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include <LibHTTP/HttpRequest.h>
|
||||
#include <ProtocolServer/HttpDownload.h>
|
||||
#include <ProtocolServer/HttpProtocol.h>
|
||||
#include <fcntl.h>
|
||||
|
||||
namespace ProtocolServer {
|
||||
|
||||
|
@ -40,7 +41,7 @@ HttpProtocol::~HttpProtocol()
|
|||
{
|
||||
}
|
||||
|
||||
OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, const ByteBuffer& request_body)
|
||||
OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, ReadonlyBytes body)
|
||||
{
|
||||
HTTP::HttpRequest request;
|
||||
if (method.equals_ignoring_case("post"))
|
||||
|
@ -49,9 +50,20 @@ OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const St
|
|||
request.set_method(HTTP::HttpRequest::Method::GET);
|
||||
request.set_url(url);
|
||||
request.set_headers(headers);
|
||||
request.set_body(request_body);
|
||||
auto job = HTTP::HttpJob::construct(request);
|
||||
auto download = HttpDownload::create_with_job({}, client, (HTTP::HttpJob&)*job);
|
||||
request.set_body(body);
|
||||
|
||||
int fd_pair[2] { 0 };
|
||||
if (pipe(fd_pair) != 0) {
|
||||
auto saved_errno = errno;
|
||||
dbgln("Protocol: pipe() failed: {}", strerror(saved_errno));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto output_stream = make<OutputFileStream>(fd_pair[1]);
|
||||
output_stream->make_unbuffered();
|
||||
auto job = HTTP::HttpJob::construct(request, *output_stream);
|
||||
auto download = HttpDownload::create_with_job({}, client, (HTTP::HttpJob&)*job, move(output_stream));
|
||||
download->set_download_fd(fd_pair[0]);
|
||||
job->start();
|
||||
return download;
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ public:
|
|||
HttpProtocol();
|
||||
virtual ~HttpProtocol() override;
|
||||
|
||||
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) override;
|
||||
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) override;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -30,15 +30,21 @@
|
|||
|
||||
namespace ProtocolServer {
|
||||
|
||||
HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job)
|
||||
: Download(client)
|
||||
HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
|
||||
: Download(client, move(output_stream))
|
||||
, m_job(job)
|
||||
{
|
||||
m_job->on_headers_received = [this](auto& headers, auto response_code) {
|
||||
if (response_code.has_value())
|
||||
set_status_code(response_code.value());
|
||||
set_response_headers(headers);
|
||||
};
|
||||
|
||||
m_job->on_finish = [this](bool success) {
|
||||
if (auto* response = m_job->response()) {
|
||||
set_status_code(response->code());
|
||||
set_payload(response->payload());
|
||||
set_response_headers(response->headers());
|
||||
set_downloaded_size(this->output_stream().size());
|
||||
}
|
||||
|
||||
// if we didn't know the total size, pretend that the download finished successfully
|
||||
|
@ -68,9 +74,9 @@ HttpsDownload::~HttpsDownload()
|
|||
m_job->shutdown();
|
||||
}
|
||||
|
||||
NonnullOwnPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job)
|
||||
NonnullOwnPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
|
||||
{
|
||||
return adopt_own(*new HttpsDownload(client, move(job)));
|
||||
return adopt_own(*new HttpsDownload(client, move(job), move(output_stream)));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -36,10 +36,10 @@ namespace ProtocolServer {
|
|||
class HttpsDownload final : public Download {
|
||||
public:
|
||||
virtual ~HttpsDownload() override;
|
||||
static NonnullOwnPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>);
|
||||
static NonnullOwnPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>, NonnullOwnPtr<OutputFileStream>&&);
|
||||
|
||||
private:
|
||||
explicit HttpsDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>);
|
||||
explicit HttpsDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>, NonnullOwnPtr<OutputFileStream>&&);
|
||||
|
||||
virtual void set_certificate(String certificate, String key) override;
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ HttpsProtocol::~HttpsProtocol()
|
|||
{
|
||||
}
|
||||
|
||||
OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, const ByteBuffer& request_body)
|
||||
OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, ReadonlyBytes body)
|
||||
{
|
||||
HTTP::HttpRequest request;
|
||||
if (method.equals_ignoring_case("post"))
|
||||
|
@ -49,9 +49,19 @@ OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const S
|
|||
request.set_method(HTTP::HttpRequest::Method::GET);
|
||||
request.set_url(url);
|
||||
request.set_headers(headers);
|
||||
request.set_body(request_body);
|
||||
auto job = HTTP::HttpsJob::construct(request);
|
||||
auto download = HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job);
|
||||
request.set_body(body);
|
||||
|
||||
int fd_pair[2] { 0 };
|
||||
if (pipe(fd_pair) != 0) {
|
||||
auto saved_errno = errno;
|
||||
dbgln("Protocol: pipe() failed: {}", strerror(saved_errno));
|
||||
return nullptr;
|
||||
}
|
||||
auto output_stream = make<OutputFileStream>(fd_pair[1]);
|
||||
output_stream->make_unbuffered();
|
||||
auto job = HTTP::HttpsJob::construct(request, *output_stream);
|
||||
auto download = HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job, move(output_stream));
|
||||
download->set_download_fd(fd_pair[0]);
|
||||
job->start();
|
||||
return download;
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ public:
|
|||
HttpsProtocol();
|
||||
virtual ~HttpsProtocol() override;
|
||||
|
||||
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) override;
|
||||
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) override;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ public:
|
|||
virtual ~Protocol();
|
||||
|
||||
const String& name() const { return m_name; }
|
||||
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) = 0;
|
||||
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) = 0;
|
||||
|
||||
static Protocol* find_by_name(const String&);
|
||||
|
||||
|
|
|
@ -2,7 +2,8 @@ endpoint ProtocolClient = 13
|
|||
{
|
||||
// Download notifications
|
||||
DownloadProgress(i32 download_id, Optional<u32> total_size, u32 downloaded_size) =|
|
||||
DownloadFinished(i32 download_id, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, IPC::Dictionary response_headers) =|
|
||||
DownloadFinished(i32 download_id, bool success, u32 total_size) =|
|
||||
HeadersBecameAvailable(i32 download_id, IPC::Dictionary response_headers, Optional<u32> status_code) =|
|
||||
|
||||
// Certificate requests
|
||||
CertificateRequested(i32 download_id) => ()
|
||||
|
|
|
@ -3,14 +3,11 @@ endpoint ProtocolServer = 9
|
|||
// Basic protocol
|
||||
Greet() => (i32 client_id)
|
||||
|
||||
// FIXME: It would be nice if the kernel provided a way to avoid this
|
||||
DisownSharedBuffer(i32 shbuf_id) => ()
|
||||
|
||||
// Test if a specific protocol is supported, e.g "http"
|
||||
IsSupportedProtocol(String protocol) => (bool supported)
|
||||
|
||||
// Download API
|
||||
StartDownload(String method, URL url, IPC::Dictionary request_headers, String request_body) => (i32 download_id)
|
||||
StartDownload(String method, URL url, IPC::Dictionary request_headers, ByteBuffer request_body) => (i32 download_id, IPC::File response_fd)
|
||||
StopDownload(i32 download_id) => (bool success)
|
||||
SetCertificate(i32 download_id, String certificate, String key) => (bool success)
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@
|
|||
|
||||
int main(int, char**)
|
||||
{
|
||||
if (pledge("stdio inet shared_buffer accept unix rpath cpath fattr", nullptr) < 0) {
|
||||
if (pledge("stdio inet shared_buffer accept unix rpath cpath fattr sendfd recvfd", nullptr) < 0) {
|
||||
perror("pledge");
|
||||
return 1;
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ int main(int, char**)
|
|||
|
||||
Core::EventLoop event_loop;
|
||||
// FIXME: Establish a connection to LookupServer and then drop "unix"?
|
||||
if (pledge("stdio inet shared_buffer accept unix", nullptr) < 0) {
|
||||
if (pledge("stdio inet shared_buffer accept unix sendfd recvfd", nullptr) < 0) {
|
||||
perror("pledge");
|
||||
return 1;
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
|
||||
#include <AK/FileStream.h>
|
||||
#include <AK/GenericLexer.h>
|
||||
#include <AK/LexicalPath.h>
|
||||
#include <AK/NumberFormat.h>
|
||||
|
@ -116,29 +117,50 @@ private:
|
|||
bool m_might_be_wrong { false };
|
||||
};
|
||||
|
||||
static void do_write(ReadonlyBytes payload)
|
||||
{
|
||||
size_t length_remaining = payload.size();
|
||||
size_t length_written = 0;
|
||||
while (length_remaining > 0) {
|
||||
auto nwritten = fwrite(payload.data() + length_written, sizeof(char), length_remaining, stdout);
|
||||
if (nwritten > 0) {
|
||||
length_remaining -= nwritten;
|
||||
length_written += nwritten;
|
||||
continue;
|
||||
}
|
||||
template<typename ConditionT>
|
||||
class ConditionalOutputFileStream final : public OutputFileStream {
|
||||
public:
|
||||
template<typename... Args>
|
||||
ConditionalOutputFileStream(ConditionT&& condition, Args... args)
|
||||
: OutputFileStream(args...)
|
||||
, m_condition(condition)
|
||||
{
|
||||
}
|
||||
|
||||
if (feof(stdout)) {
|
||||
fprintf(stderr, "pro: unexpected eof while writing\n");
|
||||
~ConditionalOutputFileStream()
|
||||
{
|
||||
if (!m_condition())
|
||||
return;
|
||||
}
|
||||
|
||||
if (ferror(stdout)) {
|
||||
fprintf(stderr, "pro: error while writing\n");
|
||||
return;
|
||||
if (!m_buffer.is_empty()) {
|
||||
OutputFileStream::write(m_buffer);
|
||||
m_buffer.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
size_t write(ReadonlyBytes bytes) override
|
||||
{
|
||||
if (!m_condition()) {
|
||||
write_to_buffer:;
|
||||
m_buffer.append(bytes.data(), bytes.size());
|
||||
return bytes.size();
|
||||
}
|
||||
|
||||
if (!m_buffer.is_empty()) {
|
||||
auto size = OutputFileStream::write(m_buffer);
|
||||
m_buffer = m_buffer.slice(size, m_buffer.size() - size);
|
||||
}
|
||||
|
||||
if (!m_buffer.is_empty())
|
||||
goto write_to_buffer;
|
||||
|
||||
return OutputFileStream::write(bytes);
|
||||
}
|
||||
|
||||
ConditionT m_condition;
|
||||
ByteBuffer m_buffer;
|
||||
};
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
|
@ -195,6 +217,8 @@ int main(int argc, char** argv)
|
|||
timeval prev_time, current_time, time_diff;
|
||||
gettimeofday(&prev_time, nullptr);
|
||||
|
||||
bool received_actual_headers = false;
|
||||
|
||||
download->on_progress = [&](Optional<u32> maybe_total_size, u32 downloaded_size) {
|
||||
fprintf(stderr, "\r\033[2K");
|
||||
if (maybe_total_size.has_value()) {
|
||||
|
@ -215,10 +239,13 @@ int main(int argc, char** argv)
|
|||
previous_downloaded_size = downloaded_size;
|
||||
prev_time = current_time;
|
||||
};
|
||||
download->on_finish = [&](bool success, auto payload, auto, auto& response_headers, auto) {
|
||||
fprintf(stderr, "\033]9;-1;\033\\");
|
||||
fprintf(stderr, "\n");
|
||||
if (success && save_at_provided_name) {
|
||||
|
||||
if (save_at_provided_name) {
|
||||
download->on_headers_received = [&](auto& response_headers, auto status_code) {
|
||||
if (received_actual_headers)
|
||||
return;
|
||||
dbg() << "Received headers! response code = " << status_code.value_or(0);
|
||||
received_actual_headers = true; // And not trailers!
|
||||
String output_name;
|
||||
if (auto content_disposition = response_headers.get("Content-Disposition"); content_disposition.has_value()) {
|
||||
auto& value = content_disposition.value();
|
||||
|
@ -245,17 +272,26 @@ int main(int argc, char** argv)
|
|||
|
||||
if (freopen(output_name.characters(), "w", stdout) == nullptr) {
|
||||
perror("freopen");
|
||||
success = false; // oops!
|
||||
loop.quit(1);
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (success)
|
||||
do_write(payload);
|
||||
else
|
||||
};
|
||||
}
|
||||
download->on_finish = [&](bool success, auto) {
|
||||
fprintf(stderr, "\033]9;-1;\033\\");
|
||||
fprintf(stderr, "\n");
|
||||
if (!success)
|
||||
fprintf(stderr, "Download failed :(\n");
|
||||
loop.quit(0);
|
||||
};
|
||||
|
||||
auto output_stream = ConditionalOutputFileStream { [&] { return save_at_provided_name ? received_actual_headers : true; }, stdout };
|
||||
download->stream_into(output_stream);
|
||||
|
||||
dbgprintf("started download with id %d\n", download->id());
|
||||
|
||||
return loop.exec();
|
||||
auto rc = loop.exec();
|
||||
// FIXME: This shouldn't be needed.
|
||||
fclose(stdout);
|
||||
return rc;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue