Pārlūkot izejas kodu

Kernel: Add support for the MSG_WAITALL sys$recvmsg flag

Idan Horowitz 3 gadi atpakaļ
vecāks
revīzija
e521ffd156

+ 1 - 0
Kernel/API/POSIX/sys/socket.h

@@ -53,6 +53,7 @@ extern "C" {
 #define MSG_PEEK 0x4
 #define MSG_OOB 0x8
 #define MSG_DONTROUTE 0x10
+#define MSG_WAITALL 0x20
 #define MSG_DONTWAIT 0x40
 
 typedef uint16_t sa_family_t;

+ 19 - 8
Kernel/Net/IPv4Socket.cpp

@@ -391,15 +391,26 @@ ErrorOr<size_t> IPv4Socket::recvfrom(OpenFileDescription& description, UserOrKer
 
     dbgln_if(IPV4_SOCKET_DEBUG, "recvfrom: type={}, local_port={}", type(), local_port());
 
-    ErrorOr<size_t> nreceived = 0;
-    if (buffer_mode() == BufferMode::Bytes)
-        nreceived = receive_byte_buffered(description, buffer, buffer_length, flags, user_addr, user_addr_length);
-    else
-        nreceived = receive_packet_buffered(description, buffer, buffer_length, flags, user_addr, user_addr_length, packet_timestamp);
+    ErrorOr<size_t> total_nreceived = 0;
+    do {
+        auto offset_buffer = buffer.offset(total_nreceived.value());
+        auto offset_buffer_length = buffer_length - total_nreceived.value();
+
+        ErrorOr<size_t> nreceived = 0;
+        if (buffer_mode() == BufferMode::Bytes)
+            nreceived = receive_byte_buffered(description, offset_buffer, offset_buffer_length, flags, user_addr, user_addr_length);
+        else
+            nreceived = receive_packet_buffered(description, offset_buffer, offset_buffer_length, flags, user_addr, user_addr_length, packet_timestamp);
+
+        if (nreceived.is_error())
+            total_nreceived = nreceived;
+        else
+            total_nreceived.value() += nreceived.value();
+    } while ((flags & MSG_WAITALL) && !total_nreceived.is_error() && total_nreceived.value() < buffer_length);
 
-    if (!nreceived.is_error())
-        Thread::current()->did_ipv4_socket_read(nreceived.value());
-    return nreceived;
+    if (!total_nreceived.is_error())
+        Thread::current()->did_ipv4_socket_read(total_nreceived.value());
+    return total_nreceived;
 }
 
 bool IPv4Socket::did_receive(const IPv4Address& source_address, u16 source_port, ReadonlyBytes packet, const Time& packet_timestamp)

+ 2 - 2
Userland/Utilities/strace.cpp

@@ -607,8 +607,8 @@ static void format_connect(FormattedSyscallBuilder& builder, int socket, const s
 struct MsgOptions : BitflagBase {
     static constexpr auto options = {
         BITFLAG(MSG_TRUNC), BITFLAG(MSG_CTRUNC), BITFLAG(MSG_PEEK),
-        BITFLAG(MSG_OOB), BITFLAG(MSG_DONTROUTE), BITFLAG(MSG_DONTWAIT)
-        // TODO: add MSG_WAITALL once its definition is added
+        BITFLAG(MSG_OOB), BITFLAG(MSG_DONTROUTE), BITFLAG(MSG_WAITALL),
+        BITFLAG(MSG_DONTWAIT)
     };
 };