Browse Source

ProtocolServer+LibProtocol: Introduce a server for handling downloads

This patch adds ProtocolServer, a server that handles network requests
on behalf of its clients. The first protocol implemented is HTTP.

The idea here is to use a plug-in architecture where any number of
protocols can be added and implemented without having to mess around
with each client program that wants to use the protocol.

A simple client API is provided through LibProtocol::Client. :^)
Andreas Kling 5 years ago
parent
commit
fd4349a9f2

+ 1 - 0
Kernel/build-root-filesystem.sh

@@ -107,6 +107,7 @@ cp ../Servers/WindowServer/WindowServer mnt/bin/WindowServer
 cp ../Servers/AudioServer/AudioServer mnt/bin/AudioServer
 cp ../Servers/AudioServer/AudioServer mnt/bin/AudioServer
 cp ../Servers/TTYServer/TTYServer mnt/bin/TTYServer
 cp ../Servers/TTYServer/TTYServer mnt/bin/TTYServer
 cp ../Servers/TelnetServer/TelnetServer mnt/bin/TelnetServer
 cp ../Servers/TelnetServer/TelnetServer mnt/bin/TelnetServer
+cp ../Servers/ProtocolServer/ProtocolServer mnt/bin/ProtocolServer
 cp ../Shell/Shell mnt/bin/Shell
 cp ../Shell/Shell mnt/bin/Shell
 echo "done"
 echo "done"
 
 

+ 2 - 0
Kernel/makeall.sh

@@ -31,6 +31,7 @@ build_targets="$build_targets ../Libraries/LibPthread"
 # Build IPC servers before their client code to ensure the IPC definitions are available.
 # Build IPC servers before their client code to ensure the IPC definitions are available.
 build_targets="$build_targets ../Servers/AudioServer"
 build_targets="$build_targets ../Servers/AudioServer"
 build_targets="$build_targets ../Servers/LookupServer"
 build_targets="$build_targets ../Servers/LookupServer"
+build_targets="$build_targets ../Servers/ProtocolServer"
 
 
 build_targets="$build_targets ../AK"
 build_targets="$build_targets ../AK"
 
 
@@ -42,6 +43,7 @@ build_targets="$build_targets ../Libraries/LibM"
 build_targets="$build_targets ../Libraries/LibPCIDB"
 build_targets="$build_targets ../Libraries/LibPCIDB"
 build_targets="$build_targets ../Libraries/LibVT"
 build_targets="$build_targets ../Libraries/LibVT"
 build_targets="$build_targets ../Libraries/LibMarkdown"
 build_targets="$build_targets ../Libraries/LibMarkdown"
+build_targets="$build_targets ../Libraries/LibProtocol"
 
 
 build_targets="$build_targets ../Applications/About"
 build_targets="$build_targets ../Applications/About"
 build_targets="$build_targets ../Applications/Calculator"
 build_targets="$build_targets ../Applications/Calculator"

+ 45 - 0
Libraries/LibProtocol/Client.cpp

@@ -0,0 +1,45 @@
+#include <LibProtocol/Client.h>
+#include <SharedBuffer.h>
+
+namespace LibProtocol {
+
+Client::Client()
+    : ConnectionNG(*this, "/tmp/psportal")
+{
+}
+
+void Client::handshake()
+{
+    auto response = send_sync<ProtocolServer::Greet>(getpid());
+    set_server_pid(response->server_pid());
+    set_my_client_id(response->client_id());
+}
+
+bool Client::is_supported_protocol(const String& protocol)
+{
+    return send_sync<ProtocolServer::IsSupportedProtocol>(protocol)->supported();
+}
+
+i32 Client::start_download(const String& url)
+{
+    return send_sync<ProtocolServer::StartDownload>(url)->download_id();
+}
+
+bool Client::stop_download(i32 download_id)
+{
+    return send_sync<ProtocolServer::StopDownload>(download_id)->success();
+}
+
+void Client::handle(const ProtocolClient::DownloadFinished& message)
+{
+    if (on_download_finish)
+        on_download_finish(message.download_id(), message.success());
+}
+
+void Client::handle(const ProtocolClient::DownloadProgress& message)
+{
+    if (on_download_progress)
+        on_download_progress(message.download_id(), message.total_size(), message.downloaded_size());
+}
+
+}

