Browse Source

RequestServer: Make WebSocket IPC APIs asynchronous

This fixes deadlocking when interacting with WebSockets while
RequestServer is trying to stream downloaded data to WebContent.
Andreas Kling 10 tháng trước cách đây
mục cha
commit
e205723b95

+ 31 - 15
Userland/Libraries/LibRequests/RequestClient.cpp

@@ -96,45 +96,61 @@ void RequestClient::certificate_requested(i32 request_id)
 
 RefPtr<WebSocket> RequestClient::websocket_connect(const URL::URL& url, ByteString const& origin, Vector<ByteString> const& protocols, Vector<ByteString> const& extensions, HTTP::HeaderMap const& request_headers)
 {
-    auto connection_id = IPCProxy::websocket_connect(url, origin, protocols, extensions, request_headers);
-    if (connection_id < 0)
-        return nullptr;
-    auto connection = WebSocket::create_from_id({}, *this, connection_id);
-    m_websockets.set(connection_id, connection);
+    auto websocket_id = m_next_websocket_id++;
+    IPCProxy::async_websocket_connect(websocket_id, url, origin, protocols, extensions, request_headers);
+    auto connection = WebSocket::create_from_id({}, *this, websocket_id);
+    m_websockets.set(websocket_id, connection);
     return connection;
 }
 
-void RequestClient::websocket_connected(i32 connection_id)
+void RequestClient::websocket_connected(i64 websocket_id)
 {
-    auto maybe_connection = m_websockets.get(connection_id);
+    auto maybe_connection = m_websockets.get(websocket_id);
     if (maybe_connection.has_value())
         maybe_connection.value()->did_open({});
 }
 
-void RequestClient::websocket_received(i32 connection_id, bool is_text, ByteBuffer const& data)
+void RequestClient::websocket_received(i64 websocket_id, bool is_text, ByteBuffer const& data)
 {
-    auto maybe_connection = m_websockets.get(connection_id);
+    auto maybe_connection = m_websockets.get(websocket_id);
     if (maybe_connection.has_value())
         maybe_connection.value()->did_receive({}, data, is_text);
 }
 
-void RequestClient::websocket_errored(i32 connection_id, i32 message)
+void RequestClient::websocket_errored(i64 websocket_id, i32 message)
 {
-    auto maybe_connection = m_websockets.get(connection_id);
+    auto maybe_connection = m_websockets.get(websocket_id);
     if (maybe_connection.has_value())
         maybe_connection.value()->did_error({}, message);
 }
 
-void RequestClient::websocket_closed(i32 connection_id, u16 code, ByteString const& reason, bool clean)
+void RequestClient::websocket_closed(i64 websocket_id, u16 code, ByteString const& reason, bool clean)
 {
-    auto maybe_connection = m_websockets.get(connection_id);
+    auto maybe_connection = m_websockets.get(websocket_id);
     if (maybe_connection.has_value())
         maybe_connection.value()->did_close({}, code, reason, clean);
 }
 
-void RequestClient::websocket_certificate_requested(i32 connection_id)
+void RequestClient::websocket_ready_state_changed(i64 websocket_id, u32 ready_state)
+{
+    auto maybe_connection = m_websockets.get(websocket_id);
+    if (maybe_connection.has_value()) {
+        VERIFY(ready_state <= static_cast<u32>(WebSocket::ReadyState::Closed));
+        maybe_connection.value()->set_ready_state(static_cast<WebSocket::ReadyState>(ready_state));
+    }
+}
+
+void RequestClient::websocket_subprotocol(i64 websocket_id, ByteString const& subprotocol)
+{
+    auto maybe_connection = m_websockets.get(websocket_id);
+    if (maybe_connection.has_value()) {
+        maybe_connection.value()->set_subprotocol_in_use(subprotocol);
+    }
+}
+
+void RequestClient::websocket_certificate_requested(i64 websocket_id)
 {
-    auto maybe_connection = m_websockets.get(connection_id);
+    auto maybe_connection = m_websockets.get(websocket_id);
     if (maybe_connection.has_value())
         maybe_connection.value()->did_request_certificates({});
 }

