Browse Source

Kernel: Implement IP multicast support

An IP socket can now join a multicast group by using the
IP_ADD_MEMBERSHIP sockopt, which will cause it to start receiving
packets sent to the multicast address, even though this address does
not belong to this host.
Sergey Bugaev 4 năm trước cách đây
mục cha
commit
78459b92d5

+ 46 - 0
Kernel/Net/IPv4Socket.cpp

@@ -466,6 +466,42 @@ KResult IPv4Socket::setsockopt(int level, int option, Userspace<const void*> use
         m_ttl = value;
         m_ttl = value;
         return KSuccess;
         return KSuccess;
     }
     }
+    case IP_MULTICAST_LOOP: {
+        if (user_value_size != 1)
+            return EINVAL;
+        u8 value;
+        if (!copy_from_user(&value, static_ptr_cast<const u8*>(user_value)))
+            return EFAULT;
+        if (value != 0 && value != 1)
+            return EINVAL;
+        m_multicast_loop = value;
+        return KSuccess;
+    }
+    case IP_ADD_MEMBERSHIP: {
+        if (user_value_size != sizeof(ip_mreq))
+            return EINVAL;
+        ip_mreq mreq;
+        if (!copy_from_user(&mreq, static_ptr_cast<const ip_mreq*>(user_value)))
+            return EFAULT;
+        if (mreq.imr_interface.s_addr != INADDR_ANY)
+            return ENOTSUP;
+        IPv4Address address { (const u8*)&mreq.imr_multiaddr.s_addr };
+        if (!m_multicast_memberships.contains_slow(address))
+            m_multicast_memberships.append(address);
+        return KSuccess;
+    }
+    case IP_DROP_MEMBERSHIP: {
+        if (user_value_size != sizeof(ip_mreq))
+            return EINVAL;
+        ip_mreq mreq;
+        if (!copy_from_user(&mreq, static_ptr_cast<const ip_mreq*>(user_value)))
+            return EFAULT;
+        if (mreq.imr_interface.s_addr != INADDR_ANY)
+            return ENOTSUP;
+        IPv4Address address { (const u8*)&mreq.imr_multiaddr.s_addr };
+        m_multicast_memberships.remove_first_matching([&address](auto& a) { return a == address; });
+        return KSuccess;
+    }
     default:
     default:
         return ENOPROTOOPT;
         return ENOPROTOOPT;
     }
     }
@@ -490,6 +526,16 @@ KResult IPv4Socket::getsockopt(FileDescription& description, int level, int opti
         if (!copy_to_user(value_size, &size))
         if (!copy_to_user(value_size, &size))
             return EFAULT;
             return EFAULT;
         return KSuccess;
         return KSuccess;
+    case IP_MULTICAST_LOOP: {
+        if (size < 1)
+            return EINVAL;
+        if (!copy_to_user(static_ptr_cast<u8*>(value), (const u8*)&m_multicast_loop))
+            return EFAULT;
+        size = 1;
+        if (!copy_to_user(value_size, &size))
+            return EFAULT;
+        return KSuccess;
+    }
     default:
     default:
         return ENOPROTOOPT;
         return ENOPROTOOPT;
     }
     }

+ 5 - 0
Kernel/Net/IPv4Socket.h

@@ -54,6 +54,8 @@ public:
     u16 peer_port() const { return m_peer_port; }
     u16 peer_port() const { return m_peer_port; }
     void set_peer_port(u16 port) { m_peer_port = port; }
     void set_peer_port(u16 port) { m_peer_port = port; }
 
 
+    const Vector<IPv4Address>& multicast_memberships() const { return m_multicast_memberships; }
+
     IPv4SocketTuple tuple() const { return IPv4SocketTuple(m_local_address, m_local_port, m_peer_address, m_peer_port); }
     IPv4SocketTuple tuple() const { return IPv4SocketTuple(m_local_address, m_local_port, m_peer_address, m_peer_port); }
 
 
     String absolute_path(const FileDescription& description) const override;
     String absolute_path(const FileDescription& description) const override;
@@ -96,6 +98,9 @@ private:
     IPv4Address m_local_address;
     IPv4Address m_local_address;
     IPv4Address m_peer_address;
     IPv4Address m_peer_address;
 
 
+    Vector<IPv4Address> m_multicast_memberships;
+    bool m_multicast_loop { true };
+
     struct ReceivedPacket {
     struct ReceivedPacket {
         IPv4Address peer_address;
         IPv4Address peer_address;
         u16 peer_port;
         u16 peer_port;

+ 5 - 7
Kernel/Net/NetworkTask.cpp

@@ -258,12 +258,6 @@ void handle_udp(const IPv4Packet& ipv4_packet, const Time& packet_timestamp)
         return;
         return;
     }
     }
 
 
