瀏覽代碼

ProtocolServer+LibTLS: Pipe certificate requests from LibTLS to clients

This makes gemini.circumlunar.space (and some more gemini pages) work
again :^)
AnotherTest 4 年之前
父節點
當前提交
97256ad977

+ 13 - 0
Libraries/LibGemini/GeminiJob.cpp

@@ -63,6 +63,10 @@ void GeminiJob::start()
     m_socket->on_tls_finished = [this] {
         finish_up();
     };
+    m_socket->on_tls_certificate_request = [this](auto&) {
+        if (on_certificate_requested)
+            on_certificate_requested(*this);
+    };
     bool success = ((TLS::TLSv12&)*m_socket).connect(m_request.url().host(), m_request.url().port());
     if (!success) {
         deferred_invoke([this](auto&) {
@@ -89,6 +93,15 @@ void GeminiJob::read_while_data_available(Function<IterationDecision()> read)
     }
 }
 
+void GeminiJob::set_certificate(String certificate, String private_key)
+{
+    if (!m_socket->add_client_key(ByteBuffer::wrap(certificate.characters(), certificate.length()), ByteBuffer::wrap(private_key.characters(), private_key.length()))) {
+        dbg() << "LibGemini: Failed to set a client certificate";
+        // FIXME: Do something about this failure
+        ASSERT_NOT_REACHED();
+    }
+}
+
 void GeminiJob::register_on_ready_to_read(Function<void()> callback)
 {
     m_socket->on_tls_ready_to_read = [callback = move(callback)](auto&) {

+ 3 - 0
Libraries/LibGemini/GeminiJob.h

@@ -48,6 +48,9 @@ public:
 
     virtual void start() override;
     virtual void shutdown() override;
+    void set_certificate(String certificate, String key);
+
+    Function<void(GeminiJob&)> on_certificate_requested;
 
 protected:
     virtual void register_on_ready_to_read(Function<void()>) override;

+ 13 - 0
Libraries/LibHTTP/HttpsJob.cpp

@@ -64,6 +64,10 @@ void HttpsJob::start()
     m_socket->on_tls_finished = [&] {
         finish_up();
     };
+    m_socket->on_tls_certificate_request = [this](auto&) {
+        if (on_certificate_requested)
+            on_certificate_requested(*this);
+    };
     bool success = ((TLS::TLSv12&)*m_socket).connect(m_request.url().host(), m_request.url().port());
     if (!success) {
         deferred_invoke([this](auto&) {
@@ -82,6 +86,15 @@ void HttpsJob::shutdown()
     m_socket = nullptr;
 }
 
+void HttpsJob::set_certificate(String certificate, String private_key)
+{
+    if (!m_socket->add_client_key(ByteBuffer::wrap(certificate.characters(), certificate.length()), ByteBuffer::wrap(private_key.characters(), private_key.length()))) {
+        dbg() << "LibHTTP: Failed to set a client certificate";
+        // FIXME: Do something about this failure
+        ASSERT_NOT_REACHED();
+    }
+}
+
 void HttpsJob::read_while_data_available(Function<IterationDecision()> read)
 {
     while (m_socket->can_read()) {

+ 3 - 0
Libraries/LibHTTP/HttpsJob.h

@@ -49,6 +49,9 @@ public:
 
     virtual void start() override;
     virtual void shutdown() override;
+    void set_certificate(String certificate, String key);
+
+    Function<void(HttpsJob&)> on_certificate_requested;
 
 protected:
     virtual void register_on_ready_to_read(Function<void()>) override;

+ 16 - 0
Libraries/LibProtocol/Client.cpp

@@ -68,6 +68,13 @@ bool Client::stop_download(Badge<Download>, Download& download)
     return send_sync<Messages::ProtocolServer::StopDownload>(download.id())->success();
 }
 
+bool Client::set_certificate(Badge<Download>, Download& download, String certificate, String key)
+{
+    if (!m_downloads.contains(download.id()))
+        return false;
+    return send_sync<Messages::ProtocolServer::SetCertificate>(download.id(), move(certificate), move(key))->success();
+}
+
 void Client::handle(const Messages::ProtocolClient::DownloadFinished& message)
 {
     RefPtr<Download> download;
@@ -85,4 +92,13 @@ void Client::handle(const Messages::ProtocolClient::DownloadProgress& message)
     }
 }
 
+OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> Client::handle(const Messages::ProtocolClient::CertificateRequested& message)
+{
+    if (auto download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr))) {
+        download->did_request_certificates({});
+    }
+
+    return make<Messages::ProtocolClient::CertificateRequestedResponse>();
+}
+
 }

+ 2 - 1
Libraries/LibProtocol/Client.h

@@ -46,14 +46,15 @@ public:
     bool is_supported_protocol(const String&);
     RefPtr<Download> start_download(const String& url, const HashMap<String, String>& request_headers = {});
 
-
     bool stop_download(Badge<Download>, Download&);
+    bool set_certificate(Badge<Download>, Download&, String, String);
 
 private:
     Client();
 
     virtual void handle(const Messages::ProtocolClient::DownloadProgress&) override;
     virtual void handle(const Messages::ProtocolClient::DownloadFinished&) override;
+    virtual OwnPtr<Messages::ProtocolClient::CertificateRequestedResponse> handle(const Messages::ProtocolClient::CertificateRequested&) override;
 
     HashMap<i32, RefPtr<Download>> m_downloads;
 };

+ 10 - 0
Libraries/LibProtocol/Download.cpp

@@ -67,4 +67,14 @@ void Download::did_progress(Badge<Client>, Optional<u32> total_size, u32 downloa
     if (on_progress)
         on_progress(total_size, downloaded_size);
 }
+
+void Download::did_request_certificates(Badge<Client>)
+{
+    if (on_certificate_requested) {
+        auto result = on_certificate_requested();
+        if (!m_client->set_certificate({}, *this, result.certificate, result.key)) {
+            dbg() << "Download: set_certificate failed";
+        }
+    }
+}
 }

