Przeglądaj źródła

IPCCompiler+LibIPC: Generate message decoders with better TRY semantics

Instead of a bunch of manual error checking and returning a null OwnPtr,
we can propagate the errors up and return NonnullOwnPtr on success.
Timothy Flynn 2 lat temu
rodzic
commit
765c5b416f

+ 36 - 43
Meta/Lagom/Tools/CodeGenerators/IPCCompiler/main.cpp

@@ -293,6 +293,21 @@ DeprecatedString constructor_for_message(DeprecatedString const& name, Vector<Pa
     return builder.to_deprecated_string();
     return builder.to_deprecated_string();
 }
 }
 
 
+static void append_handle_stream_error(SourceGenerator& generator, StringView error_message)
+{
+    if constexpr (GENERATE_DEBUG) {
+        generator.set("error_message"sv, error_message);
+        generator.append(R"~~~(
+        if (stream.handle_any_error()) {
+            dbgln("@error_message@");
+            return Error::from_string_literal("@error_message@");
+        })~~~");
+    } else {
+        generator.append(R"~~~(
+        TRY(stream.try_handle_any_error());)~~~");
+    }
+}
+
 void do_message(SourceGenerator message_generator, DeprecatedString const& name, Vector<Parameter> const& parameters, DeprecatedString const& response_type = {})
 void do_message(SourceGenerator message_generator, DeprecatedString const& name, Vector<Parameter> const& parameters, DeprecatedString const& response_type = {})
 {
 {
     auto pascal_name = pascal_case(name);
     auto pascal_name = pascal_case(name);
@@ -338,7 +353,7 @@ public:)~~~");
     static i32 static_message_id() { return (int)MessageID::@message.pascal_name@; }
     static i32 static_message_id() { return (int)MessageID::@message.pascal_name@; }
     virtual const char* message_name() const override { return "@endpoint.name@::@message.pascal_name@"; }
     virtual const char* message_name() const override { return "@endpoint.name@::@message.pascal_name@"; }
 
 
-    static OwnPtr<@message.pascal_name@> decode(InputMemoryStream& stream, Core::Stream::LocalSocket& socket)
+    static ErrorOr<NonnullOwnPtr<@message.pascal_name@>> decode(InputMemoryStream& stream, Core::Stream::LocalSocket& socket)
     {
     {
         IPC::Decoder decoder { stream, socket };)~~~");
         IPC::Decoder decoder { stream, socket };)~~~");
 
 
@@ -355,13 +370,12 @@ public:)~~~");
 
 
         parameter_generator.appendln(R"~~~(
         parameter_generator.appendln(R"~~~(
         @parameter.type@ @parameter.name@ = @parameter.initial_value@;
         @parameter.type@ @parameter.name@ = @parameter.initial_value@;
-        if (decoder.decode(@parameter.name@).is_error())
-            return {};)~~~");
+        TRY(decoder.decode(@parameter.name@));)~~~");
 
 
         if (parameter.attributes.contains_slow("UTF8")) {
         if (parameter.attributes.contains_slow("UTF8")) {
             parameter_generator.appendln(R"~~~(
             parameter_generator.appendln(R"~~~(
         if (!Utf8View(@parameter.name@).validate())
         if (!Utf8View(@parameter.name@).validate())
-            return {};)~~~");
+            return Error::from_string_literal("Decoded @parameter.name@ is invalid UTF-8");)~~~");
         }
         }
     }
     }
 
 
@@ -374,7 +388,7 @@ public:)~~~");
     }
     }
 
 
     message_generator.set("message.constructor_call_parameters", builder.build());
     message_generator.set("message.constructor_call_parameters", builder.build());
-
+    append_handle_stream_error(message_generator, "Failed to read the message"sv);
     message_generator.appendln(R"~~~(
     message_generator.appendln(R"~~~(
         return make<@message.pascal_name@>(@message.constructor_call_parameters@);
         return make<@message.pascal_name@>(@message.constructor_call_parameters@);
     })~~~");
     })~~~");
