ソースを参照

LibWebSocket: Switch to using Core::Stream

As LibTLS now supports the Core::Stream APIs, we can get rid of the
split paths for TCP/TLS and significantly simplify the code as well.
Provided to you free of charge by the Core::Stream-ification team :^)
Ali Mohammad Pur 3 年 前
コミット
3f614a8fca

+ 1 - 3
Userland/Libraries/LibWebSocket/CMakeLists.txt

@@ -1,8 +1,6 @@
 set(SOURCES
     ConnectionInfo.cpp
-    Impl/AbstractWebSocketImpl.cpp
-    Impl/TCPWebSocketConnectionImpl.cpp
-    Impl/TLSv12WebSocketConnectionImpl.cpp
+    Impl/WebSocketImpl.cpp
     WebSocket.cpp
 )
 

+ 0 - 20
Userland/Libraries/LibWebSocket/Impl/AbstractWebSocketImpl.cpp

@@ -1,20 +0,0 @@
-/*
- * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com>
- *
- * SPDX-License-Identifier: BSD-2-Clause
- */
-
-#include <LibWebSocket/Impl/AbstractWebSocketImpl.h>
-
-namespace WebSocket {
-
-AbstractWebSocketImpl::AbstractWebSocketImpl(Core::Object* parent)
-    : Object(parent)
-{
-}
-
-AbstractWebSocketImpl::~AbstractWebSocketImpl()
-{
-}
-
-}

+ 0 - 43
Userland/Libraries/LibWebSocket/Impl/AbstractWebSocketImpl.h

@@ -1,43 +0,0 @@
-/*
- * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com>
- *
- * SPDX-License-Identifier: BSD-2-Clause
- */
-
-#pragma once
-
-#include <AK/ByteBuffer.h>
-#include <AK/Span.h>
-#include <AK/String.h>
-#include <LibCore/Object.h>
-#include <LibWebSocket/ConnectionInfo.h>
-
-namespace WebSocket {
-
-class AbstractWebSocketImpl : public Core::Object {
-    C_OBJECT_ABSTRACT(AbstractWebSocketImpl);
-
-public:
-    virtual ~AbstractWebSocketImpl() override;
-    explicit AbstractWebSocketImpl(Core::Object* parent = nullptr);
-
-    virtual void connect(ConnectionInfo const&) = 0;
-
-    virtual bool can_read_line() = 0;
-    virtual String read_line(size_t size) = 0;
-
-    virtual bool can_read() = 0;
-    virtual ByteBuffer read(int max_size) = 0;
-
-    virtual bool send(ReadonlyBytes) = 0;
-
-    virtual bool eof() = 0;
-
-    virtual void discard_connection() = 0;
-
-    Function<void()> on_connected;
-    Function<void()> on_connection_error;
-    Function<void()> on_ready_to_read;
-};
-
-}

+ 0 - 84
Userland/Libraries/LibWebSocket/Impl/TCPWebSocketConnectionImpl.cpp

@@ -1,84 +0,0 @@
-/*
- * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com>
- *
- * SPDX-License-Identifier: BSD-2-Clause
- */
-
-#include <LibWebSocket/Impl/TCPWebSocketConnectionImpl.h>
-
-namespace WebSocket {
-
-TCPWebSocketConnectionImpl::TCPWebSocketConnectionImpl(Core::Object* parent)
-    : AbstractWebSocketImpl(parent)
-{
-}
-
-TCPWebSocketConnectionImpl::~TCPWebSocketConnectionImpl()
-{
-    discard_connection();
-}
-
-void TCPWebSocketConnectionImpl::connect(ConnectionInfo const& connection)
-{
-    VERIFY(!m_socket);
-    VERIFY(on_connected);
-    VERIFY(on_connection_error);
-    VERIFY(on_ready_to_read);
-    m_socket = Core::TCPSocket::construct(this);
-
-    m_notifier = Core::Notifier::construct(m_socket->fd(), Core::Notifier::Read);
-    m_notifier->on_ready_to_read = [this] {
-        on_ready_to_read();
-    };
-
-    m_socket->on_connected = [this] {
-        on_connected();
-    };
-    bool success = m_socket->connect(connection.url().host(), connection.url().port_or_default());
-    if (!success) {
-        deferred_invoke([this] {
-            on_connection_error();
-        });
-    }
-}
-
-bool TCPWebSocketConnectionImpl::send(ReadonlyBytes data)
-{
-    return m_socket->write(data);
-}
-
-bool TCPWebSocketConnectionImpl::can_read_line()
-{
-    return m_socket->can_read_line();
-}
-
-String TCPWebSocketConnectionImpl::read_line(size_t size)
-{
-    return m_socket->read_line(size);
-}
-
-bool TCPWebSocketConnectionImpl::can_read()
-{
-    return m_socket->can_read();
-}
-
-ByteBuffer TCPWebSocketConnectionImpl::read(int max_size)
-{
-    return m_socket->read(max_size);
-}
-
-bool TCPWebSocketConnectionImpl::eof()
-{
-    return m_socket->eof();
-}
-
-void TCPWebSocketConnectionImpl::discard_connection()
-{
-    if (!m_socket)
-        return;
-    m_socket->on_ready_to_read = nullptr;
-    remove_child(*m_socket);
-    m_socket = nullptr;
-}
-
-}

