IPCCompiler+LibIPC: Generate message decoders with better TRY semantics

Instead of a bunch of manual error checking and returning a null OwnPtr,
we can propagate the errors up and return NonnullOwnPtr on success.
This commit is contained in:
Timothy Flynn 2022-12-22 13:48:44 -05:00 committed by Andreas Kling
parent dc77ec733f
commit 765c5b416f
Notes: sideshowbarker 2024-07-18 00:41:35 +09:00
2 changed files with 52 additions and 50 deletions

View file

@ -293,6 +293,21 @@ DeprecatedString constructor_for_message(DeprecatedString const& name, Vector<Pa
return builder.to_deprecated_string();
}
static void append_handle_stream_error(SourceGenerator& generator, StringView error_message)
{
if constexpr (GENERATE_DEBUG) {
generator.set("error_message"sv, error_message);
generator.append(R"~~~(
if (stream.handle_any_error()) {
dbgln("@error_message@");
return Error::from_string_literal("@error_message@");
})~~~");
} else {
generator.append(R"~~~(
TRY(stream.try_handle_any_error());)~~~");
}
}
void do_message(SourceGenerator message_generator, DeprecatedString const& name, Vector<Parameter> const& parameters, DeprecatedString const& response_type = {})
{
auto pascal_name = pascal_case(name);
@ -338,7 +353,7 @@ public:)~~~");
static i32 static_message_id() { return (int)MessageID::@message.pascal_name@; }
virtual const char* message_name() const override { return "@endpoint.name@::@message.pascal_name@"; }
static OwnPtr<@message.pascal_name@> decode(InputMemoryStream& stream, Core::Stream::LocalSocket& socket)
static ErrorOr<NonnullOwnPtr<@message.pascal_name@>> decode(InputMemoryStream& stream, Core::Stream::LocalSocket& socket)
{
IPC::Decoder decoder { stream, socket };)~~~");
@ -355,13 +370,12 @@ public:)~~~");
parameter_generator.appendln(R"~~~(
@parameter.type@ @parameter.name@ = @parameter.initial_value@;
if (decoder.decode(@parameter.name@).is_error())
return {};)~~~");
TRY(decoder.decode(@parameter.name@));)~~~");
if (parameter.attributes.contains_slow("UTF8")) {
parameter_generator.appendln(R"~~~(
if (!Utf8View(@parameter.name@).validate())
return {};)~~~");
return Error::from_string_literal("Decoded @parameter.name@ is invalid UTF-8");)~~~");
}
}
@ -374,7 +388,7 @@ public:)~~~");
}
message_generator.set("message.constructor_call_parameters", builder.build());
append_handle_stream_error(message_generator, "Failed to read the message"sv);
message_generator.appendln(R"~~~(
return make<@message.pascal_name@>(@message.constructor_call_parameters@);
})~~~");
@ -574,7 +588,7 @@ private:
IPC::Connection<LocalEndpoint, PeerEndpoint>& m_connection;
};)~~~");
generator.appendln(R"~~~(
generator.append(R"~~~(
template<typename LocalEndpoint, typename PeerEndpoint>
class @endpoint.name@Proxy;
class @endpoint.name@Stub;
@ -587,41 +601,28 @@ public:
static u32 static_magic() { return @endpoint.magic@; }
static OwnPtr<IPC::Message> decode_message(ReadonlyBytes buffer, [[maybe_unused]] Core::Stream::LocalSocket& socket)
static ErrorOr<NonnullOwnPtr<IPC::Message>> decode_message(ReadonlyBytes buffer, [[maybe_unused]] Core::Stream::LocalSocket& socket)
{
InputMemoryStream stream { buffer };
u32 message_endpoint_magic = 0;
stream >> message_endpoint_magic;
if (stream.handle_any_error()) {)~~~");
if constexpr (GENERATE_DEBUG) {
generator.appendln(R"~~~(
dbgln("Failed to read message endpoint magic");)~~~");
}
generator.appendln(R"~~~(
return {};
}
stream >> message_endpoint_magic;)~~~");
append_handle_stream_error(generator, "Failed to read message endpoint magic"sv);
generator.append(R"~~~(
if (message_endpoint_magic != @endpoint.magic@) {)~~~");
if constexpr (GENERATE_DEBUG) {
generator.appendln(R"~~~(
generator.append(R"~~~(
dbgln("@endpoint.name@: Endpoint magic number message_endpoint_magic != @endpoint.magic@, not my message! (the other endpoint may have handled it)");)~~~");
}
generator.appendln(R"~~~(
return {};
return Error::from_string_literal("Endpoint magic number mismatch, not my message!");
}
i32 message_id = 0;
stream >> message_id;
if (stream.handle_any_error()) {)~~~");
if constexpr (GENERATE_DEBUG) {
generator.appendln(R"~~~(
dbgln("Failed to read message ID");)~~~");
}
stream >> message_id;)~~~");
append_handle_stream_error(generator, "Failed to read message ID"sv);
generator.appendln(R"~~~(
return {};
}
OwnPtr<IPC::Message> message;
switch (message_id) {)~~~");
for (auto const& message : endpoint.messages) {
@ -631,10 +632,9 @@ public:
message_generator.set("message.name", name);
message_generator.set("message.pascal_name", pascal_case(name));
message_generator.appendln(R"~~~(
message_generator.append(R"~~~(
case (int)Messages::@endpoint.name@::MessageID::@message.pascal_name@:
message = Messages::@endpoint.name@::@message.pascal_name@::decode(stream, socket);
break;)~~~");
return TRY(Messages::@endpoint.name@::@message.pascal_name@::decode(stream, socket));)~~~");
};
do_decode_message(message.name);
@ -642,26 +642,18 @@ public:
do_decode_message(message.response_name());
}
generator.appendln(R"~~~(
generator.append(R"~~~(
default:)~~~");
if constexpr (GENERATE_DEBUG) {
generator.appendln(R"~~~(
generator.append(R"~~~(
dbgln("Failed to decode @endpoint.name@.({})", message_id);)~~~");
}
generator.appendln(R"~~~(
return {};
}
return Error::from_string_literal("Failed to decode @endpoint.name@ message");
})~~~");
if (stream.handle_any_error()) {)~~~");
if constexpr (GENERATE_DEBUG) {
generator.appendln(R"~~~(
dbgln("Failed to read the message");)~~~");
}
generator.appendln(R"~~~(
return {};
}
return message;
VERIFY_NOT_REACHED();
}
};
@ -797,6 +789,7 @@ void build(StringBuilder& builder, Vector<Endpoint> const& endpoints)
}
generator.appendln(R"~~~(#include <AK/MemoryStream.h>
#include <AK/Error.h>
#include <AK/OwnPtr.h>
#include <AK/Result.h>
#include <AK/Utf8View.h>

View file

@ -139,14 +139,23 @@ protected:
break;
index += sizeof(message_size);
auto remaining_bytes = ReadonlyBytes { bytes.data() + index, message_size };
if (auto message = LocalEndpoint::decode_message(remaining_bytes, fd_passing_socket())) {
m_unprocessed_messages.append(message.release_nonnull());
} else if (auto message = PeerEndpoint::decode_message(remaining_bytes, fd_passing_socket())) {
m_unprocessed_messages.append(message.release_nonnull());
} else {
dbgln("Failed to parse a message");
break;
auto local_message = LocalEndpoint::decode_message(remaining_bytes, fd_passing_socket());
if (!local_message.is_error()) {
m_unprocessed_messages.append(local_message.release_value());
continue;
}
auto peer_message = PeerEndpoint::decode_message(remaining_bytes, fd_passing_socket());
if (!peer_message.is_error()) {
m_unprocessed_messages.append(peer_message.release_value());
continue;
}
dbgln("Failed to parse a message");
dbgln("Local endpoint error: {}", local_message.error());
dbgln("Peer endpoint error: {}", peer_message.error());
break;
}
}
};