Jelajahi Sumber

RequestServer: Make WebSocket IPC APIs asynchronous

This fixes deadlocking when interacting with WebSockets while
RequestServer is trying to stream downloaded data to WebContent.
Andreas Kling 10 bulan lalu
induk
melakukan
e205723b95

+ 31 - 15
Userland/Libraries/LibRequests/RequestClient.cpp

@@ -96,45 +96,61 @@ void RequestClient::certificate_requested(i32 request_id)
 
 RefPtr<WebSocket> RequestClient::websocket_connect(const URL::URL& url, ByteString const& origin, Vector<ByteString> const& protocols, Vector<ByteString> const& extensions, HTTP::HeaderMap const& request_headers)
 {
-    auto connection_id = IPCProxy::websocket_connect(url, origin, protocols, extensions, request_headers);
-    if (connection_id < 0)
-        return nullptr;
-    auto connection = WebSocket::create_from_id({}, *this, connection_id);
-    m_websockets.set(connection_id, connection);
+    auto websocket_id = m_next_websocket_id++;
+    IPCProxy::async_websocket_connect(websocket_id, url, origin, protocols, extensions, request_headers);
+    auto connection = WebSocket::create_from_id({}, *this, websocket_id);
+    m_websockets.set(websocket_id, connection);
     return connection;
 }
 
-void RequestClient::websocket_connected(i32 connection_id)
+void RequestClient::websocket_connected(i64 websocket_id)
 {
-    auto maybe_connection = m_websockets.get(connection_id);
+    auto maybe_connection = m_websockets.get(websocket_id);
     if (maybe_connection.has_value())
         maybe_connection.value()->did_open({});
 }
 
-void RequestClient::websocket_received(i32 connection_id, bool is_text, ByteBuffer const& data)
+void RequestClient::websocket_received(i64 websocket_id, bool is_text, ByteBuffer const& data)
 {
-    auto maybe_connection = m_websockets.get(connection_id);
+    auto maybe_connection = m_websockets.get(websocket_id);
     if (maybe_connection.has_value())
         maybe_connection.value()->did_receive({}, data, is_text);
 }
 
-void RequestClient::websocket_errored(i32 connection_id, i32 message)
+void RequestClient::websocket_errored(i64 websocket_id, i32 message)
 {
-    auto maybe_connection = m_websockets.get(connection_id);
+    auto maybe_connection = m_websockets.get(websocket_id);
     if (maybe_connection.has_value())
         maybe_connection.value()->did_error({}, message);
 }
 
-void RequestClient::websocket_closed(i32 connection_id, u16 code, ByteString const& reason, bool clean)
+void RequestClient::websocket_closed(i64 websocket_id, u16 code, ByteString const& reason, bool clean)
 {
-    auto maybe_connection = m_websockets.get(connection_id);
+    auto maybe_connection = m_websockets.get(websocket_id);
     if (maybe_connection.has_value())
         maybe_connection.value()->did_close({}, code, reason, clean);
 }
 
-void RequestClient::websocket_certificate_requested(i32 connection_id)
+void RequestClient::websocket_ready_state_changed(i64 websocket_id, u32 ready_state)
+{
+    auto maybe_connection = m_websockets.get(websocket_id);
+    if (maybe_connection.has_value()) {
+        VERIFY(ready_state <= static_cast<u32>(WebSocket::ReadyState::Closed));
+        maybe_connection.value()->set_ready_state(static_cast<WebSocket::ReadyState>(ready_state));
+    }
+}
+
+void RequestClient::websocket_subprotocol(i64 websocket_id, ByteString const& subprotocol)
+{
+    auto maybe_connection = m_websockets.get(websocket_id);
+    if (maybe_connection.has_value()) {
+        maybe_connection.value()->set_subprotocol_in_use(subprotocol);
+    }
+}
+
+void RequestClient::websocket_certificate_requested(i64 websocket_id)
 {
-    auto maybe_connection = m_websockets.get(connection_id);
+    auto maybe_connection = m_websockets.get(websocket_id);
     if (maybe_connection.has_value())
         maybe_connection.value()->did_request_certificates({});
 }