+ 29 - 0
Libraries/LibProtocol/Client.h

@@ -0,0 +1,29 @@
+#pragma once
+
+#include <LibCore/CoreIPCClient.h>
+#include <ProtocolServer/ProtocolClientEndpoint.h>
+#include <ProtocolServer/ProtocolServerEndpoint.h>
+
+namespace LibProtocol {
+
+class Client : public IPC::Client::ConnectionNG<ProtocolClientEndpoint, ProtocolServerEndpoint>
+    , public ProtocolClientEndpoint {
+    C_OBJECT(Client)
+public:
+    Client();
+
+    virtual void handshake() override;
+
+    bool is_supported_protocol(const String&);
+    i32 start_download(const String& url);
+    bool stop_download(i32 download_id);
+
+    Function<void(i32 download_id, bool success)> on_download_finish;
+    Function<void(i32 download_id, u64 total_size, u64 downloaded_size)> on_download_progress;
+
+private:
+    virtual void handle(const ProtocolClient::DownloadProgress&) override;
+    virtual void handle(const ProtocolClient::DownloadFinished&) override;
+};
+
+}

+ 20 - 0
Libraries/LibProtocol/Makefile

@@ -0,0 +1,20 @@
+include ../../Makefile.common
+
+OBJS = \
+    Client.o
+
+LIBRARY = libprotocol.a
+DEFINES += -DUSERLAND
+
+all: $(LIBRARY)
+
+$(LIBRARY): $(OBJS)
+	@echo "LIB $@"; $(AR) rcs $@ $(OBJS) $(LIBS)
+
+.cpp.o:
+	@echo "CXX $<"; $(CXX) $(CXXFLAGS) -o $@ -c $<
+
+-include $(OBJS:%.o=%.d)
+
+clean:
+	@echo "CLEAN"; rm -f $(LIBRARY) $(OBJS) *.d

+ 1 - 0
Makefile.common

@@ -28,6 +28,7 @@ LDFLAGS = \
     -L$(SERENITY_BASE_DIR)/Libraries/LibMarkdown \
     -L$(SERENITY_BASE_DIR)/Libraries/LibMarkdown \
     -L$(SERENITY_BASE_DIR)/Libraries/LibThread \
     -L$(SERENITY_BASE_DIR)/Libraries/LibThread \
     -L$(SERENITY_BASE_DIR)/Libraries/LibVT \
     -L$(SERENITY_BASE_DIR)/Libraries/LibVT \
+    -L$(SERENITY_BASE_DIR)/Libraries/LibProtocol \
     -L$(SERENITY_BASE_DIR)/Libraries/LibAudio
     -L$(SERENITY_BASE_DIR)/Libraries/LibAudio
 
 
 CLANG_FLAGS = -Wconsumed -m32 -ffreestanding -march=i686
 CLANG_FLAGS = -Wconsumed -m32 -ffreestanding -march=i686

+ 55 - 0
Servers/ProtocolServer/Download.cpp

@@ -0,0 +1,55 @@
+#include <ProtocolServer/Download.h>
+#include <ProtocolServer/PSClientConnection.h>
+
+// FIXME: What about rollover?
+static i32 s_next_id = 1;
+
+static HashMap<i32, RefPtr<Download>>& all_downloads()
+{
+    static HashMap<i32, RefPtr<Download>> map;
+    return map;
+}
+
+Download* Download::find_by_id(i32 id)
+{
+    return all_downloads().get(id).value_or(nullptr);
+}
+
+Download::Download(PSClientConnection& client)
+    : m_id(s_next_id++)
+    , m_client(client.make_weak_ptr())
+{
+    all_downloads().set(m_id, this);
+}
+
+Download::~Download()
+{
+}
+
+void Download::stop()
+{
+    all_downloads().remove(m_id);
+}
+
+void Download::did_finish(bool success)
+{
+    if (!m_client) {
+        dbg() << "Download::did_finish() after the client already disconnected.";
+        return;
+    }
+    m_client->did_finish_download({}, *this, success);
+    all_downloads().remove(m_id);
+}
+
+void Download::did_progress(size_t total_size, size_t downloaded_size)
+{
+    if (!m_client) {
+        // FIXME: We should also abort the download in this situation, I guess!
+        dbg() << "Download::did_progress() after the client already disconnected.";
+        return;
+    }
+    m_total_size = total_size;
+    m_downloaded_size = downloaded_size;
+    m_client->did_progress_download({}, *this);
+}
+

