Pārlūkot izejas kodu

LibHTTP: Respect the 'Connection: close' header on keep-alive jobs

If the server responds with this header, we _must_ close the connection,
as the server is allowed to ignore the socket and not respond to
anything past that response.
Fixes some RequestServer spins.
Ali Mohammad Pur 3 gadi atpakaļ
vecāks
revīzija
b0a9c5673e

+ 3 - 3
Userland/Libraries/LibCore/NetworkJob.cpp

@@ -23,7 +23,7 @@ void NetworkJob::start(NonnullRefPtr<Core::Socket>)
 {
 }
 
-void NetworkJob::shutdown()
+void NetworkJob::shutdown(ShutdownMode)
 {
 }
 
@@ -40,7 +40,7 @@ void NetworkJob::did_finish(NonnullRefPtr<NetworkResponse>&& response)
     dbgln_if(CNETWORKJOB_DEBUG, "{} job did_finish", *this);
     VERIFY(on_finish);
     on_finish(true);
-    shutdown();
+    shutdown(ShutdownMode::DetachFromSocket);
 }
 
 void NetworkJob::did_fail(Error error)
@@ -56,7 +56,7 @@ void NetworkJob::did_fail(Error error)
     dbgln_if(CNETWORKJOB_DEBUG, "{}{{{:p}}} job did_fail! error: {} ({})", class_name(), this, (unsigned)error, to_string(error));
     VERIFY(on_finish);
     on_finish(false);
-    shutdown();
+    shutdown(ShutdownMode::DetachFromSocket);
 }
 
 void NetworkJob::did_progress(Optional<u32> total_size, u32 downloaded)

+ 6 - 2
Userland/Libraries/LibCore/NetworkJob.h

@@ -35,12 +35,16 @@ public:
     NetworkResponse* response() { return m_response.ptr(); }
     const NetworkResponse* response() const { return m_response.ptr(); }
 
+    enum class ShutdownMode {
+        DetachFromSocket,
+        CloseSocket,
+    };
     virtual void start(NonnullRefPtr<Core::Socket>) = 0;
-    virtual void shutdown() = 0;
+    virtual void shutdown(ShutdownMode) = 0;
 
     void cancel()
     {
-        shutdown();
+        shutdown(ShutdownMode::DetachFromSocket);
         m_error = Error::Cancelled;
     }
 

+ 10 - 0
Userland/Libraries/LibCore/Socket.cpp

@@ -208,6 +208,16 @@ void Socket::did_update_fd(int fd)
     }
 }
 
+bool Socket::close()
+{
+    m_connected = false;
+    if (m_notifier)
+        m_notifier->close();
+    if (m_read_notifier)
+        m_read_notifier->close();
+    return IODevice::close();
+}
+
 void Socket::ensure_read_notifier()
 {
     VERIFY(m_connected);

+ 2 - 0
Userland/Libraries/LibCore/Socket.h

@@ -42,6 +42,8 @@ public:
     SocketAddress destination_address() const { return m_destination_address; }
     int destination_port() const { return m_destination_port; }
 
+    virtual bool close() override;
+
     Function<void()> on_connected;
     Function<void()> on_error;
     Function<void()> on_ready_to_read;

+ 8 - 4
Userland/Libraries/LibGemini/GeminiJob.cpp

@@ -58,13 +58,17 @@ void GeminiJob::start(NonnullRefPtr<Core::Socket> socket)
     }
 }
 
-void GeminiJob::shutdown()
+void GeminiJob::shutdown(ShutdownMode mode)
 {
     if (!m_socket)
         return;
-    m_socket->on_tls_ready_to_read = nullptr;
-    m_socket->on_tls_connected = nullptr;
-    m_socket = nullptr;
+    if (mode == ShutdownMode::CloseSocket) {
+        m_socket->close();
+    } else {
+        m_socket->on_tls_ready_to_read = nullptr;
+        m_socket->on_tls_connected = nullptr;
+        m_socket = nullptr;
+    }
 }
 
 void GeminiJob::read_while_data_available(Function<IterationDecision()> read)

+ 1 - 1
Userland/Libraries/LibGemini/GeminiJob.h

@@ -28,7 +28,7 @@ public:
     }
 
     virtual void start(NonnullRefPtr<Core::Socket>) override;
-    virtual void shutdown() override;
+    virtual void shutdown(ShutdownMode) override;
     void set_certificate(String certificate, String key);
 
     Core::Socket const* socket() const { return m_socket; }

