xref: /llvm-project/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.cpp (revision e69794323338c5279011e6f01c12254da8dff10d)
1 //===-- SimplifyBooleanExprCheck.cpp - clang-tidy -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "SimplifyBooleanExprCheck.h"
10 #include "clang/AST/Expr.h"
11 #include "clang/AST/RecursiveASTVisitor.h"
12 #include "clang/Basic/DiagnosticIDs.h"
13 #include "clang/Lex/Lexer.h"
14 #include "llvm/Support/SaveAndRestore.h"
15 
16 #include <optional>
17 #include <string>
18 #include <utility>
19 
20 using namespace clang::ast_matchers;
21 
22 namespace clang::tidy::readability {
23 
24 namespace {
25 
getText(const ASTContext & Context,SourceRange Range)26 StringRef getText(const ASTContext &Context, SourceRange Range) {
27   return Lexer::getSourceText(CharSourceRange::getTokenRange(Range),
28                               Context.getSourceManager(),
29                               Context.getLangOpts());
30 }
31 
getText(const ASTContext & Context,T & Node)32 template <typename T> StringRef getText(const ASTContext &Context, T &Node) {
33   return getText(Context, Node.getSourceRange());
34 }
35 
36 } // namespace
37 
38 static constexpr char SimplifyOperatorDiagnostic[] =
39     "redundant boolean literal supplied to boolean operator";
40 static constexpr char SimplifyConditionDiagnostic[] =
41     "redundant boolean literal in if statement condition";
42 static constexpr char SimplifyConditionalReturnDiagnostic[] =
43     "redundant boolean literal in conditional return statement";
44 
needsParensAfterUnaryNegation(const Expr * E)45 static bool needsParensAfterUnaryNegation(const Expr *E) {
46   E = E->IgnoreImpCasts();
47   if (isa<BinaryOperator>(E) || isa<ConditionalOperator>(E))
48     return true;
49 
50   if (const auto *Op = dyn_cast<CXXOperatorCallExpr>(E))
51     return Op->getNumArgs() == 2 && Op->getOperator() != OO_Call &&
52            Op->getOperator() != OO_Subscript;
53 
54   return false;
55 }
56 
57 static std::pair<BinaryOperatorKind, BinaryOperatorKind> Opposites[] = {
58     {BO_LT, BO_GE}, {BO_GT, BO_LE}, {BO_EQ, BO_NE}};
59 
negatedOperator(const BinaryOperator * BinOp)60 static StringRef negatedOperator(const BinaryOperator *BinOp) {
61   const BinaryOperatorKind Opcode = BinOp->getOpcode();
62   for (auto NegatableOp : Opposites) {
63     if (Opcode == NegatableOp.first)
64       return BinaryOperator::getOpcodeStr(NegatableOp.second);
65     if (Opcode == NegatableOp.second)
66       return BinaryOperator::getOpcodeStr(NegatableOp.first);
67   }
68   return {};
69 }
70 
71 static std::pair<OverloadedOperatorKind, StringRef> OperatorNames[] = {
72     {OO_EqualEqual, "=="},   {OO_ExclaimEqual, "!="}, {OO_Less, "<"},
73     {OO_GreaterEqual, ">="}, {OO_Greater, ">"},       {OO_LessEqual, "<="}};
74 
getOperatorName(OverloadedOperatorKind OpKind)75 static StringRef getOperatorName(OverloadedOperatorKind OpKind) {
76   for (auto Name : OperatorNames) {
77     if (Name.first == OpKind)
78       return Name.second;
79   }
80 
81   return {};
82 }
83 
84 static std::pair<OverloadedOperatorKind, OverloadedOperatorKind>
85     OppositeOverloads[] = {{OO_EqualEqual, OO_ExclaimEqual},
86                            {OO_Less, OO_GreaterEqual},
87                            {OO_Greater, OO_LessEqual}};
88 
negatedOperator(const CXXOperatorCallExpr * OpCall)89 static StringRef negatedOperator(const CXXOperatorCallExpr *OpCall) {
90   const OverloadedOperatorKind Opcode = OpCall->getOperator();
91   for (auto NegatableOp : OppositeOverloads) {
92     if (Opcode == NegatableOp.first)
93       return getOperatorName(NegatableOp.second);
94     if (Opcode == NegatableOp.second)
95       return getOperatorName(NegatableOp.first);
96   }
97   return {};
98 }
99 
asBool(StringRef Text,bool NeedsStaticCast)100 static std::string asBool(StringRef Text, bool NeedsStaticCast) {
101   if (NeedsStaticCast)
102     return ("static_cast<bool>(" + Text + ")").str();
103 
104   return std::string(Text);
105 }
106 
needsNullPtrComparison(const Expr * E)107 static bool needsNullPtrComparison(const Expr *E) {
108   if (const auto *ImpCast = dyn_cast<ImplicitCastExpr>(E))
109     return ImpCast->getCastKind() == CK_PointerToBoolean ||
110            ImpCast->getCastKind() == CK_MemberPointerToBoolean;
111 
112   return false;
113 }
114 
needsZeroComparison(const Expr * E)115 static bool needsZeroComparison(const Expr *E) {
116   if (const auto *ImpCast = dyn_cast<ImplicitCastExpr>(E))
117     return ImpCast->getCastKind() == CK_IntegralToBoolean;
118 
119   return false;
120 }
121 
needsStaticCast(const Expr * E)122 static bool needsStaticCast(const Expr *E) {
123   if (const auto *ImpCast = dyn_cast<ImplicitCastExpr>(E)) {
124     if (ImpCast->getCastKind() == CK_UserDefinedConversion &&
125         ImpCast->getSubExpr()->getType()->isBooleanType()) {
126       if (const auto *MemCall =
127               dyn_cast<CXXMemberCallExpr>(ImpCast->getSubExpr())) {
128         if (const auto *MemDecl =
129                 dyn_cast<CXXConversionDecl>(MemCall->getMethodDecl())) {
130           if (MemDecl->isExplicit())
131             return true;
132         }
133       }
134     }
135   }
136 
137   E = E->IgnoreImpCasts();
138   return !E->getType()->isBooleanType();
139 }
140 
compareExpressionToConstant(const ASTContext & Context,const Expr * E,bool Negated,const char * Constant)141 static std::string compareExpressionToConstant(const ASTContext &Context,
142                                                const Expr *E, bool Negated,
143                                                const char *Constant) {
144   E = E->IgnoreImpCasts();
145   const std::string ExprText =
146       (isa<BinaryOperator>(E) ? ("(" + getText(Context, *E) + ")")
147                               : getText(Context, *E))
148           .str();
149   return ExprText + " " + (Negated ? "!=" : "==") + " " + Constant;
150 }
151 
compareExpressionToNullPtr(const ASTContext & Context,const Expr * E,bool Negated)152 static std::string compareExpressionToNullPtr(const ASTContext &Context,
153                                               const Expr *E, bool Negated) {
154   const char *NullPtr = Context.getLangOpts().CPlusPlus11 ? "nullptr" : "NULL";
155   return compareExpressionToConstant(Context, E, Negated, NullPtr);
156 }
157 
compareExpressionToZero(const ASTContext & Context,const Expr * E,bool Negated)158 static std::string compareExpressionToZero(const ASTContext &Context,
159                                            const Expr *E, bool Negated) {
160   return compareExpressionToConstant(Context, E, Negated, "0");
161 }
162 
replacementExpression(const ASTContext & Context,bool Negated,const Expr * E)163 static std::string replacementExpression(const ASTContext &Context,
164                                          bool Negated, const Expr *E) {
165   E = E->IgnoreParenBaseCasts();
166   if (const auto *EC = dyn_cast<ExprWithCleanups>(E))
167     E = EC->getSubExpr();
168 
169   const bool NeedsStaticCast = needsStaticCast(E);
170   if (Negated) {
171     if (const auto *UnOp = dyn_cast<UnaryOperator>(E)) {
172       if (UnOp->getOpcode() == UO_LNot) {
173         if (needsNullPtrComparison(UnOp->getSubExpr()))
174           return compareExpressionToNullPtr(Context, UnOp->getSubExpr(), true);
175 
176         if (needsZeroComparison(UnOp->getSubExpr()))
177           return compareExpressionToZero(Context, UnOp->getSubExpr(), true);
178 
179         return replacementExpression(Context, false, UnOp->getSubExpr());
180       }
181     }
182 
183     if (needsNullPtrComparison(E))
184       return compareExpressionToNullPtr(Context, E, false);
185 
186     if (needsZeroComparison(E))
187       return compareExpressionToZero(Context, E, false);
188 
189     StringRef NegatedOperator;
190     const Expr *LHS = nullptr;
191     const Expr *RHS = nullptr;
192     if (const auto *BinOp = dyn_cast<BinaryOperator>(E)) {
193       NegatedOperator = negatedOperator(BinOp);
194       LHS = BinOp->getLHS();
195       RHS = BinOp->getRHS();
196     } else if (const auto *OpExpr = dyn_cast<CXXOperatorCallExpr>(E)) {
197       if (OpExpr->getNumArgs() == 2) {
198         NegatedOperator = negatedOperator(OpExpr);
199         LHS = OpExpr->getArg(0);
200         RHS = OpExpr->getArg(1);
201       }
202     }
203     if (!NegatedOperator.empty() && LHS && RHS)
204       return (asBool((getText(Context, *LHS) + " " + NegatedOperator + " " +
205                       getText(Context, *RHS))
206                          .str(),
207                      NeedsStaticCast));
208 
209     StringRef Text = getText(Context, *E);
210     if (!NeedsStaticCast && needsParensAfterUnaryNegation(E))
211       return ("!(" + Text + ")").str();
212 
213     if (needsNullPtrComparison(E))
214       return compareExpressionToNullPtr(Context, E, false);
215 
216     if (needsZeroComparison(E))
217       return compareExpressionToZero(Context, E, false);
218 
219     return ("!" + asBool(Text, NeedsStaticCast));
220   }
221 
222   if (const auto *UnOp = dyn_cast<UnaryOperator>(E)) {
223     if (UnOp->getOpcode() == UO_LNot) {
224       if (needsNullPtrComparison(UnOp->getSubExpr()))
225         return compareExpressionToNullPtr(Context, UnOp->getSubExpr(), false);
226 
227       if (needsZeroComparison(UnOp->getSubExpr()))
228         return compareExpressionToZero(Context, UnOp->getSubExpr(), false);
229     }
230   }
231 
232   if (needsNullPtrComparison(E))
233     return compareExpressionToNullPtr(Context, E, true);
234 
235   if (needsZeroComparison(E))
236     return compareExpressionToZero(Context, E, true);
237 
238   return asBool(getText(Context, *E), NeedsStaticCast);
239 }
240 
containsDiscardedTokens(const ASTContext & Context,CharSourceRange CharRange)241 static bool containsDiscardedTokens(const ASTContext &Context,
242                                     CharSourceRange CharRange) {
243   std::string ReplacementText =
244       Lexer::getSourceText(CharRange, Context.getSourceManager(),
245                            Context.getLangOpts())
246           .str();
247   Lexer Lex(CharRange.getBegin(), Context.getLangOpts(), ReplacementText.data(),
248             ReplacementText.data(),
249             ReplacementText.data() + ReplacementText.size());
250   Lex.SetCommentRetentionState(true);
251 
252   Token Tok;
253   while (!Lex.LexFromRawLexer(Tok)) {
254     if (Tok.is(tok::TokenKind::comment) || Tok.is(tok::TokenKind::hash))
255       return true;
256   }
257 
258   return false;
259 }
260 
261 class SimplifyBooleanExprCheck::Visitor : public RecursiveASTVisitor<Visitor> {
262   using Base = RecursiveASTVisitor<Visitor>;
263 
264 public:
Visitor(SimplifyBooleanExprCheck * Check,ASTContext & Context)265   Visitor(SimplifyBooleanExprCheck *Check, ASTContext &Context)
266       : Check(Check), Context(Context) {}
267 
traverse()268   bool traverse() { return TraverseAST(Context); }
269 
shouldIgnore(Stmt * S)270   static bool shouldIgnore(Stmt *S) {
271     switch (S->getStmtClass()) {
272     case Stmt::ImplicitCastExprClass:
273     case Stmt::MaterializeTemporaryExprClass:
274     case Stmt::CXXBindTemporaryExprClass:
275       return true;
276     default:
277       return false;
278     }
279   }
280 
dataTraverseStmtPre(Stmt * S)281   bool dataTraverseStmtPre(Stmt *S) {
282     if (!S) {
283       return true;
284     }
285     if (Check->canBeBypassed(S))
286       return false;
287     if (!shouldIgnore(S))
288       StmtStack.push_back(S);
289     return true;
290   }
291 
dataTraverseStmtPost(Stmt * S)292   bool dataTraverseStmtPost(Stmt *S) {
293     if (S && !shouldIgnore(S)) {
294       assert(StmtStack.back() == S);
295       StmtStack.pop_back();
296     }
297     return true;
298   }
299 
VisitBinaryOperator(const BinaryOperator * Op) const300   bool VisitBinaryOperator(const BinaryOperator *Op) const {
301     Check->reportBinOp(Context, Op);
302     return true;
303   }
304 
305   // Extracts a bool if an expression is (true|false|!true|!false);
getAsBoolLiteral(const Expr * E,bool FilterMacro)306   static std::optional<bool> getAsBoolLiteral(const Expr *E, bool FilterMacro) {
307     if (const auto *Bool = dyn_cast<CXXBoolLiteralExpr>(E)) {
308       if (FilterMacro && Bool->getBeginLoc().isMacroID())
309         return std::nullopt;
310       return Bool->getValue();
311     }
312     if (const auto *UnaryOp = dyn_cast<UnaryOperator>(E)) {
313       if (FilterMacro && UnaryOp->getBeginLoc().isMacroID())
314         return std::nullopt;
315       if (UnaryOp->getOpcode() == UO_LNot)
316         if (std::optional<bool> Res = getAsBoolLiteral(
317                 UnaryOp->getSubExpr()->IgnoreImplicit(), FilterMacro))
318           return !*Res;
319     }
320     return std::nullopt;
321   }
322 
323   template <typename Node> struct NodeAndBool {
324     const Node *Item = nullptr;
325     bool Bool = false;
326 
operator boolclang::tidy::readability::SimplifyBooleanExprCheck::Visitor::NodeAndBool327     operator bool() const { return Item != nullptr; }
328   };
329 
330   using ExprAndBool = NodeAndBool<Expr>;
331   using DeclAndBool = NodeAndBool<Decl>;
332 
333   /// Detect's return (true|false|!true|!false);
parseReturnLiteralBool(const Stmt * S)334   static ExprAndBool parseReturnLiteralBool(const Stmt *S) {
335     const auto *RS = dyn_cast<ReturnStmt>(S);
336     if (!RS || !RS->getRetValue())
337       return {};
338     if (std::optional<bool> Ret =
339             getAsBoolLiteral(RS->getRetValue()->IgnoreImplicit(), false)) {
340       return {RS->getRetValue(), *Ret};
341     }
342     return {};
343   }
344 
345   /// If \p S is not a \c CompoundStmt, applies F on \p S, otherwise if there is
346   /// only 1 statement in the \c CompoundStmt, applies F on that single
347   /// statement.
348   template <typename Functor>
checkSingleStatement(Stmt * S,Functor F)349   static auto checkSingleStatement(Stmt *S, Functor F) -> decltype(F(S)) {
350     if (auto *CS = dyn_cast<CompoundStmt>(S)) {
351       if (CS->size() == 1)
352         return F(CS->body_front());
353       return {};
354     }
355     return F(S);
356   }
357 
parent() const358   Stmt *parent() const {
359     return StmtStack.size() < 2 ? nullptr : StmtStack[StmtStack.size() - 2];
360   }
361 
VisitIfStmt(IfStmt * If)362   bool VisitIfStmt(IfStmt *If) {
363     // Skip any if's that have a condition var or an init statement, or are
364     // "if consteval" statements.
365     if (If->hasInitStorage() || If->hasVarStorage() || If->isConsteval())
366       return true;
367     /*
368      * if (true) ThenStmt(); -> ThenStmt();
369      * if (false) ThenStmt(); -> <Empty>;
370      * if (false) ThenStmt(); else ElseStmt() -> ElseStmt();
371      */
372     Expr *Cond = If->getCond()->IgnoreImplicit();
373     if (std::optional<bool> Bool = getAsBoolLiteral(Cond, true)) {
374       if (*Bool)
375         Check->replaceWithThenStatement(Context, If, Cond);
376       else
377         Check->replaceWithElseStatement(Context, If, Cond);
378     }
379 
380     if (If->getElse()) {
381       /*
382        * if (Cond) return true; else return false; -> return Cond;
383        * if (Cond) return false; else return true; -> return !Cond;
384        */
385       if (ExprAndBool ThenReturnBool =
386               checkSingleStatement(If->getThen(), parseReturnLiteralBool)) {
387         ExprAndBool ElseReturnBool =
388             checkSingleStatement(If->getElse(), parseReturnLiteralBool);
389         if (ElseReturnBool && ThenReturnBool.Bool != ElseReturnBool.Bool) {
390           if (Check->ChainedConditionalReturn ||
391               !isa_and_nonnull<IfStmt>(parent())) {
392             Check->replaceWithReturnCondition(Context, If, ThenReturnBool.Item,
393                                               ElseReturnBool.Bool);
394           }
395         }
396       } else {
397         /*
398          * if (Cond) A = true; else A = false; -> A = Cond;
399          * if (Cond) A = false; else A = true; -> A = !Cond;
400          */
401         Expr *Var = nullptr;
402         SourceLocation Loc;
403         auto VarBoolAssignmentMatcher = [&Var,
404                                          &Loc](const Stmt *S) -> DeclAndBool {
405           const auto *BO = dyn_cast<BinaryOperator>(S);
406           if (!BO || BO->getOpcode() != BO_Assign)
407             return {};
408           std::optional<bool> RightasBool =
409               getAsBoolLiteral(BO->getRHS()->IgnoreImplicit(), false);
410           if (!RightasBool)
411             return {};
412           Expr *IgnImp = BO->getLHS()->IgnoreImplicit();
413           if (!Var) {
414             // We only need to track these for the Then branch.
415             Loc = BO->getRHS()->getBeginLoc();
416             Var = IgnImp;
417           }
418           if (auto *DRE = dyn_cast<DeclRefExpr>(IgnImp))
419             return {DRE->getDecl(), *RightasBool};
420           if (auto *ME = dyn_cast<MemberExpr>(IgnImp))
421             return {ME->getMemberDecl(), *RightasBool};
422           return {};
423         };
424         if (DeclAndBool ThenAssignment =
425                 checkSingleStatement(If->getThen(), VarBoolAssignmentMatcher)) {
426           DeclAndBool ElseAssignment =
427               checkSingleStatement(If->getElse(), VarBoolAssignmentMatcher);
428           if (ElseAssignment.Item == ThenAssignment.Item &&
429               ElseAssignment.Bool != ThenAssignment.Bool) {
430             if (Check->ChainedConditionalAssignment ||
431                 !isa_and_nonnull<IfStmt>(parent())) {
432               Check->replaceWithAssignment(Context, If, Var, Loc,
433                                            ElseAssignment.Bool);
434             }
435           }
436         }
437       }
438     }
439     return true;
440   }
441 
VisitConditionalOperator(ConditionalOperator * Cond)442   bool VisitConditionalOperator(ConditionalOperator *Cond) {
443     /*
444      * Condition ? true : false; -> Condition
445      * Condition ? false : true; -> !Condition;
446      */
447     if (std::optional<bool> Then =
448             getAsBoolLiteral(Cond->getTrueExpr()->IgnoreImplicit(), false)) {
449       if (std::optional<bool> Else =
450               getAsBoolLiteral(Cond->getFalseExpr()->IgnoreImplicit(), false)) {
451         if (*Then != *Else)
452           Check->replaceWithCondition(Context, Cond, *Else);
453       }
454     }
455     return true;
456   }
457 
VisitCompoundStmt(CompoundStmt * CS)458   bool VisitCompoundStmt(CompoundStmt *CS) {
459     if (CS->size() < 2)
460       return true;
461     bool CurIf = false, PrevIf = false;
462     for (auto First = CS->body_begin(), Second = std::next(First),
463               End = CS->body_end();
464          Second != End; ++Second, ++First) {
465       PrevIf = CurIf;
466       CurIf = isa<IfStmt>(*First);
467       ExprAndBool TrailingReturnBool = parseReturnLiteralBool(*Second);
468       if (!TrailingReturnBool)
469         continue;
470 
471       if (CurIf) {
472         /*
473          * if (Cond) return true; return false; -> return Cond;
474          * if (Cond) return false; return true; -> return !Cond;
475          */
476         auto *If = cast<IfStmt>(*First);
477         if (!If->hasInitStorage() && !If->hasVarStorage() &&
478             !If->isConsteval()) {
479           ExprAndBool ThenReturnBool =
480               checkSingleStatement(If->getThen(), parseReturnLiteralBool);
481           if (ThenReturnBool &&
482               ThenReturnBool.Bool != TrailingReturnBool.Bool) {
483             if ((Check->ChainedConditionalReturn || !PrevIf) &&
484                 If->getElse() == nullptr) {
485               Check->replaceCompoundReturnWithCondition(
486                   Context, cast<ReturnStmt>(*Second), TrailingReturnBool.Bool,
487                   If, ThenReturnBool.Item);
488             }
489           }
490         }
491       } else if (isa<LabelStmt, CaseStmt, DefaultStmt>(*First)) {
492         /*
493          * (case X|label_X|default): if (Cond) return BoolLiteral;
494          *                           return !BoolLiteral
495          */
496         Stmt *SubStmt =
497             isa<LabelStmt>(*First)  ? cast<LabelStmt>(*First)->getSubStmt()
498             : isa<CaseStmt>(*First) ? cast<CaseStmt>(*First)->getSubStmt()
499                                     : cast<DefaultStmt>(*First)->getSubStmt();
500         auto *SubIf = dyn_cast<IfStmt>(SubStmt);
501         if (SubIf && !SubIf->getElse() && !SubIf->hasInitStorage() &&
502             !SubIf->hasVarStorage() && !SubIf->isConsteval()) {
503           ExprAndBool ThenReturnBool =
504               checkSingleStatement(SubIf->getThen(), parseReturnLiteralBool);
505           if (ThenReturnBool &&
506               ThenReturnBool.Bool != TrailingReturnBool.Bool) {
507             Check->replaceCompoundReturnWithCondition(
508                 Context, cast<ReturnStmt>(*Second), TrailingReturnBool.Bool,
509                 SubIf, ThenReturnBool.Item);
510           }
511         }
512       }
513     }
514     return true;
515   }
516 
isExpectedUnaryLNot(const Expr * E)517   bool isExpectedUnaryLNot(const Expr *E) {
518     return !Check->canBeBypassed(E) && isa<UnaryOperator>(E) &&
519            cast<UnaryOperator>(E)->getOpcode() == UO_LNot;
520   }
521 
isExpectedBinaryOp(const Expr * E)522   bool isExpectedBinaryOp(const Expr *E) {
523     const auto *BinaryOp = dyn_cast<BinaryOperator>(E);
524     return !Check->canBeBypassed(E) && BinaryOp && BinaryOp->isLogicalOp() &&
525            BinaryOp->getType()->isBooleanType();
526   }
527 
528   template <typename Functor>
checkEitherSide(const BinaryOperator * BO,Functor Func)529   static bool checkEitherSide(const BinaryOperator *BO, Functor Func) {
530     return Func(BO->getLHS()) || Func(BO->getRHS());
531   }
532 
nestedDemorgan(const Expr * E,unsigned NestingLevel)533   bool nestedDemorgan(const Expr *E, unsigned NestingLevel) {
534     const auto *BO = dyn_cast<BinaryOperator>(E->IgnoreUnlessSpelledInSource());
535     if (!BO)
536       return false;
537     if (!BO->getType()->isBooleanType())
538       return false;
539     switch (BO->getOpcode()) {
540     case BO_LT:
541     case BO_GT:
542     case BO_LE:
543     case BO_GE:
544     case BO_EQ:
545     case BO_NE:
546       return true;
547     case BO_LAnd:
548     case BO_LOr:
549       return checkEitherSide(
550                  BO,
551                  [this](const Expr *E) { return isExpectedUnaryLNot(E); }) ||
552              (NestingLevel &&
553               checkEitherSide(BO, [this, NestingLevel](const Expr *E) {
554                 return nestedDemorgan(E, NestingLevel - 1);
555               }));
556     default:
557       return false;
558     }
559   }
560 
TraverseUnaryOperator(UnaryOperator * Op)561   bool TraverseUnaryOperator(UnaryOperator *Op) {
562     if (!Check->SimplifyDeMorgan || Op->getOpcode() != UO_LNot)
563       return Base::TraverseUnaryOperator(Op);
564     const Expr *SubImp = Op->getSubExpr()->IgnoreImplicit();
565     const auto *Parens = dyn_cast<ParenExpr>(SubImp);
566     const Expr *SubExpr =
567         Parens ? Parens->getSubExpr()->IgnoreImplicit() : SubImp;
568     if (!isExpectedBinaryOp(SubExpr))
569       return Base::TraverseUnaryOperator(Op);
570     const auto *BinaryOp = cast<BinaryOperator>(SubExpr);
571     if (Check->SimplifyDeMorganRelaxed ||
572         checkEitherSide(
573             BinaryOp,
574             [this](const Expr *E) { return isExpectedUnaryLNot(E); }) ||
575         checkEitherSide(
576             BinaryOp, [this](const Expr *E) { return nestedDemorgan(E, 1); })) {
577       if (Check->reportDeMorgan(Context, Op, BinaryOp, !IsProcessing, parent(),
578                                 Parens) &&
579           !Check->areDiagsSelfContained()) {
580         llvm::SaveAndRestore RAII(IsProcessing, true);
581         return Base::TraverseUnaryOperator(Op);
582       }
583     }
584     return Base::TraverseUnaryOperator(Op);
585   }
586 
587 private:
588   bool IsProcessing = false;
589   SimplifyBooleanExprCheck *Check;
590   SmallVector<Stmt *, 32> StmtStack;
591   ASTContext &Context;
592 };
593 
SimplifyBooleanExprCheck(StringRef Name,ClangTidyContext * Context)594 SimplifyBooleanExprCheck::SimplifyBooleanExprCheck(StringRef Name,
595                                                    ClangTidyContext *Context)
596     : ClangTidyCheck(Name, Context),
597       IgnoreMacros(Options.get("IgnoreMacros", false)),
598       ChainedConditionalReturn(Options.get("ChainedConditionalReturn", false)),
599       ChainedConditionalAssignment(
600           Options.get("ChainedConditionalAssignment", false)),
601       SimplifyDeMorgan(Options.get("SimplifyDeMorgan", true)),
602       SimplifyDeMorganRelaxed(Options.get("SimplifyDeMorganRelaxed", false)) {
603   if (SimplifyDeMorganRelaxed && !SimplifyDeMorgan)
604     configurationDiag("%0: 'SimplifyDeMorganRelaxed' cannot be enabled "
605                       "without 'SimplifyDeMorgan' enabled")
606         << Name;
607 }
608 
containsBoolLiteral(const Expr * E)609 static bool containsBoolLiteral(const Expr *E) {
610   if (!E)
611     return false;
612   E = E->IgnoreParenImpCasts();
613   if (isa<CXXBoolLiteralExpr>(E))
614     return true;
615   if (const auto *BinOp = dyn_cast<BinaryOperator>(E))
616     return containsBoolLiteral(BinOp->getLHS()) ||
617            containsBoolLiteral(BinOp->getRHS());
618   if (const auto *UnaryOp = dyn_cast<UnaryOperator>(E))
619     return containsBoolLiteral(UnaryOp->getSubExpr());
620   return false;
621 }
622 
reportBinOp(const ASTContext & Context,const BinaryOperator * Op)623 void SimplifyBooleanExprCheck::reportBinOp(const ASTContext &Context,
624                                            const BinaryOperator *Op) {
625   const auto *LHS = Op->getLHS()->IgnoreParenImpCasts();
626   const auto *RHS = Op->getRHS()->IgnoreParenImpCasts();
627 
628   const CXXBoolLiteralExpr *Bool = nullptr;
629   const Expr *Other = nullptr;
630   if ((Bool = dyn_cast<CXXBoolLiteralExpr>(LHS)) != nullptr)
631     Other = RHS;
632   else if ((Bool = dyn_cast<CXXBoolLiteralExpr>(RHS)) != nullptr)
633     Other = LHS;
634   else
635     return;
636 
637   if (Bool->getBeginLoc().isMacroID())
638     return;
639 
640   // FIXME: why do we need this?
641   if (!isa<CXXBoolLiteralExpr>(Other) && containsBoolLiteral(Other))
642     return;
643 
644   bool BoolValue = Bool->getValue();
645 
646   auto ReplaceWithExpression = [this, &Context, LHS, RHS,
647                                 Bool](const Expr *ReplaceWith, bool Negated) {
648     std::string Replacement =
649         replacementExpression(Context, Negated, ReplaceWith);
650     SourceRange Range(LHS->getBeginLoc(), RHS->getEndLoc());
651     issueDiag(Context, Bool->getBeginLoc(), SimplifyOperatorDiagnostic, Range,
652               Replacement);
653   };
654 
655   switch (Op->getOpcode()) {
656   case BO_LAnd:
657     if (BoolValue)
658       // expr && true -> expr
659       ReplaceWithExpression(Other, /*Negated=*/false);
660     else
661       // expr && false -> false
662       ReplaceWithExpression(Bool, /*Negated=*/false);
663     break;
664   case BO_LOr:
665     if (BoolValue)
666       // expr || true -> true
667       ReplaceWithExpression(Bool, /*Negated=*/false);
668     else
669       // expr || false -> expr
670       ReplaceWithExpression(Other, /*Negated=*/false);
671     break;
672   case BO_EQ:
673     // expr == true -> expr, expr == false -> !expr
674     ReplaceWithExpression(Other, /*Negated=*/!BoolValue);
675     break;
676   case BO_NE:
677     // expr != true -> !expr, expr != false -> expr
678     ReplaceWithExpression(Other, /*Negated=*/BoolValue);
679     break;
680   default:
681     break;
682   }
683 }
684 
storeOptions(ClangTidyOptions::OptionMap & Opts)685 void SimplifyBooleanExprCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
686   Options.store(Opts, "IgnoreMacros", IgnoreMacros);
687   Options.store(Opts, "ChainedConditionalReturn", ChainedConditionalReturn);
688   Options.store(Opts, "ChainedConditionalAssignment",
689                 ChainedConditionalAssignment);
690   Options.store(Opts, "SimplifyDeMorgan", SimplifyDeMorgan);
691   Options.store(Opts, "SimplifyDeMorganRelaxed", SimplifyDeMorganRelaxed);
692 }
693 
registerMatchers(MatchFinder * Finder)694 void SimplifyBooleanExprCheck::registerMatchers(MatchFinder *Finder) {
695   Finder->addMatcher(translationUnitDecl(), this);
696 }
697 
check(const MatchFinder::MatchResult & Result)698 void SimplifyBooleanExprCheck::check(const MatchFinder::MatchResult &Result) {
699   Visitor(this, *Result.Context).traverse();
700 }
701 
canBeBypassed(const Stmt * S) const702 bool SimplifyBooleanExprCheck::canBeBypassed(const Stmt *S) const {
703   return IgnoreMacros && S->getBeginLoc().isMacroID();
704 }
705 
706 /// @brief return true when replacement created.
issueDiag(const ASTContext & Context,SourceLocation Loc,StringRef Description,SourceRange ReplacementRange,StringRef Replacement)707 bool SimplifyBooleanExprCheck::issueDiag(const ASTContext &Context,
708                                          SourceLocation Loc,
709                                          StringRef Description,
710                                          SourceRange ReplacementRange,
711                                          StringRef Replacement) {
712   CharSourceRange CharRange =
713       Lexer::makeFileCharRange(CharSourceRange::getTokenRange(ReplacementRange),
714                                Context.getSourceManager(), getLangOpts());
715 
716   DiagnosticBuilder Diag = diag(Loc, Description);
717   const bool HasReplacement = !containsDiscardedTokens(Context, CharRange);
718   if (HasReplacement)
719     Diag << FixItHint::CreateReplacement(CharRange, Replacement);
720   return HasReplacement;
721 }
722 
replaceWithThenStatement(const ASTContext & Context,const IfStmt * IfStatement,const Expr * BoolLiteral)723 void SimplifyBooleanExprCheck::replaceWithThenStatement(
724     const ASTContext &Context, const IfStmt *IfStatement,
725     const Expr *BoolLiteral) {
726   issueDiag(Context, BoolLiteral->getBeginLoc(), SimplifyConditionDiagnostic,
727             IfStatement->getSourceRange(),
728             getText(Context, *IfStatement->getThen()));
729 }
730 
replaceWithElseStatement(const ASTContext & Context,const IfStmt * IfStatement,const Expr * BoolLiteral)731 void SimplifyBooleanExprCheck::replaceWithElseStatement(
732     const ASTContext &Context, const IfStmt *IfStatement,
733     const Expr *BoolLiteral) {
734   const Stmt *ElseStatement = IfStatement->getElse();
735   issueDiag(Context, BoolLiteral->getBeginLoc(), SimplifyConditionDiagnostic,
736             IfStatement->getSourceRange(),
737             ElseStatement ? getText(Context, *ElseStatement) : "");
738 }
739 
replaceWithCondition(const ASTContext & Context,const ConditionalOperator * Ternary,bool Negated)740 void SimplifyBooleanExprCheck::replaceWithCondition(
741     const ASTContext &Context, const ConditionalOperator *Ternary,
742     bool Negated) {
743   std::string Replacement =
744       replacementExpression(Context, Negated, Ternary->getCond());
745   issueDiag(Context, Ternary->getTrueExpr()->getBeginLoc(),
746             "redundant boolean literal in ternary expression result",
747             Ternary->getSourceRange(), Replacement);
748 }
749 
replaceWithReturnCondition(const ASTContext & Context,const IfStmt * If,const Expr * BoolLiteral,bool Negated)750 void SimplifyBooleanExprCheck::replaceWithReturnCondition(
751     const ASTContext &Context, const IfStmt *If, const Expr *BoolLiteral,
752     bool Negated) {
753   StringRef Terminator = isa<CompoundStmt>(If->getElse()) ? ";" : "";
754   std::string Condition =
755       replacementExpression(Context, Negated, If->getCond());
756   std::string Replacement = ("return " + Condition + Terminator).str();
757   SourceLocation Start = BoolLiteral->getBeginLoc();
758 
759   const bool HasReplacement =
760       issueDiag(Context, Start, SimplifyConditionalReturnDiagnostic,
761                 If->getSourceRange(), Replacement);
762 
763   if (!HasReplacement) {
764     const SourceRange ConditionRange = If->getCond()->getSourceRange();
765     if (ConditionRange.isValid())
766       diag(ConditionRange.getBegin(), "conditions that can be simplified",
767            DiagnosticIDs::Note)
768           << ConditionRange;
769   }
770 }
771 
replaceCompoundReturnWithCondition(const ASTContext & Context,const ReturnStmt * Ret,bool Negated,const IfStmt * If,const Expr * ThenReturn)772 void SimplifyBooleanExprCheck::replaceCompoundReturnWithCondition(
773     const ASTContext &Context, const ReturnStmt *Ret, bool Negated,
774     const IfStmt *If, const Expr *ThenReturn) {
775   const std::string Replacement =
776       "return " + replacementExpression(Context, Negated, If->getCond());
777 
778   const bool HasReplacement = issueDiag(
779       Context, ThenReturn->getBeginLoc(), SimplifyConditionalReturnDiagnostic,
780       SourceRange(If->getBeginLoc(), Ret->getEndLoc()), Replacement);
781 
782   if (!HasReplacement) {
783     const SourceRange ConditionRange = If->getCond()->getSourceRange();
784     if (ConditionRange.isValid())
785       diag(ConditionRange.getBegin(), "conditions that can be simplified",
786            DiagnosticIDs::Note)
787           << ConditionRange;
788     const SourceRange ReturnRange = Ret->getSourceRange();
789     if (ReturnRange.isValid())
790       diag(ReturnRange.getBegin(), "return statement that can be simplified",
791            DiagnosticIDs::Note)
792           << ReturnRange;
793   }
794 }
795 
replaceWithAssignment(const ASTContext & Context,const IfStmt * IfAssign,const Expr * Var,SourceLocation Loc,bool Negated)796 void SimplifyBooleanExprCheck::replaceWithAssignment(const ASTContext &Context,
797                                                      const IfStmt *IfAssign,
798                                                      const Expr *Var,
799                                                      SourceLocation Loc,
800                                                      bool Negated) {
801   SourceRange Range = IfAssign->getSourceRange();
802   StringRef VariableName = getText(Context, *Var);
803   StringRef Terminator = isa<CompoundStmt>(IfAssign->getElse()) ? ";" : "";
804   std::string Condition =
805       replacementExpression(Context, Negated, IfAssign->getCond());
806   std::string Replacement =
807       (VariableName + " = " + Condition + Terminator).str();
808   issueDiag(Context, Loc, "redundant boolean literal in conditional assignment",
809             Range, Replacement);
810 }
811 
812 /// Swaps a \c BinaryOperator opcode from `&&` to `||` or vice-versa.
flipDemorganOperator(llvm::SmallVectorImpl<FixItHint> & Output,const BinaryOperator * BO)813 static bool flipDemorganOperator(llvm::SmallVectorImpl<FixItHint> &Output,
814                                  const BinaryOperator *BO) {
815   assert(BO->isLogicalOp());
816   if (BO->getOperatorLoc().isMacroID())
817     return true;
818   Output.push_back(FixItHint::CreateReplacement(
819       BO->getOperatorLoc(), BO->getOpcode() == BO_LAnd ? "||" : "&&"));
820   return false;
821 }
822 
getDemorganFlippedOperator(BinaryOperatorKind BO)823 static BinaryOperatorKind getDemorganFlippedOperator(BinaryOperatorKind BO) {
824   assert(BinaryOperator::isLogicalOp(BO));
825   return BO == BO_LAnd ? BO_LOr : BO_LAnd;
826 }
827 
828 static bool flipDemorganSide(SmallVectorImpl<FixItHint> &Fixes,
829                              const ASTContext &Ctx, const Expr *E,
830                              std::optional<BinaryOperatorKind> OuterBO);
831 
832 /// Inverts \p BinOp, Removing \p Parens if they exist and are safe to remove.
833 /// returns \c true if there is any issue building the Fixes, \c false
834 /// otherwise.
835 static bool
flipDemorganBinaryOperator(SmallVectorImpl<FixItHint> & Fixes,const ASTContext & Ctx,const BinaryOperator * BinOp,std::optional<BinaryOperatorKind> OuterBO,const ParenExpr * Parens=nullptr)836 flipDemorganBinaryOperator(SmallVectorImpl<FixItHint> &Fixes,
837                            const ASTContext &Ctx, const BinaryOperator *BinOp,
838                            std::optional<BinaryOperatorKind> OuterBO,
839                            const ParenExpr *Parens = nullptr) {
840   switch (BinOp->getOpcode()) {
841   case BO_LAnd:
842   case BO_LOr: {
843     // if we have 'a && b' or 'a || b', use demorgan to flip it to '!a || !b'
844     // or '!a && !b'.
845     if (flipDemorganOperator(Fixes, BinOp))
846       return true;
847     auto NewOp = getDemorganFlippedOperator(BinOp->getOpcode());
848     if (OuterBO) {
849       // The inner parens are technically needed in a fix for
850       // `!(!A1 && !(A2 || A3)) -> (A1 || (A2 && A3))`,
851       // however this would trip the LogicalOpParentheses warning.
852       // FIXME: Make this user configurable or detect if that warning is
853       // enabled.
854       constexpr bool LogicalOpParentheses = true;
855       if (((*OuterBO == NewOp) || (!LogicalOpParentheses &&
856                                    (*OuterBO == BO_LOr && NewOp == BO_LAnd))) &&
857           Parens) {
858         if (!Parens->getLParen().isMacroID() &&
859             !Parens->getRParen().isMacroID()) {
860           Fixes.push_back(FixItHint::CreateRemoval(Parens->getLParen()));
861           Fixes.push_back(FixItHint::CreateRemoval(Parens->getRParen()));
862         }
863       }
864       if (*OuterBO == BO_LAnd && NewOp == BO_LOr && !Parens) {
865         Fixes.push_back(FixItHint::CreateInsertion(BinOp->getBeginLoc(), "("));
866         Fixes.push_back(FixItHint::CreateInsertion(
867             Lexer::getLocForEndOfToken(BinOp->getEndLoc(), 0,
868                                        Ctx.getSourceManager(),
869                                        Ctx.getLangOpts()),
870             ")"));
871       }
872     }
873     if (flipDemorganSide(Fixes, Ctx, BinOp->getLHS(), NewOp) ||
874         flipDemorganSide(Fixes, Ctx, BinOp->getRHS(), NewOp))
875       return true;
876     return false;
877   };
878   case BO_LT:
879   case BO_GT:
880   case BO_LE:
881   case BO_GE:
882   case BO_EQ:
883   case BO_NE:
884     // For comparison operators, just negate the comparison.
885     if (BinOp->getOperatorLoc().isMacroID())
886       return true;
887     Fixes.push_back(FixItHint::CreateReplacement(
888         BinOp->getOperatorLoc(),
889         BinaryOperator::getOpcodeStr(
890             BinaryOperator::negateComparisonOp(BinOp->getOpcode()))));
891     return false;
892   default:
893     // for any other binary operator, just use logical not and wrap in
894     // parens.
895     if (Parens) {
896       if (Parens->getBeginLoc().isMacroID())
897         return true;
898       Fixes.push_back(FixItHint::CreateInsertion(Parens->getBeginLoc(), "!"));
899     } else {
900       if (BinOp->getBeginLoc().isMacroID() || BinOp->getEndLoc().isMacroID())
901         return true;
902       Fixes.append({FixItHint::CreateInsertion(BinOp->getBeginLoc(), "!("),
903                     FixItHint::CreateInsertion(
904                         Lexer::getLocForEndOfToken(BinOp->getEndLoc(), 0,
905                                                    Ctx.getSourceManager(),
906                                                    Ctx.getLangOpts()),
907                         ")")});
908     }
909     break;
910   }
911   return false;
912 }
913 
flipDemorganSide(SmallVectorImpl<FixItHint> & Fixes,const ASTContext & Ctx,const Expr * E,std::optional<BinaryOperatorKind> OuterBO)914 static bool flipDemorganSide(SmallVectorImpl<FixItHint> &Fixes,
915                              const ASTContext &Ctx, const Expr *E,
916                              std::optional<BinaryOperatorKind> OuterBO) {
917   if (isa<UnaryOperator>(E) && cast<UnaryOperator>(E)->getOpcode() == UO_LNot) {
918     //  if we have a not operator, '!a', just remove the '!'.
919     if (cast<UnaryOperator>(E)->getOperatorLoc().isMacroID())
920       return true;
921     Fixes.push_back(
922         FixItHint::CreateRemoval(cast<UnaryOperator>(E)->getOperatorLoc()));
923     return false;
924   }
925   if (const auto *BinOp = dyn_cast<BinaryOperator>(E)) {
926     return flipDemorganBinaryOperator(Fixes, Ctx, BinOp, OuterBO);
927   }
928   if (const auto *Paren = dyn_cast<ParenExpr>(E)) {
929     if (const auto *BinOp = dyn_cast<BinaryOperator>(Paren->getSubExpr())) {
930       return flipDemorganBinaryOperator(Fixes, Ctx, BinOp, OuterBO, Paren);
931     }
932   }
933   // Fallback case just insert a logical not operator.
934   if (E->getBeginLoc().isMacroID())
935     return true;
936   Fixes.push_back(FixItHint::CreateInsertion(E->getBeginLoc(), "!"));
937   return false;
938 }
939 
shouldRemoveParens(const Stmt * Parent,BinaryOperatorKind NewOuterBinary,const ParenExpr * Parens)940 static bool shouldRemoveParens(const Stmt *Parent,
941                                BinaryOperatorKind NewOuterBinary,
942                                const ParenExpr *Parens) {
943   if (!Parens)
944     return false;
945   if (!Parent)
946     return true;
947   switch (Parent->getStmtClass()) {
948   case Stmt::BinaryOperatorClass: {
949     const auto *BO = cast<BinaryOperator>(Parent);
950     if (BO->isAssignmentOp())
951       return true;
952     if (BO->isCommaOp())
953       return true;
954     if (BO->getOpcode() == NewOuterBinary)
955       return true;
956     return false;
957   }
958   case Stmt::UnaryOperatorClass:
959   case Stmt::CXXRewrittenBinaryOperatorClass:
960     return false;
961   default:
962     return true;
963   }
964 }
965 
reportDeMorgan(const ASTContext & Context,const UnaryOperator * Outer,const BinaryOperator * Inner,bool TryOfferFix,const Stmt * Parent,const ParenExpr * Parens)966 bool SimplifyBooleanExprCheck::reportDeMorgan(const ASTContext &Context,
967                                               const UnaryOperator *Outer,
968                                               const BinaryOperator *Inner,
969                                               bool TryOfferFix,
970                                               const Stmt *Parent,
971                                               const ParenExpr *Parens) {
972   assert(Outer);
973   assert(Inner);
974   assert(Inner->isLogicalOp());
975 
976   auto Diag =
977       diag(Outer->getBeginLoc(),
978            "boolean expression can be simplified by DeMorgan's theorem");
979   Diag << Outer->getSourceRange();
980   // If we have already fixed this with a previous fix, don't attempt any fixes
981   if (!TryOfferFix)
982     return false;
983   if (Outer->getOperatorLoc().isMacroID())
984     return false;
985   SmallVector<FixItHint> Fixes;
986   auto NewOpcode = getDemorganFlippedOperator(Inner->getOpcode());
987   if (shouldRemoveParens(Parent, NewOpcode, Parens)) {
988     Fixes.push_back(FixItHint::CreateRemoval(
989         SourceRange(Outer->getOperatorLoc(), Parens->getLParen())));
990     Fixes.push_back(FixItHint::CreateRemoval(Parens->getRParen()));
991   } else {
992     Fixes.push_back(FixItHint::CreateRemoval(Outer->getOperatorLoc()));
993   }
994   if (flipDemorganOperator(Fixes, Inner))
995     return false;
996   if (flipDemorganSide(Fixes, Context, Inner->getLHS(), NewOpcode) ||
997       flipDemorganSide(Fixes, Context, Inner->getRHS(), NewOpcode))
998     return false;
999   Diag << Fixes;
1000   return true;
1001 }
1002 } // namespace clang::tidy::readability
1003