ProtocolServer: Stream the downloaded data if possible

This patchset makes ProtocolServer stream the downloads to its client
(LibProtocol), and as such changes the download API; a possible
download lifecycle could be as such:
notation = client->server:'>', server->client:'<', pipe activity:'*'
```
> StartDownload(GET, url, headers, {})
< Response(0, fd 8)
* {data, 1024b}
< HeadersBecameAvailable(0, response_headers, 200)
< DownloadProgress(0, 4K, 1024)
* {data, 1024b}
* {data, 1024b}
< DownloadProgress(0, 4K, 2048)
* {data, 1024b}
< DownloadProgress(0, 4K, 1024)
< DownloadFinished(0, true, 4K)
```

Since managing the received file descriptor is a pain, LibProtocol
implements `Download::stream_into(OutputStream)`, which can be used to
stream the download into any given output stream (be it a file, or
memory, or writing stuff with a delay, etc.).
Also, as some of the users of this API require all the downloaded data
upfront, LibProtocol also implements `set_should_buffer_all_input()`,
which causes the download instance to buffer all the data until the
download is complete, and to call the `on_buffered_download_finish`
hook.
This commit is contained in:
AnotherTest 2020-12-26 17:14:12 +03:30 committed by Andreas Kling
parent 36d642ee75
commit 4a2da10e38
Notes: sideshowbarker 2024-07-19 00:23:12 +09:00
55 changed files with 528 additions and 235 deletions

View file

@ -29,6 +29,7 @@
#include <AK/SharedBuffer.h>
#include <AK/StringBuilder.h>
#include <LibCore/File.h>
#include <LibCore/FileStream.h>
#include <LibCore/StandardPaths.h>
#include <LibDesktop/Launcher.h>
#include <LibGUI/BoxLayout.h>
@ -61,9 +62,19 @@ DownloadWidget::DownloadWidget(const URL& url)
m_download->on_progress = [this](Optional<u32> total_size, u32 downloaded_size) {
did_progress(total_size.value(), downloaded_size);
};
m_download->on_finish = [this](bool success, auto payload, auto payload_storage, auto& response_headers, auto) {
did_finish(success, payload, payload_storage, response_headers);
};
{
auto file_or_error = Core::File::open(m_destination_path, Core::IODevice::WriteOnly);
if (file_or_error.is_error()) {
GUI::MessageBox::show(window(), String::formatted("Cannot open {} for writing", m_destination_path), "Download failed", GUI::MessageBox::Type::Error);
window()->close();
return;
}
m_output_file_stream = make<Core::OutputFileStream>(*file_or_error.value());
}
m_download->on_finish = [this](bool success, auto) { did_finish(success); };
m_download->stream_into(*m_output_file_stream);
set_fill_with_background_color(true);
auto& layout = set_layout<GUI::VerticalBoxLayout>();
@ -149,7 +160,7 @@ void DownloadWidget::did_progress(Optional<u32> total_size, u32 downloaded_size)
}
}
void DownloadWidget::did_finish(bool success, [[maybe_unused]] ReadonlyBytes payload, [[maybe_unused]] RefPtr<SharedBuffer> payload_storage, [[maybe_unused]] const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)
void DownloadWidget::did_finish(bool success)
{
dbg() << "did_finish, success=" << success;
@ -166,17 +177,6 @@ void DownloadWidget::did_finish(bool success, [[maybe_unused]] ReadonlyBytes pay
window()->close();
return;
}
auto file_or_error = Core::File::open(m_destination_path, Core::IODevice::WriteOnly);
if (file_or_error.is_error()) {
GUI::MessageBox::show(window(), String::formatted("Cannot open {} for writing", m_destination_path), "Download failed", GUI::MessageBox::Type::Error);
window()->close();
return;
}
auto& file = *file_or_error.value();
bool write_success = file.write(payload.data(), payload.size());
ASSERT(write_success);
}
}

View file

@ -28,6 +28,7 @@
#include <AK/URL.h>
#include <LibCore/ElapsedTimer.h>
#include <LibCore/FileStream.h>
#include <LibGUI/ProgressBar.h>
#include <LibGUI/Widget.h>
#include <LibProtocol/Download.h>
@ -44,7 +45,7 @@ private:
explicit DownloadWidget(const URL&);
void did_progress(Optional<u32> total_size, u32 downloaded_size);
void did_finish(bool success, ReadonlyBytes payload, RefPtr<SharedBuffer> payload_storage, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers);
void did_finish(bool success);
URL m_url;
String m_destination_path;
@ -53,6 +54,7 @@ private:
RefPtr<GUI::Label> m_progress_label;
RefPtr<GUI::Button> m_cancel_button;
RefPtr<GUI::Button> m_close_button;
OwnPtr<Core::OutputFileStream> m_output_file_stream;
Core::ElapsedTimer m_elapsed_timer;
};

View file

@ -68,7 +68,7 @@ int main(int argc, char** argv)
return 1;
}
if (pledge("stdio shared_buffer accept unix cpath rpath wpath fattr", nullptr) < 0) {
if (pledge("stdio shared_buffer accept unix cpath rpath wpath fattr sendfd recvfd", nullptr) < 0) {
perror("pledge");
return 1;
}
@ -86,7 +86,7 @@ int main(int argc, char** argv)
Web::ResourceLoader::the();
// FIXME: Once there is a standalone Download Manager, we can drop the "unix" pledge.
if (pledge("stdio shared_buffer accept unix cpath rpath wpath", nullptr) < 0) {
if (pledge("stdio shared_buffer accept unix cpath rpath wpath sendfd recvfd", nullptr) < 0) {
perror("pledge");
return 1;
}

View file

@ -32,7 +32,8 @@
namespace Core {
NetworkJob::NetworkJob()
NetworkJob::NetworkJob(OutputStream& output_stream)
: m_output_stream(output_stream)
{
}

View file

@ -27,6 +27,7 @@
#pragma once
#include <AK/Function.h>
#include <AK/Stream.h>
#include <LibCore/Object.h>
namespace Core {
@ -43,6 +44,8 @@ public:
};
virtual ~NetworkJob() override;
// Could fire twice, after Headers and after Trailers!
Function<void(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code)> on_headers_received;
Function<void(bool success)> on_finish;
Function<void(Optional<u32>, u32)> on_progress;
@ -62,13 +65,16 @@ public:
}
protected:
NetworkJob();
NetworkJob(OutputStream&);
void did_finish(NonnullRefPtr<NetworkResponse>&&);
void did_fail(Error);
void did_progress(Optional<u32> total_size, u32 downloaded);
size_t do_write(ReadonlyBytes bytes) { return m_output_stream.write(bytes); }
private:
RefPtr<NetworkResponse> m_response;
OutputStream& m_output_stream;
Error m_error { Error::None };
};

View file

@ -28,8 +28,7 @@
namespace Core {
NetworkResponse::NetworkResponse(ByteBuffer&& payload)
: m_payload(payload)
NetworkResponse::NetworkResponse()
{
}

View file

@ -36,13 +36,11 @@ public:
virtual ~NetworkResponse();
bool is_error() const { return m_error; }
const ByteBuffer& payload() const { return m_payload; }
protected:
explicit NetworkResponse(ByteBuffer&&);
explicit NetworkResponse();
bool m_error { false };
ByteBuffer m_payload;
};
}

View file

@ -142,9 +142,9 @@ bool GeminiJob::eof() const
return m_socket->eof();
}
bool GeminiJob::write(const ByteBuffer& data)
bool GeminiJob::write(ReadonlyBytes bytes)
{
return m_socket->write(data);
return m_socket->write(bytes);
}
}

View file

