Browse Source

LibCore: Support IPv6 for TCP and UDP connection

Salem Yaslem 1 year ago
parent
commit
ab82fc8993

+ 2 - 0
AK/Forward.h

@@ -35,6 +35,7 @@ class Error;
 class FlyString;
 class GenericLexer;
 class IPv4Address;
+class IPv6Address;
 class JsonArray;
 class JsonObject;
 class JsonValue;
@@ -167,6 +168,7 @@ using AK::GenericLexer;
 using AK::HashMap;
 using AK::HashTable;
 using AK::IPv4Address;
+using AK::IPv6Address;
 using AK::JsonArray;
 using AK::JsonObject;
 using AK::JsonValue;

+ 34 - 8
Userland/Libraries/LibCore/Socket.cpp

@@ -19,6 +19,9 @@ ErrorOr<int> Socket::create_fd(SocketDomain domain, SocketType type)
     case SocketDomain::Inet:
         socket_domain = AF_INET;
         break;
+    case SocketDomain::Inet6:
+        socket_domain = AF_INET6;
+        break;
     case SocketDomain::Local:
         socket_domain = AF_LOCAL;
         break;
@@ -48,7 +51,7 @@ ErrorOr<int> Socket::create_fd(SocketDomain domain, SocketType type)
 #endif
 }
 
