Parcourir la source

Kernel: More sockets work. Fleshing out accept().

Andreas Kling il y a 6 ans
Parent
commit
54b1d6f57f
7 fichiers modifiés avec 103 ajouts et 9 suppressions
  1. 15 2
      Kernel/LocalSocket.cpp
  2. 8 0
      Kernel/LocalSocket.h
  3. 35 4
      Kernel/Process.cpp
  4. 1 1
      Kernel/Process.h
  5. 22 1
      Kernel/Socket.cpp
  6. 16 1
      Kernel/Socket.h
  7. 6 0
      Kernel/UnixTypes.h

+ 15 - 2
Kernel/LocalSocket.cpp

@@ -19,6 +19,16 @@ LocalSocket::~LocalSocket()
 {
 }
 
+bool LocalSocket::get_address(sockaddr* address, socklen_t* address_size)
+{
+    // FIXME: Look into what fallback behavior we should have here.
+    if (*address_size != sizeof(sockaddr_un))
+        return false;
+    memcpy(address, &m_address, sizeof(sockaddr_un));
+    *address_size = sizeof(sockaddr_un);
+    return true;
+}
+
 bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& error)
 {
     if (address_size != sizeof(sockaddr_un)) {
@@ -37,11 +47,14 @@ bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& err
 
     kprintf("%s(%u) LocalSocket{%p} bind(%s)\n", current->name().characters(), current->pid(), safe_address);
 
-    auto descriptor = VFS::the().open(safe_address, error, O_CREAT | O_EXCL, S_IFSOCK | 0666, *current->cwd_inode());
-    if (!descriptor) {
+    m_file = VFS::the().open(safe_address, error, O_CREAT | O_EXCL, S_IFSOCK | 0666, *current->cwd_inode());
+    if (!m_file) {
         if (error == -EEXIST)
             error = -EADDRINUSE;
         return error;
     }
+
+    m_address = local_address;
+    m_bound = true;
     return true;
 }

+ 8 - 0
Kernel/LocalSocket.h

@@ -3,16 +3,24 @@
 #include <Kernel/Socket.h>
 #include <Kernel/DoubleBuffer.h>
 
+class FileDescriptor;
+
 class LocalSocket final : public Socket {
 public:
     static RetainPtr<LocalSocket> create(int type);
     virtual ~LocalSocket() override;
 
     virtual bool bind(const sockaddr*, socklen_t, int& error) override;
+    virtual bool get_address(sockaddr*, socklen_t*) override;
 
 private:
     explicit LocalSocket(int type);
 
+    RetainPtr<FileDescriptor> m_file;
+
+    bool m_bound { false };
+    sockaddr_un m_address;
+
     DoubleBuffer m_for_client;
     DoubleBuffer m_for_server;
 };

+ 35 - 4
Kernel/Process.cpp

@@ -2258,7 +2258,12 @@ int Process::sys$socket(int domain, int type, int protocol)
     if (!socket)
         return error;
     auto descriptor = FileDescriptor::create(move(socket));
-    m_fds[fd].set(move(descriptor));
+    unsigned flags = 0;
+    if (type & SOCK_CLOEXEC)
+        flags |= O_CLOEXEC;
+    if (type & SOCK_NONBLOCK)
+        descriptor->set_blocking(false);
+    m_fds[fd].set(move(descriptor), flags);
     return fd;
 }
 
@@ -2280,12 +2285,38 @@ int Process::sys$bind(int sockfd, const sockaddr* address, socklen_t address_len
 
 int Process::sys$listen(int sockfd, int backlog)
 {
-    return -ENOTIMPL;
+    auto* descriptor = file_descriptor(sockfd);
+    if (!descriptor)
+        return -EBADF;
+    if (!descriptor->is_socket())
+        return -ENOTSOCK;
+    auto& socket = *descriptor->socket();
+    int error;
+    if (!socket.listen(backlog, error))
+        return error;
+    return 0;
 }
 
-int Process::sys$accept(int sockfd, sockaddr*, socklen_t)
+int Process::sys$accept(int sockfd, sockaddr* address, socklen_t* address_size)
 {
-    return -ENOTIMPL;
+    if (!validate_write_typed(address_size))
+        return -EFAULT;
+    if (!validate_write(address, *address_size))
+        return -EFAULT;
+    auto* descriptor = file_descriptor(sockfd);
+    if (!descriptor)
+        return -EBADF;
+    if (!descriptor->is_socket())
+        return -ENOTSOCK;
+    auto& socket = *descriptor->socket();
+    if (!socket.can_accept()) {
+        ASSERT(!descriptor->is_blocking());
+        return -EAGAIN;
+    }
+    auto client = socket.accept();
+    ASSERT(client);
+    client->get_address(address, address_size);
+    return 0;
 }
 
 int Process::sys$connect(int sockfd, const sockaddr*, socklen_t)

+ 1 - 1
Kernel/Process.h

@@ -221,7 +221,7 @@ public:
     int sys$socket(int domain, int type, int protocol);
     int sys$bind(int sockfd, const sockaddr* addr, socklen_t);
     int sys$listen(int sockfd, int backlog);
-    int sys$accept(int sockfd, sockaddr*, socklen_t);
+    int sys$accept(int sockfd, sockaddr*, socklen_t*);
     int sys$connect(int sockfd, const sockaddr*, socklen_t);
 
     DisplayInfo set_video_resolution(int width, int height);

+ 22 - 1
Kernel/Socket.cpp

@@ -8,7 +8,7 @@ RetainPtr<Socket> Socket::create(int domain, int type, int protocol, int& error)
     (void)protocol;
     switch (domain) {
     case AF_LOCAL:
-        return LocalSocket::create(type);
+        return LocalSocket::create(type & SOCK_TYPE_MASK);
     default:
         error = EAFNOSUPPORT;
         return nullptr;
@@ -26,4 +26,25 @@ Socket::~Socket()
 {
 }
 
+bool Socket::listen(int backlog, int& error)
+{
+    LOCKER(m_lock);
+    if (m_type != SOCK_STREAM) {
+        error = -EOPNOTSUPP;
+        return false;
+    }
+    m_backlog = backlog;
+    m_listening = true;
+    kprintf("Socket{%p} listening with backlog=%d\n", m_backlog);
+    return true;
+}
 
+RetainPtr<Socket> Socket::accept()
+{
+    LOCKER(m_lock);
+    if (m_pending.is_empty())
+        return nullptr;
+    auto client = m_pending.take_first();
+    m_clients.append(client.copy_ref());
+    return client;
+}

+ 16 - 1
Kernel/Socket.h

@@ -1,7 +1,10 @@
 #pragma once
 
+#include <AK/Lock.h>
 #include <AK/Retainable.h>
 #include <AK/RetainPtr.h>
+#include <AK/HashTable.h>
+#include <AK/Vector.h>
 #include <Kernel/UnixTypes.h>
 
 class Socket : public Retainable<Socket> {
@@ -9,18 +12,30 @@ public:
     static RetainPtr<Socket> create(int domain, int type, int protocol, int& error);
     virtual ~Socket();
 
+    bool is_listening() const { return m_listening; }
     int domain() const { return m_domain; }
     int type() const { return m_type; }
     int protocol() const { return m_protocol; }
 
+    bool can_accept() const { return m_pending.is_empty(); }
+    RetainPtr<Socket> accept();
+
+    bool listen(int backlog, int& error);
+
     virtual bool bind(const sockaddr*, socklen_t, int& error) = 0;
+    virtual bool get_address(sockaddr*, socklen_t*) = 0;
 
 protected:
     Socket(int domain, int type, int protocol);
 
 private:
+    Lock m_lock;
     int m_domain { 0 };
     int m_type { 0 };
     int m_protocol { 0 };
-};
+    int m_backlog { 0 };
+    bool m_listening { false };
 
+    Vector<RetainPtr<Socket>> m_pending;
+    Vector<RetainPtr<Socket>> m_clients;
+};

+ 6 - 0
Kernel/UnixTypes.h

@@ -306,9 +306,15 @@ struct pollfd {
     short revents;
 };
 
+#define AF_MASK 0xff
 #define AF_UNSPEC 0
 #define AF_LOCAL 1
 
+#define SOCK_TYPE_MASK 0xff
+#define SOCK_STREAM 1
+#define SOCK_NONBLOCK 04000
+#define SOCK_CLOEXEC 02000000
+
 struct sockaddr {
     word sa_family;
     char sa_data[14];