@@ -574,7 +588,7 @@ private:
     IPC::Connection<LocalEndpoint, PeerEndpoint>& m_connection;
     IPC::Connection<LocalEndpoint, PeerEndpoint>& m_connection;
 };)~~~");
 };)~~~");
 
 
-    generator.appendln(R"~~~(
+    generator.append(R"~~~(
 template<typename LocalEndpoint, typename PeerEndpoint>
 template<typename LocalEndpoint, typename PeerEndpoint>
 class @endpoint.name@Proxy;
 class @endpoint.name@Proxy;
 class @endpoint.name@Stub;
 class @endpoint.name@Stub;
@@ -587,41 +601,28 @@ public:
 
 
     static u32 static_magic() { return @endpoint.magic@; }
     static u32 static_magic() { return @endpoint.magic@; }
 
 
-    static OwnPtr<IPC::Message> decode_message(ReadonlyBytes buffer, [[maybe_unused]] Core::Stream::LocalSocket& socket)
+    static ErrorOr<NonnullOwnPtr<IPC::Message>> decode_message(ReadonlyBytes buffer, [[maybe_unused]] Core::Stream::LocalSocket& socket)
     {
     {
         InputMemoryStream stream { buffer };
         InputMemoryStream stream { buffer };
         u32 message_endpoint_magic = 0;
         u32 message_endpoint_magic = 0;
-        stream >> message_endpoint_magic;
-        if (stream.handle_any_error()) {)~~~");
-    if constexpr (GENERATE_DEBUG) {
-        generator.appendln(R"~~~(
-            dbgln("Failed to read message endpoint magic");)~~~");
-    }
-    generator.appendln(R"~~~(
-            return {};
-        }
+        stream >> message_endpoint_magic;)~~~");
+    append_handle_stream_error(generator, "Failed to read message endpoint magic"sv);
+    generator.append(R"~~~(
 
 
         if (message_endpoint_magic != @endpoint.magic@) {)~~~");
         if (message_endpoint_magic != @endpoint.magic@) {)~~~");
     if constexpr (GENERATE_DEBUG) {
     if constexpr (GENERATE_DEBUG) {
-        generator.appendln(R"~~~(
+        generator.append(R"~~~(
             dbgln("@endpoint.name@: Endpoint magic number message_endpoint_magic != @endpoint.magic@, not my message! (the other endpoint may have handled it)");)~~~");
             dbgln("@endpoint.name@: Endpoint magic number message_endpoint_magic != @endpoint.magic@, not my message! (the other endpoint may have handled it)");)~~~");
     }
     }
     generator.appendln(R"~~~(
     generator.appendln(R"~~~(
-            return {};
+            return Error::from_string_literal("Endpoint magic number mismatch, not my message!");
         }
         }
 
 
         i32 message_id = 0;
         i32 message_id = 0;
-        stream >> message_id;
-        if (stream.handle_any_error()) {)~~~");
-    if constexpr (GENERATE_DEBUG) {
-        generator.appendln(R"~~~(
-            dbgln("Failed to read message ID");)~~~");
-    }
+        stream >> message_id;)~~~");
+    append_handle_stream_error(generator, "Failed to read message ID"sv);
     generator.appendln(R"~~~(
     generator.appendln(R"~~~(
-            return {};
-        }
 
 
-        OwnPtr<IPC::Message> message;
         switch (message_id) {)~~~");
         switch (message_id) {)~~~");
 
 
     for (auto const& message : endpoint.messages) {
     for (auto const& message : endpoint.messages) {
@@ -631,10 +632,9 @@ public:
             message_generator.set("message.name", name);
             message_generator.set("message.name", name);
             message_generator.set("message.pascal_name", pascal_case(name));
             message_generator.set("message.pascal_name", pascal_case(name));
 
 
-            message_generator.appendln(R"~~~(
+            message_generator.append(R"~~~(
         case (int)Messages::@endpoint.name@::MessageID::@message.pascal_name@:
         case (int)Messages::@endpoint.name@::MessageID::@message.pascal_name@:
-            message = Messages::@endpoint.name@::@message.pascal_name@::decode(stream, socket);
-            break;)~~~");
+            return TRY(Messages::@endpoint.name@::@message.pascal_name@::decode(stream, socket));)~~~");
         };
         };
 
 
         do_decode_message(message.name);
         do_decode_message(message.name);
@@ -642,26 +642,18 @@ public:
             do_decode_message(message.response_name());
             do_decode_message(message.response_name());
     }
     }
 
 
-    generator.appendln(R"~~~(
+    generator.append(R"~~~(
         default:)~~~");
         default:)~~~");
     if constexpr (GENERATE_DEBUG) {
     if constexpr (GENERATE_DEBUG) {
-        generator.appendln(R"~~~(
+        generator.append(R"~~~(
             dbgln("Failed to decode @endpoint.name@.({})", message_id);)~~~");
             dbgln("Failed to decode @endpoint.name@.({})", message_id);)~~~");
     }
     }
     generator.appendln(R"~~~(
     generator.appendln(R"~~~(
-            return {};
-        }
+            return Error::from_string_literal("Failed to decode @endpoint.name@ message");
+        })~~~");
 
 
-        if (stream.handle_any_error()) {)~~~");
-    if constexpr (GENERATE_DEBUG) {
-        generator.appendln(R"~~~(
-            dbgln("Failed to read the message");)~~~");
-    }
     generator.appendln(R"~~~(
     generator.appendln(R"~~~(
-            return {};
-        }
-
-        return message;
+        VERIFY_NOT_REACHED();
     }
     }
 
 
 };
 };