+ 10 - 6
Userland/Libraries/LibRequests/RequestClient.h

@@ -44,14 +44,18 @@ private:
     virtual void certificate_requested(i32) override;
     virtual void headers_became_available(i32, HTTP::HeaderMap const&, Optional<u32> const&) override;
 
-    virtual void websocket_connected(i32) override;
-    virtual void websocket_received(i32, bool, ByteBuffer const&) override;
-    virtual void websocket_errored(i32, i32) override;
-    virtual void websocket_closed(i32, u16, ByteString const&, bool) override;
-    virtual void websocket_certificate_requested(i32) override;
+    virtual void websocket_connected(i64 websocket_id) override;
+    virtual void websocket_received(i64 websocket_id, bool, ByteBuffer const&) override;
+    virtual void websocket_errored(i64 websocket_id, i32) override;
+    virtual void websocket_closed(i64 websocket_id, u16, ByteString const&, bool) override;
+    virtual void websocket_ready_state_changed(i64 websocket_id, u32 ready_state) override;
+    virtual void websocket_subprotocol(i64 websocket_id, ByteString const& subprotocol) override;
+    virtual void websocket_certificate_requested(i64 websocket_id) override;
 
     HashMap<i32, RefPtr<Request>> m_requests;
-    HashMap<i32, NonnullRefPtr<WebSocket>> m_websockets;
+    HashMap<i64, NonnullRefPtr<WebSocket>> m_websockets;
+
+    i64 m_next_websocket_id { 0 };
 };
 
 }

+ 17 - 7
Userland/Libraries/LibRequests/WebSocket.cpp

@@ -9,25 +9,35 @@
 
 namespace Requests {
 
-WebSocket::WebSocket(RequestClient& client, i32 connection_id)
+WebSocket::WebSocket(RequestClient& client, i64 connection_id)
     : m_client(client)
-    , m_connection_id(connection_id)
+    , m_websocket_id(connection_id)
 {
 }
 
 WebSocket::ReadyState WebSocket::ready_state()
 {
-    return static_cast<WebSocket::ReadyState>(m_client->websocket_ready_state(m_connection_id));
+    return m_ready_state;
+}
+
+void WebSocket::set_ready_state(ReadyState ready_state)
+{
+    m_ready_state = ready_state;
 }
 
 ByteString WebSocket::subprotocol_in_use()
 {
-    return m_client->websocket_subprotocol_in_use(m_connection_id);
+    return m_subprotocol;
+}
+
+void WebSocket::set_subprotocol_in_use(ByteString subprotocol)
+{
+    m_subprotocol = move(subprotocol);
 }
 
 void WebSocket::send(ByteBuffer binary_or_text_message, bool is_text)
 {
-    m_client->async_websocket_send(m_connection_id, is_text, move(binary_or_text_message));
+    m_client->async_websocket_send(m_websocket_id, is_text, move(binary_or_text_message));
 }
 
 void WebSocket::send(StringView text_message)
@@ -37,7 +47,7 @@ void WebSocket::send(StringView text_message)
 
 void WebSocket::close(u16 code, ByteString reason)
 {
-    m_client->async_websocket_close(m_connection_id, code, move(reason));
+    m_client->async_websocket_close(m_websocket_id, code, move(reason));
 }
 
 void WebSocket::did_open(Badge<RequestClient>)
@@ -68,7 +78,7 @@ void WebSocket::did_request_certificates(Badge<RequestClient>)
 {
     if (on_certificate_requested) {
         auto result = on_certificate_requested();
-        if (!m_client->websocket_set_certificate(m_connection_id, result.certificate, result.key))
+        if (!m_client->websocket_set_certificate(m_websocket_id, result.certificate, result.key))
             dbgln("WebSocket: set_certificate failed");
     }
 }

+ 9 - 5
Userland/Libraries/LibRequests/WebSocket.h

@@ -44,16 +44,18 @@ public:
         Closed = 3,
     };
 
