xref: /llvm-project/clang-tools-extra/clang-tidy/readability/UseStdMinMaxCheck.cpp (revision 32bcd41adcc664f6d690efc9b7cd209ac9c65f68)
1 //===--- UseStdMinMaxCheck.cpp - clang-tidy -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "UseStdMinMaxCheck.h"
10 #include "../utils/ASTUtils.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/ASTMatchers/ASTMatchFinder.h"
13 #include "clang/Lex/Preprocessor.h"
14 
15 using namespace clang::ast_matchers;
16 
17 namespace clang::tidy::readability {
18 
19 namespace {
20 
21 // Ignore if statements that are inside macros.
22 AST_MATCHER(IfStmt, isIfInMacro) {
23   return Node.getIfLoc().isMacroID() || Node.getEndLoc().isMacroID();
24 }
25 
26 } // namespace
27 
28 static const llvm::StringRef AlgorithmHeader("<algorithm>");
29 
30 static bool minCondition(const BinaryOperator::Opcode Op, const Expr *CondLhs,
31                          const Expr *CondRhs, const Expr *AssignLhs,
32                          const Expr *AssignRhs, const ASTContext &Context) {
33   if ((Op == BO_LT || Op == BO_LE) &&
34       (tidy::utils::areStatementsIdentical(CondLhs, AssignRhs, Context) &&
35        tidy::utils::areStatementsIdentical(CondRhs, AssignLhs, Context)))
36     return true;
37 
38   if ((Op == BO_GT || Op == BO_GE) &&
39       (tidy::utils::areStatementsIdentical(CondLhs, AssignLhs, Context) &&
40        tidy::utils::areStatementsIdentical(CondRhs, AssignRhs, Context)))
41     return true;
42 
43   return false;
44 }
45 
46 static bool maxCondition(const BinaryOperator::Opcode Op, const Expr *CondLhs,
47                          const Expr *CondRhs, const Expr *AssignLhs,
48                          const Expr *AssignRhs, const ASTContext &Context) {
49   if ((Op == BO_LT || Op == BO_LE) &&
50       (tidy::utils::areStatementsIdentical(CondLhs, AssignLhs, Context) &&
51        tidy::utils::areStatementsIdentical(CondRhs, AssignRhs, Context)))
52     return true;
53 
54   if ((Op == BO_GT || Op == BO_GE) &&
55       (tidy::utils::areStatementsIdentical(CondLhs, AssignRhs, Context) &&
56        tidy::utils::areStatementsIdentical(CondRhs, AssignLhs, Context)))
57     return true;
58 
59   return false;
60 }
61 
62 static QualType getNonTemplateAlias(QualType QT) {
63   while (true) {
64     // cast to a TypedefType
65     if (const TypedefType *TT = dyn_cast<TypedefType>(QT)) {
66       // check if the typedef is a template and if it is dependent
67       if (!TT->getDecl()->getDescribedTemplate() &&
68           !TT->getDecl()->getDeclContext()->isDependentContext())
69         return QT;
70       QT = TT->getDecl()->getUnderlyingType();
71     }
72     // cast to elaborated type
73     else if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(QT)) {
74       QT = ET->getNamedType();
75     } else {
76       break;
77     }
78   }
79   return QT;
80 }
81 
82 static QualType getReplacementCastType(const Expr *CondLhs, const Expr *CondRhs,
83                                        QualType ComparedType) {
84   QualType LhsType = CondLhs->getType();
85   QualType RhsType = CondRhs->getType();
86   QualType LhsCanonicalType =
87       LhsType.getCanonicalType().getNonReferenceType().getUnqualifiedType();
88   QualType RhsCanonicalType =
89       RhsType.getCanonicalType().getNonReferenceType().getUnqualifiedType();
90   QualType GlobalImplicitCastType;
91   if (LhsCanonicalType != RhsCanonicalType) {
92     if (llvm::isa<IntegerLiteral>(CondRhs)) {
93       GlobalImplicitCastType = getNonTemplateAlias(LhsType);
94     } else if (llvm::isa<IntegerLiteral>(CondLhs)) {
95       GlobalImplicitCastType = getNonTemplateAlias(RhsType);
96     } else {
97       GlobalImplicitCastType = getNonTemplateAlias(ComparedType);
98     }
99   }
100   return GlobalImplicitCastType;
101 }
102 
103 static std::string createReplacement(const Expr *CondLhs, const Expr *CondRhs,
104                                      const Expr *AssignLhs,
105                                      const SourceManager &Source,
106                                      const LangOptions &LO,
107                                      StringRef FunctionName,
108                                      const BinaryOperator *BO) {
109   const llvm::StringRef CondLhsStr = Lexer::getSourceText(
110       Source.getExpansionRange(CondLhs->getSourceRange()), Source, LO);
111   const llvm::StringRef CondRhsStr = Lexer::getSourceText(
112       Source.getExpansionRange(CondRhs->getSourceRange()), Source, LO);
113   const llvm::StringRef AssignLhsStr = Lexer::getSourceText(
114       Source.getExpansionRange(AssignLhs->getSourceRange()), Source, LO);
115 
116   QualType GlobalImplicitCastType =
117       getReplacementCastType(CondLhs, CondRhs, BO->getLHS()->getType());
118 
119   return (AssignLhsStr + " = " + FunctionName +
120           (!GlobalImplicitCastType.isNull()
121                ? "<" + GlobalImplicitCastType.getAsString() + ">("
122                : "(") +
123           CondLhsStr + ", " + CondRhsStr + ");")
124       .str();
125 }
126 
127 UseStdMinMaxCheck::UseStdMinMaxCheck(StringRef Name, ClangTidyContext *Context)
128     : ClangTidyCheck(Name, Context),
129       IncludeInserter(Options.getLocalOrGlobal("IncludeStyle",
130                                                utils::IncludeSorter::IS_LLVM),
131                       areDiagsSelfContained()) {}
132 
133 void UseStdMinMaxCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
134   Options.store(Opts, "IncludeStyle", IncludeInserter.getStyle());
135 }
136 
137 void UseStdMinMaxCheck::registerMatchers(MatchFinder *Finder) {
138   auto AssignOperator =
139       binaryOperator(hasOperatorName("="),
140                      hasLHS(expr(unless(isTypeDependent())).bind("AssignLhs")),
141                      hasRHS(expr(unless(isTypeDependent())).bind("AssignRhs")));
142   auto BinaryOperator =
143       binaryOperator(hasAnyOperatorName("<", ">", "<=", ">="),
144                      hasLHS(expr(unless(isTypeDependent())).bind("CondLhs")),
145                      hasRHS(expr(unless(isTypeDependent())).bind("CondRhs")))
146           .bind("binaryOp");
147   Finder->addMatcher(
148       ifStmt(stmt().bind("if"), unless(isIfInMacro()),
149              unless(hasElse(stmt())), // Ensure `if` has no `else`
150              hasCondition(BinaryOperator),
151              hasThen(
152                  anyOf(stmt(AssignOperator),
153                        compoundStmt(statementCountIs(1), has(AssignOperator)))),
154              hasParent(stmt(unless(ifStmt(hasElse(
155                  equalsBoundNode("if"))))))), // Ensure `if` has no `else if`
156       this);
157 }
158 
159 void UseStdMinMaxCheck::registerPPCallbacks(const SourceManager &SM,
160                                             Preprocessor *PP,
161                                             Preprocessor *ModuleExpanderPP) {
162   IncludeInserter.registerPreprocessor(PP);
163 }
164 
165 void UseStdMinMaxCheck::check(const MatchFinder::MatchResult &Result) {
166   const auto *If = Result.Nodes.getNodeAs<IfStmt>("if");
167   const clang::LangOptions &LO = Result.Context->getLangOpts();
168   const auto *CondLhs = Result.Nodes.getNodeAs<Expr>("CondLhs");
169   const auto *CondRhs = Result.Nodes.getNodeAs<Expr>("CondRhs");
170   const auto *AssignLhs = Result.Nodes.getNodeAs<Expr>("AssignLhs");
171   const auto *AssignRhs = Result.Nodes.getNodeAs<Expr>("AssignRhs");
172   const auto *BinaryOp = Result.Nodes.getNodeAs<BinaryOperator>("binaryOp");
173   const clang::BinaryOperatorKind BinaryOpcode = BinaryOp->getOpcode();
174   const SourceLocation IfLocation = If->getIfLoc();
175   const SourceLocation ThenLocation = If->getEndLoc();
176 
177   auto ReplaceAndDiagnose = [&](const llvm::StringRef FunctionName) {
178     const SourceManager &Source = *Result.SourceManager;
179     diag(IfLocation, "use `%0` instead of `%1`")
180         << FunctionName << BinaryOp->getOpcodeStr()
181         << FixItHint::CreateReplacement(
182                SourceRange(IfLocation, Lexer::getLocForEndOfToken(
183                                            ThenLocation, 0, Source, LO)),
184                createReplacement(CondLhs, CondRhs, AssignLhs, Source, LO,
185                                  FunctionName, BinaryOp))
186         << IncludeInserter.createIncludeInsertion(
187                Source.getFileID(If->getBeginLoc()), AlgorithmHeader);
188   };
189 
190   if (minCondition(BinaryOpcode, CondLhs, CondRhs, AssignLhs, AssignRhs,
191                    (*Result.Context))) {
192     ReplaceAndDiagnose("std::min");
193   } else if (maxCondition(BinaryOpcode, CondLhs, CondRhs, AssignLhs, AssignRhs,
194                           (*Result.Context))) {
195     ReplaceAndDiagnose("std::max");
196   }
197 }
198 
199 } // namespace clang::tidy::readability
200