LibJSGCPluginAction.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. /*
  2. * Copyright (c) 2024, Matthew Olsson <mattco@serenityos.org>
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #include "LibJSGCPluginAction.h"
  7. #include <clang/ASTMatchers/ASTMatchFinder.h>
  8. #include <clang/ASTMatchers/ASTMatchers.h>
  9. #include <clang/Basic/SourceManager.h>
  10. #include <clang/Frontend/CompilerInstance.h>
  11. #include <clang/Frontend/FrontendPluginRegistry.h>
  12. #include <clang/Lex/MacroArgs.h>
  13. #include <unordered_set>
  14. template<typename T>
  15. class SimpleCollectMatchesCallback : public clang::ast_matchers::MatchFinder::MatchCallback {
  16. public:
  17. explicit SimpleCollectMatchesCallback(std::string name)
  18. : m_name(std::move(name))
  19. {
  20. }
  21. void run(clang::ast_matchers::MatchFinder::MatchResult const& result) override
  22. {
  23. if (auto const* node = result.Nodes.getNodeAs<T>(m_name))
  24. m_matches.push_back(node);
  25. }
  26. auto const& matches() const { return m_matches; }
  27. private:
  28. std::string m_name;
  29. std::vector<T const*> m_matches;
  30. };
  31. bool record_inherits_from_cell(clang::CXXRecordDecl const& record)
  32. {
  33. if (!record.isCompleteDefinition())
  34. return false;
  35. bool inherits_from_cell = record.getQualifiedNameAsString() == "JS::Cell";
  36. record.forallBases([&](clang::CXXRecordDecl const* base) -> bool {
  37. if (base->getQualifiedNameAsString() == "JS::Cell") {
  38. inherits_from_cell = true;
  39. return false;
  40. }
  41. return true;
  42. });
  43. return inherits_from_cell;
  44. }
  45. std::vector<clang::QualType> get_all_qualified_types(clang::QualType const& type)
  46. {
  47. std::vector<clang::QualType> qualified_types;
  48. if (auto const* template_specialization = type->getAs<clang::TemplateSpecializationType>()) {
  49. auto specialization_name = template_specialization->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
  50. // Do not unwrap GCPtr/NonnullGCPtr/MarkedVector
  51. if (specialization_name == "JS::GCPtr" || specialization_name == "JS::NonnullGCPtr" || specialization_name == "JS::RawGCPtr" || specialization_name == "JS::MarkedVector") {
  52. qualified_types.push_back(type);
  53. } else {
  54. auto const template_arguments = template_specialization->template_arguments();
  55. for (size_t i = 0; i < template_arguments.size(); i++) {
  56. auto const& template_arg = template_arguments[i];
  57. if (template_arg.getKind() == clang::TemplateArgument::Type) {
  58. auto template_qualified_types = get_all_qualified_types(template_arg.getAsType());
  59. std::move(template_qualified_types.begin(), template_qualified_types.end(), std::back_inserter(qualified_types));
  60. }
  61. }
  62. }
  63. } else {
  64. qualified_types.push_back(type);
  65. }
  66. return qualified_types;
  67. }
  68. struct FieldValidationResult {
  69. bool is_valid { false };
  70. bool is_wrapped_in_gcptr { false };
  71. bool needs_visiting { false };
  72. };
  73. FieldValidationResult validate_field(clang::FieldDecl const* field_decl)
  74. {
  75. auto type = field_decl->getType();
  76. if (auto const* elaborated_type = llvm::dyn_cast<clang::ElaboratedType>(type.getTypePtr()))
  77. type = elaborated_type->desugar();
  78. FieldValidationResult result { .is_valid = true };
  79. for (auto const& qualified_type : get_all_qualified_types(type)) {
  80. if (auto const* pointer_decl = qualified_type->getAs<clang::PointerType>()) {
  81. if (auto const* pointee = pointer_decl->getPointeeCXXRecordDecl()) {
  82. if (record_inherits_from_cell(*pointee)) {
  83. result.is_valid = false;
  84. result.is_wrapped_in_gcptr = false;
  85. result.needs_visiting = true;
  86. return result;
  87. }
  88. }
  89. } else if (auto const* reference_decl = qualified_type->getAs<clang::ReferenceType>()) {
  90. if (auto const* pointee = reference_decl->getPointeeCXXRecordDecl()) {
  91. if (record_inherits_from_cell(*pointee)) {
  92. result.is_valid = false;
  93. result.is_wrapped_in_gcptr = false;
  94. result.needs_visiting = true;
  95. return result;
  96. }
  97. }
  98. } else if (auto const* specialization = qualified_type->getAs<clang::TemplateSpecializationType>()) {
  99. auto template_type_name = specialization->getTemplateName().getAsTemplateDecl()->getName();
  100. if (template_type_name != "GCPtr" && template_type_name != "NonnullGCPtr" && template_type_name != "RawGCPtr")
  101. return result;
  102. auto const template_args = specialization->template_arguments();
  103. if (template_args.size() != 1)
  104. return result; // Not really valid, but will produce a compilation error anyway
  105. auto const& type_arg = template_args[0];
  106. auto const* record_type = type_arg.getAsType()->getAs<clang::RecordType>();
  107. if (!record_type)
  108. return result;
  109. auto const* record_decl = record_type->getAsCXXRecordDecl();
  110. if (!record_decl->hasDefinition())
  111. return result;
  112. result.is_wrapped_in_gcptr = true;
  113. result.is_valid = record_inherits_from_cell(*record_decl);
  114. result.needs_visiting = template_type_name != "RawGCPtr";
  115. }
  116. }
  117. return result;
  118. }
  119. bool LibJSGCVisitor::VisitCXXRecordDecl(clang::CXXRecordDecl* record)
  120. {
  121. using namespace clang::ast_matchers;
  122. if (!record || !record->isCompleteDefinition() || (!record->isClass() && !record->isStruct()))
  123. return true;
  124. // Cell triggers a bunch of warnings for its empty visit_edges implementation, but
  125. // it doesn't have any members anyways so it's fine to just ignore.
  126. auto qualified_name = record->getQualifiedNameAsString();
  127. if (qualified_name == "JS::Cell")
  128. return true;
  129. auto& diag_engine = m_context.getDiagnostics();
  130. std::vector<clang::FieldDecl const*> fields_that_need_visiting;
  131. for (clang::FieldDecl const* field : record->fields()) {
  132. auto validation_results = validate_field(field);
  133. if (!validation_results.is_valid) {
  134. if (validation_results.is_wrapped_in_gcptr) {
  135. auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Error, "Specialization type must inherit from JS::Cell");
  136. diag_engine.Report(field->getLocation(), diag_id);
  137. } else {
  138. auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Error, "%0 to JS::Cell type should be wrapped in %1");
  139. auto builder = diag_engine.Report(field->getLocation(), diag_id);
  140. if (field->getType()->isReferenceType()) {
  141. builder << "reference"
  142. << "JS::NonnullGCPtr";
  143. } else {
  144. builder << "pointer"
  145. << "JS::GCPtr";
  146. }
  147. }
  148. } else if (validation_results.needs_visiting) {
  149. fields_that_need_visiting.push_back(field);
  150. }
  151. }
  152. if (!record_inherits_from_cell(*record))
  153. return true;
  154. clang::DeclarationName name = &m_context.Idents.get("visit_edges");
  155. auto const* visit_edges_method = record->lookup(name).find_first<clang::CXXMethodDecl>();
  156. if (!visit_edges_method && !fields_that_need_visiting.empty()) {
  157. auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Error, "JS::Cell-inheriting class %0 contains a GC-allocated member %1 but has no visit_edges method");
  158. auto builder = diag_engine.Report(record->getLocation(), diag_id);
  159. builder << record->getName()
  160. << fields_that_need_visiting[0];
  161. }
  162. if (!visit_edges_method || !visit_edges_method->getBody())
  163. return true;
  164. // Search for a call to Base::visit_edges. Note that this also has the nice side effect of
  165. // ensuring the classes use JS_CELL/JS_OBJECT, as Base will not be defined if they do not.
  166. MatchFinder base_visit_edges_finder;
  167. SimpleCollectMatchesCallback<clang::MemberExpr> base_visit_edges_callback("member-call");
  168. auto base_visit_edges_matcher = cxxMethodDecl(
  169. ofClass(hasName(qualified_name)),
  170. functionDecl(hasName("visit_edges")),
  171. isOverride(),
  172. hasDescendant(memberExpr(member(hasName("visit_edges"))).bind("member-call")));
  173. base_visit_edges_finder.addMatcher(base_visit_edges_matcher, &base_visit_edges_callback);
  174. base_visit_edges_finder.matchAST(m_context);
  175. bool call_to_base_visit_edges_found = false;
  176. for (auto const* call_expr : base_visit_edges_callback.matches()) {
  177. // FIXME: Can we constrain the matcher above to avoid looking directly at the source code?
  178. auto const* source_chars = m_context.getSourceManager().getCharacterData(call_expr->getBeginLoc());
  179. if (strncmp(source_chars, "Base::", 6) == 0) {
  180. call_to_base_visit_edges_found = true;
  181. break;
  182. }
  183. }
  184. if (!call_to_base_visit_edges_found) {
  185. auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Error, "Missing call to Base::visit_edges");
  186. diag_engine.Report(visit_edges_method->getBeginLoc(), diag_id);
  187. }
  188. // Search for uses of all fields that need visiting. We don't ensure they are _actually_ visited
  189. // with a call to visitor.visit(...), as that is too complex. Instead, we just assume that if the
  190. // field is accessed at all, then it is visited.
  191. if (fields_that_need_visiting.empty())
  192. return true;
  193. MatchFinder field_access_finder;
  194. SimpleCollectMatchesCallback<clang::MemberExpr> field_access_callback("member-expr");
  195. auto field_access_matcher = memberExpr(
  196. hasAncestor(cxxMethodDecl(hasName("visit_edges"))),
  197. hasObjectExpression(hasType(pointsTo(cxxRecordDecl(hasName(record->getName()))))))
  198. .bind("member-expr");
  199. field_access_finder.addMatcher(field_access_matcher, &field_access_callback);
  200. field_access_finder.matchAST(visit_edges_method->getASTContext());
  201. std::unordered_set<std::string> fields_that_are_visited;
  202. for (auto const* member_expr : field_access_callback.matches())
  203. fields_that_are_visited.insert(member_expr->getMemberNameInfo().getAsString());
  204. auto diag_id = diag_engine.getCustomDiagID(clang::DiagnosticsEngine::Error, "GC-allocated member is not visited in %0::visit_edges");
  205. for (auto const* field : fields_that_need_visiting) {
  206. if (!fields_that_are_visited.contains(field->getNameAsString())) {
  207. auto builder = diag_engine.Report(field->getBeginLoc(), diag_id);
  208. builder << record->getName();
  209. }
  210. }
  211. return true;
  212. }
  213. void LibJSGCASTConsumer::HandleTranslationUnit(clang::ASTContext& context)
  214. {
  215. LibJSGCVisitor visitor { context };
  216. visitor.TraverseDecl(context.getTranslationUnitDecl());
  217. }
  218. char const* LibJSCellMacro::type_name(Type type)
  219. {
  220. switch (type) {
  221. case Type::JSCell:
  222. return "JS_CELL";
  223. case Type::JSObject:
  224. return "JS_OBJECT";
  225. case Type::JSEnvironment:
  226. return "JS_ENVIRONMENT";
  227. case Type::JSPrototypeObject:
  228. return "JS_PROTOTYPE_OBJECT";
  229. case Type::WebPlatformObject:
  230. return "WEB_PLATFORM_OBJECT";
  231. default:
  232. __builtin_unreachable();
  233. }
  234. }
  235. void LibJSPPCallbacks::LexedFileChanged(clang::FileID curr_fid, LexedFileChangeReason reason, clang::SrcMgr::CharacteristicKind, clang::FileID, clang::SourceLocation)
  236. {
  237. if (reason == LexedFileChangeReason::EnterFile) {
  238. m_curr_fid_hash_stack.push_back(curr_fid.getHashValue());
  239. } else {
  240. assert(!m_curr_fid_hash_stack.empty());
  241. m_curr_fid_hash_stack.pop_back();
  242. }
  243. }
  244. void LibJSPPCallbacks::MacroExpands(clang::Token const& name_token, clang::MacroDefinition const&, clang::SourceRange range, clang::MacroArgs const* args)
  245. {
  246. if (auto* ident_info = name_token.getIdentifierInfo()) {
  247. static llvm::StringMap<LibJSCellMacro::Type> libjs_macro_types {
  248. { "JS_CELL", LibJSCellMacro::Type::JSCell },
  249. { "JS_OBJECT", LibJSCellMacro::Type::JSObject },
  250. { "JS_ENVIRONMENT", LibJSCellMacro::Type::JSEnvironment },
  251. { "JS_PROTOTYPE_OBJECT", LibJSCellMacro::Type::JSPrototypeObject },
  252. { "WEB_PLATFORM_OBJECT", LibJSCellMacro::Type::WebPlatformObject },
  253. };
  254. auto name = ident_info->getName();
  255. if (auto it = libjs_macro_types.find(name); it != libjs_macro_types.end()) {
  256. LibJSCellMacro macro { range, it->second, {} };
  257. for (size_t arg_index = 0; arg_index < args->getNumMacroArguments(); arg_index++) {
  258. auto const* first_token = args->getUnexpArgument(arg_index);
  259. auto stringified_token = clang::MacroArgs::StringifyArgument(first_token, m_preprocessor, false, range.getBegin(), range.getEnd());
  260. // The token includes leading and trailing quotes
  261. auto len = strlen(stringified_token.getLiteralData());
  262. std::string arg_text { stringified_token.getLiteralData() + 1, len - 2 };
  263. macro.args.push_back({ arg_text, first_token->getLocation() });
  264. }
  265. assert(!m_curr_fid_hash_stack.empty());
  266. auto curr_fid_hash = m_curr_fid_hash_stack.back();
  267. if (m_macro_map.find(curr_fid_hash) == m_macro_map.end())
  268. m_macro_map[curr_fid_hash] = {};
  269. m_macro_map[curr_fid_hash].push_back(macro);
  270. }
  271. }
  272. }
  273. static clang::FrontendPluginRegistry::Add<LibJSGCPluginAction> X("libjs_gc_scanner", "analyze LibJS GC usage");