-    static NonnullRefPtr<WebSocket> create_from_id(Badge<RequestClient>, RequestClient& client, i32 connection_id)
+    static NonnullRefPtr<WebSocket> create_from_id(Badge<RequestClient>, RequestClient& client, i64 websocket_id)
     {
-        return adopt_ref(*new WebSocket(client, connection_id));
+        return adopt_ref(*new WebSocket(client, websocket_id));
     }
 
-    int id() const { return m_connection_id; }
+    i64 id() const { return m_websocket_id; }
 
     ReadyState ready_state();
+    void set_ready_state(ReadyState);
 
     ByteString subprotocol_in_use();
+    void set_subprotocol_in_use(ByteString);
 
     void send(ByteBuffer binary_or_text_message, bool is_text);
     void send(StringView text_message);
@@ -72,9 +74,11 @@ public:
     void did_request_certificates(Badge<RequestClient>);
 
 private:
-    explicit WebSocket(RequestClient&, i32 connection_id);
+    explicit WebSocket(RequestClient&, i64 websocket_id);
     WeakPtr<RequestClient> m_client;
-    int m_connection_id { -1 };
+    ReadyState m_ready_state { ReadyState::Connecting };
+    ByteString m_subprotocol;
+    i64 m_websocket_id { -1 };
 };
 
 }

+ 23 - 10
Userland/Libraries/LibWebSocket/WebSocket.cpp

@@ -42,14 +42,14 @@ void WebSocket::start()
     m_impl->on_connected = [this] {
         if (m_state != WebSocket::InternalState::EstablishingProtocolConnection)
             return;
-        m_state = WebSocket::InternalState::SendingClientHandshake;
+        set_state(WebSocket::InternalState::SendingClientHandshake);
         send_client_handshake();
         drain_read();
     };
     m_impl->on_ready_to_read = [this] {
         drain_read();
     };
-    m_state = WebSocket::InternalState::EstablishingProtocolConnection;
+    set_state(WebSocket::InternalState::EstablishingProtocolConnection);
     m_impl->connect(m_connection);
 }
 
@@ -100,7 +100,7 @@ void WebSocket::close(u16 code, ByteString const& message)
     case InternalState::SendingClientHandshake:
     case InternalState::WaitingForServerHandshake:
         // FIXME: Fail the connection.
-        m_state = InternalState::Closing;
+        set_state(InternalState::Closing);
         break;
     case InternalState::Open: {
         auto message_bytes = message.bytes();
@@ -108,7 +108,7 @@ void WebSocket::close(u16 code, ByteString const& message)
         close_payload.overwrite(0, (u8*)&code, 2);
         close_payload.overwrite(2, message_bytes.data(), message_bytes.size());
         send_frame(WebSocket::OpCode::ConnectionClose, close_payload, true);
-        m_state = InternalState::Closing;
+        set_state(InternalState::Closing);
         break;
     }
     default:
@@ -120,7 +120,7 @@ void WebSocket::drain_read()
 {
     if (m_impl->eof()) {
         // The connection got closed by the server
-        m_state = WebSocket::InternalState::Closed;
+        set_state(WebSocket::InternalState::Closed);
         notify_close(m_last_close_code, m_last_close_message, true);
         discard_connection();
         return;
@@ -218,7 +218,7 @@ void WebSocket::send_client_handshake()
 
     builder.append("\r\n"sv);
 
-    m_state = WebSocket::InternalState::WaitingForServerHandshake;
+    set_state(WebSocket::InternalState::WaitingForServerHandshake);
     auto success = m_impl->send(builder.string_view().bytes());
     VERIFY(success);
 }
@@ -282,7 +282,7 @@ void WebSocket::read_server_handshake()
                 return;
             }
 
-            m_state = WebSocket::InternalState::Open;
+            set_state(WebSocket::InternalState::Open);
             notify_open();
             return;
         }
