xref: /openbsd-src/gnu/llvm/clang/lib/Tooling/Refactoring/Rename/USRFindingAction.cpp (revision 12c855180aad702bbcca06e0398d774beeafb155)
1e5dd7070Spatrick //===--- USRFindingAction.cpp - Clang refactoring library -----------------===//
2e5dd7070Spatrick //
3e5dd7070Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e5dd7070Spatrick // See https://llvm.org/LICENSE.txt for license information.
5e5dd7070Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e5dd7070Spatrick //
7e5dd7070Spatrick //===----------------------------------------------------------------------===//
8e5dd7070Spatrick ///
9e5dd7070Spatrick /// \file
10e5dd7070Spatrick /// Provides an action to find USR for the symbol at <offset>, as well as
11e5dd7070Spatrick /// all additional USRs.
12e5dd7070Spatrick ///
13e5dd7070Spatrick //===----------------------------------------------------------------------===//
14e5dd7070Spatrick 
15e5dd7070Spatrick #include "clang/Tooling/Refactoring/Rename/USRFindingAction.h"
16e5dd7070Spatrick #include "clang/AST/AST.h"
17e5dd7070Spatrick #include "clang/AST/ASTConsumer.h"
18e5dd7070Spatrick #include "clang/AST/ASTContext.h"
19e5dd7070Spatrick #include "clang/AST/Decl.h"
20e5dd7070Spatrick #include "clang/AST/RecursiveASTVisitor.h"
21e5dd7070Spatrick #include "clang/Basic/FileManager.h"
22e5dd7070Spatrick #include "clang/Frontend/CompilerInstance.h"
23e5dd7070Spatrick #include "clang/Frontend/FrontendAction.h"
24e5dd7070Spatrick #include "clang/Lex/Lexer.h"
25e5dd7070Spatrick #include "clang/Lex/Preprocessor.h"
26e5dd7070Spatrick #include "clang/Tooling/CommonOptionsParser.h"
27e5dd7070Spatrick #include "clang/Tooling/Refactoring.h"
28e5dd7070Spatrick #include "clang/Tooling/Refactoring/Rename/USRFinder.h"
29e5dd7070Spatrick #include "clang/Tooling/Tooling.h"
30e5dd7070Spatrick 
31e5dd7070Spatrick #include <algorithm>
32e5dd7070Spatrick #include <set>
33e5dd7070Spatrick #include <string>
34e5dd7070Spatrick #include <vector>
35e5dd7070Spatrick 
36e5dd7070Spatrick using namespace llvm;
37e5dd7070Spatrick 
38e5dd7070Spatrick namespace clang {
39e5dd7070Spatrick namespace tooling {
40e5dd7070Spatrick 
getCanonicalSymbolDeclaration(const NamedDecl * FoundDecl)41e5dd7070Spatrick const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) {
42e5dd7070Spatrick   if (!FoundDecl)
43e5dd7070Spatrick     return nullptr;
44e5dd7070Spatrick   // If FoundDecl is a constructor or destructor, we want to instead take
45e5dd7070Spatrick   // the Decl of the corresponding class.
46e5dd7070Spatrick   if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl))
47e5dd7070Spatrick     FoundDecl = CtorDecl->getParent();
48e5dd7070Spatrick   else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl))
49e5dd7070Spatrick     FoundDecl = DtorDecl->getParent();
50e5dd7070Spatrick   // FIXME: (Alex L): Canonicalize implicit template instantions, just like
51e5dd7070Spatrick   // the indexer does it.
52e5dd7070Spatrick 
53e5dd7070Spatrick   // Note: please update the declaration's doc comment every time the
54e5dd7070Spatrick   // canonicalization rules are changed.
55e5dd7070Spatrick   return FoundDecl;
56e5dd7070Spatrick }
57e5dd7070Spatrick 
58e5dd7070Spatrick namespace {
59e5dd7070Spatrick // NamedDeclFindingConsumer should delegate finding USRs of given Decl to
60e5dd7070Spatrick // AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given
61e5dd7070Spatrick // Decl refers to class and adds USRs of all overridden methods if Decl refers
62e5dd7070Spatrick // to virtual method.
63e5dd7070Spatrick class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> {
64e5dd7070Spatrick public:
AdditionalUSRFinder(const Decl * FoundDecl,ASTContext & Context)65e5dd7070Spatrick   AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context)
66e5dd7070Spatrick       : FoundDecl(FoundDecl), Context(Context) {}
67e5dd7070Spatrick 
Find()68e5dd7070Spatrick   std::vector<std::string> Find() {
69e5dd7070Spatrick     // Fill OverriddenMethods and PartialSpecs storages.
70e5dd7070Spatrick     TraverseAST(Context);
71e5dd7070Spatrick     if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) {
72e5dd7070Spatrick       addUSRsOfOverridenFunctions(MethodDecl);
73e5dd7070Spatrick       for (const auto &OverriddenMethod : OverriddenMethods) {
74e5dd7070Spatrick         if (checkIfOverriddenFunctionAscends(OverriddenMethod))
75e5dd7070Spatrick           USRSet.insert(getUSRForDecl(OverriddenMethod));
76e5dd7070Spatrick       }
77e5dd7070Spatrick       addUSRsOfInstantiatedMethods(MethodDecl);
78e5dd7070Spatrick     } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) {
79e5dd7070Spatrick       handleCXXRecordDecl(RecordDecl);
80e5dd7070Spatrick     } else if (const auto *TemplateDecl =
81e5dd7070Spatrick                    dyn_cast<ClassTemplateDecl>(FoundDecl)) {
82e5dd7070Spatrick       handleClassTemplateDecl(TemplateDecl);
83a9ac8606Spatrick     } else if (const auto *FD = dyn_cast<FunctionDecl>(FoundDecl)) {
84a9ac8606Spatrick       USRSet.insert(getUSRForDecl(FD));
85a9ac8606Spatrick       if (const auto *FTD = FD->getPrimaryTemplate())
86a9ac8606Spatrick         handleFunctionTemplateDecl(FTD);
87a9ac8606Spatrick     } else if (const auto *FD = dyn_cast<FunctionTemplateDecl>(FoundDecl)) {
88a9ac8606Spatrick       handleFunctionTemplateDecl(FD);
89a9ac8606Spatrick     } else if (const auto *VTD = dyn_cast<VarTemplateDecl>(FoundDecl)) {
90a9ac8606Spatrick       handleVarTemplateDecl(VTD);
91a9ac8606Spatrick     } else if (const auto *VD =
92a9ac8606Spatrick                    dyn_cast<VarTemplateSpecializationDecl>(FoundDecl)) {
93a9ac8606Spatrick       // FIXME: figure out why FoundDecl can be a VarTemplateSpecializationDecl.
94a9ac8606Spatrick       handleVarTemplateDecl(VD->getSpecializedTemplate());
95a9ac8606Spatrick     } else if (const auto *VD = dyn_cast<VarDecl>(FoundDecl)) {
96a9ac8606Spatrick       USRSet.insert(getUSRForDecl(VD));
97a9ac8606Spatrick       if (const auto *VTD = VD->getDescribedVarTemplate())
98a9ac8606Spatrick         handleVarTemplateDecl(VTD);
99e5dd7070Spatrick     } else {
100e5dd7070Spatrick       USRSet.insert(getUSRForDecl(FoundDecl));
101e5dd7070Spatrick     }
102e5dd7070Spatrick     return std::vector<std::string>(USRSet.begin(), USRSet.end());
103e5dd7070Spatrick   }
104e5dd7070Spatrick 
shouldVisitTemplateInstantiations() const105e5dd7070Spatrick   bool shouldVisitTemplateInstantiations() const { return true; }
106e5dd7070Spatrick 
VisitCXXMethodDecl(const CXXMethodDecl * MethodDecl)107e5dd7070Spatrick   bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) {
108e5dd7070Spatrick     if (MethodDecl->isVirtual())
109e5dd7070Spatrick       OverriddenMethods.push_back(MethodDecl);
110e5dd7070Spatrick     if (MethodDecl->getInstantiatedFromMemberFunction())
111e5dd7070Spatrick       InstantiatedMethods.push_back(MethodDecl);
112e5dd7070Spatrick     return true;
113e5dd7070Spatrick   }
114e5dd7070Spatrick 
115e5dd7070Spatrick private:
handleCXXRecordDecl(const CXXRecordDecl * RecordDecl)116e5dd7070Spatrick   void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) {
117e5dd7070Spatrick     if (!RecordDecl->getDefinition()) {
118e5dd7070Spatrick       USRSet.insert(getUSRForDecl(RecordDecl));
119e5dd7070Spatrick       return;
120e5dd7070Spatrick     }
121e5dd7070Spatrick     RecordDecl = RecordDecl->getDefinition();
122e5dd7070Spatrick     if (const auto *ClassTemplateSpecDecl =
123e5dd7070Spatrick             dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl))
124e5dd7070Spatrick       handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate());
125e5dd7070Spatrick     addUSRsOfCtorDtors(RecordDecl);
126e5dd7070Spatrick   }
127e5dd7070Spatrick 
handleClassTemplateDecl(const ClassTemplateDecl * TemplateDecl)128e5dd7070Spatrick   void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) {
129e5dd7070Spatrick     for (const auto *Specialization : TemplateDecl->specializations())
130e5dd7070Spatrick       addUSRsOfCtorDtors(Specialization);
131a9ac8606Spatrick     SmallVector<ClassTemplatePartialSpecializationDecl *, 4> PartialSpecs;
132a9ac8606Spatrick     TemplateDecl->getPartialSpecializations(PartialSpecs);
133a9ac8606Spatrick     for (const auto *Spec : PartialSpecs)
134a9ac8606Spatrick       addUSRsOfCtorDtors(Spec);
135e5dd7070Spatrick     addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl());
136e5dd7070Spatrick   }
137e5dd7070Spatrick 
handleFunctionTemplateDecl(const FunctionTemplateDecl * FTD)138a9ac8606Spatrick   void handleFunctionTemplateDecl(const FunctionTemplateDecl *FTD) {
139a9ac8606Spatrick     USRSet.insert(getUSRForDecl(FTD));
140a9ac8606Spatrick     USRSet.insert(getUSRForDecl(FTD->getTemplatedDecl()));
141a9ac8606Spatrick     for (const auto *S : FTD->specializations())
142a9ac8606Spatrick       USRSet.insert(getUSRForDecl(S));
143a9ac8606Spatrick   }
144a9ac8606Spatrick 
handleVarTemplateDecl(const VarTemplateDecl * VTD)145a9ac8606Spatrick   void handleVarTemplateDecl(const VarTemplateDecl *VTD) {
146a9ac8606Spatrick     USRSet.insert(getUSRForDecl(VTD));
147a9ac8606Spatrick     USRSet.insert(getUSRForDecl(VTD->getTemplatedDecl()));
148*12c85518Srobert     for (const auto *Spec : VTD->specializations())
149a9ac8606Spatrick       USRSet.insert(getUSRForDecl(Spec));
150a9ac8606Spatrick     SmallVector<VarTemplatePartialSpecializationDecl *, 4> PartialSpecs;
151a9ac8606Spatrick     VTD->getPartialSpecializations(PartialSpecs);
152*12c85518Srobert     for (const auto *Spec : PartialSpecs)
153a9ac8606Spatrick       USRSet.insert(getUSRForDecl(Spec));
154a9ac8606Spatrick   }
155a9ac8606Spatrick 
addUSRsOfCtorDtors(const CXXRecordDecl * RD)156ec727ea7Spatrick   void addUSRsOfCtorDtors(const CXXRecordDecl *RD) {
157ec727ea7Spatrick     const auto* RecordDecl = RD->getDefinition();
158e5dd7070Spatrick 
159e5dd7070Spatrick     // Skip if the CXXRecordDecl doesn't have definition.
160ec727ea7Spatrick     if (!RecordDecl) {
161ec727ea7Spatrick       USRSet.insert(getUSRForDecl(RD));
162e5dd7070Spatrick       return;
163ec727ea7Spatrick     }
164e5dd7070Spatrick 
165e5dd7070Spatrick     for (const auto *CtorDecl : RecordDecl->ctors())
166e5dd7070Spatrick       USRSet.insert(getUSRForDecl(CtorDecl));
167ec727ea7Spatrick     // Add template constructor decls, they are not in ctors() unfortunately.
168ec727ea7Spatrick     if (RecordDecl->hasUserDeclaredConstructor())
169ec727ea7Spatrick       for (const auto *D : RecordDecl->decls())
170ec727ea7Spatrick         if (const auto *FTD = dyn_cast<FunctionTemplateDecl>(D))
171ec727ea7Spatrick           if (const auto *Ctor =
172ec727ea7Spatrick                   dyn_cast<CXXConstructorDecl>(FTD->getTemplatedDecl()))
173ec727ea7Spatrick             USRSet.insert(getUSRForDecl(Ctor));
174e5dd7070Spatrick 
175e5dd7070Spatrick     USRSet.insert(getUSRForDecl(RecordDecl->getDestructor()));
176e5dd7070Spatrick     USRSet.insert(getUSRForDecl(RecordDecl));
177e5dd7070Spatrick   }
178e5dd7070Spatrick 
addUSRsOfOverridenFunctions(const CXXMethodDecl * MethodDecl)179e5dd7070Spatrick   void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) {
180e5dd7070Spatrick     USRSet.insert(getUSRForDecl(MethodDecl));
181e5dd7070Spatrick     // Recursively visit each OverridenMethod.
182e5dd7070Spatrick     for (const auto &OverriddenMethod : MethodDecl->overridden_methods())
183e5dd7070Spatrick       addUSRsOfOverridenFunctions(OverriddenMethod);
184e5dd7070Spatrick   }
185e5dd7070Spatrick 
addUSRsOfInstantiatedMethods(const CXXMethodDecl * MethodDecl)186e5dd7070Spatrick   void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) {
187e5dd7070Spatrick     // For renaming a class template method, all references of the instantiated
188e5dd7070Spatrick     // member methods should be renamed too, so add USRs of the instantiated
189e5dd7070Spatrick     // methods to the USR set.
190e5dd7070Spatrick     USRSet.insert(getUSRForDecl(MethodDecl));
191e5dd7070Spatrick     if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction())
192e5dd7070Spatrick       USRSet.insert(getUSRForDecl(FT));
193e5dd7070Spatrick     for (const auto *Method : InstantiatedMethods) {
194e5dd7070Spatrick       if (USRSet.find(getUSRForDecl(
195e5dd7070Spatrick               Method->getInstantiatedFromMemberFunction())) != USRSet.end())
196e5dd7070Spatrick         USRSet.insert(getUSRForDecl(Method));
197e5dd7070Spatrick     }
198e5dd7070Spatrick   }
199e5dd7070Spatrick 
checkIfOverriddenFunctionAscends(const CXXMethodDecl * MethodDecl)200e5dd7070Spatrick   bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) {
201e5dd7070Spatrick     for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) {
202e5dd7070Spatrick       if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
203e5dd7070Spatrick         return true;
204e5dd7070Spatrick       return checkIfOverriddenFunctionAscends(OverriddenMethod);
205e5dd7070Spatrick     }
206e5dd7070Spatrick     return false;
207e5dd7070Spatrick   }
208e5dd7070Spatrick 
209e5dd7070Spatrick   const Decl *FoundDecl;
210e5dd7070Spatrick   ASTContext &Context;
211e5dd7070Spatrick   std::set<std::string> USRSet;
212e5dd7070Spatrick   std::vector<const CXXMethodDecl *> OverriddenMethods;
213e5dd7070Spatrick   std::vector<const CXXMethodDecl *> InstantiatedMethods;
214e5dd7070Spatrick };
215e5dd7070Spatrick } // namespace
216e5dd7070Spatrick 
getUSRsForDeclaration(const NamedDecl * ND,ASTContext & Context)217e5dd7070Spatrick std::vector<std::string> getUSRsForDeclaration(const NamedDecl *ND,
218e5dd7070Spatrick                                                ASTContext &Context) {
219e5dd7070Spatrick   AdditionalUSRFinder Finder(ND, Context);
220e5dd7070Spatrick   return Finder.Find();
221e5dd7070Spatrick }
222e5dd7070Spatrick 
223e5dd7070Spatrick class NamedDeclFindingConsumer : public ASTConsumer {
224e5dd7070Spatrick public:
NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,ArrayRef<std::string> QualifiedNames,std::vector<std::string> & SpellingNames,std::vector<std::vector<std::string>> & USRList,bool Force,bool & ErrorOccurred)225e5dd7070Spatrick   NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,
226e5dd7070Spatrick                            ArrayRef<std::string> QualifiedNames,
227e5dd7070Spatrick                            std::vector<std::string> &SpellingNames,
228e5dd7070Spatrick                            std::vector<std::vector<std::string>> &USRList,
229e5dd7070Spatrick                            bool Force, bool &ErrorOccurred)
230e5dd7070Spatrick       : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames),
231e5dd7070Spatrick         SpellingNames(SpellingNames), USRList(USRList), Force(Force),
232e5dd7070Spatrick         ErrorOccurred(ErrorOccurred) {}
233e5dd7070Spatrick 
234e5dd7070Spatrick private:
FindSymbol(ASTContext & Context,const SourceManager & SourceMgr,unsigned SymbolOffset,const std::string & QualifiedName)235e5dd7070Spatrick   bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr,
236e5dd7070Spatrick                   unsigned SymbolOffset, const std::string &QualifiedName) {
237e5dd7070Spatrick     DiagnosticsEngine &Engine = Context.getDiagnostics();
238e5dd7070Spatrick     const FileID MainFileID = SourceMgr.getMainFileID();
239e5dd7070Spatrick 
240e5dd7070Spatrick     if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) {
241e5dd7070Spatrick       ErrorOccurred = true;
242e5dd7070Spatrick       unsigned InvalidOffset = Engine.getCustomDiagID(
243e5dd7070Spatrick           DiagnosticsEngine::Error,
244e5dd7070Spatrick           "SourceLocation in file %0 at offset %1 is invalid");
245e5dd7070Spatrick       Engine.Report(SourceLocation(), InvalidOffset)
246e5dd7070Spatrick           << SourceMgr.getFileEntryForID(MainFileID)->getName() << SymbolOffset;
247e5dd7070Spatrick       return false;
248e5dd7070Spatrick     }
249e5dd7070Spatrick 
250e5dd7070Spatrick     const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID)
251e5dd7070Spatrick                                      .getLocWithOffset(SymbolOffset);
252e5dd7070Spatrick     const NamedDecl *FoundDecl = QualifiedName.empty()
253e5dd7070Spatrick                                      ? getNamedDeclAt(Context, Point)
254e5dd7070Spatrick                                      : getNamedDeclFor(Context, QualifiedName);
255e5dd7070Spatrick 
256e5dd7070Spatrick     if (FoundDecl == nullptr) {
257e5dd7070Spatrick       if (QualifiedName.empty()) {
258e5dd7070Spatrick         FullSourceLoc FullLoc(Point, SourceMgr);
259e5dd7070Spatrick         unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID(
260e5dd7070Spatrick             DiagnosticsEngine::Error,
261e5dd7070Spatrick             "clang-rename could not find symbol (offset %0)");
262e5dd7070Spatrick         Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset;
263e5dd7070Spatrick         ErrorOccurred = true;
264e5dd7070Spatrick         return false;
265e5dd7070Spatrick       }
266e5dd7070Spatrick 
267e5dd7070Spatrick       if (Force) {
268e5dd7070Spatrick         SpellingNames.push_back(std::string());
269e5dd7070Spatrick         USRList.push_back(std::vector<std::string>());
270e5dd7070Spatrick         return true;
271e5dd7070Spatrick       }
272e5dd7070Spatrick 
273e5dd7070Spatrick       unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID(
274e5dd7070Spatrick           DiagnosticsEngine::Error, "clang-rename could not find symbol %0");
275e5dd7070Spatrick       Engine.Report(CouldNotFindSymbolNamed) << QualifiedName;
276e5dd7070Spatrick       ErrorOccurred = true;
277e5dd7070Spatrick       return false;
278e5dd7070Spatrick     }
279e5dd7070Spatrick 
280e5dd7070Spatrick     FoundDecl = getCanonicalSymbolDeclaration(FoundDecl);
281e5dd7070Spatrick     SpellingNames.push_back(FoundDecl->getNameAsString());
282e5dd7070Spatrick     AdditionalUSRFinder Finder(FoundDecl, Context);
283e5dd7070Spatrick     USRList.push_back(Finder.Find());
284e5dd7070Spatrick     return true;
285e5dd7070Spatrick   }
286e5dd7070Spatrick 
HandleTranslationUnit(ASTContext & Context)287e5dd7070Spatrick   void HandleTranslationUnit(ASTContext &Context) override {
288e5dd7070Spatrick     const SourceManager &SourceMgr = Context.getSourceManager();
289e5dd7070Spatrick     for (unsigned Offset : SymbolOffsets) {
290e5dd7070Spatrick       if (!FindSymbol(Context, SourceMgr, Offset, ""))
291e5dd7070Spatrick         return;
292e5dd7070Spatrick     }
293e5dd7070Spatrick     for (const std::string &QualifiedName : QualifiedNames) {
294e5dd7070Spatrick       if (!FindSymbol(Context, SourceMgr, 0, QualifiedName))
295e5dd7070Spatrick         return;
296e5dd7070Spatrick     }
297e5dd7070Spatrick   }
298e5dd7070Spatrick 
299e5dd7070Spatrick   ArrayRef<unsigned> SymbolOffsets;
300e5dd7070Spatrick   ArrayRef<std::string> QualifiedNames;
301e5dd7070Spatrick   std::vector<std::string> &SpellingNames;
302e5dd7070Spatrick   std::vector<std::vector<std::string>> &USRList;
303e5dd7070Spatrick   bool Force;
304e5dd7070Spatrick   bool &ErrorOccurred;
305e5dd7070Spatrick };
306e5dd7070Spatrick 
newASTConsumer()307e5dd7070Spatrick std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() {
308e5dd7070Spatrick   return std::make_unique<NamedDeclFindingConsumer>(
309e5dd7070Spatrick       SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force,
310e5dd7070Spatrick       ErrorOccurred);
311e5dd7070Spatrick }
312e5dd7070Spatrick 
313e5dd7070Spatrick } // end namespace tooling
314e5dd7070Spatrick } // end namespace clang
315