+ 10 - 6
Userland/Libraries/LibRequests/RequestClient.h

@@ -44,14 +44,18 @@ private:
     virtual void certificate_requested(i32) override;
     virtual void headers_became_available(i32, HTTP::HeaderMap const&, Optional<u32> const&) override;
 
-    virtual void websocket_connected(i32) override;
-    virtual void websocket_received(i32, bool, ByteBuffer const&) override;
-    virtual void websocket_errored(i32, i32) override;
-    virtual void websocket_closed(i32, u16, ByteString const&, bool) override;
-    virtual void websocket_certificate_requested(i32) override;
+    virtual void websocket_connected(i64 websocket_id) override;
+    virtual void websocket_received(i64 websocket_id, bool, ByteBuffer const&) override;
+    virtual void websocket_errored(i64 websocket_id, i32) override;
+    virtual void websocket_closed(i64 websocket_id, u16, ByteString const&, bool) override;
+    virtual void websocket_ready_state_changed(i64 websocket_id, u32 ready_state) override;
+    virtual void websocket_subprotocol(i64 websocket_id, ByteString const& subprotocol) override;
+    virtual void websocket_certificate_requested(i64 websocket_id) override;
 
     HashMap<i32, RefPtr<Request>> m_requests;
-    HashMap<i32, NonnullRefPtr<WebSocket>> m_websockets;
+    HashMap<i64, NonnullRefPtr<WebSocket>> m_websockets;
+
+    i64 m_next_websocket_id { 0 };
 };
 
 }

+ 17 - 7
Userland/Libraries/LibRequests/WebSocket.cpp

@@ -9,25 +9,35 @@
 
 namespace Requests {
 
-WebSocket::WebSocket(RequestClient& client, i32 connection_id)
+WebSocket::WebSocket(RequestClient& client, i64 connection_id)
     : m_client(client)
-    , m_connection_id(connection_id)
+    , m_websocket_id(connection_id)
 {
 }
 
 WebSocket::ReadyState WebSocket::ready_state()
 {
-    return static_cast<WebSocket::ReadyState>(m_client->websocket_ready_state(m_connection_id));
+    return m_ready_state;
+}
+
+void WebSocket::set_ready_state(ReadyState ready_state)
+{
+    m_ready_state = ready_state;
 }
 
 ByteString WebSocket::subprotocol_in_use()
 {
-    return m_client->websocket_subprotocol_in_use(m_connection_id);
+    return m_subprotocol;
+}
+
+void WebSocket::set_subprotocol_in_use(ByteString subprotocol)
+{
+    m_subprotocol = move(subprotocol);
 }
 
 void WebSocket::send(ByteBuffer binary_or_text_message, bool is_text)
 {
-    m_client->async_websocket_send(m_connection_id, is_text, move(binary_or_text_message));
+    m_client->async_websocket_send(m_websocket_id, is_text, move(binary_or_text_message));
 }
 
 void WebSocket::send(StringView text_message)
@@ -37,7 +47,7 @@ void WebSocket::send(StringView text_message)
 
 void WebSocket::close(u16 code, ByteString reason)
 {
-    m_client->async_websocket_close(m_connection_id, code, move(reason));
+    m_client->async_websocket_close(m_websocket_id, code, move(reason));
 }
 
 void WebSocket::did_open(Badge<RequestClient>)
@@ -68,7 +78,7 @@ void WebSocket::did_request_certificates(Badge<RequestClient>)
 {
     if (on_certificate_requested) {
         auto result = on_certificate_requested();
-        if (!m_client->websocket_set_certificate(m_connection_id, result.certificate, result.key))
+        if (!m_client->websocket_set_certificate(m_websocket_id, result.certificate, result.key))
             dbgln("WebSocket: set_certificate failed");
     }
 }

+ 9 - 5
Userland/Libraries/LibRequests/WebSocket.h

@@ -44,16 +44,18 @@ public:
         Closed = 3,
     };
 