+ 1 - 1
Userland/Libraries/LibGemini/Job.h

@@ -20,7 +20,7 @@ public:
     virtual ~Job() override;
 
     virtual void start(NonnullRefPtr<Core::Socket>) override = 0;
-    virtual void shutdown() override = 0;
+    virtual void shutdown(ShutdownMode) override = 0;
 
     GeminiResponse* response() { return static_cast<GeminiResponse*>(Core::NetworkJob::response()); }
     const GeminiResponse* response() const { return static_cast<const GeminiResponse*>(Core::NetworkJob::response()); }

+ 8 - 4
Userland/Libraries/LibHTTP/HttpJob.cpp

@@ -43,13 +43,17 @@ void HttpJob::start(NonnullRefPtr<Core::Socket> socket)
     };
 }
 
-void HttpJob::shutdown()
+void HttpJob::shutdown(ShutdownMode mode)
 {
     if (!m_socket)
         return;
-    m_socket->on_ready_to_read = nullptr;
-    m_socket->on_connected = nullptr;
-    m_socket = nullptr;
+    if (mode == ShutdownMode::CloseSocket) {
+        m_socket->close();
+    } else {
+        m_socket->on_ready_to_read = nullptr;
+        m_socket->on_connected = nullptr;
+        m_socket = nullptr;
+    }
 }
 
 void HttpJob::register_on_ready_to_read(Function<void()> callback)

+ 1 - 1
Userland/Libraries/LibHTTP/HttpJob.h

@@ -28,7 +28,7 @@ public:
     }
 
     virtual void start(NonnullRefPtr<Core::Socket>) override;
-    virtual void shutdown() override;
+    virtual void shutdown(ShutdownMode) override;
 
     Core::Socket const* socket() const { return m_socket; }
     URL url() const { return m_request.url(); }

+ 9 - 5
Userland/Libraries/LibHTTP/HttpsJob.cpp

@@ -62,14 +62,18 @@ void HttpsJob::start(NonnullRefPtr<Core::Socket> socket)
     }
 }
 
-void HttpsJob::shutdown()
+void HttpsJob::shutdown(ShutdownMode mode)
 {
     if (!m_socket)
         return;
-    m_socket->on_tls_ready_to_read = nullptr;
-    m_socket->on_tls_connected = nullptr;
-    m_socket->set_on_tls_ready_to_write(nullptr);
-    m_socket = nullptr;
+    if (mode == ShutdownMode::CloseSocket) {
+        m_socket->close();
+    } else {
+        m_socket->on_tls_ready_to_read = nullptr;
+        m_socket->on_tls_connected = nullptr;
+        m_socket->set_on_tls_ready_to_write(nullptr);
+        m_socket = nullptr;
+    }
 }
 
 void HttpsJob::set_certificate(String certificate, String private_key)

+ 1 - 1
Userland/Libraries/LibHTTP/HttpsJob.h

@@ -29,7 +29,7 @@ public:
     }
 
     virtual void start(NonnullRefPtr<Core::Socket>) override;
-    virtual void shutdown() override;
+    virtual void shutdown(ShutdownMode) override;
     void set_certificate(String certificate, String key);
 
     Core::Socket const* socket() const { return m_socket; }

+ 4 - 0
Userland/Libraries/LibHTTP/Job.cpp

@@ -412,6 +412,10 @@ void Job::finish_up()
     m_has_scheduled_finish = true;
     auto response = HttpResponse::create(m_code, move(m_headers));
     deferred_invoke([this, response = move(response)] {
+        // If the server responded with "Connection: close", close the connection
+        // as the server may or may not want to close the socket.
+        if (auto result = response->headers().get("Connection"sv); result.has_value() && result.value().equals_ignoring_case("close"sv))
+            shutdown(ShutdownMode::CloseSocket);
         did_finish(response);
     });
 }

+ 1 - 1
Userland/Libraries/LibHTTP/Job.h

@@ -22,7 +22,7 @@ public:
     virtual ~Job() override;
 
     virtual void start(NonnullRefPtr<Core::Socket>) override = 0;
-    virtual void shutdown() override = 0;
+    virtual void shutdown(ShutdownMode) override = 0;
 
     HttpResponse* response() { return static_cast<HttpResponse*>(Core::NetworkJob::response()); }
     const HttpResponse* response() const { return static_cast<const HttpResponse*>(Core::NetworkJob::response()); }