ladybird/Libraries/LibThreading/WorkerThread.h

172 lines
5.2 KiB
C++

/*
* Copyright (c) 2022, Gregory Bertilson <zaggy1024@gmail.com>
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#pragma once
#include <AK/Debug.h>
#include <AK/Variant.h>
#include <LibThreading/ConditionVariable.h>
#include <LibThreading/Mutex.h>
#include <LibThreading/Thread.h>
namespace Threading {
// Macro to allow single-line logging prints with fields that only exist in debug mode.
#if WORKER_THREAD_DEBUG
# define WORKER_LOG(args...) ({ dbgln(args); })
#else
# define WORKER_LOG(args...)
#endif
template<typename ErrorType>
class WorkerThread {
enum class State {
Idle,
Working,
Stopped,
};
using WorkerTask = Function<ErrorOr<void, ErrorType>()>;
using WorkerState = Variant<State, WorkerTask, ErrorType>;
public:
static ErrorOr<NonnullOwnPtr<WorkerThread>> create(StringView name)
{
auto worker_thread = TRY(adopt_nonnull_own_or_enomem(new (nothrow) WorkerThread()));
worker_thread->m_thread = TRY(Threading::Thread::try_create([&self = *worker_thread]() {
WORKER_LOG("Starting worker loop {}", self.m_id);
while (true) {
self.m_mutex.lock();
if (self.m_stop) {
WORKER_LOG("Exiting {}", self.m_id);
self.m_state = State::Stopped;
self.m_condition.broadcast();
self.m_mutex.unlock();
return 0;
}
if (self.m_state.template has<WorkerTask>()) {
auto task = move(self.m_state.template get<WorkerTask>());
self.m_state = State::Working;
self.m_mutex.unlock();
WORKER_LOG("Starting task on {}", self.m_id);
auto result = task();
if (result.is_error()) {
WORKER_LOG("Task finished on {} with error", self.m_id);
self.m_mutex.lock();
self.m_state = result.release_error();
self.m_condition.broadcast();
} else {
WORKER_LOG("Task finished successfully on {}", self.m_id);
self.m_mutex.lock();
self.m_state = State::Idle;
self.m_condition.broadcast();
}
}
WORKER_LOG("Awaiting new task in {}...", self.m_id);
self.m_condition.wait();
WORKER_LOG("Worker thread awoken in {}", self.m_id);
self.m_mutex.unlock();
}
return 0;
},
name));
worker_thread->m_thread->start();
return worker_thread;
}
~WorkerThread()
{
m_mutex.lock();
m_stop = true;
m_condition.broadcast();
while (!is_in_state(State::Stopped))
m_condition.wait();
m_mutex.unlock();
(void)m_thread->join();
WORKER_LOG("Worker thread {} joined successfully", m_id);
}
// Returns whether the task is starting.
bool start_task(WorkerTask&& task)
{
m_mutex.lock();
VERIFY(!is_in_state(State::Stopped));
bool start_work = false;
if (is_in_state(State::Idle)) {
start_work = true;
} else if (m_state.template has<ErrorType>()) {
WORKER_LOG("Starting task and ignoring previous error: {}", m_state.template get<ErrorType>().string_literal());
start_work = true;
}
if (start_work) {
WORKER_LOG("Queuing task on {}", m_id);
m_state = move(task);
m_condition.broadcast();
}
m_mutex.unlock();
return start_work;
}
ErrorOr<void, ErrorType> wait_until_task_is_finished()
{
WORKER_LOG("Waiting for task to finish on {}...", m_id);
m_mutex.lock();
while (true) {
if (m_state.template has<WorkerTask>() || is_in_state(State::Working)) {
m_condition.wait();
} else if (m_state.template has<ErrorType>()) {
auto error = move(m_state.template get<ErrorType>());
m_state = State::Idle;
m_mutex.unlock();
WORKER_LOG("Finished waiting with error on {}: {}", m_id, error.string_literal());
return error;
} else {
m_mutex.unlock();
WORKER_LOG("Finished waiting on {}", m_id);
return {};
}
}
m_mutex.unlock();
}
private:
#if WORKER_THREAD_DEBUG
static inline size_t current_id = 0;
#endif
WorkerThread()
: m_condition(m_mutex)
#if WORKER_THREAD_DEBUG
, m_id(current_id++)
#endif
{
}
WorkerThread(WorkerThread const&) = delete;
WorkerThread(WorkerThread&&) = delete;
// Must be called with the mutex locked.
bool is_in_state(State state)
{
return m_state.template has<State>() && m_state.template get<State>() == state;
}
RefPtr<Threading::Thread> m_thread;
Threading::Mutex m_mutex;
Threading::ConditionVariable m_condition;
WorkerState m_state { State::Idle };
bool m_stop { false };
#if WORKER_THREAD_DEBUG
size_t m_id;
#endif
};
#undef WORKER_LOG
}