LibIPC+AudioServer: Allow unsolicited server-to-client IPC messages

Client-side connection objects must now provide both client and server
endpoint types. When a message is received from the server side, we try
to decode it using both endpoint types and then send it to the right
place for handling.

This now makes it possible for AudioServer to send unsolicited messages
to its clients. This opens up a ton of possibilities :^)
This commit is contained in:
Andreas Kling 2019-11-23 16:43:21 +01:00
parent 06ee24263c
commit 630d5b3ffd
Notes: sideshowbarker 2024-07-19 11:06:36 +09:00
13 changed files with 95 additions and 42 deletions

View file

@ -27,6 +27,7 @@ struct Message {
struct Endpoint { struct Endpoint {
String name; String name;
int magic;
Vector<Message> messages; Vector<Message> messages;
}; };
@ -177,6 +178,13 @@ int main(int argc, char** argv)
consume_whitespace(); consume_whitespace();
endpoints.last().name = extract_while([](char ch) { return !isspace(ch); }); endpoints.last().name = extract_while([](char ch) { return !isspace(ch); });
consume_whitespace(); consume_whitespace();
consume_specific('=');
consume_whitespace();
auto magic_string = extract_while([](char ch) { return !isspace(ch) && ch != '{'; });
bool ok;
endpoints.last().magic = magic_string.to_int(ok);
ASSERT(ok);
consume_whitespace();
consume_specific('{'); consume_specific('{');
parse_messages(); parse_messages();
consume_specific('}'); consume_specific('}');
@ -244,17 +252,20 @@ int main(int argc, char** argv)
return builder.to_string(); return builder.to_string();
}; };
auto do_message = [&](const String& name, const Vector<Parameter>& parameters, String response_type = {}) { auto do_message = [&](const String& name, const Vector<Parameter>& parameters, const String& response_type = {}) {
dbg() << "class " << name << " final : public IMessage {"; dbg() << "class " << name << " final : public IMessage {";
dbg() << "public:"; dbg() << "public:";
if (!response_type.is_null()) if (!response_type.is_null())
dbg() << " typedef class " << response_type << " ResponseType;"; dbg() << " typedef class " << response_type << " ResponseType;";
dbg() << " " << constructor_for_message(name, parameters); dbg() << " " << constructor_for_message(name, parameters);
dbg() << " virtual ~" << name << "() override {}"; dbg() << " virtual ~" << name << "() override {}";
dbg() << " virtual i32 endpoint_magic() const override { return " << endpoint.magic << "; }";
dbg() << " static i32 static_endpoint_magic() { return " << endpoint.magic << "; }";
dbg() << " virtual i32 id() const override { return (int)MessageID::" << name << "; }"; dbg() << " virtual i32 id() const override { return (int)MessageID::" << name << "; }";
dbg() << " static i32 static_message_id() { return (int)MessageID::" << name << "; }"; dbg() << " static i32 static_message_id() { return (int)MessageID::" << name << "; }";
dbg() << " virtual String name() const override { return \"" << endpoint.name << "::" << name << "\"; }"; dbg() << " virtual String name() const override { return \"" << endpoint.name << "::" << name << "\"; }";
dbg() << " static OwnPtr<" << name << "> decode(BufferStream& stream)"; dbg() << " static String static_name() { return \"" << endpoint.name << "::" << name << "\"; }";
dbg() << " static OwnPtr<" << name << "> decode(BufferStream& stream, size_t& size_in_bytes)";
dbg() << " {"; dbg() << " {";
if (parameters.is_empty()) if (parameters.is_empty())
@ -278,6 +289,7 @@ int main(int argc, char** argv)
if (i != parameters.size() - 1) if (i != parameters.size() - 1)
builder.append(", "); builder.append(", ");
} }
dbg() << " size_in_bytes = stream.offset();";
dbg() << " return make<" << name << ">(" << builder.to_string() << ");"; dbg() << " return make<" << name << ">(" << builder.to_string() << ");";
dbg() << " }"; dbg() << " }";
dbg() << " virtual ByteBuffer encode() const override"; dbg() << " virtual ByteBuffer encode() const override";
@ -285,6 +297,7 @@ int main(int argc, char** argv)
// FIXME: Support longer messages: // FIXME: Support longer messages:
dbg() << " auto buffer = ByteBuffer::create_uninitialized(1024);"; dbg() << " auto buffer = ByteBuffer::create_uninitialized(1024);";
dbg() << " BufferStream stream(buffer);"; dbg() << " BufferStream stream(buffer);";
dbg() << " stream << endpoint_magic();";
dbg() << " stream << (int)MessageID::" << name << ";"; dbg() << " stream << (int)MessageID::" << name << ";";
for (auto& parameter : parameters) { for (auto& parameter : parameters) {
dbg() << " stream << m_" << parameter.name << ";"; dbg() << " stream << m_" << parameter.name << ";";
@ -317,17 +330,24 @@ int main(int argc, char** argv)
dbg() << "public:"; dbg() << "public:";
dbg() << " " << endpoint.name << "Endpoint() {}"; dbg() << " " << endpoint.name << "Endpoint() {}";
dbg() << " virtual ~" << endpoint.name << "Endpoint() override {}"; dbg() << " virtual ~" << endpoint.name << "Endpoint() override {}";
dbg() << " static int static_magic() { return " << endpoint.magic << "; }";
dbg() << " virtual int magic() const override { return " << endpoint.magic << "; }";
dbg() << " static String static_name() { return \"" << endpoint.name << "\"; };";
dbg() << " virtual String name() const override { return \"" << endpoint.name << "\"; };"; dbg() << " virtual String name() const override { return \"" << endpoint.name << "\"; };";
dbg() << " static OwnPtr<IMessage> decode_message(const ByteBuffer& buffer)"; dbg() << " static OwnPtr<IMessage> decode_message(const ByteBuffer& buffer, size_t& size_in_bytes)";
dbg() << " {"; dbg() << " {";
dbg() << " BufferStream stream(const_cast<ByteBuffer&>(buffer));"; dbg() << " BufferStream stream(const_cast<ByteBuffer&>(buffer));";
dbg() << " i32 message_endpoint_magic = 0;";
dbg() << " stream >> message_endpoint_magic;";
dbg() << " if (message_endpoint_magic != " << endpoint.magic << ")";
dbg() << " return nullptr;";
dbg() << " i32 message_id = 0;"; dbg() << " i32 message_id = 0;";
dbg() << " stream >> message_id;"; dbg() << " stream >> message_id;";
dbg() << " switch (message_id) {"; dbg() << " switch (message_id) {";
for (auto& message : endpoint.messages) { for (auto& message : endpoint.messages) {
auto do_decode_message = [&](const String& name) { auto do_decode_message = [&](const String& name) {
dbg() << " case (int)" << endpoint.name << "::MessageID::" << name << ":"; dbg() << " case (int)" << endpoint.name << "::MessageID::" << name << ":";
dbg() << " return " << endpoint.name << "::" << name << "::decode(stream);"; dbg() << " return " << endpoint.name << "::" << name << "::decode(stream, size_in_bytes);";
}; };
do_decode_message(message.name); do_decode_message(message.name);
if (message.is_synchronous) if (message.is_synchronous)
@ -383,7 +403,7 @@ int main(int argc, char** argv)
#ifdef DEBUG #ifdef DEBUG
for (auto& endpoint : endpoints) { for (auto& endpoint : endpoints) {
dbg() << "Endpoint: '" << endpoint.name << "'"; dbg() << "Endpoint: '" << endpoint.name << "' (magic: " << endpoint.magic << ")";
for (auto& message : endpoint.messages) { for (auto& message : endpoint.messages) {
dbg() << " Message: '" << message.name << "'"; dbg() << " Message: '" << message.name << "'";
dbg() << " Sync: " << message.is_synchronous; dbg() << " Sync: " << message.is_synchronous;

View file

@ -3,7 +3,7 @@
#include <SharedBuffer.h> #include <SharedBuffer.h>
AClientConnection::AClientConnection() AClientConnection::AClientConnection()
: ConnectionNG("/tmp/asportal") : ConnectionNG(*this, "/tmp/asportal")
{ {
} }
@ -76,3 +76,7 @@ int AClientConnection::get_playing_buffer()
{ {
return send_sync<AudioServer::GetPlayingBuffer>()->buffer_id(); return send_sync<AudioServer::GetPlayingBuffer>()->buffer_id();
} }
void AClientConnection::handle(const AudioClient::FinishedPlayingBuffer&)
{
}

View file

@ -1,11 +1,13 @@
#pragma once #pragma once
#include <AudioServer/AudioClientEndpoint.h>
#include <AudioServer/AudioServerEndpoint.h> #include <AudioServer/AudioServerEndpoint.h>
#include <LibCore/CoreIPCClient.h> #include <LibCore/CoreIPCClient.h>
class ABuffer; class ABuffer;
class AClientConnection : public IPC::Client::ConnectionNG<AudioServerEndpoint> { class AClientConnection : public IPC::Client::ConnectionNG<AudioClientEndpoint, AudioServerEndpoint>
, public AudioClientEndpoint {
C_OBJECT(AClientConnection) C_OBJECT(AClientConnection)
public: public:
AClientConnection(); AClientConnection();
@ -26,4 +28,7 @@ public:
void set_paused(bool paused); void set_paused(bool paused);
void clear_buffer(bool paused = false); void clear_buffer(bool paused = false);
private:
virtual void handle(const AudioClient::FinishedPlayingBuffer&) override;
}; };

View file

@ -246,11 +246,12 @@ namespace Client {
int m_my_client_id { -1 }; int m_my_client_id { -1 };
}; };
template<typename Endpoint> template<typename LocalEndpoint, typename PeerEndpoint>
class ConnectionNG : public CObject { class ConnectionNG : public CObject {
public: public:
ConnectionNG(const StringView& address) ConnectionNG(LocalEndpoint& local_endpoint, const StringView& address)
: m_connection(CLocalSocket::construct(this)) : m_local_endpoint(local_endpoint)
, m_connection(CLocalSocket::construct(this))
, m_notifier(CNotifier::construct(m_connection->fd(), CNotifier::Read, this)) , m_notifier(CNotifier::construct(m_connection->fd(), CNotifier::Read, this))
{ {
// We want to rate-limit our clients // We want to rate-limit our clients
@ -312,8 +313,7 @@ namespace Client {
} }
ASSERT(rc > 0); ASSERT(rc > 0);
ASSERT(FD_ISSET(m_connection->fd(), &rfds)); ASSERT(FD_ISSET(m_connection->fd(), &rfds));
bool success = drain_messages_from_server(); if (!drain_messages_from_server())
if (!success)
return nullptr; return nullptr;
for (ssize_t i = 0; i < m_unprocessed_messages.size(); ++i) { for (ssize_t i = 0; i < m_unprocessed_messages.size(); ++i) {
if (m_unprocessed_messages[i]->id() == MessageType::static_message_id()) { if (m_unprocessed_messages[i]->id() == MessageType::static_message_id()) {
@ -358,30 +358,42 @@ namespace Client {
private: private:
bool drain_messages_from_server() bool drain_messages_from_server()
{ {
Vector<u8> bytes;
for (;;) { for (;;) {
u8 buffer[4096]; u8 buffer[4096];
ssize_t nread = recv(m_connection->fd(), buffer, sizeof(buffer), MSG_DONTWAIT); ssize_t nread = recv(m_connection->fd(), buffer, sizeof(buffer), MSG_DONTWAIT);
if (nread < 0) { if (nread < 0) {
if (errno == EAGAIN) { if (errno == EAGAIN)
return true; break;
}
perror("read"); perror("read");
exit(1); exit(1);
return false; return false;
} }
if (nread == 0) { if (nread == 0) {
dbg() << "EOF on IPC fd"; dbg() << "EOF on IPC fd";
// FIXME: Dying is definitely not always appropriate!
exit(1); exit(1);
return false; return false;
} }
bytes.append(buffer, nread);
auto message = Endpoint::decode_message(ByteBuffer::wrap(buffer, sizeof(buffer)));
ASSERT(message);
m_unprocessed_messages.append(move(message));
} }
size_t decoded_bytes = 0;
for (size_t index = 0; index < (size_t)bytes.size(); index += decoded_bytes) {
auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index);
if (auto message = LocalEndpoint::decode_message(remaining_bytes, decoded_bytes)) {
m_local_endpoint.handle(*message);
} else if (auto message = PeerEndpoint::decode_message(remaining_bytes, decoded_bytes)) {
m_unprocessed_messages.append(move(message));
} else {
ASSERT_NOT_REACHED();
}
ASSERT(decoded_bytes);
}
return true;
} }
LocalEndpoint& m_local_endpoint;
RefPtr<CLocalSocket> m_connection; RefPtr<CLocalSocket> m_connection;
RefPtr<CNotifier> m_notifier; RefPtr<CNotifier> m_notifier;
Vector<OwnPtr<IMessage>> m_unprocessed_messages; Vector<OwnPtr<IMessage>> m_unprocessed_messages;

View file

@ -256,7 +256,7 @@ namespace Server {
, m_client_id(client_id) , m_client_id(client_id)
{ {
add_child(socket); add_child(socket);
m_socket->on_ready_to_read = [this] { drain_client(); }; m_socket->on_ready_to_read = [this] { drain_messages_from_client(); };
} }
virtual ~ConnectionNG() override virtual ~ConnectionNG() override
@ -287,15 +287,16 @@ namespace Server {
ASSERT(nwritten == buffer.size()); ASSERT(nwritten == buffer.size());
} }
void drain_client() void drain_messages_from_client()
{ {
unsigned messages_received = 0; Vector<u8> bytes;
for (;;) { for (;;) {
u8 buffer[4096]; u8 buffer[4096];
ssize_t nread = recv(m_socket->fd(), buffer, sizeof(buffer), MSG_DONTWAIT); ssize_t nread = recv(m_socket->fd(), buffer, sizeof(buffer), MSG_DONTWAIT);
if (nread == 0 || (nread == -1 && errno == EAGAIN)) { if (nread == 0 || (nread == -1 && errno == EAGAIN)) {
if (!messages_received) { if (bytes.is_empty()) {
CEventLoop::current().post_event(*this, make<DisconnectedEvent>(client_id())); CEventLoop::current().post_event(*this, make<DisconnectedEvent>(client_id()));
return;
} }
break; break;
} }
@ -303,17 +304,21 @@ namespace Server {
perror("recv"); perror("recv");
ASSERT_NOT_REACHED(); ASSERT_NOT_REACHED();
} }
auto message = m_endpoint.decode_message(ByteBuffer::wrap(buffer, nread)); bytes.append(buffer, nread);
}
size_t decoded_bytes = 0;
for (size_t index = 0; index < (size_t)bytes.size(); index += decoded_bytes) {
auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index);
auto message = Endpoint::decode_message(remaining_bytes, decoded_bytes);
if (!message) { if (!message) {
dbg() << "drain_client: Endpoint didn't recognize message"; dbg() << "drain_messages_from_client: Endpoint didn't recognize message";
did_misbehave(); did_misbehave();
return; return;
} }
++messages_received; if (auto response = m_endpoint.handle(*message))
auto response = m_endpoint.handle(*message);
if (response)
post_message(*response); post_message(*response);
ASSERT(decoded_bytes);
} }
} }

View file

@ -13,6 +13,7 @@ class IEndpoint {
public: public:
virtual ~IEndpoint(); virtual ~IEndpoint();
virtual int magic() const = 0;
virtual String name() const = 0; virtual String name() const = 0;
virtual OwnPtr<IMessage> handle(const IMessage&) = 0; virtual OwnPtr<IMessage> handle(const IMessage&) = 0;

View file

@ -7,6 +7,7 @@ class IMessage {
public: public:
virtual ~IMessage(); virtual ~IMessage();
virtual int endpoint_magic() const = 0;
virtual int id() const = 0; virtual int id() const = 0;
virtual String name() const = 0; virtual String name() const = 0;
virtual ByteBuffer encode() const = 0; virtual ByteBuffer encode() const = 0;

View file

@ -1,10 +1,9 @@
#include "ASClientConnection.h" #include "ASClientConnection.h"
#include "ASMixer.h" #include "ASMixer.h"
#include "AudioClientEndpoint.h"
#include <LibAudio/ABuffer.h> #include <LibAudio/ABuffer.h>
#include <LibCore/CEventLoop.h> #include <LibCore/CEventLoop.h>
#include <SharedBuffer.h> #include <SharedBuffer.h>
#include <errno.h> #include <errno.h>
#include <stdio.h> #include <stdio.h>
#include <sys/socket.h> #include <sys/socket.h>
@ -30,10 +29,9 @@ void ASClientConnection::die()
s_connections.remove(client_id()); s_connections.remove(client_id());
} }
void ASClientConnection::did_finish_playing_buffer(Badge<ASMixer>, int buffer_id) void ASClientConnection::did_finish_playing_buffer(Badge<ASBufferQueue>, int buffer_id)
{ {
(void)buffer_id; post_message(AudioClient::FinishedPlayingBuffer(buffer_id));
//post_message(AudioClient::FinishedPlayingBuffer(buffer_id));
} }
OwnPtr<AudioServer::GreetResponse> ASClientConnection::handle(const AudioServer::Greet& message) OwnPtr<AudioServer::GreetResponse> ASClientConnection::handle(const AudioServer::Greet& message)

View file

@ -13,7 +13,7 @@ class ASClientConnection final : public IPC::Server::ConnectionNG<AudioServerEnd
public: public:
explicit ASClientConnection(CLocalSocket&, int client_id, ASMixer& mixer); explicit ASClientConnection(CLocalSocket&, int client_id, ASMixer& mixer);
~ASClientConnection() override; ~ASClientConnection() override;
void did_finish_playing_buffer(Badge<ASMixer>, int buffer_id); void did_finish_playing_buffer(Badge<ASBufferQueue>, int buffer_id);
virtual void die() override; virtual void die() override;

View file

@ -1,5 +1,6 @@
#pragma once #pragma once
#include "ASClientConnection.h"
#include <AK/ByteBuffer.h> #include <AK/ByteBuffer.h>
#include <AK/NonnullRefPtrVector.h> #include <AK/NonnullRefPtrVector.h>
#include <AK/Queue.h> #include <AK/Queue.h>
@ -36,6 +37,7 @@ public:
++m_played_samples; ++m_played_samples;
if (m_position >= m_current->sample_count()) { if (m_position >= m_current->sample_count()) {
m_client->did_finish_playing_buffer({}, m_current->shared_buffer_id());
m_current = nullptr; m_current = nullptr;
m_position = 0; m_position = 0;
} }
@ -61,8 +63,10 @@ public:
int get_remaining_samples() const { return m_remaining_samples; } int get_remaining_samples() const { return m_remaining_samples; }
int get_played_samples() const { return m_played_samples; } int get_played_samples() const { return m_played_samples; }
int get_playing_buffer() const { int get_playing_buffer() const
if(m_current) return m_current->shared_buffer_id(); {
if (m_current)
return m_current->shared_buffer_id();
return -1; return -1;
} }

View file

@ -1,4 +1,4 @@
endpoint AudioClient endpoint AudioClient = 82
{ {
FinishedPlayingBuffer(i32 buffer_id) =| FinishedPlayingBuffer(i32 buffer_id) =|
} }

View file

@ -1,4 +1,4 @@
endpoint AudioServer endpoint AudioServer = 85
{ {
// Basic protocol // Basic protocol
Greet(i32 client_pid) => (i32 server_pid, i32 client_id) Greet(i32 client_pid) => (i32 server_pid, i32 client_id)

View file

@ -13,11 +13,14 @@ DEFINES += -DUSERLAND
all: $(APP) all: $(APP)
*.cpp: AudioServerEndpoint.h *.cpp: AudioServerEndpoint.h AudioClientEndpoint.h
AudioServerEndpoint.h: AudioServer.ipc AudioServerEndpoint.h: AudioServer.ipc
@echo "IPC $<"; $(IPCCOMPILER) $< > $@ @echo "IPC $<"; $(IPCCOMPILER) $< > $@
AudioClientEndpoint.h: AudioClient.ipc
@echo "IPC $<"; $(IPCCOMPILER) $< > $@
$(APP): $(OBJS) $(APP): $(OBJS)
$(LD) -o $(APP) $(LDFLAGS) $(OBJS) -lc -lcore -lipc -lthread -lpthread $(LD) -o $(APP) $(LDFLAGS) $(OBJS) -lc -lcore -lipc -lthread -lpthread
@ -27,5 +30,5 @@ $(APP): $(OBJS)
-include $(OBJS:%.o=%.d) -include $(OBJS:%.o=%.d)
clean: clean:
@echo "CLEAN"; rm -f $(APP) $(OBJS) *.d AudioServerEndpoint.h @echo "CLEAN"; rm -f $(APP) $(OBJS) *.d AudioServerEndpoint.h AudioClientEndpoint.h