SafeFunction.h 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. /*
  2. * Copyright (c) 2016 Apple Inc. All rights reserved.
  3. * Copyright (c) 2021, Gunnar Beutner <gbeutner@serenityos.org>
  4. * Copyright (c) 2022, Andreas Kling <kling@serenityos.org>
  5. *
  6. * SPDX-License-Identifier: BSD-2-Clause
  7. */
  8. #pragma once
  9. #include <AK/Function.h>
  10. namespace JS {
  11. void register_safe_function_closure(void*, size_t);
  12. void unregister_safe_function_closure(void*, size_t);
  13. template<typename>
  14. class SafeFunction;
  15. template<typename Out, typename... In>
  16. class SafeFunction<Out(In...)> {
  17. AK_MAKE_NONCOPYABLE(SafeFunction);
  18. public:
  19. SafeFunction() = default;
  20. SafeFunction(nullptr_t)
  21. {
  22. }
  23. ~SafeFunction()
  24. {
  25. clear(false);
  26. }
  27. void register_closure()
  28. {
  29. if (!m_size)
  30. return;
  31. if (auto* wrapper = callable_wrapper())
  32. register_safe_function_closure(wrapper, m_size);
  33. }
  34. void unregister_closure()
  35. {
  36. if (!m_size)
  37. return;
  38. if (auto* wrapper = callable_wrapper())
  39. unregister_safe_function_closure(wrapper, m_size);
  40. }
  41. template<typename CallableType>
  42. SafeFunction(CallableType&& callable)
  43. requires((AK::IsFunctionObject<CallableType> && IsCallableWithArguments<CallableType, Out, In...> && !IsSame<RemoveCVReference<CallableType>, SafeFunction>))
  44. {
  45. init_with_callable(forward<CallableType>(callable), CallableKind::FunctionObject);
  46. }
  47. template<typename FunctionType>
  48. SafeFunction(FunctionType f)
  49. requires((AK::IsFunctionPointer<FunctionType> && IsCallableWithArguments<RemovePointer<FunctionType>, Out, In...> && !IsSame<RemoveCVReference<FunctionType>, SafeFunction>))
  50. {
  51. init_with_callable(move(f), CallableKind::FunctionPointer);
  52. }
  53. SafeFunction(SafeFunction&& other)
  54. {
  55. move_from(move(other));
  56. }
  57. // Note: Despite this method being const, a mutable lambda _may_ modify its own captures.
  58. Out operator()(In... in) const
  59. {
  60. auto* wrapper = callable_wrapper();
  61. VERIFY(wrapper);
  62. ++m_call_nesting_level;
  63. ScopeGuard guard([this] {
  64. if (--m_call_nesting_level == 0 && m_deferred_clear)
  65. const_cast<SafeFunction*>(this)->clear(false);
  66. });
  67. return wrapper->call(forward<In>(in)...);
  68. }
  69. explicit operator bool() const { return !!callable_wrapper(); }
  70. template<typename CallableType>
  71. SafeFunction& operator=(CallableType&& callable)
  72. requires((AK::IsFunctionObject<CallableType> && IsCallableWithArguments<CallableType, Out, In...>))
  73. {
  74. clear();
  75. init_with_callable(forward<CallableType>(callable), CallableKind::FunctionObject);
  76. return *this;
  77. }
  78. template<typename FunctionType>
  79. SafeFunction& operator=(FunctionType f)
  80. requires((AK::IsFunctionPointer<FunctionType> && IsCallableWithArguments<RemovePointer<FunctionType>, Out, In...>))
  81. {
  82. clear();
  83. if (f)
  84. init_with_callable(move(f), CallableKind::FunctionPointer);
  85. return *this;
  86. }
  87. SafeFunction& operator=(nullptr_t)
  88. {
  89. clear();
  90. return *this;
  91. }
  92. SafeFunction& operator=(SafeFunction&& other)
  93. {
  94. if (this != &other) {
  95. clear();
  96. move_from(move(other));
  97. }
  98. return *this;
  99. }
  100. private:
  101. enum class CallableKind {
  102. FunctionPointer,
  103. FunctionObject,
  104. };
  105. class CallableWrapperBase {
  106. public:
  107. virtual ~CallableWrapperBase() = default;
  108. // Note: This is not const to allow storing mutable lambdas.
  109. virtual Out call(In...) = 0;
  110. virtual void destroy() = 0;
  111. virtual void init_and_swap(u8*, size_t) = 0;
  112. };
  113. template<typename CallableType>
  114. class CallableWrapper final : public CallableWrapperBase {
  115. AK_MAKE_NONMOVABLE(CallableWrapper);
  116. AK_MAKE_NONCOPYABLE(CallableWrapper);
  117. public:
  118. explicit CallableWrapper(CallableType&& callable)
  119. : m_callable(move(callable))
  120. {
  121. }
  122. Out call(In... in) final override
  123. {
  124. return m_callable(forward<In>(in)...);
  125. }
  126. void destroy() final override
  127. {
  128. delete this;
  129. }
  130. // NOLINTNEXTLINE(readability-non-const-parameter) False positive; destination is used in a placement new expression
  131. void init_and_swap(u8* destination, size_t size) final override
  132. {
  133. VERIFY(size >= sizeof(CallableWrapper));
  134. new (destination) CallableWrapper { move(m_callable) };
  135. }
  136. private:
  137. CallableType m_callable;
  138. };
  139. enum class FunctionKind {
  140. NullPointer,
  141. Inline,
  142. Outline,
  143. };
  144. CallableWrapperBase* callable_wrapper() const
  145. {
  146. switch (m_kind) {
  147. case FunctionKind::NullPointer:
  148. return nullptr;
  149. case FunctionKind::Inline:
  150. return bit_cast<CallableWrapperBase*>(&m_storage);
  151. case FunctionKind::Outline:
  152. return *bit_cast<CallableWrapperBase**>(&m_storage);
  153. default:
  154. VERIFY_NOT_REACHED();
  155. }
  156. }
  157. void clear(bool may_defer = true)
  158. {
  159. bool called_from_inside_function = m_call_nesting_level > 0;
  160. // NOTE: This VERIFY could fail because a Function is destroyed from within itself.
  161. VERIFY(may_defer || !called_from_inside_function);
  162. if (called_from_inside_function && may_defer) {
  163. m_deferred_clear = true;
  164. return;
  165. }
  166. m_deferred_clear = false;
  167. auto* wrapper = callable_wrapper();
  168. if (m_kind == FunctionKind::Inline) {
  169. VERIFY(wrapper);
  170. wrapper->~CallableWrapperBase();
  171. unregister_closure();
  172. } else if (m_kind == FunctionKind::Outline) {
  173. VERIFY(wrapper);
  174. wrapper->destroy();
  175. unregister_closure();
  176. }
  177. m_kind = FunctionKind::NullPointer;
  178. }
  179. template<typename Callable>
  180. void init_with_callable(Callable&& callable, CallableKind kind)
  181. {
  182. VERIFY(m_call_nesting_level == 0);
  183. VERIFY(m_kind == FunctionKind::NullPointer);
  184. using WrapperType = CallableWrapper<Callable>;
  185. if constexpr (sizeof(WrapperType) > inline_capacity) {
  186. *bit_cast<CallableWrapperBase**>(&m_storage) = new WrapperType(forward<Callable>(callable));
  187. m_kind = FunctionKind::Outline;
  188. } else {
  189. new (m_storage) WrapperType(forward<Callable>(callable));
  190. m_kind = FunctionKind::Inline;
  191. }
  192. if (kind == CallableKind::FunctionObject)
  193. m_size = sizeof(WrapperType);
  194. else
  195. m_size = 0;
  196. register_closure();
  197. }
  198. void move_from(SafeFunction&& other)
  199. {
  200. VERIFY(m_call_nesting_level == 0);
  201. VERIFY(other.m_call_nesting_level == 0);
  202. VERIFY(m_kind == FunctionKind::NullPointer);
  203. auto* other_wrapper = other.callable_wrapper();
  204. m_size = other.m_size;
  205. switch (other.m_kind) {
  206. case FunctionKind::NullPointer:
  207. break;
  208. case FunctionKind::Inline:
  209. other.unregister_closure();
  210. other_wrapper->init_and_swap(m_storage, inline_capacity);
  211. m_kind = FunctionKind::Inline;
  212. register_closure();
  213. break;
  214. case FunctionKind::Outline:
  215. *bit_cast<CallableWrapperBase**>(&m_storage) = other_wrapper;
  216. m_kind = FunctionKind::Outline;
  217. break;
  218. default:
  219. VERIFY_NOT_REACHED();
  220. }
  221. other.m_kind = FunctionKind::NullPointer;
  222. }
  223. FunctionKind m_kind { FunctionKind::NullPointer };
  224. bool m_deferred_clear { false };
  225. mutable Atomic<u16> m_call_nesting_level { 0 };
  226. size_t m_size { 0 };
  227. // Empirically determined to fit most lambdas and functions.
  228. static constexpr size_t inline_capacity = 4 * sizeof(void*);
  229. alignas(max(alignof(CallableWrapperBase), alignof(CallableWrapperBase*))) u8 m_storage[inline_capacity];
  230. };
  231. }