@@ -400,7 +400,7 @@ void WebSocket::read_frame()
     auto head_bytes = get_buffered_bytes(2);
     if (head_bytes.is_null() || head_bytes.is_empty()) {
         // The connection got closed.
-        m_state = WebSocket::InternalState::Closed;
+        set_state(WebSocket::InternalState::Closed);
         notify_close(m_last_close_code, m_last_close_message, true);
         discard_connection();
         return;
@@ -487,7 +487,7 @@ void WebSocket::read_frame()
             m_last_close_code = (((u16)(payload[0] & 0xff) << 8) | ((u16)(payload[1] & 0xff)));
             m_last_close_message = ByteString(ReadonlyBytes(payload.offset_pointer(2), payload.size() - 2));
         }
-        m_state = WebSocket::InternalState::Closing;
+        set_state(WebSocket::InternalState::Closing);
         return;
     }
     if (op_code == WebSocket::OpCode::Ping) {
@@ -608,7 +608,7 @@ void WebSocket::send_frame(WebSocket::OpCode op_code, ReadonlyBytes payload, boo
 
 void WebSocket::fatal_error(WebSocket::Error error)
 {
-    m_state = WebSocket::InternalState::Errored;
+    set_state(WebSocket::InternalState::Errored);
     notify_error(error);
     discard_connection();
 }
@@ -653,4 +653,17 @@ void WebSocket::notify_message(Message message)
     on_message(move(message));
 }
 
+void WebSocket::set_state(InternalState state)
+{
+    if (m_state == state)
+        return;
+    auto old_ready_state = ready_state();
+    m_state = state;
+    auto new_ready_state = ready_state();
+    if (old_ready_state != new_ready_state) {
+        if (on_ready_state_change)
+            on_ready_state_change(ready_state());
+    }
+}
+
 }

+ 4 - 0
Userland/Libraries/LibWebSocket/WebSocket.h

@@ -46,6 +46,8 @@ public:
     Function<void()> on_open;
     Function<void(u16 code, ByteString reason, bool was_clean)> on_close;
     Function<void(Message message)> on_message;
