Sfoglia il codice sorgente

Kernel: Support non-blocking connect().

If connect() is called on a non-blocking socket, it will "fail" immediately
with -EINPROGRESS. After that, you select() on the socket and wait for it to
become writable.
Andreas Kling 6 anni fa
parent
commit
65d6318c33

+ 3 - 3
Kernel/Net/IPv4Socket.cpp

@@ -66,7 +66,7 @@ KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size)
     ASSERT_NOT_REACHED();
     ASSERT_NOT_REACHED();
 }
 }
 
 
-KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size)
+KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size, ShouldBlock should_block)
 {
 {
     ASSERT(!m_bound);
     ASSERT(!m_bound);
     if (address_size != sizeof(sockaddr_in))
     if (address_size != sizeof(sockaddr_in))
@@ -78,7 +78,7 @@ KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size)
     m_destination_address = IPv4Address((const byte*)&ia.sin_addr.s_addr);
     m_destination_address = IPv4Address((const byte*)&ia.sin_addr.s_addr);
     m_destination_port = ntohs(ia.sin_port);
     m_destination_port = ntohs(ia.sin_port);
 
 
-    return protocol_connect();
+    return protocol_connect(should_block);
 }
 }
 
 
 void IPv4Socket::attach_fd(SocketRole)
 void IPv4Socket::attach_fd(SocketRole)
@@ -110,7 +110,7 @@ ssize_t IPv4Socket::write(SocketRole, const byte* data, ssize_t size)
 
 
 bool IPv4Socket::can_write(SocketRole) const
 bool IPv4Socket::can_write(SocketRole) const
 {
 {
-    return true;
+    return is_connected();
 }
 }
 
 
 int IPv4Socket::allocate_source_port_if_needed()
 int IPv4Socket::allocate_source_port_if_needed()

+ 2 - 2
Kernel/Net/IPv4Socket.h

@@ -21,7 +21,7 @@ public:
     static Lockable<HashTable<IPv4Socket*>>& all_sockets();
     static Lockable<HashTable<IPv4Socket*>>& all_sockets();
 
 
     virtual KResult bind(const sockaddr*, socklen_t) override;
     virtual KResult bind(const sockaddr*, socklen_t) override;
-    virtual KResult connect(const sockaddr*, socklen_t) override;
+    virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
     virtual bool get_address(sockaddr*, socklen_t*) override;
     virtual bool get_address(sockaddr*, socklen_t*) override;
     virtual void attach_fd(SocketRole) override;
     virtual void attach_fd(SocketRole) override;
     virtual void detach_fd(SocketRole) override;
     virtual void detach_fd(SocketRole) override;
@@ -49,7 +49,7 @@ protected:
 
 
     virtual int protocol_receive(const ByteBuffer&, void*, size_t, int, sockaddr*, socklen_t*) { return -ENOTIMPL; }
     virtual int protocol_receive(const ByteBuffer&, void*, size_t, int, sockaddr*, socklen_t*) { return -ENOTIMPL; }
     virtual int protocol_send(const void*, int) { return -ENOTIMPL; }
     virtual int protocol_send(const void*, int) { return -ENOTIMPL; }
-    virtual KResult protocol_connect() { return KSuccess; }
+    virtual KResult protocol_connect(ShouldBlock) { return KSuccess; }
     virtual int protocol_allocate_source_port() { return 0; }
     virtual int protocol_allocate_source_port() { return 0; }
     virtual bool protocol_is_disconnected() const { return false; }
     virtual bool protocol_is_disconnected() const { return false; }
 
 

+ 1 - 1
Kernel/Net/LocalSocket.cpp

@@ -65,7 +65,7 @@ KResult LocalSocket::bind(const sockaddr* address, socklen_t address_size)
     return KSuccess;
     return KSuccess;
 }
 }
 
 
-KResult LocalSocket::connect(const sockaddr* address, socklen_t address_size)
+KResult LocalSocket::connect(const sockaddr* address, socklen_t address_size, ShouldBlock)
 {
 {
     ASSERT(!m_bound);
     ASSERT(!m_bound);
     if (address_size != sizeof(sockaddr_un))
     if (address_size != sizeof(sockaddr_un))

+ 1 - 1
Kernel/Net/LocalSocket.h

@@ -11,7 +11,7 @@ public:
     virtual ~LocalSocket() override;
     virtual ~LocalSocket() override;
 
 
     virtual KResult bind(const sockaddr*, socklen_t) override;
     virtual KResult bind(const sockaddr*, socklen_t) override;
-    virtual KResult connect(const sockaddr*, socklen_t) override;
+    virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
     virtual bool get_address(sockaddr*, socklen_t*) override;
     virtual bool get_address(sockaddr*, socklen_t*) override;
     virtual void attach_fd(SocketRole) override;
     virtual void attach_fd(SocketRole) override;
     virtual void detach_fd(SocketRole) override;
     virtual void detach_fd(SocketRole) override;

+ 1 - 0
Kernel/Net/NetworkTask.cpp

@@ -342,6 +342,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
         socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
         socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
         socket->send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK);
         socket->send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK);
         socket->set_state(TCPSocket::State::Disconnecting);
         socket->set_state(TCPSocket::State::Disconnecting);
