Prechádzať zdrojové kódy

Net: Add a basic sys$shutdown() implementation

Calling shutdown prevents further reads and/or writes on a socket.
We should do a few more things based on the type of socket, but this
initial implementation just puts the basic mechanism in place.

Work towards #428.
Andreas Kling 5 rokov pred
rodič
commit
2b0b7cc5a4

+ 13 - 0
Kernel/Net/Socket.cpp

@@ -149,10 +149,23 @@ KResult Socket::getsockopt(FileDescription&, int level, int option, void* value,
 
 ssize_t Socket::read(FileDescription& description, u8* buffer, ssize_t size)
 {
+    if (is_shut_down_for_reading())
+        return 0;
     return recvfrom(description, buffer, size, 0, nullptr, 0);
 }
 
 ssize_t Socket::write(FileDescription& description, const u8* data, ssize_t size)
 {
+    if (is_shut_down_for_writing())
+        return -EPIPE;
     return sendto(description, data, size, 0, nullptr, 0);
 }
+
+KResult Socket::shutdown(int how)
+{
+    if (type() == SOCK_STREAM && !is_connected())
+        return KResult(-ENOTCONN);
+    m_shut_down_for_reading |= how & SHUT_RD;
+    m_shut_down_for_writing |= how & SHUT_WR;
+    return KSuccess;
+}

+ 7 - 0
Kernel/Net/Socket.h

@@ -51,6 +51,9 @@ public:
     int type() const { return m_type; }
     int protocol() const { return m_protocol; }
 
+    bool is_shut_down_for_writing() const { return m_shut_down_for_writing; }
+    bool is_shut_down_for_reading() const { return m_shut_down_for_reading; }
+
     enum class SetupState {
         Unstarted,  // we haven't tried to set the socket up yet
         InProgress, // we're in the process of setting things up - for TCP maybe we've sent a SYN packet
@@ -90,6 +93,8 @@ public:
     bool can_accept() const { return !m_pending.is_empty(); }
     RefPtr<Socket> accept();
 
+    KResult shutdown(int how);
+
     virtual KResult bind(const sockaddr*, socklen_t) = 0;
     virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock) = 0;
     virtual KResult listen(int) = 0;
@@ -153,6 +158,8 @@ private:
     int m_backlog { 0 };
     SetupState m_setup_state { SetupState::Unstarted };
     bool m_connected { false };
+    bool m_shut_down_for_reading { false };
+    bool m_shut_down_for_writing { false };
 
     timeval m_receive_timeout { 0, 0 };
     timeval m_send_timeout { 0, 0 };

+ 22 - 1
Kernel/Process.cpp

@@ -3221,6 +3221,22 @@ int Process::sys$connect(int sockfd, const sockaddr* address, socklen_t address_
     return socket.connect(*description, address, address_size, description->is_blocking() ? ShouldBlock::Yes : ShouldBlock::No);
 }
 
+int Process::sys$shutdown(int sockfd, int how)
+{
+    REQUIRE_PROMISE(stdio);
+    if (how & ~SHUT_RDWR)
+        return -EINVAL;
+    auto description = file_description(sockfd);
+    if (!description)
+        return -EBADF;
+    if (!description->is_socket())
+        return -ENOTSOCK;
+
+    auto& socket = *description->socket();
+    REQUIRE_PROMISE_FOR_SOCKET_DOMAIN(socket.domain());
+    return socket.shutdown(how);
+}
+
 ssize_t Process::sys$sendto(const Syscall::SC_sendto_params* user_params)
 {
     REQUIRE_PROMISE(stdio);
@@ -3241,8 +3257,10 @@ ssize_t Process::sys$sendto(const Syscall::SC_sendto_params* user_params)
         return -EBADF;
     if (!description->is_socket())
         return -ENOTSOCK;
-    SmapDisabler disabler;
     auto& socket = *description->socket();
+    if (socket.is_shut_down_for_writing())
+        return -EPIPE;
+    SmapDisabler disabler;
     return socket.sendto(*description, params.data.data, params.data.size, flags, addr, addr_length);
 }
 
@@ -3276,6 +3294,9 @@ ssize_t Process::sys$recvfrom(const Syscall::SC_recvfrom_params* user_params)
         return -ENOTSOCK;
     auto& socket = *description->socket();
 
+    if (socket.is_shut_down_for_reading())
+        return 0;
+
     bool original_blocking = description->is_blocking();
     if (flags & MSG_DONTWAIT)
         description->set_blocking(false);

+ 1 - 0
Kernel/Process.h

@@ -260,6 +260,7 @@ public:
     int sys$listen(int sockfd, int backlog);
     int sys$accept(int sockfd, sockaddr*, socklen_t*);
     int sys$connect(int sockfd, const sockaddr*, socklen_t);
+    int sys$shutdown(int sockfd, int how);
     ssize_t sys$sendto(const Syscall::SC_sendto_params*);
     ssize_t sys$recvfrom(const Syscall::SC_recvfrom_params*);
     int sys$getsockopt(const Syscall::SC_getsockopt_params*);

+ 2 - 1
Kernel/Syscall.h

@@ -177,7 +177,8 @@ typedef u32 socklen_t;
     __ENUMERATE_SYSCALL(chroot)                     \
     __ENUMERATE_SYSCALL(pledge)                     \
     __ENUMERATE_SYSCALL(unveil)                     \
-    __ENUMERATE_SYSCALL(perf_event)
+    __ENUMERATE_SYSCALL(perf_event)                 \
+    __ENUMERATE_SYSCALL(shutdown)
 
 namespace Syscall {
 

+ 4 - 0
Kernel/UnixTypes.h

@@ -387,6 +387,10 @@ struct pollfd {
 #define SOCK_NONBLOCK 04000
 #define SOCK_CLOEXEC 02000000
 
+#define SHUT_RD 1
+#define SHUT_WR 2
+#define SHUT_RDWR 3
+
 #define MSG_DONTWAIT 0x40
 
 #define SOL_SOCKET 1

+ 6 - 0
Libraries/LibC/sys/socket.cpp

@@ -62,6 +62,12 @@ int connect(int sockfd, const sockaddr* addr, socklen_t addrlen)
     __RETURN_WITH_ERRNO(rc, rc, -1);
 }
 
+int shutdown(int sockfd, int how)
+{
+    int rc = syscall(SC_shutdown, sockfd, how);
+    __RETURN_WITH_ERRNO(rc, rc, -1);
+}
+
 ssize_t sendto(int sockfd, const void* data, size_t data_length, int flags, const struct sockaddr* addr, socklen_t addr_length)
 {
     Syscall::SC_sendto_params params { sockfd, { data, data_length }, flags, addr, addr_length };

+ 5 - 0
Libraries/LibC/sys/socket.h

@@ -49,6 +49,10 @@ __BEGIN_DECLS
 #define SOCK_NONBLOCK 04000
 #define SOCK_CLOEXEC 02000000
 
+#define SHUT_RD 1
+#define SHUT_WR 2
+#define SHUT_RDWR 3
+
 #define IPPROTO_IP 0
 #define IPPROTO_ICMP 1
 #define IPPROTO_TCP 6
@@ -81,6 +85,7 @@ int bind(int sockfd, const struct sockaddr* addr, socklen_t);
 int listen(int sockfd, int backlog);
 int accept(int sockfd, struct sockaddr*, socklen_t*);
 int connect(int sockfd, const struct sockaddr*, socklen_t);
+int shutdown(int sockfd, int how);
 ssize_t send(int sockfd, const void*, size_t, int flags);
 ssize_t sendto(int sockfd, const void*, size_t, int flags, const struct sockaddr*, socklen_t);
 ssize_t recv(int sockfd, void*, size_t, int flags);