Explorar o código

ProtocolServer: Support request headers

You can now pass a dictionary of request headers when starting a new
download in ProtocolServer.

The HTTP and HTTPS protocol will include the headers in their requests.
Andreas Kling %!s(int64=5) %!d(string=hai) anos
pai
achega
897998017a

+ 14 - 1
Libraries/LibHTTP/HttpRequest.cpp

@@ -71,7 +71,14 @@ ByteBuffer HttpRequest::to_raw_request() const
     }
     builder.append(" HTTP/1.1\r\nHost: ");
     builder.append(m_url.host());
-    builder.append("\r\nConnection: close\r\n\r\n");
+    builder.append("\r\n");
+    for (auto& header : m_headers) {
+        builder.append(header.name);
+        builder.append(": ");
+        builder.append(header.value);
+        builder.append("\r\n");
+    }
+    builder.append("Connection: close\r\n\r\n");
     return builder.to_byte_buffer();
 }
 
@@ -181,4 +188,10 @@ Optional<HttpRequest> HttpRequest::from_raw_request(const ByteBuffer& raw_reques
     return request;
 }
 
+void HttpRequest::set_headers(const HashMap<String,String>& headers)
+{
+    for (auto& it : headers)
+        m_headers.append({ it.key, it.value });
+}
+
 }

+ 2 - 0
Libraries/LibHTTP/HttpRequest.h

@@ -65,6 +65,8 @@ public:
 
     RefPtr<Core::NetworkJob> schedule();
 
+    void set_headers(const HashMap<String, String>&);
+
     static Optional<HttpRequest> from_raw_request(const ByteBuffer&);
 
 private:

+ 6 - 2
Libraries/LibProtocol/Client.cpp

@@ -47,9 +47,13 @@ bool Client::is_supported_protocol(const String& protocol)
     return send_sync<Messages::ProtocolServer::IsSupportedProtocol>(protocol)->supported();
 }
 
-RefPtr<Download> Client::start_download(const String& url)
+RefPtr<Download> Client::start_download(const String& url, const HashMap<String, String>& request_headers)
 {
-    i32 download_id = send_sync<Messages::ProtocolServer::StartDownload>(url)->download_id();
+    IPC::Dictionary header_dictionary;
+    for (auto& it : request_headers)
+        header_dictionary.add(it.key, it.value);
+
+    i32 download_id = send_sync<Messages::ProtocolServer::StartDownload>(url, header_dictionary)->download_id();
     if (download_id < 0)
         return nullptr;
     auto download = Download::create_from_id({}, *this, download_id);

+ 2 - 1
Libraries/LibProtocol/Client.h

@@ -44,7 +44,8 @@ public:
     virtual void handshake() override;
 
     bool is_supported_protocol(const String&);
-    RefPtr<Download> start_download(const String& url);
+    RefPtr<Download> start_download(const String& url, const HashMap<String, String>& request_headers = {});
+
 
     bool stop_download(Badge<Download>, Download&);
 

+ 1 - 1
Services/ProtocolServer/ClientConnection.cpp

@@ -64,7 +64,7 @@ OwnPtr<Messages::ProtocolServer::StartDownloadResponse> ClientConnection::handle
     auto* protocol = Protocol::find_by_name(url.protocol());
     if (!protocol)
         return make<Messages::ProtocolServer::StartDownloadResponse>(-1);
-    auto download = protocol->start_download(*this, url);
+    auto download = protocol->start_download(*this, url, message.request_headers().entries());
     if (!download)
         return make<Messages::ProtocolServer::StartDownloadResponse>(-1);
     auto id = download->id();

+ 1 - 1
Services/ProtocolServer/GeminiProtocol.cpp

@@ -40,7 +40,7 @@ GeminiProtocol::~GeminiProtocol()
 {
 }
 
-OwnPtr<Download> GeminiProtocol::start_download(ClientConnection& client, const URL& url)
+OwnPtr<Download> GeminiProtocol::start_download(ClientConnection& client, const URL& url, const HashMap<String, String>&)
 {
     Gemini::GeminiRequest request;
     request.set_url(url);

+ 1 - 1
Services/ProtocolServer/GeminiProtocol.h

@@ -35,7 +35,7 @@ public:
     GeminiProtocol();
     virtual ~GeminiProtocol() override;
 
-    virtual OwnPtr<Download> start_download(ClientConnection&, const URL&) override;
+    virtual OwnPtr<Download> start_download(ClientConnection&, const URL&, const HashMap<String, String>&) override;
 };
 
 }

+ 2 - 1
Services/ProtocolServer/HttpProtocol.cpp

@@ -40,11 +40,12 @@ HttpProtocol::~HttpProtocol()
 {
 }
 
-OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const URL& url)
+OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const URL& url, const HashMap<String, String>& headers)
 {
     HTTP::HttpRequest request;
     request.set_method(HTTP::HttpRequest::Method::GET);
     request.set_url(url);
+    request.set_headers(headers);
     auto job = request.schedule();
     if (!job)
         return nullptr;

+ 1 - 1
Services/ProtocolServer/HttpProtocol.h

@@ -35,7 +35,7 @@ public:
     HttpProtocol();
     virtual ~HttpProtocol() override;
 
-    virtual OwnPtr<Download> start_download(ClientConnection&, const URL&) override;
+    virtual OwnPtr<Download> start_download(ClientConnection&, const URL&, const HashMap<String, String>& headers) override;
 };
 
 }

+ 2 - 1
Services/ProtocolServer/HttpsProtocol.cpp

@@ -40,11 +40,12 @@ HttpsProtocol::~HttpsProtocol()
 {
 }
 
-OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const URL& url)
+OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const URL& url, const HashMap<String, String>& headers)
 {
     HTTP::HttpRequest request;
     request.set_method(HTTP::HttpRequest::Method::GET);
     request.set_url(url);
+    request.set_headers(headers);
     auto job = HTTP::HttpsJob::construct(request);
     auto download = HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job);
     job->start();

+ 1 - 1
Services/ProtocolServer/HttpsProtocol.h

@@ -35,7 +35,7 @@ public:
     HttpsProtocol();
     virtual ~HttpsProtocol() override;
 
-    virtual OwnPtr<Download> start_download(ClientConnection&, const URL&) override;
+    virtual OwnPtr<Download> start_download(ClientConnection&, const URL&, const HashMap<String, String>& headers) override;
 };
 
 }

+ 1 - 1
Services/ProtocolServer/Protocol.h

@@ -37,7 +37,7 @@ public:
     virtual ~Protocol();
 
     const String& name() const { return m_name; }
-    virtual OwnPtr<Download> start_download(ClientConnection&, const URL&) = 0;
+    virtual OwnPtr<Download> start_download(ClientConnection&, const URL&, const HashMap<String, String>& headers) = 0;
 
     static Protocol* find_by_name(const String&);
 

+ 1 - 1
Services/ProtocolServer/ProtocolServer.ipc

@@ -10,6 +10,6 @@ endpoint ProtocolServer = 9
     IsSupportedProtocol(String protocol) => (bool supported)
 
     // Download API
-    StartDownload(String url) => (i32 download_id)
+    StartDownload(String url, IPC::Dictionary request_headers) => (i32 download_id)
     StopDownload(i32 download_id) => (bool success)
 }