Преглед изворни кода

Kernel: Prepare Socket for becoming a File.

Make the Socket functions take a FileDescriptor& rather than a socket role
throughout the code. Also change threads to block on a FileDescriptor,
rather than either an fd index or a Socket.
Andreas Kling пре 6 година
родитељ
комит
03da7046bd

+ 8 - 9
Kernel/FileSystem/FileDescriptor.cpp

@@ -47,7 +47,7 @@ FileDescriptor::FileDescriptor(RetainPtr<Socket>&& socket, SocketRole role)
 FileDescriptor::~FileDescriptor()
 FileDescriptor::~FileDescriptor()
 {
 {
     if (m_socket) {
     if (m_socket) {
-        m_socket->detach_fd(m_socket_role);
+        m_socket->detach(*this);
         m_socket = nullptr;
         m_socket = nullptr;
     }
     }
     if (is_fifo())
     if (is_fifo())
@@ -65,11 +65,10 @@ void FileDescriptor::set_socket_role(SocketRole role)
         return;
         return;
 
 
     ASSERT(m_socket);
     ASSERT(m_socket);
-    auto old_socket_role = m_socket_role;
+    if (m_socket_role != SocketRole::None)
+        m_socket->detach(*this);
     m_socket_role = role;
     m_socket_role = role;
-    m_socket->attach_fd(role);
-    if (old_socket_role != SocketRole::None)
-        m_socket->detach_fd(old_socket_role);
+    m_socket->attach(*this);
 }
 }
 
 
 Retained<FileDescriptor> FileDescriptor::clone()
 Retained<FileDescriptor> FileDescriptor::clone()
@@ -182,7 +181,7 @@ ssize_t FileDescriptor::read(byte* buffer, ssize_t count)
         return nread;
         return nread;
     }
     }
     if (m_socket)
     if (m_socket)
-        return m_socket->read(m_socket_role, buffer, count);
+        return m_socket->read(*this, buffer, count);
     ASSERT(inode());
     ASSERT(inode());
     ssize_t nread = inode()->read_bytes(m_current_offset, count, buffer, this);
     ssize_t nread = inode()->read_bytes(m_current_offset, count, buffer, this);
     m_current_offset += nread;
     m_current_offset += nread;
@@ -198,7 +197,7 @@ ssize_t FileDescriptor::write(const byte* data, ssize_t size)
         return nwritten;
         return nwritten;
     }
     }
     if (m_socket)
     if (m_socket)
-        return m_socket->write(m_socket_role, data, size);
+        return m_socket->write(*this, data, size);
     ASSERT(m_inode);
     ASSERT(m_inode);
     ssize_t nwritten = m_inode->write_bytes(m_current_offset, size, data, this);
     ssize_t nwritten = m_inode->write_bytes(m_current_offset, size, data, this);
     m_current_offset += nwritten;
     m_current_offset += nwritten;
@@ -210,7 +209,7 @@ bool FileDescriptor::can_write()
     if (m_file)
     if (m_file)
         return m_file->can_write(*this);
         return m_file->can_write(*this);
     if (m_socket)
     if (m_socket)
-        return m_socket->can_write(m_socket_role);
+        return m_socket->can_write(*this);
     return true;
     return true;
 }
 }
 
 
@@ -219,7 +218,7 @@ bool FileDescriptor::can_read()
     if (m_file)
     if (m_file)
         return m_file->can_read(*this);
         return m_file->can_read(*this);
     if (m_socket)
     if (m_socket)
-        return m_socket->can_read(m_socket_role);
+        return m_socket->can_read(*this);
     return true;
     return true;
 }
 }
 
 

+ 13 - 14
Kernel/Net/IPv4Socket.cpp

@@ -66,7 +66,7 @@ KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size)
     ASSERT_NOT_REACHED();
     ASSERT_NOT_REACHED();
 }
 }
 
 
-KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size, ShouldBlock should_block)
+KResult IPv4Socket::connect(FileDescriptor& descriptor, const sockaddr* address, socklen_t address_size, ShouldBlock should_block)
 {
 {
     ASSERT(!m_bound);
     ASSERT(!m_bound);
     if (address_size != sizeof(sockaddr_in))
     if (address_size != sizeof(sockaddr_in))
@@ -78,37 +78,37 @@ KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size, Sho
     m_destination_address = IPv4Address((const byte*)&ia.sin_addr.s_addr);
     m_destination_address = IPv4Address((const byte*)&ia.sin_addr.s_addr);
     m_destination_port = ntohs(ia.sin_port);
     m_destination_port = ntohs(ia.sin_port);
 
 
-    return protocol_connect(should_block);
+    return protocol_connect(descriptor, should_block);
 }
 }
 
 
-void IPv4Socket::attach_fd(SocketRole)
+void IPv4Socket::attach(FileDescriptor&)
 {
 {
     ++m_attached_fds;
     ++m_attached_fds;
 }
 }
 
 
