xref: /llvm-project/clang/lib/Analysis/Consumed.cpp (revision 16f76d27ae9c921ad167ce4d50759e8e9dd10926)
1 //===- Consumed.cpp --------------------------------------------*- C++ --*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // A intra-procedural analysis for checking consumed properties.  This is based,
11 // in part, on research on linear types.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "clang/AST/ASTContext.h"
16 #include "clang/AST/Attr.h"
17 #include "clang/AST/DeclCXX.h"
18 #include "clang/AST/ExprCXX.h"
19 #include "clang/AST/RecursiveASTVisitor.h"
20 #include "clang/AST/StmtVisitor.h"
21 #include "clang/AST/StmtCXX.h"
22 #include "clang/AST/Type.h"
23 #include "clang/Analysis/Analyses/PostOrderCFGView.h"
24 #include "clang/Analysis/AnalysisContext.h"
25 #include "clang/Analysis/CFG.h"
26 #include "clang/Analysis/Analyses/Consumed.h"
27 #include "clang/Basic/OperatorKinds.h"
28 #include "clang/Basic/SourceLocation.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/Support/Compiler.h"
32 #include "llvm/Support/raw_ostream.h"
33 
34 // TODO: Add notes about the actual and expected state for
35 // TODO: Correctly identify unreachable blocks when chaining boolean operators.
36 // TODO: Warn about unreachable code.
37 // TODO: Switch to using a bitmap to track unreachable blocks.
38 // TODO: Mark variables as Unknown going into while- or for-loops only if they
39 //       are referenced inside that block. (Deferred)
40 // TODO: Handle variable definitions, e.g. bool valid = x.isValid();
41 //       if (valid) ...; (Deferred)
42 // TODO: Add a method(s) to identify which method calls perform what state
43 //       transitions. (Deferred)
44 // TODO: Take notes on state transitions to provide better warning messages.
45 //       (Deferred)
46 // TODO: Test nested conditionals: A) Checking the same value multiple times,
47 //       and 2) Checking different values. (Deferred)
48 
49 using namespace clang;
50 using namespace consumed;
51 
52 // Key method definition
53 ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() {}
54 
55 static ConsumedState invertConsumedUnconsumed(ConsumedState State) {
56   switch (State) {
57   case CS_Unconsumed:
58     return CS_Consumed;
59   case CS_Consumed:
60     return CS_Unconsumed;
61   case CS_None:
62     return CS_None;
63   case CS_Unknown:
64     return CS_Unknown;
65   }
66   llvm_unreachable("invalid enum");
67 }
68 
69 static bool isConsumableType(const QualType &QT) {
70   if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
71     return RD->hasAttr<ConsumableAttr>();
72   else
73     return false;
74 }
75 
76 static bool isKnownState(ConsumedState State) {
77   switch (State) {
78   case CS_Unconsumed:
79   case CS_Consumed:
80     return true;
81   case CS_None:
82   case CS_Unknown:
83     return false;
84   }
85   llvm_unreachable("invalid enum");
86 }
87 
88 static bool isTestingFunction(const FunctionDecl *FunDecl) {
89   return FunDecl->hasAttr<TestsUnconsumedAttr>();
90 }
91 
92 static ConsumedState mapConsumableAttrState(const QualType QT) {
93   assert(isConsumableType(QT));
94 
95   const ConsumableAttr *CAttr =
96       QT->getAsCXXRecordDecl()->getAttr<ConsumableAttr>();
97 
98   switch (CAttr->getDefaultState()) {
99   case ConsumableAttr::Unknown:
100     return CS_Unknown;
101   case ConsumableAttr::Unconsumed:
102     return CS_Unconsumed;
103   case ConsumableAttr::Consumed:
104     return CS_Consumed;
105   }
106   llvm_unreachable("invalid enum");
107 }
108 
109 static ConsumedState
110 mapReturnTypestateAttrState(const ReturnTypestateAttr *RTSAttr) {
111   switch (RTSAttr->getState()) {
112   case ReturnTypestateAttr::Unknown:
113     return CS_Unknown;
114   case ReturnTypestateAttr::Unconsumed:
115     return CS_Unconsumed;
116   case ReturnTypestateAttr::Consumed:
117     return CS_Consumed;
118   }
119   llvm_unreachable("invalid enum");
120 }
121 
122 static StringRef stateToString(ConsumedState State) {
123   switch (State) {
124   case consumed::CS_None:
125     return "none";
126 
127   case consumed::CS_Unknown:
128     return "unknown";
129 
130   case consumed::CS_Unconsumed:
131     return "unconsumed";
132 
133   case consumed::CS_Consumed:
134     return "consumed";
135   }
136   llvm_unreachable("invalid enum");
137 }
138 
139 namespace {
140 struct VarTestResult {
141   const VarDecl *Var;
142   ConsumedState TestsFor;
143 };
144 } // end anonymous::VarTestResult
145 
146 namespace clang {
147 namespace consumed {
148 
149 enum EffectiveOp {
150   EO_And,
151   EO_Or
152 };
153 
154 class PropagationInfo {
155   enum {
156     IT_None,
157     IT_State,
158     IT_Test,
159     IT_BinTest,
160     IT_Var
161   } InfoType;
162 
163   struct BinTestTy {
164     const BinaryOperator *Source;
165     EffectiveOp EOp;
166     VarTestResult LTest;
167     VarTestResult RTest;
168   };
169 
170   union {
171     ConsumedState State;
172     VarTestResult Test;
173     const VarDecl *Var;
174     BinTestTy BinTest;
175   };
176 
177 public:
178   PropagationInfo() : InfoType(IT_None) {}
179 
180   PropagationInfo(const VarTestResult &Test) : InfoType(IT_Test), Test(Test) {}
181   PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)
182     : InfoType(IT_Test) {
183 
184     Test.Var      = Var;
185     Test.TestsFor = TestsFor;
186   }
187 
188   PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
189                   const VarTestResult &LTest, const VarTestResult &RTest)
190     : InfoType(IT_BinTest) {
191 
192     BinTest.Source  = Source;
193     BinTest.EOp     = EOp;
194     BinTest.LTest   = LTest;
195     BinTest.RTest   = RTest;
196   }
197 
198   PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
199                   const VarDecl *LVar, ConsumedState LTestsFor,
200                   const VarDecl *RVar, ConsumedState RTestsFor)
201     : InfoType(IT_BinTest) {
202 
203     BinTest.Source         = Source;
204     BinTest.EOp            = EOp;
205     BinTest.LTest.Var      = LVar;
206     BinTest.LTest.TestsFor = LTestsFor;
207     BinTest.RTest.Var      = RVar;
208     BinTest.RTest.TestsFor = RTestsFor;
209   }
210 
211   PropagationInfo(ConsumedState State) : InfoType(IT_State), State(State) {}
212   PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
213 
214   const ConsumedState & getState() const {
215     assert(InfoType == IT_State);
216     return State;
217   }
218 
219   const VarTestResult & getTest() const {
220     assert(InfoType == IT_Test);
221     return Test;
222   }
223 
224   const VarTestResult & getLTest() const {
225     assert(InfoType == IT_BinTest);
226     return BinTest.LTest;
227   }
228 
229   const VarTestResult & getRTest() const {
230     assert(InfoType == IT_BinTest);
231     return BinTest.RTest;
232   }
233 
234   const VarDecl * getVar() const {
235     assert(InfoType == IT_Var);
236     return Var;
237   }
238 
239   EffectiveOp testEffectiveOp() const {
240     assert(InfoType == IT_BinTest);
241     return BinTest.EOp;
242   }
243 
244   const BinaryOperator * testSourceNode() const {
245     assert(InfoType == IT_BinTest);
246     return BinTest.Source;
247   }
248 
249   bool isValid()   const { return InfoType != IT_None;     }
250   bool isState()   const { return InfoType == IT_State;    }
251   bool isTest()    const { return InfoType == IT_Test;     }
252   bool isBinTest() const { return InfoType == IT_BinTest;  }
253   bool isVar()     const { return InfoType == IT_Var;      }
254 
255   PropagationInfo invertTest() const {
256     assert(InfoType == IT_Test || InfoType == IT_BinTest);
257 
258     if (InfoType == IT_Test) {
259       return PropagationInfo(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
260 
261     } else if (InfoType == IT_BinTest) {
262       return PropagationInfo(BinTest.Source,
263         BinTest.EOp == EO_And ? EO_Or : EO_And,
264         BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),
265         BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));
266     } else {
267       return PropagationInfo();
268     }
269   }
270 };
271 
272 class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {
273 
274   typedef llvm::DenseMap<const Stmt *, PropagationInfo> MapType;
275   typedef std::pair<const Stmt *, PropagationInfo> PairType;
276   typedef MapType::iterator InfoEntry;
277   typedef MapType::const_iterator ConstInfoEntry;
278 
279   AnalysisDeclContext &AC;
280   ConsumedAnalyzer &Analyzer;
281   ConsumedStateMap *StateMap;
282   MapType PropagationMap;
283 
284   void checkCallability(const PropagationInfo &PInfo,
285                         const FunctionDecl *FunDecl,
286                         const CallExpr *Call);
287   void forwardInfo(const Stmt *From, const Stmt *To);
288   void handleTestingFunctionCall(const CallExpr *Call, const VarDecl *Var);
289   bool isLikeMoveAssignment(const CXXMethodDecl *MethodDecl);
290   void propagateReturnType(const Stmt *Call, const FunctionDecl *Fun,
291                            QualType ReturnType);
292 
293 public:
294 
295   void Visit(const Stmt *StmtNode);
296 
297   void VisitBinaryOperator(const BinaryOperator *BinOp);
298   void VisitCallExpr(const CallExpr *Call);
299   void VisitCastExpr(const CastExpr *Cast);
300   void VisitCXXConstructExpr(const CXXConstructExpr *Call);
301   void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);
302   void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);
303   void VisitDeclRefExpr(const DeclRefExpr *DeclRef);
304   void VisitDeclStmt(const DeclStmt *DelcS);
305   void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
306   void VisitMemberExpr(const MemberExpr *MExpr);
307   void VisitParmVarDecl(const ParmVarDecl *Param);
308   void VisitReturnStmt(const ReturnStmt *Ret);
309   void VisitUnaryOperator(const UnaryOperator *UOp);
310   void VisitVarDecl(const VarDecl *Var);
311 
312   ConsumedStmtVisitor(AnalysisDeclContext &AC, ConsumedAnalyzer &Analyzer,
313                       ConsumedStateMap *StateMap)
314       : AC(AC), Analyzer(Analyzer), StateMap(StateMap) {}
315 
316   PropagationInfo getInfo(const Stmt *StmtNode) const {
317     ConstInfoEntry Entry = PropagationMap.find(StmtNode);
318 
319     if (Entry != PropagationMap.end())
320       return Entry->second;
321     else
322       return PropagationInfo();
323   }
324 
325   void reset(ConsumedStateMap *NewStateMap) {
326     StateMap = NewStateMap;
327   }
328 };
329 
330 // TODO: When we support CallableWhenConsumed this will have to check for
331 //       the different attributes and change the behavior bellow. (Deferred)
332 void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
333                                            const FunctionDecl *FunDecl,
334                                            const CallExpr *Call) {
335 
336   if (!FunDecl->hasAttr<CallableWhenUnconsumedAttr>()) return;
337 
338   if (PInfo.isVar()) {
339     const VarDecl *Var = PInfo.getVar();
340 
341     switch (StateMap->getState(Var)) {
342     case CS_Consumed:
343       Analyzer.WarningsHandler.warnUseWhileConsumed(
344         FunDecl->getNameAsString(), Var->getNameAsString(),
345         Call->getExprLoc());
346       break;
347 
348     case CS_Unknown:
349       Analyzer.WarningsHandler.warnUseInUnknownState(
350         FunDecl->getNameAsString(), Var->getNameAsString(),
351         Call->getExprLoc());
352       break;
353 
354     case CS_None:
355     case CS_Unconsumed:
356       break;
357     }
358 
359   } else {
360     switch (PInfo.getState()) {
361     case CS_Consumed:
362       Analyzer.WarningsHandler.warnUseOfTempWhileConsumed(
363         FunDecl->getNameAsString(), Call->getExprLoc());
364       break;
365 
366     case CS_Unknown:
367       Analyzer.WarningsHandler.warnUseOfTempInUnknownState(
368         FunDecl->getNameAsString(), Call->getExprLoc());
369       break;
370 
371     case CS_None:
372     case CS_Unconsumed:
373       break;
374     }
375   }
376 }
377 
378 void ConsumedStmtVisitor::forwardInfo(const Stmt *From, const Stmt *To) {
379   InfoEntry Entry = PropagationMap.find(From);
380 
381   if (Entry != PropagationMap.end())
382     PropagationMap.insert(PairType(To, Entry->second));
383 }
384 
385 void ConsumedStmtVisitor::handleTestingFunctionCall(const CallExpr *Call,
386                                                     const VarDecl  *Var) {
387 
388   ConsumedState VarState = StateMap->getState(Var);
389 
390   if (VarState != CS_Unknown) {
391     SourceLocation CallLoc = Call->getExprLoc();
392 
393     if (!CallLoc.isMacroID())
394       Analyzer.WarningsHandler.warnUnnecessaryTest(Var->getNameAsString(),
395         stateToString(VarState), CallLoc);
396   }
397 
398   PropagationMap.insert(PairType(Call, PropagationInfo(Var, CS_Unconsumed)));
399 }
400 
401 bool ConsumedStmtVisitor::isLikeMoveAssignment(
402   const CXXMethodDecl *MethodDecl) {
403 
404   return MethodDecl->isMoveAssignmentOperator() ||
405          (MethodDecl->getOverloadedOperator() == OO_Equal &&
406           MethodDecl->getNumParams() == 1 &&
407           MethodDecl->getParamDecl(0)->getType()->isRValueReferenceType());
408 }
409 
410 void ConsumedStmtVisitor::propagateReturnType(const Stmt *Call,
411                                               const FunctionDecl *Fun,
412                                               QualType ReturnType) {
413   if (isConsumableType(ReturnType)) {
414 
415     ConsumedState ReturnState;
416 
417     if (Fun->hasAttr<ReturnTypestateAttr>())
418       ReturnState = mapReturnTypestateAttrState(
419         Fun->getAttr<ReturnTypestateAttr>());
420     else
421       ReturnState = mapConsumableAttrState(ReturnType);
422 
423     PropagationMap.insert(PairType(Call,
424       PropagationInfo(ReturnState)));
425   }
426 }
427 
428 void ConsumedStmtVisitor::Visit(const Stmt *StmtNode) {
429 
430   ConstStmtVisitor<ConsumedStmtVisitor>::Visit(StmtNode);
431 
432   for (Stmt::const_child_iterator CI = StmtNode->child_begin(),
433        CE = StmtNode->child_end(); CI != CE; ++CI) {
434 
435     PropagationMap.erase(*CI);
436   }
437 }
438 
439 void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {
440   switch (BinOp->getOpcode()) {
441   case BO_LAnd:
442   case BO_LOr : {
443     InfoEntry LEntry = PropagationMap.find(BinOp->getLHS()),
444               REntry = PropagationMap.find(BinOp->getRHS());
445 
446     VarTestResult LTest, RTest;
447 
448     if (LEntry != PropagationMap.end() && LEntry->second.isTest()) {
449       LTest = LEntry->second.getTest();
450 
451     } else {
452       LTest.Var      = NULL;
453       LTest.TestsFor = CS_None;
454     }
455 
456     if (REntry != PropagationMap.end() && REntry->second.isTest()) {
457       RTest = REntry->second.getTest();
458 
459     } else {
460       RTest.Var      = NULL;
461       RTest.TestsFor = CS_None;
462     }
463 
464     if (!(LTest.Var == NULL && RTest.Var == NULL))
465       PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,
466         static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));
467 
468     break;
469   }
470 
471   case BO_PtrMemD:
472   case BO_PtrMemI:
473     forwardInfo(BinOp->getLHS(), BinOp);
474     break;
475 
476   default:
477     break;
478   }
479 }
480 
481 void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {
482   if (const FunctionDecl *FunDecl =
483     dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee())) {
484 
485     // Special case for the std::move function.
486     // TODO: Make this more specific. (Deferred)
487     if (FunDecl->getNameAsString() == "move") {
488       InfoEntry Entry = PropagationMap.find(Call->getArg(0));
489 
490       if (Entry != PropagationMap.end()) {
491         PropagationMap.insert(PairType(Call, Entry->second));
492       }
493 
494       return;
495     }
496 
497     unsigned Offset = Call->getNumArgs() - FunDecl->getNumParams();
498 
499     for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {
500       QualType ParamType = FunDecl->getParamDecl(Index - Offset)->getType();
501 
502       InfoEntry Entry = PropagationMap.find(Call->getArg(Index));
503 
504       if (Entry == PropagationMap.end() || !Entry->second.isVar()) {
505         continue;
506       }
507 
508       PropagationInfo PInfo = Entry->second;
509 
510       if (ParamType->isRValueReferenceType() ||
511           (ParamType->isLValueReferenceType() &&
512            !cast<LValueReferenceType>(*ParamType).isSpelledAsLValue())) {
513 
514         StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
515 
516       } else if (!(ParamType.isConstQualified() ||
517                    ((ParamType->isReferenceType() ||
518                      ParamType->isPointerType()) &&
519                     ParamType->getPointeeType().isConstQualified()))) {
520 
521         StateMap->setState(PInfo.getVar(), consumed::CS_Unknown);
522       }
523     }
524 
525     propagateReturnType(Call, FunDecl, FunDecl->getCallResultType());
526   }
527 }
528 
529 void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {
530   forwardInfo(Cast->getSubExpr(), Cast);
531 }
532 
533 void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {
534   CXXConstructorDecl *Constructor = Call->getConstructor();
535 
536   ASTContext &CurrContext = AC.getASTContext();
537   QualType ThisType = Constructor->getThisType(CurrContext)->getPointeeType();
538 
539   if (isConsumableType(ThisType)) {
540     if (Constructor->isDefaultConstructor()) {
541 
542       PropagationMap.insert(PairType(Call,
543         PropagationInfo(consumed::CS_Consumed)));
544 
545     } else if (Constructor->isMoveConstructor()) {
546 
547       PropagationInfo PInfo =
548         PropagationMap.find(Call->getArg(0))->second;
549 
550       if (PInfo.isVar()) {
551         const VarDecl* Var = PInfo.getVar();
552 
553         PropagationMap.insert(PairType(Call,
554           PropagationInfo(StateMap->getState(Var))));
555 
556         StateMap->setState(Var, consumed::CS_Consumed);
557 
558       } else {
559         PropagationMap.insert(PairType(Call, PInfo));
560       }
561 
562     } else if (Constructor->isCopyConstructor()) {
563       MapType::iterator Entry = PropagationMap.find(Call->getArg(0));
564 
565       if (Entry != PropagationMap.end())
566         PropagationMap.insert(PairType(Call, Entry->second));
567 
568     } else {
569       propagateReturnType(Call, Constructor, ThisType);
570     }
571   }
572 }
573 
574 void ConsumedStmtVisitor::VisitCXXMemberCallExpr(
575   const CXXMemberCallExpr *Call) {
576 
577   VisitCallExpr(Call);
578 
579   InfoEntry Entry = PropagationMap.find(Call->getCallee()->IgnoreParens());
580 
581   if (Entry != PropagationMap.end()) {
582     PropagationInfo PInfo = Entry->second;
583     const CXXMethodDecl *MethodDecl = Call->getMethodDecl();
584 
585     checkCallability(PInfo, MethodDecl, Call);
586 
587     if (PInfo.isVar()) {
588       if (isTestingFunction(MethodDecl))
589         handleTestingFunctionCall(Call, PInfo.getVar());
590       else if (MethodDecl->hasAttr<ConsumesAttr>())
591         StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
592     }
593   }
594 }
595 
596 void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(
597   const CXXOperatorCallExpr *Call) {
598 
599   const FunctionDecl *FunDecl =
600     dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
601 
602   if (!FunDecl) return;
603 
604   if (isa<CXXMethodDecl>(FunDecl) &&
605       isLikeMoveAssignment(cast<CXXMethodDecl>(FunDecl))) {
606 
607     InfoEntry LEntry = PropagationMap.find(Call->getArg(0));
608     InfoEntry REntry = PropagationMap.find(Call->getArg(1));
609 
610     PropagationInfo LPInfo, RPInfo;
611 
612     if (LEntry != PropagationMap.end() &&
613         REntry != PropagationMap.end()) {
614 
615       LPInfo = LEntry->second;
616       RPInfo = REntry->second;
617 
618       if (LPInfo.isVar() && RPInfo.isVar()) {
619         StateMap->setState(LPInfo.getVar(),
620           StateMap->getState(RPInfo.getVar()));
621 
622         StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
623 
624         PropagationMap.insert(PairType(Call, LPInfo));
625 
626       } else if (LPInfo.isVar() && !RPInfo.isVar()) {
627         StateMap->setState(LPInfo.getVar(), RPInfo.getState());
628 
629         PropagationMap.insert(PairType(Call, LPInfo));
630 
631       } else if (!LPInfo.isVar() && RPInfo.isVar()) {
632         PropagationMap.insert(PairType(Call,
633           PropagationInfo(StateMap->getState(RPInfo.getVar()))));
634 
635         StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
636 
637       } else {
638         PropagationMap.insert(PairType(Call, RPInfo));
639       }
640 
641     } else if (LEntry != PropagationMap.end() &&
642                REntry == PropagationMap.end()) {
643 
644       LPInfo = LEntry->second;
645 
646       if (LPInfo.isVar()) {
647         StateMap->setState(LPInfo.getVar(), consumed::CS_Unknown);
648 
649         PropagationMap.insert(PairType(Call, LPInfo));
650 
651       } else {
652         PropagationMap.insert(PairType(Call,
653           PropagationInfo(consumed::CS_Unknown)));
654       }
655 
656     } else if (LEntry == PropagationMap.end() &&
657                REntry != PropagationMap.end()) {
658 
659       RPInfo = REntry->second;
660 
661       if (RPInfo.isVar()) {
662         const VarDecl *Var = RPInfo.getVar();
663 
664         PropagationMap.insert(PairType(Call,
665           PropagationInfo(StateMap->getState(Var))));
666 
667         StateMap->setState(Var, consumed::CS_Consumed);
668 
669       } else {
670         PropagationMap.insert(PairType(Call, RPInfo));
671       }
672     }
673 
674   } else {
675 
676     VisitCallExpr(Call);
677 
678     InfoEntry Entry = PropagationMap.find(Call->getArg(0));
679 
680     if (Entry != PropagationMap.end()) {
681       PropagationInfo PInfo = Entry->second;
682 
683       checkCallability(PInfo, FunDecl, Call);
684 
685       if (PInfo.isVar()) {
686         if (isTestingFunction(FunDecl))
687           handleTestingFunctionCall(Call, PInfo.getVar());
688         else if (FunDecl->hasAttr<ConsumesAttr>())
689           StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
690       }
691     }
692   }
693 }
694 
695 void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {
696   if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))
697     if (StateMap->getState(Var) != consumed::CS_None)
698       PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));
699 }
700 
701 void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {
702   for (DeclStmt::const_decl_iterator DI = DeclS->decl_begin(),
703        DE = DeclS->decl_end(); DI != DE; ++DI) {
704 
705     if (isa<VarDecl>(*DI)) VisitVarDecl(cast<VarDecl>(*DI));
706   }
707 
708   if (DeclS->isSingleDecl())
709     if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))
710       PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));
711 }
712 
713 void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(
714   const MaterializeTemporaryExpr *Temp) {
715 
716   InfoEntry Entry = PropagationMap.find(Temp->GetTemporaryExpr());
717 
718   if (Entry != PropagationMap.end())
719     PropagationMap.insert(PairType(Temp, Entry->second));
720 }
721 
722 void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {
723   forwardInfo(MExpr->getBase(), MExpr);
724 }
725 
726 
727 void ConsumedStmtVisitor::VisitParmVarDecl(const ParmVarDecl *Param) {
728   QualType ParamType = Param->getType();
729   ConsumedState ParamState = consumed::CS_None;
730 
731   if (!(ParamType->isPointerType() || ParamType->isReferenceType()) &&
732       isConsumableType(ParamType))
733     ParamState = mapConsumableAttrState(ParamType);
734   else if (ParamType->isReferenceType() &&
735            isConsumableType(ParamType->getPointeeType()))
736     ParamState = consumed::CS_Unknown;
737 
738   if (ParamState)
739     StateMap->setState(Param, ParamState);
740 }
741 
742 void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {
743   if (ConsumedState ExpectedState = Analyzer.getExpectedReturnState()) {
744     InfoEntry Entry = PropagationMap.find(Ret->getRetValue());
745 
746     if (Entry != PropagationMap.end()) {
747       assert(Entry->second.isState() || Entry->second.isVar());
748 
749       ConsumedState RetState = Entry->second.isState() ?
750         Entry->second.getState() : StateMap->getState(Entry->second.getVar());
751 
752       if (RetState != ExpectedState)
753         Analyzer.WarningsHandler.warnReturnTypestateMismatch(
754           Ret->getReturnLoc(), stateToString(ExpectedState),
755           stateToString(RetState));
756     }
757   }
758 }
759 
760 void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
761   InfoEntry Entry = PropagationMap.find(UOp->getSubExpr()->IgnoreParens());
762   if (Entry == PropagationMap.end()) return;
763 
764   switch (UOp->getOpcode()) {
765   case UO_AddrOf:
766     PropagationMap.insert(PairType(UOp, Entry->second));
767     break;
768 
769   case UO_LNot:
770     if (Entry->second.isTest() || Entry->second.isBinTest())
771       PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
772     break;
773 
774   default:
775     break;
776   }
777 }
778 
779 void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
780   if (isConsumableType(Var->getType())) {
781     if (Var->hasInit()) {
782       PropagationInfo PInfo =
783         PropagationMap.find(Var->getInit())->second;
784 
785       StateMap->setState(Var, PInfo.isVar() ?
786         StateMap->getState(PInfo.getVar()) : PInfo.getState());
787 
788     } else {
789       StateMap->setState(Var, consumed::CS_Unknown);
790     }
791   }
792 }
793 }} // end clang::consumed::ConsumedStmtVisitor
794 
795 namespace clang {
796 namespace consumed {
797 
798 void splitVarStateForIf(const IfStmt * IfNode, const VarTestResult &Test,
799                         ConsumedStateMap *ThenStates,
800                         ConsumedStateMap *ElseStates) {
801 
802   ConsumedState VarState = ThenStates->getState(Test.Var);
803 
804   if (VarState == CS_Unknown) {
805     ThenStates->setState(Test.Var, Test.TestsFor);
806     if (ElseStates)
807       ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
808 
809   } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
810     ThenStates->markUnreachable();
811 
812   } else if (VarState == Test.TestsFor && ElseStates) {
813     ElseStates->markUnreachable();
814   }
815 }
816 
817 void splitVarStateForIfBinOp(const PropagationInfo &PInfo,
818   ConsumedStateMap *ThenStates, ConsumedStateMap *ElseStates) {
819 
820   const VarTestResult &LTest = PInfo.getLTest(),
821                       &RTest = PInfo.getRTest();
822 
823   ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,
824                 RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;
825 
826   if (LTest.Var) {
827     if (PInfo.testEffectiveOp() == EO_And) {
828       if (LState == CS_Unknown) {
829         ThenStates->setState(LTest.Var, LTest.TestsFor);
830 
831       } else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {
832         ThenStates->markUnreachable();
833 
834       } else if (LState == LTest.TestsFor && isKnownState(RState)) {
835         if (RState == RTest.TestsFor) {
836           if (ElseStates)
837             ElseStates->markUnreachable();
838         } else {
839           ThenStates->markUnreachable();
840         }
841       }
842 
843     } else {
844       if (LState == CS_Unknown && ElseStates) {
845         ElseStates->setState(LTest.Var,
846                              invertConsumedUnconsumed(LTest.TestsFor));
847 
848       } else if (LState == LTest.TestsFor && ElseStates) {
849         ElseStates->markUnreachable();
850 
851       } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
852                  isKnownState(RState)) {
853 
854         if (RState == RTest.TestsFor) {
855           if (ElseStates)
856             ElseStates->markUnreachable();
857         } else {
858           ThenStates->markUnreachable();
859         }
860       }
861     }
862   }
863 
864   if (RTest.Var) {
865     if (PInfo.testEffectiveOp() == EO_And) {
866       if (RState == CS_Unknown)
867         ThenStates->setState(RTest.Var, RTest.TestsFor);
868       else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
869         ThenStates->markUnreachable();
870 
871     } else if (ElseStates) {
872       if (RState == CS_Unknown)
873         ElseStates->setState(RTest.Var,
874                              invertConsumedUnconsumed(RTest.TestsFor));
875       else if (RState == RTest.TestsFor)
876         ElseStates->markUnreachable();
877     }
878   }
879 }
880 
881 void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
882                                 ConsumedStateMap *StateMap,
883                                 bool &AlreadyOwned) {
884 
885   if (VisitedBlocks.alreadySet(Block)) return;
886 
887   ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
888 
889   if (Entry) {
890     Entry->intersect(StateMap);
891 
892   } else if (AlreadyOwned) {
893     StateMapsArray[Block->getBlockID()] = new ConsumedStateMap(*StateMap);
894 
895   } else {
896     StateMapsArray[Block->getBlockID()] = StateMap;
897     AlreadyOwned = true;
898   }
899 }
900 
901 void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
902                                 ConsumedStateMap *StateMap) {
903 
904   if (VisitedBlocks.alreadySet(Block)) {
905     delete StateMap;
906     return;
907   }
908 
909   ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
910 
911   if (Entry) {
912     Entry->intersect(StateMap);
913     delete StateMap;
914 
915   } else {
916     StateMapsArray[Block->getBlockID()] = StateMap;
917   }
918 }
919 
920 ConsumedStateMap* ConsumedBlockInfo::getInfo(const CFGBlock *Block) {
921   return StateMapsArray[Block->getBlockID()];
922 }
923 
924 void ConsumedBlockInfo::markVisited(const CFGBlock *Block) {
925   VisitedBlocks.insert(Block);
926 }
927 
928 ConsumedState ConsumedStateMap::getState(const VarDecl *Var) {
929   MapType::const_iterator Entry = Map.find(Var);
930 
931   if (Entry != Map.end()) {
932     return Entry->second;
933 
934   } else {
935     return CS_None;
936   }
937 }
938 
939 void ConsumedStateMap::intersect(const ConsumedStateMap *Other) {
940   ConsumedState LocalState;
941 
942   if (this->From && this->From == Other->From && !Other->Reachable) {
943     this->markUnreachable();
944     return;
945   }
946 
947   for (MapType::const_iterator DMI = Other->Map.begin(),
948        DME = Other->Map.end(); DMI != DME; ++DMI) {
949 
950     LocalState = this->getState(DMI->first);
951 
952     if (LocalState == CS_None)
953       continue;
954 
955     if (LocalState != DMI->second)
956        Map[DMI->first] = CS_Unknown;
957   }
958 }
959 
960 void ConsumedStateMap::markUnreachable() {
961   this->Reachable = false;
962   Map.clear();
963 }
964 
965 void ConsumedStateMap::makeUnknown() {
966   for (MapType::const_iterator DMI = Map.begin(), DME = Map.end(); DMI != DME;
967        ++DMI) {
968 
969     Map[DMI->first] = CS_Unknown;
970   }
971 }
972 
973 void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {
974   Map[Var] = State;
975 }
976 
977 void ConsumedStateMap::remove(const VarDecl *Var) {
978   Map.erase(Var);
979 }
980 
981 void ConsumedAnalyzer::determineExpectedReturnState(AnalysisDeclContext &AC,
982                                                     const FunctionDecl *D) {
983   QualType ReturnType;
984   if (const CXXConstructorDecl *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
985     ASTContext &CurrContext = AC.getASTContext();
986     ReturnType = Constructor->getThisType(CurrContext)->getPointeeType();
987   } else
988     ReturnType = D->getCallResultType();
989 
990   if (D->hasAttr<ReturnTypestateAttr>()) {
991     const ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>();
992 
993     const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();
994     if (!RD || !RD->hasAttr<ConsumableAttr>()) {
995       // FIXME: This should be removed when template instantiation propagates
996       //        attributes at template specialization definition, not
997       //        declaration. When it is removed the test needs to be enabled
998       //        in SemaDeclAttr.cpp.
999       WarningsHandler.warnReturnTypestateForUnconsumableType(
1000           RTSAttr->getLocation(), ReturnType.getAsString());
1001       ExpectedReturnState = CS_None;
1002     } else
1003       ExpectedReturnState = mapReturnTypestateAttrState(RTSAttr);
1004   } else if (isConsumableType(ReturnType))
1005     ExpectedReturnState = mapConsumableAttrState(ReturnType);
1006   else
1007     ExpectedReturnState = CS_None;
1008 }
1009 
1010 bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,
1011                                   const ConsumedStmtVisitor &Visitor) {
1012 
1013   ConsumedStateMap *FalseStates = new ConsumedStateMap(*CurrStates);
1014   PropagationInfo PInfo;
1015 
1016   if (const IfStmt *IfNode =
1017     dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
1018 
1019     bool HasElse = IfNode->getElse() != NULL;
1020     const Stmt *Cond = IfNode->getCond();
1021 
1022     PInfo = Visitor.getInfo(Cond);
1023     if (!PInfo.isValid() && isa<BinaryOperator>(Cond))
1024       PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
1025 
1026     if (PInfo.isTest()) {
1027       CurrStates->setSource(Cond);
1028       FalseStates->setSource(Cond);
1029 
1030       splitVarStateForIf(IfNode, PInfo.getTest(), CurrStates,
1031                          HasElse ? FalseStates : NULL);
1032 
1033     } else if (PInfo.isBinTest()) {
1034       CurrStates->setSource(PInfo.testSourceNode());
1035       FalseStates->setSource(PInfo.testSourceNode());
1036 
1037       splitVarStateForIfBinOp(PInfo, CurrStates, HasElse ? FalseStates : NULL);
1038 
1039     } else {
1040       delete FalseStates;
1041       return false;
1042     }
1043 
1044   } else if (const BinaryOperator *BinOp =
1045     dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {
1046 
1047     PInfo = Visitor.getInfo(BinOp->getLHS());
1048     if (!PInfo.isTest()) {
1049       if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
1050         PInfo = Visitor.getInfo(BinOp->getRHS());
1051 
1052         if (!PInfo.isTest()) {
1053           delete FalseStates;
1054           return false;
1055         }
1056 
1057       } else {
1058         delete FalseStates;
1059         return false;
1060       }
1061     }
1062 
1063     CurrStates->setSource(BinOp);
1064     FalseStates->setSource(BinOp);
1065 
1066     const VarTestResult &Test = PInfo.getTest();
1067     ConsumedState VarState = CurrStates->getState(Test.Var);
1068 
1069     if (BinOp->getOpcode() == BO_LAnd) {
1070       if (VarState == CS_Unknown)
1071         CurrStates->setState(Test.Var, Test.TestsFor);
1072       else if (VarState == invertConsumedUnconsumed(Test.TestsFor))
1073         CurrStates->markUnreachable();
1074 
1075     } else if (BinOp->getOpcode() == BO_LOr) {
1076       if (VarState == CS_Unknown)
1077         FalseStates->setState(Test.Var,
1078                               invertConsumedUnconsumed(Test.TestsFor));
1079       else if (VarState == Test.TestsFor)
1080         FalseStates->markUnreachable();
1081     }
1082 
1083   } else {
1084     delete FalseStates;
1085     return false;
1086   }
1087 
1088   CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();
1089 
1090   if (*SI)
1091     BlockInfo.addInfo(*SI, CurrStates);
1092   else
1093     delete CurrStates;
1094 
1095   if (*++SI)
1096     BlockInfo.addInfo(*SI, FalseStates);
1097   else
1098     delete FalseStates;
1099 
1100   CurrStates = NULL;
1101   return true;
1102 }
1103 
1104 void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {
1105   const FunctionDecl *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());
1106 
1107   if (!D) return;
1108 
1109   determineExpectedReturnState(AC, D);
1110 
1111   BlockInfo = ConsumedBlockInfo(AC.getCFG());
1112 
1113   PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();
1114 
1115   CurrStates = new ConsumedStateMap();
1116   ConsumedStmtVisitor Visitor(AC, *this, CurrStates);
1117 
1118   // Add all trackable parameters to the state map.
1119   for (FunctionDecl::param_const_iterator PI = D->param_begin(),
1120        PE = D->param_end(); PI != PE; ++PI) {
1121     Visitor.VisitParmVarDecl(*PI);
1122   }
1123 
1124   // Visit all of the function's basic blocks.
1125   for (PostOrderCFGView::iterator I = SortedGraph->begin(),
1126        E = SortedGraph->end(); I != E; ++I) {
1127 
1128     const CFGBlock *CurrBlock = *I;
1129     BlockInfo.markVisited(CurrBlock);
1130 
1131     if (CurrStates == NULL)
1132       CurrStates = BlockInfo.getInfo(CurrBlock);
1133 
1134     if (!CurrStates) {
1135       continue;
1136 
1137     } else if (!CurrStates->isReachable()) {
1138       delete CurrStates;
1139       CurrStates = NULL;
1140       continue;
1141     }
1142 
1143     Visitor.reset(CurrStates);
1144 
1145     // Visit all of the basic block's statements.
1146     for (CFGBlock::const_iterator BI = CurrBlock->begin(),
1147          BE = CurrBlock->end(); BI != BE; ++BI) {
1148 
1149       switch (BI->getKind()) {
1150       case CFGElement::Statement:
1151         Visitor.Visit(BI->castAs<CFGStmt>().getStmt());
1152         break;
1153       case CFGElement::AutomaticObjectDtor:
1154         CurrStates->remove(BI->castAs<CFGAutomaticObjDtor>().getVarDecl());
1155       default:
1156         break;
1157       }
1158     }
1159 
1160     // TODO: Handle other forms of branching with precision, including while-
1161     //       and for-loops. (Deferred)
1162     if (!splitState(CurrBlock, Visitor)) {
1163       CurrStates->setSource(NULL);
1164 
1165       if (CurrBlock->succ_size() > 1) {
1166         CurrStates->makeUnknown();
1167 
1168         bool OwnershipTaken = false;
1169 
1170         for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),
1171              SE = CurrBlock->succ_end(); SI != SE; ++SI) {
1172 
1173           if (*SI) BlockInfo.addInfo(*SI, CurrStates, OwnershipTaken);
1174         }
1175 
1176         if (!OwnershipTaken)
1177           delete CurrStates;
1178 
1179         CurrStates = NULL;
1180 
1181       } else if (CurrBlock->succ_size() == 1 &&
1182                  (*CurrBlock->succ_begin())->pred_size() > 1) {
1183 
1184         BlockInfo.addInfo(*CurrBlock->succ_begin(), CurrStates);
1185         CurrStates = NULL;
1186       }
1187     }
1188   } // End of block iterator.
1189 
1190   // Delete the last existing state map.
1191   delete CurrStates;
1192 
1193   WarningsHandler.emitDiagnostics();
1194 }
1195 }} // end namespace clang::consumed
1196