-    static NonnullRefPtr<WebSocket> create_from_id(Badge<RequestClient>, RequestClient& client, i32 connection_id)
+    static NonnullRefPtr<WebSocket> create_from_id(Badge<RequestClient>, RequestClient& client, i64 websocket_id)
     {
-        return adopt_ref(*new WebSocket(client, connection_id));
+        return adopt_ref(*new WebSocket(client, websocket_id));
     }
 
-    int id() const { return m_connection_id; }
+    i64 id() const { return m_websocket_id; }
 
     ReadyState ready_state();
+    void set_ready_state(ReadyState);
 
     ByteString subprotocol_in_use();
+    void set_subprotocol_in_use(ByteString);
 
     void send(ByteBuffer binary_or_text_message, bool is_text);
     void send(StringView text_message);
@@ -72,9 +74,11 @@ public:
     void did_request_certificates(Badge<RequestClient>);
 
 private:
-    explicit WebSocket(RequestClient&, i32 connection_id);
+    explicit WebSocket(RequestClient&, i64 websocket_id);
     WeakPtr<RequestClient> m_client;
-    int m_connection_id { -1 };
+    ReadyState m_ready_state { ReadyState::Connecting };
+    ByteString m_subprotocol;
+    i64 m_websocket_id { -1 };
 };
 
 }

+ 23 - 10
Userland/Libraries/LibWebSocket/WebSocket.cpp

@@ -42,14 +42,14 @@ void WebSocket::start()
     m_impl->on_connected = [this] {
         if (m_state != WebSocket::InternalState::EstablishingProtocolConnection)
             return;
-        m_state = WebSocket::InternalState::SendingClientHandshake;
+        set_state(WebSocket::InternalState::SendingClientHandshake);
         send_client_handshake();
         drain_read();
     };
     m_impl->on_ready_to_read = [this] {
         drain_read();
     };
-    m_state = WebSocket::InternalState::EstablishingProtocolConnection;
+    set_state(WebSocket::InternalState::EstablishingProtocolConnection);
     m_impl->connect(m_connection);
 }
 
@@ -100,7 +100,7 @@ void WebSocket::close(u16 code, ByteString const& message)
     case InternalState::SendingClientHandshake:
     case InternalState::WaitingForServerHandshake:
         // FIXME: Fail the connection.
-        m_state = InternalState::Closing;
+        set_state(InternalState::Closing);
         break;
     case InternalState::Open: {
         auto message_bytes = message.bytes();
@@ -108,7 +108,7 @@ void WebSocket::close(u16 code, ByteString const& message)
         close_payload.overwrite(0, (u8*)&code, 2);
         close_payload.overwrite(2, message_bytes.data(), message_bytes.size());
         send_frame(WebSocket::OpCode::ConnectionClose, close_payload, true);
-        m_state = InternalState::Closing;
+        set_state(InternalState::Closing);
         break;
     }
     default:
@@ -120,7 +120,7 @@ void WebSocket::drain_read()
 {
     if (m_impl->eof()) {
         // The connection got closed by the server
-        m_state = WebSocket::InternalState::Closed;
+        set_state(WebSocket::InternalState::Closed);
         notify_close(m_last_close_code, m_last_close_message, true);
         discard_connection();
         return;
@@ -218,7 +218,7 @@ void WebSocket::send_client_handshake()
 
     builder.append("\r\n"sv);
 
-    m_state = WebSocket::InternalState::WaitingForServerHandshake;
+    set_state(WebSocket::InternalState::WaitingForServerHandshake);
     auto success = m_impl->send(builder.string_view().bytes());
     VERIFY(success);
 }
@@ -282,7 +282,7 @@ void WebSocket::read_server_handshake()
                 return;
             }
 
-            m_state = WebSocket::InternalState::Open;
+            set_state(WebSocket::InternalState::Open);
             notify_open();
             return;
         }
