From 131c3f50dec480e5bc73554d48e4e6531916129f Mon Sep 17 00:00:00 2001 From: Andreas Kling Date: Sat, 24 Sep 2022 11:56:43 +0200 Subject: [PATCH] LibJS: Add JS::SafeFunction, like Function but protects captures from GC SafeFunction automatically registers its closure memory area in a place where the JS garbage collector can find it. This means that you can capture JS::Value and arbitrary pointers into the GC heap in closures, as long as you're using a SafeFunction, and the GC will not zap those values! There's probably some performance impact from this, and there's a lot of things that could be nicer/smarter about it, but let's build something that ensures safety first, and we can worry about performance later. :^) --- Meta/check-style.py | 1 + Userland/Libraries/LibJS/Heap/Heap.cpp | 31 +++ Userland/Libraries/LibJS/SafeFunction.h | 250 ++++++++++++++++++++++++ 3 files changed, 282 insertions(+) create mode 100644 Userland/Libraries/LibJS/SafeFunction.h diff --git a/Meta/check-style.py b/Meta/check-style.py index 5baf25a8578..72cf961bf5a 100755 --- a/Meta/check-style.py +++ b/Meta/check-style.py @@ -22,6 +22,7 @@ GOOD_LICENSE_HEADER_PATTERN = re.compile( LICENSE_HEADER_CHECK_EXCLUDES = { 'AK/Checked.h', 'AK/Function.h', + 'Userland/Libraries/LibJS/SafeFunction.h', 'Userland/Libraries/LibC/elf.h', 'Userland/Libraries/LibCodeComprehension/Cpp/Tests/', 'Userland/Libraries/LibCpp/Tests/parser/', diff --git a/Userland/Libraries/LibJS/Heap/Heap.cpp b/Userland/Libraries/LibJS/Heap/Heap.cpp index 97b416980aa..a6eee105a6e 100644 --- a/Userland/Libraries/LibJS/Heap/Heap.cpp +++ b/Userland/Libraries/LibJS/Heap/Heap.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #ifdef __serenity__ @@ -29,6 +30,9 @@ namespace JS { static int gc_perf_string_id; #endif +// NOTE: We keep a per-thread list of custom ranges. This hinges on the assumption that there is one JS VM per thread. +static __thread HashMap* s_custom_ranges_for_conservative_scan = nullptr; + Heap::Heap(VM& vm) : m_vm(vm) { @@ -164,6 +168,16 @@ __attribute__((no_sanitize("address"))) void Heap::gather_conservative_roots(Has add_possible_value(data); } + // NOTE: If we have any custom ranges registered, scan those as well. + // This is where JS::SafeFunction closures get marked. + if (s_custom_ranges_for_conservative_scan) { + for (auto& custom_range : *s_custom_ranges_for_conservative_scan) { + for (size_t i = 0; i < (custom_range.value / sizeof(FlatPtr)); ++i) { + add_possible_value(custom_range.key[i]); + } + } + } + HashTable all_live_heap_blocks; for_each_block([&](auto& block) { all_live_heap_blocks.set(&block); @@ -349,4 +363,21 @@ void Heap::uproot_cell(Cell* cell) m_uprooted_cells.append(cell); } +void register_safe_function_closure(void* base, size_t size) +{ + if (!s_custom_ranges_for_conservative_scan) { + // FIXME: This per-thread HashMap is currently leaked on thread exit. + s_custom_ranges_for_conservative_scan = new HashMap; + } + auto result = s_custom_ranges_for_conservative_scan->set(reinterpret_cast(base), size); + VERIFY(result == AK::HashSetResult::InsertedNewEntry); +} + +void unregister_safe_function_closure(void* base, size_t) +{ + VERIFY(s_custom_ranges_for_conservative_scan); + bool did_remove = s_custom_ranges_for_conservative_scan->remove(reinterpret_cast(base)); + VERIFY(did_remove); +} + } diff --git a/Userland/Libraries/LibJS/SafeFunction.h b/Userland/Libraries/LibJS/SafeFunction.h new file mode 100644 index 00000000000..4dfb9599b11 --- /dev/null +++ b/Userland/Libraries/LibJS/SafeFunction.h @@ -0,0 +1,250 @@ +/* + * Copyright (c) 2016 Apple Inc. All rights reserved. + * Copyright (c) 2021, Gunnar Beutner + * Copyright (c) 2022, Andreas Kling + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include + +namespace JS { + +void register_safe_function_closure(void*, size_t); +void unregister_safe_function_closure(void*, size_t); + +template +class SafeFunction; + +template +class SafeFunction { + AK_MAKE_NONCOPYABLE(SafeFunction); + +public: + SafeFunction() = default; + SafeFunction(std::nullptr_t) + { + } + + ~SafeFunction() + { + clear(false); + } + + void register_closure() + { + if (auto* wrapper = callable_wrapper()) + register_safe_function_closure(wrapper, m_size); + } + + void unregister_closure() + { + if (auto* wrapper = callable_wrapper()) + unregister_safe_function_closure(wrapper, m_size); + } + + template + SafeFunction(CallableType&& callable) requires((AK::IsFunctionObject && IsCallableWithArguments && !IsSame, SafeFunction>)) + { + init_with_callable(forward(callable)); + } + + template + SafeFunction(FunctionType f) requires((AK::IsFunctionPointer && IsCallableWithArguments, In...> && !IsSame, SafeFunction>)) + { + init_with_callable(move(f)); + } + + SafeFunction(SafeFunction&& other) + { + move_from(move(other)); + } + + // Note: Despite this method being const, a mutable lambda _may_ modify its own captures. + Out operator()(In... in) const + { + auto* wrapper = callable_wrapper(); + VERIFY(wrapper); + ++m_call_nesting_level; + ScopeGuard guard([this] { + if (--m_call_nesting_level == 0 && m_deferred_clear) + const_cast(this)->clear(false); + }); + return wrapper->call(forward(in)...); + } + + explicit operator bool() const { return !!callable_wrapper(); } + + template + SafeFunction& operator=(CallableType&& callable) requires((AK::IsFunctionObject && IsCallableWithArguments)) + { + clear(); + init_with_callable(forward(callable)); + return *this; + } + + template + SafeFunction& operator=(FunctionType f) requires((AK::IsFunctionPointer && IsCallableWithArguments, In...>)) + { + clear(); + if (f) + init_with_callable(move(f)); + return *this; + } + + SafeFunction& operator=(std::nullptr_t) + { + clear(); + return *this; + } + + SafeFunction& operator=(SafeFunction&& other) + { + if (this != &other) { + clear(); + move_from(move(other)); + } + return *this; + } + +private: + class CallableWrapperBase { + public: + virtual ~CallableWrapperBase() = default; + // Note: This is not const to allow storing mutable lambdas. + virtual Out call(In...) = 0; + virtual void destroy() = 0; + virtual void init_and_swap(u8*, size_t) = 0; + }; + + template + class CallableWrapper final : public CallableWrapperBase { + AK_MAKE_NONMOVABLE(CallableWrapper); + AK_MAKE_NONCOPYABLE(CallableWrapper); + + public: + explicit CallableWrapper(CallableType&& callable) + : m_callable(move(callable)) + { + } + + Out call(In... in) final override + { + return m_callable(forward(in)...); + } + + void destroy() final override + { + delete this; + } + + // NOLINTNEXTLINE(readability-non-const-parameter) False positive; destination is used in a placement new expression + void init_and_swap(u8* destination, size_t size) final override + { + VERIFY(size >= sizeof(CallableWrapper)); + new (destination) CallableWrapper { move(m_callable) }; + } + + private: + CallableType m_callable; + }; + + enum class FunctionKind { + NullPointer, + Inline, + Outline, + }; + + CallableWrapperBase* callable_wrapper() const + { + switch (m_kind) { + case FunctionKind::NullPointer: + return nullptr; + case FunctionKind::Inline: + return bit_cast(&m_storage); + case FunctionKind::Outline: + return *bit_cast(&m_storage); + default: + VERIFY_NOT_REACHED(); + } + } + + void clear(bool may_defer = true) + { + bool called_from_inside_function = m_call_nesting_level > 0; + // NOTE: This VERIFY could fail because a Function is destroyed from within itself. + VERIFY(may_defer || !called_from_inside_function); + if (called_from_inside_function && may_defer) { + m_deferred_clear = true; + return; + } + m_deferred_clear = false; + auto* wrapper = callable_wrapper(); + if (m_kind == FunctionKind::Inline) { + VERIFY(wrapper); + wrapper->~CallableWrapperBase(); + unregister_closure(); + } else if (m_kind == FunctionKind::Outline) { + VERIFY(wrapper); + wrapper->destroy(); + unregister_closure(); + } + m_kind = FunctionKind::NullPointer; + } + + template + void init_with_callable(Callable&& callable) + { + VERIFY(m_call_nesting_level == 0); + VERIFY(m_kind == FunctionKind::NullPointer); + using WrapperType = CallableWrapper; + if constexpr (sizeof(WrapperType) > inline_capacity) { + *bit_cast(&m_storage) = new WrapperType(forward(callable)); + m_kind = FunctionKind::Outline; + } else { + new (m_storage) WrapperType(forward(callable)); + m_kind = FunctionKind::Inline; + } + m_size = sizeof(WrapperType); + register_closure(); + } + + void move_from(SafeFunction&& other) + { + VERIFY(m_call_nesting_level == 0); + VERIFY(other.m_call_nesting_level == 0); + VERIFY(m_kind == FunctionKind::NullPointer); + auto* other_wrapper = other.callable_wrapper(); + m_size = other.m_size; + switch (other.m_kind) { + case FunctionKind::NullPointer: + break; + case FunctionKind::Inline: + other.unregister_closure(); + other_wrapper->init_and_swap(m_storage, inline_capacity); + m_kind = FunctionKind::Inline; + register_closure(); + break; + case FunctionKind::Outline: + *bit_cast(&m_storage) = other_wrapper; + m_kind = FunctionKind::Outline; + break; + default: + VERIFY_NOT_REACHED(); + } + other.m_kind = FunctionKind::NullPointer; + } + + FunctionKind m_kind { FunctionKind::NullPointer }; + bool m_deferred_clear { false }; + mutable Atomic m_call_nesting_level { 0 }; + size_t m_size { 0 }; + + // Empirically determined to fit most lambdas and functions. + static constexpr size_t inline_capacity = 4 * sizeof(void*); + alignas(max(alignof(CallableWrapperBase), alignof(CallableWrapperBase*))) u8 m_storage[inline_capacity]; +}; + +}