Browse Source

ClangPlugins: Check for strong root fields in GC allocated objects

GC-allocated objects should never have JS::SafeFunction/JS::Handle
fields.

For now the plugin only emits warnings here, as there are many cases
of this occurring in the codebase that aren't trivial to fix. It is also
behind a CMake flag since it is a _very_ loud warning.
Matthew Olsson 1 year ago
parent
commit
5740f93ef4

+ 1 - 0
Meta/CMake/common_options.cmake

@@ -33,3 +33,4 @@ serenity_option(SERENITY_CACHE_DIR "${PROJECT_BINARY_DIR}/../caches" CACHE PATH
 serenity_option(ENABLE_NETWORK_DOWNLOADS ON CACHE BOOL "Allow downloads of required files. If OFF, required files must already be present in SERENITY_CACHE_DIR")
 
 serenity_option(ENABLE_CLANG_PLUGINS OFF CACHE BOOL "Enable building with the Clang plugins")
+serenity_option(ENABLE_CLANG_PLUGINS_INVALID_FUNCTION_MEMBERS OFF CACHE BOOL "Enable detecting invalid function types as members of GC-allocated objects")

+ 3 - 0
Meta/Lagom/ClangPlugins/CMakeLists.txt

@@ -19,6 +19,9 @@ function(depend_on_clang_plugin target_name plugin_name)
         add_dependencies(${target_name} ${plugin_name}Target)
     endif()
     target_compile_options(${target_name} INTERFACE -fplugin=$<TARGET_FILE:Lagom::${plugin_name}>)
+    if (${ENABLE_CLANG_PLUGINS_INVALID_FUNCTION_MEMBERS})
+        target_compile_options(${target_name} INTERFACE -fplugin-arg-libjs_gc_scanner-detect-invalid-function-members)
+    endif()
 endfunction()
 
 clang_plugin(LambdaCaptureClangPlugin SOURCES LambdaCapturePluginAction.cpp)

+ 96 - 61
Meta/Lagom/ClangPlugins/LibJSGCPluginAction.cpp

@@ -57,7 +57,16 @@ std::vector<clang::QualType> get_all_qualified_types(clang::QualType const& type
     if (auto const* template_specialization = type->getAs<clang::TemplateSpecializationType>()) {
         auto specialization_name = template_specialization->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
         // Do not unwrap GCPtr/NonnullGCPtr/MarkedVector
-        if (specialization_name == "JS::GCPtr" || specialization_name == "JS::NonnullGCPtr" || specialization_name == "JS::RawGCPtr" || specialization_name == "JS::MarkedVector") {
+        static std::unordered_set<std::string> gc_relevant_type_names {
+            "JS::GCPtr",
+            "JS::NonnullGCPtr",
+            "JS::RawGCPtr",
+            "JS::MarkedVector",
+            "JS::Handle",
+            "JS::SafeFunction",
+        };
+
+        if (gc_relevant_type_names.contains(specialization_name)) {
             qualified_types.push_back(type);
         } else {
             auto const template_arguments = template_specialization->template_arguments();
@@ -75,65 +84,75 @@ std::vector<clang::QualType> get_all_qualified_types(clang::QualType const& type
 
     return qualified_types;
 }
+enum class OuterType {
+    GCPtr,
+    RawGCPtr,
+    Handle,
+    SafeFunction,
+    Ptr,
+    Ref,
+};
 
-struct FieldValidationResult {
-    bool is_valid { false };
-    bool is_wrapped_in_gcptr { false };
-    bool needs_visiting { false };
+struct QualTypeGCInfo {
+    std::optional<OuterType> outer_type { {} };
+    bool base_type_inherits_from_cell { false };
 };
 
-FieldValidationResult validate_field(clang::FieldDecl const* field_decl)
+std::optional<QualTypeGCInfo> validate_qualified_type(clang::QualType const& type)
+{
+    if (auto const* pointer_decl = type->getAs<clang::PointerType>()) {
+        if (auto const* pointee = pointer_decl->getPointeeCXXRecordDecl())
+            return QualTypeGCInfo { OuterType::Ptr, record_inherits_from_cell(*pointee) };
+    } else if (auto const* reference_decl = type->getAs<clang::ReferenceType>()) {
+        if (auto const* pointee = reference_decl->getPointeeCXXRecordDecl())
+            return QualTypeGCInfo { OuterType::Ref, record_inherits_from_cell(*pointee) };
+    } else if (auto const* specialization = type->getAs<clang::TemplateSpecializationType>()) {
+        auto template_type_name = specialization->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
+
+        OuterType outer_type;
+        if (template_type_name == "JS::GCPtr" || template_type_name == "JS::NonnullGCPtr") {
+            outer_type = OuterType::GCPtr;
+        } else if (template_type_name == "JS::RawGCPtr") {
+            outer_type = OuterType::RawGCPtr;
+        } else if (template_type_name == "JS::Handle") {
+            outer_type = OuterType::Handle;
+        } else if (template_type_name == "JS::SafeFunction") {
+            return QualTypeGCInfo { OuterType::SafeFunction, false };
+        } else {
+            return {};
+        }
+
+        auto template_args = specialization->template_arguments();
+        if (template_args.size() != 1)
+            return {}; // Not really valid, but will produce a compilation error anyway
+
+        auto const& type_arg = template_args[0];
+        auto const* record_type = type_arg.getAsType()->getAs<clang::RecordType>();
+        if (!record_type)
+            return {};
+
+        auto const* record_decl = record_type->getAsCXXRecordDecl();
+        if (!record_decl->hasDefinition())
+            return {};
+
+        return QualTypeGCInfo { outer_type, record_inherits_from_cell(*record_decl) };
+    }
+
+    return {};
+}
+
+std::optional<QualTypeGCInfo> validate_field_qualified_type(clang::FieldDecl const* field_decl)
 {
     auto type = field_decl->getType();
     if (auto const* elaborated_type = llvm::dyn_cast<clang::ElaboratedType>(type.getTypePtr()))
         type = elaborated_type->desugar();
 
-    FieldValidationResult result { .is_valid = true };
-
     for (auto const& qualified_type : get_all_qualified_types(type)) {
-        if (auto const* pointer_decl = qualified_type->getAs<clang::PointerType>()) {
-            if (auto const* pointee = pointer_decl->getPointeeCXXRecordDecl()) {
-                if (record_inherits_from_cell(*pointee)) {
-                    result.is_valid = false;
-                    result.is_wrapped_in_gcptr = false;
-                    result.needs_visiting = true;
-                    return result;
-                }
-            }
-        } else if (auto const* reference_decl = qualified_type->getAs<clang::ReferenceType>()) {
-            if (auto const* pointee = reference_decl->getPointeeCXXRecordDecl()) {
-                if (record_inherits_from_cell(*pointee)) {
-                    result.is_valid = false;
-                    result.is_wrapped_in_gcptr = false;
-                    result.needs_visiting = true;
-                    return result;
-                }
-            }
-        } else if (auto const* specialization = qualified_type->getAs<clang::TemplateSpecializationType>()) {
-            auto template_type_name = specialization->getTemplateName().getAsTemplateDecl()->getName();
-            if (template_type_name != "GCPtr" && template_type_name != "NonnullGCPtr" && template_type_name != "RawGCPtr")
-                return result;
-
-            auto const template_args = specialization->template_arguments();
-            if (template_args.size() != 1)
-                return result; // Not really valid, but will produce a compilation error anyway
-
-            auto const& type_arg = template_args[0];
-            auto const* record_type = type_arg.getAsType()->getAs<clang::RecordType>();
-            if (!record_type)
-                return result;
-
-            auto const* record_decl = record_type->getAsCXXRecordDecl();
-            if (!record_decl->hasDefinition())
-                return result;
-
-            result.is_wrapped_in_gcptr = true;
-            result.is_valid = record_inherits_from_cell(*record_decl);
-            result.needs_visiting = template_type_name != "RawGCPtr";
-        }
+        if (auto error = validate_qualified_type(qualified_type))
+            return error;
     }
 
-    return result;
+    return {};
 }
 
 bool LibJSGCVisitor::VisitCXXRecordDecl(clang::CXXRecordDecl* record)
@@ -151,17 +170,20 @@ bool LibJSGCVisitor::VisitCXXRecordDecl(clang::CXXRecordDecl* record)
 
     auto& diag_engine = m_context.getDiagnostics();
     std::vector<clang::FieldDecl const*> fields_that_need_visiting;
+    auto record_is_cell = record_inherits_from_cell(*record);
 
     for (clang::FieldDecl const* field : record->fields()) {
-        auto validation_results = validate_field(field);
-        if (!validation_results.is_valid) {
-            if (validation_results.is_wrapped_in_gcptr) {
-                auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Error, "Specialization type must inherit from JS::Cell");
-                diag_engine.Report(field->getLocation(), diag_id);
-            } else {
+        auto validation_results = validate_field_qualified_type(field);
+        if (!validation_results)
+            continue;
+
+        auto [outer_type, base_type_inherits_from_cell] = *validation_results;
+
+        if (outer_type == OuterType::Ptr || outer_type == OuterType::Ref) {
+            if (base_type_inherits_from_cell) {
                 auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Error, "%0 to JS::Cell type should be wrapped in %1");
                 auto builder = diag_engine.Report(field->getLocation(), diag_id);
-                if (field->getType()->isReferenceType()) {
+                if (outer_type == OuterType::Ref) {
                     builder << "reference"
                             << "JS::NonnullGCPtr";
                 } else {
@@ -169,12 +191,24 @@ bool LibJSGCVisitor::VisitCXXRecordDecl(clang::CXXRecordDecl* record)
                             << "JS::GCPtr";
                 }
             }
-        } else if (validation_results.needs_visiting) {
-            fields_that_need_visiting.push_back(field);
+        } else if (outer_type == OuterType::GCPtr || outer_type == OuterType::RawGCPtr) {
+            if (!base_type_inherits_from_cell) {
+                auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Error, "Specialization type must inherit from JS::Cell");
+                diag_engine.Report(field->getLocation(), diag_id);
+            } else if (outer_type == OuterType::GCPtr) {
+                fields_that_need_visiting.push_back(field);
+            }
+        } else if (outer_type == OuterType::Handle || outer_type == OuterType::SafeFunction) {
+            if (record_is_cell && m_detect_invalid_function_members) {
+                // FIXME: Change this to an Error when all of the use cases get addressed and remove the plugin argument
+                auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Warning, "Types inheriting from JS::Cell should not have %0 fields");
+                auto builder = diag_engine.Report(field->getLocation(), diag_id);
+                builder << (outer_type == OuterType::Handle ? "JS::Handle" : "JS::SafeFunction");
+            }
         }
     }
 
-    if (!record_inherits_from_cell(*record))
+    if (!record_is_cell)
         return true;
 
     validate_record_macros(*record);
@@ -406,8 +440,9 @@ void LibJSGCVisitor::validate_record_macros(clang::CXXRecordDecl const& record)
         report_missing_macro();
 }
 
-LibJSGCASTConsumer::LibJSGCASTConsumer(clang::CompilerInstance& compiler)
+LibJSGCASTConsumer::LibJSGCASTConsumer(clang::CompilerInstance& compiler, bool detect_invalid_function_members)
     : m_compiler(compiler)
+    , m_detect_invalid_function_members(detect_invalid_function_members)
 {
     auto& preprocessor = compiler.getPreprocessor();
     preprocessor.addPPCallbacks(std::make_unique<LibJSPPCallbacks>(preprocessor, m_macro_map));
@@ -415,7 +450,7 @@ LibJSGCASTConsumer::LibJSGCASTConsumer(clang::CompilerInstance& compiler)
 
 void LibJSGCASTConsumer::HandleTranslationUnit(clang::ASTContext& context)
 {
-    LibJSGCVisitor visitor { context, m_macro_map };
+    LibJSGCVisitor visitor { context, m_macro_map, m_detect_invalid_function_members };
     visitor.TraverseDecl(context.getTranslationUnitDecl());
 }
 

+ 11 - 4
Meta/Lagom/ClangPlugins/LibJSGCPluginAction.h

@@ -53,9 +53,10 @@ private:
 
 class LibJSGCVisitor : public clang::RecursiveASTVisitor<LibJSGCVisitor> {
 public:
-    explicit LibJSGCVisitor(clang::ASTContext& context, LibJSCellMacroMap const& macro_map)
+    explicit LibJSGCVisitor(clang::ASTContext& context, LibJSCellMacroMap const& macro_map, bool detect_invalid_function_members)
         : m_context(context)
         , m_macro_map(macro_map)
+        , m_detect_invalid_function_members(detect_invalid_function_members)
     {
     }
 
@@ -72,33 +73,39 @@ private:
 
     clang::ASTContext& m_context;
     LibJSCellMacroMap const& m_macro_map;
+    bool m_detect_invalid_function_members;
 };
 
 class LibJSGCASTConsumer : public clang::ASTConsumer {
 public:
-    explicit LibJSGCASTConsumer(clang::CompilerInstance&);
+    LibJSGCASTConsumer(clang::CompilerInstance&, bool detect_invalid_function_members);
 
 private:
     virtual void HandleTranslationUnit(clang::ASTContext& context) override;
 
     clang::CompilerInstance& m_compiler;
     LibJSCellMacroMap m_macro_map;
+    bool m_detect_invalid_function_members;
 };
 
 class LibJSGCPluginAction : public clang::PluginASTAction {
 public:
-    virtual bool ParseArgs(clang::CompilerInstance const&, std::vector<std::string> const&) override
+    virtual bool ParseArgs(clang::CompilerInstance const&, std::vector<std::string> const& args) override
     {
+        m_detect_invalid_function_members = std::find(args.begin(), args.end(), "detect-invalid-function-members") != args.end();
         return true;
     }
 
     virtual std::unique_ptr<clang::ASTConsumer> CreateASTConsumer(clang::CompilerInstance& compiler, llvm::StringRef) override
     {
-        return std::make_unique<LibJSGCASTConsumer>(compiler);
+        return std::make_unique<LibJSGCASTConsumer>(compiler, m_detect_invalid_function_members);
     }
 
     ActionType getActionType() override
     {
         return AddAfterMainAction;
     }
+
+private:
+    bool m_detect_invalid_function_members { false };
 };

+ 3 - 0
Tests/ClangPlugins/CMakeLists.txt

@@ -6,6 +6,9 @@ find_package(Python3 REQUIRED COMPONENTS Interpreter)
 get_property(CLANG_PLUGINS_COMPILE_OPTIONS_FOR_TESTS GLOBAL PROPERTY CLANG_PLUGINS_COMPILE_OPTIONS_FOR_TESTS)
 list(APPEND CLANG_PLUGINS_COMPILE_OPTIONS_FOR_TESTS -std=c++23 -Wno-user-defined-literals -Wno-literal-range)
 
+# Ensure we always check for invalid function field types regardless of the value of ENABLE_CLANG_PLUGINS_INVALID_FUNCTION_MEMBERS
+list(APPEND CLANG_PLUGINS_COMPILE_OPTIONS_FOR_TESTS -fplugin-arg-libjs_gc_scanner-detect-invalid-function-members)
+
 get_property(CLANG_PLUGINS_INCLUDE_DIRECTORIES TARGET AK PROPERTY INCLUDE_DIRECTORIES)
 list(APPEND CLANG_PLUGINS_INCLUDE_DIRECTORIES ${CMAKE_CXX_IMPLICIT_INCLUDE_DIRECTORIES})
 

+ 26 - 0
Tests/ClangPlugins/LibJSGCTests/strong_root_fields_in_gc_allocated_types.cpp

@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2024, Matthew Olsson <mattco@serenityos.org>
+ *
+ * SPDX-License-Identifier: BSD-2-Clause
+ */
+
+// RUN: %clang++ -cc1 -verify %plugin_opts% %s 2>&1
+
+#include <LibJS/Heap/Cell.h>
+#include <LibJS/Heap/Handle.h>
+#include <LibJS/SafeFunction.h>
+
+class CellClass : JS::Cell {
+    JS_CELL(CellClass, JS::Cell);
+
+    // expected-warning@+1 {{Types inheriting from JS::Cell should not have JS::SafeFunction fields}}
+    JS::SafeFunction<void()> m_func;
+
+    // expected-warning@+1 {{Types inheriting from JS::Cell should not have JS::Handle fields}}
+    JS::Handle<JS::Cell> m_handle;
+};
+
+class NonCellClass {
+    JS::SafeFunction<void()> m_func;
+    JS::Handle<JS::Cell> m_handle;
+};