xref: /llvm-project/clang-tools-extra/clang-tidy/utils/UseRangesCheck.cpp (revision 0762db6533eda3453158c7b9b0631542c47093a8)
1 //===--- UseRangesCheck.cpp - clang-tidy ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "UseRangesCheck.h"
10 #include "Matchers.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/Decl.h"
13 #include "clang/AST/Expr.h"
14 #include "clang/ASTMatchers/ASTMatchFinder.h"
15 #include "clang/ASTMatchers/ASTMatchers.h"
16 #include "clang/ASTMatchers/ASTMatchersInternal.h"
17 #include "clang/Basic/Diagnostic.h"
18 #include "clang/Basic/LLVM.h"
19 #include "clang/Basic/SourceLocation.h"
20 #include "clang/Basic/SourceManager.h"
21 #include "clang/Lex/Lexer.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/ADT/Twine.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <cassert>
31 #include <optional>
32 #include <string>
33 
34 using namespace clang::ast_matchers;
35 
36 static constexpr const char BoundCall[] = "CallExpr";
37 static constexpr const char FuncDecl[] = "FuncDecl";
38 static constexpr const char ArgName[] = "ArgName";
39 
40 namespace clang::tidy::utils {
41 
42 static std::string getFullPrefix(ArrayRef<UseRangesCheck::Indexes> Signature) {
43   std::string Output;
44   llvm::raw_string_ostream OS(Output);
45   for (const UseRangesCheck::Indexes &Item : Signature)
46     OS << Item.BeginArg << ":" << Item.EndArg << ":"
47        << (Item.ReplaceArg == Item.First ? '0' : '1');
48   return Output;
49 }
50 
51 namespace {
52 
53 AST_MATCHER(Expr, hasSideEffects) {
54   return Node.HasSideEffects(Finder->getASTContext());
55 }
56 } // namespace
57 
58 static auto
59 makeExprMatcher(ast_matchers::internal::Matcher<Expr> ArgumentMatcher,
60                 ArrayRef<StringRef> MethodNames,
61                 ArrayRef<StringRef> FreeNames) {
62   return expr(
63       anyOf(cxxMemberCallExpr(argumentCountIs(0),
64                               callee(cxxMethodDecl(hasAnyName(MethodNames))),
65                               on(ArgumentMatcher)),
66             callExpr(argumentCountIs(1), hasArgument(0, ArgumentMatcher),
67                      hasDeclaration(functionDecl(hasAnyName(FreeNames))))));
68 }
69 
70 static ast_matchers::internal::Matcher<CallExpr>
71 makeMatcherPair(StringRef State, const UseRangesCheck::Indexes &Indexes,
72                 ArrayRef<StringRef> BeginFreeNames,
73                 ArrayRef<StringRef> EndFreeNames,
74                 const std::optional<UseRangesCheck::ReverseIteratorDescriptor>
75                     &ReverseDescriptor) {
76   std::string ArgBound = (ArgName + llvm::Twine(Indexes.BeginArg)).str();
77   SmallString<64> ID = {BoundCall, State};
78   ast_matchers::internal::Matcher<CallExpr> ArgumentMatcher = allOf(
79       hasArgument(Indexes.BeginArg,
80                   makeExprMatcher(expr(unless(hasSideEffects())).bind(ArgBound),
81                                   {"begin", "cbegin"}, BeginFreeNames)),
82       hasArgument(Indexes.EndArg,
83                   makeExprMatcher(
84                       expr(matchers::isStatementIdenticalToBoundNode(ArgBound)),
85                       {"end", "cend"}, EndFreeNames)));
86   if (ReverseDescriptor) {
87     ArgBound.push_back('R');
88     SmallVector<StringRef> RBegin{
89         llvm::make_first_range(ReverseDescriptor->FreeReverseNames)};
90     SmallVector<StringRef> REnd{
91         llvm::make_second_range(ReverseDescriptor->FreeReverseNames)};
92     ArgumentMatcher = anyOf(
93         ArgumentMatcher,
94         allOf(hasArgument(
95                   Indexes.BeginArg,
96                   makeExprMatcher(expr(unless(hasSideEffects())).bind(ArgBound),
97                                   {"rbegin", "crbegin"}, RBegin)),
98               hasArgument(
99                   Indexes.EndArg,
100                   makeExprMatcher(
101                       expr(matchers::isStatementIdenticalToBoundNode(ArgBound)),
102                       {"rend", "crend"}, REnd))));
103   }
104   return callExpr(argumentCountAtLeast(
105                       std::max(Indexes.BeginArg, Indexes.EndArg) + 1),
106                   ArgumentMatcher)
107       .bind(ID);
108 }
109 
110 void UseRangesCheck::registerMatchers(MatchFinder *Finder) {
111   auto Replaces = getReplacerMap();
112   ReverseDescriptor = getReverseDescriptor();
113   auto BeginEndNames = getFreeBeginEndMethods();
114   llvm::SmallVector<StringRef, 4> BeginNames{
115       llvm::make_first_range(BeginEndNames)};
116   llvm::SmallVector<StringRef, 4> EndNames{
117       llvm::make_second_range(BeginEndNames)};
118   Replacers.clear();
119   llvm::DenseSet<Replacer *> SeenRepl;
120   for (auto I = Replaces.begin(), E = Replaces.end(); I != E; ++I) {
121     auto Replacer = I->getValue();
122     if (!SeenRepl.insert(Replacer.get()).second)
123       continue;
124     Replacers.push_back(Replacer);
125     assert(!Replacer->getReplacementSignatures().empty() &&
126            llvm::all_of(Replacer->getReplacementSignatures(),
127                         [](auto Index) { return !Index.empty(); }));
128     std::vector<StringRef> Names(1, I->getKey());
129     for (auto J = std::next(I); J != E; ++J)
130       if (J->getValue() == Replacer)
131         Names.push_back(J->getKey());
132 
133     std::vector<ast_matchers::internal::DynTypedMatcher> TotalMatchers;
134     // As we match on the first matched signature, we need to sort the
135     // signatures in order of length(longest to shortest). This way any
136     // signature that is a subset of another signature will be matched after the
137     // other.
138     SmallVector<Signature> SigVec(Replacer->getReplacementSignatures());
139     llvm::sort(SigVec, [](auto &L, auto &R) { return R.size() < L.size(); });
140     for (const auto &Signature : SigVec) {
141       std::vector<ast_matchers::internal::DynTypedMatcher> Matchers;
142       for (const auto &ArgPair : Signature)
143         Matchers.push_back(makeMatcherPair(getFullPrefix(Signature), ArgPair,
144                                            BeginNames, EndNames,
145                                            ReverseDescriptor));
146       TotalMatchers.push_back(
147           ast_matchers::internal::DynTypedMatcher::constructVariadic(
148               ast_matchers::internal::DynTypedMatcher::VO_AllOf,
149               ASTNodeKind::getFromNodeKind<CallExpr>(), std::move(Matchers)));
150     }
151     Finder->addMatcher(
152         callExpr(
153             callee(functionDecl(hasAnyName(std::move(Names)))
154                        .bind((FuncDecl + Twine(Replacers.size() - 1).str()))),
155             ast_matchers::internal::DynTypedMatcher::constructVariadic(
156                 ast_matchers::internal::DynTypedMatcher::VO_AnyOf,
157                 ASTNodeKind::getFromNodeKind<CallExpr>(),
158                 std::move(TotalMatchers))
159                 .convertTo<CallExpr>()),
160         this);
161   }
162 }
163 
164 static void removeFunctionArgs(DiagnosticBuilder &Diag, const CallExpr &Call,
165                                ArrayRef<unsigned> Indexes,
166                                const ASTContext &Ctx) {
167   llvm::SmallVector<unsigned> Sorted(Indexes);
168   llvm::sort(Sorted);
169   // Keep track of commas removed
170   llvm::SmallBitVector Commas(Call.getNumArgs());
171   // The first comma is actually the '(' which we can't remove
172   Commas[0] = true;
173   for (unsigned Index : Sorted) {
174     const Expr *Arg = Call.getArg(Index);
175     if (Commas[Index]) {
176       if (Index >= Commas.size()) {
177         Diag << FixItHint::CreateRemoval(Arg->getSourceRange());
178       } else {
179         // Remove the next comma
180         Commas[Index + 1] = true;
181         Diag << FixItHint::CreateRemoval(CharSourceRange::getTokenRange(
182             {Arg->getBeginLoc(),
183              Lexer::getLocForEndOfToken(
184                  Arg->getEndLoc(), 0, Ctx.getSourceManager(), Ctx.getLangOpts())
185                  .getLocWithOffset(1)}));
186       }
187     } else {
188       Diag << FixItHint::CreateRemoval(CharSourceRange::getTokenRange(
189           Arg->getBeginLoc().getLocWithOffset(-1), Arg->getEndLoc()));
190       Commas[Index] = true;
191     }
192   }
193 }
194 
195 void UseRangesCheck::check(const MatchFinder::MatchResult &Result) {
196   Replacer *Replacer = nullptr;
197   const FunctionDecl *Function = nullptr;
198   for (auto [Node, Value] : Result.Nodes.getMap()) {
199     StringRef NodeStr(Node);
200     if (!NodeStr.consume_front(FuncDecl))
201       continue;
202     Function = Value.get<FunctionDecl>();
203     size_t Index;
204     if (NodeStr.getAsInteger(10, Index)) {
205       llvm_unreachable("Unable to extract replacer index");
206     }
207     assert(Index < Replacers.size());
208     Replacer = Replacers[Index].get();
209     break;
210   }
211   assert(Replacer && Function);
212   SmallString<64> Buffer;
213   for (const Signature &Sig : Replacer->getReplacementSignatures()) {
214     Buffer.assign({BoundCall, getFullPrefix(Sig)});
215     const auto *Call = Result.Nodes.getNodeAs<CallExpr>(Buffer);
216     if (!Call)
217       continue;
218     auto Diag = createDiag(*Call);
219     if (auto ReplaceName = Replacer->getReplaceName(*Function))
220       Diag << FixItHint::CreateReplacement(Call->getCallee()->getSourceRange(),
221                                            *ReplaceName);
222     if (auto Include = Replacer->getHeaderInclusion(*Function))
223       Diag << Inserter.createIncludeInsertion(
224           Result.SourceManager->getFileID(Call->getBeginLoc()), *Include);
225     llvm::SmallVector<unsigned, 3> ToRemove;
226     for (const auto &[First, Second, Replace] : Sig) {
227       auto ArgNode = ArgName + std::to_string(First);
228       if (const auto *ArgExpr = Result.Nodes.getNodeAs<Expr>(ArgNode)) {
229         Diag << FixItHint::CreateReplacement(
230             Call->getArg(Replace == Indexes::Second ? Second : First)
231                 ->getSourceRange(),
232             Lexer::getSourceText(
233                 CharSourceRange::getTokenRange(ArgExpr->getSourceRange()),
234                 Result.Context->getSourceManager(),
235                 Result.Context->getLangOpts()));
236       } else {
237         assert(ReverseDescriptor && "Couldn't find forward argument");
238         ArgNode.push_back('R');
239         ArgExpr = Result.Nodes.getNodeAs<Expr>(ArgNode);
240         assert(ArgExpr && "Couldn't find forward or reverse argument");
241         if (ReverseDescriptor->ReverseHeader)
242           Diag << Inserter.createIncludeInsertion(
243               Result.SourceManager->getFileID(Call->getBeginLoc()),
244               *ReverseDescriptor->ReverseHeader);
245         StringRef ArgText = Lexer::getSourceText(
246             CharSourceRange::getTokenRange(ArgExpr->getSourceRange()),
247             Result.Context->getSourceManager(), Result.Context->getLangOpts());
248         SmallString<128> ReplaceText;
249         if (ReverseDescriptor->IsPipeSyntax)
250           ReplaceText.assign(
251               {ArgText, " | ", ReverseDescriptor->ReverseAdaptorName});
252         else
253           ReplaceText.assign(
254               {ReverseDescriptor->ReverseAdaptorName, "(", ArgText, ")"});
255         Diag << FixItHint::CreateReplacement(
256             Call->getArg(Replace == Indexes::Second ? Second : First)
257                 ->getSourceRange(),
258             ReplaceText);
259       }
260       ToRemove.push_back(Replace == Indexes::Second ? First : Second);
261     }
262     removeFunctionArgs(Diag, *Call, ToRemove, *Result.Context);
263     return;
264   }
265   llvm_unreachable("No valid signature found");
266 }
267 
268 bool UseRangesCheck::isLanguageVersionSupported(
269     const LangOptions &LangOpts) const {
270   return LangOpts.CPlusPlus11;
271 }
272 
273 UseRangesCheck::UseRangesCheck(StringRef Name, ClangTidyContext *Context)
274     : ClangTidyCheck(Name, Context),
275       Inserter(Options.getLocalOrGlobal("IncludeStyle",
276                                         utils::IncludeSorter::IS_LLVM),
277                areDiagsSelfContained()) {}
278 
279 void UseRangesCheck::registerPPCallbacks(const SourceManager &,
280                                          Preprocessor *PP, Preprocessor *) {
281   Inserter.registerPreprocessor(PP);
282 }
283 
284 void UseRangesCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
285   Options.store(Opts, "IncludeStyle", Inserter.getStyle());
286 }
287 
288 std::optional<std::string>
289 UseRangesCheck::Replacer::getHeaderInclusion(const NamedDecl &) const {
290   return std::nullopt;
291 }
292 
293 DiagnosticBuilder UseRangesCheck::createDiag(const CallExpr &Call) {
294   return diag(Call.getBeginLoc(), "use a ranges version of this algorithm");
295 }
296 
297 std::optional<UseRangesCheck::ReverseIteratorDescriptor>
298 UseRangesCheck::getReverseDescriptor() const {
299   return std::nullopt;
300 }
301 
302 ArrayRef<std::pair<StringRef, StringRef>>
303 UseRangesCheck::getFreeBeginEndMethods() const {
304   return {};
305 }
306 
307 std::optional<TraversalKind> UseRangesCheck::getCheckTraversalKind() const {
308   return TK_IgnoreUnlessSpelledInSource;
309 }
310 } // namespace clang::tidy::utils
311