xref: /llvm-project/clang-tools-extra/clang-tidy/readability/ElseAfterReturnCheck.cpp (revision 3f9e2e179a52eb50a2bcff148a5f351a4eddcb37)
1 //===--- ElseAfterReturnCheck.cpp - clang-tidy-----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ElseAfterReturnCheck.h"
10 #include "clang/AST/ASTContext.h"
11 #include "clang/ASTMatchers/ASTMatchFinder.h"
12 #include "clang/Lex/Lexer.h"
13 #include "clang/Lex/Preprocessor.h"
14 #include "clang/Tooling/FixIt.h"
15 #include "llvm/ADT/SmallVector.h"
16 
17 using namespace clang::ast_matchers;
18 
19 namespace clang::tidy::readability {
20 
21 namespace {
22 
23 class PPConditionalCollector : public PPCallbacks {
24 public:
PPConditionalCollector(ElseAfterReturnCheck::ConditionalBranchMap & Collections,const SourceManager & SM)25   PPConditionalCollector(
26       ElseAfterReturnCheck::ConditionalBranchMap &Collections,
27       const SourceManager &SM)
28       : Collections(Collections), SM(SM) {}
Endif(SourceLocation Loc,SourceLocation IfLoc)29   void Endif(SourceLocation Loc, SourceLocation IfLoc) override {
30     if (!SM.isWrittenInSameFile(Loc, IfLoc))
31       return;
32     SmallVectorImpl<SourceRange> &Collection = Collections[SM.getFileID(Loc)];
33     assert(Collection.empty() || Collection.back().getEnd() < Loc);
34     Collection.emplace_back(IfLoc, Loc);
35   }
36 
37 private:
38   ElseAfterReturnCheck::ConditionalBranchMap &Collections;
39   const SourceManager &SM;
40 };
41 
42 } // namespace
43 
44 static const char InterruptingStr[] = "interrupting";
45 static const char WarningMessage[] = "do not use 'else' after '%0'";
46 static const char WarnOnUnfixableStr[] = "WarnOnUnfixable";
47 static const char WarnOnConditionVariablesStr[] = "WarnOnConditionVariables";
48 
findUsage(const Stmt * Node,int64_t DeclIdentifier)49 static const DeclRefExpr *findUsage(const Stmt *Node, int64_t DeclIdentifier) {
50   if (!Node)
51     return nullptr;
52   if (const auto *DeclRef = dyn_cast<DeclRefExpr>(Node)) {
53     if (DeclRef->getDecl()->getID() == DeclIdentifier)
54       return DeclRef;
55   } else {
56     for (const Stmt *ChildNode : Node->children()) {
57       if (const DeclRefExpr *Result = findUsage(ChildNode, DeclIdentifier))
58         return Result;
59     }
60   }
61   return nullptr;
62 }
63 
64 static const DeclRefExpr *
findUsageRange(const Stmt * Node,const llvm::ArrayRef<int64_t> & DeclIdentifiers)65 findUsageRange(const Stmt *Node,
66                const llvm::ArrayRef<int64_t> &DeclIdentifiers) {
67   if (!Node)
68     return nullptr;
69   if (const auto *DeclRef = dyn_cast<DeclRefExpr>(Node)) {
70     if (llvm::is_contained(DeclIdentifiers, DeclRef->getDecl()->getID()))
71       return DeclRef;
72   } else {
73     for (const Stmt *ChildNode : Node->children()) {
74       if (const DeclRefExpr *Result =
75               findUsageRange(ChildNode, DeclIdentifiers))
76         return Result;
77     }
78   }
79   return nullptr;
80 }
81 
checkInitDeclUsageInElse(const IfStmt * If)82 static const DeclRefExpr *checkInitDeclUsageInElse(const IfStmt *If) {
83   const auto *InitDeclStmt = dyn_cast_or_null<DeclStmt>(If->getInit());
84   if (!InitDeclStmt)
85     return nullptr;
86   if (InitDeclStmt->isSingleDecl()) {
87     const Decl *InitDecl = InitDeclStmt->getSingleDecl();
88     assert(isa<VarDecl>(InitDecl) && "SingleDecl must be a VarDecl");
89     return findUsage(If->getElse(), InitDecl->getID());
90   }
91   llvm::SmallVector<int64_t, 4> DeclIdentifiers;
92   for (const Decl *ChildDecl : InitDeclStmt->decls()) {
93     assert(isa<VarDecl>(ChildDecl) && "Init Decls must be a VarDecl");
94     DeclIdentifiers.push_back(ChildDecl->getID());
95   }
96   return findUsageRange(If->getElse(), DeclIdentifiers);
97 }
98 
checkConditionVarUsageInElse(const IfStmt * If)99 static const DeclRefExpr *checkConditionVarUsageInElse(const IfStmt *If) {
100   if (const VarDecl *CondVar = If->getConditionVariable())
101     return findUsage(If->getElse(), CondVar->getID());
102   return nullptr;
103 }
104 
containsDeclInScope(const Stmt * Node)105 static bool containsDeclInScope(const Stmt *Node) {
106   if (isa<DeclStmt>(Node))
107     return true;
108   if (const auto *Compound = dyn_cast<CompoundStmt>(Node))
109     return llvm::any_of(Compound->body(), [](const Stmt *SubNode) {
110       return isa<DeclStmt>(SubNode);
111     });
112   return false;
113 }
114 
removeElseAndBrackets(DiagnosticBuilder & Diag,ASTContext & Context,const Stmt * Else,SourceLocation ElseLoc)115 static void removeElseAndBrackets(DiagnosticBuilder &Diag, ASTContext &Context,
116                                   const Stmt *Else, SourceLocation ElseLoc) {
117   auto Remap = [&](SourceLocation Loc) {
118     return Context.getSourceManager().getExpansionLoc(Loc);
119   };
120   auto TokLen = [&](SourceLocation Loc) {
121     return Lexer::MeasureTokenLength(Loc, Context.getSourceManager(),
122                                      Context.getLangOpts());
123   };
124 
125   if (const auto *CS = dyn_cast<CompoundStmt>(Else)) {
126     Diag << tooling::fixit::createRemoval(ElseLoc);
127     SourceLocation LBrace = CS->getLBracLoc();
128     SourceLocation RBrace = CS->getRBracLoc();
129     SourceLocation RangeStart =
130         Remap(LBrace).getLocWithOffset(TokLen(LBrace) + 1);
131     SourceLocation RangeEnd = Remap(RBrace).getLocWithOffset(-1);
132 
133     llvm::StringRef Repl = Lexer::getSourceText(
134         CharSourceRange::getTokenRange(RangeStart, RangeEnd),
135         Context.getSourceManager(), Context.getLangOpts());
136     Diag << tooling::fixit::createReplacement(CS->getSourceRange(), Repl);
137   } else {
138     SourceLocation ElseExpandedLoc = Remap(ElseLoc);
139     SourceLocation EndLoc = Remap(Else->getEndLoc());
140 
141     llvm::StringRef Repl = Lexer::getSourceText(
142         CharSourceRange::getTokenRange(
143             ElseExpandedLoc.getLocWithOffset(TokLen(ElseLoc) + 1), EndLoc),
144         Context.getSourceManager(), Context.getLangOpts());
145     Diag << tooling::fixit::createReplacement(
146         SourceRange(ElseExpandedLoc, EndLoc), Repl);
147   }
148 }
149 
ElseAfterReturnCheck(StringRef Name,ClangTidyContext * Context)150 ElseAfterReturnCheck::ElseAfterReturnCheck(StringRef Name,
151                                            ClangTidyContext *Context)
152     : ClangTidyCheck(Name, Context),
153       WarnOnUnfixable(Options.get(WarnOnUnfixableStr, true)),
154       WarnOnConditionVariables(Options.get(WarnOnConditionVariablesStr, true)) {
155 }
156 
storeOptions(ClangTidyOptions::OptionMap & Opts)157 void ElseAfterReturnCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
158   Options.store(Opts, WarnOnUnfixableStr, WarnOnUnfixable);
159   Options.store(Opts, WarnOnConditionVariablesStr, WarnOnConditionVariables);
160 }
161 
registerPPCallbacks(const SourceManager & SM,Preprocessor * PP,Preprocessor * ModuleExpanderPP)162 void ElseAfterReturnCheck::registerPPCallbacks(const SourceManager &SM,
163                                                Preprocessor *PP,
164                                                Preprocessor *ModuleExpanderPP) {
165   PP->addPPCallbacks(
166       std::make_unique<PPConditionalCollector>(this->PPConditionals, SM));
167 }
168 
registerMatchers(MatchFinder * Finder)169 void ElseAfterReturnCheck::registerMatchers(MatchFinder *Finder) {
170   const auto InterruptsControlFlow = stmt(anyOf(
171       returnStmt().bind(InterruptingStr), continueStmt().bind(InterruptingStr),
172       breakStmt().bind(InterruptingStr), cxxThrowExpr().bind(InterruptingStr)));
173   Finder->addMatcher(
174       compoundStmt(
175           forEach(ifStmt(unless(isConstexpr()), unless(isConsteval()),
176                          hasThen(stmt(
177                              anyOf(InterruptsControlFlow,
178                                    compoundStmt(has(InterruptsControlFlow))))),
179                          hasElse(stmt().bind("else")))
180                       .bind("if")))
181           .bind("cs"),
182       this);
183 }
184 
hasPreprocessorBranchEndBetweenLocations(const ElseAfterReturnCheck::ConditionalBranchMap & ConditionalBranchMap,const SourceManager & SM,SourceLocation StartLoc,SourceLocation EndLoc)185 static bool hasPreprocessorBranchEndBetweenLocations(
186     const ElseAfterReturnCheck::ConditionalBranchMap &ConditionalBranchMap,
187     const SourceManager &SM, SourceLocation StartLoc, SourceLocation EndLoc) {
188 
189   SourceLocation ExpandedStartLoc = SM.getExpansionLoc(StartLoc);
190   SourceLocation ExpandedEndLoc = SM.getExpansionLoc(EndLoc);
191   if (!SM.isWrittenInSameFile(ExpandedStartLoc, ExpandedEndLoc))
192     return false;
193 
194   // StartLoc and EndLoc expand to the same macro.
195   if (ExpandedStartLoc == ExpandedEndLoc)
196     return false;
197 
198   assert(ExpandedStartLoc < ExpandedEndLoc);
199 
200   auto Iter = ConditionalBranchMap.find(SM.getFileID(ExpandedEndLoc));
201 
202   if (Iter == ConditionalBranchMap.end() || Iter->getSecond().empty())
203     return false;
204 
205   const SmallVectorImpl<SourceRange> &ConditionalBranches = Iter->getSecond();
206 
207   assert(llvm::is_sorted(ConditionalBranches,
208                          [](const SourceRange &LHS, const SourceRange &RHS) {
209                            return LHS.getEnd() < RHS.getEnd();
210                          }));
211 
212   // First conditional block that ends after ExpandedStartLoc.
213   const auto *Begin =
214       llvm::lower_bound(ConditionalBranches, ExpandedStartLoc,
215                         [](const SourceRange &LHS, const SourceLocation &RHS) {
216                           return LHS.getEnd() < RHS;
217                         });
218   const auto *End = ConditionalBranches.end();
219   for (; Begin != End && Begin->getEnd() < ExpandedEndLoc; ++Begin)
220     if (Begin->getBegin() < ExpandedStartLoc)
221       return true;
222   return false;
223 }
224 
getControlFlowString(const Stmt & Stmt)225 static StringRef getControlFlowString(const Stmt &Stmt) {
226   if (isa<ReturnStmt>(Stmt))
227     return "return";
228   if (isa<ContinueStmt>(Stmt))
229     return "continue";
230   if (isa<BreakStmt>(Stmt))
231     return "break";
232   if (isa<CXXThrowExpr>(Stmt))
233     return "throw";
234   llvm_unreachable("Unknown control flow interrupter");
235 }
236 
check(const MatchFinder::MatchResult & Result)237 void ElseAfterReturnCheck::check(const MatchFinder::MatchResult &Result) {
238   const auto *If = Result.Nodes.getNodeAs<IfStmt>("if");
239   const auto *Else = Result.Nodes.getNodeAs<Stmt>("else");
240   const auto *OuterScope = Result.Nodes.getNodeAs<CompoundStmt>("cs");
241   const auto *Interrupt = Result.Nodes.getNodeAs<Stmt>(InterruptingStr);
242   SourceLocation ElseLoc = If->getElseLoc();
243 
244   if (hasPreprocessorBranchEndBetweenLocations(
245           PPConditionals, *Result.SourceManager, Interrupt->getBeginLoc(),
246           ElseLoc))
247     return;
248 
249   bool IsLastInScope = OuterScope->body_back() == If;
250   const StringRef ControlFlowInterrupter = getControlFlowString(*Interrupt);
251 
252   if (!IsLastInScope && containsDeclInScope(Else)) {
253     if (WarnOnUnfixable) {
254       // Warn, but don't attempt an autofix.
255       diag(ElseLoc, WarningMessage) << ControlFlowInterrupter;
256     }
257     return;
258   }
259 
260   if (checkConditionVarUsageInElse(If) != nullptr) {
261     if (!WarnOnConditionVariables)
262       return;
263     if (IsLastInScope) {
264       // If the if statement is the last statement of its enclosing statements
265       // scope, we can pull the decl out of the if statement.
266       DiagnosticBuilder Diag = diag(ElseLoc, WarningMessage)
267                                << ControlFlowInterrupter
268                                << SourceRange(ElseLoc);
269       if (checkInitDeclUsageInElse(If) != nullptr) {
270         Diag << tooling::fixit::createReplacement(
271                     SourceRange(If->getIfLoc()),
272                     (tooling::fixit::getText(*If->getInit(), *Result.Context) +
273                      llvm::StringRef("\n"))
274                         .str())
275              << tooling::fixit::createRemoval(If->getInit()->getSourceRange());
276       }
277       const DeclStmt *VDeclStmt = If->getConditionVariableDeclStmt();
278       const VarDecl *VDecl = If->getConditionVariable();
279       std::string Repl =
280           (tooling::fixit::getText(*VDeclStmt, *Result.Context) +
281            llvm::StringRef(";\n") +
282            tooling::fixit::getText(If->getIfLoc(), *Result.Context))
283               .str();
284       Diag << tooling::fixit::createReplacement(SourceRange(If->getIfLoc()),
285                                                 Repl)
286            << tooling::fixit::createReplacement(VDeclStmt->getSourceRange(),
287                                                 VDecl->getName());
288       removeElseAndBrackets(Diag, *Result.Context, Else, ElseLoc);
289     } else if (WarnOnUnfixable) {
290       // Warn, but don't attempt an autofix.
291       diag(ElseLoc, WarningMessage) << ControlFlowInterrupter;
292     }
293     return;
294   }
295 
296   if (checkInitDeclUsageInElse(If) != nullptr) {
297     if (!WarnOnConditionVariables)
298       return;
299     if (IsLastInScope) {
300       // If the if statement is the last statement of its enclosing statements
301       // scope, we can pull the decl out of the if statement.
302       DiagnosticBuilder Diag = diag(ElseLoc, WarningMessage)
303                                << ControlFlowInterrupter
304                                << SourceRange(ElseLoc);
305       Diag << tooling::fixit::createReplacement(
306                   SourceRange(If->getIfLoc()),
307                   (tooling::fixit::getText(*If->getInit(), *Result.Context) +
308                    "\n" +
309                    tooling::fixit::getText(If->getIfLoc(), *Result.Context))
310                       .str())
311            << tooling::fixit::createRemoval(If->getInit()->getSourceRange());
312       removeElseAndBrackets(Diag, *Result.Context, Else, ElseLoc);
313     } else if (WarnOnUnfixable) {
314       // Warn, but don't attempt an autofix.
315       diag(ElseLoc, WarningMessage) << ControlFlowInterrupter;
316     }
317     return;
318   }
319 
320   DiagnosticBuilder Diag = diag(ElseLoc, WarningMessage)
321                            << ControlFlowInterrupter << SourceRange(ElseLoc);
322   removeElseAndBrackets(Diag, *Result.Context, Else, ElseLoc);
323 }
324 
325 } // namespace clang::tidy::readability
326