@@ -400,7 +400,7 @@ void WebSocket::read_frame()
     auto head_bytes = get_buffered_bytes(2);
     if (head_bytes.is_null() || head_bytes.is_empty()) {
         // The connection got closed.
-        m_state = WebSocket::InternalState::Closed;
+        set_state(WebSocket::InternalState::Closed);
         notify_close(m_last_close_code, m_last_close_message, true);
         discard_connection();
         return;
@@ -487,7 +487,7 @@ void WebSocket::read_frame()
             m_last_close_code = (((u16)(payload[0] & 0xff) << 8) | ((u16)(payload[1] & 0xff)));
             m_last_close_message = ByteString(ReadonlyBytes(payload.offset_pointer(2), payload.size() - 2));
         }
-        m_state = WebSocket::InternalState::Closing;
+        set_state(WebSocket::InternalState::Closing);
         return;
     }
     if (op_code == WebSocket::OpCode::Ping) {
@@ -608,7 +608,7 @@ void WebSocket::send_frame(WebSocket::OpCode op_code, ReadonlyBytes payload, boo
 
 void WebSocket::fatal_error(WebSocket::Error error)
 {
-    m_state = WebSocket::InternalState::Errored;
+    set_state(WebSocket::InternalState::Errored);
     notify_error(error);
     discard_connection();
 }
@@ -653,4 +653,17 @@ void WebSocket::notify_message(Message message)
     on_message(move(message));
 }
 
+void WebSocket::set_state(InternalState state)
+{
+    if (m_state == state)
+        return;
+    auto old_ready_state = ready_state();
+    m_state = state;
+    auto new_ready_state = ready_state();
+    if (old_ready_state != new_ready_state) {
+        if (on_ready_state_change)
+            on_ready_state_change(ready_state());
+    }
+}
+
 }

+ 4 - 0
Userland/Libraries/LibWebSocket/WebSocket.h

@@ -46,6 +46,8 @@ public:
     Function<void()> on_open;
     Function<void(u16 code, ByteString reason, bool was_clean)> on_close;
     Function<void(Message message)> on_message;
