xref: /llvm-project/clang/lib/Analysis/Consumed.cpp (revision 210791a021a18e3eb60d4a67df6cd452006aa338)
1 //===- Consumed.cpp --------------------------------------------*- C++ --*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // A intra-procedural analysis for checking consumed properties.  This is based,
11 // in part, on research on linear types.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "clang/AST/ASTContext.h"
16 #include "clang/AST/Attr.h"
17 #include "clang/AST/DeclCXX.h"
18 #include "clang/AST/ExprCXX.h"
19 #include "clang/AST/RecursiveASTVisitor.h"
20 #include "clang/AST/StmtVisitor.h"
21 #include "clang/AST/StmtCXX.h"
22 #include "clang/AST/Type.h"
23 #include "clang/Analysis/Analyses/PostOrderCFGView.h"
24 #include "clang/Analysis/AnalysisContext.h"
25 #include "clang/Analysis/CFG.h"
26 #include "clang/Analysis/Analyses/Consumed.h"
27 #include "clang/Basic/OperatorKinds.h"
28 #include "clang/Basic/SourceLocation.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/Support/Compiler.h"
32 #include "llvm/Support/raw_ostream.h"
33 
34 // TODO: Add notes about the actual and expected state for
35 // TODO: Correctly identify unreachable blocks when chaining boolean operators.
36 // TODO: Adjust the parser and AttributesList class to support lists of
37 //       identifiers.
38 // TODO: Warn about unreachable code.
39 // TODO: Switch to using a bitmap to track unreachable blocks.
40 // TODO: Mark variables as Unknown going into while- or for-loops only if they
41 //       are referenced inside that block. (Deferred)
42 // TODO: Handle variable definitions, e.g. bool valid = x.isValid();
43 //       if (valid) ...; (Deferred)
44 // TODO: Add a method(s) to identify which method calls perform what state
45 //       transitions. (Deferred)
46 // TODO: Take notes on state transitions to provide better warning messages.
47 //       (Deferred)
48 // TODO: Test nested conditionals: A) Checking the same value multiple times,
49 //       and 2) Checking different values. (Deferred)
50 
51 using namespace clang;
52 using namespace consumed;
53 
54 // Key method definition
55 ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() {}
56 
57 static ConsumedState invertConsumedUnconsumed(ConsumedState State) {
58   switch (State) {
59   case CS_Unconsumed:
60     return CS_Consumed;
61   case CS_Consumed:
62     return CS_Unconsumed;
63   case CS_None:
64     return CS_None;
65   case CS_Unknown:
66     return CS_Unknown;
67   }
68   llvm_unreachable("invalid enum");
69 }
70 
71 static bool isCallableInState(const CallableWhenAttr *CWAttr,
72                               ConsumedState State) {
73 
74   CallableWhenAttr::callableState_iterator I = CWAttr->callableState_begin(),
75                                            E = CWAttr->callableState_end();
76 
77   for (; I != E; ++I) {
78 
79     ConsumedState MappedAttrState = CS_None;
80 
81     switch (*I) {
82     case CallableWhenAttr::Unknown:
83       MappedAttrState = CS_Unknown;
84       break;
85 
86     case CallableWhenAttr::Unconsumed:
87       MappedAttrState = CS_Unconsumed;
88       break;
89 
90     case CallableWhenAttr::Consumed:
91       MappedAttrState = CS_Consumed;
92       break;
93     }
94 
95     if (MappedAttrState == State)
96       return true;
97   }
98 
99   return false;
100 }
101 
102 static bool isConsumableType(const QualType &QT) {
103   if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
104     return RD->hasAttr<ConsumableAttr>();
105   else
106     return false;
107 }
108 
109 static bool isKnownState(ConsumedState State) {
110   switch (State) {
111   case CS_Unconsumed:
112   case CS_Consumed:
113     return true;
114   case CS_None:
115   case CS_Unknown:
116     return false;
117   }
118   llvm_unreachable("invalid enum");
119 }
120 
121 static bool isTestingFunction(const FunctionDecl *FunDecl) {
122   return FunDecl->hasAttr<TestsUnconsumedAttr>();
123 }
124 
125 static ConsumedState mapConsumableAttrState(const QualType QT) {
126   assert(isConsumableType(QT));
127 
128   const ConsumableAttr *CAttr =
129       QT->getAsCXXRecordDecl()->getAttr<ConsumableAttr>();
130 
131   switch (CAttr->getDefaultState()) {
132   case ConsumableAttr::Unknown:
133     return CS_Unknown;
134   case ConsumableAttr::Unconsumed:
135     return CS_Unconsumed;
136   case ConsumableAttr::Consumed:
137     return CS_Consumed;
138   }
139   llvm_unreachable("invalid enum");
140 }
141 
142 static ConsumedState
143 mapReturnTypestateAttrState(const ReturnTypestateAttr *RTSAttr) {
144   switch (RTSAttr->getState()) {
145   case ReturnTypestateAttr::Unknown:
146     return CS_Unknown;
147   case ReturnTypestateAttr::Unconsumed:
148     return CS_Unconsumed;
149   case ReturnTypestateAttr::Consumed:
150     return CS_Consumed;
151   }
152   llvm_unreachable("invalid enum");
153 }
154 
155 static StringRef stateToString(ConsumedState State) {
156   switch (State) {
157   case consumed::CS_None:
158     return "none";
159 
160   case consumed::CS_Unknown:
161     return "unknown";
162 
163   case consumed::CS_Unconsumed:
164     return "unconsumed";
165 
166   case consumed::CS_Consumed:
167     return "consumed";
168   }
169   llvm_unreachable("invalid enum");
170 }
171 
172 namespace {
173 struct VarTestResult {
174   const VarDecl *Var;
175   ConsumedState TestsFor;
176 };
177 } // end anonymous::VarTestResult
178 
179 namespace clang {
180 namespace consumed {
181 
182 enum EffectiveOp {
183   EO_And,
184   EO_Or
185 };
186 
187 class PropagationInfo {
188   enum {
189     IT_None,
190     IT_State,
191     IT_Test,
192     IT_BinTest,
193     IT_Var
194   } InfoType;
195 
196   struct BinTestTy {
197     const BinaryOperator *Source;
198     EffectiveOp EOp;
199     VarTestResult LTest;
200     VarTestResult RTest;
201   };
202 
203   union {
204     ConsumedState State;
205     VarTestResult Test;
206     const VarDecl *Var;
207     BinTestTy BinTest;
208   };
209 
210   QualType TempType;
211 
212 public:
213   PropagationInfo() : InfoType(IT_None) {}
214 
215   PropagationInfo(const VarTestResult &Test) : InfoType(IT_Test), Test(Test) {}
216   PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)
217     : InfoType(IT_Test) {
218 
219     Test.Var      = Var;
220     Test.TestsFor = TestsFor;
221   }
222 
223   PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
224                   const VarTestResult &LTest, const VarTestResult &RTest)
225     : InfoType(IT_BinTest) {
226 
227     BinTest.Source  = Source;
228     BinTest.EOp     = EOp;
229     BinTest.LTest   = LTest;
230     BinTest.RTest   = RTest;
231   }
232 
233   PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
234                   const VarDecl *LVar, ConsumedState LTestsFor,
235                   const VarDecl *RVar, ConsumedState RTestsFor)
236     : InfoType(IT_BinTest) {
237 
238     BinTest.Source         = Source;
239     BinTest.EOp            = EOp;
240     BinTest.LTest.Var      = LVar;
241     BinTest.LTest.TestsFor = LTestsFor;
242     BinTest.RTest.Var      = RVar;
243     BinTest.RTest.TestsFor = RTestsFor;
244   }
245 
246   PropagationInfo(ConsumedState State, QualType TempType)
247     : InfoType(IT_State), State(State), TempType(TempType) {}
248 
249   PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
250 
251   const ConsumedState & getState() const {
252     assert(InfoType == IT_State);
253     return State;
254   }
255 
256   const QualType & getTempType() const {
257     assert(InfoType == IT_State);
258     return TempType;
259   }
260 
261   const VarTestResult & getTest() const {
262     assert(InfoType == IT_Test);
263     return Test;
264   }
265 
266   const VarTestResult & getLTest() const {
267     assert(InfoType == IT_BinTest);
268     return BinTest.LTest;
269   }
270 
271   const VarTestResult & getRTest() const {
272     assert(InfoType == IT_BinTest);
273     return BinTest.RTest;
274   }
275 
276   const VarDecl * getVar() const {
277     assert(InfoType == IT_Var);
278     return Var;
279   }
280 
281   EffectiveOp testEffectiveOp() const {
282     assert(InfoType == IT_BinTest);
283     return BinTest.EOp;
284   }
285 
286   const BinaryOperator * testSourceNode() const {
287     assert(InfoType == IT_BinTest);
288     return BinTest.Source;
289   }
290 
291   bool isValid()   const { return InfoType != IT_None;     }
292   bool isState()   const { return InfoType == IT_State;    }
293   bool isTest()    const { return InfoType == IT_Test;     }
294   bool isBinTest() const { return InfoType == IT_BinTest;  }
295   bool isVar()     const { return InfoType == IT_Var;      }
296 
297   PropagationInfo invertTest() const {
298     assert(InfoType == IT_Test || InfoType == IT_BinTest);
299 
300     if (InfoType == IT_Test) {
301       return PropagationInfo(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
302 
303     } else if (InfoType == IT_BinTest) {
304       return PropagationInfo(BinTest.Source,
305         BinTest.EOp == EO_And ? EO_Or : EO_And,
306         BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),
307         BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));
308     } else {
309       return PropagationInfo();
310     }
311   }
312 };
313 
314 class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {
315 
316   typedef llvm::DenseMap<const Stmt *, PropagationInfo> MapType;
317   typedef std::pair<const Stmt *, PropagationInfo> PairType;
318   typedef MapType::iterator InfoEntry;
319   typedef MapType::const_iterator ConstInfoEntry;
320 
321   AnalysisDeclContext &AC;
322   ConsumedAnalyzer &Analyzer;
323   ConsumedStateMap *StateMap;
324   MapType PropagationMap;
325 
326   void checkCallability(const PropagationInfo &PInfo,
327                         const FunctionDecl *FunDecl,
328                         const CallExpr *Call);
329   void forwardInfo(const Stmt *From, const Stmt *To);
330   void handleTestingFunctionCall(const CallExpr *Call, const VarDecl *Var);
331   bool isLikeMoveAssignment(const CXXMethodDecl *MethodDecl);
332   void propagateReturnType(const Stmt *Call, const FunctionDecl *Fun,
333                            QualType ReturnType);
334 
335 public:
336 
337   void Visit(const Stmt *StmtNode);
338 
339   void VisitBinaryOperator(const BinaryOperator *BinOp);
340   void VisitCallExpr(const CallExpr *Call);
341   void VisitCastExpr(const CastExpr *Cast);
342   void VisitCXXConstructExpr(const CXXConstructExpr *Call);
343   void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);
344   void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);
345   void VisitDeclRefExpr(const DeclRefExpr *DeclRef);
346   void VisitDeclStmt(const DeclStmt *DelcS);
347   void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
348   void VisitMemberExpr(const MemberExpr *MExpr);
349   void VisitParmVarDecl(const ParmVarDecl *Param);
350   void VisitReturnStmt(const ReturnStmt *Ret);
351   void VisitUnaryOperator(const UnaryOperator *UOp);
352   void VisitVarDecl(const VarDecl *Var);
353 
354   ConsumedStmtVisitor(AnalysisDeclContext &AC, ConsumedAnalyzer &Analyzer,
355                       ConsumedStateMap *StateMap)
356       : AC(AC), Analyzer(Analyzer), StateMap(StateMap) {}
357 
358   PropagationInfo getInfo(const Stmt *StmtNode) const {
359     ConstInfoEntry Entry = PropagationMap.find(StmtNode);
360 
361     if (Entry != PropagationMap.end())
362       return Entry->second;
363     else
364       return PropagationInfo();
365   }
366 
367   void reset(ConsumedStateMap *NewStateMap) {
368     StateMap = NewStateMap;
369   }
370 };
371 
372 void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
373                                            const FunctionDecl *FunDecl,
374                                            const CallExpr *Call) {
375 
376   if (!FunDecl->hasAttr<CallableWhenAttr>())
377     return;
378 
379   const CallableWhenAttr *CWAttr = FunDecl->getAttr<CallableWhenAttr>();
380 
381   if (PInfo.isVar()) {
382     const VarDecl *Var = PInfo.getVar();
383     ConsumedState VarState = StateMap->getState(Var);
384 
385     assert(VarState != CS_None && "Invalid state");
386 
387     if (isCallableInState(CWAttr, VarState))
388       return;
389 
390     Analyzer.WarningsHandler.warnUseInInvalidState(
391       FunDecl->getNameAsString(), Var->getNameAsString(),
392       stateToString(VarState), Call->getExprLoc());
393 
394   } else if (PInfo.isState()) {
395 
396     assert(PInfo.getState() != CS_None && "Invalid state");
397 
398     if (isCallableInState(CWAttr, PInfo.getState()))
399       return;
400 
401     Analyzer.WarningsHandler.warnUseOfTempInInvalidState(
402       FunDecl->getNameAsString(), stateToString(PInfo.getState()),
403       Call->getExprLoc());
404   }
405 }
406 
407 void ConsumedStmtVisitor::forwardInfo(const Stmt *From, const Stmt *To) {
408   InfoEntry Entry = PropagationMap.find(From);
409 
410   if (Entry != PropagationMap.end())
411     PropagationMap.insert(PairType(To, Entry->second));
412 }
413 
414 void ConsumedStmtVisitor::handleTestingFunctionCall(const CallExpr *Call,
415                                                     const VarDecl  *Var) {
416 
417   ConsumedState VarState = StateMap->getState(Var);
418 
419   if (VarState != CS_Unknown) {
420     SourceLocation CallLoc = Call->getExprLoc();
421 
422     if (!CallLoc.isMacroID())
423       Analyzer.WarningsHandler.warnUnnecessaryTest(Var->getNameAsString(),
424         stateToString(VarState), CallLoc);
425   }
426 
427   PropagationMap.insert(PairType(Call, PropagationInfo(Var, CS_Unconsumed)));
428 }
429 
430 bool ConsumedStmtVisitor::isLikeMoveAssignment(
431   const CXXMethodDecl *MethodDecl) {
432 
433   return MethodDecl->isMoveAssignmentOperator() ||
434          (MethodDecl->getOverloadedOperator() == OO_Equal &&
435           MethodDecl->getNumParams() == 1 &&
436           MethodDecl->getParamDecl(0)->getType()->isRValueReferenceType());
437 }
438 
439 void ConsumedStmtVisitor::propagateReturnType(const Stmt *Call,
440                                               const FunctionDecl *Fun,
441                                               QualType ReturnType) {
442   if (isConsumableType(ReturnType)) {
443 
444     ConsumedState ReturnState;
445 
446     if (Fun->hasAttr<ReturnTypestateAttr>())
447       ReturnState = mapReturnTypestateAttrState(
448         Fun->getAttr<ReturnTypestateAttr>());
449     else
450       ReturnState = mapConsumableAttrState(ReturnType);
451 
452     PropagationMap.insert(PairType(Call,
453       PropagationInfo(ReturnState, ReturnType)));
454   }
455 }
456 
457 void ConsumedStmtVisitor::Visit(const Stmt *StmtNode) {
458 
459   ConstStmtVisitor<ConsumedStmtVisitor>::Visit(StmtNode);
460 
461   for (Stmt::const_child_iterator CI = StmtNode->child_begin(),
462        CE = StmtNode->child_end(); CI != CE; ++CI) {
463 
464     PropagationMap.erase(*CI);
465   }
466 }
467 
468 void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {
469   switch (BinOp->getOpcode()) {
470   case BO_LAnd:
471   case BO_LOr : {
472     InfoEntry LEntry = PropagationMap.find(BinOp->getLHS()),
473               REntry = PropagationMap.find(BinOp->getRHS());
474 
475     VarTestResult LTest, RTest;
476 
477     if (LEntry != PropagationMap.end() && LEntry->second.isTest()) {
478       LTest = LEntry->second.getTest();
479 
480     } else {
481       LTest.Var      = NULL;
482       LTest.TestsFor = CS_None;
483     }
484 
485     if (REntry != PropagationMap.end() && REntry->second.isTest()) {
486       RTest = REntry->second.getTest();
487 
488     } else {
489       RTest.Var      = NULL;
490       RTest.TestsFor = CS_None;
491     }
492 
493     if (!(LTest.Var == NULL && RTest.Var == NULL))
494       PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,
495         static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));
496 
497     break;
498   }
499 
500   case BO_PtrMemD:
501   case BO_PtrMemI:
502     forwardInfo(BinOp->getLHS(), BinOp);
503     break;
504 
505   default:
506     break;
507   }
508 }
509 
510 void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {
511   if (const FunctionDecl *FunDecl =
512     dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee())) {
513 
514     // Special case for the std::move function.
515     // TODO: Make this more specific. (Deferred)
516     if (FunDecl->getNameAsString() == "move") {
517       InfoEntry Entry = PropagationMap.find(Call->getArg(0));
518 
519       if (Entry != PropagationMap.end()) {
520         PropagationMap.insert(PairType(Call, Entry->second));
521       }
522 
523       return;
524     }
525 
526     unsigned Offset = Call->getNumArgs() - FunDecl->getNumParams();
527 
528     for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {
529       QualType ParamType = FunDecl->getParamDecl(Index - Offset)->getType();
530 
531       InfoEntry Entry = PropagationMap.find(Call->getArg(Index));
532 
533       if (Entry == PropagationMap.end() || !Entry->second.isVar()) {
534         continue;
535       }
536 
537       PropagationInfo PInfo = Entry->second;
538 
539       if (ParamType->isRValueReferenceType() ||
540           (ParamType->isLValueReferenceType() &&
541            !cast<LValueReferenceType>(*ParamType).isSpelledAsLValue())) {
542 
543         StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
544 
545       } else if (!(ParamType.isConstQualified() ||
546                    ((ParamType->isReferenceType() ||
547                      ParamType->isPointerType()) &&
548                     ParamType->getPointeeType().isConstQualified()))) {
549 
550         StateMap->setState(PInfo.getVar(), consumed::CS_Unknown);
551       }
552     }
553 
554     QualType RetType = FunDecl->getCallResultType();
555     if (RetType->isReferenceType())
556       RetType = RetType->getPointeeType();
557 
558     propagateReturnType(Call, FunDecl, RetType);
559   }
560 }
561 
562 void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {
563   forwardInfo(Cast->getSubExpr(), Cast);
564 }
565 
566 void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {
567   CXXConstructorDecl *Constructor = Call->getConstructor();
568 
569   ASTContext &CurrContext = AC.getASTContext();
570   QualType ThisType = Constructor->getThisType(CurrContext)->getPointeeType();
571 
572   if (isConsumableType(ThisType)) {
573     if (Constructor->isDefaultConstructor()) {
574 
575       PropagationMap.insert(PairType(Call,
576         PropagationInfo(consumed::CS_Consumed, ThisType)));
577 
578     } else if (Constructor->isMoveConstructor()) {
579 
580       PropagationInfo PInfo =
581         PropagationMap.find(Call->getArg(0))->second;
582 
583       if (PInfo.isVar()) {
584         const VarDecl* Var = PInfo.getVar();
585 
586         PropagationMap.insert(PairType(Call,
587           PropagationInfo(StateMap->getState(Var), ThisType)));
588 
589         StateMap->setState(Var, consumed::CS_Consumed);
590 
591       } else {
592         PropagationMap.insert(PairType(Call, PInfo));
593       }
594 
595     } else if (Constructor->isCopyConstructor()) {
596       MapType::iterator Entry = PropagationMap.find(Call->getArg(0));
597 
598       if (Entry != PropagationMap.end())
599         PropagationMap.insert(PairType(Call, Entry->second));
600 
601     } else {
602       propagateReturnType(Call, Constructor, ThisType);
603     }
604   }
605 }
606 
607 void ConsumedStmtVisitor::VisitCXXMemberCallExpr(
608   const CXXMemberCallExpr *Call) {
609 
610   VisitCallExpr(Call);
611 
612   InfoEntry Entry = PropagationMap.find(Call->getCallee()->IgnoreParens());
613 
614   if (Entry != PropagationMap.end()) {
615     PropagationInfo PInfo = Entry->second;
616     const CXXMethodDecl *MethodDecl = Call->getMethodDecl();
617 
618     checkCallability(PInfo, MethodDecl, Call);
619 
620     if (PInfo.isVar()) {
621       if (isTestingFunction(MethodDecl))
622         handleTestingFunctionCall(Call, PInfo.getVar());
623       else if (MethodDecl->hasAttr<ConsumesAttr>())
624         StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
625     }
626   }
627 }
628 
629 void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(
630   const CXXOperatorCallExpr *Call) {
631 
632   const FunctionDecl *FunDecl =
633     dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
634 
635   if (!FunDecl) return;
636 
637   if (isa<CXXMethodDecl>(FunDecl) &&
638       isLikeMoveAssignment(cast<CXXMethodDecl>(FunDecl))) {
639 
640     InfoEntry LEntry = PropagationMap.find(Call->getArg(0));
641     InfoEntry REntry = PropagationMap.find(Call->getArg(1));
642 
643     PropagationInfo LPInfo, RPInfo;
644 
645     if (LEntry != PropagationMap.end() &&
646         REntry != PropagationMap.end()) {
647 
648       LPInfo = LEntry->second;
649       RPInfo = REntry->second;
650 
651       if (LPInfo.isVar() && RPInfo.isVar()) {
652         StateMap->setState(LPInfo.getVar(),
653           StateMap->getState(RPInfo.getVar()));
654 
655         StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
656 
657         PropagationMap.insert(PairType(Call, LPInfo));
658 
659       } else if (LPInfo.isVar() && !RPInfo.isVar()) {
660         StateMap->setState(LPInfo.getVar(), RPInfo.getState());
661 
662         PropagationMap.insert(PairType(Call, LPInfo));
663 
664       } else if (!LPInfo.isVar() && RPInfo.isVar()) {
665         PropagationMap.insert(PairType(Call,
666           PropagationInfo(StateMap->getState(RPInfo.getVar()),
667                           LPInfo.getTempType())));
668 
669         StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
670 
671       } else {
672         PropagationMap.insert(PairType(Call, RPInfo));
673       }
674 
675     } else if (LEntry != PropagationMap.end() &&
676                REntry == PropagationMap.end()) {
677 
678       LPInfo = LEntry->second;
679 
680       if (LPInfo.isVar()) {
681         StateMap->setState(LPInfo.getVar(), consumed::CS_Unknown);
682 
683         PropagationMap.insert(PairType(Call, LPInfo));
684 
685       } else if (LPInfo.isState()) {
686         PropagationMap.insert(PairType(Call,
687           PropagationInfo(consumed::CS_Unknown, LPInfo.getTempType())));
688       }
689 
690     } else if (LEntry == PropagationMap.end() &&
691                REntry != PropagationMap.end()) {
692 
693       if (REntry->second.isVar())
694         StateMap->setState(REntry->second.getVar(), consumed::CS_Consumed);
695     }
696 
697   } else {
698 
699     VisitCallExpr(Call);
700 
701     InfoEntry Entry = PropagationMap.find(Call->getArg(0));
702 
703     if (Entry != PropagationMap.end()) {
704       PropagationInfo PInfo = Entry->second;
705 
706       checkCallability(PInfo, FunDecl, Call);
707 
708       if (PInfo.isVar()) {
709         if (isTestingFunction(FunDecl))
710           handleTestingFunctionCall(Call, PInfo.getVar());
711         else if (FunDecl->hasAttr<ConsumesAttr>())
712           StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
713       }
714     }
715   }
716 }
717 
718 void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {
719   if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))
720     if (StateMap->getState(Var) != consumed::CS_None)
721       PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));
722 }
723 
724 void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {
725   for (DeclStmt::const_decl_iterator DI = DeclS->decl_begin(),
726        DE = DeclS->decl_end(); DI != DE; ++DI) {
727 
728     if (isa<VarDecl>(*DI)) VisitVarDecl(cast<VarDecl>(*DI));
729   }
730 
731   if (DeclS->isSingleDecl())
732     if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))
733       PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));
734 }
735 
736 void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(
737   const MaterializeTemporaryExpr *Temp) {
738 
739   InfoEntry Entry = PropagationMap.find(Temp->GetTemporaryExpr());
740 
741   if (Entry != PropagationMap.end())
742     PropagationMap.insert(PairType(Temp, Entry->second));
743 }
744 
745 void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {
746   forwardInfo(MExpr->getBase(), MExpr);
747 }
748 
749 
750 void ConsumedStmtVisitor::VisitParmVarDecl(const ParmVarDecl *Param) {
751   QualType ParamType = Param->getType();
752   ConsumedState ParamState = consumed::CS_None;
753 
754   if (!(ParamType->isPointerType() || ParamType->isReferenceType()) &&
755       isConsumableType(ParamType))
756     ParamState = mapConsumableAttrState(ParamType);
757   else if (ParamType->isReferenceType() &&
758            isConsumableType(ParamType->getPointeeType()))
759     ParamState = consumed::CS_Unknown;
760 
761   if (ParamState)
762     StateMap->setState(Param, ParamState);
763 }
764 
765 void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {
766   if (ConsumedState ExpectedState = Analyzer.getExpectedReturnState()) {
767     InfoEntry Entry = PropagationMap.find(Ret->getRetValue());
768 
769     if (Entry != PropagationMap.end()) {
770       assert(Entry->second.isState() || Entry->second.isVar());
771 
772       ConsumedState RetState = Entry->second.isState() ?
773         Entry->second.getState() : StateMap->getState(Entry->second.getVar());
774 
775       if (RetState != ExpectedState)
776         Analyzer.WarningsHandler.warnReturnTypestateMismatch(
777           Ret->getReturnLoc(), stateToString(ExpectedState),
778           stateToString(RetState));
779     }
780   }
781 }
782 
783 void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
784   InfoEntry Entry = PropagationMap.find(UOp->getSubExpr()->IgnoreParens());
785   if (Entry == PropagationMap.end()) return;
786 
787   switch (UOp->getOpcode()) {
788   case UO_AddrOf:
789     PropagationMap.insert(PairType(UOp, Entry->second));
790     break;
791 
792   case UO_LNot:
793     if (Entry->second.isTest() || Entry->second.isBinTest())
794       PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
795     break;
796 
797   default:
798     break;
799   }
800 }
801 
802 // TODO: See if I need to check for reference types here.
803 void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
804   if (isConsumableType(Var->getType())) {
805     if (Var->hasInit()) {
806       PropagationInfo PInfo =
807         PropagationMap.find(Var->getInit())->second;
808 
809       StateMap->setState(Var, PInfo.isVar() ?
810         StateMap->getState(PInfo.getVar()) : PInfo.getState());
811 
812     } else {
813       StateMap->setState(Var, consumed::CS_Unknown);
814     }
815   }
816 }
817 }} // end clang::consumed::ConsumedStmtVisitor
818 
819 namespace clang {
820 namespace consumed {
821 
822 void splitVarStateForIf(const IfStmt * IfNode, const VarTestResult &Test,
823                         ConsumedStateMap *ThenStates,
824                         ConsumedStateMap *ElseStates) {
825 
826   ConsumedState VarState = ThenStates->getState(Test.Var);
827 
828   if (VarState == CS_Unknown) {
829     ThenStates->setState(Test.Var, Test.TestsFor);
830     ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
831 
832   } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
833     ThenStates->markUnreachable();
834 
835   } else if (VarState == Test.TestsFor) {
836     ElseStates->markUnreachable();
837   }
838 }
839 
840 void splitVarStateForIfBinOp(const PropagationInfo &PInfo,
841   ConsumedStateMap *ThenStates, ConsumedStateMap *ElseStates) {
842 
843   const VarTestResult &LTest = PInfo.getLTest(),
844                       &RTest = PInfo.getRTest();
845 
846   ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,
847                 RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;
848 
849   if (LTest.Var) {
850     if (PInfo.testEffectiveOp() == EO_And) {
851       if (LState == CS_Unknown) {
852         ThenStates->setState(LTest.Var, LTest.TestsFor);
853 
854       } else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {
855         ThenStates->markUnreachable();
856 
857       } else if (LState == LTest.TestsFor && isKnownState(RState)) {
858         if (RState == RTest.TestsFor)
859           ElseStates->markUnreachable();
860         else
861           ThenStates->markUnreachable();
862       }
863 
864     } else {
865       if (LState == CS_Unknown) {
866         ElseStates->setState(LTest.Var,
867                              invertConsumedUnconsumed(LTest.TestsFor));
868 
869       } else if (LState == LTest.TestsFor) {
870         ElseStates->markUnreachable();
871 
872       } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
873                  isKnownState(RState)) {
874 
875         if (RState == RTest.TestsFor)
876           ElseStates->markUnreachable();
877         else
878           ThenStates->markUnreachable();
879       }
880     }
881   }
882 
883   if (RTest.Var) {
884     if (PInfo.testEffectiveOp() == EO_And) {
885       if (RState == CS_Unknown)
886         ThenStates->setState(RTest.Var, RTest.TestsFor);
887       else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
888         ThenStates->markUnreachable();
889 
890     } else {
891       if (RState == CS_Unknown)
892         ElseStates->setState(RTest.Var,
893                              invertConsumedUnconsumed(RTest.TestsFor));
894       else if (RState == RTest.TestsFor)
895         ElseStates->markUnreachable();
896     }
897   }
898 }
899 
900 void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
901                                 ConsumedStateMap *StateMap,
902                                 bool &AlreadyOwned) {
903 
904   if (VisitedBlocks.alreadySet(Block)) return;
905 
906   ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
907 
908   if (Entry) {
909     Entry->intersect(StateMap);
910 
911   } else if (AlreadyOwned) {
912     StateMapsArray[Block->getBlockID()] = new ConsumedStateMap(*StateMap);
913 
914   } else {
915     StateMapsArray[Block->getBlockID()] = StateMap;
916     AlreadyOwned = true;
917   }
918 }
919 
920 void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
921                                 ConsumedStateMap *StateMap) {
922 
923   if (VisitedBlocks.alreadySet(Block)) {
924     delete StateMap;
925     return;
926   }
927 
928   ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
929 
930   if (Entry) {
931     Entry->intersect(StateMap);
932     delete StateMap;
933 
934   } else {
935     StateMapsArray[Block->getBlockID()] = StateMap;
936   }
937 }
938 
939 ConsumedStateMap* ConsumedBlockInfo::getInfo(const CFGBlock *Block) {
940   return StateMapsArray[Block->getBlockID()];
941 }
942 
943 void ConsumedBlockInfo::markVisited(const CFGBlock *Block) {
944   VisitedBlocks.insert(Block);
945 }
946 
947 ConsumedState ConsumedStateMap::getState(const VarDecl *Var) {
948   MapType::const_iterator Entry = Map.find(Var);
949 
950   if (Entry != Map.end()) {
951     return Entry->second;
952 
953   } else {
954     return CS_None;
955   }
956 }
957 
958 void ConsumedStateMap::intersect(const ConsumedStateMap *Other) {
959   ConsumedState LocalState;
960 
961   if (this->From && this->From == Other->From && !Other->Reachable) {
962     this->markUnreachable();
963     return;
964   }
965 
966   for (MapType::const_iterator DMI = Other->Map.begin(),
967        DME = Other->Map.end(); DMI != DME; ++DMI) {
968 
969     LocalState = this->getState(DMI->first);
970 
971     if (LocalState == CS_None)
972       continue;
973 
974     if (LocalState != DMI->second)
975        Map[DMI->first] = CS_Unknown;
976   }
977 }
978 
979 void ConsumedStateMap::markUnreachable() {
980   this->Reachable = false;
981   Map.clear();
982 }
983 
984 void ConsumedStateMap::makeUnknown() {
985   for (MapType::const_iterator DMI = Map.begin(), DME = Map.end(); DMI != DME;
986        ++DMI) {
987 
988     Map[DMI->first] = CS_Unknown;
989   }
990 }
991 
992 void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {
993   Map[Var] = State;
994 }
995 
996 void ConsumedStateMap::remove(const VarDecl *Var) {
997   Map.erase(Var);
998 }
999 
1000 void ConsumedAnalyzer::determineExpectedReturnState(AnalysisDeclContext &AC,
1001                                                     const FunctionDecl *D) {
1002   QualType ReturnType;
1003   if (const CXXConstructorDecl *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
1004     ASTContext &CurrContext = AC.getASTContext();
1005     ReturnType = Constructor->getThisType(CurrContext)->getPointeeType();
1006   } else
1007     ReturnType = D->getCallResultType();
1008 
1009   if (D->hasAttr<ReturnTypestateAttr>()) {
1010     const ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>();
1011 
1012     const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();
1013     if (!RD || !RD->hasAttr<ConsumableAttr>()) {
1014       // FIXME: This should be removed when template instantiation propagates
1015       //        attributes at template specialization definition, not
1016       //        declaration. When it is removed the test needs to be enabled
1017       //        in SemaDeclAttr.cpp.
1018       WarningsHandler.warnReturnTypestateForUnconsumableType(
1019           RTSAttr->getLocation(), ReturnType.getAsString());
1020       ExpectedReturnState = CS_None;
1021     } else
1022       ExpectedReturnState = mapReturnTypestateAttrState(RTSAttr);
1023   } else if (isConsumableType(ReturnType))
1024     ExpectedReturnState = mapConsumableAttrState(ReturnType);
1025   else
1026     ExpectedReturnState = CS_None;
1027 }
1028 
1029 bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,
1030                                   const ConsumedStmtVisitor &Visitor) {
1031 
1032   ConsumedStateMap *FalseStates = new ConsumedStateMap(*CurrStates);
1033   PropagationInfo PInfo;
1034 
1035   if (const IfStmt *IfNode =
1036     dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
1037 
1038     const Stmt *Cond = IfNode->getCond();
1039 
1040     PInfo = Visitor.getInfo(Cond);
1041     if (!PInfo.isValid() && isa<BinaryOperator>(Cond))
1042       PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
1043 
1044     if (PInfo.isTest()) {
1045       CurrStates->setSource(Cond);
1046       FalseStates->setSource(Cond);
1047       splitVarStateForIf(IfNode, PInfo.getTest(), CurrStates, FalseStates);
1048 
1049     } else if (PInfo.isBinTest()) {
1050       CurrStates->setSource(PInfo.testSourceNode());
1051       FalseStates->setSource(PInfo.testSourceNode());
1052       splitVarStateForIfBinOp(PInfo, CurrStates, FalseStates);
1053 
1054     } else {
1055       delete FalseStates;
1056       return false;
1057     }
1058 
1059   } else if (const BinaryOperator *BinOp =
1060     dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {
1061 
1062     PInfo = Visitor.getInfo(BinOp->getLHS());
1063     if (!PInfo.isTest()) {
1064       if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
1065         PInfo = Visitor.getInfo(BinOp->getRHS());
1066 
1067         if (!PInfo.isTest()) {
1068           delete FalseStates;
1069           return false;
1070         }
1071 
1072       } else {
1073         delete FalseStates;
1074         return false;
1075       }
1076     }
1077 
1078     CurrStates->setSource(BinOp);
1079     FalseStates->setSource(BinOp);
1080 
1081     const VarTestResult &Test = PInfo.getTest();
1082     ConsumedState VarState = CurrStates->getState(Test.Var);
1083 
1084     if (BinOp->getOpcode() == BO_LAnd) {
1085       if (VarState == CS_Unknown)
1086         CurrStates->setState(Test.Var, Test.TestsFor);
1087       else if (VarState == invertConsumedUnconsumed(Test.TestsFor))
1088         CurrStates->markUnreachable();
1089 
1090     } else if (BinOp->getOpcode() == BO_LOr) {
1091       if (VarState == CS_Unknown)
1092         FalseStates->setState(Test.Var,
1093                               invertConsumedUnconsumed(Test.TestsFor));
1094       else if (VarState == Test.TestsFor)
1095         FalseStates->markUnreachable();
1096     }
1097 
1098   } else {
1099     delete FalseStates;
1100     return false;
1101   }
1102 
1103   CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();
1104 
1105   if (*SI)
1106     BlockInfo.addInfo(*SI, CurrStates);
1107   else
1108     delete CurrStates;
1109 
1110   if (*++SI)
1111     BlockInfo.addInfo(*SI, FalseStates);
1112   else
1113     delete FalseStates;
1114 
1115   CurrStates = NULL;
1116   return true;
1117 }
1118 
1119 void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {
1120   const FunctionDecl *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());
1121   if (!D)
1122     return;
1123 
1124   CFG *CFGraph = AC.getCFG();
1125   if (!CFGraph)
1126     return;
1127 
1128   determineExpectedReturnState(AC, D);
1129 
1130   BlockInfo = ConsumedBlockInfo(CFGraph);
1131 
1132   PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();
1133 
1134   CurrStates = new ConsumedStateMap();
1135   ConsumedStmtVisitor Visitor(AC, *this, CurrStates);
1136 
1137   // Add all trackable parameters to the state map.
1138   for (FunctionDecl::param_const_iterator PI = D->param_begin(),
1139        PE = D->param_end(); PI != PE; ++PI) {
1140     Visitor.VisitParmVarDecl(*PI);
1141   }
1142 
1143   // Visit all of the function's basic blocks.
1144   for (PostOrderCFGView::iterator I = SortedGraph->begin(),
1145        E = SortedGraph->end(); I != E; ++I) {
1146 
1147     const CFGBlock *CurrBlock = *I;
1148     BlockInfo.markVisited(CurrBlock);
1149 
1150     if (CurrStates == NULL)
1151       CurrStates = BlockInfo.getInfo(CurrBlock);
1152 
1153     if (!CurrStates) {
1154       continue;
1155 
1156     } else if (!CurrStates->isReachable()) {
1157       delete CurrStates;
1158       CurrStates = NULL;
1159       continue;
1160     }
1161 
1162     Visitor.reset(CurrStates);
1163 
1164     // Visit all of the basic block's statements.
1165     for (CFGBlock::const_iterator BI = CurrBlock->begin(),
1166          BE = CurrBlock->end(); BI != BE; ++BI) {
1167 
1168       switch (BI->getKind()) {
1169       case CFGElement::Statement:
1170         Visitor.Visit(BI->castAs<CFGStmt>().getStmt());
1171         break;
1172       case CFGElement::AutomaticObjectDtor:
1173         CurrStates->remove(BI->castAs<CFGAutomaticObjDtor>().getVarDecl());
1174       default:
1175         break;
1176       }
1177     }
1178 
1179     // TODO: Handle other forms of branching with precision, including while-
1180     //       and for-loops. (Deferred)
1181     if (!splitState(CurrBlock, Visitor)) {
1182       CurrStates->setSource(NULL);
1183 
1184       if (CurrBlock->succ_size() > 1) {
1185         CurrStates->makeUnknown();
1186 
1187         bool OwnershipTaken = false;
1188 
1189         for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),
1190              SE = CurrBlock->succ_end(); SI != SE; ++SI) {
1191 
1192           if (*SI) BlockInfo.addInfo(*SI, CurrStates, OwnershipTaken);
1193         }
1194 
1195         if (!OwnershipTaken)
1196           delete CurrStates;
1197 
1198         CurrStates = NULL;
1199 
1200       } else if (CurrBlock->succ_size() == 1 &&
1201                  (*CurrBlock->succ_begin())->pred_size() > 1) {
1202 
1203         BlockInfo.addInfo(*CurrBlock->succ_begin(), CurrStates);
1204         CurrStates = NULL;
1205       }
1206     }
1207   } // End of block iterator.
1208 
1209   // Delete the last existing state map.
1210   delete CurrStates;
1211 
1212   WarningsHandler.emitDiagnostics();
1213 }
1214 }} // end namespace clang::consumed
1215