xref: /llvm-project/clang-tools-extra/clang-include-fixer/IncludeFixer.cpp (revision df9a14d7bbf1180e4f1474254c9d7ed6bcb4ce55)
1 //===-- IncludeFixer.cpp - Include inserter based on sema callbacks -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IncludeFixer.h"
10 #include "clang/Format/Format.h"
11 #include "clang/Frontend/CompilerInstance.h"
12 #include "clang/Lex/HeaderSearch.h"
13 #include "clang/Lex/Preprocessor.h"
14 #include "clang/Parse/ParseAST.h"
15 #include "clang/Sema/Sema.h"
16 #include "llvm/Support/Debug.h"
17 #include "llvm/Support/raw_ostream.h"
18 
19 #define DEBUG_TYPE "clang-include-fixer"
20 
21 using namespace clang;
22 
23 namespace clang {
24 namespace include_fixer {
25 namespace {
26 /// Manages the parse, gathers include suggestions.
27 class Action : public clang::ASTFrontendAction {
28 public:
29   explicit Action(SymbolIndexManager &SymbolIndexMgr, bool MinimizeIncludePaths)
30       : SemaSource(new IncludeFixerSemaSource(SymbolIndexMgr,
31                                               MinimizeIncludePaths,
32                                               /*GenerateDiagnostics=*/false)) {}
33 
34   std::unique_ptr<clang::ASTConsumer>
35   CreateASTConsumer(clang::CompilerInstance &Compiler,
36                     StringRef InFile) override {
37     SemaSource->setFilePath(InFile);
38     return std::make_unique<clang::ASTConsumer>();
39   }
40 
41   void ExecuteAction() override {
42     clang::CompilerInstance *Compiler = &getCompilerInstance();
43     assert(!Compiler->hasSema() && "CI already has Sema");
44 
45     // Set up our hooks into sema and parse the AST.
46     if (hasCodeCompletionSupport() &&
47         !Compiler->getFrontendOpts().CodeCompletionAt.FileName.empty())
48       Compiler->createCodeCompletionConsumer();
49 
50     clang::CodeCompleteConsumer *CompletionConsumer = nullptr;
51     if (Compiler->hasCodeCompletionConsumer())
52       CompletionConsumer = &Compiler->getCodeCompletionConsumer();
53 
54     Compiler->createSema(getTranslationUnitKind(), CompletionConsumer);
55     SemaSource->setCompilerInstance(Compiler);
56     Compiler->getSema().addExternalSource(SemaSource.get());
57 
58     clang::ParseAST(Compiler->getSema(), Compiler->getFrontendOpts().ShowStats,
59                     Compiler->getFrontendOpts().SkipFunctionBodies);
60   }
61 
62   IncludeFixerContext
63   getIncludeFixerContext(const clang::SourceManager &SourceManager,
64                          clang::HeaderSearch &HeaderSearch) const {
65     return SemaSource->getIncludeFixerContext(SourceManager, HeaderSearch,
66                                               SemaSource->getMatchedSymbols());
67   }
68 
69 private:
70   IntrusiveRefCntPtr<IncludeFixerSemaSource> SemaSource;
71 };
72 
73 } // namespace
74 
75 IncludeFixerActionFactory::IncludeFixerActionFactory(
76     SymbolIndexManager &SymbolIndexMgr,
77     std::vector<IncludeFixerContext> &Contexts, StringRef StyleName,
78     bool MinimizeIncludePaths)
79     : SymbolIndexMgr(SymbolIndexMgr), Contexts(Contexts),
80       MinimizeIncludePaths(MinimizeIncludePaths) {}
81 
82 IncludeFixerActionFactory::~IncludeFixerActionFactory() = default;
83 
84 bool IncludeFixerActionFactory::runInvocation(
85     std::shared_ptr<clang::CompilerInvocation> Invocation,
86     clang::FileManager *Files,
87     std::shared_ptr<clang::PCHContainerOperations> PCHContainerOps,
88     clang::DiagnosticConsumer *Diagnostics) {
89   assert(Invocation->getFrontendOpts().Inputs.size() == 1);
90 
91   // Set up Clang.
92   clang::CompilerInstance Compiler(PCHContainerOps);
93   Compiler.setInvocation(std::move(Invocation));
94   Compiler.setFileManager(Files);
95 
96   // Create the compiler's actual diagnostics engine. We want to drop all
97   // diagnostics here.
98   Compiler.createDiagnostics(Files->getVirtualFileSystem(),
99                              new clang::IgnoringDiagConsumer,
100                              /*ShouldOwnClient=*/true);
101   Compiler.createSourceManager(*Files);
102 
103   // We abort on fatal errors so don't let a large number of errors become
104   // fatal. A missing #include can cause thousands of errors.
105   Compiler.getDiagnostics().setErrorLimit(0);
106 
107   // Run the parser, gather missing includes.
108   auto ScopedToolAction =
109       std::make_unique<Action>(SymbolIndexMgr, MinimizeIncludePaths);
110   Compiler.ExecuteAction(*ScopedToolAction);
111 
112   Contexts.push_back(ScopedToolAction->getIncludeFixerContext(
113       Compiler.getSourceManager(),
114       Compiler.getPreprocessor().getHeaderSearchInfo()));
115 
116   // Technically this should only return true if we're sure that we have a
117   // parseable file. We don't know that though. Only inform users of fatal
118   // errors.
119   return !Compiler.getDiagnostics().hasFatalErrorOccurred();
120 }
121 
122 static bool addDiagnosticsForContext(TypoCorrection &Correction,
123                                      const IncludeFixerContext &Context,
124                                      StringRef Code, SourceLocation StartOfFile,
125                                      ASTContext &Ctx) {
126   auto Reps = createIncludeFixerReplacements(
127       Code, Context, format::getLLVMStyle(), /*AddQualifiers=*/false);
128   if (!Reps || Reps->size() != 1)
129     return false;
130 
131   unsigned DiagID = Ctx.getDiagnostics().getCustomDiagID(
132       DiagnosticsEngine::Note, "Add '#include %0' to provide the missing "
133                                "declaration [clang-include-fixer]");
134 
135   // FIXME: Currently we only generate a diagnostic for the first header. Give
136   // the user choices.
137   const tooling::Replacement &Placed = *Reps->begin();
138 
139   auto Begin = StartOfFile.getLocWithOffset(Placed.getOffset());
140   auto End = Begin.getLocWithOffset(std::max(0, (int)Placed.getLength() - 1));
141   PartialDiagnostic PD(DiagID, Ctx.getDiagAllocator());
142   PD << Context.getHeaderInfos().front().Header
143      << FixItHint::CreateReplacement(CharSourceRange::getCharRange(Begin, End),
144                                      Placed.getReplacementText());
145   Correction.addExtraDiagnostic(std::move(PD));
146   return true;
147 }
148 
149 /// Callback for incomplete types. If we encounter a forward declaration we
150 /// have the fully qualified name ready. Just query that.
151 bool IncludeFixerSemaSource::MaybeDiagnoseMissingCompleteType(
152     clang::SourceLocation Loc, clang::QualType T) {
153   // Ignore spurious callbacks from SFINAE contexts.
154   if (CI->getSema().isSFINAEContext())
155     return false;
156 
157   clang::ASTContext &context = CI->getASTContext();
158   std::string QueryString = QualType(T->getUnqualifiedDesugaredType(), 0)
159                                 .getAsString(context.getPrintingPolicy());
160   LLVM_DEBUG(llvm::dbgs() << "Query missing complete type '" << QueryString
161                           << "'");
162   // Pass an empty range here since we don't add qualifier in this case.
163   std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
164       query(QueryString, "", tooling::Range());
165 
166   if (!MatchedSymbols.empty() && GenerateDiagnostics) {
167     TypoCorrection Correction;
168     FileID FID = CI->getSourceManager().getFileID(Loc);
169     StringRef Code = CI->getSourceManager().getBufferData(FID);
170     SourceLocation StartOfFile =
171         CI->getSourceManager().getLocForStartOfFile(FID);
172     addDiagnosticsForContext(
173         Correction,
174         getIncludeFixerContext(CI->getSourceManager(),
175                                CI->getPreprocessor().getHeaderSearchInfo(),
176                                MatchedSymbols),
177         Code, StartOfFile, CI->getASTContext());
178     for (const PartialDiagnostic &PD : Correction.getExtraDiagnostics())
179       CI->getSema().Diag(Loc, PD);
180   }
181   return true;
182 }
183 
184 /// Callback for unknown identifiers. Try to piece together as much
185 /// qualification as we can get and do a query.
186 clang::TypoCorrection IncludeFixerSemaSource::CorrectTypo(
187     const DeclarationNameInfo &Typo, int LookupKind, Scope *S, CXXScopeSpec *SS,
188     CorrectionCandidateCallback &CCC, DeclContext *MemberContext,
189     bool EnteringContext, const ObjCObjectPointerType *OPT) {
190   // Ignore spurious callbacks from SFINAE contexts.
191   if (CI->getSema().isSFINAEContext())
192     return clang::TypoCorrection();
193 
194   // We currently ignore the unidentified symbol which is not from the
195   // main file.
196   //
197   // However, this is not always true due to templates in a non-self contained
198   // header, consider the case:
199   //
200   //   // header.h
201   //   template <typename T>
202   //   class Foo {
203   //     T t;
204   //   };
205   //
206   //   // test.cc
207   //   // We need to add <bar.h> in test.cc instead of header.h.
208   //   class Bar;
209   //   Foo<Bar> foo;
210   //
211   // FIXME: Add the missing header to the header file where the symbol comes
212   // from.
213   if (!CI->getSourceManager().isWrittenInMainFile(Typo.getLoc()))
214     return clang::TypoCorrection();
215 
216   std::string TypoScopeString;
217   if (S) {
218     // FIXME: Currently we only use namespace contexts. Use other context
219     // types for query.
220     for (const auto *Context = S->getEntity(); Context;
221          Context = Context->getParent()) {
222       if (const auto *ND = dyn_cast<NamespaceDecl>(Context)) {
223         if (!ND->getName().empty())
224           TypoScopeString = ND->getNameAsString() + "::" + TypoScopeString;
225       }
226     }
227   }
228 
229   auto ExtendNestedNameSpecifier = [this](CharSourceRange Range) {
230     StringRef Source =
231         Lexer::getSourceText(Range, CI->getSourceManager(), CI->getLangOpts());
232 
233     // Skip forward until we find a character that's neither identifier nor
234     // colon. This is a bit of a hack around the fact that we will only get a
235     // single callback for a long nested name if a part of the beginning is
236     // unknown. For example:
237     //
238     // llvm::sys::path::parent_path(...)
239     // ^~~~  ^~~
240     //    known
241     //            ^~~~
242     //      unknown, last callback
243     //                  ^~~~~~~~~~~
244     //                  no callback
245     //
246     // With the extension we get the full nested name specifier including
247     // parent_path.
248     // FIXME: Don't rely on source text.
249     const char *End = Source.end();
250     while (isAsciiIdentifierContinue(*End) || *End == ':')
251       ++End;
252 
253     return std::string(Source.begin(), End);
254   };
255 
256   /// If we have a scope specification, use that to get more precise results.
257   std::string QueryString;
258   tooling::Range SymbolRange;
259   const auto &SM = CI->getSourceManager();
260   auto CreateToolingRange = [&QueryString, &SM](SourceLocation BeginLoc) {
261     return tooling::Range(SM.getDecomposedLoc(BeginLoc).second,
262                           QueryString.size());
263   };
264   if (SS && SS->getRange().isValid()) {
265     auto Range = CharSourceRange::getTokenRange(SS->getRange().getBegin(),
266                                                 Typo.getLoc());
267 
268     QueryString = ExtendNestedNameSpecifier(Range);
269     SymbolRange = CreateToolingRange(Range.getBegin());
270   } else if (Typo.getName().isIdentifier() && !Typo.getLoc().isMacroID()) {
271     auto Range =
272         CharSourceRange::getTokenRange(Typo.getBeginLoc(), Typo.getEndLoc());
273 
274     QueryString = ExtendNestedNameSpecifier(Range);
275     SymbolRange = CreateToolingRange(Range.getBegin());
276   } else {
277     QueryString = Typo.getAsString();
278     SymbolRange = CreateToolingRange(Typo.getLoc());
279   }
280 
281   LLVM_DEBUG(llvm::dbgs() << "TypoScopeQualifiers: " << TypoScopeString
282                           << "\n");
283   std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
284       query(QueryString, TypoScopeString, SymbolRange);
285 
286   if (!MatchedSymbols.empty() && GenerateDiagnostics) {
287     TypoCorrection Correction(Typo.getName());
288     Correction.setCorrectionRange(SS, Typo);
289     FileID FID = SM.getFileID(Typo.getLoc());
290     StringRef Code = SM.getBufferData(FID);
291     SourceLocation StartOfFile = SM.getLocForStartOfFile(FID);
292     if (addDiagnosticsForContext(
293             Correction, getIncludeFixerContext(
294                             SM, CI->getPreprocessor().getHeaderSearchInfo(),
295                             MatchedSymbols),
296             Code, StartOfFile, CI->getASTContext()))
297       return Correction;
298   }
299   return TypoCorrection();
300 }
301 
302 /// Get the minimal include for a given path.
303 std::string IncludeFixerSemaSource::minimizeInclude(
304     StringRef Include, const clang::SourceManager &SourceManager,
305     clang::HeaderSearch &HeaderSearch) const {
306   if (!MinimizeIncludePaths)
307     return std::string(Include);
308 
309   // Get the FileEntry for the include.
310   StringRef StrippedInclude = Include.trim("\"<>");
311   auto Entry =
312       SourceManager.getFileManager().getOptionalFileRef(StrippedInclude);
313 
314   // If the file doesn't exist return the path from the database.
315   // FIXME: This should never happen.
316   if (!Entry)
317     return std::string(Include);
318 
319   bool IsAngled = false;
320   std::string Suggestion =
321       HeaderSearch.suggestPathToFileForDiagnostics(*Entry, "", &IsAngled);
322 
323   return IsAngled ? '<' + Suggestion + '>' : '"' + Suggestion + '"';
324 }
325 
326 /// Get the include fixer context for the queried symbol.
327 IncludeFixerContext IncludeFixerSemaSource::getIncludeFixerContext(
328     const clang::SourceManager &SourceManager,
329     clang::HeaderSearch &HeaderSearch,
330     ArrayRef<find_all_symbols::SymbolInfo> MatchedSymbols) const {
331   std::vector<find_all_symbols::SymbolInfo> SymbolCandidates;
332   for (const auto &Symbol : MatchedSymbols) {
333     std::string FilePath = Symbol.getFilePath().str();
334     std::string MinimizedFilePath = minimizeInclude(
335         ((FilePath[0] == '"' || FilePath[0] == '<') ? FilePath
336                                                     : "\"" + FilePath + "\""),
337         SourceManager, HeaderSearch);
338     SymbolCandidates.emplace_back(Symbol.getName(), Symbol.getSymbolKind(),
339                                   MinimizedFilePath, Symbol.getContexts());
340   }
341   return IncludeFixerContext(FilePath, QuerySymbolInfos, SymbolCandidates);
342 }
343 
344 std::vector<find_all_symbols::SymbolInfo>
345 IncludeFixerSemaSource::query(StringRef Query, StringRef ScopedQualifiers,
346                               tooling::Range Range) {
347   assert(!Query.empty() && "Empty query!");
348 
349   // Save all instances of an unidentified symbol.
350   //
351   // We use conservative behavior for detecting the same unidentified symbol
352   // here. The symbols which have the same ScopedQualifier and RawIdentifier
353   // are considered equal. So that clang-include-fixer avoids false positives,
354   // and always adds missing qualifiers to correct symbols.
355   if (!GenerateDiagnostics && !QuerySymbolInfos.empty()) {
356     if (ScopedQualifiers == QuerySymbolInfos.front().ScopedQualifiers &&
357         Query == QuerySymbolInfos.front().RawIdentifier) {
358       QuerySymbolInfos.push_back(
359           {Query.str(), std::string(ScopedQualifiers), Range});
360     }
361     return {};
362   }
363 
364   LLVM_DEBUG(llvm::dbgs() << "Looking up '" << Query << "' at ");
365   LLVM_DEBUG(CI->getSourceManager()
366                  .getLocForStartOfFile(CI->getSourceManager().getMainFileID())
367                  .getLocWithOffset(Range.getOffset())
368                  .print(llvm::dbgs(), CI->getSourceManager()));
369   LLVM_DEBUG(llvm::dbgs() << " ...");
370   llvm::StringRef FileName = CI->getSourceManager().getFilename(
371       CI->getSourceManager().getLocForStartOfFile(
372           CI->getSourceManager().getMainFileID()));
373 
374   QuerySymbolInfos.push_back(
375       {Query.str(), std::string(ScopedQualifiers), Range});
376 
377   // Query the symbol based on C++ name Lookup rules.
378   // Firstly, lookup the identifier with scoped namespace contexts;
379   // If that fails, falls back to look up the identifier directly.
380   //
381   // For example:
382   //
383   // namespace a {
384   // b::foo f;
385   // }
386   //
387   // 1. lookup a::b::foo.
388   // 2. lookup b::foo.
389   std::string QueryString = ScopedQualifiers.str() + Query.str();
390   // It's unsafe to do nested search for the identifier with scoped namespace
391   // context, it might treat the identifier as a nested class of the scoped
392   // namespace.
393   std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
394       SymbolIndexMgr.search(QueryString, /*IsNestedSearch=*/false, FileName);
395   if (MatchedSymbols.empty())
396     MatchedSymbols =
397         SymbolIndexMgr.search(Query, /*IsNestedSearch=*/true, FileName);
398   LLVM_DEBUG(llvm::dbgs() << "Having found " << MatchedSymbols.size()
399                           << " symbols\n");
400   // We store a copy of MatchedSymbols in a place where it's globally reachable.
401   // This is used by the standalone version of the tool.
402   this->MatchedSymbols = MatchedSymbols;
403   return MatchedSymbols;
404 }
405 
406 llvm::Expected<tooling::Replacements> createIncludeFixerReplacements(
407     StringRef Code, const IncludeFixerContext &Context,
408     const clang::format::FormatStyle &Style, bool AddQualifiers) {
409   if (Context.getHeaderInfos().empty())
410     return tooling::Replacements();
411   StringRef FilePath = Context.getFilePath();
412   std::string IncludeName =
413       "#include " + Context.getHeaderInfos().front().Header + "\n";
414   // Create replacements for the new header.
415   clang::tooling::Replacements Insertions;
416   auto Err =
417       Insertions.add(tooling::Replacement(FilePath, UINT_MAX, 0, IncludeName));
418   if (Err)
419     return std::move(Err);
420 
421   auto CleanReplaces = cleanupAroundReplacements(Code, Insertions, Style);
422   if (!CleanReplaces)
423     return CleanReplaces;
424 
425   auto Replaces = std::move(*CleanReplaces);
426   if (AddQualifiers) {
427     for (const auto &Info : Context.getQuerySymbolInfos()) {
428       // Ignore the empty range.
429       if (Info.Range.getLength() > 0) {
430         auto R = tooling::Replacement(
431             {FilePath, Info.Range.getOffset(), Info.Range.getLength(),
432              Context.getHeaderInfos().front().QualifiedName});
433         auto Err = Replaces.add(R);
434         if (Err) {
435           llvm::consumeError(std::move(Err));
436           R = tooling::Replacement(
437               R.getFilePath(), Replaces.getShiftedCodePosition(R.getOffset()),
438               R.getLength(), R.getReplacementText());
439           Replaces = Replaces.merge(tooling::Replacements(R));
440         }
441       }
442     }
443   }
444   return formatReplacements(Code, Replaces, Style);
445 }
446 
447 } // namespace include_fixer
448 } // namespace clang
449