LibJSGCPluginAction.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. /*
  2. * Copyright (c) 2024, Matthew Olsson <mattco@serenityos.org>
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #include "LibJSGCPluginAction.h"
  7. #include <clang/ASTMatchers/ASTMatchFinder.h>
  8. #include <clang/ASTMatchers/ASTMatchers.h>
  9. #include <clang/Basic/SourceManager.h>
  10. #include <clang/Frontend/CompilerInstance.h>
  11. #include <clang/Frontend/FrontendPluginRegistry.h>
  12. #include <unordered_set>
  13. template<typename T>
  14. class SimpleCollectMatchesCallback : public clang::ast_matchers::MatchFinder::MatchCallback {
  15. public:
  16. explicit SimpleCollectMatchesCallback(std::string name)
  17. : m_name(std::move(name))
  18. {
  19. }
  20. void run(clang::ast_matchers::MatchFinder::MatchResult const& result) override
  21. {
  22. if (auto const* node = result.Nodes.getNodeAs<T>(m_name))
  23. m_matches.push_back(node);
  24. }
  25. auto const& matches() const { return m_matches; }
  26. private:
  27. std::string m_name;
  28. std::vector<T const*> m_matches;
  29. };
  30. bool record_inherits_from_cell(clang::CXXRecordDecl const& record)
  31. {
  32. if (!record.isCompleteDefinition())
  33. return false;
  34. bool inherits_from_cell = record.getQualifiedNameAsString() == "JS::Cell";
  35. record.forallBases([&](clang::CXXRecordDecl const* base) -> bool {
  36. if (base->getQualifiedNameAsString() == "JS::Cell") {
  37. inherits_from_cell = true;
  38. return false;
  39. }
  40. return true;
  41. });
  42. return inherits_from_cell;
  43. }
  44. std::vector<clang::QualType> get_all_qualified_types(clang::QualType const& type)
  45. {
  46. std::vector<clang::QualType> qualified_types;
  47. if (auto const* template_specialization = type->getAs<clang::TemplateSpecializationType>()) {
  48. auto specialization_name = template_specialization->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
  49. // Do not unwrap GCPtr/NonnullGCPtr/MarkedVector
  50. if (specialization_name == "JS::GCPtr" || specialization_name == "JS::NonnullGCPtr" || specialization_name == "JS::RawGCPtr" || specialization_name == "JS::MarkedVector") {
  51. qualified_types.push_back(type);
  52. } else {
  53. auto const template_arguments = template_specialization->template_arguments();
  54. for (size_t i = 0; i < template_arguments.size(); i++) {
  55. auto const& template_arg = template_arguments[i];
  56. if (template_arg.getKind() == clang::TemplateArgument::Type) {
  57. auto template_qualified_types = get_all_qualified_types(template_arg.getAsType());
  58. std::move(template_qualified_types.begin(), template_qualified_types.end(), std::back_inserter(qualified_types));
  59. }
  60. }
  61. }
  62. } else {
  63. qualified_types.push_back(type);
  64. }
  65. return qualified_types;
  66. }
  67. struct FieldValidationResult {
  68. bool is_valid { false };
  69. bool is_wrapped_in_gcptr { false };
  70. bool needs_visiting { false };
  71. };
  72. FieldValidationResult validate_field(clang::FieldDecl const* field_decl)
  73. {
  74. auto type = field_decl->getType();
  75. if (auto const* elaborated_type = llvm::dyn_cast<clang::ElaboratedType>(type.getTypePtr()))
  76. type = elaborated_type->desugar();
  77. FieldValidationResult result { .is_valid = true };
  78. for (auto const& qualified_type : get_all_qualified_types(type)) {
  79. if (auto const* pointer_decl = qualified_type->getAs<clang::PointerType>()) {
  80. if (auto const* pointee = pointer_decl->getPointeeCXXRecordDecl()) {
  81. if (record_inherits_from_cell(*pointee)) {
  82. result.is_valid = false;
  83. result.is_wrapped_in_gcptr = false;
  84. result.needs_visiting = true;
  85. return result;
  86. }
  87. }
  88. } else if (auto const* reference_decl = qualified_type->getAs<clang::ReferenceType>()) {
  89. if (auto const* pointee = reference_decl->getPointeeCXXRecordDecl()) {
  90. if (record_inherits_from_cell(*pointee)) {
  91. result.is_valid = false;
  92. result.is_wrapped_in_gcptr = false;
  93. result.needs_visiting = true;
  94. return result;
  95. }
  96. }
  97. } else if (auto const* specialization = qualified_type->getAs<clang::TemplateSpecializationType>()) {
  98. auto template_type_name = specialization->getTemplateName().getAsTemplateDecl()->getName();
  99. if (template_type_name != "GCPtr" && template_type_name != "NonnullGCPtr" && template_type_name != "RawGCPtr")
  100. return result;
  101. auto const template_args = specialization->template_arguments();
  102. if (template_args.size() != 1)
  103. return result; // Not really valid, but will produce a compilation error anyway
  104. auto const& type_arg = template_args[0];
  105. auto const* record_type = type_arg.getAsType()->getAs<clang::RecordType>();
  106. if (!record_type)
  107. return result;
  108. auto const* record_decl = record_type->getAsCXXRecordDecl();
  109. if (!record_decl->hasDefinition())
  110. return result;
  111. result.is_wrapped_in_gcptr = true;
  112. result.is_valid = record_inherits_from_cell(*record_decl);
  113. result.needs_visiting = template_type_name != "RawGCPtr";
  114. }
  115. }
  116. return result;
  117. }
  118. bool LibJSGCVisitor::VisitCXXRecordDecl(clang::CXXRecordDecl* record)
  119. {
  120. using namespace clang::ast_matchers;
  121. if (!record || !record->isCompleteDefinition() || (!record->isClass() && !record->isStruct()))
  122. return true;
  123. // Cell triggers a bunch of warnings for its empty visit_edges implementation, but
  124. // it doesn't have any members anyways so it's fine to just ignore.
  125. auto qualified_name = record->getQualifiedNameAsString();
  126. if (qualified_name == "JS::Cell")
  127. return true;
  128. auto& diag_engine = m_context.getDiagnostics();
  129. std::vector<clang::FieldDecl const*> fields_that_need_visiting;
  130. for (clang::FieldDecl const* field : record->fields()) {
  131. auto validation_results = validate_field(field);
  132. if (!validation_results.is_valid) {
  133. if (validation_results.is_wrapped_in_gcptr) {
  134. auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Error, "Specialization type must inherit from JS::Cell");
  135. diag_engine.Report(field->getLocation(), diag_id);
  136. } else {
  137. auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Error, "%0 to JS::Cell type should be wrapped in %1");
  138. auto builder = diag_engine.Report(field->getLocation(), diag_id);
  139. if (field->getType()->isReferenceType()) {
  140. builder << "reference"
  141. << "JS::NonnullGCPtr";
  142. } else {
  143. builder << "pointer"
  144. << "JS::GCPtr";
  145. }
  146. }
  147. } else if (validation_results.needs_visiting) {
  148. fields_that_need_visiting.push_back(field);
  149. }
  150. }
  151. if (!record_inherits_from_cell(*record))
  152. return true;
  153. clang::DeclarationName name = &m_context.Idents.get("visit_edges");
  154. auto const* visit_edges_method = record->lookup(name).find_first<clang::CXXMethodDecl>();
  155. if (!visit_edges_method && !fields_that_need_visiting.empty()) {
  156. auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Error, "JS::Cell-inheriting class %0 contains a GC-allocated member %1 but has no visit_edges method");
  157. auto builder = diag_engine.Report(record->getLocation(), diag_id);
  158. builder << record->getName()
  159. << fields_that_need_visiting[0];
  160. }
  161. if (!visit_edges_method || !visit_edges_method->getBody())
  162. return true;
  163. // Search for a call to Base::visit_edges. Note that this also has the nice side effect of
  164. // ensuring the classes use JS_CELL/JS_OBJECT, as Base will not be defined if they do not.
  165. MatchFinder base_visit_edges_finder;
  166. SimpleCollectMatchesCallback<clang::MemberExpr> base_visit_edges_callback("member-call");
  167. auto base_visit_edges_matcher = cxxMethodDecl(
  168. ofClass(hasName(qualified_name)),
  169. functionDecl(hasName("visit_edges")),
  170. isOverride(),
  171. hasDescendant(memberExpr(member(hasName("visit_edges"))).bind("member-call")));
  172. base_visit_edges_finder.addMatcher(base_visit_edges_matcher, &base_visit_edges_callback);
  173. base_visit_edges_finder.matchAST(m_context);
  174. bool call_to_base_visit_edges_found = false;
  175. for (auto const* call_expr : base_visit_edges_callback.matches()) {
  176. // FIXME: Can we constrain the matcher above to avoid looking directly at the source code?
  177. auto const* source_chars = m_context.getSourceManager().getCharacterData(call_expr->getBeginLoc());
  178. if (strncmp(source_chars, "Base::", 6) == 0) {
  179. call_to_base_visit_edges_found = true;
  180. break;
  181. }
  182. }
  183. if (!call_to_base_visit_edges_found) {
  184. auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Error, "Missing call to Base::visit_edges");
  185. diag_engine.Report(visit_edges_method->getBeginLoc(), diag_id);
  186. }
  187. // Search for uses of all fields that need visiting. We don't ensure they are _actually_ visited
  188. // with a call to visitor.visit(...), as that is too complex. Instead, we just assume that if the
  189. // field is accessed at all, then it is visited.
  190. if (fields_that_need_visiting.empty())
  191. return true;
  192. MatchFinder field_access_finder;
  193. SimpleCollectMatchesCallback<clang::MemberExpr> field_access_callback("member-expr");
  194. auto field_access_matcher = memberExpr(
  195. hasAncestor(cxxMethodDecl(hasName("visit_edges"))),
  196. hasObjectExpression(hasType(pointsTo(cxxRecordDecl(hasName(record->getName()))))))
  197. .bind("member-expr");
  198. field_access_finder.addMatcher(field_access_matcher, &field_access_callback);
  199. field_access_finder.matchAST(visit_edges_method->getASTContext());
  200. std::unordered_set<std::string> fields_that_are_visited;
  201. for (auto const* member_expr : field_access_callback.matches())
  202. fields_that_are_visited.insert(member_expr->getMemberNameInfo().getAsString());
  203. auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Error, "GC-allocated member is not visited in %0::visit_edges");
  204. for (auto const* field : fields_that_need_visiting) {
  205. if (!fields_that_are_visited.contains(field->getNameAsString())) {
  206. auto builder = diag_engine.Report(field->getBeginLoc(), diag_id);
  207. builder << record->getName();
  208. }
  209. }
  210. return true;
  211. }
  212. void LibJSGCASTConsumer::HandleTranslationUnit(clang::ASTContext& context)
  213. {
  214. LibJSGCVisitor visitor { context };
  215. visitor.TraverseDecl(context.getTranslationUnitDecl());
  216. }
  217. static clang::FrontendPluginRegistry::Add<LibJSGCPluginAction> X("libjs_gc_scanner", "analyze LibJS GC usage");