-void IPv4Socket::detach_fd(SocketRole)
+void IPv4Socket::detach(FileDescriptor&)
 {
 {
     --m_attached_fds;
     --m_attached_fds;
 }
 }
 
 
-bool IPv4Socket::can_read(SocketRole) const
+bool IPv4Socket::can_read(FileDescriptor&) const
 {
 {
     if (protocol_is_disconnected())
     if (protocol_is_disconnected())
         return true;
         return true;
     return m_can_read;
     return m_can_read;
 }
 }
 
 
-ssize_t IPv4Socket::read(SocketRole, byte* buffer, ssize_t size)
+ssize_t IPv4Socket::read(FileDescriptor& descriptor, byte* buffer, ssize_t size)
 {
 {
-    return recvfrom(buffer, size, 0, nullptr, 0);
+    return recvfrom(descriptor, buffer, size, 0, nullptr, 0);
 }
 }
 
 
-ssize_t IPv4Socket::write(SocketRole, const byte* data, ssize_t size)
+ssize_t IPv4Socket::write(FileDescriptor& descriptor, const byte* data, ssize_t size)
 {
 {
-    return sendto(data, size, 0, nullptr, 0);
+    return sendto(descriptor, data, size, 0, nullptr, 0);
 }
 }
 
 
-bool IPv4Socket::can_write(SocketRole) const
+bool IPv4Socket::can_write(FileDescriptor&) const
 {
 {
     return is_connected();
     return is_connected();
 }
 }
@@ -124,7 +124,7 @@ int IPv4Socket::allocate_source_port_if_needed()
     return port;
     return port;
 }
 }
 
 
-ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, const sockaddr* addr, socklen_t addr_length)
+ssize_t IPv4Socket::sendto(FileDescriptor&, const void* data, size_t data_length, int flags, const sockaddr* addr, socklen_t addr_length)
 {
 {
     (void)flags;
     (void)flags;
     if (addr && addr_length != sizeof(sockaddr_in))
     if (addr && addr_length != sizeof(sockaddr_in))
@@ -159,7 +159,7 @@ ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, cons
     return protocol_send(data, data_length);
     return protocol_send(data, data_length);
 }
 }
 
 