+ 35 - 0
Servers/ProtocolServer/Download.h

@@ -0,0 +1,35 @@
+#pragma once
+
+#include <AK/RefCounted.h>
+#include <AK/URL.h>
+#include <AK/WeakPtr.h>
+
+class PSClientConnection;
+
+class Download : public RefCounted<Download> {
+public:
+    virtual ~Download();
+
+    static Download* find_by_id(i32);
+
+    i32 id() const { return m_id; }
+    URL url() const { return m_url; }
+
+    size_t total_size() const { return m_total_size; }
+    size_t downloaded_size() const { return m_downloaded_size; }
+
+    void stop();
+
+protected:
+    explicit Download(PSClientConnection&);
+
+    void did_finish(bool success);
+    void did_progress(size_t total_size, size_t downloaded_size);
+
+private:
+    i32 m_id;
+    URL m_url;
+    size_t m_total_size { 0 };
+    size_t m_downloaded_size { 0 };
+    WeakPtr<PSClientConnection> m_client;
+};

+ 20 - 0
Servers/ProtocolServer/HttpDownload.cpp

@@ -0,0 +1,20 @@
+#include <LibCore/CHttpJob.h>
+#include <ProtocolServer/HttpDownload.h>
+
+HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr<CHttpJob>&& job)
+    : Download(client)
+    , m_job(job)
+{
+    m_job->on_finish = [this](bool success) {
+        did_finish(success);
+    };
+}
+
+HttpDownload::~HttpDownload()
+{
+}
+
+NonnullRefPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, PSClientConnection& client, NonnullRefPtr<CHttpJob>&& job)
+{
+    return adopt(*new HttpDownload(client, move(job)));
+}

+ 18 - 0
Servers/ProtocolServer/HttpDownload.h

@@ -0,0 +1,18 @@
+#pragma once
+
+#include <AK/Badge.h>
+#include <ProtocolServer/Download.h>
+
+class CHttpJob;
+class HttpProtocol;
+
+class HttpDownload final : public Download {
+public:
+    virtual ~HttpDownload() override;
+    static NonnullRefPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, PSClientConnection&, NonnullRefPtr<CHttpJob>&&);
+
+private:
+    explicit HttpDownload(PSClientConnection&, NonnullRefPtr<CHttpJob>&&);
+
+    NonnullRefPtr<CHttpJob> m_job;
+};

+ 24 - 0
Servers/ProtocolServer/HttpProtocol.cpp

@@ -0,0 +1,24 @@
+#include <LibCore/CHttpJob.h>
+#include <LibCore/CHttpRequest.h>
+#include <ProtocolServer/HttpDownload.h>
+#include <ProtocolServer/HttpProtocol.h>
+
+HttpProtocol::HttpProtocol()
+    : Protocol("http")
+{
+}
+
+HttpProtocol::~HttpProtocol()
+{
+}
+
+RefPtr<Download> HttpProtocol::start_download(PSClientConnection& client, const URL& url)
+{
+    CHttpRequest request;
+    request.set_method(CHttpRequest::Method::GET);
+    request.set_url(url);
+    auto job = request.schedule();
+    if (!job)
+        return nullptr;
+    return HttpDownload::create_with_job({}, client, (CHttpJob&)*job);
+}

+ 11 - 0
Servers/ProtocolServer/HttpProtocol.h

@@ -0,0 +1,11 @@
+#pragma once
+
+#include <ProtocolServer/Protocol.h>
+
+class HttpProtocol final : public Protocol {
+public:
+    HttpProtocol();
+    virtual ~HttpProtocol() override;
+
+    virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) override;
+};

+ 35 - 0
Servers/ProtocolServer/Makefile

