diff --git a/Libraries/LibCore/Socket.h b/Libraries/LibCore/Socket.h index 5159d3e1190..93061b20fc4 100644 --- a/Libraries/LibCore/Socket.h +++ b/Libraries/LibCore/Socket.h @@ -159,6 +159,7 @@ class TCPSocket final : public Socket { public: static ErrorOr> connect(ByteString const& host, u16 port); static ErrorOr> connect(SocketAddress const& address); + static ErrorOr> connect(SocketAddress const& address, ByteString const&) { return connect(address); } static ErrorOr> adopt_fd(int fd); TCPSocket(TCPSocket&& other) @@ -220,6 +221,7 @@ class UDPSocket final : public Socket { public: static ErrorOr> connect(ByteString const& host, u16 port, Optional timeout = {}); static ErrorOr> connect(SocketAddress const& address, Optional timeout = {}); + static ErrorOr> connect(SocketAddress const& address, ByteString const&, Optional timeout = {}) { return connect(address, timeout); } UDPSocket(UDPSocket&& other) : Socket(static_cast(other)) diff --git a/Libraries/LibTLS/Socket.cpp b/Libraries/LibTLS/Socket.cpp index 495d0e0d553..e7f8b5dee74 100644 --- a/Libraries/LibTLS/Socket.cpp +++ b/Libraries/LibTLS/Socket.cpp @@ -73,6 +73,29 @@ ErrorOr> TLSv12::connect(ByteString const& host, u16 port, return tls_socket; } +ErrorOr> TLSv12::connect(Core::SocketAddress address, ByteString const& host, Options options) +{ + auto promise = Core::Promise::construct(); + OwnPtr tcp_socket = TRY(Core::TCPSocket::connect(address)); + TRY(tcp_socket->set_blocking(false)); + auto tls_socket = make(move(tcp_socket), move(options)); + tls_socket->set_sni(host); + tls_socket->on_connected = [&] { + promise->resolve({}); + }; + tls_socket->on_tls_error = [&](auto alert) { + tls_socket->try_disambiguate_error(); + promise->reject(AK::Error::from_string_view(enum_to_string(alert))); + }; + + TRY(promise->await()); + + tls_socket->on_tls_error = nullptr; + tls_socket->on_connected = nullptr; + tls_socket->m_context.should_expect_successful_read = true; + return tls_socket; +} + ErrorOr> TLSv12::connect(ByteString const& host, Core::Socket& underlying_stream, Options options) { auto promise = Core::Promise::construct(); @@ -271,7 +294,8 @@ bool TLSv12::check_connection_state(bool read) ErrorOr TLSv12::flush() { - auto out_bytes = m_context.tls_buffer.bytes(); + ByteBuffer out = move(m_context.tls_buffer); + auto out_bytes = out.bytes(); if (out_bytes.is_empty()) return true; @@ -298,17 +322,11 @@ ErrorOr TLSv12::flush() out_bytes = out_bytes.slice(written); } while (!out_bytes.is_empty()); - if (out_bytes.is_empty() && !error.has_value()) { - m_context.tls_buffer.clear(); + if (out_bytes.is_empty() && !error.has_value()) return true; - } - if (m_context.send_retries++ == 10) { - // drop the records, we can't send - dbgln_if(TLS_DEBUG, "Dropping {} bytes worth of TLS records as max retries has been reached", m_context.tls_buffer.size()); - m_context.tls_buffer.clear(); - m_context.send_retries = 0; - } + if (!out_bytes.is_empty()) + dbgln("Dropping {} bytes worth of TLS records on the floor", out_bytes.size()); return false; } diff --git a/Libraries/LibTLS/TLSv12.h b/Libraries/LibTLS/TLSv12.h index 806fd9ff9b0..02b25fd004e 100644 --- a/Libraries/LibTLS/TLSv12.h +++ b/Libraries/LibTLS/TLSv12.h @@ -357,6 +357,7 @@ public: virtual void set_notifications_enabled(bool enabled) override { underlying_stream().set_notifications_enabled(enabled); } + static ErrorOr> connect(Core::SocketAddress, ByteString const& host, Options = {}); static ErrorOr> connect(ByteString const& host, u16 port, Options = {}); static ErrorOr> connect(ByteString const& host, Core::Socket& underlying_stream, Options = {});