+ 7 - 0
Libraries/LibProtocol/Download.h

@@ -40,6 +40,11 @@ class Client;
 
 class Download : public RefCounted<Download> {
 public:
+    struct CertificateAndKey {
+        String certificate;
+        String key;
+    };
+
     static NonnullRefPtr<Download> create_from_id(Badge<Client>, Client& client, i32 download_id)
     {
         return adopt(*new Download(client, download_id));
@@ -50,9 +55,11 @@ public:
 
     Function<void(bool success, const ByteBuffer& payload, RefPtr<SharedBuffer> payload_storage, const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers, Optional<u32> status_code)> on_finish;
     Function<void(Optional<u32> total_size, u32 downloaded_size)> on_progress;
+    Function<CertificateAndKey()> on_certificate_requested;
 
     void did_finish(Badge<Client>, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, const IPC::Dictionary& response_headers);
     void did_progress(Badge<Client>, Optional<u32> total_size, u32 downloaded_size);
+    void did_request_certificates(Badge<Client>);
 
 private:
     explicit Download(Client&, i32 download_id);

+ 25 - 0
Libraries/LibTLS/TLSv12.cpp

@@ -27,6 +27,7 @@
 #include <LibCore/DateTime.h>
 #include <LibCore/Timer.h>
 #include <LibCrypto/ASN1/DER.h>
+#include <LibCrypto/ASN1/PEM.h>
 #include <LibCrypto/PK/Code/EMSA_PSS.h>
 #include <LibTLS/TLSv12.h>
 
@@ -721,4 +722,28 @@ TLSv12::TLSv12(Core::Object* parent, Version version)
     }
 }
 
+bool TLSv12::add_client_key(const ByteBuffer& certificate_pem_buffer, const ByteBuffer& rsa_key) // FIXME: This should not be bound to RSA
+{
+    if (certificate_pem_buffer.is_empty() || rsa_key.is_empty()) {
+        return true;
+    }
+    auto decoded_certificate = decode_pem(certificate_pem_buffer.span(), 0);
+    if (decoded_certificate.is_empty()) {
+        dbg() << "Certificate not PEM";
+        return false;
+    }
+
+    auto maybe_certificate = parse_asn1(decoded_certificate);
+    if (!maybe_certificate.has_value()) {
+        dbg() << "Invalid certificate";
+        return false;
+    }
+
+    Crypto::PK::RSA rsa(rsa_key);
+    auto certificate = maybe_certificate.value();
+    certificate.private_key = rsa.private_key();
+
+    return add_client_key(certificate);
+}
+
 }

+ 9 - 0
Libraries/LibTLS/TLSv12.h

