SafeFunction.h 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. /*
  2. * Copyright (c) 2016 Apple Inc. All rights reserved.
  3. * Copyright (c) 2021, Gunnar Beutner <gbeutner@serenityos.org>
  4. * Copyright (c) 2022, Andreas Kling <kling@serenityos.org>
  5. *
  6. * SPDX-License-Identifier: BSD-2-Clause
  7. */
  8. #pragma once
  9. #include <AK/Function.h>
  10. namespace JS {
  11. void register_safe_function_closure(void*, size_t);
  12. void unregister_safe_function_closure(void*, size_t);
  13. template<typename>
  14. class SafeFunction;
  15. template<typename Out, typename... In>
  16. class SafeFunction<Out(In...)> {
  17. AK_MAKE_NONCOPYABLE(SafeFunction);
  18. public:
  19. SafeFunction() = default;
  20. SafeFunction(std::nullptr_t)
  21. {
  22. }
  23. ~SafeFunction()
  24. {
  25. clear(false);
  26. }
  27. void register_closure()
  28. {
  29. if (!m_size)
  30. return;
  31. if (auto* wrapper = callable_wrapper())
  32. register_safe_function_closure(wrapper, m_size);
  33. }
  34. void unregister_closure()
  35. {
  36. if (!m_size)
  37. return;
  38. if (auto* wrapper = callable_wrapper())
  39. unregister_safe_function_closure(wrapper, m_size);
  40. }
  41. template<typename CallableType>
  42. SafeFunction(CallableType&& callable) requires((AK::IsFunctionObject<CallableType> && IsCallableWithArguments<CallableType, In...> && !IsSame<RemoveCVReference<CallableType>, SafeFunction>))
  43. {
  44. init_with_callable(forward<CallableType>(callable), CallableKind::FunctionObject);
  45. }
  46. template<typename FunctionType>
  47. SafeFunction(FunctionType f) requires((AK::IsFunctionPointer<FunctionType> && IsCallableWithArguments<RemovePointer<FunctionType>, In...> && !IsSame<RemoveCVReference<FunctionType>, SafeFunction>))
  48. {
  49. init_with_callable(move(f), CallableKind::FunctionPointer);
  50. }
  51. SafeFunction(SafeFunction&& other)
  52. {
  53. move_from(move(other));
  54. }
  55. // Note: Despite this method being const, a mutable lambda _may_ modify its own captures.
  56. Out operator()(In... in) const
  57. {
  58. auto* wrapper = callable_wrapper();
  59. VERIFY(wrapper);
  60. ++m_call_nesting_level;
  61. ScopeGuard guard([this] {
  62. if (--m_call_nesting_level == 0 && m_deferred_clear)
  63. const_cast<SafeFunction*>(this)->clear(false);
  64. });
  65. return wrapper->call(forward<In>(in)...);
  66. }
  67. explicit operator bool() const { return !!callable_wrapper(); }
  68. template<typename CallableType>
  69. SafeFunction& operator=(CallableType&& callable) requires((AK::IsFunctionObject<CallableType> && IsCallableWithArguments<CallableType, In...>))
  70. {
  71. clear();
  72. init_with_callable(forward<CallableType>(callable));
  73. return *this;
  74. }
  75. template<typename FunctionType>
  76. SafeFunction& operator=(FunctionType f) requires((AK::IsFunctionPointer<FunctionType> && IsCallableWithArguments<RemovePointer<FunctionType>, In...>))
  77. {
  78. clear();
  79. if (f)
  80. init_with_callable(move(f));
  81. return *this;
  82. }
  83. SafeFunction& operator=(std::nullptr_t)
  84. {
  85. clear();
  86. return *this;
  87. }
  88. SafeFunction& operator=(SafeFunction&& other)
  89. {
  90. if (this != &other) {
  91. clear();
  92. move_from(move(other));
  93. }
  94. return *this;
  95. }
  96. private:
  97. enum class CallableKind {
  98. FunctionPointer,
  99. FunctionObject,
  100. };
  101. class CallableWrapperBase {
  102. public:
  103. virtual ~CallableWrapperBase() = default;
  104. // Note: This is not const to allow storing mutable lambdas.
  105. virtual Out call(In...) = 0;
  106. virtual void destroy() = 0;
  107. virtual void init_and_swap(u8*, size_t) = 0;
  108. };
  109. template<typename CallableType>
  110. class CallableWrapper final : public CallableWrapperBase {
  111. AK_MAKE_NONMOVABLE(CallableWrapper);
  112. AK_MAKE_NONCOPYABLE(CallableWrapper);
  113. public:
  114. explicit CallableWrapper(CallableType&& callable)
  115. : m_callable(move(callable))
  116. {
  117. }
  118. Out call(In... in) final override
  119. {
  120. return m_callable(forward<In>(in)...);
  121. }
  122. void destroy() final override
  123. {
  124. delete this;
  125. }
  126. // NOLINTNEXTLINE(readability-non-const-parameter) False positive; destination is used in a placement new expression
  127. void init_and_swap(u8* destination, size_t size) final override
  128. {
  129. VERIFY(size >= sizeof(CallableWrapper));
  130. new (destination) CallableWrapper { move(m_callable) };
  131. }
  132. private:
  133. CallableType m_callable;
  134. };
  135. enum class FunctionKind {
  136. NullPointer,
  137. Inline,
  138. Outline,
  139. };
  140. CallableWrapperBase* callable_wrapper() const
  141. {
  142. switch (m_kind) {
  143. case FunctionKind::NullPointer:
  144. return nullptr;
  145. case FunctionKind::Inline:
  146. return bit_cast<CallableWrapperBase*>(&m_storage);
  147. case FunctionKind::Outline:
  148. return *bit_cast<CallableWrapperBase**>(&m_storage);
  149. default:
  150. VERIFY_NOT_REACHED();
  151. }
  152. }
  153. void clear(bool may_defer = true)
  154. {
  155. bool called_from_inside_function = m_call_nesting_level > 0;
  156. // NOTE: This VERIFY could fail because a Function is destroyed from within itself.
  157. VERIFY(may_defer || !called_from_inside_function);
  158. if (called_from_inside_function && may_defer) {
  159. m_deferred_clear = true;
  160. return;
  161. }
  162. m_deferred_clear = false;
  163. auto* wrapper = callable_wrapper();
  164. if (m_kind == FunctionKind::Inline) {
  165. VERIFY(wrapper);
  166. wrapper->~CallableWrapperBase();
  167. unregister_closure();
  168. } else if (m_kind == FunctionKind::Outline) {
  169. VERIFY(wrapper);
  170. wrapper->destroy();
  171. unregister_closure();
  172. }
  173. m_kind = FunctionKind::NullPointer;
  174. }
  175. template<typename Callable>
  176. void init_with_callable(Callable&& callable, CallableKind kind)
  177. {
  178. VERIFY(m_call_nesting_level == 0);
  179. VERIFY(m_kind == FunctionKind::NullPointer);
  180. using WrapperType = CallableWrapper<Callable>;
  181. if constexpr (sizeof(WrapperType) > inline_capacity) {
  182. *bit_cast<CallableWrapperBase**>(&m_storage) = new WrapperType(forward<Callable>(callable));
  183. m_kind = FunctionKind::Outline;
  184. } else {
  185. new (m_storage) WrapperType(forward<Callable>(callable));
  186. m_kind = FunctionKind::Inline;
  187. }
  188. if (kind == CallableKind::FunctionObject)
  189. m_size = sizeof(WrapperType);
  190. else
  191. m_size = 0;
  192. register_closure();
  193. }
  194. void move_from(SafeFunction&& other)
  195. {
  196. VERIFY(m_call_nesting_level == 0);
  197. VERIFY(other.m_call_nesting_level == 0);
  198. VERIFY(m_kind == FunctionKind::NullPointer);
  199. auto* other_wrapper = other.callable_wrapper();
  200. m_size = other.m_size;
  201. switch (other.m_kind) {
  202. case FunctionKind::NullPointer:
  203. break;
  204. case FunctionKind::Inline:
  205. other.unregister_closure();
  206. other_wrapper->init_and_swap(m_storage, inline_capacity);
  207. m_kind = FunctionKind::Inline;
  208. register_closure();
  209. break;
  210. case FunctionKind::Outline:
  211. *bit_cast<CallableWrapperBase**>(&m_storage) = other_wrapper;
  212. m_kind = FunctionKind::Outline;
  213. break;
  214. default:
  215. VERIFY_NOT_REACHED();
  216. }
  217. other.m_kind = FunctionKind::NullPointer;
  218. }
  219. FunctionKind m_kind { FunctionKind::NullPointer };
  220. bool m_deferred_clear { false };
  221. mutable Atomic<u16> m_call_nesting_level { 0 };
  222. size_t m_size { 0 };
  223. // Empirically determined to fit most lambdas and functions.
  224. static constexpr size_t inline_capacity = 4 * sizeof(void*);
  225. alignas(max(alignof(CallableWrapperBase), alignof(CallableWrapperBase*))) u8 m_storage[inline_capacity];
  226. };
  227. }