@ -37,8 +37,8 @@ namespace Gemini {
class GeminiJob final : public Job {
C_OBJECT(GeminiJob)
public:
explicit GeminiJob(const GeminiRequest& request, const Vector<Certificate>* override_certificates = nullptr)
: Job(request)
explicit GeminiJob(const GeminiRequest& request, OutputStream& output_stream, const Vector<Certificate>* override_certificates = nullptr)
: Job(request, output_stream)
, m_override_ca_certificates(override_certificates)
{
}
@ -61,7 +61,7 @@ protected:
virtual bool can_read() const override;
virtual ByteBuffer receive(size_t) override;
virtual bool eof() const override;
virtual bool write(const ByteBuffer&) override;
virtual bool write(ReadonlyBytes) override;
virtual bool is_established() const override { return m_socket->is_established(); }
virtual bool should_fail_on_empty_payload() const override { return false; }
virtual void read_while_data_available(Function<IterationDecision()>) override;

View file

@ -28,9 +28,8 @@
namespace Gemini {
GeminiResponse::GeminiResponse(int status, String meta, ByteBuffer&& payload)
: Core::NetworkResponse(move(payload))
, m_status(status)
GeminiResponse::GeminiResponse(int status, String meta)
: m_status(status)
, m_meta(meta)
{
}

View file

@ -34,16 +34,16 @@ namespace Gemini {
class GeminiResponse : public Core::NetworkResponse {
public:
virtual ~GeminiResponse() override;
static NonnullRefPtr<GeminiResponse> create(int status, String meta, ByteBuffer&& payload)
static NonnullRefPtr<GeminiResponse> create(int status, String meta)
{
return adopt(*new GeminiResponse(status, meta, move(payload)));
return adopt(*new GeminiResponse(status, meta));
}
int status() const { return m_status; }
String meta() const { return m_meta; }
private:
GeminiResponse(int status, String, ByteBuffer&&);
GeminiResponse(int status, String);
int m_status { 0 };
String m_meta;

View file

@ -33,8 +33,9 @@
namespace Gemini {
Job::Job(const GeminiRequest& request)
: m_request(request)
Job::Job(const GeminiRequest& request, OutputStream& output_stream)
: Core::NetworkJob(output_stream)
, m_request(request)
{
}
@ -42,6 +43,23 @@ Job::~Job()
{
}
void Job::flush_received_buffers()
{
for (size_t i = 0; i < m_received_buffers.size(); ++i) {
auto& payload = m_received_buffers[i];
auto written = do_write(payload);
m_received_size -= written;
if (written == payload.size()) {
// FIXME: Make this a take-first-friendly object?
m_received_buffers.take_first();
continue;
}
ASSERT(written < payload.size());
payload = payload.slice(written, payload.size() - written);
return;
}
}
void Job::on_socket_connected()
{
register_on_ready_to_write([this] {
@ -126,6 +144,7 @@ void Job::on_socket_connected()
m_received_buffers.append(payload);
m_received_size += payload.size();
flush_received_buffers();
deferred_invoke([this](auto&) { did_progress({}, m_received_size); });
@ -144,15 +163,17 @@ void Job::on_socket_connected()
void Job::finish_up()
{
m_state = State::Finished;
auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size);
u8* flat_ptr = flattened_buffer.data();
for (auto& received_buffer : m_received_buffers) {
memcpy(flat_ptr, received_buffer.data(), received_buffer.size());
flat_ptr += received_buffer.size();
flush_received_buffers();
if (m_received_size != 0) {
// FIXME: What do we do? ignore it?
// "Transmission failed" is not strictly correct, but let's roll with it for now.
deferred_invoke([this](auto&) {
did_fail(Error::TransmissionFailed);
});
return;
}
m_received_buffers.clear();
auto response = GeminiResponse::create(m_status, m_meta, move(flattened_buffer));
auto response = GeminiResponse::create(m_status, m_meta);
deferred_invoke([this, response](auto&) {
did_finish(move(response));
});

View file

@ -36,7 +36,7 @@ namespace Gemini {
class Job : public Core::NetworkJob {
public:
explicit Job(const GeminiRequest&);
explicit Job(const GeminiRequest&, OutputStream&);
virtual ~Job() override;
virtual void start() override = 0;
@ -48,6 +48,7 @@ public:
protected:
void finish_up();
void on_socket_connected();
void flush_received_buffers();
virtual void register_on_ready_to_read(Function<void()>) = 0;
virtual void register_on_ready_to_write(Function<void()>) = 0;
virtual bool can_read_line() const = 0;
@ -55,7 +56,7 @@ protected:
virtual bool can_read() const = 0;
virtual ByteBuffer receive(size_t) = 0;
virtual bool eof() const = 0;
virtual bool write(const ByteBuffer&) = 0;
virtual bool write(ReadonlyBytes) = 0;
virtual bool is_established() const = 0;
virtual bool should_fail_on_empty_payload() const { return false; }
virtual void read_while_data_available(Function<IterationDecision()> read) { read(); };
@ -70,7 +71,7 @@ protected:
State m_state { State::InStatus };
int m_status { -1 };
String m_meta;
Vector<ByteBuffer> m_received_buffers;
Vector<ByteBuffer, 2> m_received_buffers;
size_t m_received_size { 0 };
bool m_sent_data { false };
bool m_should_have_payload { false };

View file

@ -98,9 +98,9 @@ bool HttpJob::eof() const
return m_socket->eof();
}
bool HttpJob::write(const ByteBuffer& data)
bool HttpJob::write(ReadonlyBytes bytes)
{
return m_socket->write(data);
return m_socket->write(bytes);
}
}

View file

@ -38,8 +38,8 @@ namespace HTTP {
class HttpJob final : public Job {
C_OBJECT(HttpJob)
public:
explicit HttpJob(const HttpRequest& request)
: Job(request)
explicit HttpJob(const HttpRequest& request, OutputStream& output_stream)
: Job(request, output_stream)
{
}
@ -59,7 +59,7 @@ protected:
virtual bool can_read() const override;
virtual ByteBuffer receive(size_t) override;
virtual bool eof() const override;
virtual bool write(const ByteBuffer&) override;
virtual bool write(ReadonlyBytes) override;
virtual bool is_established() const override { return true; }
private:

View file

@ -71,11 +71,12 @@ ByteBuffer HttpRequest::to_raw_request() const
builder.append(header.value);
builder.append("\r\n");
}
builder.append("Connection: close\r\n\r\n");
builder.append("Connection: close\r\n");
if (!m_body.is_empty()) {
builder.appendff("Content-Length: {}\r\n\r\n", m_body.size());
builder.append((const char*)m_body.data(), m_body.size());
builder.append("\r\n");
}
builder.append("\r\n");
return builder.to_byte_buffer();
}

View file

@ -62,7 +62,8 @@ public:
void set_method(Method method) { m_method = method; }
const ByteBuffer& body() const { return m_body; }
void set_body(const ByteBuffer& body) { m_body = body; }
void set_body(ReadonlyBytes body) { m_body = ByteBuffer::copy(body); }
void set_body(ByteBuffer&& body) { m_body = move(body); }
String method_name() const;
ByteBuffer to_raw_request() const;

View file

@ -28,9 +28,8 @@
namespace HTTP {
HttpResponse::HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers, ByteBuffer&& payload)
: Core::NetworkResponse(move(payload))
, m_code(code)
HttpResponse::HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers)
: m_code(code)
, m_headers(move(headers))
{
}

View file

@ -35,16 +35,16 @@ namespace HTTP {
class HttpResponse : public Core::NetworkResponse {
public:
virtual ~HttpResponse() override;
static NonnullRefPtr<HttpResponse> create(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers, ByteBuffer&& payload)
static NonnullRefPtr<HttpResponse> create(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers)
{
return adopt(*new HttpResponse(code, move(headers), move(payload)));
return adopt(*new HttpResponse(code, move(headers)));
}
int code() const { return m_code; }
const HashMap<String, String, CaseInsensitiveStringTraits>& headers() const { return m_headers; }
private:
HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&&, ByteBuffer&&);
HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&&);
int m_code { 0 };
HashMap<String, String, CaseInsensitiveStringTraits> m_headers;

View file

@ -143,7 +143,7 @@ bool HttpsJob::eof() const
return m_socket->eof();
}
bool HttpsJob::write(const ByteBuffer& data)
bool HttpsJob::write(ReadonlyBytes data)
{
return m_socket->write(data);
}

View file

@ -38,8 +38,8 @@ namespace HTTP {
class HttpsJob final : public Job {
C_OBJECT(HttpsJob)
public:
explicit HttpsJob(const HttpRequest& request, const Vector<Certificate>* override_certs = nullptr)
: Job(request)
explicit HttpsJob(const HttpRequest& request, OutputStream& output_stream, const Vector<Certificate>* override_certs = nullptr)
: Job(request, output_stream)
, m_override_ca_certificates(override_certs)
{
}
@ -62,7 +62,7 @@ protected:
virtual bool can_read() const override;
virtual ByteBuffer receive(size_t) override;
virtual bool eof() const override;
virtual bool write(const ByteBuffer&) override;
virtual bool write(ReadonlyBytes) override;
virtual bool is_established() const override { return m_socket->is_established(); }
virtual bool should_fail_on_empty_payload() const override { return false; }
virtual void read_while_data_available(Function<IterationDecision()>) override;

View file

@ -68,8 +68,9 @@ static ByteBuffer handle_content_encoding(const ByteBuffer& buf, const String& c
return buf;
}
Job::Job(const HttpRequest& request)
: m_request(request)
Job::Job(const HttpRequest& request, OutputStream& output_stream)
: Core::NetworkJob(output_stream)
, m_request(request)
{
}
@ -77,6 +78,35 @@ Job::~Job()
{
}
void Job::flush_received_buffers()
{
if (!m_can_stream_response || m_buffered_size == 0)
return;
#ifdef JOB_DEBUG
dbg() << "Job: Flushing received buffers: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers";
#endif
for (size_t i = 0; i < m_received_buffers.size(); ++i) {
auto& payload = m_received_buffers[i];
auto written = do_write(payload);
m_buffered_size -= written;
if (written == payload.size()) {
// FIXME: Make this a take-first-friendly object?
m_received_buffers.take_first();
--i;
continue;
}
ASSERT(written < payload.size());
payload = payload.slice(written, payload.size() - written);
#ifdef JOB_DEBUG
dbg() << "Job: Flushing received buffers done: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers";
#endif
return;
}
#ifdef JOB_DEBUG
dbg() << "Job: Flushing received buffers done: have " << m_buffered_size << " bytes in " << m_received_buffers.size() << " buffers";
#endif
}
void Job::on_socket_connected()
{
register_on_ready_to_write([&] {
@ -135,6 +165,8 @@ void Job::on_socket_connected()
if (m_state == State::Trailers) {
return finish_up();
} else {
if (on_headers_received)
on_headers_received(m_headers, m_code > 0 ? m_code : Optional<u32> {});
m_state = State::InBody;
}
return;
@ -163,6 +195,13 @@ void Job::on_socket_connected()
}
auto value = line.substring(name.length() + 2, line.length() - name.length() - 2);
m_headers.set(name, value);
if (name.equals_ignoring_case("Content-Encoding")) {
// Assume that any content-encoding means that we can't decode it as a stream :(
#ifdef JOB_DEBUG
dbg() << "Content-Encoding " << value << " detected, cannot stream output :(";
#endif
m_can_stream_response = false;
}
#ifdef JOB_DEBUG
dbg() << "Job: [" << name << "] = '" << value << "'";
#endif
@ -252,7 +291,9 @@ void Job::on_socket_connected()
}
m_received_buffers.append(payload);
m_buffered_size += payload.size();
m_received_size += payload.size();
flush_received_buffers();
if (m_current_chunk_remaining_size.has_value()) {
auto size = m_current_chunk_remaining_size.value() - payload.size();
@ -313,20 +354,37 @@ void Job::on_socket_connected()
void Job::finish_up()
{
m_state = State::Finished;
auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size);
u8* flat_ptr = flattened_buffer.data();
for (auto& received_buffer : m_received_buffers) {
memcpy(flat_ptr, received_buffer.data(), received_buffer.size());
flat_ptr += received_buffer.size();
}
m_received_buffers.clear();
if (!m_can_stream_response) {
auto flattened_buffer = ByteBuffer::create_uninitialized(m_received_size);
u8* flat_ptr = flattened_buffer.data();
for (auto& received_buffer : m_received_buffers) {
memcpy(flat_ptr, received_buffer.data(), received_buffer.size());
flat_ptr += received_buffer.size();
}
m_received_buffers.clear();
auto content_encoding = m_headers.get("Content-Encoding");
if (content_encoding.has_value()) {
flattened_buffer = handle_content_encoding(flattened_buffer, content_encoding.value());
// For the time being, we cannot stream stuff with content-encoding set to _anything_.
auto content_encoding = m_headers.get("Content-Encoding");
if (content_encoding.has_value()) {
flattened_buffer = handle_content_encoding(flattened_buffer, content_encoding.value());
}
m_buffered_size = flattened_buffer.size();
m_received_buffers.append(move(flattened_buffer));
m_can_stream_response = true;
}
auto response = HttpResponse::create(m_code, move(m_headers), move(flattened_buffer));
flush_received_buffers();
if (m_buffered_size != 0) {
// FIXME: What do we do? ignore it?
// "Transmission failed" is not strictly correct, but let's roll with it for now.
deferred_invoke([this](auto&) {
did_fail(Error::TransmissionFailed);
});
return;
}
auto response = HttpResponse::create(m_code, move(m_headers));
deferred_invoke([this, response](auto&) {
did_finish(move(response));
});

View file

@ -26,6 +26,7 @@
#pragma once
#include <AK/FileStream.h>
#include <AK/HashMap.h>
#include <AK/Optional.h>
#include <LibCore/NetworkJob.h>
@ -37,7 +38,7 @@ namespace HTTP {
class Job : public Core::NetworkJob {
public:
explicit Job(const HttpRequest&);
explicit Job(const HttpRequest&, OutputStream&);
virtual ~Job() override;
virtual void start() override = 0;
@ -49,6 +50,7 @@ public:
protected:
void finish_up();
void on_socket_connected();
void flush_received_buffers();
virtual void register_on_ready_to_read(Function<void()>) = 0;
virtual void register_on_ready_to_write(Function<void()>) = 0;
virtual bool can_read_line() const = 0;
@ -56,7 +58,7 @@ protected:
virtual bool can_read() const = 0;
virtual ByteBuffer receive(size_t) = 0;
virtual bool eof() const = 0;
virtual bool write(const ByteBuffer&) = 0;
virtual bool write(ReadonlyBytes) = 0;
virtual bool is_established() const = 0;
virtual bool should_fail_on_empty_payload() const { return true; }
virtual void read_while_data_available(Function<IterationDecision()> read) { read(); };
@ -73,11 +75,13 @@ protected:
State m_state { State::InStatus };
int m_code { -1 };
HashMap<String, String, CaseInsensitiveStringTraits> m_headers;
Vector<ByteBuffer> m_received_buffers;
Vector<ByteBuffer, 2> m_received_buffers;
size_t m_buffered_size { 0 };
size_t m_received_size { 0 };
bool m_sent_data { 0 };
Optional<ssize_t> m_current_chunk_remaining_size;
Optional<size_t> m_current_chunk_total_size;
bool m_can_stream_response { true };
};
}

View file

@ -24,6 +24,7 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <AK/FileStream.h>
#include <AK/SharedBuffer.h>
#include <LibProtocol/Client.h>
#include <LibProtocol/Download.h>
@ -47,16 +48,20 @@ bool Client::is_supported_protocol(const String& protocol)
return send_sync<Messages::ProtocolServer::IsSupportedProtocol>(protocol)->supported();
}
RefPtr<Download> Client::start_download(const String& method, const String& url, const HashMap<String, String>& request_headers, const ByteBuffer& request_body)
template<typename RequestHashMapTraits>
RefPtr<Download> Client::start_download(const String& method, const String& url, const HashMap<String, String, RequestHashMapTraits>& request_headers, ReadonlyBytes request_body)
{
IPC::Dictionary header_dictionary;
for (auto& it : request_headers)
header_dictionary.add(it.key, it.value);
i32 download_id = send_sync<Messages::ProtocolServer::StartDownload>(method, url, header_dictionary, String::copy(request_body))->download_id();
if (download_id < 0)
auto response = send_sync<Messages::ProtocolServer::StartDownload>(method, url, header_dictionary, ByteBuffer::copy(request_body));
auto download_id = response->download_id();
auto response_fd = response->response_fd().fd();
if (download_id < 0 || response_fd < 0)
return nullptr;
auto download = Download::create_from_id({}, *this, download_id);
download->set_download_fd({}, response_fd);
m_downloads.set(download_id, download);
return download;
}
@ -79,9 +84,8 @@ void Client::handle(const Messages::ProtocolClient::DownloadFinished& message)
{
RefPtr<Download> download;
if ((download = m_downloads.get(message.download_id()).value_or(nullptr))) {
download->did_finish({}, message.success(), message.status_code(), message.total_size(), message.shbuf_id(), message.response_headers());
download->did_finish({}, message.success(), message.total_size());
}
send_sync<Messages::ProtocolServer::DisownSharedBuffer>(message.shbuf_id());
m_downloads.remove(message.download_id());
}
@ -92,6 +96,15 @@ void Client::handle(const Messages::ProtocolClient::DownloadProgress& message)
}
}
void Client::handle(const Messages::ProtocolClient::HeadersBecameAvailable& message)
{
if (auto download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr))) {
HashMap<String, String, CaseInsensitiveStringTraits> headers;
message.response_headers().for_each_entry([&](auto& name, auto& value) { headers.set(name, value); });
download->did_receive_headers({}, headers, message.status_code());
}
}
OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> Client::handle(const Messages::ProtocolClient::CertificateRequested& message)
{
if (auto download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr))) {
@ -102,3 +115,6 @@ OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> Client::handle(co
}
}
template RefPtr<Protocol::Download> Protocol::Client::start_download(const String& method, const String& url, const HashMap<String, String>& request_headers, ReadonlyBytes request_body);
template RefPtr<Protocol::Download> Protocol::Client::start_download(const String& method, const String& url, const HashMap<String, String, CaseInsensitiveStringTraits>& request_headers, ReadonlyBytes request_body);

View file

@ -44,7 +44,8 @@ public:
virtual void handshake() override;
bool is_supported_protocol(const String&);
RefPtr<Download> start_download(const String& method, const String& url, const HashMap<String, String>& request_headers = {}, const ByteBuffer& request_body = {});
template<typename RequestHashMapTraits = Traits<String>>
RefPtr<Download> start_download(const String& method, const String& url, const HashMap<String, String, RequestHashMapTraits>& request_headers = {}, ReadonlyBytes request_body = {});
bool stop_download(Badge<Download>, Download&);
bool set_certificate(Badge<Download>, Download&, String, String);
@ -55,6 +56,7 @@ private:
virtual void handle(const Messages::ProtocolClient::DownloadProgress&) override;
virtual void handle(const Messages::ProtocolClient::DownloadFinished&) override;
virtual OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> handle(const Messages::ProtocolClient::CertificateRequested&) override;
virtual void handle(const Messages::ProtocolClient::HeadersBecameAvailable&) override;
HashMap<i32, RefPtr<Download>> m_downloads;
};

View file

@ -41,25 +41,81 @@ bool Download::stop()
return m_client->stop_download({}, *this);
}
void Download::did_finish(Badge<Client>, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, const IPC::Dictionary& response_headers)
void Download::stream_into(OutputStream& stream)
{
ASSERT(!m_internal_stream_data);
auto notifier = Core::Notifier::construct(fd(), Core::Notifier::Read);
m_internal_stream_data = make<InternalStreamData>(fd());
m_internal_stream_data->read_notifier = notifier;
auto user_on_finish = move(on_finish);
on_finish = [this](auto success, auto total_size) {
m_internal_stream_data->success = success;
m_internal_stream_data->total_size = total_size;
m_internal_stream_data->download_done = true;
};
notifier->on_ready_to_read = [this, &stream, user_on_finish = move(user_on_finish)] {
constexpr size_t buffer_size = 1 * KiB;
static char buf[buffer_size];
auto nread = m_internal_stream_data->read_stream.read({ buf, buffer_size });
if (!stream.write_or_error({ buf, nread })) {
// FIXME: What do we do here?
TODO();
}
if (m_internal_stream_data->read_stream.eof() || (m_internal_stream_data->download_done && !m_internal_stream_data->success)) {
m_internal_stream_data->read_notifier->close();
user_on_finish(m_internal_stream_data->success, m_internal_stream_data->total_size);
} else {
m_internal_stream_data->read_stream.handle_any_error();
}
};
}
void Download::set_should_buffer_all_input(bool value)
{
if (m_should_buffer_all_input == value)
return;
if (m_internal_buffered_data && !value) {
m_internal_buffered_data = nullptr;
m_should_buffer_all_input = false;
return;
}
ASSERT(!m_internal_stream_data);
ASSERT(!m_internal_buffered_data);
ASSERT(on_buffered_download_finish); // Not having this set makes no sense.
m_internal_buffered_data = make<InternalBufferedData>(fd());
m_should_buffer_all_input = true;
on_headers_received = [this](auto& headers, auto response_code) {
m_internal_buffered_data->response_headers = headers;
m_internal_buffered_data->response_code = move(response_code);
};
on_finish = [this](auto success, u32 total_size) {
auto output_buffer = m_internal_buffered_data->payload_stream.copy_into_contiguous_buffer();
on_buffered_download_finish(
success,
total_size,
m_internal_buffered_data->response_headers,
m_internal_buffered_data->response_code,
output_buffer);
};
stream_into(m_internal_buffered_data->payload_stream);
}
void Download::did_finish(Badge<Client>, bool success, u32 total_size)
{
if (!on_finish)
return;
ReadonlyBytes payload;
RefPtr<SharedBuffer> shared_buffer;
if (success && shbuf_id != -1) {
shared_buffer = SharedBuffer::create_from_shbuf_id(shbuf_id);
payload = { shared_buffer->data<void>(), total_size };
}
// FIXME: It's a bit silly that we copy the response headers here just so we can move them into a HashMap with different traits.
HashMap<String, String, CaseInsensitiveStringTraits> caseless_response_headers;
response_headers.for_each_entry([&](auto& name, auto& value) {
caseless_response_headers.set(name, value);
});
on_finish(success, payload, move(shared_buffer), caseless_response_headers, status_code);
on_finish(success, total_size);
}
void Download::did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size)
@ -68,6 +124,12 @@ void Download::did_progress(Badge<Client>, Optional<u32> total_size, u32 downloa
on_progress(total_size, downloaded_size);
}
void Download::did_receive_headers(Badge<Client>, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code)
{
if (on_headers_received)
on_headers_received(response_headers, response_code);
}
void Download::did_request_certificates(Badge<Client>)
{
if (on_certificate_requested) {

View file

@ -28,10 +28,13 @@
#include <AK/Badge.h>
#include <AK/ByteBuffer.h>
#include <AK/FileStream.h>
#include <AK/Function.h>
#include <AK/MemoryStream.h>
#include <AK/RefCounted.h>
#include <AK/String.h>
#include <AK/WeakPtr.h>
#include <LibCore/Notifier.h>
#include <LibIPC/Forward.h>
namespace Protocol {
@ -51,20 +54,65 @@ public:
}
int id() const { return m_download_id; }
int fd() const { return m_fd; }
bool stop();
Function<void(bool success, ReadonlyBytes payload, RefPtr<SharedBuffer> payload_storage, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> status_code)> on_finish;
void stream_into(OutputStream&);
bool should_buffer_all_input() const { return m_should_buffer_all_input; }
/// Note: Will override `on_finish', and `on_headers_received', and expects `on_buffered_download_finish' to be set!
void set_should_buffer_all_input(bool);
/// Note: Must be set before `set_should_buffer_all_input(true)`.
Function<void(bool success, u32 total_size, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code, ReadonlyBytes payload)> on_buffered_download_finish;
Function<void(bool success, u32 total_size)> on_finish;
Function<void(Optional<u32> total_size, u32 downloaded_size)> on_progress;
Function<void(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code)> on_headers_received;
Function<CertificateAndKey()> on_certificate_requested;
void did_finish(Badge<Client>, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, const IPC::Dictionary& response_headers);
void did_finish(Badge<Client>, bool success, u32 total_size);
void did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size);
void did_receive_headers(Badge<Client>, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> response_code);
void did_request_certificates(Badge<Client>);
RefPtr<Core::Notifier>& write_notifier(Badge<Client>) { return m_write_notifier; }
void set_download_fd(Badge<Client>, int fd) { m_fd = fd; }
private:
explicit Download(Client&, i32 download_id);
WeakPtr<Client> m_client;
int m_download_id { -1 };
RefPtr<Core::Notifier> m_write_notifier;
int m_fd { -1 };
bool m_should_buffer_all_input { false };
struct InternalBufferedData {
InternalBufferedData(int fd)
: read_stream(fd)
{
}
InputFileStream read_stream;
DuplexMemoryStream payload_stream;
HashMap<String, String, CaseInsensitiveStringTraits> response_headers;
Optional<u32> response_code;
};
struct InternalStreamData {
InternalStreamData(int fd)
: read_stream(fd)
{
}
InputFileStream read_stream;
RefPtr<Core::Notifier> read_notifier;
bool success;
u32 total_size { 0 };
bool download_done { false };
};
OwnPtr<InternalBufferedData> m_internal_buffered_data;
OwnPtr<InternalStreamData> m_internal_stream_data;
};
}

View file

@ -92,10 +92,10 @@ void XMLHttpRequest::send()
// we need to make ResourceLoader give us more detailed updates than just "done" and "error".
ResourceLoader::the().load(
m_window->document().complete_url(m_url),
[weak_this = make_weak_ptr()](auto& data, auto&) {
[weak_this = make_weak_ptr()](auto data, auto&) {
if (!weak_this)
return;
const_cast<XMLHttpRequest&>(*weak_this).m_response = data;
const_cast<XMLHttpRequest&>(*weak_this).m_response = ByteBuffer::copy(data);
const_cast<XMLHttpRequest&>(*weak_this).set_ready_state(ReadyState::Done);
const_cast<XMLHttpRequest&>(*weak_this).dispatch_event(DOM::Event::create(HTML::EventNames::load));
},

View file

@ -128,7 +128,7 @@ void HTMLScriptElement::prepare_script(Badge<HTMLDocumentParser>)
// FIXME: This load should be made asynchronous and the parser should spin an event loop etc.
ResourceLoader::the().load_sync(
url,
[this, url](auto& data, auto&) {
[this, url](auto data, auto&) {
if (data.is_null()) {
dbg() << "HTMLScriptElement: Failed to load " << url;
return;

View file

@ -171,6 +171,7 @@ bool FrameLoader::load(const LoadRequest& request, Type type)
return true;
if (url.protocol() == "http" || url.protocol() == "https") {
#if 0
URL favicon_url;
favicon_url.set_protocol(url.protocol());
favicon_url.set_host(url.host());
@ -191,6 +192,7 @@ bool FrameLoader::load(const LoadRequest& request, Type type)
if (auto* page = frame().page())
page->client().page_did_change_favicon(*bitmap);
});
#endif
}
return true;

View file

@ -84,10 +84,10 @@ static String mime_type_from_content_type(const String& content_type)
return content_type;
}
void Resource::did_load(Badge<ResourceLoader>, const ByteBuffer& data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers)
void Resource::did_load(Badge<ResourceLoader>, ReadonlyBytes data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers)
{
ASSERT(!m_loaded);
m_encoded_data = data;
m_encoded_data = ByteBuffer::copy(data);
m_response_headers = headers;
m_loaded = true;

View file

@ -77,7 +77,7 @@ public:
void for_each_client(Function<void(ResourceClient&)>);
void did_load(Badge<ResourceLoader>, const ByteBuffer& data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers);
void did_load(Badge<ResourceLoader>, ReadonlyBytes data, const HashMap<String, String, CaseInsensitiveStringTraits>& headers);
void did_fail(Badge<ResourceLoader>, const String& error);
protected:

View file

@ -53,13 +53,13 @@ ResourceLoader::ResourceLoader()
{
}
void ResourceLoader::load_sync(const URL& url, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
void ResourceLoader::load_sync(const URL& url, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
{
Core::EventLoop loop;
load(
url,
[&](auto& data, auto& response_headers) {
[&](auto data, auto& response_headers) {
success_callback(data, response_headers);
loop.quit(0);
},
@ -97,7 +97,7 @@ RefPtr<Resource> ResourceLoader::load_resource(Resource::Type type, const LoadRe
load(
request,
[=](auto& data, auto& headers) {
[=](auto data, auto& headers) {
const_cast<Resource&>(*resource).did_load({}, data, headers);
},
[=](auto& error) {
@ -107,7 +107,7 @@ RefPtr<Resource> ResourceLoader::load_resource(Resource::Type type, const LoadRe
return resource;
}
void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
void ResourceLoader::load(const LoadRequest& request, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
{
auto& url = request.url();
if (is_port_blocked(url.port())) {
@ -170,7 +170,12 @@ void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBu
error_callback("Failed to initiate load");
return;
}
download->on_finish = [this, success_callback = move(success_callback), error_callback = move(error_callback)](bool success, ReadonlyBytes payload, auto, auto& response_headers, auto status_code) {
download->on_buffered_download_finish = [this, success_callback = move(success_callback), error_callback = move(error_callback), download](bool success, auto, auto& response_headers, auto status_code, ReadonlyBytes payload) {
if (status_code.has_value() && status_code.value() >= 400 && status_code.value() <= 499) {
if (error_callback)
error_callback(String::format("HTTP error (%u)", status_code.value()));
return;
}
--m_pending_loads;
if (on_load_counter_change)
on_load_counter_change();
@ -179,13 +184,9 @@ void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBu
error_callback("HTTP load failed");
return;
}
if (status_code.has_value() && status_code.value() >= 400 && status_code.value() <= 499) {
if (error_callback)
error_callback(String::format("HTTP error (%u)", status_code.value()));
return;
}
success_callback(ByteBuffer::copy(payload.data(), payload.size()), response_headers);
success_callback(payload, response_headers);
};
download->set_should_buffer_all_input(true);
download->on_certificate_requested = []() -> Protocol::Download::CertificateAndKey {
return {};
};
@ -199,7 +200,7 @@ void ResourceLoader::load(const LoadRequest& request, Function<void(const ByteBu
error_callback(String::format("Protocol not implemented: %s", url.protocol().characters()));
}
void ResourceLoader::load(const URL& url, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
void ResourceLoader::load(const URL& url, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback)
{
LoadRequest request;
request.set_url(url);

View file

@ -44,9 +44,9 @@ public:
RefPtr<Resource> load_resource(Resource::Type, const LoadRequest&);
void load(const LoadRequest&, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
void load(const URL&, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
void load_sync(const URL&, Function<void(const ByteBuffer&, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
void load(const LoadRequest&, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
void load(const URL&, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
void load_sync(const URL&, Function<void(ReadonlyBytes, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)> success_callback, Function<void(const String&)> error_callback = nullptr);
Function<void()> on_load_counter_change;

View file

@ -62,16 +62,17 @@ OwnPtr<Messages::ProtocolServer::StartDownloadResponse> ClientConnection::handle
{
URL url(message.url());
if (!url.is_valid())
return make<Messages::ProtocolServer::StartDownloadResponse>(-1);
return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1);
auto* protocol = Protocol::find_by_name(url.protocol());
if (!protocol)
return make<Messages::ProtocolServer::StartDownloadResponse>(-1);
auto download = protocol->start_download(*this, message.method(), url, message.request_headers().entries(), message.request_body().to_byte_buffer());
return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1);
auto download = protocol->start_download(*this, message.method(), url, message.request_headers().entries(), message.request_body());
if (!download)
return make<Messages::ProtocolServer::StartDownloadResponse>(-1);
return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1);
auto id = download->id();
auto fd = download->download_fd();
m_downloads.set(id, move(download));
return make<Messages::ProtocolServer::StartDownloadResponse>(id);
return make<Messages::ProtocolServer::StartDownloadResponse>(id, fd);
}
OwnPtr<Messages::ProtocolServer::StopDownloadResponse> ClientConnection::handle(const Messages::ProtocolServer::StopDownload& message)
@ -86,22 +87,20 @@ OwnPtr<Messages::ProtocolServer::StopDownloadResponse> ClientConnection::handle(
return make<Messages::ProtocolServer::StopDownloadResponse>(success);
}
void ClientConnection::did_finish_download(Badge<Download>, Download& download, bool success)
void ClientConnection::did_receive_headers(Badge<Download>, Download& download)
{
RefPtr<SharedBuffer> buffer;
if (success && download.payload().size() > 0 && !download.payload().is_null()) {
buffer = SharedBuffer::create_with_size(download.payload().size());
memcpy(buffer->data<void>(), download.payload().data(), download.payload().size());
buffer->seal();
buffer->share_with(client_pid());
m_shared_buffers.set(buffer->shbuf_id(), buffer);
}
ASSERT(download.total_size().has_value());
IPC::Dictionary response_headers;
for (auto& it : download.response_headers())
response_headers.add(it.key, it.value);
post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.status_code(), download.total_size().value(), buffer ? buffer->shbuf_id() : -1, response_headers));
post_message(Messages::ProtocolClient::HeadersBecameAvailable(download.id(), move(response_headers), download.status_code()));
}
void ClientConnection::did_finish_download(Badge<Download>, Download& download, bool success)
{
ASSERT(download.total_size().has_value());
post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.total_size().value()));
m_downloads.remove(download.id());
}
@ -121,12 +120,6 @@ OwnPtr<Messages::ProtocolServer::GreetResponse> ClientConnection::handle(const M
return make<Messages::ProtocolServer::GreetResponse>(client_id());
}
OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> ClientConnection::handle(const Messages::ProtocolServer::DisownSharedBuffer& message)
{
m_shared_buffers.remove(message.shbuf_id());
return make<Messages::ProtocolServer::DisownSharedBufferResponse>();
}
OwnPtr<Messages::ProtocolServer::SetCertificateResponse> ClientConnection::handle(const Messages::ProtocolServer::SetCertificate& message)
{
auto* download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr));

View file

@ -45,6 +45,7 @@ public:
virtual void die() override;
void did_receive_headers(Badge<Download>, Download&);
void did_finish_download(Badge<Download>, Download&, bool success);
void did_progress_download(Badge<Download>, Download&);
void did_request_certificates(Badge<Download>, Download&);
@ -54,11 +55,9 @@ private:
virtual OwnPtr<Messages::ProtocolServer::IsSupportedProtocolResponse> handle(const Messages::ProtocolServer::IsSupportedProtocol&) override;
virtual OwnPtr<Messages::ProtocolServer::StartDownloadResponse> handle(const Messages::ProtocolServer::StartDownload&) override;
virtual OwnPtr<Messages::ProtocolServer::StopDownloadResponse> handle(const Messages::ProtocolServer::StopDownload&) override;
virtual OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> handle(const Messages::ProtocolServer::DisownSharedBuffer&) override;
virtual OwnPtr<Messages::ProtocolServer::SetCertificateResponse> handle(const Messages::ProtocolServer::SetCertificate&);
virtual OwnPtr<Messages::ProtocolServer::SetCertificateResponse> handle(const Messages::ProtocolServer::SetCertificate&) override;
HashMap<i32, OwnPtr<Download>> m_downloads;
HashMap<i32, RefPtr<AK::SharedBuffer>> m_shared_buffers;
};
}

View file

@ -33,9 +33,10 @@ namespace ProtocolServer {
// FIXME: What about rollover?
static i32 s_next_id = 1;
Download::Download(ClientConnection& client)
Download::Download(ClientConnection& client, NonnullOwnPtr<OutputFileStream>&& output_stream)
: m_client(client)
, m_id(s_next_id++)
, m_output_stream(move(output_stream))
{
}
@ -48,15 +49,10 @@ void Download::stop()
m_client.did_finish_download({}, *this, false);
}
void Download::set_payload(const ByteBuffer& payload)
{
m_payload = payload;
m_total_size = payload.size();
}
void Download::set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers)
{
m_response_headers = response_headers;
m_client.did_receive_headers({}, *this);
}
void Download::set_certificate(String, String)

View file

@ -26,8 +26,9 @@
#pragma once
#include <AK/ByteBuffer.h>
#include <AK/FileStream.h>
#include <AK/HashMap.h>
#include <AK/NonnullOwnPtr.h>
#include <AK/Optional.h>
#include <AK/RefCounted.h>
#include <AK/URL.h>
@ -45,30 +46,35 @@ public:
Optional<u32> status_code() const { return m_status_code; }
Optional<u32> total_size() const { return m_total_size; }
size_t downloaded_size() const { return m_downloaded_size; }
const ByteBuffer& payload() const { return m_payload; }
const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers() const { return m_response_headers; }
void stop();
virtual void set_certificate(String, String);
// FIXME: Want Badge<Protocol>, but can't make one from HttpProtocol, etc.
void set_download_fd(int fd) { m_download_fd = fd; }
int download_fd() const { return m_download_fd; }
protected:
explicit Download(ClientConnection&);
explicit Download(ClientConnection&, NonnullOwnPtr<OutputFileStream>&&);
void did_finish(bool success);
void did_progress(Optional<u32> total_size, u32 downloaded_size);
void set_status_code(u32 status_code) { m_status_code = status_code; }
void did_request_certificates();
void set_payload(const ByteBuffer&);
void set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>&);
void set_downloaded_size(size_t size) { m_downloaded_size = size; }
const OutputFileStream& output_stream() const { return *m_output_stream; }
private:
ClientConnection& m_client;
i32 m_id { 0 };
int m_download_fd { -1 }; // Passed to client.
URL m_url;
Optional<u32> m_status_code;
Optional<u32> m_total_size {};
size_t m_downloaded_size { 0 };
ByteBuffer m_payload;
NonnullOwnPtr<OutputFileStream> m_output_stream;
HashMap<String, String, CaseInsensitiveStringTraits> m_response_headers;
};

View file

@ -30,13 +30,13 @@
namespace ProtocolServer {
GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job)
: Download(client)
GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
: Download(client, move(output_stream))
, m_job(job)
{
m_job->on_finish = [this](bool success) {
if (auto* response = m_job->response()) {
set_payload(response->payload());
set_downloaded_size(this->output_stream().size());
if (!response->meta().is_empty()) {
HashMap<String, String, CaseInsensitiveStringTraits> headers;
headers.set("meta", response->meta());
@ -76,9 +76,9 @@ GeminiDownload::~GeminiDownload()
m_job->shutdown();
}
NonnullOwnPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job)
NonnullOwnPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
{
return adopt_own(*new GeminiDownload(client, move(job)));
return adopt_own(*new GeminiDownload(client, move(job), move(output_stream)));
}
}

View file

@ -36,10 +36,10 @@ namespace ProtocolServer {
class GeminiDownload final : public Download {
public:
virtual ~GeminiDownload() override;
static NonnullOwnPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>);
static NonnullOwnPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>, NonnullOwnPtr<OutputFileStream>&&);
private:
explicit GeminiDownload(ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>);
explicit GeminiDownload(ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>, NonnullOwnPtr<OutputFileStream>&&);
virtual void set_certificate(String certificate, String key) override;

View file

@ -40,12 +40,22 @@ GeminiProtocol::~GeminiProtocol()
{
}
OwnPtr<Download> GeminiProtocol::start_download(ClientConnection& client, const String&, const URL& url, const HashMap<String, String>&, const ByteBuffer&)
OwnPtr<Download> GeminiProtocol::start_download(ClientConnection& client, const String&, const URL& url, const HashMap<String, String>&, ReadonlyBytes)
{
Gemini::GeminiRequest request;
request.set_url(url);
auto job = Gemini::GeminiJob::construct(request);
auto download = GeminiDownload::create_with_job({}, client, (Gemini::GeminiJob&)*job);
int fd_pair[2] { 0 };
if (pipe(fd_pair) != 0) {
auto saved_errno = errno;
dbgln("Protocol: pipe() failed: {}", strerror(saved_errno));
return nullptr;
}
auto output_stream = make<OutputFileStream>(fd_pair[1]);
output_stream->make_unbuffered();
auto job = Gemini::GeminiJob::construct(request, *output_stream);
auto download = GeminiDownload::create_with_job({}, client, (Gemini::GeminiJob&)*job, move(output_stream));
download->set_download_fd(fd_pair[0]);
job->start();
return download;
}

View file

@ -35,7 +35,7 @@ public:
GeminiProtocol();
virtual ~GeminiProtocol() override;
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>&, const ByteBuffer& request_body) override;
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>&, ReadonlyBytes body) override;
};
}

View file

@ -30,15 +30,21 @@
namespace ProtocolServer {
HttpDownload::HttpDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job)
: Download(client)
HttpDownload::HttpDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
: Download(client, move(output_stream))
, m_job(job)
{
m_job->on_headers_received = [this](auto& headers, auto response_code) {
if (response_code.has_value())
set_status_code(response_code.value());
set_response_headers(headers);
};
m_job->on_finish = [this](bool success) {
if (auto* response = m_job->response()) {
set_status_code(response->code());
set_payload(response->payload());
set_response_headers(response->headers());
set_downloaded_size(this->output_stream().size());
}
// if we didn't know the total size, pretend that the download finished successfully
@ -60,9 +66,9 @@ HttpDownload::~HttpDownload()
m_job->shutdown();
}
NonnullOwnPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job)
NonnullOwnPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
{
return adopt_own(*new HttpDownload(client, move(job)));
return adopt_own(*new HttpDownload(client, move(job), move(output_stream)));
}
}

View file

@ -36,10 +36,10 @@ namespace ProtocolServer {
class HttpDownload final : public Download {
public:
virtual ~HttpDownload() override;
static NonnullOwnPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpJob>);
static NonnullOwnPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpJob>, NonnullOwnPtr<OutputFileStream>&&);
private:
explicit HttpDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpJob>);
explicit HttpDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpJob>, NonnullOwnPtr<OutputFileStream>&&);
NonnullRefPtr<HTTP::HttpJob> m_job;
};

View file

@ -28,6 +28,7 @@
#include <LibHTTP/HttpRequest.h>
#include <ProtocolServer/HttpDownload.h>
#include <ProtocolServer/HttpProtocol.h>
#include <fcntl.h>
namespace ProtocolServer {
@ -40,7 +41,7 @@ HttpProtocol::~HttpProtocol()
{
}
OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, const ByteBuffer& request_body)
OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, ReadonlyBytes body)
{
HTTP::HttpRequest request;
if (method.equals_ignoring_case("post"))
@ -49,9 +50,20 @@ OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const St
request.set_method(HTTP::HttpRequest::Method::GET);
request.set_url(url);
request.set_headers(headers);
request.set_body(request_body);
auto job = HTTP::HttpJob::construct(request);
auto download = HttpDownload::create_with_job({}, client, (HTTP::HttpJob&)*job);
request.set_body(body);
int fd_pair[2] { 0 };
if (pipe(fd_pair) != 0) {
auto saved_errno = errno;
dbgln("Protocol: pipe() failed: {}", strerror(saved_errno));
return nullptr;
}
auto output_stream = make<OutputFileStream>(fd_pair[1]);
output_stream->make_unbuffered();
auto job = HTTP::HttpJob::construct(request, *output_stream);
auto download = HttpDownload::create_with_job({}, client, (HTTP::HttpJob&)*job, move(output_stream));
download->set_download_fd(fd_pair[0]);
job->start();
return download;
}

View file

@ -35,7 +35,7 @@ public:
HttpProtocol();
virtual ~HttpProtocol() override;
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) override;
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) override;
};
}

View file

@ -30,15 +30,21 @@
namespace ProtocolServer {
HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job)
: Download(client)
HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
: Download(client, move(output_stream))
, m_job(job)
{
m_job->on_headers_received = [this](auto& headers, auto response_code) {
if (response_code.has_value())
set_status_code(response_code.value());
set_response_headers(headers);
};
m_job->on_finish = [this](bool success) {
if (auto* response = m_job->response()) {
set_status_code(response->code());
set_payload(response->payload());
set_response_headers(response->headers());
set_downloaded_size(this->output_stream().size());
}
// if we didn't know the total size, pretend that the download finished successfully
@ -68,9 +74,9 @@ HttpsDownload::~HttpsDownload()
m_job->shutdown();
}
NonnullOwnPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job)
NonnullOwnPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream)
{
return adopt_own(*new HttpsDownload(client, move(job)));
return adopt_own(*new HttpsDownload(client, move(job), move(output_stream)));
}
}

