1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// This file contains functions which are used to decide if a loop worth to be 10 /// unrolled. Moreover, these functions manages the stack of loop which is 11 /// tracked by the ProgramState. 12 /// 13 //===----------------------------------------------------------------------===// 14 15 #include "clang/ASTMatchers/ASTMatchers.h" 16 #include "clang/ASTMatchers/ASTMatchFinder.h" 17 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h" 18 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h" 19 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h" 20 #include <optional> 21 22 using namespace clang; 23 using namespace ento; 24 using namespace clang::ast_matchers; 25 26 static const int MAXIMUM_STEP_UNROLLED = 128; 27 28 namespace { 29 struct LoopState { 30 private: 31 enum Kind { Normal, Unrolled } K; 32 const Stmt *LoopStmt; 33 const LocationContext *LCtx; 34 unsigned maxStep; 35 LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N) 36 : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {} 37 38 public: 39 static LoopState getNormal(const Stmt *S, const LocationContext *L, 40 unsigned N) { 41 return LoopState(Normal, S, L, N); 42 } 43 static LoopState getUnrolled(const Stmt *S, const LocationContext *L, 44 unsigned N) { 45 return LoopState(Unrolled, S, L, N); 46 } 47 bool isUnrolled() const { return K == Unrolled; } 48 unsigned getMaxStep() const { return maxStep; } 49 const Stmt *getLoopStmt() const { return LoopStmt; } 50 const LocationContext *getLocationContext() const { return LCtx; } 51 bool operator==(const LoopState &X) const { 52 return K == X.K && LoopStmt == X.LoopStmt; 53 } 54 void Profile(llvm::FoldingSetNodeID &ID) const { 55 ID.AddInteger(K); 56 ID.AddPointer(LoopStmt); 57 ID.AddPointer(LCtx); 58 ID.AddInteger(maxStep); 59 } 60 }; 61 } // namespace 62 63 // The tracked stack of loops. The stack indicates that which loops the 64 // simulated element contained by. The loops are marked depending if we decided 65 // to unroll them. 66 // TODO: The loop stack should not need to be in the program state since it is 67 // lexical in nature. Instead, the stack of loops should be tracked in the 68 // LocationContext. 69 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState) 70 71 namespace clang { 72 namespace ento { 73 74 static bool isLoopStmt(const Stmt *S) { 75 return isa_and_nonnull<ForStmt, WhileStmt, DoStmt>(S); 76 } 77 78 ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) { 79 auto LS = State->get<LoopStack>(); 80 if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt) 81 State = State->set<LoopStack>(LS.getTail()); 82 return State; 83 } 84 85 static internal::Matcher<Stmt> simpleCondition(StringRef BindName, 86 StringRef RefName) { 87 return binaryOperator( 88 anyOf(hasOperatorName("<"), hasOperatorName(">"), 89 hasOperatorName("<="), hasOperatorName(">="), 90 hasOperatorName("!=")), 91 hasEitherOperand(ignoringParenImpCasts( 92 declRefExpr(to(varDecl(hasType(isInteger())).bind(BindName))) 93 .bind(RefName))), 94 hasEitherOperand( 95 ignoringParenImpCasts(integerLiteral().bind("boundNum")))) 96 .bind("conditionOperator"); 97 } 98 99 static internal::Matcher<Stmt> 100 changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) { 101 return anyOf( 102 unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")), 103 hasUnaryOperand(ignoringParenImpCasts( 104 declRefExpr(to(varDecl(VarNodeMatcher)))))), 105 binaryOperator(isAssignmentOperator(), 106 hasLHS(ignoringParenImpCasts( 107 declRefExpr(to(varDecl(VarNodeMatcher))))))); 108 } 109 110 static internal::Matcher<Stmt> 111 callByRef(internal::Matcher<Decl> VarNodeMatcher) { 112 return callExpr(forEachArgumentWithParam( 113 declRefExpr(to(varDecl(VarNodeMatcher))), 114 parmVarDecl(hasType(references(qualType(unless(isConstQualified()))))))); 115 } 116 117 static internal::Matcher<Stmt> 118 assignedToRef(internal::Matcher<Decl> VarNodeMatcher) { 119 return declStmt(hasDescendant(varDecl( 120 allOf(hasType(referenceType()), 121 hasInitializer(anyOf( 122 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))), 123 declRefExpr(to(varDecl(VarNodeMatcher))))))))); 124 } 125 126 static internal::Matcher<Stmt> 127 getAddrTo(internal::Matcher<Decl> VarNodeMatcher) { 128 return unaryOperator( 129 hasOperatorName("&"), 130 hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher)))); 131 } 132 133 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) { 134 return hasDescendant(stmt( 135 anyOf(gotoStmt(), switchStmt(), returnStmt(), 136 // Escaping and not known mutation of the loop counter is handled 137 // by exclusion of assigning and address-of operators and 138 // pass-by-ref function calls on the loop counter from the body. 139 changeIntBoundNode(equalsBoundNode(std::string(NodeName))), 140 callByRef(equalsBoundNode(std::string(NodeName))), 141 getAddrTo(equalsBoundNode(std::string(NodeName))), 142 assignedToRef(equalsBoundNode(std::string(NodeName)))))); 143 } 144 145 static internal::Matcher<Stmt> forLoopMatcher() { 146 return forStmt( 147 hasCondition(simpleCondition("initVarName", "initVarRef")), 148 // Initialization should match the form: 'int i = 6' or 'i = 42'. 149 hasLoopInit( 150 anyOf(declStmt(hasSingleDecl( 151 varDecl(allOf(hasInitializer(ignoringParenImpCasts( 152 integerLiteral().bind("initNum"))), 153 equalsBoundNode("initVarName"))))), 154 binaryOperator(hasLHS(declRefExpr(to(varDecl( 155 equalsBoundNode("initVarName"))))), 156 hasRHS(ignoringParenImpCasts( 157 integerLiteral().bind("initNum")))))), 158 // Incrementation should be a simple increment or decrement 159 // operator call. 160 hasIncrement(unaryOperator( 161 anyOf(hasOperatorName("++"), hasOperatorName("--")), 162 hasUnaryOperand(declRefExpr( 163 to(varDecl(allOf(equalsBoundNode("initVarName"), 164 hasType(isInteger())))))))), 165 unless(hasBody(hasSuspiciousStmt("initVarName")))) 166 .bind("forLoop"); 167 } 168 169 static bool isCapturedByReference(ExplodedNode *N, const DeclRefExpr *DR) { 170 171 // Get the lambda CXXRecordDecl 172 assert(DR->refersToEnclosingVariableOrCapture()); 173 const LocationContext *LocCtxt = N->getLocationContext(); 174 const Decl *D = LocCtxt->getDecl(); 175 const auto *MD = cast<CXXMethodDecl>(D); 176 assert(MD && MD->getParent()->isLambda() && 177 "Captured variable should only be seen while evaluating a lambda"); 178 const CXXRecordDecl *LambdaCXXRec = MD->getParent(); 179 180 // Lookup the fields of the lambda 181 llvm::DenseMap<const ValueDecl *, FieldDecl *> LambdaCaptureFields; 182 FieldDecl *LambdaThisCaptureField; 183 LambdaCXXRec->getCaptureFields(LambdaCaptureFields, LambdaThisCaptureField); 184 185 // Check if the counter is captured by reference 186 const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl()); 187 assert(VD); 188 const FieldDecl *FD = LambdaCaptureFields[VD]; 189 assert(FD && "Captured variable without a corresponding field"); 190 return FD->getType()->isReferenceType(); 191 } 192 193 static bool isFoundInStmt(const Stmt *S, const VarDecl *VD) { 194 if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) { 195 for (const Decl *D : DS->decls()) { 196 // Once we reach the declaration of the VD we can return. 197 if (D->getCanonicalDecl() == VD) 198 return true; 199 } 200 } 201 return false; 202 } 203 204 // A loop counter is considered escaped if: 205 // case 1: It is a global variable. 206 // case 2: It is a reference parameter or a reference capture. 207 // case 3: It is assigned to a non-const reference variable or parameter. 208 // case 4: Has its address taken. 209 static bool isPossiblyEscaped(ExplodedNode *N, const DeclRefExpr *DR) { 210 const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl()); 211 assert(VD); 212 // Case 1: 213 if (VD->hasGlobalStorage()) 214 return true; 215 216 const bool IsRefParamOrCapture = 217 isa<ParmVarDecl>(VD) || DR->refersToEnclosingVariableOrCapture(); 218 // Case 2: 219 if ((DR->refersToEnclosingVariableOrCapture() && 220 isCapturedByReference(N, DR)) || 221 (IsRefParamOrCapture && VD->getType()->isReferenceType())) 222 return true; 223 224 while (!N->pred_empty()) { 225 // FIXME: getStmtForDiagnostics() does nasty things in order to provide 226 // a valid statement for body farms, do we need this behavior here? 227 const Stmt *S = N->getStmtForDiagnostics(); 228 if (!S) { 229 N = N->getFirstPred(); 230 continue; 231 } 232 233 if (isFoundInStmt(S, VD)) { 234 return false; 235 } 236 237 if (const auto *SS = dyn_cast<SwitchStmt>(S)) { 238 if (const auto *CST = dyn_cast<CompoundStmt>(SS->getBody())) { 239 for (const Stmt *CB : CST->body()) { 240 if (isFoundInStmt(CB, VD)) 241 return false; 242 } 243 } 244 } 245 246 // Check the usage of the pass-by-ref function calls and adress-of operator 247 // on VD and reference initialized by VD. 248 ASTContext &ASTCtx = 249 N->getLocationContext()->getAnalysisDeclContext()->getASTContext(); 250 // Case 3 and 4: 251 auto Match = 252 match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)), 253 assignedToRef(equalsNode(VD)))), 254 *S, ASTCtx); 255 if (!Match.empty()) 256 return true; 257 258 N = N->getFirstPred(); 259 } 260 261 // Reference parameter and reference capture will not be found. 262 if (IsRefParamOrCapture) 263 return false; 264 265 llvm_unreachable("Reached root without finding the declaration of VD"); 266 } 267 268 static bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx, 269 ExplodedNode *Pred, unsigned &maxStep) { 270 271 if (!isLoopStmt(LoopStmt)) 272 return false; 273 274 // TODO: Match the cases where the bound is not a concrete literal but an 275 // integer with known value 276 auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx); 277 if (Matches.empty()) 278 return false; 279 280 const auto *CounterVarRef = Matches[0].getNodeAs<DeclRefExpr>("initVarRef"); 281 llvm::APInt BoundNum = 282 Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue(); 283 llvm::APInt InitNum = 284 Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue(); 285 auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator"); 286 unsigned MaxWidth = std::max(InitNum.getBitWidth(), BoundNum.getBitWidth()); 287 288 InitNum = InitNum.zext(MaxWidth); 289 BoundNum = BoundNum.zext(MaxWidth); 290 291 if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE) 292 maxStep = (BoundNum - InitNum + 1).abs().getZExtValue(); 293 else 294 maxStep = (BoundNum - InitNum).abs().getZExtValue(); 295 296 // Check if the counter of the loop is not escaped before. 297 return !isPossiblyEscaped(Pred, CounterVarRef); 298 } 299 300 static bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) { 301 const Stmt *S = nullptr; 302 while (!N->pred_empty()) { 303 if (N->succ_size() > 1) 304 return true; 305 306 ProgramPoint P = N->getLocation(); 307 if (std::optional<BlockEntrance> BE = P.getAs<BlockEntrance>()) 308 S = BE->getBlock()->getTerminatorStmt(); 309 310 if (S == LoopStmt) 311 return false; 312 313 N = N->getFirstPred(); 314 } 315 316 llvm_unreachable("Reached root without encountering the previous step"); 317 } 318 319 // updateLoopStack is called on every basic block, therefore it needs to be fast 320 ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx, 321 ExplodedNode *Pred, unsigned maxVisitOnPath) { 322 auto State = Pred->getState(); 323 auto LCtx = Pred->getLocationContext(); 324 325 if (!isLoopStmt(LoopStmt)) 326 return State; 327 328 auto LS = State->get<LoopStack>(); 329 if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() && 330 LCtx == LS.getHead().getLocationContext()) { 331 if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) { 332 State = State->set<LoopStack>(LS.getTail()); 333 State = State->add<LoopStack>( 334 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); 335 } 336 return State; 337 } 338 unsigned maxStep; 339 if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) { 340 State = State->add<LoopStack>( 341 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); 342 return State; 343 } 344 345 unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep()); 346 347 unsigned innerMaxStep = maxStep * outerStep; 348 if (innerMaxStep > MAXIMUM_STEP_UNROLLED) 349 State = State->add<LoopStack>( 350 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); 351 else 352 State = State->add<LoopStack>( 353 LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep)); 354 return State; 355 } 356 357 bool isUnrolledState(ProgramStateRef State) { 358 auto LS = State->get<LoopStack>(); 359 if (LS.isEmpty() || !LS.getHead().isUnrolled()) 360 return false; 361 return true; 362 } 363 } 364 } 365