Quellcode durchsuchen

Kernel: Give each FileDescriptor a chance to co-open sockets.

Track how many fds are open for a socket's Accepted and Connected roles.
This allows fork() to clone a socket fd without a subsequent close() walking
all over the parent process's fd.
Andreas Kling vor 6 Jahren
Ursprung
Commit
d5f515cf6c
6 geänderte Dateien mit 51 neuen und 27 gelöschten Zeilen
  1. 15 2
      Kernel/FileDescriptor.cpp
  2. 1 1
      Kernel/FileDescriptor.h
  3. 29 17
      Kernel/LocalSocket.cpp
  4. 4 3
      Kernel/LocalSocket.h
  5. 0 1
      Kernel/Socket.cpp
  6. 2 3
      Kernel/Socket.h

+ 15 - 2
Kernel/FileDescriptor.cpp

@@ -49,14 +49,14 @@ FileDescriptor::FileDescriptor(RetainPtr<Device>&& device)
 
 FileDescriptor::FileDescriptor(RetainPtr<Socket>&& socket, SocketRole role)
     : m_socket(move(socket))
-    , m_socket_role(role)
 {
+    set_socket_role(role);
 }
 
 FileDescriptor::~FileDescriptor()
 {
     if (m_socket) {
-        m_socket->close(m_socket_role);
+        m_socket->detach_fd(m_socket_role);
         m_socket = nullptr;
     }
     if (m_device) {
@@ -70,6 +70,16 @@ FileDescriptor::~FileDescriptor()
     m_inode = nullptr;
 }
 
+void FileDescriptor::set_socket_role(SocketRole role)
+{
+    if (role == m_socket_role)
+        return;
+
+    ASSERT(m_socket);
+    m_socket_role = role;
+    m_socket->attach_fd(role);
+}
+
 RetainPtr<FileDescriptor> FileDescriptor::clone()
 {
     RetainPtr<FileDescriptor> descriptor;
@@ -81,6 +91,9 @@ RetainPtr<FileDescriptor> FileDescriptor::clone()
         if (m_device) {
             descriptor = FileDescriptor::create(m_device.copy_ref());
             descriptor->m_inode = m_inode.copy_ref();
+        } else if (m_socket) {
+            descriptor = FileDescriptor::create(m_socket.copy_ref(), m_socket_role);
+            descriptor->m_inode = m_inode.copy_ref();
         } else {
             descriptor = FileDescriptor::create(m_inode.copy_ref());
         }

+ 1 - 1
Kernel/FileDescriptor.h

@@ -86,7 +86,7 @@ public:
 
     void set_original_inode(Badge<VFS>, RetainPtr<Inode>&& inode) { m_inode = move(inode); }
 
-    void set_socket_role(SocketRole role) { m_socket_role = role; }
+    void set_socket_role(SocketRole);
 
 private:
     friend class VFS;

+ 29 - 17
Kernel/LocalSocket.cpp

@@ -106,23 +106,34 @@ bool LocalSocket::connect(const sockaddr* address, socklen_t address_size, int&
     return true;
 }
 
-void LocalSocket::close(SocketRole role)
+void LocalSocket::attach_fd(SocketRole role)
 {
-    if (role == SocketRole::Accepted)
-        m_server_closed = true;
-    else if (role == SocketRole::Connected)
-        m_client_closed = true;
+    if (role == SocketRole::Accepted) {
+        ++m_accepted_fds_open;
+    } else if (role == SocketRole::Connected) {
+        ++m_connected_fds_open;
+    }
+}
+
+void LocalSocket::detach_fd(SocketRole role)
+{
+    if (role == SocketRole::Accepted) {
+        ASSERT(m_accepted_fds_open);
+        --m_accepted_fds_open;
+    } else if (role == SocketRole::Connected) {
+        ASSERT(m_connected_fds_open);
+        --m_connected_fds_open;
+    }
 }
 
 bool LocalSocket::can_read(SocketRole role) const
 {
-    if (m_bound && is_listening())
+    if (role == SocketRole::Listener)
         return can_accept();
-
     if (role == SocketRole::Accepted)
-        return m_client_closed || !m_for_server.is_empty();
-    else if (role == SocketRole::Connected)
-        return m_server_closed || !m_for_client.is_empty();
+        return !m_connected_fds_open || !m_for_server.is_empty();
+    if (role == SocketRole::Connected)
+        return !m_accepted_fds_open || !m_for_client.is_empty();
     ASSERT_NOT_REACHED();
 }
 
@@ -130,7 +141,7 @@ ssize_t LocalSocket::read(SocketRole role, byte* buffer, size_t size)
 {
     if (role == SocketRole::Accepted)
         return m_for_server.read(buffer, size);
-    else if (role == SocketRole::Connected)
+    if (role == SocketRole::Connected)
         return m_for_client.read(buffer, size);
     ASSERT_NOT_REACHED();
 }
@@ -138,11 +149,12 @@ ssize_t LocalSocket::read(SocketRole role, byte* buffer, size_t size)
 ssize_t LocalSocket::write(SocketRole role, const byte* data, size_t size)
 {
     if (role == SocketRole::Accepted) {
-        if (m_client_closed)
+        if (!m_accepted_fds_open)
             return -EPIPE;
         return m_for_client.write(data, size);
-    } else if (role == SocketRole::Connected) {
-        if (m_client_closed)
+    }
+    if (role == SocketRole::Connected) {
+        if (!m_connected_fds_open)
             return -EPIPE;
         return m_for_server.write(data, size);
     }
@@ -152,8 +164,8 @@ ssize_t LocalSocket::write(SocketRole role, const byte* data, size_t size)
 bool LocalSocket::can_write(SocketRole role) const
 {
     if (role == SocketRole::Accepted)
-        return m_client_closed || m_for_client.bytes_in_write_buffer() < 4096;
-    else if (role == SocketRole::Connected)
-        return m_server_closed || m_for_server.bytes_in_write_buffer() < 4096;
+        return !m_connected_fds_open || m_for_client.bytes_in_write_buffer() < 4096;
+    if (role == SocketRole::Connected)
+        return !m_accepted_fds_open || m_for_server.bytes_in_write_buffer() < 4096;
     ASSERT_NOT_REACHED();
 }

+ 4 - 3
Kernel/LocalSocket.h

@@ -13,7 +13,8 @@ public:
     virtual bool bind(const sockaddr*, socklen_t, int& error) override;
     virtual bool connect(const sockaddr*, socklen_t, int& error) override;
     virtual bool get_address(sockaddr*, socklen_t*) override;
-    virtual void close(SocketRole) override;
+    virtual void attach_fd(SocketRole) override;
+    virtual void detach_fd(SocketRole) override;
     virtual bool can_read(SocketRole) const override;
     virtual ssize_t read(SocketRole, byte*, size_t) override;
     virtual ssize_t write(SocketRole, const byte*, size_t) override;
@@ -27,8 +28,8 @@ private:
     RetainPtr<LocalSocket> m_peer;
 
     bool m_bound { false };
-    bool m_server_closed { false };
-    bool m_client_closed { false };
+    int m_accepted_fds_open { 0 };
+    int m_connected_fds_open { 0 };
     sockaddr_un m_address;
 
     DoubleBuffer m_for_client;

+ 0 - 1
Kernel/Socket.cpp

@@ -36,7 +36,6 @@ bool Socket::listen(int backlog, int& error)
         return false;
     }
     m_backlog = backlog;
-    m_listening = true;
     kprintf("Socket{%p} listening with backlog=%d\n", this, m_backlog);
     return true;
 }

+ 2 - 3
Kernel/Socket.h

@@ -14,7 +14,6 @@ public:
     static RetainPtr<Socket> create(int domain, int type, int protocol, int& error);
     virtual ~Socket();
 
-    bool is_listening() const { return m_listening; }
     int domain() const { return m_domain; }
     int type() const { return m_type; }
     int protocol() const { return m_protocol; }
@@ -28,7 +27,8 @@ public:
     virtual bool connect(const sockaddr*, socklen_t, int& error) = 0;
     virtual bool get_address(sockaddr*, socklen_t*) = 0;
     virtual bool is_local() const { return false; }
-    virtual void close(SocketRole) = 0;
+    virtual void attach_fd(SocketRole) = 0;
+    virtual void detach_fd(SocketRole) = 0;
     virtual bool can_read(SocketRole) const = 0;
     virtual ssize_t read(SocketRole, byte*, size_t) = 0;
     virtual ssize_t write(SocketRole, const byte*, size_t) = 0;
@@ -48,7 +48,6 @@ private:
     int m_type { 0 };
     int m_protocol { 0 };
     int m_backlog { 0 };
-    bool m_listening { false };
     bool m_connected { false };
 
     Vector<RetainPtr<Socket>> m_pending;