+ 0 - 47
Userland/Libraries/LibWebSocket/Impl/TCPWebSocketConnectionImpl.h

@@ -1,47 +0,0 @@
-/*
- * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com>
- *
- * SPDX-License-Identifier: BSD-2-Clause
- */
-
-#pragma once
-
-#include <AK/ByteBuffer.h>
-#include <AK/Span.h>
-#include <AK/String.h>
-#include <LibCore/Notifier.h>
-#include <LibCore/Object.h>
-#include <LibCore/TCPSocket.h>
-#include <LibWebSocket/ConnectionInfo.h>
-#include <LibWebSocket/Impl/AbstractWebSocketImpl.h>
-
-namespace WebSocket {
-
-class TCPWebSocketConnectionImpl final : public AbstractWebSocketImpl {
-    C_OBJECT(TCPWebSocketConnectionImpl);
-
-public:
-    virtual ~TCPWebSocketConnectionImpl() override;
-
-    virtual void connect(ConnectionInfo const& connection) override;
-
-    virtual bool can_read_line() override;
-    virtual String read_line(size_t size) override;
-
-    virtual bool can_read() override;
-    virtual ByteBuffer read(int max_size) override;
-
-    virtual bool send(ReadonlyBytes data) override;
-
-    virtual bool eof() override;
-
-    virtual void discard_connection() override;
-
-private:
-    explicit TCPWebSocketConnectionImpl(Core::Object* parent = nullptr);
-
-    RefPtr<Core::Notifier> m_notifier;
-    RefPtr<Core::TCPSocket> m_socket;
-};
-
-}

+ 0 - 87
Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp

@@ -1,87 +0,0 @@
-/*
- * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com>
- *
- * SPDX-License-Identifier: BSD-2-Clause
- */
-
-#include <LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.h>
-
-namespace WebSocket {
-
-TLSv12WebSocketConnectionImpl::TLSv12WebSocketConnectionImpl(Core::Object* parent)
-    : AbstractWebSocketImpl(parent)
-{
-}
-
-TLSv12WebSocketConnectionImpl::~TLSv12WebSocketConnectionImpl()
-{
-    discard_connection();
-}
-
-void TLSv12WebSocketConnectionImpl::connect(ConnectionInfo const& connection)
-{
-    VERIFY(!m_socket);
-    VERIFY(on_connected);
-    VERIFY(on_connection_error);
-    VERIFY(on_ready_to_read);
-    m_socket = TLS::TLSv12::connect(connection.url().host(), connection.url().port_or_default()).release_value_but_fixme_should_propagate_errors();
-
-    m_socket->on_tls_error = [this](TLS::AlertDescription) {
-        on_connection_error();
-    };
-    m_socket->on_ready_to_read = [this] {
-        on_ready_to_read();
-    };
-    m_socket->on_tls_finished = [this] {
-        on_connection_error();
-    };
-    m_socket->on_tls_certificate_request = [](auto&) {
-        // FIXME : Once we handle TLS certificate requests, handle it here as well.
-    };
-    on_connected();
-}
-
-bool TLSv12WebSocketConnectionImpl::send(ReadonlyBytes data)
-{
-    return m_socket->write_or_error(data);
-}
-
-bool TLSv12WebSocketConnectionImpl::can_read_line()
-{
-    return m_socket->can_read_line();
-}
-
-String TLSv12WebSocketConnectionImpl::read_line(size_t size)
-{
-    return m_socket->read_line(size);
-}
-
-bool TLSv12WebSocketConnectionImpl::can_read()
-{
-    return m_socket->can_read();
-}
-
-ByteBuffer TLSv12WebSocketConnectionImpl::read(int max_size)
-{
-    auto buffer = ByteBuffer::create_uninitialized(max_size).release_value_but_fixme_should_propagate_errors();
-    auto nread = m_socket->read(buffer).release_value_but_fixme_should_propagate_errors();
-    return buffer.slice(0, nread);
-}
-
-bool TLSv12WebSocketConnectionImpl::eof()
-{
-    return m_socket->is_eof();
-}
-
-void TLSv12WebSocketConnectionImpl::discard_connection()
-{
-    if (!m_socket)
-        return;
-    m_socket->on_tls_error = nullptr;
-    m_socket->on_tls_finished = nullptr;
-    m_socket->on_tls_certificate_request = nullptr;
-    m_socket->on_ready_to_read = nullptr;
-    m_socket = nullptr;
-}
-
-}