-ssize_t IPv4Socket::recvfrom(void* buffer, size_t buffer_length, int flags, sockaddr* addr, socklen_t* addr_length)
+ssize_t IPv4Socket::recvfrom(FileDescriptor& descriptor, void* buffer, size_t buffer_length, int flags, sockaddr* addr, socklen_t* addr_length)
 {
 {
     (void)flags;
     (void)flags;
     if (addr_length && *addr_length < sizeof(sockaddr_in))
     if (addr_length && *addr_length < sizeof(sockaddr_in))
@@ -186,9 +186,8 @@ ssize_t IPv4Socket::recvfrom(void* buffer, size_t buffer_length, int flags, sock
             return 0;
             return 0;
         }
         }
 
 
-        current->set_blocked_socket(this);
         load_receive_deadline();
         load_receive_deadline();
-        current->block(Thread::BlockedReceive);
+        current->block(Thread::BlockedReceive, descriptor);
 
 
         LOCKER(lock());
         LOCKER(lock());
         if (!m_can_read) {
         if (!m_can_read) {

+ 10 - 10
Kernel/Net/IPv4Socket.h

@@ -21,16 +21,16 @@ public:
     static Lockable<HashTable<IPv4Socket*>>& all_sockets();
     static Lockable<HashTable<IPv4Socket*>>& all_sockets();
 
 
     virtual KResult bind(const sockaddr*, socklen_t) override;
     virtual KResult bind(const sockaddr*, socklen_t) override;
-    virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
+    virtual KResult connect(FileDescriptor&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
     virtual bool get_address(sockaddr*, socklen_t*) override;
     virtual bool get_address(sockaddr*, socklen_t*) override;
-    virtual void attach_fd(SocketRole) override;
-    virtual void detach_fd(SocketRole) override;
-    virtual bool can_read(SocketRole) const override;
-    virtual ssize_t read(SocketRole, byte*, ssize_t) override;
-    virtual ssize_t write(SocketRole, const byte*, ssize_t) override;
-    virtual bool can_write(SocketRole) const override;
-    virtual ssize_t sendto(const void*, size_t, int, const sockaddr*, socklen_t) override;
-    virtual ssize_t recvfrom(void*, size_t, int flags, sockaddr*, socklen_t*) override;
+    virtual void attach(FileDescriptor&) override;
+    virtual void detach(FileDescriptor&) override;
+    virtual bool can_read(FileDescriptor&) const override;
+    virtual ssize_t read(FileDescriptor&, byte*, ssize_t) override;
+    virtual ssize_t write(FileDescriptor&, const byte*, ssize_t) override;
+    virtual bool can_write(FileDescriptor&) const override;
+    virtual ssize_t sendto(FileDescriptor&, const void*, size_t, int, const sockaddr*, socklen_t) override;
+    virtual ssize_t recvfrom(FileDescriptor&, void*, size_t, int flags, sockaddr*, socklen_t*) override;
 
 
     void did_receive(ByteBuffer&&);
     void did_receive(ByteBuffer&&);
 
 
@@ -49,7 +49,7 @@ protected:
 
 
     virtual int protocol_receive(const ByteBuffer&, void*, size_t, int, sockaddr*, socklen_t*) { return -ENOTIMPL; }
     virtual int protocol_receive(const ByteBuffer&, void*, size_t, int, sockaddr*, socklen_t*) { return -ENOTIMPL; }
     virtual int protocol_send(const void*, int) { return -ENOTIMPL; }
     virtual int protocol_send(const void*, int) { return -ENOTIMPL; }
-    virtual KResult protocol_connect(ShouldBlock) { return KSuccess; }
+    virtual KResult protocol_connect(FileDescriptor&, ShouldBlock) { return KSuccess; }
     virtual int protocol_allocate_source_port() { return 0; }
     virtual int protocol_allocate_source_port() { return 0; }
     virtual bool protocol_is_disconnected() const { return false; }
     virtual bool protocol_is_disconnected() const { return false; }
 
 

+ 30 - 20
Kernel/Net/LocalSocket.cpp

@@ -66,7 +66,7 @@ KResult LocalSocket::bind(const sockaddr* address, socklen_t address_size)
     return KSuccess;
     return KSuccess;
 }
 }
 
 
-KResult LocalSocket::connect(const sockaddr* address, socklen_t address_size, ShouldBlock)
+KResult LocalSocket::connect(FileDescriptor& descriptor, const sockaddr* address, socklen_t address_size, ShouldBlock)
 {
 {
     ASSERT(!m_bound);
     ASSERT(!m_bound);
     if (address_size != sizeof(sockaddr_un))
     if (address_size != sizeof(sockaddr_un))
@@ -98,36 +98,45 @@ KResult LocalSocket::connect(const sockaddr* address, socklen_t address_size, Sh
     if (result.is_error())
     if (result.is_error())
         return result;
         return result;
 
 
-    return current->wait_for_connect(*this);
+    return current->wait_for_connect(descriptor);
 }
 }
 
 
-void LocalSocket::attach_fd(SocketRole role)
+void LocalSocket::attach(FileDescriptor& descriptor)
 {
 {
-    if (role == SocketRole::Accepted) {
+    switch (descriptor.socket_role()) {
+    case SocketRole::Accepted:
         ++m_accepted_fds_open;
         ++m_accepted_fds_open;
-    } else if (role == SocketRole::Connected) {
+        break;
+    case SocketRole::Connected:
         ++m_connected_fds_open;
         ++m_connected_fds_open;
-    } else if (role == SocketRole::Connecting) {
+        break;
+    case SocketRole::Connecting:
         ++m_connecting_fds_open;
         ++m_connecting_fds_open;
+        break;
     }
     }
 }
 }
 
 
-void LocalSocket::detach_fd(SocketRole role)
+void LocalSocket::detach(FileDescriptor& descriptor)
 {
 {
-    if (role == SocketRole::Accepted) {
+    switch (descriptor.socket_role()) {
+    case SocketRole::Accepted:
         ASSERT(m_accepted_fds_open);
         ASSERT(m_accepted_fds_open);
         --m_accepted_fds_open;
         --m_accepted_fds_open;
-    } else if (role == SocketRole::Connected) {
+        break;
+    case SocketRole::Connected:
         ASSERT(m_connected_fds_open);
         ASSERT(m_connected_fds_open);
         --m_connected_fds_open;
         --m_connected_fds_open;
-    } else if (role == SocketRole::Connecting) {
+        break;
+    case SocketRole::Connecting:
         ASSERT(m_connecting_fds_open);
         ASSERT(m_connecting_fds_open);
         --m_connecting_fds_open;
         --m_connecting_fds_open;
+        break;
     }
     }
 }
 }
 
 
-bool LocalSocket::can_read(SocketRole role) const
+bool LocalSocket::can_read(FileDescriptor& descriptor) const
 {
 {
+    auto role = descriptor.socket_role();
     if (role == SocketRole::Listener)
     if (role == SocketRole::Listener)
         return can_accept();
         return can_accept();
     if (role == SocketRole::Accepted)
     if (role == SocketRole::Accepted)
@@ -137,8 +146,9 @@ bool LocalSocket::can_read(SocketRole role) const
     ASSERT_NOT_REACHED();
     ASSERT_NOT_REACHED();
 }
 }
 
 
-ssize_t LocalSocket::read(SocketRole role, byte* buffer, ssize_t size)
+ssize_t LocalSocket::read(FileDescriptor& descriptor, byte* buffer, ssize_t size)
 {
 {
+    auto role = descriptor.socket_role();
     if (role == SocketRole::Accepted)
     if (role == SocketRole::Accepted)
         return m_for_server.read(buffer, size);
         return m_for_server.read(buffer, size);
     if (role == SocketRole::Connected)
     if (role == SocketRole::Connected)
@@ -146,14 +156,14 @@ ssize_t LocalSocket::read(SocketRole role, byte* buffer, ssize_t size)
     ASSERT_NOT_REACHED();
     ASSERT_NOT_REACHED();
 }
 }
 
 
-ssize_t LocalSocket::write(SocketRole role, const byte* data, ssize_t size)
+ssize_t LocalSocket::write(FileDescriptor& descriptor, const byte* data, ssize_t size)
 {
 {
-    if (role == SocketRole::Accepted) {
+    if (descriptor.socket_role() == SocketRole::Accepted) {
         if (!m_accepted_fds_open)
         if (!m_accepted_fds_open)
             return -EPIPE;
             return -EPIPE;
         return m_for_client.write(data, size);
         return m_for_client.write(data, size);
     }
     }
-    if (role == SocketRole::Connected) {
+    if (descriptor.socket_role() == SocketRole::Connected) {
         if (!m_connected_fds_open && !m_connecting_fds_open)
         if (!m_connected_fds_open && !m_connecting_fds_open)
             return -EPIPE;
             return -EPIPE;
         return m_for_server.write(data, size);
         return m_for_server.write(data, size);
@@ -161,21 +171,21 @@ ssize_t LocalSocket::write(SocketRole role, const byte* data, ssize_t size)
     ASSERT_NOT_REACHED();
     ASSERT_NOT_REACHED();
 }
 }
 
 
-bool LocalSocket::can_write(SocketRole role) const
+bool LocalSocket::can_write(FileDescriptor& descriptor) const
 {
 {
-    if (role == SocketRole::Accepted)
+    if (descriptor.socket_role() == SocketRole::Accepted)
         return (!m_connected_fds_open && !m_connecting_fds_open) || m_for_client.bytes_in_write_buffer() < 4096;
         return (!m_connected_fds_open && !m_connecting_fds_open) || m_for_client.bytes_in_write_buffer() < 4096;
-    if (role == SocketRole::Connected)
+    if (descriptor.socket_role() == SocketRole::Connected)
         return !m_accepted_fds_open || m_for_server.bytes_in_write_buffer() < 4096;
         return !m_accepted_fds_open || m_for_server.bytes_in_write_buffer() < 4096;
     ASSERT_NOT_REACHED();
     ASSERT_NOT_REACHED();
 }
 }
 
 
-ssize_t LocalSocket::sendto(const void*, size_t, int, const sockaddr*, socklen_t)
+ssize_t LocalSocket::sendto(FileDescriptor&, const void*, size_t, int, const sockaddr*, socklen_t)
 {
 {
     ASSERT_NOT_REACHED();
     ASSERT_NOT_REACHED();
 }
 }
 
 
-ssize_t LocalSocket::recvfrom(void*, size_t, int flags, sockaddr*, socklen_t*)
+ssize_t LocalSocket::recvfrom(FileDescriptor&, void*, size_t, int flags, sockaddr*, socklen_t*)
 {
 {
     ASSERT_NOT_REACHED();
     ASSERT_NOT_REACHED();
 }
 }

+ 9 - 9
Kernel/Net/LocalSocket.h

@@ -11,16 +11,16 @@ public:
     virtual ~LocalSocket() override;
     virtual ~LocalSocket() override;
 
 
     virtual KResult bind(const sockaddr*, socklen_t) override;
     virtual KResult bind(const sockaddr*, socklen_t) override;
-    virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
+    virtual KResult connect(FileDescriptor&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
     virtual bool get_address(sockaddr*, socklen_t*) override;
     virtual bool get_address(sockaddr*, socklen_t*) override;
-    virtual void attach_fd(SocketRole) override;
-    virtual void detach_fd(SocketRole) override;
-    virtual bool can_read(SocketRole) const override;
-    virtual ssize_t read(SocketRole, byte*, ssize_t) override;
-    virtual ssize_t write(SocketRole, const byte*, ssize_t) override;
-    virtual bool can_write(SocketRole) const override;
-    virtual ssize_t sendto(const void*, size_t, int, const sockaddr*, socklen_t) override;
-    virtual ssize_t recvfrom(void*, size_t, int flags, sockaddr*, socklen_t*) override;
+    virtual void attach(FileDescriptor&) override;
+    virtual void detach(FileDescriptor&) override;
+    virtual bool can_read(FileDescriptor&) const override;
+    virtual ssize_t read(FileDescriptor&, byte*, ssize_t) override;
+    virtual ssize_t write(FileDescriptor&, const byte*, ssize_t) override;
+    virtual bool can_write(FileDescriptor&) const override;
+    virtual ssize_t sendto(FileDescriptor&, const void*, size_t, int, const sockaddr*, socklen_t) override;
+    virtual ssize_t recvfrom(FileDescriptor&, void*, size_t, int flags, sockaddr*, socklen_t*) override;
 
 
 private:
 private:
     explicit LocalSocket(int type);
     explicit LocalSocket(int type);

+ 11 - 9
Kernel/Net/Socket.h

@@ -11,6 +11,8 @@
 enum class SocketRole { None, Listener, Accepted, Connected, Connecting };
 enum class SocketRole { None, Listener, Accepted, Connected, Connecting };
 enum class ShouldBlock { No = 0, Yes = 1 };
 enum class ShouldBlock { No = 0, Yes = 1 };
 
 
+class FileDescriptor;
+
 class Socket : public Retainable<Socket> {
 class Socket : public Retainable<Socket> {
 public:
 public:
     static KResultOr<Retained<Socket>> create(int domain, int type, int protocol);
     static KResultOr<Retained<Socket>> create(int domain, int type, int protocol);
@@ -26,18 +28,18 @@ public:
     KResult listen(int backlog);
     KResult listen(int backlog);
 
 
     virtual KResult bind(const sockaddr*, socklen_t) = 0;
     virtual KResult bind(const sockaddr*, socklen_t) = 0;
-    virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock) = 0;
+    virtual KResult connect(FileDescriptor&, const sockaddr*, socklen_t, ShouldBlock) = 0;
     virtual bool get_address(sockaddr*, socklen_t*) = 0;
     virtual bool get_address(sockaddr*, socklen_t*) = 0;
     virtual bool is_local() const { return false; }
     virtual bool is_local() const { return false; }
     virtual bool is_ipv4() const { return false; }
     virtual bool is_ipv4() const { return false; }
-    virtual void attach_fd(SocketRole) = 0;
-    virtual void detach_fd(SocketRole) = 0;
-    virtual bool can_read(SocketRole) const = 0;
-    virtual ssize_t read(SocketRole, byte*, ssize_t) = 0;
-    virtual ssize_t write(SocketRole, const byte*, ssize_t) = 0;
-    virtual bool can_write(SocketRole) const = 0;
-    virtual ssize_t sendto(const void*, size_t, int flags, const sockaddr*, socklen_t) = 0;
-    virtual ssize_t recvfrom(void*, size_t, int flags, sockaddr*, socklen_t*) = 0;
+    virtual void attach(FileDescriptor&) = 0;
+    virtual void detach(FileDescriptor&) = 0;
+    virtual bool can_read(FileDescriptor&) const = 0;
+    virtual ssize_t read(FileDescriptor&, byte*, ssize_t) = 0;
+    virtual ssize_t write(FileDescriptor&, const byte*, ssize_t) = 0;
+    virtual bool can_write(FileDescriptor&) const = 0;
+    virtual ssize_t sendto(FileDescriptor&, const void*, size_t, int flags, const sockaddr*, socklen_t) = 0;
+    virtual ssize_t recvfrom(FileDescriptor&, void*, size_t, int flags, sockaddr*, socklen_t*) = 0;
 
 
     KResult setsockopt(int level, int option, const void*, socklen_t);
     KResult setsockopt(int level, int option, const void*, socklen_t);
     KResult getsockopt(int level, int option, void*, socklen_t*);
     KResult getsockopt(int level, int option, void*, socklen_t*);

+ 2 - 3
Kernel/Net/TCPSocket.cpp

@@ -152,7 +152,7 @@ NetworkOrdered<word> TCPSocket::compute_tcp_checksum(const IPv4Address& source,
     return ~(checksum & 0xffff);
     return ~(checksum & 0xffff);
 }
 }
 
 
-KResult TCPSocket::protocol_connect(ShouldBlock should_block)
+KResult TCPSocket::protocol_connect(FileDescriptor& descriptor, ShouldBlock should_block)
 {
 {
     auto* adapter = adapter_for_route_to(destination_address());
     auto* adapter = adapter_for_route_to(destination_address());
     if (!adapter)
     if (!adapter)
@@ -167,8 +167,7 @@ KResult TCPSocket::protocol_connect(ShouldBlock should_block)
     m_state = State::Connecting;
     m_state = State::Connecting;
 
 
     if (should_block == ShouldBlock::Yes) {
     if (should_block == ShouldBlock::Yes) {
-        current->set_blocked_socket(this);
-        current->block(Thread::BlockedConnect);
+        current->block(Thread::BlockedConnect, descriptor);
         ASSERT(is_connected());
         ASSERT(is_connected());
         return KSuccess;
         return KSuccess;
     }
     }

+ 1 - 1
Kernel/Net/TCPSocket.h

@@ -34,7 +34,7 @@ private:
 
 
     virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override;
     virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override;
     virtual int protocol_send(const void*, int) override;
     virtual int protocol_send(const void*, int) override;
-    virtual KResult protocol_connect(ShouldBlock) override;
+    virtual KResult protocol_connect(FileDescriptor&, ShouldBlock) override;
     virtual int protocol_allocate_source_port() override;
     virtual int protocol_allocate_source_port() override;
     virtual bool protocol_is_disconnected() const override;
     virtual bool protocol_is_disconnected() const override;
 
 

+ 0 - 5
Kernel/Net/UDPSocket.cpp

@@ -81,11 +81,6 @@ int UDPSocket::protocol_send(const void* data, int data_length)
     return data_length;
     return data_length;
 }
 }
 
 
-KResult UDPSocket::protocol_connect(ShouldBlock)
-{
-    return KSuccess;
-}
-
 int UDPSocket::protocol_allocate_source_port()
 int UDPSocket::protocol_allocate_source_port()
 {
 {
     static const word first_ephemeral_port = 32768;
     static const word first_ephemeral_port = 32768;

+ 1 - 1
Kernel/Net/UDPSocket.h

@@ -17,7 +17,7 @@ private:
 
 
     virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override;
     virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override;
     virtual int protocol_send(const void*, int) override;
     virtual int protocol_send(const void*, int) override;
-    virtual KResult protocol_connect(ShouldBlock) override;
+    virtual KResult protocol_connect(FileDescriptor&, ShouldBlock) override { return KSuccess; }
     virtual int protocol_allocate_source_port() override;
     virtual int protocol_allocate_source_port() override;
 };
 };
 
 

+ 5 - 7
Kernel/Process.cpp

@@ -840,8 +840,7 @@ ssize_t Process::sys$write(int fd, const byte* data, ssize_t size)
 #ifdef IO_DEBUG
 #ifdef IO_DEBUG
                 dbgprintf("block write on %d\n", fd);
                 dbgprintf("block write on %d\n", fd);
 #endif
 #endif
-                current->m_blocked_fd = fd;
-                current->block(Thread::State::BlockedWrite);
+                current->block(Thread::State::BlockedWrite, *descriptor);
             }
             }
             ssize_t rc = descriptor->write((const byte*)data + nwritten, size - nwritten);
             ssize_t rc = descriptor->write((const byte*)data + nwritten, size - nwritten);
 #ifdef IO_DEBUG
 #ifdef IO_DEBUG
@@ -888,8 +887,7 @@ ssize_t Process::sys$read(int fd, byte* buffer, ssize_t size)
         return -EBADF;
         return -EBADF;
     if (descriptor->is_blocking()) {
     if (descriptor->is_blocking()) {
         if (!descriptor->can_read()) {
         if (!descriptor->can_read()) {
-            current->m_blocked_fd = fd;
-            current->block(Thread::State::BlockedRead);
+            current->block(Thread::State::BlockedRead, *descriptor);
             if (current->m_was_interrupted_while_blocked)
             if (current->m_was_interrupted_while_blocked)
                 return -EINTR;
                 return -EINTR;
         }
         }
@@ -2057,7 +2055,7 @@ int Process::sys$connect(int sockfd, const sockaddr* address, socklen_t address_
         return -EISCONN;
         return -EISCONN;
     auto& socket = *descriptor->socket();
     auto& socket = *descriptor->socket();
     descriptor->set_socket_role(SocketRole::Connecting);
     descriptor->set_socket_role(SocketRole::Connecting);
-    auto result = socket.connect(address, address_size, descriptor->is_blocking() ? ShouldBlock::Yes : ShouldBlock::No);
+    auto result = socket.connect(*descriptor, address, address_size, descriptor->is_blocking() ? ShouldBlock::Yes : ShouldBlock::No);
     if (result.is_error()) {
     if (result.is_error()) {
         descriptor->set_socket_role(SocketRole::None);
         descriptor->set_socket_role(SocketRole::None);
         return result;
         return result;
@@ -2089,7 +2087,7 @@ ssize_t Process::sys$sendto(const Syscall::SC_sendto_params* params)
         return -ENOTSOCK;
         return -ENOTSOCK;
     auto& socket = *descriptor->socket();
     auto& socket = *descriptor->socket();
     kprintf("sendto %p (%u), flags=%u, addr: %p (%u)\n", data, data_length, flags, addr, addr_length);
     kprintf("sendto %p (%u), flags=%u, addr: %p (%u)\n", data, data_length, flags, addr, addr_length);
-    return socket.sendto(data, data_length, flags, addr, addr_length);
+    return socket.sendto(*descriptor, data, data_length, flags, addr, addr_length);
 }
 }
 
 
 ssize_t Process::sys$recvfrom(const Syscall::SC_recvfrom_params* params)
 ssize_t Process::sys$recvfrom(const Syscall::SC_recvfrom_params* params)
@@ -2121,7 +2119,7 @@ ssize_t Process::sys$recvfrom(const Syscall::SC_recvfrom_params* params)
         return -ENOTSOCK;
         return -ENOTSOCK;
     auto& socket = *descriptor->socket();
     auto& socket = *descriptor->socket();
     kprintf("recvfrom %p (%u), flags=%u, addr: %p (%p)\n", buffer, buffer_length, flags, addr, addr_length);
     kprintf("recvfrom %p (%u), flags=%u, addr: %p (%p)\n", buffer, buffer_length, flags, addr, addr_length);
-    return socket.recvfrom(buffer, buffer_length, flags, addr, addr_length);
+    return socket.recvfrom(*descriptor, buffer, buffer_length, flags, addr, addr_length);
 }
 }
 
 
 int Process::sys$getsockopt(const Syscall::SC_getsockopt_params* params)
 int Process::sys$getsockopt(const Syscall::SC_getsockopt_params* params)

+ 10 - 10
Kernel/Scheduler.cpp

@@ -88,35 +88,35 @@ bool Scheduler::pick_next()
         }
         }
 
 
         if (thread.state() == Thread::BlockedRead) {
         if (thread.state() == Thread::BlockedRead) {
-            ASSERT(thread.m_blocked_fd != -1);
+            ASSERT(thread.m_blocked_descriptor);
             // FIXME: Block until the amount of data wanted is available.
             // FIXME: Block until the amount of data wanted is available.
-            if (process.m_fds[thread.m_blocked_fd].descriptor->can_read())
+            if (thread.m_blocked_descriptor->can_read())
                 thread.unblock();
                 thread.unblock();
             return IterationDecision::Continue;
             return IterationDecision::Continue;
         }
         }
 
 
         if (thread.state() == Thread::BlockedWrite) {
         if (thread.state() == Thread::BlockedWrite) {
-            ASSERT(thread.m_blocked_fd != -1);
-            if (process.m_fds[thread.m_blocked_fd].descriptor->can_write())
+            ASSERT(thread.m_blocked_descriptor != -1);
+            if (thread.m_blocked_descriptor->can_write())
                 thread.unblock();
                 thread.unblock();
             return IterationDecision::Continue;
             return IterationDecision::Continue;
         }
         }
 
 
         if (thread.state() == Thread::BlockedConnect) {
         if (thread.state() == Thread::BlockedConnect) {
-            ASSERT(thread.m_blocked_socket);
-            if (thread.m_blocked_socket->is_connected())
+            auto& descriptor = *thread.m_blocked_descriptor;
+            auto& socket = *descriptor.socket();
+            if (socket.is_connected())
                 thread.unblock();
                 thread.unblock();
             return IterationDecision::Continue;
             return IterationDecision::Continue;
         }
         }
 
 
         if (thread.state() == Thread::BlockedReceive) {
         if (thread.state() == Thread::BlockedReceive) {
-            ASSERT(thread.m_blocked_socket);
-            auto& socket = *thread.m_blocked_socket;
+            auto& descriptor = *thread.m_blocked_descriptor;
+            auto& socket = *descriptor.socket();
             // FIXME: Block until the amount of data wanted is available.
             // FIXME: Block until the amount of data wanted is available.
             bool timed_out = now_sec > socket.receive_deadline().tv_sec || (now_sec == socket.receive_deadline().tv_sec && now_usec >= socket.receive_deadline().tv_usec);
             bool timed_out = now_sec > socket.receive_deadline().tv_sec || (now_sec == socket.receive_deadline().tv_sec && now_usec >= socket.receive_deadline().tv_usec);
-            if (timed_out || socket.can_read(SocketRole::None)) {
+            if (timed_out || descriptor.can_read()) {
                 thread.unblock();
                 thread.unblock();
-                thread.m_blocked_socket = nullptr;
                 return IterationDecision::Continue;
                 return IterationDecision::Continue;
             }
             }
             return IterationDecision::Continue;
             return IterationDecision::Continue;

+ 14 - 11
Kernel/Thread.cpp

@@ -1,7 +1,7 @@
 #include <Kernel/Thread.h>
 #include <Kernel/Thread.h>
 #include <Kernel/Scheduler.h>
 #include <Kernel/Scheduler.h>
 #include <Kernel/Process.h>
 #include <Kernel/Process.h>
-#include <Kernel/Net/Socket.h>
+#include <Kernel/FileSystem/FileDescriptor.h>
 #include <Kernel/VM/MemoryManager.h>
 #include <Kernel/VM/MemoryManager.h>
 #include <LibC/signal_numbers.h>
 #include <LibC/signal_numbers.h>
 
 
@@ -91,6 +91,7 @@ Thread::~Thread()
 
 
 void Thread::unblock()
 void Thread::unblock()
 {
 {
+    m_blocked_descriptor = nullptr;
     if (current == this) {
     if (current == this) {
         m_state = Thread::Running;
         m_state = Thread::Running;
         return;
         return;
@@ -120,6 +121,12 @@ void Thread::block(Thread::State new_state)
         process().big_lock().lock();
         process().big_lock().lock();
 }
 }
 
 
+void Thread::block(Thread::State new_state, FileDescriptor& descriptor)
+{
+    m_blocked_descriptor = &descriptor;
+    block(new_state);
+}
+
 void Thread::sleep(dword ticks)
 void Thread::sleep(dword ticks)
 {
 {
     ASSERT(state() == Thread::Running);
     ASSERT(state() == Thread::Running);
@@ -157,9 +164,10 @@ const char* to_string(Thread::State state)
 void Thread::finalize()
 void Thread::finalize()
 {
 {
     dbgprintf("Finalizing Thread %u in %s(%u)\n", tid(), m_process.name().characters(), pid());
     dbgprintf("Finalizing Thread %u in %s(%u)\n", tid(), m_process.name().characters(), pid());
-    m_blocked_socket = nullptr;
     set_state(Thread::State::Dead);
     set_state(Thread::State::Dead);
 
 
+    m_blocked_descriptor = nullptr;
+
     if (this == &m_process.main_thread())
     if (this == &m_process.main_thread())
         m_process.finalize();
         m_process.finalize();
 }
 }
@@ -496,14 +504,14 @@ Thread* Thread::clone(Process& process)
     return clone;
     return clone;
 }
 }
 
 
-KResult Thread::wait_for_connect(Socket& socket)
+KResult Thread::wait_for_connect(FileDescriptor& descriptor)
 {
 {
+    ASSERT(descriptor.is_socket());
+    auto& socket = *descriptor.socket();
     if (socket.is_connected())
     if (socket.is_connected())
         return KSuccess;
         return KSuccess;
-    m_blocked_socket = socket;
-    block(Thread::State::BlockedConnect);
+    block(Thread::State::BlockedConnect, descriptor);
     Scheduler::yield();
     Scheduler::yield();
-    m_blocked_socket = nullptr;
     if (!socket.is_connected())
     if (!socket.is_connected())
         return KResult(-ECONNREFUSED);
         return KResult(-ECONNREFUSED);
     return KSuccess;
     return KSuccess;
@@ -533,8 +541,3 @@ bool Thread::is_thread(void* ptr)
     }
     }
     return false;
     return false;
 }
 }
-
-void Thread::set_blocked_socket(Socket* socket)
-{
-    m_blocked_socket = socket;
-}

+ 4 - 6
Kernel/Thread.h

@@ -11,9 +11,9 @@
 #include <AK/Vector.h>
 #include <AK/Vector.h>
 
 
 class Alarm;
 class Alarm;
+class FileDescriptor;
 class Process;
 class Process;
 class Region;
 class Region;
-class Socket;
 
 
 enum class ShouldUnblockThread { No = 0, Yes };
 enum class ShouldUnblockThread { No = 0, Yes };
 
 
@@ -86,12 +86,13 @@ public:
 
 
     void sleep(dword ticks);
     void sleep(dword ticks);
     void block(Thread::State);
     void block(Thread::State);
+    void block(Thread::State, FileDescriptor&);
     void unblock();
     void unblock();
 
 
     void set_wakeup_time(qword t) { m_wakeup_time = t; }
     void set_wakeup_time(qword t) { m_wakeup_time = t; }
     qword wakeup_time() const { return m_wakeup_time; }
     qword wakeup_time() const { return m_wakeup_time; }
     void snooze_until(Alarm&);
     void snooze_until(Alarm&);
-    KResult wait_for_connect(Socket&);
+    KResult wait_for_connect(FileDescriptor&);
 
 
     const FarPtr& far_ptr() const { return m_far_ptr; }
     const FarPtr& far_ptr() const { return m_far_ptr; }
 
 
@@ -116,8 +117,6 @@ public:
     bool has_used_fpu() const { return m_has_used_fpu; }
     bool has_used_fpu() const { return m_has_used_fpu; }
     void set_has_used_fpu(bool b) { m_has_used_fpu = b; }
     void set_has_used_fpu(bool b) { m_has_used_fpu = b; }
 
 
-    void set_blocked_socket(Socket*);
-
     void set_default_signal_dispositions();
     void set_default_signal_dispositions();
     void push_value_on_stack(dword);
     void push_value_on_stack(dword);
     void make_userspace_stack_for_main_thread(Vector<String> arguments, Vector<String> environment);
     void make_userspace_stack_for_main_thread(Vector<String> arguments, Vector<String> environment);
@@ -148,10 +147,9 @@ private:
     void* m_kernel_stack { nullptr };
     void* m_kernel_stack { nullptr };
     void* m_kernel_stack_for_signal_handler { nullptr };
     void* m_kernel_stack_for_signal_handler { nullptr };
     pid_t m_waitee_pid { -1 };
     pid_t m_waitee_pid { -1 };
-    int m_blocked_fd { -1 };
+    RetainPtr<FileDescriptor> m_blocked_descriptor;
     timeval m_select_timeout;
     timeval m_select_timeout;
     SignalActionData m_signal_action_data[32];
     SignalActionData m_signal_action_data[32];
-    RetainPtr<Socket> m_blocked_socket;
     Region* m_signal_stack_user_region { nullptr };
     Region* m_signal_stack_user_region { nullptr };
     Alarm* m_snoozing_alarm { nullptr };
     Alarm* m_snoozing_alarm { nullptr };
     Vector<int> m_select_read_fds;
     Vector<int> m_select_read_fds;