diff --git a/Ladybird/WebSocketClientManagerLadybird.cpp b/Ladybird/WebSocketClientManagerLadybird.cpp index 2cbbd1413b0..79edc4e2cd2 100644 --- a/Ladybird/WebSocketClientManagerLadybird.cpp +++ b/Ladybird/WebSocketClientManagerLadybird.cpp @@ -19,10 +19,11 @@ NonnullRefPtr WebSocketClientManagerLadybird::cr WebSocketClientManagerLadybird::WebSocketClientManagerLadybird() = default; WebSocketClientManagerLadybird::~WebSocketClientManagerLadybird() = default; -RefPtr WebSocketClientManagerLadybird::connect(AK::URL const& url, DeprecatedString const& origin) +RefPtr WebSocketClientManagerLadybird::connect(AK::URL const& url, DeprecatedString const& origin, Vector const& protocols) { WebSocket::ConnectionInfo connection_info(url); connection_info.set_origin(origin); + connection_info.set_protocols(protocols); auto impl = adopt_ref(*new WebSocketImplQt); auto web_socket = WebSocket::WebSocket::create(move(connection_info), move(impl)); diff --git a/Ladybird/WebSocketClientManagerLadybird.h b/Ladybird/WebSocketClientManagerLadybird.h index cad676bdabf..a0831d28b0c 100644 --- a/Ladybird/WebSocketClientManagerLadybird.h +++ b/Ladybird/WebSocketClientManagerLadybird.h @@ -19,7 +19,7 @@ public: static NonnullRefPtr create(); virtual ~WebSocketClientManagerLadybird() override; - virtual RefPtr connect(AK::URL const&, DeprecatedString const& origin) override; + virtual RefPtr connect(AK::URL const&, DeprecatedString const& origin, Vector const& protocols) override; private: WebSocketClientManagerLadybird(); diff --git a/Ladybird/WebSocketLadybird.cpp b/Ladybird/WebSocketLadybird.cpp index afadf2b4636..e11930def67 100644 --- a/Ladybird/WebSocketLadybird.cpp +++ b/Ladybird/WebSocketLadybird.cpp @@ -74,6 +74,11 @@ Web::WebSockets::WebSocket::ReadyState WebSocketLadybird::ready_state() VERIFY_NOT_REACHED(); } +DeprecatedString WebSocketLadybird::subprotocol_in_use() +{ + return m_websocket->subprotocol_in_use(); +} + void WebSocketLadybird::send(ByteBuffer binary_or_text_message, bool is_text) { m_websocket->send(WebSocket::Message(binary_or_text_message, is_text)); diff --git a/Ladybird/WebSocketLadybird.h b/Ladybird/WebSocketLadybird.h index 5e7eee75f0b..170d64b3e4a 100644 --- a/Ladybird/WebSocketLadybird.h +++ b/Ladybird/WebSocketLadybird.h @@ -21,6 +21,7 @@ public: virtual ~WebSocketLadybird() override; virtual Web::WebSockets::WebSocket::ReadyState ready_state() override; + virtual DeprecatedString subprotocol_in_use() override; virtual void send(ByteBuffer binary_or_text_message, bool is_text) override; virtual void send(StringView message) override; virtual void close(u16 code, DeprecatedString reason) override; diff --git a/Userland/Libraries/LibProtocol/WebSocket.cpp b/Userland/Libraries/LibProtocol/WebSocket.cpp index fd85b64add4..7d5b3cbe659 100644 --- a/Userland/Libraries/LibProtocol/WebSocket.cpp +++ b/Userland/Libraries/LibProtocol/WebSocket.cpp @@ -20,6 +20,11 @@ WebSocket::ReadyState WebSocket::ready_state() return (WebSocket::ReadyState)m_client->ready_state({}, *this); } +DeprecatedString WebSocket::subprotocol_in_use() +{ + return m_client->subprotocol_in_use({}, *this); +} + void WebSocket::send(ByteBuffer binary_or_text_message, bool is_text) { m_client->send({}, *this, move(binary_or_text_message), is_text); diff --git a/Userland/Libraries/LibProtocol/WebSocket.h b/Userland/Libraries/LibProtocol/WebSocket.h index d323a9edcf1..eb7d859cbaa 100644 --- a/Userland/Libraries/LibProtocol/WebSocket.h +++ b/Userland/Libraries/LibProtocol/WebSocket.h @@ -53,6 +53,8 @@ public: ReadyState ready_state(); + DeprecatedString subprotocol_in_use(); + void send(ByteBuffer binary_or_text_message, bool is_text); void send(StringView text_message); void close(u16 code = 1005, DeprecatedString reason = {}); diff --git a/Userland/Libraries/LibProtocol/WebSocketClient.cpp b/Userland/Libraries/LibProtocol/WebSocketClient.cpp index b96d04951be..03e65268a39 100644 --- a/Userland/Libraries/LibProtocol/WebSocketClient.cpp +++ b/Userland/Libraries/LibProtocol/WebSocketClient.cpp @@ -34,6 +34,13 @@ u32 WebSocketClient::ready_state(Badge, WebSocket& connection) return IPCProxy::ready_state(connection.id()); } +DeprecatedString WebSocketClient::subprotocol_in_use(Badge, WebSocket& connection) +{ + if (!m_connections.contains(connection.id())) + return DeprecatedString::empty(); + return IPCProxy::subprotocol_in_use(connection.id()); +} + void WebSocketClient::send(Badge, WebSocket& connection, ByteBuffer data, bool is_text) { if (!m_connections.contains(connection.id())) diff --git a/Userland/Libraries/LibProtocol/WebSocketClient.h b/Userland/Libraries/LibProtocol/WebSocketClient.h index 60281c93c07..85d026ea6ea 100644 --- a/Userland/Libraries/LibProtocol/WebSocketClient.h +++ b/Userland/Libraries/LibProtocol/WebSocketClient.h @@ -24,6 +24,7 @@ public: RefPtr connect(const URL&, DeprecatedString const& origin = {}, Vector const& protocols = {}, Vector const& extensions = {}, HashMap const& request_headers = {}); u32 ready_state(Badge, WebSocket&); + DeprecatedString subprotocol_in_use(Badge, WebSocket&); void send(Badge, WebSocket&, ByteBuffer, bool is_text); void close(Badge, WebSocket&, u16 code, DeprecatedString reason); bool set_certificate(Badge, WebSocket&, DeprecatedString, DeprecatedString); diff --git a/Userland/Libraries/LibWeb/WebSockets/WebSocket.cpp b/Userland/Libraries/LibWeb/WebSockets/WebSocket.cpp index 84fd063c284..e0fa3f6ad9c 100644 --- a/Userland/Libraries/LibWeb/WebSockets/WebSocket.cpp +++ b/Userland/Libraries/LibWeb/WebSockets/WebSocket.cpp @@ -4,6 +4,7 @@ * SPDX-License-Identifier: BSD-2-Clause */ +#include #include #include #include @@ -45,7 +46,7 @@ WebSocketClientSocket::~WebSocketClientSocket() = default; WebSocketClientManager::WebSocketClientManager() = default; // https://websockets.spec.whatwg.org/#dom-websocket-websocket -WebIDL::ExceptionOr> WebSocket::construct_impl(JS::Realm& realm, DeprecatedString const& url) +WebIDL::ExceptionOr> WebSocket::construct_impl(JS::Realm& realm, DeprecatedString const& url, Optional>> const& protocols) { auto& window = verify_cast(realm.global_object()); AK::URL url_record(url); @@ -55,18 +56,39 @@ WebIDL::ExceptionOr> WebSocket::construct_impl(JS::R return WebIDL::SyntaxError::create(realm, "Invalid protocol"); if (!url_record.fragment().is_empty()) return WebIDL::SyntaxError::create(realm, "Presence of URL fragment is invalid"); - // 5. If `protocols` is a string, set `protocols` to a sequence consisting of just that string - // 6. If any of the values in `protocols` occur more than once or otherwise fail to match the requirements, throw SyntaxError - return MUST_OR_THROW_OOM(realm.heap().allocate(realm, window, url_record)); + Vector protocols_sequence; + if (protocols.has_value()) { + // 5. If `protocols` is a string, set `protocols` to a sequence consisting of just that string + if (protocols.value().has()) + protocols_sequence = { protocols.value().get() }; + else + protocols_sequence = protocols.value().get>(); + // 6. If any of the values in `protocols` occur more than once or otherwise fail to match the requirements, throw SyntaxError + auto sorted_protocols = protocols_sequence; + quick_sort(sorted_protocols); + for (size_t i = 0; i < sorted_protocols.size(); i++) { + // https://datatracker.ietf.org/doc/html/rfc6455 + // The elements that comprise this value MUST be non-empty strings with characters in the range U+0021 to U+007E not including + // separator characters as defined in [RFC2616] and MUST all be unique strings. + auto protocol = sorted_protocols[i]; + if (i < sorted_protocols.size() - 1 && protocol == sorted_protocols[i + 1]) + return WebIDL::SyntaxError::create(realm, "Found a duplicate protocol name in the specified list"); + for (auto character : protocol) { + if (character < '\x21' || character > '\x7E') + return WebIDL::SyntaxError::create(realm, "Found invalid character in subprotocol name"); + } + } + } + return MUST_OR_THROW_OOM(realm.heap().allocate(realm, window, url_record, protocols_sequence)); } -WebSocket::WebSocket(HTML::Window& window, AK::URL& url) +WebSocket::WebSocket(HTML::Window& window, AK::URL& url, Vector const& protocols) : EventTarget(window.realm()) , m_window(window) { // FIXME: Integrate properly with FETCH as per https://fetch.spec.whatwg.org/#websocket-opening-handshake auto origin_string = m_window->associated_document().origin().serialize(); - m_websocket = WebSocketClientManager::the().connect(url, origin_string); + m_websocket = WebSocketClientManager::the().connect(url, origin_string, protocols); m_websocket->on_open = [weak_this = make_weak_ptr()] { if (!weak_this) return; @@ -132,9 +154,7 @@ DeprecatedString WebSocket::protocol() const { if (!m_websocket) return DeprecatedString::empty(); - // https://websockets.spec.whatwg.org/#feedback-from-the-protocol - // FIXME: Change the protocol attribute's value to the subprotocol in use, if it is not the null value. - return DeprecatedString::empty(); + return m_websocket->subprotocol_in_use(); } // https://websockets.spec.whatwg.org/#dom-websocket-close diff --git a/Userland/Libraries/LibWeb/WebSockets/WebSocket.h b/Userland/Libraries/LibWeb/WebSockets/WebSocket.h index 96b6fb55ea2..40b3d185a43 100644 --- a/Userland/Libraries/LibWeb/WebSockets/WebSocket.h +++ b/Userland/Libraries/LibWeb/WebSockets/WebSocket.h @@ -37,7 +37,7 @@ public: Closed = 3, }; - static WebIDL::ExceptionOr> construct_impl(JS::Realm&, DeprecatedString const& url); + static WebIDL::ExceptionOr> construct_impl(JS::Realm&, DeprecatedString const& url, Optional>> const& protocols); virtual ~WebSocket() override; @@ -66,7 +66,7 @@ private: void on_error(); void on_close(u16 code, DeprecatedString reason, bool was_clean); - WebSocket(HTML::Window&, AK::URL&); + WebSocket(HTML::Window&, AK::URL&, Vector const& protocols); virtual JS::ThrowCompletionOr initialize(JS::Realm&) override; virtual void visit_edges(Cell::Visitor&) override; @@ -99,6 +99,7 @@ public: }; virtual Web::WebSockets::WebSocket::ReadyState ready_state() = 0; + virtual DeprecatedString subprotocol_in_use() = 0; virtual void send(ByteBuffer binary_or_text_message, bool is_text) = 0; virtual void send(StringView text_message) = 0; @@ -120,7 +121,7 @@ public: static void initialize(RefPtr); static WebSocketClientManager& the(); - virtual RefPtr connect(AK::URL const&, DeprecatedString const& origin) = 0; + virtual RefPtr connect(AK::URL const&, DeprecatedString const& origin, Vector const& protocols) = 0; protected: explicit WebSocketClientManager(); diff --git a/Userland/Libraries/LibWeb/WebSockets/WebSocket.idl b/Userland/Libraries/LibWeb/WebSockets/WebSocket.idl index 7df127d9500..c857e077636 100644 --- a/Userland/Libraries/LibWeb/WebSockets/WebSocket.idl +++ b/Userland/Libraries/LibWeb/WebSockets/WebSocket.idl @@ -5,8 +5,7 @@ [Exposed=(Window,Worker)] interface WebSocket : EventTarget { - // FIXME: A second "protocols" argument should be added once supported - constructor(USVString url); + constructor(USVString url, optional (DOMString or sequence) protocols); readonly attribute USVString url; diff --git a/Userland/Libraries/LibWebSocket/WebSocket.cpp b/Userland/Libraries/LibWebSocket/WebSocket.cpp index f3072c7aa77..35ceb54f6a2 100644 --- a/Userland/Libraries/LibWebSocket/WebSocket.cpp +++ b/Userland/Libraries/LibWebSocket/WebSocket.cpp @@ -74,6 +74,11 @@ ReadyState WebSocket::ready_state() } } +DeprecatedString WebSocket::subprotocol_in_use() +{ + return m_subprotocol_in_use; +} + void WebSocket::send(Message const& message) { // Calling send on a socket that is not opened is not allowed @@ -356,22 +361,21 @@ void WebSocket::read_server_handshake() } if (header_name.equals_ignoring_case("Sec-WebSocket-Protocol"sv)) { - // 6. |Sec-WebSocket-Protocol| should not contain an extension that doesn't appear in m_connection->protocols() - auto server_protocols = parts[1].split(','); - for (auto const& protocol : server_protocols) { - auto trimmed_protocol = protocol.trim_whitespace(); - bool found_protocol = false; - for (auto const& supported_protocol : m_connection.protocols()) { - if (trimmed_protocol.equals_ignoring_case(supported_protocol)) { - found_protocol = true; - } - } - if (!found_protocol) { - dbgln("WebSocket: Server HTTP Handshake Header |Sec-WebSocket-Protocol| contains '{}', which is not supported by the client. Failing connection.", trimmed_protocol); - fatal_error(WebSocket::Error::ConnectionUpgradeFailed); - return; + // 6. If the response includes a |Sec-WebSocket-Protocol| header field and this header field indicates the use of a subprotocol that was not present in the client's handshake (the server has indicated a subprotocol not requested by the client), the client MUST _Fail the WebSocket Connection_. + // Additionally, Section 4.2.2 says this is "Either a single value representing the subprotocol the server is ready to use or null." + auto server_protocol = parts[1].trim_whitespace(); + bool found_protocol = false; + for (auto const& supported_protocol : m_connection.protocols()) { + if (server_protocol.equals_ignoring_case(supported_protocol)) { + found_protocol = true; } } + if (!found_protocol) { + dbgln("WebSocket: Server HTTP Handshake Header |Sec-WebSocket-Protocol| contains '{}', which is not supported by the client. Failing connection.", server_protocol); + fatal_error(WebSocket::Error::ConnectionUpgradeFailed); + return; + } + m_subprotocol_in_use = server_protocol; continue; } } diff --git a/Userland/Libraries/LibWebSocket/WebSocket.h b/Userland/Libraries/LibWebSocket/WebSocket.h index a47ee24b041..7d1ad27c268 100644 --- a/Userland/Libraries/LibWebSocket/WebSocket.h +++ b/Userland/Libraries/LibWebSocket/WebSocket.h @@ -32,6 +32,8 @@ public: ReadyState ready_state(); + DeprecatedString subprotocol_in_use(); + // Call this to start the WebSocket connection. void start(); @@ -95,6 +97,8 @@ private: InternalState m_state { InternalState::NotStarted }; + DeprecatedString m_subprotocol_in_use { DeprecatedString::empty() }; + DeprecatedString m_websocket_key; bool m_has_read_server_handshake_first_line { false }; bool m_has_read_server_handshake_upgrade { false }; diff --git a/Userland/Libraries/LibWebView/WebSocketClientAdapter.cpp b/Userland/Libraries/LibWebView/WebSocketClientAdapter.cpp index f9ae63e145b..763049d9f51 100644 --- a/Userland/Libraries/LibWebView/WebSocketClientAdapter.cpp +++ b/Userland/Libraries/LibWebView/WebSocketClientAdapter.cpp @@ -87,6 +87,11 @@ Web::WebSockets::WebSocket::ReadyState WebSocketClientSocketAdapter::ready_state VERIFY_NOT_REACHED(); } +DeprecatedString WebSocketClientSocketAdapter::subprotocol_in_use() +{ + return m_websocket->subprotocol_in_use(); +} + void WebSocketClientSocketAdapter::send(ByteBuffer binary_or_text_message, bool is_text) { m_websocket->send(binary_or_text_message, is_text); @@ -115,9 +120,9 @@ WebSocketClientManagerAdapter::WebSocketClientManagerAdapter(NonnullRefPtr WebSocketClientManagerAdapter::connect(const AK::URL& url, DeprecatedString const& origin) +RefPtr WebSocketClientManagerAdapter::connect(const AK::URL& url, DeprecatedString const& origin, Vector const& protocols) { - auto underlying_websocket = m_websocket_client->connect(url, origin); + auto underlying_websocket = m_websocket_client->connect(url, origin, protocols); if (!underlying_websocket) return {}; return WebSocketClientSocketAdapter::create(underlying_websocket.release_nonnull()); diff --git a/Userland/Libraries/LibWebView/WebSocketClientAdapter.h b/Userland/Libraries/LibWebView/WebSocketClientAdapter.h index 35095331b49..2a1004d71bd 100644 --- a/Userland/Libraries/LibWebView/WebSocketClientAdapter.h +++ b/Userland/Libraries/LibWebView/WebSocketClientAdapter.h @@ -26,6 +26,7 @@ public: virtual ~WebSocketClientSocketAdapter() override; virtual Web::WebSockets::WebSocket::ReadyState ready_state() override; + virtual DeprecatedString subprotocol_in_use() override; virtual void send(ByteBuffer binary_or_text_message, bool is_text) override; virtual void send(StringView text_message) override; @@ -43,7 +44,7 @@ public: virtual ~WebSocketClientManagerAdapter() override; - virtual RefPtr connect(const AK::URL&, DeprecatedString const& origin) override; + virtual RefPtr connect(const AK::URL&, DeprecatedString const& origin, Vector const& protocols) override; private: WebSocketClientManagerAdapter(NonnullRefPtr); diff --git a/Userland/Services/WebSocket/ConnectionFromClient.cpp b/Userland/Services/WebSocket/ConnectionFromClient.cpp index 032df8f0a1b..1a3ed845419 100644 --- a/Userland/Services/WebSocket/ConnectionFromClient.cpp +++ b/Userland/Services/WebSocket/ConnectionFromClient.cpp @@ -75,6 +75,15 @@ Messages::WebSocketServer::ReadyStateResponse ConnectionFromClient::ready_state( return (u32)ReadyState::Closed; } +Messages::WebSocketServer::SubprotocolInUseResponse ConnectionFromClient::subprotocol_in_use(i32 connection_id) +{ + RefPtr connection = m_connections.get(connection_id).value_or({}); + if (connection) { + return connection->subprotocol_in_use(); + } + return DeprecatedString::empty(); +} + void ConnectionFromClient::send(i32 connection_id, bool is_text, ByteBuffer const& data) { RefPtr connection = m_connections.get(connection_id).value_or({}); diff --git a/Userland/Services/WebSocket/ConnectionFromClient.h b/Userland/Services/WebSocket/ConnectionFromClient.h index d5bab0fb9aa..aec8185b87e 100644 --- a/Userland/Services/WebSocket/ConnectionFromClient.h +++ b/Userland/Services/WebSocket/ConnectionFromClient.h @@ -28,6 +28,7 @@ private: virtual Messages::WebSocketServer::ConnectResponse connect(URL const&, DeprecatedString const&, Vector const&, Vector const&, IPC::Dictionary const&) override; virtual Messages::WebSocketServer::ReadyStateResponse ready_state(i32) override; + virtual Messages::WebSocketServer::SubprotocolInUseResponse subprotocol_in_use(i32) override; virtual void send(i32, bool, ByteBuffer const&) override; virtual void close(i32, u16, DeprecatedString const&) override; virtual Messages::WebSocketServer::SetCertificateResponse set_certificate(i32, DeprecatedString const&, DeprecatedString const&) override; diff --git a/Userland/Services/WebSocket/WebSocketServer.ipc b/Userland/Services/WebSocket/WebSocketServer.ipc index d0dfcf6b25e..5b5bc7aa8bc 100644 --- a/Userland/Services/WebSocket/WebSocketServer.ipc +++ b/Userland/Services/WebSocket/WebSocketServer.ipc @@ -5,6 +5,7 @@ endpoint WebSocketServer // Connection API connect(URL url, DeprecatedString origin, Vector protocols, Vector extensions, IPC::Dictionary additional_request_headers) => (i32 connection_id) ready_state(i32 connection_id) => (u32 ready_state) + subprotocol_in_use(i32 connection_id) => (DeprecatedString subprotocol_in_use) send(i32 connection_id, bool is_text, ByteBuffer data) =| close(i32 connection_id, u16 code, DeprecatedString reason) =| diff --git a/Userland/Utilities/headless-browser.cpp b/Userland/Utilities/headless-browser.cpp index 551e921d3d6..2e339890e3f 100644 --- a/Userland/Utilities/headless-browser.cpp +++ b/Userland/Utilities/headless-browser.cpp @@ -592,6 +592,11 @@ public: VERIFY_NOT_REACHED(); } + virtual DeprecatedString subprotocol_in_use() override + { + return m_websocket->subprotocol_in_use(); + } + virtual void send(ByteBuffer binary_or_text_message, bool is_text) override { m_websocket->send(WebSocket::Message(binary_or_text_message, is_text)); @@ -661,10 +666,11 @@ public: virtual ~HeadlessWebSocketClientManager() override { } - virtual RefPtr connect(AK::URL const& url, DeprecatedString const& origin) override + virtual RefPtr connect(AK::URL const& url, DeprecatedString const& origin, Vector const& protocols) override { WebSocket::ConnectionInfo connection_info(url); connection_info.set_origin(origin); + connection_info.set_protocols(protocols); auto connection = HeadlessWebSocket::create(WebSocket::WebSocket::create(move(connection_info))); return connection;