xref: /llvm-project/clang-tools-extra/clang-tidy/modernize/MinMaxUseInitializerListCheck.cpp (revision 605a9adb4340b347f480a95a6eef3c9045e8416f)
1 //===--- MinMaxUseInitializerListCheck.cpp - clang-tidy -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MinMaxUseInitializerListCheck.h"
10 #include "../utils/ASTUtils.h"
11 #include "../utils/LexerUtils.h"
12 #include "clang/ASTMatchers/ASTMatchFinder.h"
13 #include "clang/Frontend/CompilerInstance.h"
14 #include "clang/Lex/Lexer.h"
15 
16 using namespace clang;
17 
18 namespace {
19 
20 struct FindArgsResult {
21   const Expr *First;
22   const Expr *Last;
23   const Expr *Compare;
24   SmallVector<const clang::Expr *, 2> Args;
25 };
26 
27 } // anonymous namespace
28 
29 using namespace clang::ast_matchers;
30 
31 namespace clang::tidy::modernize {
32 
33 static FindArgsResult findArgs(const CallExpr *Call) {
34   FindArgsResult Result;
35   Result.First = nullptr;
36   Result.Last = nullptr;
37   Result.Compare = nullptr;
38 
39   //   check if the function has initializer list argument
40   if (Call->getNumArgs() < 3) {
41     auto ArgIterator = Call->arguments().begin();
42 
43     const auto *InitListExpr =
44         dyn_cast<CXXStdInitializerListExpr>(*ArgIterator);
45     const auto *InitList =
46         InitListExpr != nullptr
47             ? dyn_cast<clang::InitListExpr>(
48                   InitListExpr->getSubExpr()->IgnoreImplicit())
49             : nullptr;
50 
51     if (InitList) {
52       Result.Args.append(InitList->inits().begin(), InitList->inits().end());
53       Result.First = *ArgIterator;
54       Result.Last = *ArgIterator;
55 
56       // check if there is a comparison argument
57       std::advance(ArgIterator, 1);
58       if (ArgIterator != Call->arguments().end())
59         Result.Compare = *ArgIterator;
60 
61       return Result;
62     }
63     Result.Args = SmallVector<const Expr *>(Call->arguments());
64   } else {
65     // if it has 3 arguments then the last will be the comparison
66     Result.Compare = *(std::next(Call->arguments().begin(), 2));
67     Result.Args = SmallVector<const Expr *>(llvm::drop_end(Call->arguments()));
68   }
69   Result.First = Result.Args.front();
70   Result.Last = Result.Args.back();
71 
72   return Result;
73 }
74 
75 // Returns `true` as `first` only if a nested call to `std::min` or
76 // `std::max` was found. Checking if `FixItHint`s were generated is not enough,
77 // as the explicit casts that the check introduces may be generated without a
78 // nested `std::min` or `std::max` call.
79 static std::pair<bool, SmallVector<FixItHint>>
80 generateReplacements(const MatchFinder::MatchResult &Match,
81                      const CallExpr *TopCall, const FindArgsResult &Result,
82                      const bool IgnoreNonTrivialTypes,
83                      const std::uint64_t IgnoreTrivialTypesOfSizeAbove) {
84   SmallVector<FixItHint> FixItHints;
85   const SourceManager &SourceMngr = *Match.SourceManager;
86   const LangOptions &LanguageOpts = Match.Context->getLangOpts();
87 
88   const QualType ResultType = TopCall->getDirectCallee()
89                                   ->getReturnType()
90                                   .getCanonicalType()
91                                   .getNonReferenceType()
92                                   .getUnqualifiedType();
93 
94   // check if the type is trivial
95   const bool IsResultTypeTrivial = ResultType.isTrivialType(*Match.Context);
96 
97   if ((!IsResultTypeTrivial && IgnoreNonTrivialTypes))
98     return {false, FixItHints};
99 
100   if (IsResultTypeTrivial &&
101       static_cast<std::uint64_t>(
102           Match.Context->getTypeSizeInChars(ResultType).getQuantity()) >
103           IgnoreTrivialTypesOfSizeAbove)
104     return {false, FixItHints};
105 
106   bool FoundNestedCall = false;
107 
108   for (const Expr *Arg : Result.Args) {
109     const auto *InnerCall = dyn_cast<CallExpr>(Arg->IgnoreParenImpCasts());
110 
111     // If the argument is not a nested call
112     if (!InnerCall) {
113       // check if typecast is required
114       const QualType ArgType = Arg->IgnoreParenImpCasts()
115                                    ->getType()
116                                    .getCanonicalType()
117                                    .getUnqualifiedType();
118 
119       if (ArgType == ResultType)
120         continue;
121 
122       const StringRef ArgText = Lexer::getSourceText(
123           CharSourceRange::getTokenRange(Arg->getSourceRange()), SourceMngr,
124           LanguageOpts);
125 
126       const auto Replacement = Twine("static_cast<")
127                                    .concat(ResultType.getAsString(LanguageOpts))
128                                    .concat(">(")
129                                    .concat(ArgText)
130                                    .concat(")")
131                                    .str();
132 
133       FixItHints.push_back(
134           FixItHint::CreateReplacement(Arg->getSourceRange(), Replacement));
135       continue;
136     }
137 
138     // if the nested call is not the same as the top call
139     if (InnerCall->getDirectCallee()->getQualifiedNameAsString() !=
140         TopCall->getDirectCallee()->getQualifiedNameAsString())
141       continue;
142 
143     const FindArgsResult InnerResult = findArgs(InnerCall);
144 
145     // if the nested call doesn't have arguments skip it
146     if (!InnerResult.First || !InnerResult.Last)
147       continue;
148 
149     // if the nested call doesn't have the same compare function
150     if ((Result.Compare || InnerResult.Compare) &&
151         !utils::areStatementsIdentical(Result.Compare, InnerResult.Compare,
152                                        *Match.Context))
153       continue;
154 
155     // We have found a nested call
156     FoundNestedCall = true;
157 
158     // remove the function call
159     FixItHints.push_back(
160         FixItHint::CreateRemoval(InnerCall->getCallee()->getSourceRange()));
161 
162     // remove the parentheses
163     const auto LParen = utils::lexer::findNextTokenSkippingComments(
164         InnerCall->getCallee()->getEndLoc(), SourceMngr, LanguageOpts);
165     if (LParen.has_value() && LParen->is(tok::l_paren))
166       FixItHints.push_back(
167           FixItHint::CreateRemoval(SourceRange(LParen->getLocation())));
168     FixItHints.push_back(
169         FixItHint::CreateRemoval(SourceRange(InnerCall->getRParenLoc())));
170 
171     // if the inner call has an initializer list arg
172     if (InnerResult.First == InnerResult.Last) {
173       // remove the initializer list braces
174       FixItHints.push_back(FixItHint::CreateRemoval(
175           CharSourceRange::getTokenRange(InnerResult.First->getBeginLoc())));
176       FixItHints.push_back(FixItHint::CreateRemoval(
177           CharSourceRange::getTokenRange(InnerResult.First->getEndLoc())));
178     }
179 
180     const auto [_, InnerReplacements] = generateReplacements(
181         Match, InnerCall, InnerResult, IgnoreNonTrivialTypes,
182         IgnoreTrivialTypesOfSizeAbove);
183 
184     FixItHints.append(InnerReplacements);
185 
186     if (InnerResult.Compare) {
187       // find the comma after the value arguments
188       const auto Comma = utils::lexer::findNextTokenSkippingComments(
189           InnerResult.Last->getEndLoc(), SourceMngr, LanguageOpts);
190 
191       // remove the comma and the comparison
192       if (Comma.has_value() && Comma->is(tok::comma))
193         FixItHints.push_back(
194             FixItHint::CreateRemoval(SourceRange(Comma->getLocation())));
195 
196       FixItHints.push_back(
197           FixItHint::CreateRemoval(InnerResult.Compare->getSourceRange()));
198     }
199   }
200 
201   return {FoundNestedCall, FixItHints};
202 }
203 
204 MinMaxUseInitializerListCheck::MinMaxUseInitializerListCheck(
205     StringRef Name, ClangTidyContext *Context)
206     : ClangTidyCheck(Name, Context),
207       IgnoreNonTrivialTypes(Options.get("IgnoreNonTrivialTypes", true)),
208       IgnoreTrivialTypesOfSizeAbove(
209           Options.get("IgnoreTrivialTypesOfSizeAbove", 32L)),
210       Inserter(Options.getLocalOrGlobal("IncludeStyle",
211                                         utils::IncludeSorter::IS_LLVM),
212                areDiagsSelfContained()) {}
213 
214 void MinMaxUseInitializerListCheck::storeOptions(
215     ClangTidyOptions::OptionMap &Opts) {
216   Options.store(Opts, "IgnoreNonTrivialTypes", IgnoreNonTrivialTypes);
217   Options.store(Opts, "IgnoreTrivialTypesOfSizeAbove",
218                 IgnoreTrivialTypesOfSizeAbove);
219   Options.store(Opts, "IncludeStyle", Inserter.getStyle());
220 }
221 
222 void MinMaxUseInitializerListCheck::registerMatchers(MatchFinder *Finder) {
223   auto CreateMatcher = [](const StringRef FunctionName) {
224     auto FuncDecl = functionDecl(hasName(FunctionName));
225     auto Expression = callExpr(callee(FuncDecl));
226 
227     return callExpr(callee(FuncDecl),
228                     anyOf(hasArgument(0, Expression),
229                           hasArgument(1, Expression),
230                           hasArgument(0, cxxStdInitializerListExpr())),
231                     unless(hasParent(Expression)))
232         .bind("topCall");
233   };
234 
235   Finder->addMatcher(CreateMatcher("::std::max"), this);
236   Finder->addMatcher(CreateMatcher("::std::min"), this);
237 }
238 
239 void MinMaxUseInitializerListCheck::registerPPCallbacks(
240     const SourceManager &SM, Preprocessor *PP, Preprocessor *ModuleExpanderPP) {
241   Inserter.registerPreprocessor(PP);
242 }
243 
244 void MinMaxUseInitializerListCheck::check(
245     const MatchFinder::MatchResult &Match) {
246 
247   const auto *TopCall = Match.Nodes.getNodeAs<CallExpr>("topCall");
248 
249   const FindArgsResult Result = findArgs(TopCall);
250   const auto [FoundNestedCall, Replacements] =
251       generateReplacements(Match, TopCall, Result, IgnoreNonTrivialTypes,
252                            IgnoreTrivialTypesOfSizeAbove);
253 
254   if (!FoundNestedCall)
255     return;
256 
257   const DiagnosticBuilder Diagnostic =
258       diag(TopCall->getBeginLoc(),
259            "do not use nested 'std::%0' calls, use an initializer list instead")
260       << TopCall->getDirectCallee()->getName()
261       << Inserter.createIncludeInsertion(
262              Match.SourceManager->getFileID(TopCall->getBeginLoc()),
263              "<algorithm>");
264 
265   // if the top call doesn't have an initializer list argument
266   if (Result.First != Result.Last) {
267     // add { and } insertions
268     Diagnostic << FixItHint::CreateInsertion(Result.First->getBeginLoc(), "{");
269 
270     Diagnostic << FixItHint::CreateInsertion(
271         Lexer::getLocForEndOfToken(Result.Last->getEndLoc(), 0,
272                                    *Match.SourceManager,
273                                    Match.Context->getLangOpts()),
274         "}");
275   }
276 
277   Diagnostic << Replacements;
278 }
279 
280 } // namespace clang::tidy::modernize
281