Pārlūkot izejas kodu

LibCore: Optionally pass MSG_NOSIGNAL to socket read/writes

When creating a `Core::Stream::Socket`, you can now choose to prevent
SIGPIPE signals from firing and terminating your process. This is done
by passing MSG_NOSIGNAL to the `System::recv()` or `System::send()`
calls when you `read()` or `write()` to that Socket.
Sam Atkins 2 gadi atpakaļ
vecāks
revīzija
cb5f83606a

+ 6 - 6
Userland/Libraries/LibCore/Stream.cpp

@@ -435,13 +435,13 @@ ErrorOr<Bytes> PosixSocketHelper::read(Bytes buffer, int flags)
     return buffer.trim(nread);
 }
 
-ErrorOr<size_t> PosixSocketHelper::write(ReadonlyBytes buffer)
+ErrorOr<size_t> PosixSocketHelper::write(ReadonlyBytes buffer, int flags)
 {
     if (!is_open()) {
         return Error::from_errno(ENOTCONN);
     }
 
-    return TRY(System::send(m_fd, buffer.data(), buffer.size(), 0));
+    return TRY(System::send(m_fd, buffer.data(), buffer.size(), flags));
 }
 
 void PosixSocketHelper::close()
@@ -574,9 +574,9 @@ ErrorOr<NonnullOwnPtr<UDPSocket>> UDPSocket::connect(SocketAddress const& addres
     return socket;
 }
 
-ErrorOr<NonnullOwnPtr<LocalSocket>> LocalSocket::connect(String const& path)
+ErrorOr<NonnullOwnPtr<LocalSocket>> LocalSocket::connect(String const& path, PreventSIGPIPE prevent_sigpipe)
 {
-    auto socket = TRY(adopt_nonnull_own_or_enomem(new (nothrow) LocalSocket()));
+    auto socket = TRY(adopt_nonnull_own_or_enomem(new (nothrow) LocalSocket(prevent_sigpipe)));
 
     auto fd = TRY(create_fd(SocketDomain::Local, SocketType::Stream));
     socket->m_helper.set_fd(fd);
@@ -587,13 +587,13 @@ ErrorOr<NonnullOwnPtr<LocalSocket>> LocalSocket::connect(String const& path)
     return socket;
 }
 
-ErrorOr<NonnullOwnPtr<LocalSocket>> LocalSocket::adopt_fd(int fd)
+ErrorOr<NonnullOwnPtr<LocalSocket>> LocalSocket::adopt_fd(int fd, PreventSIGPIPE prevent_sigpipe)
 {
     if (fd < 0) {
         return Error::from_errno(EBADF);
     }
 
-    auto socket = TRY(adopt_nonnull_own_or_enomem(new (nothrow) LocalSocket()));
+    auto socket = TRY(adopt_nonnull_own_or_enomem(new (nothrow) LocalSocket(prevent_sigpipe)));
     socket->m_helper.set_fd(fd);
     socket->setup_notifier();
     return socket;

+ 41 - 15
Userland/Libraries/LibCore/Stream.h

@@ -106,6 +106,11 @@ public:
     virtual ErrorOr<void> discard(size_t discarded_bytes) override;
 };
 
+enum class PreventSIGPIPE {
+    No,
+    Yes,
+};
+
 /// The Socket class is the base class for all concrete BSD-style socket
 /// classes. Sockets are non-seekable streams which can be read byte-wise.
 class Socket : public Stream {
@@ -149,7 +154,8 @@ protected:
         Datagram,
     };
 
-    Socket()
+    Socket(PreventSIGPIPE prevent_sigpipe = PreventSIGPIPE::No)
+        : m_prevent_sigpipe(prevent_sigpipe == PreventSIGPIPE::Yes)
     {
     }
 
@@ -160,6 +166,17 @@ protected:
 
     static ErrorOr<void> connect_local(int fd, String const& path);
     static ErrorOr<void> connect_inet(int fd, SocketAddress const&);
+
+    int default_flags() const
+    {
+        int flags = 0;
+        if (m_prevent_sigpipe)
+            flags |= MSG_NOSIGNAL;
+        return flags;
+    }
+
+private:
+    bool m_prevent_sigpipe { false };
 };
 
 /// A reusable socket maintains state about being connected in addition to
