Parcourir la source

RequestServer: Make pre-connection job refcounted

Fixes #22582

Previously, the job and the (cache of them) would lead to a UAF, as
after `.start()` was called on the job it'd be immediately destroyed.

Example of previous bug:

```
// Note due to the cache &jobA == &jobB
auto& jobA = Job::ensure("https://r.bing.com/");
auto& jobB = Job::ensure("https://r.bing.com/");
// Previously, the first .start() free'd the job
jobA.start();
// So the second .start() was a UAF
jobB.start();
```
MacDue il y a 1 an
Parent
commit
a1d669fe63

+ 19 - 19
Userland/Services/RequestServer/ConnectionCache.h

@@ -61,21 +61,21 @@ struct Connection {
         Function<Vector<TLS::Certificate>()> provide_client_certificates {};
 
         template<typename T>
-        static JobData create(T& job)
+        static JobData create(NonnullRefPtr<T> job)
         {
             // Clang-format _really_ messes up formatting this, so just format it manually.
             // clang-format off
             return JobData {
-                .start = [&job](auto& socket) {
-                    job.start(socket);
+                .start = [job](auto& socket) {
+                    job->start(socket);
                 },
-                .fail = [&job](auto error) {
-                    job.fail(error);
+                .fail = [job](auto error) {
+                    job->fail(error);
                 },
-                .provide_client_certificates = [&job] {
-                    if constexpr (requires { job.on_certificate_requested; }) {
-                        if (job.on_certificate_requested)
-                            return job.on_certificate_requested();
+                .provide_client_certificates = [job] {
+                    if constexpr (requires { job->on_certificate_requested; }) {
+                        if (job->on_certificate_requested)
+                            return job->on_certificate_requested();
                     } else {
                         // "use" `job`, otherwise clang gets sad.
                         (void)job;
@@ -170,7 +170,7 @@ ErrorOr<void> recreate_socket_if_needed(T& connection, URL const& url)
     return {};
 }
 
-decltype(auto) get_or_create_connection(auto& cache, URL const& url, auto& job, Core::ProxyData proxy_data = {})
+decltype(auto) get_or_create_connection(auto& cache, URL const& url, auto job, Core::ProxyData proxy_data = {})
 {
     using CacheEntryType = RemoveCVReference<decltype(*cache.begin()->value)>;
     auto& sockets_for_url = *cache.ensure({ url.serialized_host().release_value_but_fixme_should_propagate_errors().to_byte_string(), url.port_or_default(), proxy_data }, [] { return make<CacheEntryType>(); });
@@ -186,16 +186,16 @@ decltype(auto) get_or_create_connection(auto& cache, URL const& url, auto& job,
         auto connection_result = proxy.tunnel<typename ConnectionType::SocketType, typename ConnectionType::StorageType>(url);
         if (connection_result.is_error()) {
             dbgln("ConnectionCache: Connection to {} failed: {}", url, connection_result.error());
-            Core::deferred_invoke([&job] {
-                job.fail(Core::NetworkJob::Error::ConnectionFailed);
+            Core::deferred_invoke([job] {
+                job->fail(Core::NetworkJob::Error::ConnectionFailed);
             });
             return ReturnType { nullptr };
         }
         auto socket_result = Core::BufferedSocket<typename ConnectionType::StorageType>::create(connection_result.release_value());
         if (socket_result.is_error()) {
             dbgln("ConnectionCache: Failed to make a buffered socket for {}: {}", url, socket_result.error());
-            Core::deferred_invoke([&job] {
-                job.fail(Core::NetworkJob::Error::ConnectionFailed);
+            Core::deferred_invoke([job] {
+                job->fail(Core::NetworkJob::Error::ConnectionFailed);
             });
             return ReturnType { nullptr };
         }
@@ -225,8 +225,8 @@ decltype(auto) get_or_create_connection(auto& cache, URL const& url, auto& job,
         index = it.index();
     }
     if (sockets_for_url.is_empty()) {
-        Core::deferred_invoke([&job] {
-            job.fail(Core::NetworkJob::Error::ConnectionFailed);
+        Core::deferred_invoke([job] {
+            job->fail(Core::NetworkJob::Error::ConnectionFailed);
         });
         return ReturnType { nullptr };
     }
@@ -235,13 +235,13 @@ decltype(auto) get_or_create_connection(auto& cache, URL const& url, auto& job,
     if (!connection.has_started) {
         if (auto result = recreate_socket_if_needed(connection, url); result.is_error()) {
             dbgln("ConnectionCache: request failed to start, failed to make a socket: {}", result.error());
-            Core::deferred_invoke([&job] {
-                job.fail(Core::NetworkJob::Error::ConnectionFailed);
+            Core::deferred_invoke([job] {
+                job->fail(Core::NetworkJob::Error::ConnectionFailed);
             });
             return ReturnType { nullptr };
         }
         dbgln_if(REQUESTSERVER_DEBUG, "Immediately start request for url {} in {} - {}", url, &connection, connection.socket);
-        Core::deferred_invoke([&connection, url, &job] {
+        Core::deferred_invoke([&connection, url, job] {
             connection.has_started = true;
             connection.removal_timer->stop();
             connection.timer.start();

+ 26 - 16
Userland/Services/RequestServer/ConnectionFromClient.cpp

@@ -6,6 +6,8 @@
 
 #include <AK/Badge.h>
 #include <AK/NonnullOwnPtr.h>
+#include <AK/RefCounted.h>
+#include <AK/Weakable.h>
 #include <LibCore/Proxy.h>
 #include <RequestServer/ConnectionFromClient.h>
 #include <RequestServer/Protocol.h>
@@ -105,19 +107,19 @@ Messages::RequestServer::SetCertificateResponse ConnectionFromClient::set_certif
     return success;
 }
 
-struct Job {
-    explicit Job(URL url)
-        : m_url(move(url))
-    {
-    }
-
-    static Job& ensure(URL const& url)
+class Job : public RefCounted<Job>
+    , public Weakable<Job> {
+public:
+    static NonnullRefPtr<Job> ensure(URL const& url)
     {
-        if (auto it = s_jobs.find(url); it == s_jobs.end()) {
-            auto job = make<Job>(url);
-            s_jobs.set(url, move(job));
+        RefPtr<Job> job;
+        if (auto it = s_jobs.find(url); it != s_jobs.end())
+            job = it->value.strong_ref();
+        if (job == nullptr) {
+            job = adopt_ref(*new Job(url));
+            s_jobs.set(url, job);
         }
-        return *s_jobs.find(url)->value;
+        return *job;
     }
 
     void start(Core::Socket& socket)
@@ -125,19 +127,27 @@ struct Job {
         auto is_connected = socket.is_open();
         VERIFY(is_connected);
         ConnectionCache::request_did_finish(m_url, &socket);
-        s_jobs.remove(m_url);
     }
+
     void fail(Core::NetworkJob::Error error)
     {
         dbgln("Pre-connect to {} failed: {}", m_url, Core::to_string(error));
+    }
+
+    void will_be_destroyed() const
+    {
         s_jobs.remove(m_url);
     }
-    URL m_url;
 
 private:
-    static HashMap<URL, NonnullOwnPtr<Job>> s_jobs;
+    explicit Job(URL url)
+        : m_url(move(url))
+    {
+    }
+
+    URL m_url;
+    inline static HashMap<URL, WeakPtr<Job>> s_jobs {};
 };
-HashMap<URL, NonnullOwnPtr<Job>> Job::s_jobs {};
 
 void ConnectionFromClient::ensure_connection(URL const& url, ::RequestServer::CacheLevel const& cache_level)
 {
@@ -153,7 +163,7 @@ void ConnectionFromClient::ensure_connection(URL const& url, ::RequestServer::Ca
         });
     }
 
-    auto& job = Job::ensure(url);
+    auto job = Job::ensure(url);
     dbgln("EnsureConnection: Pre-connect to {}", url);
     auto do_preconnect = [&](auto& cache) {
         auto serialized_host = url.serialized_host().release_value_but_fixme_should_propagate_errors().to_byte_string();

+ 1 - 1
Userland/Services/RequestServer/GeminiProtocol.cpp

@@ -31,7 +31,7 @@ OwnPtr<Request> GeminiProtocol::start_request(ConnectionFromClient& client, Byte
     auto protocol_request = GeminiRequest::create_with_job({}, client, *job, move(output_stream));
     protocol_request->set_request_fd(pipe_result.value().read_fd);
 
-    ConnectionCache::get_or_create_connection(ConnectionCache::g_tls_connection_cache, url, *job, proxy_data);
+    ConnectionCache::get_or_create_connection(ConnectionCache::g_tls_connection_cache, url, job, proxy_data);
 
     return protocol_request;
 }

+ 2 - 2
Userland/Services/RequestServer/HttpCommon.h

@@ -103,9 +103,9 @@ OwnPtr<Request> start_request(TBadgedProtocol&& protocol, ConnectionFromClient&
     protocol_request->set_request_fd(pipe_result.value().read_fd);
 
     if constexpr (IsSame<typename TBadgedProtocol::Type, HttpsProtocol>)
-        ConnectionCache::get_or_create_connection(ConnectionCache::g_tls_connection_cache, url, *job, proxy_data);
+        ConnectionCache::get_or_create_connection(ConnectionCache::g_tls_connection_cache, url, job, proxy_data);
     else
-        ConnectionCache::get_or_create_connection(ConnectionCache::g_tcp_connection_cache, url, *job, proxy_data);
+        ConnectionCache::get_or_create_connection(ConnectionCache::g_tcp_connection_cache, url, job, proxy_data);
 
     return protocol_request;
 }