@@ -0,0 +1,35 @@
+include ../../Makefile.common
+
+OBJS = \
+    PSClientConnection.o \
+    Protocol.o \
+    Download.o \
+    HttpProtocol.o \
+    HttpDownload.o \
+    main.o
+
+APP = ProtocolServer
+
+DEFINES += -DUSERLAND
+
+all: $(APP)
+
+*.cpp: ProtocolServerEndpoint.h ProtocolClientEndpoint.h
+
+ProtocolServerEndpoint.h: ProtocolServer.ipc
+	@echo "IPC $<"; $(IPCCOMPILER) $< > $@
+
+ProtocolClientEndpoint.h: ProtocolClient.ipc
+	@echo "IPC $<"; $(IPCCOMPILER) $< > $@
+
+$(APP): $(OBJS)
+	$(LD) -o $(APP) $(LDFLAGS) $(OBJS) -lc -lcore -lipc -ldraw
+
+.cpp.o:
+	@echo "CXX $<"; $(CXX) $(CXXFLAGS) -o $@ -c $<
+
+-include $(OBJS:%.o=%.d)
+
+clean:
+	@echo "CLEAN"; rm -f $(APP) $(OBJS) *.d ProtocolClientEndpoint.h ProtocolServerEndpoint.h
+

+ 63 - 0
Servers/ProtocolServer/PSClientConnection.cpp

@@ -0,0 +1,63 @@
+#include <ProtocolServer/Download.h>
+#include <ProtocolServer/PSClientConnection.h>
+#include <ProtocolServer/Protocol.h>
+#include <ProtocolServer/ProtocolClientEndpoint.h>
+
+static HashMap<int, RefPtr<PSClientConnection>> s_connections;
+
+PSClientConnection::PSClientConnection(CLocalSocket& socket, int client_id)
+    : ConnectionNG(*this, socket, client_id)
+{
+    s_connections.set(client_id, *this);
+}
+
+PSClientConnection::~PSClientConnection()
+{
+}
+
+void PSClientConnection::die()
+{
+    s_connections.remove(client_id());
+}
+
+OwnPtr<ProtocolServer::IsSupportedProtocolResponse> PSClientConnection::handle(const ProtocolServer::IsSupportedProtocol& message)
+{
+    bool supported = Protocol::find_by_name(message.protocol().to_lowercase());
+    return make<ProtocolServer::IsSupportedProtocolResponse>(supported);
+}
+
+OwnPtr<ProtocolServer::StartDownloadResponse> PSClientConnection::handle(const ProtocolServer::StartDownload& message)
+{
+    URL url(message.url());
+    ASSERT(url.is_valid());
+    auto* protocol = Protocol::find_by_name(url.protocol());
+    ASSERT(protocol);
+    auto download = protocol->start_download(*this, url);
+    return make<ProtocolServer::StartDownloadResponse>(download->id());
+}
+
+OwnPtr<ProtocolServer::StopDownloadResponse> PSClientConnection::handle(const ProtocolServer::StopDownload& message)
+{
+    auto* download = Download::find_by_id(message.download_id());
+    bool success = false;
+    if (download) {
+        download->stop();
+    }
+    return make<ProtocolServer::StopDownloadResponse>(success);
+}
+
+void PSClientConnection::did_finish_download(Badge<Download>, Download& download, bool success)
+{
+    post_message(ProtocolClient::DownloadFinished(download.id(), success));
+}
+
+void PSClientConnection::did_progress_download(Badge<Download>, Download& download)
+{
+    post_message(ProtocolClient::DownloadProgress(download.id(), download.total_size(), download.downloaded_size()));
+}
+
+OwnPtr<ProtocolServer::GreetResponse> PSClientConnection::handle(const ProtocolServer::Greet& message)
+{
+    set_client_pid(message.client_pid());
+    return make<ProtocolServer::GreetResponse>(getpid(), client_id());
+}

+ 26 - 0
Servers/ProtocolServer/PSClientConnection.h

@@ -0,0 +1,26 @@
+#pragma once
+
+#include <AK/Badge.h>
+#include <LibCore/CoreIPCServer.h>
+#include <ProtocolServer/ProtocolServerEndpoint.h>
+
+class Download;
+
+class PSClientConnection final : public IPC::Server::ConnectionNG<ProtocolServerEndpoint>
+    , public ProtocolServerEndpoint {
+    C_OBJECT(PSClientConnection)
+public:
+    explicit PSClientConnection(CLocalSocket&, int client_id);
+    ~PSClientConnection() override;
+
+    virtual void die() override;
+
+    void did_finish_download(Badge<Download>, Download&, bool success);
+    void did_progress_download(Badge<Download>, Download&);
+
+private:
+    virtual OwnPtr<ProtocolServer::GreetResponse> handle(const ProtocolServer::Greet&) override;
+    virtual OwnPtr<ProtocolServer::IsSupportedProtocolResponse> handle(const ProtocolServer::IsSupportedProtocol&) override;
+    virtual OwnPtr<ProtocolServer::StartDownloadResponse> handle(const ProtocolServer::StartDownload&) override;
+    virtual OwnPtr<ProtocolServer::StopDownloadResponse> handle(const ProtocolServer::StopDownload&) override;
+};