+        socket->set_connected(false);
         return;
         return;
     }
     }
 
 

+ 2 - 1
Kernel/Net/Socket.h

@@ -9,6 +9,7 @@
 #include <Kernel/KResult.h>
 #include <Kernel/KResult.h>
 
 
 enum class SocketRole { None, Listener, Accepted, Connected, Connecting };
 enum class SocketRole { None, Listener, Accepted, Connected, Connecting };
+enum class ShouldBlock { No = 0, Yes = 1 };
 
 
 class Socket : public Retainable<Socket> {
 class Socket : public Retainable<Socket> {
 public:
 public:
@@ -25,7 +26,7 @@ public:
     KResult listen(int backlog);
     KResult listen(int backlog);
 
 
     virtual KResult bind(const sockaddr*, socklen_t) = 0;
     virtual KResult bind(const sockaddr*, socklen_t) = 0;
-    virtual KResult connect(const sockaddr*, socklen_t) = 0;
+    virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock) = 0;
     virtual bool get_address(sockaddr*, socklen_t*) = 0;
     virtual bool get_address(sockaddr*, socklen_t*) = 0;
     virtual bool is_local() const { return false; }
     virtual bool is_local() const { return false; }
     virtual bool is_ipv4() const { return false; }
     virtual bool is_ipv4() const { return false; }

+ 8 - 5
Kernel/Net/TCPSocket.cpp

@@ -152,7 +152,7 @@ NetworkOrdered<word> TCPSocket::compute_tcp_checksum(const IPv4Address& source,
     return ~(checksum & 0xffff);
     return ~(checksum & 0xffff);
 }
 }
 
 
-KResult TCPSocket::protocol_connect()
+KResult TCPSocket::protocol_connect(ShouldBlock should_block)
 {
 {
     auto* adapter = adapter_for_route_to(destination_address());
     auto* adapter = adapter_for_route_to(destination_address());
     if (!adapter)
     if (!adapter)
@@ -166,11 +166,14 @@ KResult TCPSocket::protocol_connect()
     send_tcp_packet(TCPFlags::SYN);
     send_tcp_packet(TCPFlags::SYN);
     m_state = State::Connecting;
     m_state = State::Connecting;
 
 
-    current->set_blocked_socket(this);
-    current->block(Thread::BlockedConnect);
+    if (should_block == ShouldBlock::Yes) {
+        current->set_blocked_socket(this);
+        current->block(Thread::BlockedConnect);
+        ASSERT(is_connected());
+        return KSuccess;
+    }
 
 
-    ASSERT(is_connected());
-    return KSuccess;
+    return KResult(-EINPROGRESS);
 }
 }
 
 
 int TCPSocket::protocol_allocate_source_port()
 int TCPSocket::protocol_allocate_source_port()

+ 1 - 1
Kernel/Net/TCPSocket.h

@@ -34,7 +34,7 @@ private:
 
 
     virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override;
     virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override;
     virtual int protocol_send(const void*, int) override;
     virtual int protocol_send(const void*, int) override;
-    virtual KResult protocol_connect() override;
+    virtual KResult protocol_connect(ShouldBlock) override;
     virtual int protocol_allocate_source_port() override;
     virtual int protocol_allocate_source_port() override;
     virtual bool protocol_is_disconnected() const override;
     virtual bool protocol_is_disconnected() const override;
 
 

+ 1 - 1
Kernel/Net/UDPSocket.cpp

@@ -81,7 +81,7 @@ int UDPSocket::protocol_send(const void* data, int data_length)
     return data_length;
     return data_length;
 }
 }
 
 
-KResult UDPSocket::protocol_connect()
+KResult UDPSocket::protocol_connect(ShouldBlock)
 {
 {
     return KSuccess;
     return KSuccess;
 }
 }

+ 1 - 1
Kernel/Net/UDPSocket.h

@@ -17,7 +17,7 @@ private:
 
 
     virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override;
     virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override;
     virtual int protocol_send(const void*, int) override;
     virtual int protocol_send(const void*, int) override;
-    virtual KResult protocol_connect() override;
+    virtual KResult protocol_connect(ShouldBlock) override;
     virtual int protocol_allocate_source_port() override;
     virtual int protocol_allocate_source_port() override;
 };
 };
 
 

+ 1 - 1
Kernel/Process.cpp

@@ -2038,7 +2038,7 @@ int Process::sys$connect(int sockfd, const sockaddr* address, socklen_t address_
         return -EISCONN;
         return -EISCONN;
     auto& socket = *descriptor->socket();
     auto& socket = *descriptor->socket();
     descriptor->set_socket_role(SocketRole::Connecting);
     descriptor->set_socket_role(SocketRole::Connecting);
-    auto result = socket.connect(address, address_size);
+    auto result = socket.connect(address, address_size, descriptor->is_blocking() ? ShouldBlock::Yes : ShouldBlock::No);
     if (result.is_error()) {
     if (result.is_error()) {
         descriptor->set_socket_role(SocketRole::None);
         descriptor->set_socket_role(SocketRole::None);
         return result;
         return result;