diff --git a/Meta/Lagom/CMakeLists.txt b/Meta/Lagom/CMakeLists.txt index 4c8725a4cf6..3f3f144362e 100644 --- a/Meta/Lagom/CMakeLists.txt +++ b/Meta/Lagom/CMakeLists.txt @@ -662,6 +662,7 @@ if (BUILD_LAGOM) # LibCore if ((LINUX OR APPLE) AND NOT EMSCRIPTEN) lagom_test(../../Tests/LibCore/TestLibCoreFileWatcher.cpp) + lagom_test(../../Tests/LibCore/TestLibCorePromise.cpp LIBS LibThreading) endif() # RegexLibC test POSIX and contains many Serenity extensions diff --git a/Tests/LibCore/CMakeLists.txt b/Tests/LibCore/CMakeLists.txt index afbc6a364ea..d0455d2c043 100644 --- a/Tests/LibCore/CMakeLists.txt +++ b/Tests/LibCore/CMakeLists.txt @@ -12,6 +12,7 @@ foreach(source IN LISTS TEST_SOURCES) serenity_test("${source}" LibCore) endforeach() +target_link_libraries(TestLibCorePromise PRIVATE LibThreading) # NOTE: Required because of the LocalServer tests target_link_libraries(TestLibCoreStream PRIVATE LibThreading) target_link_libraries(TestLibCoreSharedSingleProducerCircularQueue PRIVATE LibThreading) diff --git a/Tests/LibCore/TestLibCorePromise.cpp b/Tests/LibCore/TestLibCorePromise.cpp index 6e8f969a88c..ac691474e2d 100644 --- a/Tests/LibCore/TestLibCorePromise.cpp +++ b/Tests/LibCore/TestLibCorePromise.cpp @@ -6,7 +6,10 @@ #include #include +#include #include +#include +#include TEST_CASE(promise_await_async_event) { @@ -57,3 +60,108 @@ TEST_CASE(promise_chain_handlers) EXPECT(resolved); EXPECT(!rejected); } + +TEST_CASE(threaded_promise_instantly_resolved) +{ + Core::EventLoop loop; + + bool resolved = false; + bool rejected = true; + Optional thread_id; + + auto promise = Core::ThreadedPromise::create(); + + auto thread = Threading::Thread::construct([&, promise] { + thread_id = pthread_self(); + promise->resolve(42); + return 0; + }); + thread->start(); + + promise + ->when_resolved([&](int result) { + EXPECT(thread_id.has_value()); + EXPECT(pthread_equal(thread_id.value(), pthread_self())); + resolved = true; + rejected = false; + EXPECT_EQ(result, 42); + }) + .when_rejected([](Error&&) { + VERIFY_NOT_REACHED(); + }); + + promise->await(); + EXPECT(promise->has_completed()); + EXPECT(resolved); + EXPECT(!rejected); + MUST(thread->join()); +} + +TEST_CASE(threaded_promise_resolved_later) +{ + Core::EventLoop loop; + + bool unblock_thread = false; + bool resolved = false; + bool rejected = true; + Optional thread_id; + + auto promise = Core::ThreadedPromise::create(); + + auto thread = Threading::Thread::construct([&, promise] { + thread_id = pthread_self(); + while (!unblock_thread) + usleep(500); + promise->resolve(42); + return 0; + }); + thread->start(); + + promise + ->when_resolved([&]() { + EXPECT(thread_id.has_value()); + EXPECT(pthread_equal(thread_id.value(), pthread_self())); + EXPECT(unblock_thread); + resolved = true; + rejected = false; + }) + .when_rejected([](Error&&) { + VERIFY_NOT_REACHED(); + }); + + Core::EventLoop::current().deferred_invoke([&]() { unblock_thread = true; }); + + promise->await(); + EXPECT(promise->has_completed()); + EXPECT(unblock_thread); + EXPECT(resolved); + EXPECT(!rejected); + MUST(thread->join()); +} + +TEST_CASE(threaded_promise_synchronously_resolved) +{ + Core::EventLoop loop; + + bool resolved = false; + bool rejected = true; + auto thread_id = pthread_self(); + + auto promise = Core::ThreadedPromise::create(); + promise->resolve(1337); + + promise + ->when_resolved([&]() { + EXPECT(pthread_equal(thread_id, pthread_self())); + resolved = true; + rejected = false; + }) + .when_rejected([](Error&&) { + VERIFY_NOT_REACHED(); + }); + + promise->await(); + EXPECT(promise->has_completed()); + EXPECT(resolved); + EXPECT(!rejected); +} diff --git a/Userland/Libraries/LibCore/EventLoop.cpp b/Userland/Libraries/LibCore/EventLoop.cpp index 7d721341eee..ee889a98cf9 100644 --- a/Userland/Libraries/LibCore/EventLoop.cpp +++ b/Userland/Libraries/LibCore/EventLoop.cpp @@ -17,12 +17,17 @@ namespace Core { namespace { -Vector& event_loop_stack() +OwnPtr>& event_loop_stack_uninitialized() { thread_local OwnPtr> s_event_loop_stack = nullptr; - if (s_event_loop_stack == nullptr) - s_event_loop_stack = make>(); - return *s_event_loop_stack; + return s_event_loop_stack; +} +Vector& event_loop_stack() +{ + auto& the_stack = event_loop_stack_uninitialized(); + if (the_stack == nullptr) + the_stack = make>(); + return *the_stack; } } @@ -41,6 +46,12 @@ EventLoop::~EventLoop() } } +bool EventLoop::is_running() +{ + auto& stack = event_loop_stack_uninitialized(); + return stack != nullptr && !stack->is_empty(); +} + EventLoop& EventLoop::current() { return event_loop_stack().last(); diff --git a/Userland/Libraries/LibCore/EventLoop.h b/Userland/Libraries/LibCore/EventLoop.h index 4d093400910..e43a22838bb 100644 --- a/Userland/Libraries/LibCore/EventLoop.h +++ b/Userland/Libraries/LibCore/EventLoop.h @@ -92,6 +92,7 @@ public: }; static void notify_forked(ForkEvent); + static bool is_running(); static EventLoop& current(); EventLoopImplementation& impl() { return *m_impl; } diff --git a/Userland/Libraries/LibCore/Forward.h b/Userland/Libraries/LibCore/Forward.h index 9272a542d7e..f7c36ebd04e 100644 --- a/Userland/Libraries/LibCore/Forward.h +++ b/Userland/Libraries/LibCore/Forward.h @@ -36,6 +36,8 @@ class ProcessStatisticsReader; class Socket; template class Promise; +template +class ThreadedPromise; class SocketAddress; class TCPServer; class TCPSocket; diff --git a/Userland/Libraries/LibCore/ThreadedPromise.h b/Userland/Libraries/LibCore/ThreadedPromise.h new file mode 100644 index 00000000000..fe2dc2a26f4 --- /dev/null +++ b/Userland/Libraries/LibCore/ThreadedPromise.h @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2021, Kyle Pereira + * Copyright (c) 2022, kleines Filmröllchen + * Copyright (c) 2021-2023, Ali Mohammad Pur + * Copyright (c) 2023, Gregory Bertilson + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace Core { + +template +class ThreadedPromise + : public AtomicRefCounted> { +public: + static NonnullRefPtr> create() + { + return adopt_ref(*new ThreadedPromise()); + } + + using ResultType = Conditional, Empty, TResult>; + using ErrorType = TError; + + void resolve(ResultType&& result) + { + when_error_handler_is_ready([self = NonnullRefPtr(*this), result = move(result)]() mutable { + if (self->m_resolution_handler) { + auto handler_result = self->m_resolution_handler(forward(result)); + if (handler_result.is_error()) + self->m_rejection_handler(handler_result.release_error()); + self->m_has_completed = true; + } + }); + } + void resolve() + requires IsSame + { + resolve(Empty()); + } + + void reject(ErrorType&& error) + { + when_error_handler_is_ready([this, error = move(error)]() mutable { + m_rejection_handler(forward(error)); + m_has_completed = true; + }); + } + void reject(ErrorType const& error) + requires IsTriviallyCopyable + { + reject(ErrorType(error)); + } + + bool has_completed() + { + Threading::MutexLocker locker { m_mutex }; + return m_has_completed; + } + + void await() + { + while (!has_completed()) + Core::EventLoop::current().pump(EventLoop::WaitMode::PollForEvents); + } + + // Set the callback to be called when the promise is resolved. A rejection callback + // must also be provided before any callback will be called. + template, ResultType&&> ResolvedHandler> + ThreadedPromise& when_resolved(ResolvedHandler handler) + { + Threading::MutexLocker locker { m_mutex }; + VERIFY(!m_resolution_handler); + m_resolution_handler = move(handler); + return *this; + } + + template ResolvedHandler> + ThreadedPromise& when_resolved(ResolvedHandler handler) + { + return when_resolved([handler = move(handler)](ResultType&& result) -> ErrorOr { + handler(forward(result)); + return {}; + }); + } + + template> ResolvedHandler> + ThreadedPromise& when_resolved(ResolvedHandler handler) + { + return when_resolved([handler = move(handler)](ResultType&&) -> ErrorOr { + return handler(); + }); + } + + template ResolvedHandler> + ThreadedPromise& when_resolved(ResolvedHandler handler) + { + return when_resolved([handler = move(handler)](ResultType&&) -> ErrorOr { + handler(); + return {}; + }); + } + + // Set the callback to be called when the promise is rejected. Setting this callback + // will cause the promise fulfillment to be ready to be handled. + template RejectedHandler> + ThreadedPromise& when_rejected(RejectedHandler when_rejected = [](ErrorType&) {}) + { + Threading::MutexLocker locker { m_mutex }; + VERIFY(!m_rejection_handler); + m_rejection_handler = move(when_rejected); + return *this; + } + + template>, ResultType&&> ChainedResolution> + NonnullRefPtr> chain_promise(ChainedResolution chained_resolution) + { + auto new_promise = ThreadedPromise::create(); + when_resolved([=, chained_resolution = move(chained_resolution)](ResultType&& result) mutable -> ErrorOr { + chained_resolution(forward(result)) + ->when_resolved([=](auto&& new_result) { new_promise->resolve(move(new_result)); }) + .when_rejected([=](ErrorType&& error) { new_promise->reject(move(error)); }); + return {}; + }); + when_rejected([=](ErrorType&& error) { new_promise->reject(move(error)); }); + return new_promise; + } + + template, ResultType&&> MappingFunction> + NonnullRefPtr> map(MappingFunction mapping_function) + { + auto new_promise = ThreadedPromise::create(); + when_resolved([=, mapping_function = move(mapping_function)](ResultType&& result) -> ErrorOr { + new_promise->resolve(TRY(mapping_function(forward(result)))); + return {}; + }); + when_rejected([=](ErrorType&& error) { new_promise->reject(move(error)); }); + return new_promise; + } + +private: + template + static void deferred_handler_check(NonnullRefPtr self, F&& function) + { + Threading::MutexLocker locker { self->m_mutex }; + if (self->m_rejection_handler) { + function(); + return; + } + EventLoop::current().deferred_invoke([self, function = forward(function)]() mutable { + deferred_handler_check(self, move(function)); + }); + } + + template + void when_error_handler_is_ready(F function) + { + if (EventLoop::is_running()) { + deferred_handler_check(NonnullRefPtr(*this), move(function)); + } else { + // NOTE: Handlers should always be set almost immediately, so we can expect this + // to spin extremely briefly. Therefore, sleeping the thread should not be + // necessary. + while (true) { + Threading::MutexLocker locker { m_mutex }; + if (m_rejection_handler) + break; + } + VERIFY(m_rejection_handler); + function(); + } + } + + ThreadedPromise() = default; + ThreadedPromise(Object* parent) + : Object(parent) + { + } + + Function(ResultType&&)> m_resolution_handler; + Function m_rejection_handler; + Threading::Mutex m_mutex; + bool m_has_completed; +}; + +}