@@ -262,7 +279,9 @@ class PosixSocketHelper {
 
 public:
     template<typename T>
-    PosixSocketHelper(Badge<T>) requires(IsBaseOf<Socket, T>) { }
+    PosixSocketHelper(Badge<T>) requires(IsBaseOf<Socket, T>)
+    {
+    }
 
     PosixSocketHelper(PosixSocketHelper&& other)
     {
@@ -280,8 +299,8 @@ public:
     int fd() const { return m_fd; }
     void set_fd(int fd) { m_fd = fd; }
 
-    ErrorOr<Bytes> read(Bytes, int flags = 0);
-    ErrorOr<size_t> write(ReadonlyBytes);
+    ErrorOr<Bytes> read(Bytes, int flags);
+    ErrorOr<size_t> write(ReadonlyBytes, int flags);
 
     bool is_eof() const { return !is_open() || m_last_read_was_eof; }
     bool is_open() const { return m_fd != -1; }
@@ -329,8 +348,8 @@ public:
 
     virtual bool is_readable() const override { return is_open(); }
     virtual bool is_writable() const override { return is_open(); }
-    virtual ErrorOr<Bytes> read(Bytes buffer) override { return m_helper.read(buffer); }
-    virtual ErrorOr<size_t> write(ReadonlyBytes buffer) override { return m_helper.write(buffer); }
+    virtual ErrorOr<Bytes> read(Bytes buffer) override { return m_helper.read(buffer, default_flags()); }
+    virtual ErrorOr<size_t> write(ReadonlyBytes buffer) override { return m_helper.write(buffer, default_flags()); }
     virtual bool is_eof() const override { return m_helper.is_eof(); }
     virtual bool is_open() const override { return m_helper.is_open(); };
     virtual void close() override { m_helper.close(); };
@@ -347,7 +366,8 @@ public:
     virtual ~TCPSocket() override { close(); }
 
 private:
-    TCPSocket()
+    TCPSocket(PreventSIGPIPE prevent_sigpipe = PreventSIGPIPE::No)
+        : Socket(prevent_sigpipe)
     {
     }
 
@@ -400,12 +420,12 @@ public:
             return Error::from_errno(EMSGSIZE);
         }
 
-        return m_helper.read(buffer);
+        return m_helper.read(buffer, default_flags());
     }
 
     virtual bool is_readable() const override { return is_open(); }
     virtual bool is_writable() const override { return is_open(); }
-    virtual ErrorOr<size_t> write(ReadonlyBytes buffer) override { return m_helper.write(buffer); }
+    virtual ErrorOr<size_t> write(ReadonlyBytes buffer) override { return m_helper.write(buffer, default_flags()); }
     virtual bool is_eof() const override { return m_helper.is_eof(); }
     virtual bool is_open() const override { return m_helper.is_open(); }
     virtual void close() override { m_helper.close(); }
@@ -422,7 +442,10 @@ public:
     virtual ~UDPSocket() override { close(); }
 
 private:
-    UDPSocket() = default;
+    UDPSocket(PreventSIGPIPE prevent_sigpipe = PreventSIGPIPE::No)
+        : Socket(prevent_sigpipe)
+    {
+    }
 
     void setup_notifier()
     {
@@ -440,8 +463,8 @@ private:
 
 class LocalSocket final : public Socket {
 public:
-    static ErrorOr<NonnullOwnPtr<LocalSocket>> connect(String const& path);
-    static ErrorOr<NonnullOwnPtr<LocalSocket>> adopt_fd(int fd);
+    static ErrorOr<NonnullOwnPtr<LocalSocket>> connect(String const& path, PreventSIGPIPE = PreventSIGPIPE::No);
+    static ErrorOr<NonnullOwnPtr<LocalSocket>> adopt_fd(int fd, PreventSIGPIPE = PreventSIGPIPE::No);
 
     LocalSocket(LocalSocket&& other)
         : Socket(static_cast<Socket&&>(other))
@@ -463,8 +486,8 @@ public:
 
     virtual bool is_readable() const override { return is_open(); }
     virtual bool is_writable() const override { return is_open(); }
-    virtual ErrorOr<Bytes> read(Bytes buffer) override { return m_helper.read(buffer); }
-    virtual ErrorOr<size_t> write(ReadonlyBytes buffer) override { return m_helper.write(buffer); }
+    virtual ErrorOr<Bytes> read(Bytes buffer) override { return m_helper.read(buffer, default_flags()); }
+    virtual ErrorOr<size_t> write(ReadonlyBytes buffer) override { return m_helper.write(buffer, default_flags()); }
     virtual bool is_eof() const override { return m_helper.is_eof(); }
     virtual bool is_open() const override { return m_helper.is_open(); }
     virtual void close() override { m_helper.close(); }
@@ -495,7 +518,10 @@ public:
     virtual ~LocalSocket() { close(); }
 
 private:
-    LocalSocket() = default;
+    LocalSocket(PreventSIGPIPE prevent_sigpipe = PreventSIGPIPE::No)
+        : Socket(prevent_sigpipe)
+    {
+    }
 
     void setup_notifier()
     {