@@ -797,6 +789,7 @@ void build(StringBuilder& builder, Vector<Endpoint> const& endpoints)
     }
     }
 
 
     generator.appendln(R"~~~(#include <AK/MemoryStream.h>
     generator.appendln(R"~~~(#include <AK/MemoryStream.h>
+#include <AK/Error.h>
 #include <AK/OwnPtr.h>
 #include <AK/OwnPtr.h>
 #include <AK/Result.h>
 #include <AK/Result.h>
 #include <AK/Utf8View.h>
 #include <AK/Utf8View.h>

+ 16 - 7
Userland/Libraries/LibIPC/Connection.h

@@ -139,14 +139,23 @@ protected:
                 break;
                 break;
             index += sizeof(message_size);
             index += sizeof(message_size);
             auto remaining_bytes = ReadonlyBytes { bytes.data() + index, message_size };
             auto remaining_bytes = ReadonlyBytes { bytes.data() + index, message_size };
-            if (auto message = LocalEndpoint::decode_message(remaining_bytes, fd_passing_socket())) {
-                m_unprocessed_messages.append(message.release_nonnull());
-            } else if (auto message = PeerEndpoint::decode_message(remaining_bytes, fd_passing_socket())) {
-                m_unprocessed_messages.append(message.release_nonnull());
-            } else {
-                dbgln("Failed to parse a message");
-                break;
+
+            auto local_message = LocalEndpoint::decode_message(remaining_bytes, fd_passing_socket());
+            if (!local_message.is_error()) {
+                m_unprocessed_messages.append(local_message.release_value());
+                continue;
             }
             }
+
+            auto peer_message = PeerEndpoint::decode_message(remaining_bytes, fd_passing_socket());
+            if (!peer_message.is_error()) {
+                m_unprocessed_messages.append(peer_message.release_value());
+                continue;
+            }
+
+            dbgln("Failed to parse a message");
+            dbgln("Local endpoint error: {}", local_message.error());
+            dbgln("Peer endpoint error: {}", peer_message.error());
+            break;
         }
         }
     }
     }
 };
 };