@@ -206,6 +206,7 @@ struct Certificate {
     CertificateKeyAlgorithm ec_algorithm;
     ByteBuffer exponent;
     Crypto::PK::RSAPublicKey<Crypto::UnsignedBigInteger> public_key;
+    Crypto::PK::RSAPrivateKey<Crypto::UnsignedBigInteger> private_key;
     String issuer_country;
     String issuer_state;
     String issuer_location;
@@ -318,6 +319,13 @@ public:
     bool load_certificates(const ByteBuffer& pem_buffer);
     bool load_private_key(const ByteBuffer& pem_buffer);
 
+    bool add_client_key(const ByteBuffer& certificate_pem_buffer, const ByteBuffer& key_pem_buffer);
+    bool add_client_key(Certificate certificate)
+    {
+        m_context.client_certificates.append(move(certificate));
+        return true;
+    }
+
     ByteBuffer finish_build();
 
     const StringView& alpn() const { return m_context.negotiated_alpn; }
@@ -349,6 +357,7 @@ public:
     Function<void(AlertDescription)> on_tls_error;
     Function<void()> on_tls_connected;
     Function<void()> on_tls_finished;
+    Function<void(TLSv12&)> on_tls_certificate_request;
 
 private:
     explicit TLSv12(Core::Object* parent, Version version = Version::V12);

+ 3 - 0
Libraries/LibWeb/Loader/ResourceLoader.cpp

@@ -179,6 +179,9 @@ void ResourceLoader::load(const URL& url, Function<void(const ByteBuffer&, const
             }
             success_callback(ByteBuffer::copy(payload.data(), payload.size()), response_headers);
         };
+        download->on_certificate_requested = []() -> Protocol::Download::CertificateAndKey {
+            return {};
+        };
         ++m_pending_loads;
         if (on_load_counter_change)
             on_load_counter_change();

+ 16 - 0
Services/ProtocolServer/ClientConnection.cpp

@@ -111,6 +111,11 @@ void ClientConnection::did_progress_download(Badge<Download>, Download& download
     post_message(Messages::ProtocolClient::DownloadProgress(download.id(), download.total_size(), download.downloaded_size()));
 }
 
+void ClientConnection::did_request_certificates(Badge<Download>, Download& download)
+{
+    post_message(Messages::ProtocolClient::CertificateRequested(download.id()));
+}
+
 OwnPtr<Messages::ProtocolServer::GreetResponse> ClientConnection::handle(const Messages::ProtocolServer::Greet&)
 {
     return make<Messages::ProtocolServer::GreetResponse>(client_id());
@@ -122,4 +127,15 @@ OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> ClientConnection::h
     return make<Messages::ProtocolServer::DisownSharedBufferResponse>();
 }
 
+OwnPtr<Messages::ProtocolServer::SetCertificateResponse> ClientConnection::handle(const Messages::ProtocolServer::SetCertificate& message)
+{
+    auto* download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr));
+    bool success = false;
+    if (download) {
+        download->set_certificate(message.certificate(), message.key());
+        success = true;
+    }
+    return make<Messages::ProtocolServer::SetCertificateResponse>(success);
+}
+
 }

+ 3 - 1
Services/ProtocolServer/ClientConnection.h

@@ -28,8 +28,8 @@
 
 #include <AK/HashMap.h>
 #include <LibIPC/ClientConnection.h>
-#include <ProtocolServer/ProtocolServerEndpoint.h>
 #include <ProtocolServer/Forward.h>