+    Function<void(ReadyState)> on_ready_state_change;
+    Function<void(ByteString)> on_subprotocol;
 
     enum class Error {
         CouldNotEstablishConnection,
@@ -97,6 +99,8 @@ private:
 
     InternalState m_state { InternalState::NotStarted };
 
+    void set_state(InternalState);
+
     ByteString m_subprotocol_in_use { ByteString::empty() };
 
     ByteString m_websocket_key;

+ 20 - 34
Userland/Services/RequestServer/ConnectionFromClient.cpp

@@ -386,12 +386,11 @@ void ConnectionFromClient::ensure_connection(URL::URL const& url, ::RequestServe
     dbgln("FIXME: EnsureConnection: Pre-connect to {}", url);
 }
 
-static i32 s_next_websocket_id = 1;
-Messages::RequestServer::WebsocketConnectResponse ConnectionFromClient::websocket_connect(URL::URL const& url, ByteString const& origin, Vector<ByteString> const& protocols, Vector<ByteString> const& extensions, HTTP::HeaderMap const& additional_request_headers)
+void ConnectionFromClient::websocket_connect(i64 websocket_id, URL::URL const& url, ByteString const& origin, Vector<ByteString> const& protocols, Vector<ByteString> const& extensions, HTTP::HeaderMap const& additional_request_headers)
 {
     if (!url.is_valid()) {
         dbgln("WebSocket::Connect: Invalid URL requested: '{}'", url);
-        return -1;
+        return;
     }
 
     WebSocket::ConnectionInfo connection_info(url);
@@ -400,56 +399,43 @@ Messages::RequestServer::WebsocketConnectResponse ConnectionFromClient::websocke
     connection_info.set_extensions(extensions);
     connection_info.set_headers(additional_request_headers);
 
-    auto id = ++s_next_websocket_id;
     auto connection = WebSocket::WebSocket::create(move(connection_info));
-    connection->on_open = [this, id]() {
-        async_websocket_connected(id);
+    connection->on_open = [this, websocket_id]() {
+        async_websocket_connected(websocket_id);
+    };
+    connection->on_message = [this, websocket_id](auto message) {
+        async_websocket_received(websocket_id, message.is_text(), message.data());
     };
-    connection->on_message = [this, id](auto message) {
-        async_websocket_received(id, message.is_text(), message.data());
+    connection->on_error = [this, websocket_id](auto message) {
+        async_websocket_errored(websocket_id, (i32)message);
     };
-    connection->on_error = [this, id](auto message) {
-        async_websocket_errored(id, (i32)message);
+    connection->on_close = [this, websocket_id](u16 code, ByteString reason, bool was_clean) {
+        async_websocket_closed(websocket_id, code, move(reason), was_clean);
     };
-    connection->on_close = [this, id](u16 code, ByteString reason, bool was_clean) {
-        async_websocket_closed(id, code, move(reason), was_clean);
+    connection->on_ready_state_change = [this, websocket_id](auto state) {
+        async_websocket_ready_state_changed(websocket_id, (u32)state);
     };
 
     connection->start();
-    m_websockets.set(id, move(connection));
-    return id;
-}
-
-Messages::RequestServer::WebsocketReadyStateResponse ConnectionFromClient::websocket_ready_state(i32 connection_id)
-{
-    if (auto connection = m_websockets.get(connection_id).value_or({}))
-        return (u32)connection->ready_state();
-    return (u32)WebSocket::ReadyState::Closed;
-}
-
-Messages::RequestServer::WebsocketSubprotocolInUseResponse ConnectionFromClient::websocket_subprotocol_in_use(i32 connection_id)
-{
-    if (auto connection = m_websockets.get(connection_id).value_or({}))
-        return connection->subprotocol_in_use();
-    return ByteString::empty();
+    m_websockets.set(websocket_id, move(connection));
 }
 
-void ConnectionFromClient::websocket_send(i32 connection_id, bool is_text, ByteBuffer const& data)
+void ConnectionFromClient::websocket_send(i64 websocket_id, bool is_text, ByteBuffer const& data)
 {
-    if (auto connection = m_websockets.get(connection_id).value_or({}); connection && connection->ready_state() == WebSocket::ReadyState::Open)
+    if (auto connection = m_websockets.get(websocket_id).value_or({}); connection && connection->ready_state() == WebSocket::ReadyState::Open)
         connection->send(WebSocket::Message { data, is_text });
 }
 
-void ConnectionFromClient::websocket_close(i32 connection_id, u16 code, ByteString const& reason)
+void ConnectionFromClient::websocket_close(i64 websocket_id, u16 code, ByteString const& reason)
 {
-    if (auto connection = m_websockets.get(connection_id).value_or({}); connection && connection->ready_state() == WebSocket::ReadyState::Open)
+    if (auto connection = m_websockets.get(websocket_id).value_or({}); connection && connection->ready_state() == WebSocket::ReadyState::Open)
         connection->close(code, reason);
 }
 
-Messages::RequestServer::WebsocketSetCertificateResponse ConnectionFromClient::websocket_set_certificate(i32 connection_id, ByteString const&, ByteString const&)
+Messages::RequestServer::WebsocketSetCertificateResponse ConnectionFromClient::websocket_set_certificate(i64 websocket_id, ByteString const&, ByteString const&)
 {
     auto success = false;
-    if (auto connection = m_websockets.get(connection_id).value_or({}); connection) {
+    if (auto connection = m_websockets.get(websocket_id).value_or({}); connection) {
         // NO OP here
         // connection->set_certificate(certificate, key);
         success = true;

+ 4 - 6
Userland/Services/RequestServer/ConnectionFromClient.h

@@ -42,12 +42,10 @@ private:
     virtual Messages::RequestServer::SetCertificateResponse set_certificate(i32, ByteString const&, ByteString const&) override;
     virtual void ensure_connection(URL::URL const& url, ::RequestServer::CacheLevel const& cache_level) override;
 
-    virtual Messages::RequestServer::WebsocketConnectResponse websocket_connect(URL::URL const&, ByteString const&, Vector<ByteString> const&, Vector<ByteString> const&, HTTP::HeaderMap const&) override;
-    virtual Messages::RequestServer::WebsocketReadyStateResponse websocket_ready_state(i32) override;
-    virtual Messages::RequestServer::WebsocketSubprotocolInUseResponse websocket_subprotocol_in_use(i32) override;
-    virtual void websocket_send(i32, bool, ByteBuffer const&) override;
-    virtual void websocket_close(i32, u16, ByteString const&) override;
-    virtual Messages::RequestServer::WebsocketSetCertificateResponse websocket_set_certificate(i32, ByteString const&, ByteString const&) override;
+    virtual void websocket_connect(i64 websocket_id, URL::URL const&, ByteString const&, Vector<ByteString> const&, Vector<ByteString> const&, HTTP::HeaderMap const&) override;
+    virtual void websocket_send(i64 websocket_id, bool, ByteBuffer const&) override;
+    virtual void websocket_close(i64 websocket_id, u16, ByteString const&) override;
+    virtual Messages::RequestServer::WebsocketSetCertificateResponse websocket_set_certificate(i64, ByteString const&, ByteString const&) override;
 
     HashMap<i32, RefPtr<WebSocket::WebSocket>> m_websockets;
 

+ 7 - 5
Userland/Services/RequestServer/RequestClient.ipc

@@ -9,11 +9,13 @@ endpoint RequestClient
 
     // Websocket API
     // FIXME: See if this can be merged with the regular APIs
-    websocket_connected(i32 connection_id) =|
-    websocket_received(i32 connection_id, bool is_text, ByteBuffer data) =|
-    websocket_errored(i32 connection_id, i32 message) =|
-    websocket_closed(i32 connection_id, u16 code, ByteString reason, bool clean) =|
-    websocket_certificate_requested(i32 request_id) =|
+    websocket_connected(i64 websocket_id) =|
+    websocket_received(i64 websocket_id, bool is_text, ByteBuffer data) =|
+    websocket_errored(i64 websocket_id, i32 message) =|
+    websocket_closed(i64 websocket_id, u16 code, ByteString reason, bool clean) =|
+    websocket_ready_state_changed(i64 websocket_id, u32 ready_state) =|
+    websocket_subprotocol(i64 websocket_id, ByteString subprotocol) =|
+    websocket_certificate_requested(i64 websocket_id) =|
 
     // Certificate requests
     certificate_requested(i32 request_id) =|

+ 4 - 6
Userland/Services/RequestServer/RequestServer.ipc

@@ -16,11 +16,9 @@ endpoint RequestServer
     ensure_connection(URL::URL url, ::RequestServer::CacheLevel cache_level) =|
 
     // Websocket Connection API
-    websocket_connect(URL::URL url, ByteString origin, Vector<ByteString> protocols, Vector<ByteString> extensions, HTTP::HeaderMap additional_request_headers) => (i32 connection_id)
-    websocket_ready_state(i32 connection_id) => (u32 ready_state)
-    websocket_subprotocol_in_use(i32 connection_id) => (ByteString subprotocol_in_use)
-    websocket_send(i32 connection_id, bool is_text, ByteBuffer data) =|
-    websocket_close(i32 connection_id, u16 code, ByteString reason) =|
-    websocket_set_certificate(i32 request_id, ByteString certificate, ByteString key) => (bool success)
+    websocket_connect(i64 websocket_id, URL::URL url, ByteString origin, Vector<ByteString> protocols, Vector<ByteString> extensions, HTTP::HeaderMap additional_request_headers) =|
+    websocket_send(i64 websocket_id, bool is_text, ByteBuffer data) =|
+    websocket_close(i64 websocket_id, u16 code, ByteString reason) =|
+    websocket_set_certificate(i64 request_id, ByteString certificate, ByteString key) => (bool success)
 
 }