+    Function<void(ReadyState)> on_ready_state_change;
+    Function<void(ByteString)> on_subprotocol;
 
     enum class Error {
         CouldNotEstablishConnection,
@@ -97,6 +99,8 @@ private:
 
     InternalState m_state { InternalState::NotStarted };
 
+    void set_state(InternalState);
+
     ByteString m_subprotocol_in_use { ByteString::empty() };
 
     ByteString m_websocket_key;

+ 20 - 34
Userland/Services/RequestServer/ConnectionFromClient.cpp

@@ -386,12 +386,11 @@ void ConnectionFromClient::ensure_connection(URL::URL const& url, ::RequestServe
     dbgln("FIXME: EnsureConnection: Pre-connect to {}", url);
 }
 
-static i32 s_next_websocket_id = 1;
-Messages::RequestServer::WebsocketConnectResponse ConnectionFromClient::websocket_connect(URL::URL const& url, ByteString const& origin, Vector<ByteString> const& protocols, Vector<ByteString> const& extensions, HTTP::HeaderMap const& additional_request_headers)
+void ConnectionFromClient::websocket_connect(i64 websocket_id, URL::URL const& url, ByteString const& origin, Vector<ByteString> const& protocols, Vector<ByteString> const& extensions, HTTP::HeaderMap const& additional_request_headers)
 {
     if (!url.is_valid()) {
         dbgln("WebSocket::Connect: Invalid URL requested: '{}'", url);
-        return -1;
+        return;
     }
 
     WebSocket::ConnectionInfo connection_info(url);
@@ -400,56 +399,43 @@ Messages::RequestServer::WebsocketConnectResponse ConnectionFromClient::websocke
     connection_info.set_extensions(extensions);
     connection_info.set_headers(additional_request_headers);
 
-    auto id = ++s_next_websocket_id;
     auto connection = WebSocket::WebSocket::create(move(connection_info));
-    connection->on_open = [this, id]() {
-        async_websocket_connected(id);
+    connection->on_open = [this, websocket_id]() {
+        async_websocket_connected(websocket_id);
+    };
+    connection->on_message = [this, websocket_id](auto message) {
+        async_websocket_received(websocket_id, message.is_text(), message.data());
     };
-    connection->on_message = [this, id](auto message) {
-        async_websocket_received(id, message.is_text(), message.data());
+    connection->on_error = [this, websocket_id](auto message) {
+        async_websocket_errored(websocket_id, (i32)message);
     };
-    connection->on_error = [this, id](auto message) {
-        async_websocket_errored(id, (i32)message);
+    connection->on_close = [this, websocket_id](u16 code, ByteString reason, bool was_clean) {
+        async_websocket_closed(websocket_id, code, move(reason), was_clean);
     };
-    connection->on_close = [this, id](u16 code, ByteString reason, bool was_clean) {
-        async_websocket_closed(id, code, move(reason), was_clean);
+    connection->on_ready_state_change = [this, websocket_id](auto state) {
+        async_websocket_ready_state_changed(websocket_id, (u32)state);
     };
 
     connection->start();
-    m_websockets.set(id, move(connection));
-    return id;
-}
-
-Messages::RequestServer::WebsocketReadyStateResponse ConnectionFromClient::websocket_ready_state(i32 connection_id)
-{
-    if (auto connection = m_websockets.get(connection_id).value_or({}))
-        return (u32)connection->ready_state();
-    return (u32)WebSocket::ReadyState::Closed;
-}
-
-Messages::RequestServer::WebsocketSubprotocolInUseResponse ConnectionFromClient::websocket_subprotocol_in_use(i32 connection_id)
-{
-    if (auto connection = m_websockets.get(connection_id).value_or({}))
-        return connection->subprotocol_in_use();
-    return ByteString::empty();
+    m_websockets.set(websocket_id, move(connection));
 }
 
-void ConnectionFromClient::websocket_send(i32 connection_id, bool is_text, ByteBuffer const& data)
+void ConnectionFromClient::websocket_send(i64 websocket_id, bool is_text, ByteBuffer const& data)
 {
-    if (auto connection = m_websockets.get(connection_id).value_or({}); connection && connection->ready_state() == WebSocket::ReadyState::Open)
+    if (auto connection = m_websockets.get(websocket_id).value_or({}); connection && connection->ready_state() == WebSocket::ReadyState::Open)
         connection->send(WebSocket::Message { data, is_text });
 }
 
-void ConnectionFromClient::websocket_close(i32 connection_id, u16 code, ByteString const& reason)
+void ConnectionFromClient::websocket_close(i64 websocket_id, u16 code, ByteString const& reason)
 {
-    if (auto connection = m_websockets.get(connection_id).value_or({}); connection && connection->ready_state() == WebSocket::ReadyState::Open)
+    if (auto connection = m_websockets.get(websocket_id).value_or({}); connection && connection->ready_state() == WebSocket::ReadyState::Open)
         connection->close(code, reason);
 }
 
-Messages::RequestServer::WebsocketSetCertificateResponse ConnectionFromClient::websocket_set_certificate(i32 connection_id, ByteString const&, ByteString const&)
+Messages::RequestServer::WebsocketSetCertificateResponse ConnectionFromClient::websocket_set_certificate(i64 websocket_id, ByteString const&, ByteString const&)
 {
     auto success = false;
-    if (auto connection = m_websockets.get(connection_id).value_or({}); connection) {
+    if (auto connection = m_websockets.get(websocket_id).value_or({}); connection) {
         // NO OP here
         // connection->set_certificate(certificate, key);
         success = true;

+ 4 - 6
Userland/Services/RequestServer/ConnectionFromClient.h

@@ -42,12 +42,10 @@ private:
     virtual Messages::RequestServer::SetCertificateResponse set_certificate(i32, ByteString const&, ByteString const&) override;
     virtual void ensure_connection(URL::URL const& url, ::RequestServer::CacheLevel const& cache_level) override;
 
-    virtual Messages::RequestServer::WebsocketConnectResponse websocket_connect(URL::URL const&, ByteString const&, Vector<ByteString> const&, Vector<ByteString> const&, HTTP::HeaderMap const&) override;
-    virtual Messages::RequestServer::WebsocketReadyStateResponse websocket_ready_state(i32) override;
-    virtual Messages::RequestServer::WebsocketSubprotocolInUseResponse websocket_subprotocol_in_use(i32) override;
-    virtual void websocket_send(i32, bool, ByteBuffer const&) override;
-    virtual void websocket_close(i32, u16, ByteString const&) override;
-    virtual Messages::RequestServer::WebsocketSetCertificateResponse websocket_set_certificate(i32, ByteString const&, ByteString const&) override;
+    virtual void websocket_connect(i64 websocket_id, URL::URL const&, ByteString const&, Vector<ByteString> const&, Vector<ByteString> const&, HTTP::HeaderMap const&) override;
+    virtual void websocket_send(i64 websocket_id, bool, ByteBuffer const&) override;
+    virtual void websocket_close(i64 websocket_id, u16, ByteString const&) override;
+    virtual Messages::RequestServer::WebsocketSetCertificateResponse websocket_set_certificate(i64, ByteString const&, ByteString const&) override;
 
     HashMap<i32, RefPtr<WebSocket::WebSocket>> m_websockets;
 

+ 7 - 5
Userland/Services/RequestServer/RequestClient.ipc

@@ -9,11 +9,13 @@ endpoint RequestClient
 
     // Websocket API
     // FIXME: See if this can be merged with the regular APIs
-    websocket_connected(i32 connection_id) =|
-    websocket_received(i32 connection_id, bool is_text, ByteBuffer data) =|
-    websocket_errored(i32 connection_id, i32 message) =|
-    websocket_closed(i32 connection_id, u16 code, ByteString reason, bool clean) =|
-    websocket_certificate_requested(i32 request_id) =|
+    websocket_connected(i64 websocket_id) =|
+    websocket_received(i64 websocket_id, bool is_text, ByteBuffer data) =|
+    websocket_errored(i64 websocket_id, i32 message) =|
+    websocket_closed(i64 websocket_id, u16 code, ByteString reason, bool clean) =|
+    websocket_ready_state_changed(i64 websocket_id, u32 ready_state) =|
+    websocket_subprotocol(i64 websocket_id, ByteString subprotocol) =|
+    websocket_certificate_requested(i64 websocket_id) =|
 
     // Certificate requests
     certificate_requested(i32 request_id) =|

+ 4 - 6
Userland/Services/RequestServer/RequestServer.ipc

@@ -16,11 +16,9 @@ endpoint RequestServer
     ensure_connection(URL::URL url, ::RequestServer::CacheLevel cache_level) =|
 
     // Websocket Connection API
-    websocket_connect(URL::URL url, ByteString origin, Vector<ByteString> protocols, Vector<ByteString> extensions, HTTP::HeaderMap additional_request_headers) => (i32 connection_id)
-    websocket_ready_state(i32 connection_id) => (u32 ready_state)
-    websocket_subprotocol_in_use(i32 connection_id) => (ByteString subprotocol_in_use)
-    websocket_send(i32 connection_id, bool is_text, ByteBuffer data) =|
-    websocket_close(i32 connection_id, u16 code, ByteString reason) =|
-    websocket_set_certificate(i32 request_id, ByteString certificate, ByteString key) => (bool success)
+    websocket_connect(i64 websocket_id, URL::URL url, ByteString origin, Vector<ByteString> protocols, Vector<ByteString> extensions, HTTP::HeaderMap additional_request_headers) =|
+    websocket_send(i64 websocket_id, bool is_text, ByteBuffer data) =|
+    websocket_close(i64 websocket_id, u16 code, ByteString reason) =|
+    websocket_set_certificate(i64 request_id, ByteString certificate, ByteString key) => (bool success)
 
 }