LibCore+LibTLS: Add an API for connect()'ing 'with hostname

This just unifies the API for all three sockets (UDP, TCP and TLS)
This commit is contained in:
Ali Mohammad Pur 2024-11-01 23:51:21 +01:00 committed by Ali Mohammad Pur
parent b93d8ef875
commit d704b61066
Notes: github-actions[bot] 2024-11-20 20:44:42 +00:00
3 changed files with 31 additions and 10 deletions

View file

@ -159,6 +159,7 @@ class TCPSocket final : public Socket {
public: public:
static ErrorOr<NonnullOwnPtr<TCPSocket>> connect(ByteString const& host, u16 port); static ErrorOr<NonnullOwnPtr<TCPSocket>> connect(ByteString const& host, u16 port);
static ErrorOr<NonnullOwnPtr<TCPSocket>> connect(SocketAddress const& address); static ErrorOr<NonnullOwnPtr<TCPSocket>> connect(SocketAddress const& address);
static ErrorOr<NonnullOwnPtr<TCPSocket>> connect(SocketAddress const& address, ByteString const&) { return connect(address); }
static ErrorOr<NonnullOwnPtr<TCPSocket>> adopt_fd(int fd); static ErrorOr<NonnullOwnPtr<TCPSocket>> adopt_fd(int fd);
TCPSocket(TCPSocket&& other) TCPSocket(TCPSocket&& other)
@ -220,6 +221,7 @@ class UDPSocket final : public Socket {
public: public:
static ErrorOr<NonnullOwnPtr<UDPSocket>> connect(ByteString const& host, u16 port, Optional<AK::Duration> timeout = {}); static ErrorOr<NonnullOwnPtr<UDPSocket>> connect(ByteString const& host, u16 port, Optional<AK::Duration> timeout = {});
static ErrorOr<NonnullOwnPtr<UDPSocket>> connect(SocketAddress const& address, Optional<AK::Duration> timeout = {}); static ErrorOr<NonnullOwnPtr<UDPSocket>> connect(SocketAddress const& address, Optional<AK::Duration> timeout = {});
static ErrorOr<NonnullOwnPtr<UDPSocket>> connect(SocketAddress const& address, ByteString const&, Optional<AK::Duration> timeout = {}) { return connect(address, timeout); }
UDPSocket(UDPSocket&& other) UDPSocket(UDPSocket&& other)
: Socket(static_cast<Socket&&>(other)) : Socket(static_cast<Socket&&>(other))

View file

@ -73,6 +73,29 @@ ErrorOr<NonnullOwnPtr<TLSv12>> TLSv12::connect(ByteString const& host, u16 port,
return tls_socket; return tls_socket;
} }
ErrorOr<NonnullOwnPtr<TLSv12>> TLSv12::connect(Core::SocketAddress address, ByteString const& host, Options options)
{
auto promise = Core::Promise<Empty>::construct();
OwnPtr<Core::Socket> tcp_socket = TRY(Core::TCPSocket::connect(address));
TRY(tcp_socket->set_blocking(false));
auto tls_socket = make<TLSv12>(move(tcp_socket), move(options));
tls_socket->set_sni(host);
tls_socket->on_connected = [&] {
promise->resolve({});
};
tls_socket->on_tls_error = [&](auto alert) {
tls_socket->try_disambiguate_error();
promise->reject(AK::Error::from_string_view(enum_to_string(alert)));
};
TRY(promise->await());
tls_socket->on_tls_error = nullptr;
tls_socket->on_connected = nullptr;
tls_socket->m_context.should_expect_successful_read = true;
return tls_socket;
}
ErrorOr<NonnullOwnPtr<TLSv12>> TLSv12::connect(ByteString const& host, Core::Socket& underlying_stream, Options options) ErrorOr<NonnullOwnPtr<TLSv12>> TLSv12::connect(ByteString const& host, Core::Socket& underlying_stream, Options options)
{ {
auto promise = Core::Promise<Empty>::construct(); auto promise = Core::Promise<Empty>::construct();
@ -271,7 +294,8 @@ bool TLSv12::check_connection_state(bool read)
ErrorOr<bool> TLSv12::flush() ErrorOr<bool> TLSv12::flush()
{ {
auto out_bytes = m_context.tls_buffer.bytes(); ByteBuffer out = move(m_context.tls_buffer);
auto out_bytes = out.bytes();
if (out_bytes.is_empty()) if (out_bytes.is_empty())
return true; return true;
@ -298,17 +322,11 @@ ErrorOr<bool> TLSv12::flush()
out_bytes = out_bytes.slice(written); out_bytes = out_bytes.slice(written);
} while (!out_bytes.is_empty()); } while (!out_bytes.is_empty());
if (out_bytes.is_empty() && !error.has_value()) { if (out_bytes.is_empty() && !error.has_value())
m_context.tls_buffer.clear();
return true; return true;
}
if (m_context.send_retries++ == 10) { if (!out_bytes.is_empty())
// drop the records, we can't send dbgln("Dropping {} bytes worth of TLS records on the floor", out_bytes.size());
dbgln_if(TLS_DEBUG, "Dropping {} bytes worth of TLS records as max retries has been reached", m_context.tls_buffer.size());
m_context.tls_buffer.clear();
m_context.send_retries = 0;
}
return false; return false;
} }

View file

@ -357,6 +357,7 @@ public:
virtual void set_notifications_enabled(bool enabled) override { underlying_stream().set_notifications_enabled(enabled); } virtual void set_notifications_enabled(bool enabled) override { underlying_stream().set_notifications_enabled(enabled); }
static ErrorOr<NonnullOwnPtr<TLSv12>> connect(Core::SocketAddress, ByteString const& host, Options = {});
static ErrorOr<NonnullOwnPtr<TLSv12>> connect(ByteString const& host, u16 port, Options = {}); static ErrorOr<NonnullOwnPtr<TLSv12>> connect(ByteString const& host, u16 port, Options = {});
static ErrorOr<NonnullOwnPtr<TLSv12>> connect(ByteString const& host, Core::Socket& underlying_stream, Options = {}); static ErrorOr<NonnullOwnPtr<TLSv12>> connect(ByteString const& host, Core::Socket& underlying_stream, Options = {});