+ 0 - 45
Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.h

@@ -1,45 +0,0 @@
-/*
- * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com>
- *
- * SPDX-License-Identifier: BSD-2-Clause
- */
-
-#pragma once
-
-#include <AK/ByteBuffer.h>
-#include <AK/Span.h>
-#include <AK/String.h>
-#include <LibCore/Object.h>
-#include <LibTLS/TLSv12.h>
-#include <LibWebSocket/ConnectionInfo.h>
-#include <LibWebSocket/Impl/AbstractWebSocketImpl.h>
-
-namespace WebSocket {
-
-class TLSv12WebSocketConnectionImpl final : public AbstractWebSocketImpl {
-    C_OBJECT(TLSv12WebSocketConnectionImpl);
-
-public:
-    virtual ~TLSv12WebSocketConnectionImpl() override;
-
-    void connect(ConnectionInfo const& connection) override;
-
-    virtual bool can_read_line() override;
-    virtual String read_line(size_t size) override;
-
-    virtual bool can_read() override;
-    virtual ByteBuffer read(int max_size) override;
-
-    virtual bool send(ReadonlyBytes data) override;
-
-    virtual bool eof() override;
-
-    virtual void discard_connection() override;
-
-private:
-    explicit TLSv12WebSocketConnectionImpl(Core::Object* parent = nullptr);
-
-    OwnPtr<TLS::TLSv12> m_socket;
-};
-
-}

+ 73 - 0
Userland/Libraries/LibWebSocket/Impl/WebSocketImpl.cpp

@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com>
+ * Copyright (c) 2022, Ali Mohammad Pur <mpfard@serenityos.org>
+ *
+ * SPDX-License-Identifier: BSD-2-Clause
+ */
+
+#include <LibWebSocket/Impl/WebSocketImpl.h>
+
+namespace WebSocket {
+
+WebSocketImpl::WebSocketImpl(Core::Object* parent)
+    : Object(parent)
+{
+}
+
+WebSocketImpl::~WebSocketImpl()
+{
+}
+
+void WebSocketImpl::connect(ConnectionInfo const& connection_info)
+{
+    VERIFY(!m_socket);
+    VERIFY(on_connected);
+    VERIFY(on_connection_error);
+    VERIFY(on_ready_to_read);
+    auto socket_result = [&]() -> ErrorOr<NonnullOwnPtr<Core::Stream::BufferedSocketBase>> {
+        if (connection_info.is_secure()) {
+            TLS::Options options;
+            options.set_alert_handler([this](auto) {
+                on_connection_error();
+            });
+            return TRY(Core::Stream::BufferedSocket<TLS::TLSv12>::create(
+                TRY(TLS::TLSv12::connect(connection_info.url().host(), connection_info.url().port_or_default(), move(options)))));
+        }
+
+        return TRY(Core::Stream::BufferedTCPSocket::create(
+            TRY(Core::Stream::TCPSocket::connect(connection_info.url().host(), connection_info.url().port_or_default()))));
+    }();
+
+    if (socket_result.is_error()) {
+        deferred_invoke([this] {
+            on_connection_error();
+        });
+        return;
+    }
+
+    m_socket = socket_result.release_value();
+
+    m_socket->on_ready_to_read = [this] {
+        on_ready_to_read();
+    };
+
+    deferred_invoke([this] {
+        on_connected();
+    });
+}
+
+ErrorOr<ByteBuffer> WebSocketImpl::read(int max_size)
+{
+    auto buffer = TRY(ByteBuffer::create_uninitialized(max_size));
+    auto nread = TRY(m_socket->read(buffer));
+    return buffer.slice(0, nread);
+}
+
+ErrorOr<String> WebSocketImpl::read_line(size_t size)
+{
+    auto buffer = TRY(ByteBuffer::create_uninitialized(size));
+    auto nread = TRY(m_socket->read_line(buffer));
+    return String::copy(buffer.span().slice(0, nread));
+}
+
+}