View file

@ -36,10 +36,10 @@ namespace ProtocolServer {
class HttpsDownload final : public Download {
public:
virtual ~HttpsDownload() override;
static NonnullOwnPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>);
static NonnullOwnPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>, NonnullOwnPtr<OutputFileStream>&&);
private:
explicit HttpsDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>);
explicit HttpsDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>, NonnullOwnPtr<OutputFileStream>&&);
virtual void set_certificate(String certificate, String key) override;

View file

@ -40,7 +40,7 @@ HttpsProtocol::~HttpsProtocol()
{
}
OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, const ByteBuffer& request_body)
OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, ReadonlyBytes body)
{
HTTP::HttpRequest request;
if (method.equals_ignoring_case("post"))
@ -49,9 +49,19 @@ OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const S
request.set_method(HTTP::HttpRequest::Method::GET);
request.set_url(url);
request.set_headers(headers);
request.set_body(request_body);
auto job = HTTP::HttpsJob::construct(request);
auto download = HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job);
request.set_body(body);
int fd_pair[2] { 0 };
if (pipe(fd_pair) != 0) {
auto saved_errno = errno;
dbgln("Protocol: pipe() failed: {}", strerror(saved_errno));
return nullptr;
}
auto output_stream = make<OutputFileStream>(fd_pair[1]);
output_stream->make_unbuffered();
auto job = HTTP::HttpsJob::construct(request, *output_stream);
auto download = HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job, move(output_stream));
download->set_download_fd(fd_pair[0]);
job->start();
return download;
}

