xref: /llvm-project/clang/lib/Analysis/Consumed.cpp (revision fc368259af1a580b1843174cae3067d1b3c9aafc)
1 //===- Consumed.cpp --------------------------------------------*- C++ --*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // A intra-procedural analysis for checking consumed properties.  This is based,
11 // in part, on research on linear types.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "clang/AST/ASTContext.h"
16 #include "clang/AST/Attr.h"
17 #include "clang/AST/DeclCXX.h"
18 #include "clang/AST/ExprCXX.h"
19 #include "clang/AST/RecursiveASTVisitor.h"
20 #include "clang/AST/StmtVisitor.h"
21 #include "clang/AST/StmtCXX.h"
22 #include "clang/AST/Type.h"
23 #include "clang/Analysis/Analyses/PostOrderCFGView.h"
24 #include "clang/Analysis/AnalysisContext.h"
25 #include "clang/Analysis/CFG.h"
26 #include "clang/Analysis/Analyses/Consumed.h"
27 #include "clang/Basic/OperatorKinds.h"
28 #include "clang/Basic/SourceLocation.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/Support/Compiler.h"
32 #include "llvm/Support/raw_ostream.h"
33 
34 // TODO: Add notes about the actual and expected state for
35 // TODO: Correctly identify unreachable blocks when chaining boolean operators.
36 // TODO: Warn about unreachable code.
37 // TODO: Switch to using a bitmap to track unreachable blocks.
38 // TODO: Mark variables as Unknown going into while- or for-loops only if they
39 //       are referenced inside that block. (Deferred)
40 // TODO: Handle variable definitions, e.g. bool valid = x.isValid();
41 //       if (valid) ...; (Deferred)
42 // TODO: Add a method(s) to identify which method calls perform what state
43 //       transitions. (Deferred)
44 // TODO: Take notes on state transitions to provide better warning messages.
45 //       (Deferred)
46 // TODO: Test nested conditionals: A) Checking the same value multiple times,
47 //       and 2) Checking different values. (Deferred)
48 
49 using namespace clang;
50 using namespace consumed;
51 
52 // Key method definition
53 ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() {}
54 
55 static ConsumedState invertConsumedUnconsumed(ConsumedState State) {
56   switch (State) {
57   case CS_Unconsumed:
58     return CS_Consumed;
59   case CS_Consumed:
60     return CS_Unconsumed;
61   case CS_None:
62     return CS_None;
63   case CS_Unknown:
64     return CS_Unknown;
65   }
66   llvm_unreachable("invalid enum");
67 }
68 
69 static bool isConsumableType(const QualType &QT) {
70   if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
71     return RD->hasAttr<ConsumableAttr>();
72   else
73     return false;
74 }
75 
76 static bool isKnownState(ConsumedState State) {
77   switch (State) {
78   case CS_Unconsumed:
79   case CS_Consumed:
80     return true;
81   case CS_None:
82   case CS_Unknown:
83     return false;
84   }
85   llvm_unreachable("invalid enum");
86 }
87 
88 static bool isTestingFunction(const FunctionDecl *FunDecl) {
89   return FunDecl->hasAttr<TestsUnconsumedAttr>();
90 }
91 
92 static ConsumedState mapReturnTypestateAttrState(
93   const ReturnTypestateAttr *RTSAttr) {
94 
95   switch (RTSAttr->getState()) {
96   case ReturnTypestateAttr::Unknown:
97     return CS_Unknown;
98   case ReturnTypestateAttr::Unconsumed:
99     return CS_Unconsumed;
100   case ReturnTypestateAttr::Consumed:
101     return CS_Consumed;
102   }
103 }
104 
105 static StringRef stateToString(ConsumedState State) {
106   switch (State) {
107   case consumed::CS_None:
108     return "none";
109 
110   case consumed::CS_Unknown:
111     return "unknown";
112 
113   case consumed::CS_Unconsumed:
114     return "unconsumed";
115 
116   case consumed::CS_Consumed:
117     return "consumed";
118   }
119   llvm_unreachable("invalid enum");
120 }
121 
122 namespace {
123 struct VarTestResult {
124   const VarDecl *Var;
125   ConsumedState TestsFor;
126 };
127 } // end anonymous::VarTestResult
128 
129 namespace clang {
130 namespace consumed {
131 
132 enum EffectiveOp {
133   EO_And,
134   EO_Or
135 };
136 
137 class PropagationInfo {
138   enum {
139     IT_None,
140     IT_State,
141     IT_Test,
142     IT_BinTest,
143     IT_Var
144   } InfoType;
145 
146   struct BinTestTy {
147     const BinaryOperator *Source;
148     EffectiveOp EOp;
149     VarTestResult LTest;
150     VarTestResult RTest;
151   };
152 
153   union {
154     ConsumedState State;
155     VarTestResult Test;
156     const VarDecl *Var;
157     BinTestTy BinTest;
158   };
159 
160 public:
161   PropagationInfo() : InfoType(IT_None) {}
162 
163   PropagationInfo(const VarTestResult &Test) : InfoType(IT_Test), Test(Test) {}
164   PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)
165     : InfoType(IT_Test) {
166 
167     Test.Var      = Var;
168     Test.TestsFor = TestsFor;
169   }
170 
171   PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
172                   const VarTestResult &LTest, const VarTestResult &RTest)
173     : InfoType(IT_BinTest) {
174 
175     BinTest.Source  = Source;
176     BinTest.EOp     = EOp;
177     BinTest.LTest   = LTest;
178     BinTest.RTest   = RTest;
179   }
180 
181   PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
182                   const VarDecl *LVar, ConsumedState LTestsFor,
183                   const VarDecl *RVar, ConsumedState RTestsFor)
184     : InfoType(IT_BinTest) {
185 
186     BinTest.Source         = Source;
187     BinTest.EOp            = EOp;
188     BinTest.LTest.Var      = LVar;
189     BinTest.LTest.TestsFor = LTestsFor;
190     BinTest.RTest.Var      = RVar;
191     BinTest.RTest.TestsFor = RTestsFor;
192   }
193 
194   PropagationInfo(ConsumedState State) : InfoType(IT_State), State(State) {}
195   PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
196 
197   const ConsumedState & getState() const {
198     assert(InfoType == IT_State);
199     return State;
200   }
201 
202   const VarTestResult & getTest() const {
203     assert(InfoType == IT_Test);
204     return Test;
205   }
206 
207   const VarTestResult & getLTest() const {
208     assert(InfoType == IT_BinTest);
209     return BinTest.LTest;
210   }
211 
212   const VarTestResult & getRTest() const {
213     assert(InfoType == IT_BinTest);
214     return BinTest.RTest;
215   }
216 
217   const VarDecl * getVar() const {
218     assert(InfoType == IT_Var);
219     return Var;
220   }
221 
222   EffectiveOp testEffectiveOp() const {
223     assert(InfoType == IT_BinTest);
224     return BinTest.EOp;
225   }
226 
227   const BinaryOperator * testSourceNode() const {
228     assert(InfoType == IT_BinTest);
229     return BinTest.Source;
230   }
231 
232   bool isValid()   const { return InfoType != IT_None;     }
233   bool isState()   const { return InfoType == IT_State;    }
234   bool isTest()    const { return InfoType == IT_Test;     }
235   bool isBinTest() const { return InfoType == IT_BinTest;  }
236   bool isVar()     const { return InfoType == IT_Var;      }
237 
238   PropagationInfo invertTest() const {
239     assert(InfoType == IT_Test || InfoType == IT_BinTest);
240 
241     if (InfoType == IT_Test) {
242       return PropagationInfo(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
243 
244     } else if (InfoType == IT_BinTest) {
245       return PropagationInfo(BinTest.Source,
246         BinTest.EOp == EO_And ? EO_Or : EO_And,
247         BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),
248         BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));
249     } else {
250       return PropagationInfo();
251     }
252   }
253 };
254 
255 class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {
256 
257   typedef llvm::DenseMap<const Stmt *, PropagationInfo> MapType;
258   typedef std::pair<const Stmt *, PropagationInfo> PairType;
259   typedef MapType::iterator InfoEntry;
260   typedef MapType::const_iterator ConstInfoEntry;
261 
262   AnalysisDeclContext &AC;
263   ConsumedAnalyzer &Analyzer;
264   ConsumedStateMap *StateMap;
265   MapType PropagationMap;
266 
267   void checkCallability(const PropagationInfo &PInfo,
268                         const FunctionDecl *FunDecl,
269                         const CallExpr *Call);
270   void forwardInfo(const Stmt *From, const Stmt *To);
271   void handleTestingFunctionCall(const CallExpr *Call, const VarDecl *Var);
272   bool isLikeMoveAssignment(const CXXMethodDecl *MethodDecl);
273   void propagateReturnType(const Stmt *Call, const FunctionDecl *Fun,
274                            QualType ReturnType);
275 
276 public:
277 
278   void Visit(const Stmt *StmtNode);
279 
280   void VisitBinaryOperator(const BinaryOperator *BinOp);
281   void VisitCallExpr(const CallExpr *Call);
282   void VisitCastExpr(const CastExpr *Cast);
283   void VisitCXXConstructExpr(const CXXConstructExpr *Call);
284   void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);
285   void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);
286   void VisitDeclRefExpr(const DeclRefExpr *DeclRef);
287   void VisitDeclStmt(const DeclStmt *DelcS);
288   void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
289   void VisitMemberExpr(const MemberExpr *MExpr);
290   void VisitParmVarDecl(const ParmVarDecl *Param);
291   void VisitReturnStmt(const ReturnStmt *Ret);
292   void VisitUnaryOperator(const UnaryOperator *UOp);
293   void VisitVarDecl(const VarDecl *Var);
294 
295   ConsumedStmtVisitor(AnalysisDeclContext &AC, ConsumedAnalyzer &Analyzer,
296                       ConsumedStateMap *StateMap)
297       : AC(AC), Analyzer(Analyzer), StateMap(StateMap) {}
298 
299   PropagationInfo getInfo(const Stmt *StmtNode) const {
300     ConstInfoEntry Entry = PropagationMap.find(StmtNode);
301 
302     if (Entry != PropagationMap.end())
303       return Entry->second;
304     else
305       return PropagationInfo();
306   }
307 
308   void reset(ConsumedStateMap *NewStateMap) {
309     StateMap = NewStateMap;
310   }
311 };
312 
313 // TODO: When we support CallableWhenConsumed this will have to check for
314 //       the different attributes and change the behavior bellow. (Deferred)
315 void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
316                                            const FunctionDecl *FunDecl,
317                                            const CallExpr *Call) {
318 
319   if (!FunDecl->hasAttr<CallableWhenUnconsumedAttr>()) return;
320 
321   if (PInfo.isVar()) {
322     const VarDecl *Var = PInfo.getVar();
323 
324     switch (StateMap->getState(Var)) {
325     case CS_Consumed:
326       Analyzer.WarningsHandler.warnUseWhileConsumed(
327         FunDecl->getNameAsString(), Var->getNameAsString(),
328         Call->getExprLoc());
329       break;
330 
331     case CS_Unknown:
332       Analyzer.WarningsHandler.warnUseInUnknownState(
333         FunDecl->getNameAsString(), Var->getNameAsString(),
334         Call->getExprLoc());
335       break;
336 
337     case CS_None:
338     case CS_Unconsumed:
339       break;
340     }
341 
342   } else {
343     switch (PInfo.getState()) {
344     case CS_Consumed:
345       Analyzer.WarningsHandler.warnUseOfTempWhileConsumed(
346         FunDecl->getNameAsString(), Call->getExprLoc());
347       break;
348 
349     case CS_Unknown:
350       Analyzer.WarningsHandler.warnUseOfTempInUnknownState(
351         FunDecl->getNameAsString(), Call->getExprLoc());
352       break;
353 
354     case CS_None:
355     case CS_Unconsumed:
356       break;
357     }
358   }
359 }
360 
361 void ConsumedStmtVisitor::forwardInfo(const Stmt *From, const Stmt *To) {
362   InfoEntry Entry = PropagationMap.find(From);
363 
364   if (Entry != PropagationMap.end())
365     PropagationMap.insert(PairType(To, Entry->second));
366 }
367 
368 void ConsumedStmtVisitor::handleTestingFunctionCall(const CallExpr *Call,
369                                                     const VarDecl  *Var) {
370 
371   ConsumedState VarState = StateMap->getState(Var);
372 
373   if (VarState != CS_Unknown) {
374     SourceLocation CallLoc = Call->getExprLoc();
375 
376     if (!CallLoc.isMacroID())
377       Analyzer.WarningsHandler.warnUnnecessaryTest(Var->getNameAsString(),
378         stateToString(VarState), CallLoc);
379   }
380 
381   PropagationMap.insert(PairType(Call, PropagationInfo(Var, CS_Unconsumed)));
382 }
383 
384 bool ConsumedStmtVisitor::isLikeMoveAssignment(
385   const CXXMethodDecl *MethodDecl) {
386 
387   return MethodDecl->isMoveAssignmentOperator() ||
388          (MethodDecl->getOverloadedOperator() == OO_Equal &&
389           MethodDecl->getNumParams() == 1 &&
390           MethodDecl->getParamDecl(0)->getType()->isRValueReferenceType());
391 }
392 
393 void ConsumedStmtVisitor::propagateReturnType(const Stmt *Call,
394                                               const FunctionDecl *Fun,
395                                               QualType ReturnType) {
396   if (isConsumableType(ReturnType)) {
397 
398     ConsumedState ReturnState;
399 
400     if (Fun->hasAttr<ReturnTypestateAttr>())
401       ReturnState = mapReturnTypestateAttrState(
402         Fun->getAttr<ReturnTypestateAttr>());
403     else
404       ReturnState = CS_Unknown;
405 
406     PropagationMap.insert(PairType(Call,
407       PropagationInfo(ReturnState)));
408   }
409 }
410 
411 void ConsumedStmtVisitor::Visit(const Stmt *StmtNode) {
412 
413   ConstStmtVisitor<ConsumedStmtVisitor>::Visit(StmtNode);
414 
415   for (Stmt::const_child_iterator CI = StmtNode->child_begin(),
416        CE = StmtNode->child_end(); CI != CE; ++CI) {
417 
418     PropagationMap.erase(*CI);
419   }
420 }
421 
422 void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {
423   switch (BinOp->getOpcode()) {
424   case BO_LAnd:
425   case BO_LOr : {
426     InfoEntry LEntry = PropagationMap.find(BinOp->getLHS()),
427               REntry = PropagationMap.find(BinOp->getRHS());
428 
429     VarTestResult LTest, RTest;
430 
431     if (LEntry != PropagationMap.end() && LEntry->second.isTest()) {
432       LTest = LEntry->second.getTest();
433 
434     } else {
435       LTest.Var      = NULL;
436       LTest.TestsFor = CS_None;
437     }
438 
439     if (REntry != PropagationMap.end() && REntry->second.isTest()) {
440       RTest = REntry->second.getTest();
441 
442     } else {
443       RTest.Var      = NULL;
444       RTest.TestsFor = CS_None;
445     }
446 
447     if (!(LTest.Var == NULL && RTest.Var == NULL))
448       PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,
449         static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));
450 
451     break;
452   }
453 
454   case BO_PtrMemD:
455   case BO_PtrMemI:
456     forwardInfo(BinOp->getLHS(), BinOp);
457     break;
458 
459   default:
460     break;
461   }
462 }
463 
464 void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {
465   if (const FunctionDecl *FunDecl =
466     dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee())) {
467 
468     // Special case for the std::move function.
469     // TODO: Make this more specific. (Deferred)
470     if (FunDecl->getNameAsString() == "move") {
471       InfoEntry Entry = PropagationMap.find(Call->getArg(0));
472 
473       if (Entry != PropagationMap.end()) {
474         PropagationMap.insert(PairType(Call, Entry->second));
475       }
476 
477       return;
478     }
479 
480     unsigned Offset = Call->getNumArgs() - FunDecl->getNumParams();
481 
482     for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {
483       QualType ParamType = FunDecl->getParamDecl(Index - Offset)->getType();
484 
485       InfoEntry Entry = PropagationMap.find(Call->getArg(Index));
486 
487       if (Entry == PropagationMap.end() || !Entry->second.isVar()) {
488         continue;
489       }
490 
491       PropagationInfo PInfo = Entry->second;
492 
493       if (ParamType->isRValueReferenceType() ||
494           (ParamType->isLValueReferenceType() &&
495            !cast<LValueReferenceType>(*ParamType).isSpelledAsLValue())) {
496 
497         StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
498 
499       } else if (!(ParamType.isConstQualified() ||
500                    ((ParamType->isReferenceType() ||
501                      ParamType->isPointerType()) &&
502                     ParamType->getPointeeType().isConstQualified()))) {
503 
504         StateMap->setState(PInfo.getVar(), consumed::CS_Unknown);
505       }
506     }
507 
508     propagateReturnType(Call, FunDecl, FunDecl->getCallResultType());
509   }
510 }
511 
512 void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {
513   forwardInfo(Cast->getSubExpr(), Cast);
514 }
515 
516 void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {
517   CXXConstructorDecl *Constructor = Call->getConstructor();
518 
519   ASTContext &CurrContext = AC.getASTContext();
520   QualType ThisType = Constructor->getThisType(CurrContext)->getPointeeType();
521 
522   if (isConsumableType(ThisType)) {
523     if (Constructor->isDefaultConstructor()) {
524 
525       PropagationMap.insert(PairType(Call,
526         PropagationInfo(consumed::CS_Consumed)));
527 
528     } else if (Constructor->isMoveConstructor()) {
529 
530       PropagationInfo PInfo =
531         PropagationMap.find(Call->getArg(0))->second;
532 
533       if (PInfo.isVar()) {
534         const VarDecl* Var = PInfo.getVar();
535 
536         PropagationMap.insert(PairType(Call,
537           PropagationInfo(StateMap->getState(Var))));
538 
539         StateMap->setState(Var, consumed::CS_Consumed);
540 
541       } else {
542         PropagationMap.insert(PairType(Call, PInfo));
543       }
544 
545     } else if (Constructor->isCopyConstructor()) {
546       MapType::iterator Entry = PropagationMap.find(Call->getArg(0));
547 
548       if (Entry != PropagationMap.end())
549         PropagationMap.insert(PairType(Call, Entry->second));
550 
551     } else {
552       propagateReturnType(Call, Constructor, ThisType);
553     }
554   }
555 }
556 
557 void ConsumedStmtVisitor::VisitCXXMemberCallExpr(
558   const CXXMemberCallExpr *Call) {
559 
560   VisitCallExpr(Call);
561 
562   InfoEntry Entry = PropagationMap.find(Call->getCallee()->IgnoreParens());
563 
564   if (Entry != PropagationMap.end()) {
565     PropagationInfo PInfo = Entry->second;
566     const CXXMethodDecl *MethodDecl = Call->getMethodDecl();
567 
568     checkCallability(PInfo, MethodDecl, Call);
569 
570     if (PInfo.isVar()) {
571       if (isTestingFunction(MethodDecl))
572         handleTestingFunctionCall(Call, PInfo.getVar());
573       else if (MethodDecl->hasAttr<ConsumesAttr>())
574         StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
575     }
576   }
577 }
578 
579 void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(
580   const CXXOperatorCallExpr *Call) {
581 
582   const FunctionDecl *FunDecl =
583     dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
584 
585   if (!FunDecl) return;
586 
587   if (isa<CXXMethodDecl>(FunDecl) &&
588       isLikeMoveAssignment(cast<CXXMethodDecl>(FunDecl))) {
589 
590     InfoEntry LEntry = PropagationMap.find(Call->getArg(0));
591     InfoEntry REntry = PropagationMap.find(Call->getArg(1));
592 
593     PropagationInfo LPInfo, RPInfo;
594 
595     if (LEntry != PropagationMap.end() &&
596         REntry != PropagationMap.end()) {
597 
598       LPInfo = LEntry->second;
599       RPInfo = REntry->second;
600 
601       if (LPInfo.isVar() && RPInfo.isVar()) {
602         StateMap->setState(LPInfo.getVar(),
603           StateMap->getState(RPInfo.getVar()));
604 
605         StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
606 
607         PropagationMap.insert(PairType(Call, LPInfo));
608 
609       } else if (LPInfo.isVar() && !RPInfo.isVar()) {
610         StateMap->setState(LPInfo.getVar(), RPInfo.getState());
611 
612         PropagationMap.insert(PairType(Call, LPInfo));
613 
614       } else if (!LPInfo.isVar() && RPInfo.isVar()) {
615         PropagationMap.insert(PairType(Call,
616           PropagationInfo(StateMap->getState(RPInfo.getVar()))));
617 
618         StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
619 
620       } else {
621         PropagationMap.insert(PairType(Call, RPInfo));
622       }
623 
624     } else if (LEntry != PropagationMap.end() &&
625                REntry == PropagationMap.end()) {
626 
627       LPInfo = LEntry->second;
628 
629       if (LPInfo.isVar()) {
630         StateMap->setState(LPInfo.getVar(), consumed::CS_Unknown);
631 
632         PropagationMap.insert(PairType(Call, LPInfo));
633 
634       } else {
635         PropagationMap.insert(PairType(Call,
636           PropagationInfo(consumed::CS_Unknown)));
637       }
638 
639     } else if (LEntry == PropagationMap.end() &&
640                REntry != PropagationMap.end()) {
641 
642       RPInfo = REntry->second;
643 
644       if (RPInfo.isVar()) {
645         const VarDecl *Var = RPInfo.getVar();
646 
647         PropagationMap.insert(PairType(Call,
648           PropagationInfo(StateMap->getState(Var))));
649 
650         StateMap->setState(Var, consumed::CS_Consumed);
651 
652       } else {
653         PropagationMap.insert(PairType(Call, RPInfo));
654       }
655     }
656 
657   } else {
658 
659     VisitCallExpr(Call);
660 
661     InfoEntry Entry = PropagationMap.find(Call->getArg(0));
662 
663     if (Entry != PropagationMap.end()) {
664       PropagationInfo PInfo = Entry->second;
665 
666       checkCallability(PInfo, FunDecl, Call);
667 
668       if (PInfo.isVar()) {
669         if (isTestingFunction(FunDecl))
670           handleTestingFunctionCall(Call, PInfo.getVar());
671         else if (FunDecl->hasAttr<ConsumesAttr>())
672           StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
673       }
674     }
675   }
676 }
677 
678 void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {
679   if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))
680     if (StateMap->getState(Var) != consumed::CS_None)
681       PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));
682 }
683 
684 void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {
685   for (DeclStmt::const_decl_iterator DI = DeclS->decl_begin(),
686        DE = DeclS->decl_end(); DI != DE; ++DI) {
687 
688     if (isa<VarDecl>(*DI)) VisitVarDecl(cast<VarDecl>(*DI));
689   }
690 
691   if (DeclS->isSingleDecl())
692     if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))
693       PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));
694 }
695 
696 void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(
697   const MaterializeTemporaryExpr *Temp) {
698 
699   InfoEntry Entry = PropagationMap.find(Temp->GetTemporaryExpr());
700 
701   if (Entry != PropagationMap.end())
702     PropagationMap.insert(PairType(Temp, Entry->second));
703 }
704 
705 void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {
706   forwardInfo(MExpr->getBase(), MExpr);
707 }
708 
709 
710 void ConsumedStmtVisitor::VisitParmVarDecl(const ParmVarDecl *Param) {
711   if (isConsumableType(Param->getType()))
712     StateMap->setState(Param, consumed::CS_Unknown);
713 }
714 
715 void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {
716   if (ConsumedState ExpectedState = Analyzer.getExpectedReturnState()) {
717     InfoEntry Entry = PropagationMap.find(Ret->getRetValue());
718 
719     if (Entry != PropagationMap.end()) {
720       assert(Entry->second.isState() || Entry->second.isVar());
721 
722       ConsumedState RetState = Entry->second.isState() ?
723         Entry->second.getState() : StateMap->getState(Entry->second.getVar());
724 
725       if (RetState != ExpectedState)
726         Analyzer.WarningsHandler.warnReturnTypestateMismatch(
727           Ret->getReturnLoc(), stateToString(ExpectedState),
728           stateToString(RetState));
729     }
730   }
731 }
732 
733 void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
734   InfoEntry Entry = PropagationMap.find(UOp->getSubExpr()->IgnoreParens());
735   if (Entry == PropagationMap.end()) return;
736 
737   switch (UOp->getOpcode()) {
738   case UO_AddrOf:
739     PropagationMap.insert(PairType(UOp, Entry->second));
740     break;
741 
742   case UO_LNot:
743     if (Entry->second.isTest() || Entry->second.isBinTest())
744       PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
745     break;
746 
747   default:
748     break;
749   }
750 }
751 
752 void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
753   if (isConsumableType(Var->getType())) {
754     if (Var->hasInit()) {
755       PropagationInfo PInfo =
756         PropagationMap.find(Var->getInit())->second;
757 
758       StateMap->setState(Var, PInfo.isVar() ?
759         StateMap->getState(PInfo.getVar()) : PInfo.getState());
760 
761     } else {
762       StateMap->setState(Var, consumed::CS_Unknown);
763     }
764   }
765 }
766 }} // end clang::consumed::ConsumedStmtVisitor
767 
768 namespace clang {
769 namespace consumed {
770 
771 void splitVarStateForIf(const IfStmt * IfNode, const VarTestResult &Test,
772                         ConsumedStateMap *ThenStates,
773                         ConsumedStateMap *ElseStates) {
774 
775   ConsumedState VarState = ThenStates->getState(Test.Var);
776 
777   if (VarState == CS_Unknown) {
778     ThenStates->setState(Test.Var, Test.TestsFor);
779     if (ElseStates)
780       ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
781 
782   } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
783     ThenStates->markUnreachable();
784 
785   } else if (VarState == Test.TestsFor && ElseStates) {
786     ElseStates->markUnreachable();
787   }
788 }
789 
790 void splitVarStateForIfBinOp(const PropagationInfo &PInfo,
791   ConsumedStateMap *ThenStates, ConsumedStateMap *ElseStates) {
792 
793   const VarTestResult &LTest = PInfo.getLTest(),
794                       &RTest = PInfo.getRTest();
795 
796   ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,
797                 RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;
798 
799   if (LTest.Var) {
800     if (PInfo.testEffectiveOp() == EO_And) {
801       if (LState == CS_Unknown) {
802         ThenStates->setState(LTest.Var, LTest.TestsFor);
803 
804       } else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {
805         ThenStates->markUnreachable();
806 
807       } else if (LState == LTest.TestsFor && isKnownState(RState)) {
808         if (RState == RTest.TestsFor) {
809           if (ElseStates)
810             ElseStates->markUnreachable();
811         } else {
812           ThenStates->markUnreachable();
813         }
814       }
815 
816     } else {
817       if (LState == CS_Unknown && ElseStates) {
818         ElseStates->setState(LTest.Var,
819                              invertConsumedUnconsumed(LTest.TestsFor));
820 
821       } else if (LState == LTest.TestsFor && ElseStates) {
822         ElseStates->markUnreachable();
823 
824       } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
825                  isKnownState(RState)) {
826 
827         if (RState == RTest.TestsFor) {
828           if (ElseStates)
829             ElseStates->markUnreachable();
830         } else {
831           ThenStates->markUnreachable();
832         }
833       }
834     }
835   }
836 
837   if (RTest.Var) {
838     if (PInfo.testEffectiveOp() == EO_And) {
839       if (RState == CS_Unknown)
840         ThenStates->setState(RTest.Var, RTest.TestsFor);
841       else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
842         ThenStates->markUnreachable();
843 
844     } else if (ElseStates) {
845       if (RState == CS_Unknown)
846         ElseStates->setState(RTest.Var,
847                              invertConsumedUnconsumed(RTest.TestsFor));
848       else if (RState == RTest.TestsFor)
849         ElseStates->markUnreachable();
850     }
851   }
852 }
853 
854 void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
855                                 ConsumedStateMap *StateMap,
856                                 bool &AlreadyOwned) {
857 
858   if (VisitedBlocks.alreadySet(Block)) return;
859 
860   ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
861 
862   if (Entry) {
863     Entry->intersect(StateMap);
864 
865   } else if (AlreadyOwned) {
866     StateMapsArray[Block->getBlockID()] = new ConsumedStateMap(*StateMap);
867 
868   } else {
869     StateMapsArray[Block->getBlockID()] = StateMap;
870     AlreadyOwned = true;
871   }
872 }
873 
874 void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
875                                 ConsumedStateMap *StateMap) {
876 
877   if (VisitedBlocks.alreadySet(Block)) {
878     delete StateMap;
879     return;
880   }
881 
882   ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
883 
884   if (Entry) {
885     Entry->intersect(StateMap);
886     delete StateMap;
887 
888   } else {
889     StateMapsArray[Block->getBlockID()] = StateMap;
890   }
891 }
892 
893 ConsumedStateMap* ConsumedBlockInfo::getInfo(const CFGBlock *Block) {
894   return StateMapsArray[Block->getBlockID()];
895 }
896 
897 void ConsumedBlockInfo::markVisited(const CFGBlock *Block) {
898   VisitedBlocks.insert(Block);
899 }
900 
901 ConsumedState ConsumedStateMap::getState(const VarDecl *Var) {
902   MapType::const_iterator Entry = Map.find(Var);
903 
904   if (Entry != Map.end()) {
905     return Entry->second;
906 
907   } else {
908     return CS_None;
909   }
910 }
911 
912 void ConsumedStateMap::intersect(const ConsumedStateMap *Other) {
913   ConsumedState LocalState;
914 
915   if (this->From && this->From == Other->From && !Other->Reachable) {
916     this->markUnreachable();
917     return;
918   }
919 
920   for (MapType::const_iterator DMI = Other->Map.begin(),
921        DME = Other->Map.end(); DMI != DME; ++DMI) {
922 
923     LocalState = this->getState(DMI->first);
924 
925     if (LocalState == CS_None)
926       continue;
927 
928     if (LocalState != DMI->second)
929        Map[DMI->first] = CS_Unknown;
930   }
931 }
932 
933 void ConsumedStateMap::markUnreachable() {
934   this->Reachable = false;
935   Map.clear();
936 }
937 
938 void ConsumedStateMap::makeUnknown() {
939   for (MapType::const_iterator DMI = Map.begin(), DME = Map.end(); DMI != DME;
940        ++DMI) {
941 
942     Map[DMI->first] = CS_Unknown;
943   }
944 }
945 
946 void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {
947   Map[Var] = State;
948 }
949 
950 void ConsumedStateMap::remove(const VarDecl *Var) {
951   Map.erase(Var);
952 }
953 
954 bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,
955                                   const ConsumedStmtVisitor &Visitor) {
956 
957   ConsumedStateMap *FalseStates = new ConsumedStateMap(*CurrStates);
958   PropagationInfo PInfo;
959 
960   if (const IfStmt *IfNode =
961     dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
962 
963     bool HasElse = IfNode->getElse() != NULL;
964     const Stmt *Cond = IfNode->getCond();
965 
966     PInfo = Visitor.getInfo(Cond);
967     if (!PInfo.isValid() && isa<BinaryOperator>(Cond))
968       PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
969 
970     if (PInfo.isTest()) {
971       CurrStates->setSource(Cond);
972       FalseStates->setSource(Cond);
973 
974       splitVarStateForIf(IfNode, PInfo.getTest(), CurrStates,
975                          HasElse ? FalseStates : NULL);
976 
977     } else if (PInfo.isBinTest()) {
978       CurrStates->setSource(PInfo.testSourceNode());
979       FalseStates->setSource(PInfo.testSourceNode());
980 
981       splitVarStateForIfBinOp(PInfo, CurrStates, HasElse ? FalseStates : NULL);
982 
983     } else {
984       delete FalseStates;
985       return false;
986     }
987 
988   } else if (const BinaryOperator *BinOp =
989     dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {
990 
991     PInfo = Visitor.getInfo(BinOp->getLHS());
992     if (!PInfo.isTest()) {
993       if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
994         PInfo = Visitor.getInfo(BinOp->getRHS());
995 
996         if (!PInfo.isTest()) {
997           delete FalseStates;
998           return false;
999         }
1000 
1001       } else {
1002         delete FalseStates;
1003         return false;
1004       }
1005     }
1006 
1007     CurrStates->setSource(BinOp);
1008     FalseStates->setSource(BinOp);
1009 
1010     const VarTestResult &Test = PInfo.getTest();
1011     ConsumedState VarState = CurrStates->getState(Test.Var);
1012 
1013     if (BinOp->getOpcode() == BO_LAnd) {
1014       if (VarState == CS_Unknown)
1015         CurrStates->setState(Test.Var, Test.TestsFor);
1016       else if (VarState == invertConsumedUnconsumed(Test.TestsFor))
1017         CurrStates->markUnreachable();
1018 
1019     } else if (BinOp->getOpcode() == BO_LOr) {
1020       if (VarState == CS_Unknown)
1021         FalseStates->setState(Test.Var,
1022                               invertConsumedUnconsumed(Test.TestsFor));
1023       else if (VarState == Test.TestsFor)
1024         FalseStates->markUnreachable();
1025     }
1026 
1027   } else {
1028     delete FalseStates;
1029     return false;
1030   }
1031 
1032   CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();
1033 
1034   if (*SI)
1035     BlockInfo.addInfo(*SI, CurrStates);
1036   else
1037     delete CurrStates;
1038 
1039   if (*++SI)
1040     BlockInfo.addInfo(*SI, FalseStates);
1041   else
1042     delete FalseStates;
1043 
1044   CurrStates = NULL;
1045   return true;
1046 }
1047 
1048 void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {
1049   const FunctionDecl *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());
1050 
1051   if (!D) return;
1052 
1053   // FIXME: This should be removed when template instantiation propagates
1054   //        attributes at template specialization definition, not declaration.
1055   //        When it is removed the test needs to be enabled in SemaDeclAttr.cpp.
1056   QualType ReturnType;
1057   if (const CXXConstructorDecl *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
1058     ASTContext &CurrContext = AC.getASTContext();
1059     ReturnType = Constructor->getThisType(CurrContext)->getPointeeType();
1060 
1061   } else {
1062     ReturnType = D->getCallResultType();
1063   }
1064 
1065   // Determine the expected return value.
1066   if (D->hasAttr<ReturnTypestateAttr>()) {
1067 
1068     ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>();
1069 
1070     const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();
1071     if (!RD || !RD->hasAttr<ConsumableAttr>()) {
1072         // FIXME: This branch can be removed with the code above.
1073         WarningsHandler.warnReturnTypestateForUnconsumableType(
1074           RTSAttr->getLocation(), ReturnType.getAsString());
1075         ExpectedReturnState = CS_None;
1076 
1077     } else {
1078       switch (RTSAttr->getState()) {
1079       case ReturnTypestateAttr::Unknown:
1080         ExpectedReturnState = CS_Unknown;
1081         break;
1082 
1083       case ReturnTypestateAttr::Unconsumed:
1084         ExpectedReturnState = CS_Unconsumed;
1085         break;
1086 
1087       case ReturnTypestateAttr::Consumed:
1088         ExpectedReturnState = CS_Consumed;
1089         break;
1090       }
1091     }
1092 
1093   } else if (isConsumableType(ReturnType)) {
1094     ExpectedReturnState = CS_Unknown;
1095 
1096   } else {
1097     ExpectedReturnState = CS_None;
1098   }
1099 
1100   BlockInfo = ConsumedBlockInfo(AC.getCFG());
1101 
1102   PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();
1103 
1104   CurrStates = new ConsumedStateMap();
1105   ConsumedStmtVisitor Visitor(AC, *this, CurrStates);
1106 
1107   // Add all trackable parameters to the state map.
1108   for (FunctionDecl::param_const_iterator PI = D->param_begin(),
1109        PE = D->param_end(); PI != PE; ++PI) {
1110     Visitor.VisitParmVarDecl(*PI);
1111   }
1112 
1113   // Visit all of the function's basic blocks.
1114   for (PostOrderCFGView::iterator I = SortedGraph->begin(),
1115        E = SortedGraph->end(); I != E; ++I) {
1116 
1117     const CFGBlock *CurrBlock = *I;
1118     BlockInfo.markVisited(CurrBlock);
1119 
1120     if (CurrStates == NULL)
1121       CurrStates = BlockInfo.getInfo(CurrBlock);
1122 
1123     if (!CurrStates) {
1124       continue;
1125 
1126     } else if (!CurrStates->isReachable()) {
1127       delete CurrStates;
1128       CurrStates = NULL;
1129       continue;
1130     }
1131 
1132     Visitor.reset(CurrStates);
1133 
1134     // Visit all of the basic block's statements.
1135     for (CFGBlock::const_iterator BI = CurrBlock->begin(),
1136          BE = CurrBlock->end(); BI != BE; ++BI) {
1137 
1138       switch (BI->getKind()) {
1139       case CFGElement::Statement:
1140         Visitor.Visit(BI->castAs<CFGStmt>().getStmt());
1141         break;
1142       case CFGElement::AutomaticObjectDtor:
1143         CurrStates->remove(BI->castAs<CFGAutomaticObjDtor>().getVarDecl());
1144       default:
1145         break;
1146       }
1147     }
1148 
1149     // TODO: Handle other forms of branching with precision, including while-
1150     //       and for-loops. (Deferred)
1151     if (!splitState(CurrBlock, Visitor)) {
1152       CurrStates->setSource(NULL);
1153 
1154       if (CurrBlock->succ_size() > 1) {
1155         CurrStates->makeUnknown();
1156 
1157         bool OwnershipTaken = false;
1158 
1159         for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),
1160              SE = CurrBlock->succ_end(); SI != SE; ++SI) {
1161 
1162           if (*SI) BlockInfo.addInfo(*SI, CurrStates, OwnershipTaken);
1163         }
1164 
1165         if (!OwnershipTaken)
1166           delete CurrStates;
1167 
1168         CurrStates = NULL;
1169 
1170       } else if (CurrBlock->succ_size() == 1 &&
1171                  (*CurrBlock->succ_begin())->pred_size() > 1) {
1172 
1173         BlockInfo.addInfo(*CurrBlock->succ_begin(), CurrStates);
1174         CurrStates = NULL;
1175       }
1176     }
1177   } // End of block iterator.
1178 
1179   // Delete the last existing state map.
1180   delete CurrStates;
1181 
1182   WarningsHandler.emitDiagnostics();
1183 }
1184 }} // end namespace clang::consumed
1185