xref: /llvm-project/clang-tools-extra/clang-tidy/readability/ContainerContainsCheck.cpp (revision 3605d9a456185f4af78c01a2684b822b57bca9b0)
1 //===--- ContainerContainsCheck.cpp - clang-tidy --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ContainerContainsCheck.h"
10 #include "clang/AST/ASTContext.h"
11 #include "clang/ASTMatchers/ASTMatchFinder.h"
12 
13 using namespace clang::ast_matchers;
14 
15 namespace clang::tidy::readability {
16 void ContainerContainsCheck::registerMatchers(MatchFinder *Finder) {
17   const auto HasContainsMatchingParamType = hasMethod(
18       cxxMethodDecl(isConst(), parameterCountIs(1), returns(booleanType()),
19                     hasName("contains"), unless(isDeleted()), isPublic(),
20                     hasParameter(0, hasType(hasUnqualifiedDesugaredType(
21                                         equalsBoundNode("parameterType"))))));
22 
23   const auto CountCall =
24       cxxMemberCallExpr(
25           argumentCountIs(1),
26           callee(cxxMethodDecl(
27               hasName("count"),
28               hasParameter(0, hasType(hasUnqualifiedDesugaredType(
29                                   type().bind("parameterType")))),
30               ofClass(cxxRecordDecl(HasContainsMatchingParamType)))))
31           .bind("call");
32 
33   const auto FindCall =
34       cxxMemberCallExpr(
35           argumentCountIs(1),
36           callee(cxxMethodDecl(
37               hasName("find"),
38               hasParameter(0, hasType(hasUnqualifiedDesugaredType(
39                                   type().bind("parameterType")))),
40               ofClass(cxxRecordDecl(HasContainsMatchingParamType)))))
41           .bind("call");
42 
43   const auto EndCall = cxxMemberCallExpr(
44       argumentCountIs(0),
45       callee(
46           cxxMethodDecl(hasName("end"),
47                         // In the matchers below, FindCall should always appear
48                         // before EndCall so 'parameterType' is properly bound.
49                         ofClass(cxxRecordDecl(HasContainsMatchingParamType)))));
50 
51   const auto Literal0 = integerLiteral(equals(0));
52   const auto Literal1 = integerLiteral(equals(1));
53 
54   auto AddSimpleMatcher = [&](auto Matcher) {
55     Finder->addMatcher(
56         traverse(TK_IgnoreUnlessSpelledInSource, std::move(Matcher)), this);
57   };
58 
59   // Find membership tests which use `count()`.
60   Finder->addMatcher(implicitCastExpr(hasImplicitDestinationType(booleanType()),
61                                       hasSourceExpression(CountCall))
62                          .bind("positiveComparison"),
63                      this);
64   AddSimpleMatcher(
65       binaryOperation(hasOperatorName("!="), hasOperands(CountCall, Literal0))
66           .bind("positiveComparison"));
67   AddSimpleMatcher(
68       binaryOperation(hasLHS(CountCall), hasOperatorName(">"), hasRHS(Literal0))
69           .bind("positiveComparison"));
70   AddSimpleMatcher(
71       binaryOperation(hasLHS(Literal0), hasOperatorName("<"), hasRHS(CountCall))
72           .bind("positiveComparison"));
73   AddSimpleMatcher(binaryOperation(hasLHS(CountCall), hasOperatorName(">="),
74                                    hasRHS(Literal1))
75                        .bind("positiveComparison"));
76   AddSimpleMatcher(binaryOperation(hasLHS(Literal1), hasOperatorName("<="),
77                                    hasRHS(CountCall))
78                        .bind("positiveComparison"));
79 
80   // Find inverted membership tests which use `count()`.
81   AddSimpleMatcher(
82       binaryOperation(hasOperatorName("=="), hasOperands(CountCall, Literal0))
83           .bind("negativeComparison"));
84   AddSimpleMatcher(binaryOperation(hasLHS(CountCall), hasOperatorName("<="),
85                                    hasRHS(Literal0))
86                        .bind("negativeComparison"));
87   AddSimpleMatcher(binaryOperation(hasLHS(Literal0), hasOperatorName(">="),
88                                    hasRHS(CountCall))
89                        .bind("negativeComparison"));
90   AddSimpleMatcher(
91       binaryOperation(hasLHS(CountCall), hasOperatorName("<"), hasRHS(Literal1))
92           .bind("negativeComparison"));
93   AddSimpleMatcher(
94       binaryOperation(hasLHS(Literal1), hasOperatorName(">"), hasRHS(CountCall))
95           .bind("negativeComparison"));
96 
97   // Find membership tests based on `find() == end()`.
98   AddSimpleMatcher(
99       binaryOperation(hasOperatorName("!="), hasOperands(FindCall, EndCall))
100           .bind("positiveComparison"));
101   AddSimpleMatcher(
102       binaryOperation(hasOperatorName("=="), hasOperands(FindCall, EndCall))
103           .bind("negativeComparison"));
104 }
105 
106 void ContainerContainsCheck::check(const MatchFinder::MatchResult &Result) {
107   // Extract the information about the match
108   const auto *Call = Result.Nodes.getNodeAs<CXXMemberCallExpr>("call");
109   const auto *PositiveComparison =
110       Result.Nodes.getNodeAs<Expr>("positiveComparison");
111   const auto *NegativeComparison =
112       Result.Nodes.getNodeAs<Expr>("negativeComparison");
113   assert((!PositiveComparison || !NegativeComparison) &&
114          "only one of PositiveComparison or NegativeComparison should be set");
115   bool Negated = NegativeComparison != nullptr;
116   const auto *Comparison = Negated ? NegativeComparison : PositiveComparison;
117 
118   // Diagnose the issue.
119   auto Diag =
120       diag(Call->getExprLoc(), "use 'contains' to check for membership");
121 
122   // Don't fix it if it's in a macro invocation. Leave fixing it to the user.
123   SourceLocation FuncCallLoc = Comparison->getEndLoc();
124   if (!FuncCallLoc.isValid() || FuncCallLoc.isMacroID())
125     return;
126 
127   // Create the fix it.
128   const auto *Member = cast<MemberExpr>(Call->getCallee());
129   Diag << FixItHint::CreateReplacement(
130       Member->getMemberNameInfo().getSourceRange(), "contains");
131   SourceLocation ComparisonBegin = Comparison->getSourceRange().getBegin();
132   SourceLocation ComparisonEnd = Comparison->getSourceRange().getEnd();
133   SourceLocation CallBegin = Call->getSourceRange().getBegin();
134   SourceLocation CallEnd = Call->getSourceRange().getEnd();
135   Diag << FixItHint::CreateReplacement(
136       CharSourceRange::getCharRange(ComparisonBegin, CallBegin),
137       Negated ? "!" : "");
138   Diag << FixItHint::CreateRemoval(CharSourceRange::getTokenRange(
139       CallEnd.getLocWithOffset(1), ComparisonEnd));
140 }
141 
142 } // namespace clang::tidy::readability
143