+ 50 - 0
Userland/Libraries/LibWebSocket/Impl/WebSocketImpl.h

@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com>
+ * Copyright (c) 2022, Ali Mohammad Pur <mpfard@serenityos.org>
+ *
+ * SPDX-License-Identifier: BSD-2-Clause
+ */
+
+#pragma once
+
+#include <AK/ByteBuffer.h>
+#include <AK/Span.h>
+#include <AK/String.h>
+#include <LibCore/Object.h>
+#include <LibWebSocket/ConnectionInfo.h>
+
+namespace WebSocket {
+
+class WebSocketImpl : public Core::Object {
+    C_OBJECT(WebSocketImpl);
+
+public:
+    virtual ~WebSocketImpl() override;
+    explicit WebSocketImpl(Core::Object* parent = nullptr);
+
+    void connect(ConnectionInfo const&);
+
+    bool can_read_line() { return MUST(m_socket->can_read_line()); }
+    ErrorOr<String> read_line(size_t size);
+
+    bool can_read() { return MUST(m_socket->can_read_without_blocking()); }
+    ErrorOr<ByteBuffer> read(int max_size);
+
+    bool send(ReadonlyBytes bytes) { return m_socket->write_or_error(bytes); }
+
+    bool eof() { return m_socket->is_eof(); }
+
+    void discard_connection()
+    {
+        m_socket.clear();
+    }
+
+    Function<void()> on_connected;
+    Function<void()> on_connection_error;
+    Function<void()> on_ready_to_read;
+
+private:
+    OwnPtr<Core::Stream::BufferedSocketBase> m_socket;
+};
+
+}

+ 16 - 17
Userland/Libraries/LibWebSocket/WebSocket.cpp

@@ -7,8 +7,6 @@
 #include <AK/Base64.h>
 #include <AK/Random.h>
 #include <LibCrypto/Hash/HashManager.h>
-#include <LibWebSocket/Impl/TCPWebSocketConnectionImpl.h>
-#include <LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.h>
 #include <LibWebSocket/WebSocket.h>
 #include <unistd.h>
 
