xref: /llvm-project/clang/lib/Analysis/Consumed.cpp (revision 32ff209b87a84890a1487b4e0bbb4a7645d31645)
1 //===- Consumed.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A intra-procedural analysis for checking consumed properties.  This is based,
10 // in part, on research on linear types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "clang/Analysis/Analyses/Consumed.h"
15 #include "clang/AST/Attr.h"
16 #include "clang/AST/Decl.h"
17 #include "clang/AST/DeclCXX.h"
18 #include "clang/AST/Expr.h"
19 #include "clang/AST/ExprCXX.h"
20 #include "clang/AST/Stmt.h"
21 #include "clang/AST/StmtVisitor.h"
22 #include "clang/AST/Type.h"
23 #include "clang/Analysis/Analyses/PostOrderCFGView.h"
24 #include "clang/Analysis/AnalysisDeclContext.h"
25 #include "clang/Analysis/CFG.h"
26 #include "clang/Basic/LLVM.h"
27 #include "clang/Basic/OperatorKinds.h"
28 #include "clang/Basic/SourceLocation.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/Casting.h"
33 #include "llvm/Support/ErrorHandling.h"
34 #include <cassert>
35 #include <memory>
36 #include <optional>
37 #include <utility>
38 
39 // TODO: Adjust states of args to constructors in the same way that arguments to
40 //       function calls are handled.
41 // TODO: Use information from tests in for- and while-loop conditional.
42 // TODO: Add notes about the actual and expected state for
43 // TODO: Correctly identify unreachable blocks when chaining boolean operators.
44 // TODO: Adjust the parser and AttributesList class to support lists of
45 //       identifiers.
46 // TODO: Warn about unreachable code.
47 // TODO: Switch to using a bitmap to track unreachable blocks.
48 // TODO: Handle variable definitions, e.g. bool valid = x.isValid();
49 //       if (valid) ...; (Deferred)
50 // TODO: Take notes on state transitions to provide better warning messages.
51 //       (Deferred)
52 // TODO: Test nested conditionals: A) Checking the same value multiple times,
53 //       and 2) Checking different values. (Deferred)
54 
55 using namespace clang;
56 using namespace consumed;
57 
58 // Key method definition
59 ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() = default;
60 
61 static SourceLocation getFirstStmtLoc(const CFGBlock *Block) {
62   // Find the source location of the first statement in the block, if the block
63   // is not empty.
64   for (const auto &B : *Block)
65     if (std::optional<CFGStmt> CS = B.getAs<CFGStmt>())
66       return CS->getStmt()->getBeginLoc();
67 
68   // Block is empty.
69   // If we have one successor, return the first statement in that block
70   if (Block->succ_size() == 1 && *Block->succ_begin())
71     return getFirstStmtLoc(*Block->succ_begin());
72 
73   return {};
74 }
75 
76 static SourceLocation getLastStmtLoc(const CFGBlock *Block) {
77   // Find the source location of the last statement in the block, if the block
78   // is not empty.
79   if (const Stmt *StmtNode = Block->getTerminatorStmt()) {
80     return StmtNode->getBeginLoc();
81   } else {
82     for (CFGBlock::const_reverse_iterator BI = Block->rbegin(),
83          BE = Block->rend(); BI != BE; ++BI) {
84       if (std::optional<CFGStmt> CS = BI->getAs<CFGStmt>())
85         return CS->getStmt()->getBeginLoc();
86     }
87   }
88 
89   // If we have one successor, return the first statement in that block
90   SourceLocation Loc;
91   if (Block->succ_size() == 1 && *Block->succ_begin())
92     Loc = getFirstStmtLoc(*Block->succ_begin());
93   if (Loc.isValid())
94     return Loc;
95 
96   // If we have one predecessor, return the last statement in that block
97   if (Block->pred_size() == 1 && *Block->pred_begin())
98     return getLastStmtLoc(*Block->pred_begin());
99 
100   return Loc;
101 }
102 
103 static ConsumedState invertConsumedUnconsumed(ConsumedState State) {
104   switch (State) {
105   case CS_Unconsumed:
106     return CS_Consumed;
107   case CS_Consumed:
108     return CS_Unconsumed;
109   case CS_None:
110     return CS_None;
111   case CS_Unknown:
112     return CS_Unknown;
113   }
114   llvm_unreachable("invalid enum");
115 }
116 
117 static bool isCallableInState(const CallableWhenAttr *CWAttr,
118                               ConsumedState State) {
119   for (const auto &S : CWAttr->callableStates()) {
120     ConsumedState MappedAttrState = CS_None;
121 
122     switch (S) {
123     case CallableWhenAttr::Unknown:
124       MappedAttrState = CS_Unknown;
125       break;
126 
127     case CallableWhenAttr::Unconsumed:
128       MappedAttrState = CS_Unconsumed;
129       break;
130 
131     case CallableWhenAttr::Consumed:
132       MappedAttrState = CS_Consumed;
133       break;
134     }
135 
136     if (MappedAttrState == State)
137       return true;
138   }
139 
140   return false;
141 }
142 
143 static bool isConsumableType(const QualType &QT) {
144   if (QT->isPointerOrReferenceType())
145     return false;
146 
147   if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
148     return RD->hasAttr<ConsumableAttr>();
149 
150   return false;
151 }
152 
153 static bool isAutoCastType(const QualType &QT) {
154   if (QT->isPointerOrReferenceType())
155     return false;
156 
157   if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
158     return RD->hasAttr<ConsumableAutoCastAttr>();
159 
160   return false;
161 }
162 
163 static bool isSetOnReadPtrType(const QualType &QT) {
164   if (const CXXRecordDecl *RD = QT->getPointeeCXXRecordDecl())
165     return RD->hasAttr<ConsumableSetOnReadAttr>();
166   return false;
167 }
168 
169 static bool isKnownState(ConsumedState State) {
170   switch (State) {
171   case CS_Unconsumed:
172   case CS_Consumed:
173     return true;
174   case CS_None:
175   case CS_Unknown:
176     return false;
177   }
178   llvm_unreachable("invalid enum");
179 }
180 
181 static bool isRValueRef(QualType ParamType) {
182   return ParamType->isRValueReferenceType();
183 }
184 
185 static bool isTestingFunction(const FunctionDecl *FunDecl) {
186   return FunDecl->hasAttr<TestTypestateAttr>();
187 }
188 
189 static ConsumedState mapConsumableAttrState(const QualType QT) {
190   assert(isConsumableType(QT));
191 
192   const ConsumableAttr *CAttr =
193       QT->getAsCXXRecordDecl()->getAttr<ConsumableAttr>();
194 
195   switch (CAttr->getDefaultState()) {
196   case ConsumableAttr::Unknown:
197     return CS_Unknown;
198   case ConsumableAttr::Unconsumed:
199     return CS_Unconsumed;
200   case ConsumableAttr::Consumed:
201     return CS_Consumed;
202   }
203   llvm_unreachable("invalid enum");
204 }
205 
206 static ConsumedState
207 mapParamTypestateAttrState(const ParamTypestateAttr *PTAttr) {
208   switch (PTAttr->getParamState()) {
209   case ParamTypestateAttr::Unknown:
210     return CS_Unknown;
211   case ParamTypestateAttr::Unconsumed:
212     return CS_Unconsumed;
213   case ParamTypestateAttr::Consumed:
214     return CS_Consumed;
215   }
216   llvm_unreachable("invalid_enum");
217 }
218 
219 static ConsumedState
220 mapReturnTypestateAttrState(const ReturnTypestateAttr *RTSAttr) {
221   switch (RTSAttr->getState()) {
222   case ReturnTypestateAttr::Unknown:
223     return CS_Unknown;
224   case ReturnTypestateAttr::Unconsumed:
225     return CS_Unconsumed;
226   case ReturnTypestateAttr::Consumed:
227     return CS_Consumed;
228   }
229   llvm_unreachable("invalid enum");
230 }
231 
232 static ConsumedState mapSetTypestateAttrState(const SetTypestateAttr *STAttr) {
233   switch (STAttr->getNewState()) {
234   case SetTypestateAttr::Unknown:
235     return CS_Unknown;
236   case SetTypestateAttr::Unconsumed:
237     return CS_Unconsumed;
238   case SetTypestateAttr::Consumed:
239     return CS_Consumed;
240   }
241   llvm_unreachable("invalid_enum");
242 }
243 
244 static StringRef stateToString(ConsumedState State) {
245   switch (State) {
246   case consumed::CS_None:
247     return "none";
248 
249   case consumed::CS_Unknown:
250     return "unknown";
251 
252   case consumed::CS_Unconsumed:
253     return "unconsumed";
254 
255   case consumed::CS_Consumed:
256     return "consumed";
257   }
258   llvm_unreachable("invalid enum");
259 }
260 
261 static ConsumedState testsFor(const FunctionDecl *FunDecl) {
262   assert(isTestingFunction(FunDecl));
263   switch (FunDecl->getAttr<TestTypestateAttr>()->getTestState()) {
264   case TestTypestateAttr::Unconsumed:
265     return CS_Unconsumed;
266   case TestTypestateAttr::Consumed:
267     return CS_Consumed;
268   }
269   llvm_unreachable("invalid enum");
270 }
271 
272 namespace {
273 
274 struct VarTestResult {
275   const VarDecl *Var;
276   ConsumedState TestsFor;
277 };
278 
279 } // namespace
280 
281 namespace clang {
282 namespace consumed {
283 
284 enum EffectiveOp {
285   EO_And,
286   EO_Or
287 };
288 
289 class PropagationInfo {
290   enum {
291     IT_None,
292     IT_State,
293     IT_VarTest,
294     IT_BinTest,
295     IT_Var,
296     IT_Tmp
297   } InfoType = IT_None;
298 
299   struct BinTestTy {
300     const BinaryOperator *Source;
301     EffectiveOp EOp;
302     VarTestResult LTest;
303     VarTestResult RTest;
304   };
305 
306   union {
307     ConsumedState State;
308     VarTestResult VarTest;
309     const VarDecl *Var;
310     const CXXBindTemporaryExpr *Tmp;
311     BinTestTy BinTest;
312   };
313 
314 public:
315   PropagationInfo() = default;
316   PropagationInfo(const VarTestResult &VarTest)
317       : InfoType(IT_VarTest), VarTest(VarTest) {}
318 
319   PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)
320       : InfoType(IT_VarTest) {
321     VarTest.Var      = Var;
322     VarTest.TestsFor = TestsFor;
323   }
324 
325   PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
326                   const VarTestResult &LTest, const VarTestResult &RTest)
327       : InfoType(IT_BinTest) {
328     BinTest.Source  = Source;
329     BinTest.EOp     = EOp;
330     BinTest.LTest   = LTest;
331     BinTest.RTest   = RTest;
332   }
333 
334   PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
335                   const VarDecl *LVar, ConsumedState LTestsFor,
336                   const VarDecl *RVar, ConsumedState RTestsFor)
337       : InfoType(IT_BinTest) {
338     BinTest.Source         = Source;
339     BinTest.EOp            = EOp;
340     BinTest.LTest.Var      = LVar;
341     BinTest.LTest.TestsFor = LTestsFor;
342     BinTest.RTest.Var      = RVar;
343     BinTest.RTest.TestsFor = RTestsFor;
344   }
345 
346   PropagationInfo(ConsumedState State)
347       : InfoType(IT_State), State(State) {}
348   PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
349   PropagationInfo(const CXXBindTemporaryExpr *Tmp)
350       : InfoType(IT_Tmp), Tmp(Tmp) {}
351 
352   const ConsumedState &getState() const {
353     assert(InfoType == IT_State);
354     return State;
355   }
356 
357   const VarTestResult &getVarTest() const {
358     assert(InfoType == IT_VarTest);
359     return VarTest;
360   }
361 
362   const VarTestResult &getLTest() const {
363     assert(InfoType == IT_BinTest);
364     return BinTest.LTest;
365   }
366 
367   const VarTestResult &getRTest() const {
368     assert(InfoType == IT_BinTest);
369     return BinTest.RTest;
370   }
371 
372   const VarDecl *getVar() const {
373     assert(InfoType == IT_Var);
374     return Var;
375   }
376 
377   const CXXBindTemporaryExpr *getTmp() const {
378     assert(InfoType == IT_Tmp);
379     return Tmp;
380   }
381 
382   ConsumedState getAsState(const ConsumedStateMap *StateMap) const {
383     assert(isVar() || isTmp() || isState());
384 
385     if (isVar())
386       return StateMap->getState(Var);
387     else if (isTmp())
388       return StateMap->getState(Tmp);
389     else if (isState())
390       return State;
391     else
392       return CS_None;
393   }
394 
395   EffectiveOp testEffectiveOp() const {
396     assert(InfoType == IT_BinTest);
397     return BinTest.EOp;
398   }
399 
400   const BinaryOperator * testSourceNode() const {
401     assert(InfoType == IT_BinTest);
402     return BinTest.Source;
403   }
404 
405   bool isValid() const { return InfoType != IT_None; }
406   bool isState() const { return InfoType == IT_State; }
407   bool isVarTest() const { return InfoType == IT_VarTest; }
408   bool isBinTest() const { return InfoType == IT_BinTest; }
409   bool isVar() const { return InfoType == IT_Var; }
410   bool isTmp() const { return InfoType == IT_Tmp; }
411 
412   bool isTest() const {
413     return InfoType == IT_VarTest || InfoType == IT_BinTest;
414   }
415 
416   bool isPointerToValue() const {
417     return InfoType == IT_Var || InfoType == IT_Tmp;
418   }
419 
420   PropagationInfo invertTest() const {
421     assert(InfoType == IT_VarTest || InfoType == IT_BinTest);
422 
423     if (InfoType == IT_VarTest) {
424       return PropagationInfo(VarTest.Var,
425                              invertConsumedUnconsumed(VarTest.TestsFor));
426 
427     } else if (InfoType == IT_BinTest) {
428       return PropagationInfo(BinTest.Source,
429         BinTest.EOp == EO_And ? EO_Or : EO_And,
430         BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),
431         BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));
432     } else {
433       return {};
434     }
435   }
436 };
437 
438 } // namespace consumed
439 } // namespace clang
440 
441 static void
442 setStateForVarOrTmp(ConsumedStateMap *StateMap, const PropagationInfo &PInfo,
443                     ConsumedState State) {
444   assert(PInfo.isVar() || PInfo.isTmp());
445 
446   if (PInfo.isVar())
447     StateMap->setState(PInfo.getVar(), State);
448   else
449     StateMap->setState(PInfo.getTmp(), State);
450 }
451 
452 namespace clang {
453 namespace consumed {
454 
455 class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {
456   using MapType = llvm::DenseMap<const Stmt *, PropagationInfo>;
457   using PairType= std::pair<const Stmt *, PropagationInfo>;
458   using InfoEntry = MapType::iterator;
459   using ConstInfoEntry = MapType::const_iterator;
460 
461   ConsumedAnalyzer &Analyzer;
462   ConsumedStateMap *StateMap;
463   MapType PropagationMap;
464 
465   InfoEntry findInfo(const Expr *E) {
466     if (const auto Cleanups = dyn_cast<ExprWithCleanups>(E))
467       if (!Cleanups->cleanupsHaveSideEffects())
468         E = Cleanups->getSubExpr();
469     return PropagationMap.find(E->IgnoreParens());
470   }
471 
472   ConstInfoEntry findInfo(const Expr *E) const {
473     if (const auto Cleanups = dyn_cast<ExprWithCleanups>(E))
474       if (!Cleanups->cleanupsHaveSideEffects())
475         E = Cleanups->getSubExpr();
476     return PropagationMap.find(E->IgnoreParens());
477   }
478 
479   void insertInfo(const Expr *E, const PropagationInfo &PI) {
480     PropagationMap.insert(PairType(E->IgnoreParens(), PI));
481   }
482 
483   void forwardInfo(const Expr *From, const Expr *To);
484   void copyInfo(const Expr *From, const Expr *To, ConsumedState CS);
485   ConsumedState getInfo(const Expr *From);
486   void setInfo(const Expr *To, ConsumedState NS);
487   void propagateReturnType(const Expr *Call, const FunctionDecl *Fun);
488 
489 public:
490   void checkCallability(const PropagationInfo &PInfo,
491                         const FunctionDecl *FunDecl,
492                         SourceLocation BlameLoc);
493   bool handleCall(const CallExpr *Call, const Expr *ObjArg,
494                   const FunctionDecl *FunD);
495 
496   void VisitBinaryOperator(const BinaryOperator *BinOp);
497   void VisitCallExpr(const CallExpr *Call);
498   void VisitCastExpr(const CastExpr *Cast);
499   void VisitCXXBindTemporaryExpr(const CXXBindTemporaryExpr *Temp);
500   void VisitCXXConstructExpr(const CXXConstructExpr *Call);
501   void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);
502   void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);
503   void VisitDeclRefExpr(const DeclRefExpr *DeclRef);
504   void VisitDeclStmt(const DeclStmt *DelcS);
505   void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
506   void VisitMemberExpr(const MemberExpr *MExpr);
507   void VisitParmVarDecl(const ParmVarDecl *Param);
508   void VisitReturnStmt(const ReturnStmt *Ret);
509   void VisitUnaryOperator(const UnaryOperator *UOp);
510   void VisitVarDecl(const VarDecl *Var);
511 
512   ConsumedStmtVisitor(ConsumedAnalyzer &Analyzer, ConsumedStateMap *StateMap)
513       : Analyzer(Analyzer), StateMap(StateMap) {}
514 
515   PropagationInfo getInfo(const Expr *StmtNode) const {
516     ConstInfoEntry Entry = findInfo(StmtNode);
517 
518     if (Entry != PropagationMap.end())
519       return Entry->second;
520     else
521       return {};
522   }
523 
524   void reset(ConsumedStateMap *NewStateMap) {
525     StateMap = NewStateMap;
526   }
527 };
528 
529 } // namespace consumed
530 } // namespace clang
531 
532 void ConsumedStmtVisitor::forwardInfo(const Expr *From, const Expr *To) {
533   InfoEntry Entry = findInfo(From);
534   if (Entry != PropagationMap.end())
535     insertInfo(To, Entry->second);
536 }
537 
538 // Create a new state for To, which is initialized to the state of From.
539 // If NS is not CS_None, sets the state of From to NS.
540 void ConsumedStmtVisitor::copyInfo(const Expr *From, const Expr *To,
541                                    ConsumedState NS) {
542   InfoEntry Entry = findInfo(From);
543   if (Entry != PropagationMap.end()) {
544     PropagationInfo& PInfo = Entry->second;
545     ConsumedState CS = PInfo.getAsState(StateMap);
546     if (CS != CS_None)
547       insertInfo(To, PropagationInfo(CS));
548     if (NS != CS_None && PInfo.isPointerToValue())
549       setStateForVarOrTmp(StateMap, PInfo, NS);
550   }
551 }
552 
553 // Get the ConsumedState for From
554 ConsumedState ConsumedStmtVisitor::getInfo(const Expr *From) {
555   InfoEntry Entry = findInfo(From);
556   if (Entry != PropagationMap.end()) {
557     PropagationInfo& PInfo = Entry->second;
558     return PInfo.getAsState(StateMap);
559   }
560   return CS_None;
561 }
562 
563 // If we already have info for To then update it, otherwise create a new entry.
564 void ConsumedStmtVisitor::setInfo(const Expr *To, ConsumedState NS) {
565   InfoEntry Entry = findInfo(To);
566   if (Entry != PropagationMap.end()) {
567     PropagationInfo& PInfo = Entry->second;
568     if (PInfo.isPointerToValue())
569       setStateForVarOrTmp(StateMap, PInfo, NS);
570   } else if (NS != CS_None) {
571      insertInfo(To, PropagationInfo(NS));
572   }
573 }
574 
575 void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
576                                            const FunctionDecl *FunDecl,
577                                            SourceLocation BlameLoc) {
578   assert(!PInfo.isTest());
579 
580   const CallableWhenAttr *CWAttr = FunDecl->getAttr<CallableWhenAttr>();
581   if (!CWAttr)
582     return;
583 
584   if (PInfo.isVar()) {
585     ConsumedState VarState = StateMap->getState(PInfo.getVar());
586 
587     if (VarState == CS_None || isCallableInState(CWAttr, VarState))
588       return;
589 
590     Analyzer.WarningsHandler.warnUseInInvalidState(
591       FunDecl->getNameAsString(), PInfo.getVar()->getNameAsString(),
592       stateToString(VarState), BlameLoc);
593   } else {
594     ConsumedState TmpState = PInfo.getAsState(StateMap);
595 
596     if (TmpState == CS_None || isCallableInState(CWAttr, TmpState))
597       return;
598 
599     Analyzer.WarningsHandler.warnUseOfTempInInvalidState(
600       FunDecl->getNameAsString(), stateToString(TmpState), BlameLoc);
601   }
602 }
603 
604 // Factors out common behavior for function, method, and operator calls.
605 // Check parameters and set parameter state if necessary.
606 // Returns true if the state of ObjArg is set, or false otherwise.
607 bool ConsumedStmtVisitor::handleCall(const CallExpr *Call, const Expr *ObjArg,
608                                      const FunctionDecl *FunD) {
609   unsigned Offset = 0;
610   if (isa<CXXOperatorCallExpr>(Call) && isa<CXXMethodDecl>(FunD))
611     Offset = 1;  // first argument is 'this'
612 
613   // check explicit parameters
614   for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {
615     // Skip variable argument lists.
616     if (Index - Offset >= FunD->getNumParams())
617       break;
618 
619     const ParmVarDecl *Param = FunD->getParamDecl(Index - Offset);
620     QualType ParamType = Param->getType();
621 
622     InfoEntry Entry = findInfo(Call->getArg(Index));
623 
624     if (Entry == PropagationMap.end() || Entry->second.isTest())
625       continue;
626     PropagationInfo PInfo = Entry->second;
627 
628     // Check that the parameter is in the correct state.
629     if (ParamTypestateAttr *PTA = Param->getAttr<ParamTypestateAttr>()) {
630       ConsumedState ParamState = PInfo.getAsState(StateMap);
631       ConsumedState ExpectedState = mapParamTypestateAttrState(PTA);
632 
633       if (ParamState != ExpectedState)
634         Analyzer.WarningsHandler.warnParamTypestateMismatch(
635           Call->getArg(Index)->getExprLoc(),
636           stateToString(ExpectedState), stateToString(ParamState));
637     }
638 
639     if (!(Entry->second.isVar() || Entry->second.isTmp()))
640       continue;
641 
642     // Adjust state on the caller side.
643     if (ReturnTypestateAttr *RT = Param->getAttr<ReturnTypestateAttr>())
644       setStateForVarOrTmp(StateMap, PInfo, mapReturnTypestateAttrState(RT));
645     else if (isRValueRef(ParamType) || isConsumableType(ParamType))
646       setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Consumed);
647     else if (ParamType->isPointerOrReferenceType() &&
648              (!ParamType->getPointeeType().isConstQualified() ||
649               isSetOnReadPtrType(ParamType)))
650       setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Unknown);
651   }
652 
653   if (!ObjArg)
654     return false;
655 
656   // check implicit 'self' parameter, if present
657   InfoEntry Entry = findInfo(ObjArg);
658   if (Entry != PropagationMap.end()) {
659     PropagationInfo PInfo = Entry->second;
660     checkCallability(PInfo, FunD, Call->getExprLoc());
661 
662     if (SetTypestateAttr *STA = FunD->getAttr<SetTypestateAttr>()) {
663       if (PInfo.isVar()) {
664         StateMap->setState(PInfo.getVar(), mapSetTypestateAttrState(STA));
665         return true;
666       }
667       else if (PInfo.isTmp()) {
668         StateMap->setState(PInfo.getTmp(), mapSetTypestateAttrState(STA));
669         return true;
670       }
671     }
672     else if (isTestingFunction(FunD) && PInfo.isVar()) {
673       PropagationMap.insert(PairType(Call,
674         PropagationInfo(PInfo.getVar(), testsFor(FunD))));
675     }
676   }
677   return false;
678 }
679 
680 void ConsumedStmtVisitor::propagateReturnType(const Expr *Call,
681                                               const FunctionDecl *Fun) {
682   QualType RetType = Fun->getCallResultType();
683   if (RetType->isReferenceType())
684     RetType = RetType->getPointeeType();
685 
686   if (isConsumableType(RetType)) {
687     ConsumedState ReturnState;
688     if (ReturnTypestateAttr *RTA = Fun->getAttr<ReturnTypestateAttr>())
689       ReturnState = mapReturnTypestateAttrState(RTA);
690     else
691       ReturnState = mapConsumableAttrState(RetType);
692 
693     PropagationMap.insert(PairType(Call, PropagationInfo(ReturnState)));
694   }
695 }
696 
697 void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {
698   switch (BinOp->getOpcode()) {
699   case BO_LAnd:
700   case BO_LOr : {
701     InfoEntry LEntry = findInfo(BinOp->getLHS()),
702               REntry = findInfo(BinOp->getRHS());
703 
704     VarTestResult LTest, RTest;
705 
706     if (LEntry != PropagationMap.end() && LEntry->second.isVarTest()) {
707       LTest = LEntry->second.getVarTest();
708     } else {
709       LTest.Var      = nullptr;
710       LTest.TestsFor = CS_None;
711     }
712 
713     if (REntry != PropagationMap.end() && REntry->second.isVarTest()) {
714       RTest = REntry->second.getVarTest();
715     } else {
716       RTest.Var      = nullptr;
717       RTest.TestsFor = CS_None;
718     }
719 
720     if (!(LTest.Var == nullptr && RTest.Var == nullptr))
721       PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,
722         static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));
723     break;
724   }
725 
726   case BO_PtrMemD:
727   case BO_PtrMemI:
728     forwardInfo(BinOp->getLHS(), BinOp);
729     break;
730 
731   default:
732     break;
733   }
734 }
735 
736 void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {
737   const FunctionDecl *FunDecl = Call->getDirectCallee();
738   if (!FunDecl)
739     return;
740 
741   // Special case for the std::move function.
742   // TODO: Make this more specific. (Deferred)
743   if (Call->isCallToStdMove()) {
744     copyInfo(Call->getArg(0), Call, CS_Consumed);
745     return;
746   }
747 
748   handleCall(Call, nullptr, FunDecl);
749   propagateReturnType(Call, FunDecl);
750 }
751 
752 void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {
753   forwardInfo(Cast->getSubExpr(), Cast);
754 }
755 
756 void ConsumedStmtVisitor::VisitCXXBindTemporaryExpr(
757   const CXXBindTemporaryExpr *Temp) {
758 
759   InfoEntry Entry = findInfo(Temp->getSubExpr());
760 
761   if (Entry != PropagationMap.end() && !Entry->second.isTest()) {
762     StateMap->setState(Temp, Entry->second.getAsState(StateMap));
763     PropagationMap.insert(PairType(Temp, PropagationInfo(Temp)));
764   }
765 }
766 
767 void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {
768   CXXConstructorDecl *Constructor = Call->getConstructor();
769 
770   QualType ThisType = Constructor->getFunctionObjectParameterType();
771 
772   if (!isConsumableType(ThisType))
773     return;
774 
775   // FIXME: What should happen if someone annotates the move constructor?
776   if (ReturnTypestateAttr *RTA = Constructor->getAttr<ReturnTypestateAttr>()) {
777     // TODO: Adjust state of args appropriately.
778     ConsumedState RetState = mapReturnTypestateAttrState(RTA);
779     PropagationMap.insert(PairType(Call, PropagationInfo(RetState)));
780   } else if (Constructor->isDefaultConstructor()) {
781     PropagationMap.insert(PairType(Call,
782       PropagationInfo(consumed::CS_Consumed)));
783   } else if (Constructor->isMoveConstructor()) {
784     copyInfo(Call->getArg(0), Call, CS_Consumed);
785   } else if (Constructor->isCopyConstructor()) {
786     // Copy state from arg.  If setStateOnRead then set arg to CS_Unknown.
787     ConsumedState NS =
788       isSetOnReadPtrType(Constructor->getThisType()) ?
789       CS_Unknown : CS_None;
790     copyInfo(Call->getArg(0), Call, NS);
791   } else {
792     // TODO: Adjust state of args appropriately.
793     ConsumedState RetState = mapConsumableAttrState(ThisType);
794     PropagationMap.insert(PairType(Call, PropagationInfo(RetState)));
795   }
796 }
797 
798 void ConsumedStmtVisitor::VisitCXXMemberCallExpr(
799     const CXXMemberCallExpr *Call) {
800   CXXMethodDecl* MD = Call->getMethodDecl();
801   if (!MD)
802     return;
803 
804   handleCall(Call, Call->getImplicitObjectArgument(), MD);
805   propagateReturnType(Call, MD);
806 }
807 
808 void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(
809     const CXXOperatorCallExpr *Call) {
810   const auto *FunDecl = dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
811   if (!FunDecl) return;
812 
813   if (Call->getOperator() == OO_Equal) {
814     ConsumedState CS = getInfo(Call->getArg(1));
815     if (!handleCall(Call, Call->getArg(0), FunDecl))
816       setInfo(Call->getArg(0), CS);
817     return;
818   }
819 
820   if (const auto *MCall = dyn_cast<CXXMemberCallExpr>(Call))
821     handleCall(MCall, MCall->getImplicitObjectArgument(), FunDecl);
822   else
823     handleCall(Call, Call->getArg(0), FunDecl);
824 
825   propagateReturnType(Call, FunDecl);
826 }
827 
828 void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {
829   if (const auto *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))
830     if (StateMap->getState(Var) != consumed::CS_None)
831       PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));
832 }
833 
834 void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {
835   for (const auto *DI : DeclS->decls())
836     if (isa<VarDecl>(DI))
837       VisitVarDecl(cast<VarDecl>(DI));
838 
839   if (DeclS->isSingleDecl())
840     if (const auto *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))
841       PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));
842 }
843 
844 void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(
845   const MaterializeTemporaryExpr *Temp) {
846   forwardInfo(Temp->getSubExpr(), Temp);
847 }
848 
849 void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {
850   forwardInfo(MExpr->getBase(), MExpr);
851 }
852 
853 void ConsumedStmtVisitor::VisitParmVarDecl(const ParmVarDecl *Param) {
854   QualType ParamType = Param->getType();
855   ConsumedState ParamState = consumed::CS_None;
856 
857   if (const ParamTypestateAttr *PTA = Param->getAttr<ParamTypestateAttr>())
858     ParamState = mapParamTypestateAttrState(PTA);
859   else if (isConsumableType(ParamType))
860     ParamState = mapConsumableAttrState(ParamType);
861   else if (isRValueRef(ParamType) &&
862            isConsumableType(ParamType->getPointeeType()))
863     ParamState = mapConsumableAttrState(ParamType->getPointeeType());
864   else if (ParamType->isReferenceType() &&
865            isConsumableType(ParamType->getPointeeType()))
866     ParamState = consumed::CS_Unknown;
867 
868   if (ParamState != CS_None)
869     StateMap->setState(Param, ParamState);
870 }
871 
872 void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {
873   ConsumedState ExpectedState = Analyzer.getExpectedReturnState();
874 
875   if (ExpectedState != CS_None) {
876     InfoEntry Entry = findInfo(Ret->getRetValue());
877 
878     if (Entry != PropagationMap.end()) {
879       ConsumedState RetState = Entry->second.getAsState(StateMap);
880 
881       if (RetState != ExpectedState)
882         Analyzer.WarningsHandler.warnReturnTypestateMismatch(
883           Ret->getReturnLoc(), stateToString(ExpectedState),
884           stateToString(RetState));
885     }
886   }
887 
888   StateMap->checkParamsForReturnTypestate(Ret->getBeginLoc(),
889                                           Analyzer.WarningsHandler);
890 }
891 
892 void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
893   InfoEntry Entry = findInfo(UOp->getSubExpr());
894   if (Entry == PropagationMap.end()) return;
895 
896   switch (UOp->getOpcode()) {
897   case UO_AddrOf:
898     PropagationMap.insert(PairType(UOp, Entry->second));
899     break;
900 
901   case UO_LNot:
902     if (Entry->second.isTest())
903       PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
904     break;
905 
906   default:
907     break;
908   }
909 }
910 
911 // TODO: See if I need to check for reference types here.
912 void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
913   if (isConsumableType(Var->getType())) {
914     if (Var->hasInit()) {
915       MapType::iterator VIT = findInfo(Var->getInit()->IgnoreImplicit());
916       if (VIT != PropagationMap.end()) {
917         PropagationInfo PInfo = VIT->second;
918         ConsumedState St = PInfo.getAsState(StateMap);
919 
920         if (St != consumed::CS_None) {
921           StateMap->setState(Var, St);
922           return;
923         }
924       }
925     }
926     // Otherwise
927     StateMap->setState(Var, consumed::CS_Unknown);
928   }
929 }
930 
931 static void splitVarStateForIf(const IfStmt *IfNode, const VarTestResult &Test,
932                                ConsumedStateMap *ThenStates,
933                                ConsumedStateMap *ElseStates) {
934   ConsumedState VarState = ThenStates->getState(Test.Var);
935 
936   if (VarState == CS_Unknown) {
937     ThenStates->setState(Test.Var, Test.TestsFor);
938     ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
939   } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
940     ThenStates->markUnreachable();
941   } else if (VarState == Test.TestsFor) {
942     ElseStates->markUnreachable();
943   }
944 }
945 
946 static void splitVarStateForIfBinOp(const PropagationInfo &PInfo,
947                                     ConsumedStateMap *ThenStates,
948                                     ConsumedStateMap *ElseStates) {
949   const VarTestResult &LTest = PInfo.getLTest(),
950                       &RTest = PInfo.getRTest();
951 
952   ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,
953                 RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;
954 
955   if (LTest.Var) {
956     if (PInfo.testEffectiveOp() == EO_And) {
957       if (LState == CS_Unknown) {
958         ThenStates->setState(LTest.Var, LTest.TestsFor);
959       } else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {
960         ThenStates->markUnreachable();
961       } else if (LState == LTest.TestsFor && isKnownState(RState)) {
962         if (RState == RTest.TestsFor)
963           ElseStates->markUnreachable();
964         else
965           ThenStates->markUnreachable();
966       }
967     } else {
968       if (LState == CS_Unknown) {
969         ElseStates->setState(LTest.Var,
970                              invertConsumedUnconsumed(LTest.TestsFor));
971       } else if (LState == LTest.TestsFor) {
972         ElseStates->markUnreachable();
973       } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
974                  isKnownState(RState)) {
975         if (RState == RTest.TestsFor)
976           ElseStates->markUnreachable();
977         else
978           ThenStates->markUnreachable();
979       }
980     }
981   }
982 
983   if (RTest.Var) {
984     if (PInfo.testEffectiveOp() == EO_And) {
985       if (RState == CS_Unknown)
986         ThenStates->setState(RTest.Var, RTest.TestsFor);
987       else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
988         ThenStates->markUnreachable();
989     } else {
990       if (RState == CS_Unknown)
991         ElseStates->setState(RTest.Var,
992                              invertConsumedUnconsumed(RTest.TestsFor));
993       else if (RState == RTest.TestsFor)
994         ElseStates->markUnreachable();
995     }
996   }
997 }
998 
999 bool ConsumedBlockInfo::allBackEdgesVisited(const CFGBlock *CurrBlock,
1000                                             const CFGBlock *TargetBlock) {
1001   assert(CurrBlock && "Block pointer must not be NULL");
1002   assert(TargetBlock && "TargetBlock pointer must not be NULL");
1003 
1004   unsigned int CurrBlockOrder = VisitOrder[CurrBlock->getBlockID()];
1005   for (CFGBlock::const_pred_iterator PI = TargetBlock->pred_begin(),
1006        PE = TargetBlock->pred_end(); PI != PE; ++PI) {
1007     if (*PI && CurrBlockOrder < VisitOrder[(*PI)->getBlockID()] )
1008       return false;
1009   }
1010   return true;
1011 }
1012 
1013 void ConsumedBlockInfo::addInfo(
1014     const CFGBlock *Block, ConsumedStateMap *StateMap,
1015     std::unique_ptr<ConsumedStateMap> &OwnedStateMap) {
1016   assert(Block && "Block pointer must not be NULL");
1017 
1018   auto &Entry = StateMapsArray[Block->getBlockID()];
1019 
1020   if (Entry) {
1021     Entry->intersect(*StateMap);
1022   } else if (OwnedStateMap)
1023     Entry = std::move(OwnedStateMap);
1024   else
1025     Entry = std::make_unique<ConsumedStateMap>(*StateMap);
1026 }
1027 
1028 void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
1029                                 std::unique_ptr<ConsumedStateMap> StateMap) {
1030   assert(Block && "Block pointer must not be NULL");
1031 
1032   auto &Entry = StateMapsArray[Block->getBlockID()];
1033 
1034   if (Entry) {
1035     Entry->intersect(*StateMap);
1036   } else {
1037     Entry = std::move(StateMap);
1038   }
1039 }
1040 
1041 ConsumedStateMap* ConsumedBlockInfo::borrowInfo(const CFGBlock *Block) {
1042   assert(Block && "Block pointer must not be NULL");
1043   assert(StateMapsArray[Block->getBlockID()] && "Block has no block info");
1044 
1045   return StateMapsArray[Block->getBlockID()].get();
1046 }
1047 
1048 void ConsumedBlockInfo::discardInfo(const CFGBlock *Block) {
1049   StateMapsArray[Block->getBlockID()] = nullptr;
1050 }
1051 
1052 std::unique_ptr<ConsumedStateMap>
1053 ConsumedBlockInfo::getInfo(const CFGBlock *Block) {
1054   assert(Block && "Block pointer must not be NULL");
1055 
1056   auto &Entry = StateMapsArray[Block->getBlockID()];
1057   return isBackEdgeTarget(Block) ? std::make_unique<ConsumedStateMap>(*Entry)
1058                                  : std::move(Entry);
1059 }
1060 
1061 bool ConsumedBlockInfo::isBackEdge(const CFGBlock *From, const CFGBlock *To) {
1062   assert(From && "From block must not be NULL");
1063   assert(To   && "From block must not be NULL");
1064 
1065   return VisitOrder[From->getBlockID()] > VisitOrder[To->getBlockID()];
1066 }
1067 
1068 bool ConsumedBlockInfo::isBackEdgeTarget(const CFGBlock *Block) {
1069   assert(Block && "Block pointer must not be NULL");
1070 
1071   // Anything with less than two predecessors can't be the target of a back
1072   // edge.
1073   if (Block->pred_size() < 2)
1074     return false;
1075 
1076   unsigned int BlockVisitOrder = VisitOrder[Block->getBlockID()];
1077   for (CFGBlock::const_pred_iterator PI = Block->pred_begin(),
1078        PE = Block->pred_end(); PI != PE; ++PI) {
1079     if (*PI && BlockVisitOrder < VisitOrder[(*PI)->getBlockID()])
1080       return true;
1081   }
1082   return false;
1083 }
1084 
1085 void ConsumedStateMap::checkParamsForReturnTypestate(SourceLocation BlameLoc,
1086   ConsumedWarningsHandlerBase &WarningsHandler) const {
1087 
1088   for (const auto &DM : VarMap) {
1089     if (isa<ParmVarDecl>(DM.first)) {
1090       const auto *Param = cast<ParmVarDecl>(DM.first);
1091       const ReturnTypestateAttr *RTA = Param->getAttr<ReturnTypestateAttr>();
1092 
1093       if (!RTA)
1094         continue;
1095 
1096       ConsumedState ExpectedState = mapReturnTypestateAttrState(RTA);
1097       if (DM.second != ExpectedState)
1098         WarningsHandler.warnParamReturnTypestateMismatch(BlameLoc,
1099           Param->getNameAsString(), stateToString(ExpectedState),
1100           stateToString(DM.second));
1101     }
1102   }
1103 }
1104 
1105 void ConsumedStateMap::clearTemporaries() {
1106   TmpMap.clear();
1107 }
1108 
1109 ConsumedState ConsumedStateMap::getState(const VarDecl *Var) const {
1110   VarMapType::const_iterator Entry = VarMap.find(Var);
1111 
1112   if (Entry != VarMap.end())
1113     return Entry->second;
1114 
1115   return CS_None;
1116 }
1117 
1118 ConsumedState
1119 ConsumedStateMap::getState(const CXXBindTemporaryExpr *Tmp) const {
1120   TmpMapType::const_iterator Entry = TmpMap.find(Tmp);
1121 
1122   if (Entry != TmpMap.end())
1123     return Entry->second;
1124 
1125   return CS_None;
1126 }
1127 
1128 void ConsumedStateMap::intersect(const ConsumedStateMap &Other) {
1129   ConsumedState LocalState;
1130 
1131   if (this->From && this->From == Other.From && !Other.Reachable) {
1132     this->markUnreachable();
1133     return;
1134   }
1135 
1136   for (const auto &DM : Other.VarMap) {
1137     LocalState = this->getState(DM.first);
1138 
1139     if (LocalState == CS_None)
1140       continue;
1141 
1142     if (LocalState != DM.second)
1143      VarMap[DM.first] = CS_Unknown;
1144   }
1145 }
1146 
1147 void ConsumedStateMap::intersectAtLoopHead(const CFGBlock *LoopHead,
1148   const CFGBlock *LoopBack, const ConsumedStateMap *LoopBackStates,
1149   ConsumedWarningsHandlerBase &WarningsHandler) {
1150 
1151   ConsumedState LocalState;
1152   SourceLocation BlameLoc = getLastStmtLoc(LoopBack);
1153 
1154   for (const auto &DM : LoopBackStates->VarMap) {
1155     LocalState = this->getState(DM.first);
1156 
1157     if (LocalState == CS_None)
1158       continue;
1159 
1160     if (LocalState != DM.second) {
1161       VarMap[DM.first] = CS_Unknown;
1162       WarningsHandler.warnLoopStateMismatch(BlameLoc,
1163                                             DM.first->getNameAsString());
1164     }
1165   }
1166 }
1167 
1168 void ConsumedStateMap::markUnreachable() {
1169   this->Reachable = false;
1170   VarMap.clear();
1171   TmpMap.clear();
1172 }
1173 
1174 void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {
1175   VarMap[Var] = State;
1176 }
1177 
1178 void ConsumedStateMap::setState(const CXXBindTemporaryExpr *Tmp,
1179                                 ConsumedState State) {
1180   TmpMap[Tmp] = State;
1181 }
1182 
1183 void ConsumedStateMap::remove(const CXXBindTemporaryExpr *Tmp) {
1184   TmpMap.erase(Tmp);
1185 }
1186 
1187 bool ConsumedStateMap::operator!=(const ConsumedStateMap *Other) const {
1188   for (const auto &DM : Other->VarMap)
1189     if (this->getState(DM.first) != DM.second)
1190       return true;
1191   return false;
1192 }
1193 
1194 void ConsumedAnalyzer::determineExpectedReturnState(AnalysisDeclContext &AC,
1195                                                     const FunctionDecl *D) {
1196   QualType ReturnType;
1197   if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
1198     ReturnType = Constructor->getFunctionObjectParameterType();
1199   } else
1200     ReturnType = D->getCallResultType();
1201 
1202   if (const ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>()) {
1203     const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();
1204     if (!RD || !RD->hasAttr<ConsumableAttr>()) {
1205       // FIXME: This should be removed when template instantiation propagates
1206       //        attributes at template specialization definition, not
1207       //        declaration. When it is removed the test needs to be enabled
1208       //        in SemaDeclAttr.cpp.
1209       WarningsHandler.warnReturnTypestateForUnconsumableType(
1210           RTSAttr->getLocation(), ReturnType.getAsString());
1211       ExpectedReturnState = CS_None;
1212     } else
1213       ExpectedReturnState = mapReturnTypestateAttrState(RTSAttr);
1214   } else if (isConsumableType(ReturnType)) {
1215     if (isAutoCastType(ReturnType))   // We can auto-cast the state to the
1216       ExpectedReturnState = CS_None;  // expected state.
1217     else
1218       ExpectedReturnState = mapConsumableAttrState(ReturnType);
1219   }
1220   else
1221     ExpectedReturnState = CS_None;
1222 }
1223 
1224 bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,
1225                                   const ConsumedStmtVisitor &Visitor) {
1226   std::unique_ptr<ConsumedStateMap> FalseStates(
1227       new ConsumedStateMap(*CurrStates));
1228   PropagationInfo PInfo;
1229 
1230   if (const auto *IfNode =
1231           dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
1232     if (IfNode->isConsteval())
1233       return false;
1234 
1235     const Expr *Cond = IfNode->getCond();
1236 
1237     PInfo = Visitor.getInfo(Cond);
1238     if (!PInfo.isValid() && isa<BinaryOperator>(Cond))
1239       PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
1240 
1241     if (PInfo.isVarTest()) {
1242       CurrStates->setSource(Cond);
1243       FalseStates->setSource(Cond);
1244       splitVarStateForIf(IfNode, PInfo.getVarTest(), CurrStates.get(),
1245                          FalseStates.get());
1246     } else if (PInfo.isBinTest()) {
1247       CurrStates->setSource(PInfo.testSourceNode());
1248       FalseStates->setSource(PInfo.testSourceNode());
1249       splitVarStateForIfBinOp(PInfo, CurrStates.get(), FalseStates.get());
1250     } else {
1251       return false;
1252     }
1253   } else if (const auto *BinOp =
1254        dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {
1255     PInfo = Visitor.getInfo(BinOp->getLHS());
1256     if (!PInfo.isVarTest()) {
1257       if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
1258         PInfo = Visitor.getInfo(BinOp->getRHS());
1259 
1260         if (!PInfo.isVarTest())
1261           return false;
1262       } else {
1263         return false;
1264       }
1265     }
1266 
1267     CurrStates->setSource(BinOp);
1268     FalseStates->setSource(BinOp);
1269 
1270     const VarTestResult &Test = PInfo.getVarTest();
1271     ConsumedState VarState = CurrStates->getState(Test.Var);
1272 
1273     if (BinOp->getOpcode() == BO_LAnd) {
1274       if (VarState == CS_Unknown)
1275         CurrStates->setState(Test.Var, Test.TestsFor);
1276       else if (VarState == invertConsumedUnconsumed(Test.TestsFor))
1277         CurrStates->markUnreachable();
1278 
1279     } else if (BinOp->getOpcode() == BO_LOr) {
1280       if (VarState == CS_Unknown)
1281         FalseStates->setState(Test.Var,
1282                               invertConsumedUnconsumed(Test.TestsFor));
1283       else if (VarState == Test.TestsFor)
1284         FalseStates->markUnreachable();
1285     }
1286   } else {
1287     return false;
1288   }
1289 
1290   CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();
1291 
1292   if (*SI)
1293     BlockInfo.addInfo(*SI, std::move(CurrStates));
1294   else
1295     CurrStates = nullptr;
1296 
1297   if (*++SI)
1298     BlockInfo.addInfo(*SI, std::move(FalseStates));
1299 
1300   return true;
1301 }
1302 
1303 void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {
1304   const auto *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());
1305   if (!D)
1306     return;
1307 
1308   CFG *CFGraph = AC.getCFG();
1309   if (!CFGraph)
1310     return;
1311 
1312   determineExpectedReturnState(AC, D);
1313 
1314   PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();
1315   // AC.getCFG()->viewCFG(LangOptions());
1316 
1317   BlockInfo = ConsumedBlockInfo(CFGraph->getNumBlockIDs(), SortedGraph);
1318 
1319   CurrStates = std::make_unique<ConsumedStateMap>();
1320   ConsumedStmtVisitor Visitor(*this, CurrStates.get());
1321 
1322   // Add all trackable parameters to the state map.
1323   for (const auto *PI : D->parameters())
1324     Visitor.VisitParmVarDecl(PI);
1325 
1326   // Visit all of the function's basic blocks.
1327   for (const auto *CurrBlock : *SortedGraph) {
1328     if (!CurrStates)
1329       CurrStates = BlockInfo.getInfo(CurrBlock);
1330 
1331     if (!CurrStates) {
1332       continue;
1333     } else if (!CurrStates->isReachable()) {
1334       CurrStates = nullptr;
1335       continue;
1336     }
1337 
1338     Visitor.reset(CurrStates.get());
1339 
1340     // Visit all of the basic block's statements.
1341     for (const auto &B : *CurrBlock) {
1342       switch (B.getKind()) {
1343       case CFGElement::Statement:
1344         Visitor.Visit(B.castAs<CFGStmt>().getStmt());
1345         break;
1346 
1347       case CFGElement::TemporaryDtor: {
1348         const CFGTemporaryDtor &DTor = B.castAs<CFGTemporaryDtor>();
1349         const CXXBindTemporaryExpr *BTE = DTor.getBindTemporaryExpr();
1350 
1351         Visitor.checkCallability(PropagationInfo(BTE),
1352                                  DTor.getDestructorDecl(AC.getASTContext()),
1353                                  BTE->getExprLoc());
1354         CurrStates->remove(BTE);
1355         break;
1356       }
1357 
1358       case CFGElement::AutomaticObjectDtor: {
1359         const CFGAutomaticObjDtor &DTor = B.castAs<CFGAutomaticObjDtor>();
1360         SourceLocation Loc = DTor.getTriggerStmt()->getEndLoc();
1361         const VarDecl *Var = DTor.getVarDecl();
1362 
1363         Visitor.checkCallability(PropagationInfo(Var),
1364                                  DTor.getDestructorDecl(AC.getASTContext()),
1365                                  Loc);
1366         break;
1367       }
1368 
1369       default:
1370         break;
1371       }
1372     }
1373 
1374     // TODO: Handle other forms of branching with precision, including while-
1375     //       and for-loops. (Deferred)
1376     if (!splitState(CurrBlock, Visitor)) {
1377       CurrStates->setSource(nullptr);
1378 
1379       if (CurrBlock->succ_size() > 1 ||
1380           (CurrBlock->succ_size() == 1 &&
1381            (*CurrBlock->succ_begin())->pred_size() > 1)) {
1382 
1383         auto *RawState = CurrStates.get();
1384 
1385         for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),
1386              SE = CurrBlock->succ_end(); SI != SE; ++SI) {
1387           if (*SI == nullptr) continue;
1388 
1389           if (BlockInfo.isBackEdge(CurrBlock, *SI)) {
1390             BlockInfo.borrowInfo(*SI)->intersectAtLoopHead(
1391                 *SI, CurrBlock, RawState, WarningsHandler);
1392 
1393             if (BlockInfo.allBackEdgesVisited(CurrBlock, *SI))
1394               BlockInfo.discardInfo(*SI);
1395           } else {
1396             BlockInfo.addInfo(*SI, RawState, CurrStates);
1397           }
1398         }
1399 
1400         CurrStates = nullptr;
1401       }
1402     }
1403 
1404     if (CurrBlock == &AC.getCFG()->getExit() &&
1405         D->getCallResultType()->isVoidType())
1406       CurrStates->checkParamsForReturnTypestate(D->getLocation(),
1407                                                 WarningsHandler);
1408   } // End of block iterator.
1409 
1410   // Delete the last existing state map.
1411   CurrStates = nullptr;
1412 
1413   WarningsHandler.emitDiagnostics();
1414 }
1415