-ErrorOr<IPv4Address> Socket::resolve_host(ByteString const& host, SocketType type)
+ErrorOr<Variant<IPv4Address, IPv6Address>> Socket::resolve_host(ByteString const& host, SocketType type)
 {
     int socket_type;
     switch (type) {
@@ -71,6 +74,13 @@ ErrorOr<IPv4Address> Socket::resolve_host(ByteString const& host, SocketType typ
     auto const results = TRY(Core::System::getaddrinfo(host.characters(), nullptr, hints));
 
     for (auto const& result : results.addresses()) {
+        if (result.ai_family == AF_INET6) {
+            auto* socket_address = bit_cast<struct sockaddr_in6*>(result.ai_addr);
+            auto address = IPv6Address { socket_address->sin6_addr.s6_addr };
+
+            return address;
+        }
+
         if (result.ai_family == AF_INET) {
             auto* socket_address = bit_cast<struct sockaddr_in*>(result.ai_addr);
             NetworkOrdered<u32> const network_ordered_address { socket_address->sin_addr.s_addr };
@@ -78,7 +88,7 @@ ErrorOr<IPv4Address> Socket::resolve_host(ByteString const& host, SocketType typ
         }
     }
 
-    return Error::from_string_literal("Could not resolve to IPv4 address");
+    return Error::from_string_literal("Could not resolve to IPv4 or IPv6 address");
 }
 
 ErrorOr<void> Socket::connect_local(int fd, ByteString const& path)
@@ -96,8 +106,13 @@ ErrorOr<void> Socket::connect_local(int fd, ByteString const& path)
 
 ErrorOr<void> Socket::connect_inet(int fd, SocketAddress const& address)
 {
-    auto addr = address.to_sockaddr_in();
-    return System::connect(fd, bit_cast<struct sockaddr*>(&addr), sizeof(addr));
+    if (address.type() == SocketAddress::Type::IPv6) {
+        auto addr = address.to_sockaddr_in6();
+        return System::connect(fd, bit_cast<struct sockaddr*>(&addr), sizeof(addr));
+    } else {
+        auto addr = address.to_sockaddr_in();
+        return System::connect(fd, bit_cast<struct sockaddr*>(&addr), sizeof(addr));
+    }
 }
 
 ErrorOr<Bytes> PosixSocketHelper::read(Bytes buffer, int flags)
@@ -200,14 +215,19 @@ void PosixSocketHelper::setup_notifier()
 ErrorOr<NonnullOwnPtr<TCPSocket>> TCPSocket::connect(ByteString const& host, u16 port)
 {
     auto ip_address = TRY(resolve_host(host, SocketType::Stream));
-    return connect(SocketAddress { ip_address, port });
+
+    return ip_address.visit([port](auto address) { return connect(SocketAddress { address, port }); });
 }
 
 ErrorOr<NonnullOwnPtr<TCPSocket>> TCPSocket::connect(SocketAddress const& address)
 {
     auto socket = TRY(adopt_nonnull_own_or_enomem(new (nothrow) TCPSocket()));
 
-    auto fd = TRY(create_fd(SocketDomain::Inet, SocketType::Stream));
+    auto socket_domain = SocketDomain::Inet6;
+    if (address.type() == SocketAddress::Type::IPv4)
+        socket_domain = SocketDomain::Inet;
+
+    auto fd = TRY(create_fd(socket_domain, SocketType::Stream));
     socket->m_helper.set_fd(fd);
 
     TRY(connect_inet(fd, address));
@@ -242,14 +262,19 @@ ErrorOr<size_t> PosixSocketHelper::pending_bytes() const
 ErrorOr<NonnullOwnPtr<UDPSocket>> UDPSocket::connect(ByteString const& host, u16 port, Optional<Duration> timeout)
 {
     auto ip_address = TRY(resolve_host(host, SocketType::Datagram));
-    return connect(SocketAddress { ip_address, port }, timeout);
+
+    return ip_address.visit([port, timeout](auto address) { return connect(SocketAddress { address, port }, timeout); });
 }
 
 ErrorOr<NonnullOwnPtr<UDPSocket>> UDPSocket::connect(SocketAddress const& address, Optional<Duration> timeout)
 {
     auto socket = TRY(adopt_nonnull_own_or_enomem(new (nothrow) UDPSocket()));
 
-    auto fd = TRY(create_fd(SocketDomain::Inet, SocketType::Datagram));
+    auto socket_domain = SocketDomain::Inet6;
+    if (address.type() == SocketAddress::Type::IPv4)
+        socket_domain = SocketDomain::Inet;
+
+    auto fd = TRY(create_fd(socket_domain, SocketType::Datagram));
     socket->m_helper.set_fd(fd);
     if (timeout.has_value()) {
         TRY(socket->m_helper.set_receive_timeout(timeout.value()));
@@ -258,6 +283,7 @@ ErrorOr<NonnullOwnPtr<UDPSocket>> UDPSocket::connect(SocketAddress const& addres
     TRY(connect_inet(fd, address));
 
     socket->setup_notifier();
+
     return socket;
 }
 

+ 2 - 1
Userland/Libraries/LibCore/Socket.h

@@ -58,7 +58,7 @@ public:
 
     // FIXME: This will need to be updated when IPv6 socket arrives. Perhaps a
     //        base class for all address types is appropriate.
-    static ErrorOr<IPv4Address> resolve_host(ByteString const&, SocketType);
+    static ErrorOr<Variant<IPv4Address, IPv6Address>> resolve_host(ByteString const&, SocketType);
 
     Function<void()> on_ready_to_read;
 
@@ -66,6 +66,7 @@ protected:
     enum class SocketDomain {
         Local,
         Inet,
+        Inet6,
     };
 
     explicit Socket(PreventSIGPIPE prevent_sigpipe = PreventSIGPIPE::Yes)

+ 40 - 7
Userland/Libraries/LibCore/SocketAddress.h

@@ -8,6 +8,7 @@
 #pragma once
 
 #include <AK/IPv4Address.h>
+#include <AK/IPv6Address.h>
 #include <arpa/inet.h>
 #include <netinet/in.h>
 #include <string.h>
@@ -21,19 +22,33 @@ public:
     enum class Type {
         Invalid,
         IPv4,
+        IPv6,
         Local
     };
 
     SocketAddress() = default;
     SocketAddress(IPv4Address const& address)
         : m_type(Type::IPv4)
-        , m_ipv4_address(address)
+        , m_ip_address { address }
+    {
+    }
+
+    SocketAddress(IPv6Address const& address)
+        : m_type(Type::IPv6)
+        , m_ip_address { address }
     {
     }
 
     SocketAddress(IPv4Address const& address, u16 port)
         : m_type(Type::IPv4)
-        , m_ipv4_address(address)
+        , m_ip_address { address }
+        , m_port(port)
+    {
+    }
+
+    SocketAddress(IPv6Address const& address, u16 port)
+        : m_type(Type::IPv6)
+        , m_ip_address { address }
         , m_port(port)
     {
     }
@@ -48,14 +63,18 @@ public:
 
     Type type() const { return m_type; }
     bool is_valid() const { return m_type != Type::Invalid; }
-    IPv4Address ipv4_address() const { return m_ipv4_address; }
+
+    IPv4Address ipv4_address() const { return m_ip_address.get<IPv4Address>(); }
+    IPv6Address ipv6_address() const { return m_ip_address.get<IPv6Address>(); }
     u16 port() const { return m_port; }
 
     ByteString to_byte_string() const
     {
         switch (m_type) {
         case Type::IPv4:
-            return ByteString::formatted("{}:{}", m_ipv4_address, m_port);
+            return ByteString::formatted("{}:{}", m_ip_address.get<IPv4Address>(), m_port);
+        case Type::IPv6:
+            return ByteString::formatted("[{}]:{}", m_ip_address.get<IPv6Address>(), m_port);
         case Type::Local:
             return m_local_address;
         default:
@@ -74,13 +93,25 @@ public:
         return address;
     }
 
+    sockaddr_in6 to_sockaddr_in6() const
+    {
+        VERIFY(type() == Type::IPv6);
+        sockaddr_in6 address {};
+        memset(&address, 0, sizeof(address));
+        address.sin6_family = AF_INET6;
+        address.sin6_port = htons(port());
+        auto ipv6_addr = ipv6_address();
+        memcpy(&address.sin6_addr, &ipv6_addr.to_in6_addr_t(), sizeof(address.sin6_addr));
+        return address;
+    }
+
     sockaddr_in to_sockaddr_in() const
     {
         VERIFY(type() == Type::IPv4);
         sockaddr_in address {};
         address.sin_family = AF_INET;
-        address.sin_addr.s_addr = m_ipv4_address.to_in_addr_t();
-        address.sin_port = htons(m_port);
+        address.sin_port = htons(port());
+        address.sin_addr.s_addr = ipv4_address().to_in_addr_t();
         return address;
     }
 
@@ -89,7 +120,9 @@ public:
 
 private:
     Type m_type { Type::Invalid };
-    IPv4Address m_ipv4_address;
+
+    Variant<IPv4Address, IPv6Address> m_ip_address = IPv4Address();
+
     u16 m_port { 0 };
     ByteString m_local_address;
 };