+#include <ProtocolServer/ProtocolServerEndpoint.h>
 
 namespace ProtocolServer {
 
@@ -46,6 +46,7 @@ public:
 
     void did_finish_download(Badge<Download>, Download&, bool success);
     void did_progress_download(Badge<Download>, Download&);
+    void did_request_certificates(Badge<Download>, Download&);
 
 private:
     virtual OwnPtr<Messages::ProtocolServer::GreetResponse> handle(const Messages::ProtocolServer::Greet&) override;
@@ -53,6 +54,7 @@ private:
     virtual OwnPtr<Messages::ProtocolServer::StartDownloadResponse> handle(const Messages::ProtocolServer::StartDownload&) override;
     virtual OwnPtr<Messages::ProtocolServer::StopDownloadResponse> handle(const Messages::ProtocolServer::StopDownload&) override;
     virtual OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> handle(const Messages::ProtocolServer::DisownSharedBuffer&) override;
+    virtual OwnPtr<Messages::ProtocolServer::SetCertificateResponse> handle(const Messages::ProtocolServer::SetCertificate&);
 
     HashMap<i32, OwnPtr<Download>> m_downloads;
     HashMap<i32, RefPtr<AK::SharedBuffer>> m_shared_buffers;

+ 10 - 1
Services/ProtocolServer/Download.cpp

@@ -25,8 +25,8 @@
  */
 
 #include <AK/Badge.h>
-#include <ProtocolServer/Download.h>
 #include <ProtocolServer/ClientConnection.h>
+#include <ProtocolServer/Download.h>
 
 namespace ProtocolServer {
 
@@ -59,6 +59,10 @@ void Download::set_response_headers(const HashMap<String, String, CaseInsensitiv
     m_response_headers = response_headers;
 }
 
+void Download::set_certificate(String, String)
+{
+}
+
 void Download::did_finish(bool success)
 {
     m_client.did_finish_download({}, *this, success);
@@ -71,4 +75,9 @@ void Download::did_progress(Optional<u32> total_size, u32 downloaded_size)
     m_client.did_progress_download({}, *this);
 }
 
+void Download::did_request_certificates()
+{
+    m_client.did_request_certificates({}, *this);
+}
+
 }

+ 2 - 0
Services/ProtocolServer/Download.h

@@ -49,6 +49,7 @@ public:
     const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers() const { return m_response_headers; }
 
     void stop();
+    virtual void set_certificate(String, String);
 
 protected:
     explicit Download(ClientConnection&);
@@ -56,6 +57,7 @@ protected:
     void did_finish(bool success);
     void did_progress(Optional<u32> total_size, u32 downloaded_size);
     void set_status_code(u32 status_code) { m_status_code = status_code; }
+    void did_request_certificates();
     void set_payload(const ByteBuffer&);
     void set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>&);
 

+ 2 - 0
Services/ProtocolServer/Forward.h

@@ -31,7 +31,9 @@ namespace ProtocolServer {
 class ClientConnection;
 class Download;
 class GeminiProtocol;
+class HttpDownload;
 class HttpProtocol;
+class HttpsDownload;
 class HttpsProtocol;
 class Protocol;
 

+ 8 - 0
Services/ProtocolServer/GeminiDownload.cpp

@@ -59,6 +59,14 @@ GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtr<Gemini::G
     m_job->on_progress = [this](Optional<u32> total, u32 current) {
         did_progress(total, current);
     };
+    m_job->on_certificate_requested = [this](auto&) {
+        did_request_certificates();
+    };
+}
+
+void GeminiDownload::set_certificate(String certificate, String key)
+{
+    m_job->set_certificate(move(certificate), move(key));
 }
 
 GeminiDownload::~GeminiDownload()

+ 2 - 0
Services/ProtocolServer/GeminiDownload.h

@@ -41,6 +41,8 @@ public:
 private:
     explicit GeminiDownload(ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>);
 
+    virtual void set_certificate(String certificate, String key) override;
+
     NonnullRefPtr<Gemini::GeminiJob> m_job;
 };
 

+ 8 - 0
Services/ProtocolServer/HttpsDownload.cpp

@@ -51,6 +51,14 @@ HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtr<HTTP::Https
     m_job->on_progress = [this](Optional<u32> total, u32 current) {
         did_progress(total, current);
     };
+    m_job->on_certificate_requested = [this](auto&) {
+        did_request_certificates();
+    };
+}
+
+void HttpsDownload::set_certificate(String certificate, String key)
+{
+    m_job->set_certificate(move(certificate), move(key));
 }
 
 HttpsDownload::~HttpsDownload()

+ 2 - 0
Services/ProtocolServer/HttpsDownload.h

@@ -41,6 +41,8 @@ public:
 private:
     explicit HttpsDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>);
 
+    virtual void set_certificate(String certificate, String key) override;
+
     NonnullRefPtr<HTTP::HttpsJob> m_job;
 };
 

+ 3 - 0
Services/ProtocolServer/ProtocolClient.ipc

@@ -3,4 +3,7 @@ endpoint ProtocolClient = 13
     // Download notifications
     DownloadProgress(i32 download_id, Optional<u32> total_size, u32 downloaded_size) =|
     DownloadFinished(i32 download_id, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, IPC::Dictionary response_headers) =|
+
+    // Certificate requests
+    CertificateRequested(i32 download_id) => ()
 }

+ 1 - 0
Services/ProtocolServer/ProtocolServer.ipc

@@ -12,4 +12,5 @@ endpoint ProtocolServer = 9
     // Download API
     StartDownload(URL url, IPC::Dictionary request_headers) => (i32 download_id)
     StopDownload(i32 download_id) => (bool success)
+    SetCertificate(i32 download_id, String certificate, String key) => (bool success)
 }