xref: /freebsd-src/contrib/llvm-project/clang/lib/StaticAnalyzer/Core/LoopUnrolling.cpp (revision 19261079b74319502c6ffa1249920079f0f69a72)
1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// This file contains functions which are used to decide if a loop worth to be
10 /// unrolled. Moreover, these functions manages the stack of loop which is
11 /// tracked by the ProgramState.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "clang/ASTMatchers/ASTMatchers.h"
16 #include "clang/ASTMatchers/ASTMatchFinder.h"
17 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
18 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
19 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
20 
21 using namespace clang;
22 using namespace ento;
23 using namespace clang::ast_matchers;
24 
25 static const int MAXIMUM_STEP_UNROLLED = 128;
26 
27 struct LoopState {
28 private:
29   enum Kind { Normal, Unrolled } K;
30   const Stmt *LoopStmt;
31   const LocationContext *LCtx;
32   unsigned maxStep;
33   LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N)
34       : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {}
35 
36 public:
37   static LoopState getNormal(const Stmt *S, const LocationContext *L,
38                              unsigned N) {
39     return LoopState(Normal, S, L, N);
40   }
41   static LoopState getUnrolled(const Stmt *S, const LocationContext *L,
42                                unsigned N) {
43     return LoopState(Unrolled, S, L, N);
44   }
45   bool isUnrolled() const { return K == Unrolled; }
46   unsigned getMaxStep() const { return maxStep; }
47   const Stmt *getLoopStmt() const { return LoopStmt; }
48   const LocationContext *getLocationContext() const { return LCtx; }
49   bool operator==(const LoopState &X) const {
50     return K == X.K && LoopStmt == X.LoopStmt;
51   }
52   void Profile(llvm::FoldingSetNodeID &ID) const {
53     ID.AddInteger(K);
54     ID.AddPointer(LoopStmt);
55     ID.AddPointer(LCtx);
56     ID.AddInteger(maxStep);
57   }
58 };
59 
60 // The tracked stack of loops. The stack indicates that which loops the
61 // simulated element contained by. The loops are marked depending if we decided
62 // to unroll them.
63 // TODO: The loop stack should not need to be in the program state since it is
64 // lexical in nature. Instead, the stack of loops should be tracked in the
65 // LocationContext.
66 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
67 
68 namespace clang {
69 namespace ento {
70 
71 static bool isLoopStmt(const Stmt *S) {
72   return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S));
73 }
74 
75 ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) {
76   auto LS = State->get<LoopStack>();
77   if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
78     State = State->set<LoopStack>(LS.getTail());
79   return State;
80 }
81 
82 static internal::Matcher<Stmt> simpleCondition(StringRef BindName) {
83   return binaryOperator(anyOf(hasOperatorName("<"), hasOperatorName(">"),
84                               hasOperatorName("<="), hasOperatorName(">="),
85                               hasOperatorName("!=")),
86                         hasEitherOperand(ignoringParenImpCasts(declRefExpr(
87                             to(varDecl(hasType(isInteger())).bind(BindName))))),
88                         hasEitherOperand(ignoringParenImpCasts(
89                             integerLiteral().bind("boundNum"))))
90       .bind("conditionOperator");
91 }
92 
93 static internal::Matcher<Stmt>
94 changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) {
95   return anyOf(
96       unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
97                     hasUnaryOperand(ignoringParenImpCasts(
98                         declRefExpr(to(varDecl(VarNodeMatcher)))))),
99       binaryOperator(isAssignmentOperator(),
100                      hasLHS(ignoringParenImpCasts(
101                          declRefExpr(to(varDecl(VarNodeMatcher)))))));
102 }
103 
104 static internal::Matcher<Stmt>
105 callByRef(internal::Matcher<Decl> VarNodeMatcher) {
106   return callExpr(forEachArgumentWithParam(
107       declRefExpr(to(varDecl(VarNodeMatcher))),
108       parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
109 }
110 
111 static internal::Matcher<Stmt>
112 assignedToRef(internal::Matcher<Decl> VarNodeMatcher) {
113   return declStmt(hasDescendant(varDecl(
114       allOf(hasType(referenceType()),
115             hasInitializer(anyOf(
116                 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))),
117                 declRefExpr(to(varDecl(VarNodeMatcher)))))))));
118 }
119 
120 static internal::Matcher<Stmt>
121 getAddrTo(internal::Matcher<Decl> VarNodeMatcher) {
122   return unaryOperator(
123       hasOperatorName("&"),
124       hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher))));
125 }
126 
127 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
128   return hasDescendant(stmt(
129       anyOf(gotoStmt(), switchStmt(), returnStmt(),
130             // Escaping and not known mutation of the loop counter is handled
131             // by exclusion of assigning and address-of operators and
132             // pass-by-ref function calls on the loop counter from the body.
133             changeIntBoundNode(equalsBoundNode(std::string(NodeName))),
134             callByRef(equalsBoundNode(std::string(NodeName))),
135             getAddrTo(equalsBoundNode(std::string(NodeName))),
136             assignedToRef(equalsBoundNode(std::string(NodeName))))));
137 }
138 
139 static internal::Matcher<Stmt> forLoopMatcher() {
140   return forStmt(
141              hasCondition(simpleCondition("initVarName")),
142              // Initialization should match the form: 'int i = 6' or 'i = 42'.
143              hasLoopInit(
144                  anyOf(declStmt(hasSingleDecl(
145                            varDecl(allOf(hasInitializer(ignoringParenImpCasts(
146                                              integerLiteral().bind("initNum"))),
147                                          equalsBoundNode("initVarName"))))),
148                        binaryOperator(hasLHS(declRefExpr(to(varDecl(
149                                           equalsBoundNode("initVarName"))))),
150                                       hasRHS(ignoringParenImpCasts(
151                                           integerLiteral().bind("initNum")))))),
152              // Incrementation should be a simple increment or decrement
153              // operator call.
154              hasIncrement(unaryOperator(
155                  anyOf(hasOperatorName("++"), hasOperatorName("--")),
156                  hasUnaryOperand(declRefExpr(
157                      to(varDecl(allOf(equalsBoundNode("initVarName"),
158                                       hasType(isInteger())))))))),
159              unless(hasBody(hasSuspiciousStmt("initVarName")))).bind("forLoop");
160 }
161 
162 static bool isPossiblyEscaped(const VarDecl *VD, ExplodedNode *N) {
163   // Global variables assumed as escaped variables.
164   if (VD->hasGlobalStorage())
165     return true;
166 
167   const bool isParm = isa<ParmVarDecl>(VD);
168   // Reference parameters are assumed as escaped variables.
169   if (isParm && VD->getType()->isReferenceType())
170     return true;
171 
172   while (!N->pred_empty()) {
173     // FIXME: getStmtForDiagnostics() does nasty things in order to provide
174     // a valid statement for body farms, do we need this behavior here?
175     const Stmt *S = N->getStmtForDiagnostics();
176     if (!S) {
177       N = N->getFirstPred();
178       continue;
179     }
180 
181     if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
182       for (const Decl *D : DS->decls()) {
183         // Once we reach the declaration of the VD we can return.
184         if (D->getCanonicalDecl() == VD)
185           return false;
186       }
187     }
188     // Check the usage of the pass-by-ref function calls and adress-of operator
189     // on VD and reference initialized by VD.
190     ASTContext &ASTCtx =
191         N->getLocationContext()->getAnalysisDeclContext()->getASTContext();
192     auto Match =
193         match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
194                          assignedToRef(equalsNode(VD)))),
195               *S, ASTCtx);
196     if (!Match.empty())
197       return true;
198 
199     N = N->getFirstPred();
200   }
201 
202   // Parameter declaration will not be found.
203   if (isParm)
204     return false;
205 
206   llvm_unreachable("Reached root without finding the declaration of VD");
207 }
208 
209 bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
210                             ExplodedNode *Pred, unsigned &maxStep) {
211 
212   if (!isLoopStmt(LoopStmt))
213     return false;
214 
215   // TODO: Match the cases where the bound is not a concrete literal but an
216   // integer with known value
217   auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx);
218   if (Matches.empty())
219     return false;
220 
221   auto CounterVar = Matches[0].getNodeAs<VarDecl>("initVarName");
222   llvm::APInt BoundNum =
223       Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue();
224   llvm::APInt InitNum =
225       Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue();
226   auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator");
227   if (InitNum.getBitWidth() != BoundNum.getBitWidth()) {
228     InitNum = InitNum.zextOrSelf(BoundNum.getBitWidth());
229     BoundNum = BoundNum.zextOrSelf(InitNum.getBitWidth());
230   }
231 
232   if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
233     maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
234   else
235     maxStep = (BoundNum - InitNum).abs().getZExtValue();
236 
237   // Check if the counter of the loop is not escaped before.
238   return !isPossiblyEscaped(CounterVar->getCanonicalDecl(), Pred);
239 }
240 
241 bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) {
242   const Stmt *S = nullptr;
243   while (!N->pred_empty()) {
244     if (N->succ_size() > 1)
245       return true;
246 
247     ProgramPoint P = N->getLocation();
248     if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
249       S = BE->getBlock()->getTerminatorStmt();
250 
251     if (S == LoopStmt)
252       return false;
253 
254     N = N->getFirstPred();
255   }
256 
257   llvm_unreachable("Reached root without encountering the previous step");
258 }
259 
260 // updateLoopStack is called on every basic block, therefore it needs to be fast
261 ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx,
262                                 ExplodedNode *Pred, unsigned maxVisitOnPath) {
263   auto State = Pred->getState();
264   auto LCtx = Pred->getLocationContext();
265 
266   if (!isLoopStmt(LoopStmt))
267     return State;
268 
269   auto LS = State->get<LoopStack>();
270   if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
271       LCtx == LS.getHead().getLocationContext()) {
272     if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) {
273       State = State->set<LoopStack>(LS.getTail());
274       State = State->add<LoopStack>(
275           LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
276     }
277     return State;
278   }
279   unsigned maxStep;
280   if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) {
281     State = State->add<LoopStack>(
282         LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
283     return State;
284   }
285 
286   unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
287 
288   unsigned innerMaxStep = maxStep * outerStep;
289   if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
290     State = State->add<LoopStack>(
291         LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
292   else
293     State = State->add<LoopStack>(
294         LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep));
295   return State;
296 }
297 
298 bool isUnrolledState(ProgramStateRef State) {
299   auto LS = State->get<LoopStack>();
300   if (LS.isEmpty() || !LS.getHead().isUnrolled())
301     return false;
302   return true;
303 }
304 }
305 }
306