@@ -35,10 +33,7 @@ void WebSocket::start()
 {
     VERIFY(m_state == WebSocket::InternalState::NotStarted);
     VERIFY(!m_impl);
-    if (m_connection.is_secure())
-        m_impl = TLSv12WebSocketConnectionImpl::construct();
-    else
-        m_impl = TCPWebSocketConnectionImpl::construct();
+    m_impl = WebSocketImpl::construct();
 
     m_impl->on_connection_error = [this] {
         dbgln("WebSocket: Connection error (underlying socket)");
@@ -117,7 +112,8 @@ void WebSocket::drain_read()
     case InternalState::EstablishingProtocolConnection:
     case InternalState::SendingClientHandshake: {
         auto initializing_bytes = m_impl->read(1024);
-        dbgln("drain_read() was called on a websocket that isn't opened yet. Read {} bytes from the socket.", initializing_bytes.size());
+        if (!initializing_bytes.is_error())
+            dbgln("drain_read() was called on a websocket that isn't opened yet. Read {} bytes from the socket.", initializing_bytes.value().size());
     } break;
     case InternalState::WaitingForServerHandshake: {
         read_server_handshake();
@@ -129,7 +125,8 @@ void WebSocket::drain_read()
     case InternalState::Closed:
     case InternalState::Errored: {
         auto closed_bytes = m_impl->read(1024);
-        dbgln("drain_read() was called on a closed websocket. Read {} bytes from the socket.", closed_bytes.size());
+        if (!closed_bytes.is_error())
+            dbgln("drain_read() was called on a closed websocket. Read {} bytes from the socket.", closed_bytes.value().size());
     } break;
     default:
         VERIFY_NOT_REACHED();
@@ -209,7 +206,7 @@ void WebSocket::read_server_handshake()
         return;
 
     if (!m_has_read_server_handshake_first_line) {
-        auto header = m_impl->read_line(PAGE_SIZE);
+        auto header = m_impl->read_line(PAGE_SIZE).release_value_but_fixme_should_propagate_errors();
         auto parts = header.split(' ');
         if (parts.size() < 2) {
             dbgln("WebSocket: Server HTTP Handshake contained HTTP header was malformed");
@@ -235,7 +232,7 @@ void WebSocket::read_server_handshake()
 
     // Read the rest of the reply until we find an empty line
     while (m_impl->can_read_line()) {
-        auto line = m_impl->read_line(PAGE_SIZE);
+        auto line = m_impl->read_line(PAGE_SIZE).release_value_but_fixme_should_propagate_errors();
         if (line.is_whitespace()) {
             // We're done with the HTTP headers.
             // Fail the connection if we're missing any of the following:
@@ -364,14 +361,15 @@ void WebSocket::read_frame()
     VERIFY(m_impl);
     VERIFY(m_state == WebSocket::InternalState::Open || m_state == WebSocket::InternalState::Closing);
 
-    auto head_bytes = m_impl->read(2);
-    if (head_bytes.size() == 0) {
+    auto head_bytes_result = m_impl->read(2);
+    if (head_bytes_result.is_error() || head_bytes_result.value().is_empty()) {
         // The connection got closed.
         m_state = WebSocket::InternalState::Closed;
         notify_close(m_last_close_code, m_last_close_message, true);
         discard_connection();
         return;
     }
+    auto head_bytes = head_bytes_result.release_value();
     VERIFY(head_bytes.size() == 2);
 
     bool is_final_frame = head_bytes[0] & 0x80;
@@ -388,7 +386,7 @@ void WebSocket::read_frame()
     auto payload_length_bits = head_bytes[1] & 0x7f;
     if (payload_length_bits == 127) {
         // A code of 127 means that the next 8 bytes contains the payload length
-        auto actual_bytes = m_impl->read(8);
+        auto actual_bytes = MUST(m_impl->read(8));
         VERIFY(actual_bytes.size() == 8);
         u64 full_payload_length = (u64)((u64)(actual_bytes[0] & 0xff) << 56)
             | (u64)((u64)(actual_bytes[1] & 0xff) << 48)
@@ -402,7 +400,7 @@ void WebSocket::read_frame()
         payload_length = (size_t)full_payload_length;
     } else if (payload_length_bits == 126) {
         // A code of 126 means that the next 2 bytes contains the payload length
-        auto actual_bytes = m_impl->read(2);
+        auto actual_bytes = MUST(m_impl->read(2));
         VERIFY(actual_bytes.size() == 2);
         payload_length = (size_t)((size_t)(actual_bytes[0] & 0xff) << 8)
             | (size_t)((size_t)(actual_bytes[1] & 0xff) << 0);
@@ -418,7 +416,7 @@ void WebSocket::read_frame()
     // But because it doesn't cost much, we can support receiving masked frames anyways.
     u8 masking_key[4];
     if (is_masked) {
-        auto masking_key_data = m_impl->read(4);
+        auto masking_key_data = MUST(m_impl->read(4));
         VERIFY(masking_key_data.size() == 4);
         masking_key[0] = masking_key_data[0];
         masking_key[1] = masking_key_data[1];
@@ -429,13 +427,14 @@ void WebSocket::read_frame()
     auto payload = ByteBuffer::create_uninitialized(payload_length).release_value_but_fixme_should_propagate_errors(); // FIXME: Handle possible OOM situation.
     u64 read_length = 0;
     while (read_length < payload_length) {
-        auto payload_part = m_impl->read(payload_length - read_length);
-        if (payload_part.size() == 0) {
+        auto payload_part_result = m_impl->read(payload_length - read_length);
+        if (payload_part_result.is_error() || payload_part_result.value().is_empty()) {
             // We got disconnected, somehow.
             dbgln("Websocket: Server disconnected while sending payload ({} bytes read out of {})", read_length, payload_length);
             fatal_error(WebSocket::Error::ServerClosedSocket);
             return;
         }
+        auto payload_part = payload_part_result.release_value();
         // We read at most "actual_length - read" bytes, so this is safe to do.
         payload.overwrite(read_length, payload_part.data(), payload_part.size());
         read_length -= payload_part.size();

+ 2 - 2
Userland/Libraries/LibWebSocket/WebSocket.h

@@ -9,7 +9,7 @@
 #include <AK/Span.h>
 #include <LibCore/Object.h>
 #include <LibWebSocket/ConnectionInfo.h>
-#include <LibWebSocket/Impl/AbstractWebSocketImpl.h>
+#include <LibWebSocket/Impl/WebSocketImpl.h>
 #include <LibWebSocket/Message.h>
 
 namespace WebSocket {
@@ -104,7 +104,7 @@ private:
     String m_last_close_message;
 
     ConnectionInfo m_connection;
-    RefPtr<AbstractWebSocketImpl> m_impl;
+    RefPtr<WebSocketImpl> m_impl;
 };
 
 }