WorkerThread.h 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. /*
  2. * Copyright (c) 2022, Gregory Bertilson <zaggy1024@gmail.com>
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #pragma once
  7. #include <AK/Debug.h>
  8. #include <AK/Variant.h>
  9. #include <LibThreading/ConditionVariable.h>
  10. #include <LibThreading/Mutex.h>
  11. #include <LibThreading/Thread.h>
  12. namespace Threading {
  13. // Macro to allow single-line logging prints with fields that only exist in debug mode.
  14. #if WORKER_THREAD_DEBUG
  15. # define WORKER_LOG(args...) ({ dbgln(args); })
  16. #else
  17. # define WORKER_LOG(args...)
  18. #endif
  19. template<typename ErrorType>
  20. class WorkerThread {
  21. enum class State {
  22. Idle,
  23. Working,
  24. Stopped,
  25. };
  26. using WorkerTask = Function<ErrorOr<void, ErrorType>()>;
  27. using WorkerState = Variant<State, WorkerTask, ErrorType>;
  28. public:
  29. static ErrorOr<NonnullOwnPtr<WorkerThread>> create(StringView name)
  30. {
  31. auto worker_thread = TRY(adopt_nonnull_own_or_enomem(new (nothrow) WorkerThread()));
  32. worker_thread->m_thread = TRY(Threading::Thread::try_create([&self = *worker_thread]() {
  33. WORKER_LOG("Starting worker loop {}", self.m_id);
  34. while (true) {
  35. self.m_mutex.lock();
  36. if (self.m_stop) {
  37. WORKER_LOG("Exiting {}", self.m_id);
  38. self.m_state = State::Stopped;
  39. self.m_condition.broadcast();
  40. self.m_mutex.unlock();
  41. return 0;
  42. }
  43. if (self.m_state.template has<WorkerTask>()) {
  44. auto task = move(self.m_state.template get<WorkerTask>());
  45. self.m_state = State::Working;
  46. self.m_mutex.unlock();
  47. WORKER_LOG("Starting task on {}", self.m_id);
  48. auto result = task();
  49. if (result.is_error()) {
  50. WORKER_LOG("Task finished on {} with error", self.m_id);
  51. self.m_mutex.lock();
  52. self.m_state = result.release_error();
  53. self.m_condition.broadcast();
  54. } else {
  55. WORKER_LOG("Task finished successfully on {}", self.m_id);
  56. self.m_mutex.lock();
  57. self.m_state = State::Idle;
  58. self.m_condition.broadcast();
  59. }
  60. }
  61. WORKER_LOG("Awaiting new task in {}...", self.m_id);
  62. self.m_condition.wait();
  63. WORKER_LOG("Worker thread awoken in {}", self.m_id);
  64. self.m_mutex.unlock();
  65. }
  66. return 0;
  67. },
  68. name));
  69. worker_thread->m_thread->start();
  70. return worker_thread;
  71. }
  72. ~WorkerThread()
  73. {
  74. m_mutex.lock();
  75. m_stop = true;
  76. m_condition.broadcast();
  77. while (!is_in_state(State::Stopped))
  78. m_condition.wait();
  79. m_mutex.unlock();
  80. (void)m_thread->join();
  81. WORKER_LOG("Worker thread {} joined successfully", m_id);
  82. }
  83. // Returns whether the task is starting.
  84. bool start_task(WorkerTask&& task)
  85. {
  86. m_mutex.lock();
  87. VERIFY(!is_in_state(State::Stopped));
  88. bool start_work = false;
  89. if (is_in_state(State::Idle)) {
  90. start_work = true;
  91. } else if (m_state.template has<ErrorType>()) {
  92. WORKER_LOG("Starting task and ignoring previous error: {}", m_state.template get<ErrorType>().string_literal());
  93. start_work = true;
  94. }
  95. if (start_work) {
  96. WORKER_LOG("Queuing task on {}", m_id);
  97. m_state = move(task);
  98. m_condition.broadcast();
  99. }
  100. m_mutex.unlock();
  101. return start_work;
  102. }
  103. ErrorOr<void, ErrorType> wait_until_task_is_finished()
  104. {
  105. WORKER_LOG("Waiting for task to finish on {}...", m_id);
  106. m_mutex.lock();
  107. while (true) {
  108. if (m_state.template has<WorkerTask>() || is_in_state(State::Working)) {
  109. m_condition.wait();
  110. } else if (m_state.template has<ErrorType>()) {
  111. auto error = move(m_state.template get<ErrorType>());
  112. m_state = State::Idle;
  113. m_mutex.unlock();
  114. WORKER_LOG("Finished waiting with error on {}: {}", m_id, error.string_literal());
  115. return error;
  116. } else {
  117. m_mutex.unlock();
  118. WORKER_LOG("Finished waiting on {}", m_id);
  119. return {};
  120. }
  121. }
  122. m_mutex.unlock();
  123. }
  124. private:
  125. #if WORKER_THREAD_DEBUG
  126. static inline size_t current_id = 0;
  127. #endif
  128. WorkerThread()
  129. : m_condition(m_mutex)
  130. #if WORKER_THREAD_DEBUG
  131. , m_id(current_id++)
  132. #endif
  133. {
  134. }
  135. WorkerThread(WorkerThread const&) = delete;
  136. WorkerThread(WorkerThread&&) = delete;
  137. // Must be called with the mutex locked.
  138. bool is_in_state(State state)
  139. {
  140. return m_state.template has<State>() && m_state.template get<State>() == state;
  141. }
  142. RefPtr<Threading::Thread> m_thread;
  143. Threading::Mutex m_mutex;
  144. Threading::ConditionVariable m_condition;
  145. WorkerState m_state { State::Idle };
  146. bool m_stop { false };
  147. #if WORKER_THREAD_DEBUG
  148. size_t m_id;
  149. #endif
  150. };
  151. #undef WORKER_LOG
  152. }