Browse Source

Kernel: Use WeakPtr<NetworkAdapter> instead of NetworkAdapter* in net code

Conrad Pankoff 5 years ago
parent
commit
54ceabd48d

+ 1 - 1
Kernel/Net/IPv4Socket.cpp

@@ -169,7 +169,7 @@ ssize_t IPv4Socket::sendto(FileDescription&, const void* data, size_t data_lengt
         m_peer_port = ntohs(ia.sin_port);
     }
 
-    auto* adapter = adapter_for_route_to(m_peer_address);
+    auto adapter = adapter_for_route_to(m_peer_address);
     if (!adapter)
         return -EHOSTUNREACH;
 

+ 2 - 2
Kernel/Net/NetworkAdapter.cpp

@@ -22,12 +22,12 @@ void NetworkAdapter::for_each(Function<void(NetworkAdapter&)> callback)
         callback(*it);
 }
 
-NetworkAdapter* NetworkAdapter::from_ipv4_address(const IPv4Address& address)
+WeakPtr<NetworkAdapter> NetworkAdapter::from_ipv4_address(const IPv4Address& address)
 {
     LOCKER(all_adapters().lock());
     for (auto* adapter : all_adapters().resource()) {
         if (adapter->ipv4_address() == address)
-            return adapter;
+            return adapter->make_weak_ptr();
     }
     return nullptr;
 }

+ 4 - 2
Kernel/Net/NetworkAdapter.h

@@ -4,6 +4,8 @@
 #include <AK/Function.h>
 #include <AK/SinglyLinkedList.h>
 #include <AK/Types.h>
+#include <AK/Weakable.h>
+#include <AK/WeakPtr.h>
 #include <Kernel/KBuffer.h>
 #include <Kernel/Net/ARP.h>
 #include <Kernel/Net/ICMP.h>
@@ -12,10 +14,10 @@
 
 class NetworkAdapter;
 
-class NetworkAdapter {
+class NetworkAdapter : public Weakable<NetworkAdapter> {
 public:
     static void for_each(Function<void(NetworkAdapter&)>);
-    static NetworkAdapter* from_ipv4_address(const IPv4Address&);
+    static WeakPtr<NetworkAdapter> from_ipv4_address(const IPv4Address&);
     virtual ~NetworkAdapter();
 
     virtual const char* class_name() const = 0;

+ 5 - 5
Kernel/Net/NetworkTask.cpp

@@ -38,7 +38,7 @@ void NetworkTask_main()
 {
     LoopbackAdapter::the();
 
-    auto* adapter = E1000NetworkAdapter::the();
+    auto adapter = E1000NetworkAdapter::the();
     if (!adapter)
         dbgprintf("E1000 network card not found!\n");
 
@@ -150,7 +150,7 @@ void handle_arp(const EthernetFrameHeader& eth, int frame_size)
 
     if (packet.operation() == ARPOperation::Request) {
         // Who has this IP address?
-        if (auto* adapter = NetworkAdapter::from_ipv4_address(packet.target_protocol_address())) {
+        if (auto adapter = NetworkAdapter::from_ipv4_address(packet.target_protocol_address())) {
             // We do!
             kprintf("handle_arp: Responding to ARP request for my IPv4 address (%s)\n",
                 adapter->ipv4_address().to_string().characters());
@@ -231,7 +231,7 @@ void handle_icmp(const EthernetFrameHeader& eth, int frame_size)
         }
     }
 
-    auto* adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination());
+    auto adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination());
     if (!adapter)
         return;
 
@@ -260,7 +260,7 @@ void handle_udp(const EthernetFrameHeader& eth, int frame_size)
     (void)frame_size;
     auto& ipv4_packet = *static_cast<const IPv4Packet*>(eth.payload());
 
-    auto* adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination());
+    auto adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination());
     if (!adapter) {
         kprintf("handle_udp: this packet is not for me, it's for %s\n", ipv4_packet.destination().to_string().characters());
         return;
@@ -292,7 +292,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
     (void)frame_size;
     auto& ipv4_packet = *static_cast<const IPv4Packet*>(eth.payload());
 
-    auto* adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination());
+    auto adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination());
     if (!adapter) {
         kprintf("handle_tcp: this packet is not for me, it's for %s\n", ipv4_packet.destination().to_string().characters());
         return;

+ 2 - 2
Kernel/Net/Routing.cpp

@@ -1,10 +1,10 @@
 #include <Kernel/Net/LoopbackAdapter.h>
 #include <Kernel/Net/Routing.h>
 
-NetworkAdapter* adapter_for_route_to(const IPv4Address& ipv4_address)
+WeakPtr<NetworkAdapter> adapter_for_route_to(const IPv4Address& ipv4_address)
 {
     // FIXME: Have an actual routing table.
     if (ipv4_address == IPv4Address(127, 0, 0, 1))
-        return &LoopbackAdapter::the();
+        return LoopbackAdapter::the().make_weak_ptr();
     return NetworkAdapter::from_ipv4_address(IPv4Address(192, 168, 5, 2));
 }

+ 1 - 1
Kernel/Net/Routing.h

@@ -2,4 +2,4 @@
 
 #include <Kernel/Net/NetworkAdapter.h>
 
-NetworkAdapter* adapter_for_route_to(const IPv4Address&);
+WeakPtr<NetworkAdapter> adapter_for_route_to(const IPv4Address&);

+ 10 - 1
Kernel/Net/TCPSocket.cpp

@@ -80,7 +80,16 @@ int TCPSocket::protocol_send(const void* data, int data_length)
 
 void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size)
 {
-    ASSERT(m_adapter);
+    if (!m_adapter) {
+        if (has_specific_local_address()) {
+            m_adapter = NetworkAdapter::from_ipv4_address(local_address());
+        } else {
+            m_adapter = adapter_for_route_to(peer_address());
+            if (m_adapter)
+                set_local_address(m_adapter->ipv4_address());
+        }
+    }
+    ASSERT(!!m_adapter);
 
     auto buffer = ByteBuffer::create_zeroed(sizeof(TCPPacket) + payload_size);
     auto& tcp_packet = *(TCPPacket*)(buffer.pointer());

+ 2 - 1
Kernel/Net/TCPSocket.h

@@ -1,6 +1,7 @@
 #pragma once
 
 #include <AK/Function.h>
+#include <AK/WeakPtr.h>
 #include <Kernel/Net/IPv4Socket.h>
 
 class TCPSocket final : public IPv4Socket {
@@ -86,7 +87,7 @@ private:
     virtual KResult protocol_bind() override;
     virtual KResult protocol_listen() override;
 
-    NetworkAdapter* m_adapter { nullptr };
+    WeakPtr<NetworkAdapter> m_adapter;
     u32 m_sequence_number { 0 };
     u32 m_ack_number { 0 };
     State m_state { State::Closed };

+ 1 - 1
Kernel/Net/UDPSocket.cpp

@@ -56,7 +56,7 @@ int UDPSocket::protocol_receive(const KBuffer& packet_buffer, void* buffer, size
 
 int UDPSocket::protocol_send(const void* data, int data_length)
 {
-    auto* adapter = adapter_for_route_to(peer_address());
+    auto adapter = adapter_for_route_to(peer_address());
     if (!adapter)
         return -EHOSTUNREACH;
     auto buffer = ByteBuffer::create_zeroed(sizeof(UDPPacket) + data_length);