-    auto adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination());
-    if (!adapter && ipv4_packet.destination() != IPv4Address(255, 255, 255, 255)) {
-        dbgln_if(UDP_DEBUG, "handle_udp: this packet is not for me, it's for {}", ipv4_packet.destination());
-        return;
-    }
-
     auto& udp_packet = *static_cast<const UDPPacket*>(ipv4_packet.payload());
     auto& udp_packet = *static_cast<const UDPPacket*>(ipv4_packet.payload());
     dbgln_if(UDP_DEBUG, "handle_udp: source={}:{}, destination={}:{}, length={}",
     dbgln_if(UDP_DEBUG, "handle_udp: source={}:{}, destination={}:{}, length={}",
         ipv4_packet.source(), udp_packet.source_port(),
         ipv4_packet.source(), udp_packet.source_port(),
@@ -278,7 +272,11 @@ void handle_udp(const IPv4Packet& ipv4_packet, const Time& packet_timestamp)
 
 
     VERIFY(socket->type() == SOCK_DGRAM);
     VERIFY(socket->type() == SOCK_DGRAM);
     VERIFY(socket->local_port() == udp_packet.destination_port());
     VERIFY(socket->local_port() == udp_packet.destination_port());
-    socket->did_receive(ipv4_packet.source(), udp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()), packet_timestamp);
+
+    auto& destination = ipv4_packet.destination();
+
+    if (destination == IPv4Address(255, 255, 255, 255) || NetworkAdapter::from_ipv4_address(destination) || socket->multicast_memberships().contains_slow(destination))
+        socket->did_receive(ipv4_packet.source(), udp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()), packet_timestamp);
 }
 }
 
 
 void handle_tcp(const IPv4Packet& ipv4_packet, const Time& packet_timestamp)
 void handle_tcp(const IPv4Packet& ipv4_packet, const Time& packet_timestamp)

+ 13 - 0
Kernel/UnixTypes.h

@@ -525,6 +525,9 @@ enum {
 #define IPPROTO_UDP 17
 #define IPPROTO_UDP 17
 
 
 #define IP_TTL 2
 #define IP_TTL 2
+#define IP_MULTICAST_LOOP 3
+#define IP_ADD_MEMBERSHIP 4
+#define IP_DROP_MEMBERSHIP 5
 
 
 struct ucred {
 struct ucred {
     pid_t pid;
     pid_t pid;
@@ -548,6 +551,7 @@ struct sockaddr_un {
 struct in_addr {
 struct in_addr {
     uint32_t s_addr;
     uint32_t s_addr;
 };
 };
+typedef uint32_t in_addr_t;
 
 
 struct sockaddr_in {
 struct sockaddr_in {
     int16_t sin_family;
     int16_t sin_family;
@@ -556,6 +560,15 @@ struct sockaddr_in {
     char sin_zero[8];
     char sin_zero[8];
 };
 };
 
 
+struct ip_mreq {
+    struct in_addr imr_multiaddr;
+    struct in_addr imr_interface;
+};
+
+#define INADDR_ANY ((in_addr_t)0)
+#define INADDR_NONE ((in_addr_t)-1)
+#define INADDR_LOOPBACK 0x7f000001
+
 typedef u32 __u32;
 typedef u32 __u32;
 typedef u16 __u16;
 typedef u16 __u16;
 typedef u8 __u8;
 typedef u8 __u8;

+ 8 - 0
Userland/Libraries/LibC/netinet/in.h

@@ -22,6 +22,9 @@ in_addr_t inet_addr(const char*);
 #define IN_LOOPBACKNET 127
 #define IN_LOOPBACKNET 127
 
 
 #define IP_TTL 2
 #define IP_TTL 2
+#define IP_MULTICAST_LOOP 3
+#define IP_ADD_MEMBERSHIP 4
+#define IP_DROP_MEMBERSHIP 5
 
 
 #define IPPORT_RESERVED 1024
 #define IPPORT_RESERVED 1024
 #define IPPORT_USERRESERVED 5000
 #define IPPORT_USERRESERVED 5000
@@ -39,6 +42,11 @@ struct sockaddr_in {
     char sin_zero[8];
     char sin_zero[8];
 };
 };
 
 
+struct ip_mreq {
+    struct in_addr imr_multiaddr;
+    struct in_addr imr_interface;
+};
+
 struct in6_addr {
 struct in6_addr {
     uint8_t s6_addr[16];
     uint8_t s6_addr[16];
 };
 };