Browse Source

Kernel: More work on sockets. Fleshing out connect().

Andreas Kling 6 years ago
parent
commit
b20a7aca61
7 changed files with 90 additions and 4 deletions
  1. 15 0
      Kernel/FileSystem.cpp
  2. 7 0
      Kernel/FileSystem.h
  3. 38 2
      Kernel/LocalSocket.cpp
  4. 4 0
      Kernel/LocalSocket.h
  5. 23 2
      Kernel/Process.cpp
  6. 2 0
      Kernel/Socket.h
  7. 1 0
      LibC/errno_numbers.h

+ 15 - 0
Kernel/FileSystem.cpp

@@ -4,6 +4,7 @@
 #include <LibC/errno_numbers.h>
 #include "FileSystem.h"
 #include "MemoryManager.h"
+#include <Kernel/LocalSocket.h>
 
 static dword s_lastFileSystemID;
 static HashMap<dword, FS*>* s_fs_map;
@@ -152,3 +153,17 @@ void Inode::set_vmo(VMObject& vmo)
 {
     m_vmo = vmo.make_weak_ptr();
 }
+
+bool Inode::bind_socket(LocalSocket& socket)
+{
+    ASSERT(!m_socket);
+    m_socket = socket;
+    return true;
+}
+
+bool Inode::unbind_socket()
+{
+    ASSERT(m_socket);
+    m_socket = nullptr;
+    return true;
+}

+ 7 - 0
Kernel/FileSystem.h

@@ -20,6 +20,7 @@ static const dword mepoch = 476763780;
 
 class Inode;
 class FileDescriptor;
+class LocalSocket;
 class VMObject;
 
 class FS : public Retainable<FS> {
@@ -92,6 +93,11 @@ public:
     virtual size_t directory_entry_count() const = 0;
     virtual bool chmod(mode_t, int& error) = 0;
 
+    LocalSocket* socket() { return m_socket.ptr(); }
+    const LocalSocket* socket() const { return m_socket.ptr(); }
+    bool bind_socket(LocalSocket&);
+    bool unbind_socket();
+
     bool is_metadata_dirty() const { return m_metadata_dirty; }
 
     virtual int set_atime(time_t);
@@ -120,6 +126,7 @@ private:
     FS& m_fs;
     unsigned m_index { 0 };
     WeakPtr<VMObject> m_vmo;
+    RetainPtr<LocalSocket> m_socket;
     bool m_metadata_dirty { false };
 };
 

+ 38 - 2
Kernel/LocalSocket.cpp

@@ -31,11 +31,11 @@ bool LocalSocket::get_address(sockaddr* address, socklen_t* address_size)
 
 bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& error)
 {
+    ASSERT(!m_connected);
     if (address_size != sizeof(sockaddr_un)) {
         error = -EINVAL;
         return false;
     }
-
     if (address->sa_family != AF_LOCAL) {
         error = -EINVAL;
         return false;
@@ -51,10 +51,46 @@ bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& err
     if (!m_file) {
         if (error == -EEXIST)
             error = -EADDRINUSE;
-        return error;
+        return false;
     }
 
+    ASSERT(m_file->inode());
+    m_file->inode()->bind_socket(*this);
+
     m_address = local_address;
     m_bound = true;
     return true;
 }
+
+RetainPtr<Socket> LocalSocket::connect(const sockaddr* address, socklen_t address_size, int& error)
+{
+    ASSERT(!m_bound);
+    if (address_size != sizeof(sockaddr_un)) {
+        error = -EINVAL;
+        return nullptr;
+    }
+    if (address->sa_family != AF_LOCAL) {
+        error = -EINVAL;
+        return nullptr;
+    }
+
+    const sockaddr_un& local_address = *reinterpret_cast<const sockaddr_un*>(address);
+    char safe_address[sizeof(local_address.sun_path) + 1];
+    memcpy(safe_address, local_address.sun_path, sizeof(local_address.sun_path));
+
+    kprintf("%s(%u) LocalSocket{%p} connect(%s)\n", current->name().characters(), current->pid(), safe_address);
+
+    m_file = VFS::the().open(safe_address, error, 0, 0, *current->cwd_inode());
+    if (!m_file) {
+        error = -ECONNREFUSED;
+        return nullptr;
+    }
+
+    ASSERT(m_file->inode());
+    ASSERT(m_file->inode()->socket());
+
+    m_peer = m_file->inode()->socket();
+    m_address = local_address;
+    m_connected = true;
+    return m_peer;
+}

+ 4 - 0
Kernel/LocalSocket.h

@@ -11,14 +11,18 @@ public:
     virtual ~LocalSocket() override;
 
     virtual bool bind(const sockaddr*, socklen_t, int& error) override;
+    virtual RetainPtr<Socket> connect(const sockaddr*, socklen_t, int& error) override;
     virtual bool get_address(sockaddr*, socklen_t*) override;
 
 private:
     explicit LocalSocket(int type);
+    virtual bool is_local() const override { return true; }
 
     RetainPtr<FileDescriptor> m_file;
+    RetainPtr<LocalSocket> m_peer;
 
     bool m_bound { false };
+    bool m_connected { false };
     sockaddr_un m_address;
 
     DoubleBuffer m_for_client;

+ 23 - 2
Kernel/Process.cpp

@@ -2329,7 +2329,28 @@ int Process::sys$accept(int sockfd, sockaddr* address, socklen_t* address_size)
     return fd;
 }
 
-int Process::sys$connect(int sockfd, const sockaddr*, socklen_t)
+int Process::sys$connect(int sockfd, const sockaddr* address, socklen_t address_size)
 {
-    return -ENOTIMPL;
+    if (!validate_read(address, address_size))
+        return -EFAULT;
+    if (number_of_open_file_descriptors() >= m_max_open_file_descriptors)
+        return -EMFILE;
+    int fd = 0;
+    for (; fd < (int)m_max_open_file_descriptors; ++fd) {
+        if (!m_fds[fd])
+            break;
+    }
+    auto* descriptor = file_descriptor(sockfd);
+    if (!descriptor)
+        return -EBADF;
+    if (!descriptor->is_socket())
+        return -ENOTSOCK;
+    auto& socket = *descriptor->socket();
+    int error;
+    auto server = socket.connect(address, address_size, error);
+    if (!server)
+        return error;
+    auto server_descriptor = FileDescriptor::create(move(server), SocketRole::Connected);
+    m_fds[fd].set(move(server_descriptor));
+    return fd;
 }

+ 2 - 0
Kernel/Socket.h

@@ -23,7 +23,9 @@ public:
     bool listen(int backlog, int& error);
 
     virtual bool bind(const sockaddr*, socklen_t, int& error) = 0;
+    virtual RetainPtr<Socket> connect(const sockaddr*, socklen_t, int& error) = 0;
     virtual bool get_address(sockaddr*, socklen_t*) = 0;
+    virtual bool is_local() const { return false; }
 
 protected:
     Socket(int domain, int type, int protocol);

+ 1 - 0
LibC/errno_numbers.h

@@ -46,6 +46,7 @@
     __ERROR(EADDRINUSE,     "Address in use") \
     __ERROR(EWHYTHO,        "Failed without setting an error code (Bug!)") \
     __ERROR(ENOTEMPTY,      "Directory not empty") \
+    __ERROR(ECONNREFUSED,   "Connection refused") \
 
 
 enum __errno_values {