View file

@ -35,7 +35,7 @@ public:
HttpsProtocol();
virtual ~HttpsProtocol() override;
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) override;
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) override;
};
}

View file

@ -37,7 +37,7 @@ public:
virtual ~Protocol();
const String& name() const { return m_name; }
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) = 0;
virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) = 0;
static Protocol* find_by_name(const String&);

View file

@ -2,7 +2,8 @@ endpoint ProtocolClient = 13
{
// Download notifications
DownloadProgress(i32 download_id, Optional<u32> total_size, u32 downloaded_size) =|
DownloadFinished(i32 download_id, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, IPC::Dictionary response_headers) =|
DownloadFinished(i32 download_id, bool success, u32 total_size) =|
HeadersBecameAvailable(i32 download_id, IPC::Dictionary response_headers, Optional<u32> status_code) =|
// Certificate requests
CertificateRequested(i32 download_id) => ()

View file

@ -3,14 +3,11 @@ endpoint ProtocolServer = 9
// Basic protocol
Greet() => (i32 client_id)
// FIXME: It would be nice if the kernel provided a way to avoid this
DisownSharedBuffer(i32 shbuf_id) => ()
// Test if a specific protocol is supported, e.g "http"
IsSupportedProtocol(String protocol) => (bool supported)
// Download API
StartDownload(String method, URL url, IPC::Dictionary request_headers, String request_body) => (i32 download_id)
StartDownload(String method, URL url, IPC::Dictionary request_headers, ByteBuffer request_body) => (i32 download_id, IPC::File response_fd)
StopDownload(i32 download_id) => (bool success)
SetCertificate(i32 download_id, String certificate, String key) => (bool success)
}

