Browse Source

Kernel: Make sys$recvfrom() with MSG_DONTWAIT not so racy

Instead of temporary changing the open file description's "blocking"
flag while doing a non-waiting recvfrom, we instead plumb the currently
wanted blocking behavior all the way through to the underlying socket.
Andreas Kling 2 years ago
parent
commit
42435ce5e4

+ 7 - 7
Kernel/Net/IPv4Socket.cpp

@@ -246,7 +246,7 @@ ErrorOr<size_t> IPv4Socket::sendto(OpenFileDescription&, UserOrKernelBuffer cons
     return nsent_or_error;
     return nsent_or_error;
 }
 }
 
 
-ErrorOr<size_t> IPv4Socket::receive_byte_buffered(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>)
+ErrorOr<size_t> IPv4Socket::receive_byte_buffered(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>, bool blocking)
 {
 {
     MutexLocker locker(mutex());
     MutexLocker locker(mutex());
 
 
@@ -255,7 +255,7 @@ ErrorOr<size_t> IPv4Socket::receive_byte_buffered(OpenFileDescription& descripti
     if (m_receive_buffer->is_empty()) {
     if (m_receive_buffer->is_empty()) {
         if (protocol_is_disconnected())
         if (protocol_is_disconnected())
             return 0;
             return 0;
-        if (!description.is_blocking())
+        if (!blocking)
             return set_so_error(EAGAIN);
             return set_so_error(EAGAIN);
 
 
         locker.unlock();
         locker.unlock();
@@ -285,7 +285,7 @@ ErrorOr<size_t> IPv4Socket::receive_byte_buffered(OpenFileDescription& descripti
     return nreceived_or_error;
     return nreceived_or_error;
 }
 }
 
 
-ErrorOr<size_t> IPv4Socket::receive_packet_buffered(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*> addr, Userspace<socklen_t*> addr_length, Time& packet_timestamp)
+ErrorOr<size_t> IPv4Socket::receive_packet_buffered(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*> addr, Userspace<socklen_t*> addr_length, Time& packet_timestamp, bool blocking)
 {
 {
     MutexLocker locker(mutex());
     MutexLocker locker(mutex());
     ReceivedPacket taken_packet;
     ReceivedPacket taken_packet;
@@ -296,7 +296,7 @@ ErrorOr<size_t> IPv4Socket::receive_packet_buffered(OpenFileDescription& descrip
             //        But if so, we still need to deliver at least one EOF read to userspace.. right?
             //        But if so, we still need to deliver at least one EOF read to userspace.. right?
             if (protocol_is_disconnected())
             if (protocol_is_disconnected())
                 return 0;
                 return 0;
-            if (!description.is_blocking())
+            if (!blocking)
                 return set_so_error(EAGAIN);
                 return set_so_error(EAGAIN);
         }
         }
 
 
@@ -380,7 +380,7 @@ ErrorOr<size_t> IPv4Socket::receive_packet_buffered(OpenFileDescription& descrip
     return protocol_receive(packet->data->bytes(), buffer, buffer_length, flags);
     return protocol_receive(packet->data->bytes(), buffer, buffer_length, flags);
 }
 }
 
 
-ErrorOr<size_t> IPv4Socket::recvfrom(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*> user_addr, Userspace<socklen_t*> user_addr_length, Time& packet_timestamp)
+ErrorOr<size_t> IPv4Socket::recvfrom(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*> user_addr, Userspace<socklen_t*> user_addr_length, Time& packet_timestamp, bool blocking)
 {
 {
     if (user_addr_length) {
     if (user_addr_length) {
         socklen_t addr_length;
         socklen_t addr_length;
@@ -398,9 +398,9 @@ ErrorOr<size_t> IPv4Socket::recvfrom(OpenFileDescription& description, UserOrKer
 
 
         ErrorOr<size_t> nreceived = 0;
         ErrorOr<size_t> nreceived = 0;
         if (buffer_mode() == BufferMode::Bytes)
         if (buffer_mode() == BufferMode::Bytes)
-            nreceived = receive_byte_buffered(description, offset_buffer, offset_buffer_length, flags, user_addr, user_addr_length);
+            nreceived = receive_byte_buffered(description, offset_buffer, offset_buffer_length, flags, user_addr, user_addr_length, blocking);
         else
         else
-            nreceived = receive_packet_buffered(description, offset_buffer, offset_buffer_length, flags, user_addr, user_addr_length, packet_timestamp);
+            nreceived = receive_packet_buffered(description, offset_buffer, offset_buffer_length, flags, user_addr, user_addr_length, packet_timestamp, blocking);
 
 
         if (nreceived.is_error())
         if (nreceived.is_error())
             total_nreceived = nreceived;
             total_nreceived = nreceived;

+ 3 - 3
Kernel/Net/IPv4Socket.h

@@ -40,7 +40,7 @@ public:
     virtual bool can_read(OpenFileDescription const&, u64) const override;
     virtual bool can_read(OpenFileDescription const&, u64) const override;
     virtual bool can_write(OpenFileDescription const&, u64) const override;
     virtual bool can_write(OpenFileDescription const&, u64) const override;
     virtual ErrorOr<size_t> sendto(OpenFileDescription&, UserOrKernelBuffer const&, size_t, int, Userspace<sockaddr const*>, socklen_t) override;
     virtual ErrorOr<size_t> sendto(OpenFileDescription&, UserOrKernelBuffer const&, size_t, int, Userspace<sockaddr const*>, socklen_t) override;
-    virtual ErrorOr<size_t> recvfrom(OpenFileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>, Time&) override;
+    virtual ErrorOr<size_t> recvfrom(OpenFileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>, Time&, bool blocking) override;
     virtual ErrorOr<void> setsockopt(int level, int option, Userspace<void const*>, socklen_t) override;
     virtual ErrorOr<void> setsockopt(int level, int option, Userspace<void const*>, socklen_t) override;
     virtual ErrorOr<void> getsockopt(OpenFileDescription&, int level, int option, Userspace<void*>, Userspace<socklen_t*>) override;
     virtual ErrorOr<void> getsockopt(OpenFileDescription&, int level, int option, Userspace<void*>, Userspace<socklen_t*>) override;
 
 
@@ -98,8 +98,8 @@ protected:
 private:
 private:
     virtual bool is_ipv4() const override { return true; }
     virtual bool is_ipv4() const override { return true; }
 
 
-    ErrorOr<size_t> receive_byte_buffered(OpenFileDescription&, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>);
-    ErrorOr<size_t> receive_packet_buffered(OpenFileDescription&, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>, Time&);
+    ErrorOr<size_t> receive_byte_buffered(OpenFileDescription&, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>, bool blocking);
+    ErrorOr<size_t> receive_packet_buffered(OpenFileDescription&, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>, Time&, bool blocking);
 
 
     void set_can_read(bool);
     void set_can_read(bool);
 
 

+ 2 - 2
Kernel/Net/LocalSocket.cpp

@@ -334,12 +334,12 @@ DoubleBuffer* LocalSocket::send_buffer_for(OpenFileDescription& description)
     return nullptr;
     return nullptr;
 }
 }
 
 
-ErrorOr<size_t> LocalSocket::recvfrom(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_size, int, Userspace<sockaddr*>, Userspace<socklen_t*>, Time&)
+ErrorOr<size_t> LocalSocket::recvfrom(OpenFileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_size, int, Userspace<sockaddr*>, Userspace<socklen_t*>, Time&, bool blocking)
 {
 {
     auto* socket_buffer = receive_buffer_for(description);
     auto* socket_buffer = receive_buffer_for(description);
     if (!socket_buffer)
     if (!socket_buffer)
         return set_so_error(EINVAL);
         return set_so_error(EINVAL);
-    if (!description.is_blocking()) {
+    if (!blocking) {
         if (socket_buffer->is_empty()) {
         if (socket_buffer->is_empty()) {
             if (!has_attached_peer(description))
             if (!has_attached_peer(description))
                 return 0;
                 return 0;

+ 1 - 1
Kernel/Net/LocalSocket.h

@@ -46,7 +46,7 @@ public:
     virtual bool can_read(OpenFileDescription const&, u64) const override;
     virtual bool can_read(OpenFileDescription const&, u64) const override;
     virtual bool can_write(OpenFileDescription const&, u64) const override;
     virtual bool can_write(OpenFileDescription const&, u64) const override;
     virtual ErrorOr<size_t> sendto(OpenFileDescription&, UserOrKernelBuffer const&, size_t, int, Userspace<sockaddr const*>, socklen_t) override;
     virtual ErrorOr<size_t> sendto(OpenFileDescription&, UserOrKernelBuffer const&, size_t, int, Userspace<sockaddr const*>, socklen_t) override;
-    virtual ErrorOr<size_t> recvfrom(OpenFileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>, Time&) override;
+    virtual ErrorOr<size_t> recvfrom(OpenFileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>, Time&, bool blocking) override;
     virtual ErrorOr<void> getsockopt(OpenFileDescription&, int level, int option, Userspace<void*>, Userspace<socklen_t*>) override;
     virtual ErrorOr<void> getsockopt(OpenFileDescription&, int level, int option, Userspace<void*>, Userspace<socklen_t*>) override;
     virtual ErrorOr<void> ioctl(OpenFileDescription&, unsigned request, Userspace<void*> arg) override;
     virtual ErrorOr<void> ioctl(OpenFileDescription&, unsigned request, Userspace<void*> arg) override;
     virtual ErrorOr<void> chown(Credentials const&, OpenFileDescription&, UserID, GroupID) override;
     virtual ErrorOr<void> chown(Credentials const&, OpenFileDescription&, UserID, GroupID) override;

+ 1 - 1
Kernel/Net/Socket.cpp

@@ -242,7 +242,7 @@ ErrorOr<size_t> Socket::read(OpenFileDescription& description, u64, UserOrKernel
     if (is_shut_down_for_reading())
     if (is_shut_down_for_reading())
         return 0;
         return 0;
     Time t {};
     Time t {};
-    return recvfrom(description, buffer, size, 0, {}, 0, t);
+    return recvfrom(description, buffer, size, 0, {}, 0, t, description.is_blocking());
 }
 }
 
 
 ErrorOr<size_t> Socket::write(OpenFileDescription& description, u64, UserOrKernelBuffer const& data, size_t size)
 ErrorOr<size_t> Socket::write(OpenFileDescription& description, u64, UserOrKernelBuffer const& data, size_t size)

+ 1 - 1
Kernel/Net/Socket.h

@@ -80,7 +80,7 @@ public:
     virtual bool is_local() const { return false; }
     virtual bool is_local() const { return false; }
     virtual bool is_ipv4() const { return false; }
     virtual bool is_ipv4() const { return false; }
     virtual ErrorOr<size_t> sendto(OpenFileDescription&, UserOrKernelBuffer const&, size_t, int flags, Userspace<sockaddr const*>, socklen_t) = 0;
     virtual ErrorOr<size_t> sendto(OpenFileDescription&, UserOrKernelBuffer const&, size_t, int flags, Userspace<sockaddr const*>, socklen_t) = 0;
-    virtual ErrorOr<size_t> recvfrom(OpenFileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>, Time&) = 0;
+    virtual ErrorOr<size_t> recvfrom(OpenFileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>, Time&, bool blocking) = 0;
 
 
     virtual ErrorOr<void> setsockopt(int level, int option, Userspace<void const*>, socklen_t);
     virtual ErrorOr<void> setsockopt(int level, int option, Userspace<void const*>, socklen_t);
     virtual ErrorOr<void> getsockopt(OpenFileDescription&, int level, int option, Userspace<void*>, Userspace<socklen_t*>);
     virtual ErrorOr<void> getsockopt(OpenFileDescription&, int level, int option, Userspace<void*>, Userspace<socklen_t*>);

+ 2 - 7
Kernel/Syscalls/socket.cpp

@@ -241,15 +241,10 @@ ErrorOr<FlatPtr> Process::sys$recvmsg(int sockfd, Userspace<struct msghdr*> user
     if (socket.is_shut_down_for_reading())
     if (socket.is_shut_down_for_reading())
         return 0;
         return 0;
 
 
-    bool original_blocking = description->is_blocking();
-    if (flags & MSG_DONTWAIT)
-        description->set_blocking(false);
-
     auto data_buffer = TRY(UserOrKernelBuffer::for_user_buffer((u8*)iovs[0].iov_base, iovs[0].iov_len));
     auto data_buffer = TRY(UserOrKernelBuffer::for_user_buffer((u8*)iovs[0].iov_base, iovs[0].iov_len));
     Time timestamp {};
     Time timestamp {};
-    auto result = socket.recvfrom(*description, data_buffer, iovs[0].iov_len, flags, user_addr, user_addr_length, timestamp);
-    if (flags & MSG_DONTWAIT)
-        description->set_blocking(original_blocking);
+    bool blocking = (flags & MSG_DONTWAIT) ? false : description->is_blocking();
+    auto result = socket.recvfrom(*description, data_buffer, iovs[0].iov_len, flags, user_addr, user_addr_length, timestamp, blocking);
 
 
     if (result.is_error())
     if (result.is_error())
         return result.release_error();
         return result.release_error();