+ 23 - 0
Servers/ProtocolServer/Protocol.cpp

@@ -0,0 +1,23 @@
+#include <AK/HashMap.h>
+#include <ProtocolServer/Protocol.h>
+
+static HashMap<String, Protocol*>& all_protocols()
+{
+    static HashMap<String, Protocol*> map;
+    return map;
+}
+
+Protocol* Protocol::find_by_name(const String& name)
+{
+    return all_protocols().get(name).value_or(nullptr);
+}
+
+Protocol::Protocol(const String& name)
+{
+    all_protocols().set(name, this);
+}
+
+Protocol::~Protocol()
+{
+    ASSERT_NOT_REACHED();
+}

+ 23 - 0
Servers/ProtocolServer/Protocol.h

@@ -0,0 +1,23 @@
+#pragma once
+
+#include <AK/RefPtr.h>
+#include <AK/URL.h>
+
+class Download;
+class PSClientConnection;
+
+class Protocol {
+public:
+    virtual ~Protocol();
+
+    const String& name() const { return m_name; }
+    virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) = 0;
+
+    static Protocol* find_by_name(const String&);
+
+protected:
+    explicit Protocol(const String& name);
+
+private:
+    String m_name;
+};

+ 6 - 0
Servers/ProtocolServer/ProtocolClient.ipc

@@ -0,0 +1,6 @@
+endpoint ProtocolClient = 13
+{
+    // Download notifications
+    DownloadProgress(i32 download_id, u32 total_size, u32 downloaded_size) =|
+    DownloadFinished(i32 download_id, bool success) =|
+}

+ 12 - 0
Servers/ProtocolServer/ProtocolServer.ipc

@@ -0,0 +1,12 @@
+endpoint ProtocolServer = 9
+{
+    // Basic protocol
+    Greet(i32 client_pid) => (i32 server_pid, i32 client_id)
+
+    // Test if a specific protocol is supported, e.g "http"
+    IsSupportedProtocol(String protocol) => (bool supported)
+
+    // Download API
+    StartDownload(String url) => (i32 download_id)
+    StopDownload(i32 download_id) => (bool success)
+}

+ 25 - 0
Servers/ProtocolServer/main.cpp

@@ -0,0 +1,25 @@
+#include <LibCore/CEventLoop.h>
+#include <LibCore/CLocalServer.h>
+#include <LibCore/CoreIPCServer.h>
+#include <ProtocolServer/HttpProtocol.h>
+#include <ProtocolServer/PSClientConnection.h>
+
+int main(int, char**)
+{
+    CEventLoop event_loop;
+    (void)*new HttpProtocol;
+    auto server = CLocalServer::construct();
+    unlink("/tmp/psportal");
+    server->listen("/tmp/psportal");
+    server->on_ready_to_accept = [&] {
+        auto client_socket = server->accept();
+        if (!client_socket) {
+            dbg() << "ProtocolServer: accept failed.";
+            return;
+        }
+        static int s_next_client_id = 0;
+        int client_id = ++s_next_client_id;
+        IPC::Server::new_connection_ng_for_client<PSClientConnection>(*client_socket, client_id);
+    };
+    return event_loop.exec();
+}

+ 1 - 0
Servers/SystemServer/main.cpp

@@ -106,6 +106,7 @@ int main(int, char**)
 
 
     signal(SIGCHLD, sigchld_handler);
     signal(SIGCHLD, sigchld_handler);
 
 
+    start_process("/bin/ProtocolServer", {}, lowest_prio);
     start_process("/bin/LookupServer", {}, lowest_prio);
     start_process("/bin/LookupServer", {}, lowest_prio);
     start_process("/bin/WindowServer", {}, highest_prio);
     start_process("/bin/WindowServer", {}, highest_prio);
     start_process("/bin/AudioServer", {}, highest_prio);
     start_process("/bin/AudioServer", {}, highest_prio);