View file

@ -35,7 +35,7 @@
int main(int, char**)
{
if (pledge("stdio inet shared_buffer accept unix rpath cpath fattr", nullptr) < 0) {
if (pledge("stdio inet shared_buffer accept unix rpath cpath fattr sendfd recvfd", nullptr) < 0) {
perror("pledge");
return 1;
}
@ -45,7 +45,7 @@ int main(int, char**)
Core::EventLoop event_loop;
// FIXME: Establish a connection to LookupServer and then drop "unix"?
if (pledge("stdio inet shared_buffer accept unix", nullptr) < 0) {
if (pledge("stdio inet shared_buffer accept unix sendfd recvfd", nullptr) < 0) {
perror("pledge");
return 1;
}

View file

@ -24,6 +24,7 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <AK/FileStream.h>
#include <AK/GenericLexer.h>
#include <AK/LexicalPath.h>
#include <AK/NumberFormat.h>
@ -116,29 +117,50 @@ private:
bool m_might_be_wrong { false };
};
static void do_write(ReadonlyBytes payload)
{
size_t length_remaining = payload.size();
size_t length_written = 0;
while (length_remaining > 0) {
auto nwritten = fwrite(payload.data() + length_written, sizeof(char), length_remaining, stdout);
if (nwritten > 0) {
length_remaining -= nwritten;
length_written += nwritten;
continue;
}
template<typename ConditionT>
class ConditionalOutputFileStream final : public OutputFileStream {
public:
template<typename... Args>
ConditionalOutputFileStream(ConditionT&& condition, Args... args)
: OutputFileStream(args...)
, m_condition(condition)
{
}
if (feof(stdout)) {
fprintf(stderr, "pro: unexpected eof while writing\n");
~ConditionalOutputFileStream()
{
if (!m_condition())
return;
}
if (ferror(stdout)) {
fprintf(stderr, "pro: error while writing\n");
return;
if (!m_buffer.is_empty()) {
OutputFileStream::write(m_buffer);
m_buffer.clear();
}
}
}
private:
size_t write(ReadonlyBytes bytes) override
{
if (!m_condition()) {
write_to_buffer:;
m_buffer.append(bytes.data(), bytes.size());
return bytes.size();
}
if (!m_buffer.is_empty()) {
auto size = OutputFileStream::write(m_buffer);
m_buffer = m_buffer.slice(size, m_buffer.size() - size);
}
if (!m_buffer.is_empty())
goto write_to_buffer;
return OutputFileStream::write(bytes);
}
ConditionT m_condition;
ByteBuffer m_buffer;
};
int main(int argc, char** argv)
{
@ -195,6 +217,8 @@ int main(int argc, char** argv)
timeval prev_time, current_time, time_diff;
gettimeofday(&prev_time, nullptr);
bool received_actual_headers = false;
download->on_progress = [&](Optional<u32> maybe_total_size, u32 downloaded_size) {
fprintf(stderr, "\r\033[2K");
if (maybe_total_size.has_value()) {
@ -215,10 +239,13 @@ int main(int argc, char** argv)
previous_downloaded_size = downloaded_size;
prev_time = current_time;
};
download->on_finish = [&](bool success, auto payload, auto, auto& response_headers, auto) {
fprintf(stderr, "\033]9;-1;\033\\");
fprintf(stderr, "\n");
if (success && save_at_provided_name) {
if (save_at_provided_name) {
download->on_headers_received = [&](auto& response_headers, auto status_code) {
if (received_actual_headers)
return;
dbg() << "Received headers! response code = " << status_code.value_or(0);
received_actual_headers = true; // And not trailers!
String output_name;
if (auto content_disposition = response_headers.get("Content-Disposition"); content_disposition.has_value()) {
auto& value = content_disposition.value();
@ -245,17 +272,26 @@ int main(int argc, char** argv)
if (freopen(output_name.characters(), "w", stdout) == nullptr) {
perror("freopen");
success = false; // oops!
loop.quit(1);
return;
}
}
if (success)
do_write(payload);
else
};
}
download->on_finish = [&](bool success, auto) {
fprintf(stderr, "\033]9;-1;\033\\");
fprintf(stderr, "\n");
if (!success)
fprintf(stderr, "Download failed :(\n");
loop.quit(0);
};
auto output_stream = ConditionalOutputFileStream { [&] { return save_at_provided_name ? received_actual_headers : true; }, stdout };
download->stream_into(output_stream);
dbgprintf("started download with id %d\n", download->id());
return loop.exec();
auto rc = loop.exec();
// FIXME: This shouldn't be needed.
fclose(stdout);
return rc;
}