1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file contains the implementation of the scalar evolution analysis 11 // engine, which is used primarily to analyze expressions involving induction 12 // variables in loops. 13 // 14 // There are several aspects to this library. First is the representation of 15 // scalar expressions, which are represented as subclasses of the SCEV class. 16 // These classes are used to represent certain types of subexpressions that we 17 // can handle. We only create one SCEV of a particular shape, so 18 // pointer-comparisons for equality are legal. 19 // 20 // One important aspect of the SCEV objects is that they are never cyclic, even 21 // if there is a cycle in the dataflow for an expression (ie, a PHI node). If 22 // the PHI node is one of the idioms that we can represent (e.g., a polynomial 23 // recurrence) then we represent it directly as a recurrence node, otherwise we 24 // represent it as a SCEVUnknown node. 25 // 26 // In addition to being able to represent expressions of various types, we also 27 // have folders that are used to build the *canonical* representation for a 28 // particular expression. These folders are capable of using a variety of 29 // rewrite rules to simplify the expressions. 30 // 31 // Once the folders are defined, we can implement the more interesting 32 // higher-level code, such as the code that recognizes PHI nodes of various 33 // types, computes the execution count of a loop, etc. 34 // 35 // TODO: We should use these routines and value representations to implement 36 // dependence analysis! 37 // 38 //===----------------------------------------------------------------------===// 39 // 40 // There are several good references for the techniques used in this analysis. 41 // 42 // Chains of recurrences -- a method to expedite the evaluation 43 // of closed-form functions 44 // Olaf Bachmann, Paul S. Wang, Eugene V. Zima 45 // 46 // On computational properties of chains of recurrences 47 // Eugene V. Zima 48 // 49 // Symbolic Evaluation of Chains of Recurrences for Loop Optimization 50 // Robert A. van Engelen 51 // 52 // Efficient Symbolic Analysis for Optimizing Compilers 53 // Robert A. van Engelen 54 // 55 // Using the chains of recurrences algebra for data dependence testing and 56 // induction variable substitution 57 // MS Thesis, Johnie Birch 58 // 59 //===----------------------------------------------------------------------===// 60 61 #include "llvm/Analysis/ScalarEvolution.h" 62 #include "llvm/ADT/Optional.h" 63 #include "llvm/ADT/STLExtras.h" 64 #include "llvm/ADT/SmallPtrSet.h" 65 #include "llvm/ADT/Statistic.h" 66 #include "llvm/Analysis/AssumptionCache.h" 67 #include "llvm/Analysis/ConstantFolding.h" 68 #include "llvm/Analysis/InstructionSimplify.h" 69 #include "llvm/Analysis/LoopInfo.h" 70 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 71 #include "llvm/Analysis/TargetLibraryInfo.h" 72 #include "llvm/Analysis/ValueTracking.h" 73 #include "llvm/IR/ConstantRange.h" 74 #include "llvm/IR/Constants.h" 75 #include "llvm/IR/DataLayout.h" 76 #include "llvm/IR/DerivedTypes.h" 77 #include "llvm/IR/Dominators.h" 78 #include "llvm/IR/GetElementPtrTypeIterator.h" 79 #include "llvm/IR/GlobalAlias.h" 80 #include "llvm/IR/GlobalVariable.h" 81 #include "llvm/IR/InstIterator.h" 82 #include "llvm/IR/Instructions.h" 83 #include "llvm/IR/LLVMContext.h" 84 #include "llvm/IR/Metadata.h" 85 #include "llvm/IR/Operator.h" 86 #include "llvm/Support/CommandLine.h" 87 #include "llvm/Support/Debug.h" 88 #include "llvm/Support/ErrorHandling.h" 89 #include "llvm/Support/MathExtras.h" 90 #include "llvm/Support/raw_ostream.h" 91 #include <algorithm> 92 using namespace llvm; 93 94 #define DEBUG_TYPE "scalar-evolution" 95 96 STATISTIC(NumArrayLenItCounts, 97 "Number of trip counts computed with array length"); 98 STATISTIC(NumTripCountsComputed, 99 "Number of loops with predictable loop counts"); 100 STATISTIC(NumTripCountsNotComputed, 101 "Number of loops without predictable loop counts"); 102 STATISTIC(NumBruteForceTripCountsComputed, 103 "Number of loops with trip counts computed by force"); 104 105 static cl::opt<unsigned> 106 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, 107 cl::desc("Maximum number of iterations SCEV will " 108 "symbolically execute a constant " 109 "derived loop"), 110 cl::init(100)); 111 112 // FIXME: Enable this with XDEBUG when the test suite is clean. 113 static cl::opt<bool> 114 VerifySCEV("verify-scev", 115 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)")); 116 117 INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution", 118 "Scalar Evolution Analysis", false, true) 119 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 120 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 121 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 122 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 123 INITIALIZE_PASS_END(ScalarEvolution, "scalar-evolution", 124 "Scalar Evolution Analysis", false, true) 125 char ScalarEvolution::ID = 0; 126 127 //===----------------------------------------------------------------------===// 128 // SCEV class definitions 129 //===----------------------------------------------------------------------===// 130 131 //===----------------------------------------------------------------------===// 132 // Implementation of the SCEV class. 133 // 134 135 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 136 void SCEV::dump() const { 137 print(dbgs()); 138 dbgs() << '\n'; 139 } 140 #endif 141 142 void SCEV::print(raw_ostream &OS) const { 143 switch (static_cast<SCEVTypes>(getSCEVType())) { 144 case scConstant: 145 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false); 146 return; 147 case scTruncate: { 148 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this); 149 const SCEV *Op = Trunc->getOperand(); 150 OS << "(trunc " << *Op->getType() << " " << *Op << " to " 151 << *Trunc->getType() << ")"; 152 return; 153 } 154 case scZeroExtend: { 155 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this); 156 const SCEV *Op = ZExt->getOperand(); 157 OS << "(zext " << *Op->getType() << " " << *Op << " to " 158 << *ZExt->getType() << ")"; 159 return; 160 } 161 case scSignExtend: { 162 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this); 163 const SCEV *Op = SExt->getOperand(); 164 OS << "(sext " << *Op->getType() << " " << *Op << " to " 165 << *SExt->getType() << ")"; 166 return; 167 } 168 case scAddRecExpr: { 169 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this); 170 OS << "{" << *AR->getOperand(0); 171 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i) 172 OS << ",+," << *AR->getOperand(i); 173 OS << "}<"; 174 if (AR->getNoWrapFlags(FlagNUW)) 175 OS << "nuw><"; 176 if (AR->getNoWrapFlags(FlagNSW)) 177 OS << "nsw><"; 178 if (AR->getNoWrapFlags(FlagNW) && 179 !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW))) 180 OS << "nw><"; 181 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false); 182 OS << ">"; 183 return; 184 } 185 case scAddExpr: 186 case scMulExpr: 187 case scUMaxExpr: 188 case scSMaxExpr: { 189 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this); 190 const char *OpStr = nullptr; 191 switch (NAry->getSCEVType()) { 192 case scAddExpr: OpStr = " + "; break; 193 case scMulExpr: OpStr = " * "; break; 194 case scUMaxExpr: OpStr = " umax "; break; 195 case scSMaxExpr: OpStr = " smax "; break; 196 } 197 OS << "("; 198 for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); 199 I != E; ++I) { 200 OS << **I; 201 if (std::next(I) != E) 202 OS << OpStr; 203 } 204 OS << ")"; 205 switch (NAry->getSCEVType()) { 206 case scAddExpr: 207 case scMulExpr: 208 if (NAry->getNoWrapFlags(FlagNUW)) 209 OS << "<nuw>"; 210 if (NAry->getNoWrapFlags(FlagNSW)) 211 OS << "<nsw>"; 212 } 213 return; 214 } 215 case scUDivExpr: { 216 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this); 217 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")"; 218 return; 219 } 220 case scUnknown: { 221 const SCEVUnknown *U = cast<SCEVUnknown>(this); 222 Type *AllocTy; 223 if (U->isSizeOf(AllocTy)) { 224 OS << "sizeof(" << *AllocTy << ")"; 225 return; 226 } 227 if (U->isAlignOf(AllocTy)) { 228 OS << "alignof(" << *AllocTy << ")"; 229 return; 230 } 231 232 Type *CTy; 233 Constant *FieldNo; 234 if (U->isOffsetOf(CTy, FieldNo)) { 235 OS << "offsetof(" << *CTy << ", "; 236 FieldNo->printAsOperand(OS, false); 237 OS << ")"; 238 return; 239 } 240 241 // Otherwise just print it normally. 242 U->getValue()->printAsOperand(OS, false); 243 return; 244 } 245 case scCouldNotCompute: 246 OS << "***COULDNOTCOMPUTE***"; 247 return; 248 } 249 llvm_unreachable("Unknown SCEV kind!"); 250 } 251 252 Type *SCEV::getType() const { 253 switch (static_cast<SCEVTypes>(getSCEVType())) { 254 case scConstant: 255 return cast<SCEVConstant>(this)->getType(); 256 case scTruncate: 257 case scZeroExtend: 258 case scSignExtend: 259 return cast<SCEVCastExpr>(this)->getType(); 260 case scAddRecExpr: 261 case scMulExpr: 262 case scUMaxExpr: 263 case scSMaxExpr: 264 return cast<SCEVNAryExpr>(this)->getType(); 265 case scAddExpr: 266 return cast<SCEVAddExpr>(this)->getType(); 267 case scUDivExpr: 268 return cast<SCEVUDivExpr>(this)->getType(); 269 case scUnknown: 270 return cast<SCEVUnknown>(this)->getType(); 271 case scCouldNotCompute: 272 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 273 } 274 llvm_unreachable("Unknown SCEV kind!"); 275 } 276 277 bool SCEV::isZero() const { 278 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) 279 return SC->getValue()->isZero(); 280 return false; 281 } 282 283 bool SCEV::isOne() const { 284 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) 285 return SC->getValue()->isOne(); 286 return false; 287 } 288 289 bool SCEV::isAllOnesValue() const { 290 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) 291 return SC->getValue()->isAllOnesValue(); 292 return false; 293 } 294 295 /// isNonConstantNegative - Return true if the specified scev is negated, but 296 /// not a constant. 297 bool SCEV::isNonConstantNegative() const { 298 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this); 299 if (!Mul) return false; 300 301 // If there is a constant factor, it will be first. 302 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0)); 303 if (!SC) return false; 304 305 // Return true if the value is negative, this matches things like (-42 * V). 306 return SC->getValue()->getValue().isNegative(); 307 } 308 309 SCEVCouldNotCompute::SCEVCouldNotCompute() : 310 SCEV(FoldingSetNodeIDRef(), scCouldNotCompute) {} 311 312 bool SCEVCouldNotCompute::classof(const SCEV *S) { 313 return S->getSCEVType() == scCouldNotCompute; 314 } 315 316 const SCEV *ScalarEvolution::getConstant(ConstantInt *V) { 317 FoldingSetNodeID ID; 318 ID.AddInteger(scConstant); 319 ID.AddPointer(V); 320 void *IP = nullptr; 321 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 322 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V); 323 UniqueSCEVs.InsertNode(S, IP); 324 return S; 325 } 326 327 const SCEV *ScalarEvolution::getConstant(const APInt &Val) { 328 return getConstant(ConstantInt::get(getContext(), Val)); 329 } 330 331 const SCEV * 332 ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { 333 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty)); 334 return getConstant(ConstantInt::get(ITy, V, isSigned)); 335 } 336 337 SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, 338 unsigned SCEVTy, const SCEV *op, Type *ty) 339 : SCEV(ID, SCEVTy), Op(op), Ty(ty) {} 340 341 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, 342 const SCEV *op, Type *ty) 343 : SCEVCastExpr(ID, scTruncate, op, ty) { 344 assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && 345 (Ty->isIntegerTy() || Ty->isPointerTy()) && 346 "Cannot truncate non-integer value!"); 347 } 348 349 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, 350 const SCEV *op, Type *ty) 351 : SCEVCastExpr(ID, scZeroExtend, op, ty) { 352 assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && 353 (Ty->isIntegerTy() || Ty->isPointerTy()) && 354 "Cannot zero extend non-integer value!"); 355 } 356 357 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, 358 const SCEV *op, Type *ty) 359 : SCEVCastExpr(ID, scSignExtend, op, ty) { 360 assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && 361 (Ty->isIntegerTy() || Ty->isPointerTy()) && 362 "Cannot sign extend non-integer value!"); 363 } 364 365 void SCEVUnknown::deleted() { 366 // Clear this SCEVUnknown from various maps. 367 SE->forgetMemoizedResults(this); 368 369 // Remove this SCEVUnknown from the uniquing map. 370 SE->UniqueSCEVs.RemoveNode(this); 371 372 // Release the value. 373 setValPtr(nullptr); 374 } 375 376 void SCEVUnknown::allUsesReplacedWith(Value *New) { 377 // Clear this SCEVUnknown from various maps. 378 SE->forgetMemoizedResults(this); 379 380 // Remove this SCEVUnknown from the uniquing map. 381 SE->UniqueSCEVs.RemoveNode(this); 382 383 // Update this SCEVUnknown to point to the new value. This is needed 384 // because there may still be outstanding SCEVs which still point to 385 // this SCEVUnknown. 386 setValPtr(New); 387 } 388 389 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const { 390 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) 391 if (VCE->getOpcode() == Instruction::PtrToInt) 392 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) 393 if (CE->getOpcode() == Instruction::GetElementPtr && 394 CE->getOperand(0)->isNullValue() && 395 CE->getNumOperands() == 2) 396 if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1))) 397 if (CI->isOne()) { 398 AllocTy = cast<PointerType>(CE->getOperand(0)->getType()) 399 ->getElementType(); 400 return true; 401 } 402 403 return false; 404 } 405 406 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const { 407 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) 408 if (VCE->getOpcode() == Instruction::PtrToInt) 409 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) 410 if (CE->getOpcode() == Instruction::GetElementPtr && 411 CE->getOperand(0)->isNullValue()) { 412 Type *Ty = 413 cast<PointerType>(CE->getOperand(0)->getType())->getElementType(); 414 if (StructType *STy = dyn_cast<StructType>(Ty)) 415 if (!STy->isPacked() && 416 CE->getNumOperands() == 3 && 417 CE->getOperand(1)->isNullValue()) { 418 if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2))) 419 if (CI->isOne() && 420 STy->getNumElements() == 2 && 421 STy->getElementType(0)->isIntegerTy(1)) { 422 AllocTy = STy->getElementType(1); 423 return true; 424 } 425 } 426 } 427 428 return false; 429 } 430 431 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const { 432 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) 433 if (VCE->getOpcode() == Instruction::PtrToInt) 434 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) 435 if (CE->getOpcode() == Instruction::GetElementPtr && 436 CE->getNumOperands() == 3 && 437 CE->getOperand(0)->isNullValue() && 438 CE->getOperand(1)->isNullValue()) { 439 Type *Ty = 440 cast<PointerType>(CE->getOperand(0)->getType())->getElementType(); 441 // Ignore vector types here so that ScalarEvolutionExpander doesn't 442 // emit getelementptrs that index into vectors. 443 if (Ty->isStructTy() || Ty->isArrayTy()) { 444 CTy = Ty; 445 FieldNo = CE->getOperand(2); 446 return true; 447 } 448 } 449 450 return false; 451 } 452 453 //===----------------------------------------------------------------------===// 454 // SCEV Utilities 455 //===----------------------------------------------------------------------===// 456 457 namespace { 458 /// SCEVComplexityCompare - Return true if the complexity of the LHS is less 459 /// than the complexity of the RHS. This comparator is used to canonicalize 460 /// expressions. 461 class SCEVComplexityCompare { 462 const LoopInfo *const LI; 463 public: 464 explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {} 465 466 // Return true or false if LHS is less than, or at least RHS, respectively. 467 bool operator()(const SCEV *LHS, const SCEV *RHS) const { 468 return compare(LHS, RHS) < 0; 469 } 470 471 // Return negative, zero, or positive, if LHS is less than, equal to, or 472 // greater than RHS, respectively. A three-way result allows recursive 473 // comparisons to be more efficient. 474 int compare(const SCEV *LHS, const SCEV *RHS) const { 475 // Fast-path: SCEVs are uniqued so we can do a quick equality check. 476 if (LHS == RHS) 477 return 0; 478 479 // Primarily, sort the SCEVs by their getSCEVType(). 480 unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); 481 if (LType != RType) 482 return (int)LType - (int)RType; 483 484 // Aside from the getSCEVType() ordering, the particular ordering 485 // isn't very important except that it's beneficial to be consistent, 486 // so that (a + b) and (b + a) don't end up as different expressions. 487 switch (static_cast<SCEVTypes>(LType)) { 488 case scUnknown: { 489 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS); 490 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS); 491 492 // Sort SCEVUnknown values with some loose heuristics. TODO: This is 493 // not as complete as it could be. 494 const Value *LV = LU->getValue(), *RV = RU->getValue(); 495 496 // Order pointer values after integer values. This helps SCEVExpander 497 // form GEPs. 498 bool LIsPointer = LV->getType()->isPointerTy(), 499 RIsPointer = RV->getType()->isPointerTy(); 500 if (LIsPointer != RIsPointer) 501 return (int)LIsPointer - (int)RIsPointer; 502 503 // Compare getValueID values. 504 unsigned LID = LV->getValueID(), 505 RID = RV->getValueID(); 506 if (LID != RID) 507 return (int)LID - (int)RID; 508 509 // Sort arguments by their position. 510 if (const Argument *LA = dyn_cast<Argument>(LV)) { 511 const Argument *RA = cast<Argument>(RV); 512 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); 513 return (int)LArgNo - (int)RArgNo; 514 } 515 516 // For instructions, compare their loop depth, and their operand 517 // count. This is pretty loose. 518 if (const Instruction *LInst = dyn_cast<Instruction>(LV)) { 519 const Instruction *RInst = cast<Instruction>(RV); 520 521 // Compare loop depths. 522 const BasicBlock *LParent = LInst->getParent(), 523 *RParent = RInst->getParent(); 524 if (LParent != RParent) { 525 unsigned LDepth = LI->getLoopDepth(LParent), 526 RDepth = LI->getLoopDepth(RParent); 527 if (LDepth != RDepth) 528 return (int)LDepth - (int)RDepth; 529 } 530 531 // Compare the number of operands. 532 unsigned LNumOps = LInst->getNumOperands(), 533 RNumOps = RInst->getNumOperands(); 534 return (int)LNumOps - (int)RNumOps; 535 } 536 537 return 0; 538 } 539 540 case scConstant: { 541 const SCEVConstant *LC = cast<SCEVConstant>(LHS); 542 const SCEVConstant *RC = cast<SCEVConstant>(RHS); 543 544 // Compare constant values. 545 const APInt &LA = LC->getValue()->getValue(); 546 const APInt &RA = RC->getValue()->getValue(); 547 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); 548 if (LBitWidth != RBitWidth) 549 return (int)LBitWidth - (int)RBitWidth; 550 return LA.ult(RA) ? -1 : 1; 551 } 552 553 case scAddRecExpr: { 554 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS); 555 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS); 556 557 // Compare addrec loop depths. 558 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); 559 if (LLoop != RLoop) { 560 unsigned LDepth = LLoop->getLoopDepth(), 561 RDepth = RLoop->getLoopDepth(); 562 if (LDepth != RDepth) 563 return (int)LDepth - (int)RDepth; 564 } 565 566 // Addrec complexity grows with operand count. 567 unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); 568 if (LNumOps != RNumOps) 569 return (int)LNumOps - (int)RNumOps; 570 571 // Lexicographically compare. 572 for (unsigned i = 0; i != LNumOps; ++i) { 573 long X = compare(LA->getOperand(i), RA->getOperand(i)); 574 if (X != 0) 575 return X; 576 } 577 578 return 0; 579 } 580 581 case scAddExpr: 582 case scMulExpr: 583 case scSMaxExpr: 584 case scUMaxExpr: { 585 const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS); 586 const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS); 587 588 // Lexicographically compare n-ary expressions. 589 unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); 590 if (LNumOps != RNumOps) 591 return (int)LNumOps - (int)RNumOps; 592 593 for (unsigned i = 0; i != LNumOps; ++i) { 594 if (i >= RNumOps) 595 return 1; 596 long X = compare(LC->getOperand(i), RC->getOperand(i)); 597 if (X != 0) 598 return X; 599 } 600 return (int)LNumOps - (int)RNumOps; 601 } 602 603 case scUDivExpr: { 604 const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS); 605 const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS); 606 607 // Lexicographically compare udiv expressions. 608 long X = compare(LC->getLHS(), RC->getLHS()); 609 if (X != 0) 610 return X; 611 return compare(LC->getRHS(), RC->getRHS()); 612 } 613 614 case scTruncate: 615 case scZeroExtend: 616 case scSignExtend: { 617 const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS); 618 const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS); 619 620 // Compare cast expressions by operand. 621 return compare(LC->getOperand(), RC->getOperand()); 622 } 623 624 case scCouldNotCompute: 625 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 626 } 627 llvm_unreachable("Unknown SCEV kind!"); 628 } 629 }; 630 } 631 632 /// GroupByComplexity - Given a list of SCEV objects, order them by their 633 /// complexity, and group objects of the same complexity together by value. 634 /// When this routine is finished, we know that any duplicates in the vector are 635 /// consecutive and that complexity is monotonically increasing. 636 /// 637 /// Note that we go take special precautions to ensure that we get deterministic 638 /// results from this routine. In other words, we don't want the results of 639 /// this to depend on where the addresses of various SCEV objects happened to 640 /// land in memory. 641 /// 642 static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, 643 LoopInfo *LI) { 644 if (Ops.size() < 2) return; // Noop 645 if (Ops.size() == 2) { 646 // This is the common case, which also happens to be trivially simple. 647 // Special case it. 648 const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; 649 if (SCEVComplexityCompare(LI)(RHS, LHS)) 650 std::swap(LHS, RHS); 651 return; 652 } 653 654 // Do the rough sort by complexity. 655 std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI)); 656 657 // Now that we are sorted by complexity, group elements of the same 658 // complexity. Note that this is, at worst, N^2, but the vector is likely to 659 // be extremely short in practice. Note that we take this approach because we 660 // do not want to depend on the addresses of the objects we are grouping. 661 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) { 662 const SCEV *S = Ops[i]; 663 unsigned Complexity = S->getSCEVType(); 664 665 // If there are any objects of the same complexity and same value as this 666 // one, group them. 667 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) { 668 if (Ops[j] == S) { // Found a duplicate. 669 // Move it to immediately after i'th element. 670 std::swap(Ops[i+1], Ops[j]); 671 ++i; // no need to rescan it. 672 if (i == e-2) return; // Done! 673 } 674 } 675 } 676 } 677 678 namespace { 679 struct FindSCEVSize { 680 int Size; 681 FindSCEVSize() : Size(0) {} 682 683 bool follow(const SCEV *S) { 684 ++Size; 685 // Keep looking at all operands of S. 686 return true; 687 } 688 bool isDone() const { 689 return false; 690 } 691 }; 692 } 693 694 // Returns the size of the SCEV S. 695 static inline int sizeOfSCEV(const SCEV *S) { 696 FindSCEVSize F; 697 SCEVTraversal<FindSCEVSize> ST(F); 698 ST.visitAll(S); 699 return F.Size; 700 } 701 702 namespace { 703 704 struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> { 705 public: 706 // Computes the Quotient and Remainder of the division of Numerator by 707 // Denominator. 708 static void divide(ScalarEvolution &SE, const SCEV *Numerator, 709 const SCEV *Denominator, const SCEV **Quotient, 710 const SCEV **Remainder) { 711 assert(Numerator && Denominator && "Uninitialized SCEV"); 712 713 SCEVDivision D(SE, Numerator, Denominator); 714 715 // Check for the trivial case here to avoid having to check for it in the 716 // rest of the code. 717 if (Numerator == Denominator) { 718 *Quotient = D.One; 719 *Remainder = D.Zero; 720 return; 721 } 722 723 if (Numerator->isZero()) { 724 *Quotient = D.Zero; 725 *Remainder = D.Zero; 726 return; 727 } 728 729 // Split the Denominator when it is a product. 730 if (const SCEVMulExpr *T = dyn_cast<const SCEVMulExpr>(Denominator)) { 731 const SCEV *Q, *R; 732 *Quotient = Numerator; 733 for (const SCEV *Op : T->operands()) { 734 divide(SE, *Quotient, Op, &Q, &R); 735 *Quotient = Q; 736 737 // Bail out when the Numerator is not divisible by one of the terms of 738 // the Denominator. 739 if (!R->isZero()) { 740 *Quotient = D.Zero; 741 *Remainder = Numerator; 742 return; 743 } 744 } 745 *Remainder = D.Zero; 746 return; 747 } 748 749 D.visit(Numerator); 750 *Quotient = D.Quotient; 751 *Remainder = D.Remainder; 752 } 753 754 // Except in the trivial case described above, we do not know how to divide 755 // Expr by Denominator for the following functions with empty implementation. 756 void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} 757 void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} 758 void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} 759 void visitUDivExpr(const SCEVUDivExpr *Numerator) {} 760 void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} 761 void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} 762 void visitUnknown(const SCEVUnknown *Numerator) {} 763 void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} 764 765 void visitConstant(const SCEVConstant *Numerator) { 766 if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { 767 APInt NumeratorVal = Numerator->getValue()->getValue(); 768 APInt DenominatorVal = D->getValue()->getValue(); 769 uint32_t NumeratorBW = NumeratorVal.getBitWidth(); 770 uint32_t DenominatorBW = DenominatorVal.getBitWidth(); 771 772 if (NumeratorBW > DenominatorBW) 773 DenominatorVal = DenominatorVal.sext(NumeratorBW); 774 else if (NumeratorBW < DenominatorBW) 775 NumeratorVal = NumeratorVal.sext(DenominatorBW); 776 777 APInt QuotientVal(NumeratorVal.getBitWidth(), 0); 778 APInt RemainderVal(NumeratorVal.getBitWidth(), 0); 779 APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); 780 Quotient = SE.getConstant(QuotientVal); 781 Remainder = SE.getConstant(RemainderVal); 782 return; 783 } 784 } 785 786 void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { 787 const SCEV *StartQ, *StartR, *StepQ, *StepR; 788 assert(Numerator->isAffine() && "Numerator should be affine"); 789 divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); 790 divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); 791 Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), 792 Numerator->getNoWrapFlags()); 793 Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), 794 Numerator->getNoWrapFlags()); 795 } 796 797 void visitAddExpr(const SCEVAddExpr *Numerator) { 798 SmallVector<const SCEV *, 2> Qs, Rs; 799 Type *Ty = Denominator->getType(); 800 801 for (const SCEV *Op : Numerator->operands()) { 802 const SCEV *Q, *R; 803 divide(SE, Op, Denominator, &Q, &R); 804 805 // Bail out if types do not match. 806 if (Ty != Q->getType() || Ty != R->getType()) { 807 Quotient = Zero; 808 Remainder = Numerator; 809 return; 810 } 811 812 Qs.push_back(Q); 813 Rs.push_back(R); 814 } 815 816 if (Qs.size() == 1) { 817 Quotient = Qs[0]; 818 Remainder = Rs[0]; 819 return; 820 } 821 822 Quotient = SE.getAddExpr(Qs); 823 Remainder = SE.getAddExpr(Rs); 824 } 825 826 void visitMulExpr(const SCEVMulExpr *Numerator) { 827 SmallVector<const SCEV *, 2> Qs; 828 Type *Ty = Denominator->getType(); 829 830 bool FoundDenominatorTerm = false; 831 for (const SCEV *Op : Numerator->operands()) { 832 // Bail out if types do not match. 833 if (Ty != Op->getType()) { 834 Quotient = Zero; 835 Remainder = Numerator; 836 return; 837 } 838 839 if (FoundDenominatorTerm) { 840 Qs.push_back(Op); 841 continue; 842 } 843 844 // Check whether Denominator divides one of the product operands. 845 const SCEV *Q, *R; 846 divide(SE, Op, Denominator, &Q, &R); 847 if (!R->isZero()) { 848 Qs.push_back(Op); 849 continue; 850 } 851 852 // Bail out if types do not match. 853 if (Ty != Q->getType()) { 854 Quotient = Zero; 855 Remainder = Numerator; 856 return; 857 } 858 859 FoundDenominatorTerm = true; 860 Qs.push_back(Q); 861 } 862 863 if (FoundDenominatorTerm) { 864 Remainder = Zero; 865 if (Qs.size() == 1) 866 Quotient = Qs[0]; 867 else 868 Quotient = SE.getMulExpr(Qs); 869 return; 870 } 871 872 if (!isa<SCEVUnknown>(Denominator)) { 873 Quotient = Zero; 874 Remainder = Numerator; 875 return; 876 } 877 878 // The Remainder is obtained by replacing Denominator by 0 in Numerator. 879 ValueToValueMap RewriteMap; 880 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = 881 cast<SCEVConstant>(Zero)->getValue(); 882 Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); 883 884 if (Remainder->isZero()) { 885 // The Quotient is obtained by replacing Denominator by 1 in Numerator. 886 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = 887 cast<SCEVConstant>(One)->getValue(); 888 Quotient = 889 SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); 890 return; 891 } 892 893 // Quotient is (Numerator - Remainder) divided by Denominator. 894 const SCEV *Q, *R; 895 const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); 896 if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) { 897 // This SCEV does not seem to simplify: fail the division here. 898 Quotient = Zero; 899 Remainder = Numerator; 900 return; 901 } 902 divide(SE, Diff, Denominator, &Q, &R); 903 assert(R == Zero && 904 "(Numerator - Remainder) should evenly divide Denominator"); 905 Quotient = Q; 906 } 907 908 private: 909 SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, 910 const SCEV *Denominator) 911 : SE(S), Denominator(Denominator) { 912 Zero = SE.getConstant(Denominator->getType(), 0); 913 One = SE.getConstant(Denominator->getType(), 1); 914 915 // By default, we don't know how to divide Expr by Denominator. 916 // Providing the default here simplifies the rest of the code. 917 Quotient = Zero; 918 Remainder = Numerator; 919 } 920 921 ScalarEvolution &SE; 922 const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; 923 }; 924 925 } 926 927 //===----------------------------------------------------------------------===// 928 // Simple SCEV method implementations 929 //===----------------------------------------------------------------------===// 930 931 /// BinomialCoefficient - Compute BC(It, K). The result has width W. 932 /// Assume, K > 0. 933 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, 934 ScalarEvolution &SE, 935 Type *ResultTy) { 936 // Handle the simplest case efficiently. 937 if (K == 1) 938 return SE.getTruncateOrZeroExtend(It, ResultTy); 939 940 // We are using the following formula for BC(It, K): 941 // 942 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K! 943 // 944 // Suppose, W is the bitwidth of the return value. We must be prepared for 945 // overflow. Hence, we must assure that the result of our computation is 946 // equal to the accurate one modulo 2^W. Unfortunately, division isn't 947 // safe in modular arithmetic. 948 // 949 // However, this code doesn't use exactly that formula; the formula it uses 950 // is something like the following, where T is the number of factors of 2 in 951 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is 952 // exponentiation: 953 // 954 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T) 955 // 956 // This formula is trivially equivalent to the previous formula. However, 957 // this formula can be implemented much more efficiently. The trick is that 958 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular 959 // arithmetic. To do exact division in modular arithmetic, all we have 960 // to do is multiply by the inverse. Therefore, this step can be done at 961 // width W. 962 // 963 // The next issue is how to safely do the division by 2^T. The way this 964 // is done is by doing the multiplication step at a width of at least W + T 965 // bits. This way, the bottom W+T bits of the product are accurate. Then, 966 // when we perform the division by 2^T (which is equivalent to a right shift 967 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get 968 // truncated out after the division by 2^T. 969 // 970 // In comparison to just directly using the first formula, this technique 971 // is much more efficient; using the first formula requires W * K bits, 972 // but this formula less than W + K bits. Also, the first formula requires 973 // a division step, whereas this formula only requires multiplies and shifts. 974 // 975 // It doesn't matter whether the subtraction step is done in the calculation 976 // width or the input iteration count's width; if the subtraction overflows, 977 // the result must be zero anyway. We prefer here to do it in the width of 978 // the induction variable because it helps a lot for certain cases; CodeGen 979 // isn't smart enough to ignore the overflow, which leads to much less 980 // efficient code if the width of the subtraction is wider than the native 981 // register width. 982 // 983 // (It's possible to not widen at all by pulling out factors of 2 before 984 // the multiplication; for example, K=2 can be calculated as 985 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires 986 // extra arithmetic, so it's not an obvious win, and it gets 987 // much more complicated for K > 3.) 988 989 // Protection from insane SCEVs; this bound is conservative, 990 // but it probably doesn't matter. 991 if (K > 1000) 992 return SE.getCouldNotCompute(); 993 994 unsigned W = SE.getTypeSizeInBits(ResultTy); 995 996 // Calculate K! / 2^T and T; we divide out the factors of two before 997 // multiplying for calculating K! / 2^T to avoid overflow. 998 // Other overflow doesn't matter because we only care about the bottom 999 // W bits of the result. 1000 APInt OddFactorial(W, 1); 1001 unsigned T = 1; 1002 for (unsigned i = 3; i <= K; ++i) { 1003 APInt Mult(W, i); 1004 unsigned TwoFactors = Mult.countTrailingZeros(); 1005 T += TwoFactors; 1006 Mult = Mult.lshr(TwoFactors); 1007 OddFactorial *= Mult; 1008 } 1009 1010 // We need at least W + T bits for the multiplication step 1011 unsigned CalculationBits = W + T; 1012 1013 // Calculate 2^T, at width T+W. 1014 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T); 1015 1016 // Calculate the multiplicative inverse of K! / 2^T; 1017 // this multiplication factor will perform the exact division by 1018 // K! / 2^T. 1019 APInt Mod = APInt::getSignedMinValue(W+1); 1020 APInt MultiplyFactor = OddFactorial.zext(W+1); 1021 MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod); 1022 MultiplyFactor = MultiplyFactor.trunc(W); 1023 1024 // Calculate the product, at width T+W 1025 IntegerType *CalculationTy = IntegerType::get(SE.getContext(), 1026 CalculationBits); 1027 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); 1028 for (unsigned i = 1; i != K; ++i) { 1029 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i)); 1030 Dividend = SE.getMulExpr(Dividend, 1031 SE.getTruncateOrZeroExtend(S, CalculationTy)); 1032 } 1033 1034 // Divide by 2^T 1035 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); 1036 1037 // Truncate the result, and divide by K! / 2^T. 1038 1039 return SE.getMulExpr(SE.getConstant(MultiplyFactor), 1040 SE.getTruncateOrZeroExtend(DivResult, ResultTy)); 1041 } 1042 1043 /// evaluateAtIteration - Return the value of this chain of recurrences at 1044 /// the specified iteration number. We can evaluate this recurrence by 1045 /// multiplying each element in the chain by the binomial coefficient 1046 /// corresponding to it. In other words, we can evaluate {A,+,B,+,C,+,D} as: 1047 /// 1048 /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) 1049 /// 1050 /// where BC(It, k) stands for binomial coefficient. 1051 /// 1052 const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It, 1053 ScalarEvolution &SE) const { 1054 const SCEV *Result = getStart(); 1055 for (unsigned i = 1, e = getNumOperands(); i != e; ++i) { 1056 // The computation is correct in the face of overflow provided that the 1057 // multiplication is performed _after_ the evaluation of the binomial 1058 // coefficient. 1059 const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType()); 1060 if (isa<SCEVCouldNotCompute>(Coeff)) 1061 return Coeff; 1062 1063 Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff)); 1064 } 1065 return Result; 1066 } 1067 1068 //===----------------------------------------------------------------------===// 1069 // SCEV Expression folder implementations 1070 //===----------------------------------------------------------------------===// 1071 1072 const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, 1073 Type *Ty) { 1074 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && 1075 "This is not a truncating conversion!"); 1076 assert(isSCEVable(Ty) && 1077 "This is not a conversion to a SCEVable type!"); 1078 Ty = getEffectiveSCEVType(Ty); 1079 1080 FoldingSetNodeID ID; 1081 ID.AddInteger(scTruncate); 1082 ID.AddPointer(Op); 1083 ID.AddPointer(Ty); 1084 void *IP = nullptr; 1085 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 1086 1087 // Fold if the operand is constant. 1088 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) 1089 return getConstant( 1090 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty))); 1091 1092 // trunc(trunc(x)) --> trunc(x) 1093 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) 1094 return getTruncateExpr(ST->getOperand(), Ty); 1095 1096 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing 1097 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op)) 1098 return getTruncateOrSignExtend(SS->getOperand(), Ty); 1099 1100 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing 1101 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) 1102 return getTruncateOrZeroExtend(SZ->getOperand(), Ty); 1103 1104 // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can 1105 // eliminate all the truncates. 1106 if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) { 1107 SmallVector<const SCEV *, 4> Operands; 1108 bool hasTrunc = false; 1109 for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) { 1110 const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty); 1111 hasTrunc = isa<SCEVTruncateExpr>(S); 1112 Operands.push_back(S); 1113 } 1114 if (!hasTrunc) 1115 return getAddExpr(Operands); 1116 UniqueSCEVs.FindNodeOrInsertPos(ID, IP); // Mutates IP, returns NULL. 1117 } 1118 1119 // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can 1120 // eliminate all the truncates. 1121 if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) { 1122 SmallVector<const SCEV *, 4> Operands; 1123 bool hasTrunc = false; 1124 for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) { 1125 const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty); 1126 hasTrunc = isa<SCEVTruncateExpr>(S); 1127 Operands.push_back(S); 1128 } 1129 if (!hasTrunc) 1130 return getMulExpr(Operands); 1131 UniqueSCEVs.FindNodeOrInsertPos(ID, IP); // Mutates IP, returns NULL. 1132 } 1133 1134 // If the input value is a chrec scev, truncate the chrec's operands. 1135 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) { 1136 SmallVector<const SCEV *, 4> Operands; 1137 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) 1138 Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty)); 1139 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap); 1140 } 1141 1142 // The cast wasn't folded; create an explicit cast node. We can reuse 1143 // the existing insert position since if we get here, we won't have 1144 // made any changes which would invalidate it. 1145 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), 1146 Op, Ty); 1147 UniqueSCEVs.InsertNode(S, IP); 1148 return S; 1149 } 1150 1151 // Get the limit of a recurrence such that incrementing by Step cannot cause 1152 // signed overflow as long as the value of the recurrence within the 1153 // loop does not exceed this limit before incrementing. 1154 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, 1155 ICmpInst::Predicate *Pred, 1156 ScalarEvolution *SE) { 1157 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); 1158 if (SE->isKnownPositive(Step)) { 1159 *Pred = ICmpInst::ICMP_SLT; 1160 return SE->getConstant(APInt::getSignedMinValue(BitWidth) - 1161 SE->getSignedRange(Step).getSignedMax()); 1162 } 1163 if (SE->isKnownNegative(Step)) { 1164 *Pred = ICmpInst::ICMP_SGT; 1165 return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - 1166 SE->getSignedRange(Step).getSignedMin()); 1167 } 1168 return nullptr; 1169 } 1170 1171 // Get the limit of a recurrence such that incrementing by Step cannot cause 1172 // unsigned overflow as long as the value of the recurrence within the loop does 1173 // not exceed this limit before incrementing. 1174 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, 1175 ICmpInst::Predicate *Pred, 1176 ScalarEvolution *SE) { 1177 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); 1178 *Pred = ICmpInst::ICMP_ULT; 1179 1180 return SE->getConstant(APInt::getMinValue(BitWidth) - 1181 SE->getUnsignedRange(Step).getUnsignedMax()); 1182 } 1183 1184 namespace { 1185 1186 struct ExtendOpTraitsBase { 1187 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *); 1188 }; 1189 1190 // Used to make code generic over signed and unsigned overflow. 1191 template <typename ExtendOp> struct ExtendOpTraits { 1192 // Members present: 1193 // 1194 // static const SCEV::NoWrapFlags WrapType; 1195 // 1196 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr; 1197 // 1198 // static const SCEV *getOverflowLimitForStep(const SCEV *Step, 1199 // ICmpInst::Predicate *Pred, 1200 // ScalarEvolution *SE); 1201 }; 1202 1203 template <> 1204 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase { 1205 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW; 1206 1207 static const GetExtendExprTy GetExtendExpr; 1208 1209 static const SCEV *getOverflowLimitForStep(const SCEV *Step, 1210 ICmpInst::Predicate *Pred, 1211 ScalarEvolution *SE) { 1212 return getSignedOverflowLimitForStep(Step, Pred, SE); 1213 } 1214 }; 1215 1216 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< 1217 SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr; 1218 1219 template <> 1220 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase { 1221 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW; 1222 1223 static const GetExtendExprTy GetExtendExpr; 1224 1225 static const SCEV *getOverflowLimitForStep(const SCEV *Step, 1226 ICmpInst::Predicate *Pred, 1227 ScalarEvolution *SE) { 1228 return getUnsignedOverflowLimitForStep(Step, Pred, SE); 1229 } 1230 }; 1231 1232 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< 1233 SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr; 1234 } 1235 1236 // The recurrence AR has been shown to have no signed/unsigned wrap or something 1237 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as 1238 // easily prove NSW/NUW for its preincrement or postincrement sibling. This 1239 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step + 1240 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the 1241 // expression "Step + sext/zext(PreIncAR)" is congruent with 1242 // "sext/zext(PostIncAR)" 1243 template <typename ExtendOpTy> 1244 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, 1245 ScalarEvolution *SE) { 1246 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType; 1247 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; 1248 1249 const Loop *L = AR->getLoop(); 1250 const SCEV *Start = AR->getStart(); 1251 const SCEV *Step = AR->getStepRecurrence(*SE); 1252 1253 // Check for a simple looking step prior to loop entry. 1254 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start); 1255 if (!SA) 1256 return nullptr; 1257 1258 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV 1259 // subtraction is expensive. For this purpose, perform a quick and dirty 1260 // difference, by checking for Step in the operand list. 1261 SmallVector<const SCEV *, 4> DiffOps; 1262 for (const SCEV *Op : SA->operands()) 1263 if (Op != Step) 1264 DiffOps.push_back(Op); 1265 1266 if (DiffOps.size() == SA->getNumOperands()) 1267 return nullptr; 1268 1269 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` + 1270 // `Step`: 1271 1272 // 1. NSW/NUW flags on the step increment. 1273 const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags()); 1274 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>( 1275 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); 1276 1277 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies 1278 // "S+X does not sign/unsign-overflow". 1279 // 1280 1281 const SCEV *BECount = SE->getBackedgeTakenCount(L); 1282 if (PreAR && PreAR->getNoWrapFlags(WrapType) && 1283 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount)) 1284 return PreStart; 1285 1286 // 2. Direct overflow check on the step operation's expression. 1287 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); 1288 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); 1289 const SCEV *OperandExtendedStart = 1290 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy), 1291 (SE->*GetExtendExpr)(Step, WideTy)); 1292 if ((SE->*GetExtendExpr)(Start, WideTy) == OperandExtendedStart) { 1293 if (PreAR && AR->getNoWrapFlags(WrapType)) { 1294 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW 1295 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then 1296 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact. 1297 const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(WrapType); 1298 } 1299 return PreStart; 1300 } 1301 1302 // 3. Loop precondition. 1303 ICmpInst::Predicate Pred; 1304 const SCEV *OverflowLimit = 1305 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE); 1306 1307 if (OverflowLimit && 1308 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) { 1309 return PreStart; 1310 } 1311 return nullptr; 1312 } 1313 1314 // Get the normalized zero or sign extended expression for this AddRec's Start. 1315 template <typename ExtendOpTy> 1316 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, 1317 ScalarEvolution *SE) { 1318 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; 1319 1320 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE); 1321 if (!PreStart) 1322 return (SE->*GetExtendExpr)(AR->getStart(), Ty); 1323 1324 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty), 1325 (SE->*GetExtendExpr)(PreStart, Ty)); 1326 } 1327 1328 const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, 1329 Type *Ty) { 1330 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && 1331 "This is not an extending conversion!"); 1332 assert(isSCEVable(Ty) && 1333 "This is not a conversion to a SCEVable type!"); 1334 Ty = getEffectiveSCEVType(Ty); 1335 1336 // Fold if the operand is constant. 1337 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) 1338 return getConstant( 1339 cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty))); 1340 1341 // zext(zext(x)) --> zext(x) 1342 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) 1343 return getZeroExtendExpr(SZ->getOperand(), Ty); 1344 1345 // Before doing any expensive analysis, check to see if we've already 1346 // computed a SCEV for this Op and Ty. 1347 FoldingSetNodeID ID; 1348 ID.AddInteger(scZeroExtend); 1349 ID.AddPointer(Op); 1350 ID.AddPointer(Ty); 1351 void *IP = nullptr; 1352 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 1353 1354 // zext(trunc(x)) --> zext(x) or x or trunc(x) 1355 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { 1356 // It's possible the bits taken off by the truncate were all zero bits. If 1357 // so, we should be able to simplify this further. 1358 const SCEV *X = ST->getOperand(); 1359 ConstantRange CR = getUnsignedRange(X); 1360 unsigned TruncBits = getTypeSizeInBits(ST->getType()); 1361 unsigned NewBits = getTypeSizeInBits(Ty); 1362 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains( 1363 CR.zextOrTrunc(NewBits))) 1364 return getTruncateOrZeroExtend(X, Ty); 1365 } 1366 1367 // If the input value is a chrec scev, and we can prove that the value 1368 // did not overflow the old, smaller, value, we can zero extend all of the 1369 // operands (often constants). This allows analysis of something like 1370 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; } 1371 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) 1372 if (AR->isAffine()) { 1373 const SCEV *Start = AR->getStart(); 1374 const SCEV *Step = AR->getStepRecurrence(*this); 1375 unsigned BitWidth = getTypeSizeInBits(AR->getType()); 1376 const Loop *L = AR->getLoop(); 1377 1378 // If we have special knowledge that this addrec won't overflow, 1379 // we don't need to do any further analysis. 1380 if (AR->getNoWrapFlags(SCEV::FlagNUW)) 1381 return getAddRecExpr( 1382 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), 1383 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1384 1385 // Check whether the backedge-taken count is SCEVCouldNotCompute. 1386 // Note that this serves two purposes: It filters out loops that are 1387 // simply not analyzable, and it covers the case where this code is 1388 // being called from within backedge-taken count analysis, such that 1389 // attempting to ask for the backedge-taken count would likely result 1390 // in infinite recursion. In the later case, the analysis code will 1391 // cope with a conservative value, and it will take care to purge 1392 // that value once it has finished. 1393 const SCEV *MaxBECount = getMaxBackedgeTakenCount(L); 1394 if (!isa<SCEVCouldNotCompute>(MaxBECount)) { 1395 // Manually compute the final value for AR, checking for 1396 // overflow. 1397 1398 // Check whether the backedge-taken count can be losslessly casted to 1399 // the addrec's type. The count is always unsigned. 1400 const SCEV *CastedMaxBECount = 1401 getTruncateOrZeroExtend(MaxBECount, Start->getType()); 1402 const SCEV *RecastedMaxBECount = 1403 getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); 1404 if (MaxBECount == RecastedMaxBECount) { 1405 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); 1406 // Check whether Start+Step*MaxBECount has no unsigned overflow. 1407 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step); 1408 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul), WideTy); 1409 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy); 1410 const SCEV *WideMaxBECount = 1411 getZeroExtendExpr(CastedMaxBECount, WideTy); 1412 const SCEV *OperandExtendedAdd = 1413 getAddExpr(WideStart, 1414 getMulExpr(WideMaxBECount, 1415 getZeroExtendExpr(Step, WideTy))); 1416 if (ZAdd == OperandExtendedAdd) { 1417 // Cache knowledge of AR NUW, which is propagated to this AddRec. 1418 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); 1419 // Return the expression with the addrec on the outside. 1420 return getAddRecExpr( 1421 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), 1422 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1423 } 1424 // Similar to above, only this time treat the step value as signed. 1425 // This covers loops that count down. 1426 OperandExtendedAdd = 1427 getAddExpr(WideStart, 1428 getMulExpr(WideMaxBECount, 1429 getSignExtendExpr(Step, WideTy))); 1430 if (ZAdd == OperandExtendedAdd) { 1431 // Cache knowledge of AR NW, which is propagated to this AddRec. 1432 // Negative step causes unsigned wrap, but it still can't self-wrap. 1433 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); 1434 // Return the expression with the addrec on the outside. 1435 return getAddRecExpr( 1436 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), 1437 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1438 } 1439 } 1440 1441 // If the backedge is guarded by a comparison with the pre-inc value 1442 // the addrec is safe. Also, if the entry is guarded by a comparison 1443 // with the start value and the backedge is guarded by a comparison 1444 // with the post-inc value, the addrec is safe. 1445 if (isKnownPositive(Step)) { 1446 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - 1447 getUnsignedRange(Step).getUnsignedMax()); 1448 if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || 1449 (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) && 1450 isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, 1451 AR->getPostIncExpr(*this), N))) { 1452 // Cache knowledge of AR NUW, which is propagated to this AddRec. 1453 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); 1454 // Return the expression with the addrec on the outside. 1455 return getAddRecExpr( 1456 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), 1457 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1458 } 1459 } else if (isKnownNegative(Step)) { 1460 const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - 1461 getSignedRange(Step).getSignedMin()); 1462 if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) || 1463 (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) && 1464 isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, 1465 AR->getPostIncExpr(*this), N))) { 1466 // Cache knowledge of AR NW, which is propagated to this AddRec. 1467 // Negative step causes unsigned wrap, but it still can't self-wrap. 1468 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); 1469 // Return the expression with the addrec on the outside. 1470 return getAddRecExpr( 1471 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), 1472 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1473 } 1474 } 1475 } 1476 } 1477 1478 // The cast wasn't folded; create an explicit cast node. 1479 // Recompute the insert position, as it may have been invalidated. 1480 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 1481 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), 1482 Op, Ty); 1483 UniqueSCEVs.InsertNode(S, IP); 1484 return S; 1485 } 1486 1487 const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, 1488 Type *Ty) { 1489 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && 1490 "This is not an extending conversion!"); 1491 assert(isSCEVable(Ty) && 1492 "This is not a conversion to a SCEVable type!"); 1493 Ty = getEffectiveSCEVType(Ty); 1494 1495 // Fold if the operand is constant. 1496 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) 1497 return getConstant( 1498 cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty))); 1499 1500 // sext(sext(x)) --> sext(x) 1501 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op)) 1502 return getSignExtendExpr(SS->getOperand(), Ty); 1503 1504 // sext(zext(x)) --> zext(x) 1505 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) 1506 return getZeroExtendExpr(SZ->getOperand(), Ty); 1507 1508 // Before doing any expensive analysis, check to see if we've already 1509 // computed a SCEV for this Op and Ty. 1510 FoldingSetNodeID ID; 1511 ID.AddInteger(scSignExtend); 1512 ID.AddPointer(Op); 1513 ID.AddPointer(Ty); 1514 void *IP = nullptr; 1515 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 1516 1517 // If the input value is provably positive, build a zext instead. 1518 if (isKnownNonNegative(Op)) 1519 return getZeroExtendExpr(Op, Ty); 1520 1521 // sext(trunc(x)) --> sext(x) or x or trunc(x) 1522 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { 1523 // It's possible the bits taken off by the truncate were all sign bits. If 1524 // so, we should be able to simplify this further. 1525 const SCEV *X = ST->getOperand(); 1526 ConstantRange CR = getSignedRange(X); 1527 unsigned TruncBits = getTypeSizeInBits(ST->getType()); 1528 unsigned NewBits = getTypeSizeInBits(Ty); 1529 if (CR.truncate(TruncBits).signExtend(NewBits).contains( 1530 CR.sextOrTrunc(NewBits))) 1531 return getTruncateOrSignExtend(X, Ty); 1532 } 1533 1534 // sext(C1 + (C2 * x)) --> C1 + sext(C2 * x) if C1 < C2 1535 if (auto SA = dyn_cast<SCEVAddExpr>(Op)) { 1536 if (SA->getNumOperands() == 2) { 1537 auto SC1 = dyn_cast<SCEVConstant>(SA->getOperand(0)); 1538 auto SMul = dyn_cast<SCEVMulExpr>(SA->getOperand(1)); 1539 if (SMul && SC1) { 1540 if (auto SC2 = dyn_cast<SCEVConstant>(SMul->getOperand(0))) { 1541 const APInt &C1 = SC1->getValue()->getValue(); 1542 const APInt &C2 = SC2->getValue()->getValue(); 1543 if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && 1544 C2.ugt(C1) && C2.isPowerOf2()) 1545 return getAddExpr(getSignExtendExpr(SC1, Ty), 1546 getSignExtendExpr(SMul, Ty)); 1547 } 1548 } 1549 } 1550 } 1551 // If the input value is a chrec scev, and we can prove that the value 1552 // did not overflow the old, smaller, value, we can sign extend all of the 1553 // operands (often constants). This allows analysis of something like 1554 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; } 1555 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) 1556 if (AR->isAffine()) { 1557 const SCEV *Start = AR->getStart(); 1558 const SCEV *Step = AR->getStepRecurrence(*this); 1559 unsigned BitWidth = getTypeSizeInBits(AR->getType()); 1560 const Loop *L = AR->getLoop(); 1561 1562 // If we have special knowledge that this addrec won't overflow, 1563 // we don't need to do any further analysis. 1564 if (AR->getNoWrapFlags(SCEV::FlagNSW)) 1565 return getAddRecExpr( 1566 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), 1567 getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW); 1568 1569 // Check whether the backedge-taken count is SCEVCouldNotCompute. 1570 // Note that this serves two purposes: It filters out loops that are 1571 // simply not analyzable, and it covers the case where this code is 1572 // being called from within backedge-taken count analysis, such that 1573 // attempting to ask for the backedge-taken count would likely result 1574 // in infinite recursion. In the later case, the analysis code will 1575 // cope with a conservative value, and it will take care to purge 1576 // that value once it has finished. 1577 const SCEV *MaxBECount = getMaxBackedgeTakenCount(L); 1578 if (!isa<SCEVCouldNotCompute>(MaxBECount)) { 1579 // Manually compute the final value for AR, checking for 1580 // overflow. 1581 1582 // Check whether the backedge-taken count can be losslessly casted to 1583 // the addrec's type. The count is always unsigned. 1584 const SCEV *CastedMaxBECount = 1585 getTruncateOrZeroExtend(MaxBECount, Start->getType()); 1586 const SCEV *RecastedMaxBECount = 1587 getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); 1588 if (MaxBECount == RecastedMaxBECount) { 1589 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); 1590 // Check whether Start+Step*MaxBECount has no signed overflow. 1591 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step); 1592 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul), WideTy); 1593 const SCEV *WideStart = getSignExtendExpr(Start, WideTy); 1594 const SCEV *WideMaxBECount = 1595 getZeroExtendExpr(CastedMaxBECount, WideTy); 1596 const SCEV *OperandExtendedAdd = 1597 getAddExpr(WideStart, 1598 getMulExpr(WideMaxBECount, 1599 getSignExtendExpr(Step, WideTy))); 1600 if (SAdd == OperandExtendedAdd) { 1601 // Cache knowledge of AR NSW, which is propagated to this AddRec. 1602 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); 1603 // Return the expression with the addrec on the outside. 1604 return getAddRecExpr( 1605 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), 1606 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1607 } 1608 // Similar to above, only this time treat the step value as unsigned. 1609 // This covers loops that count up with an unsigned step. 1610 OperandExtendedAdd = 1611 getAddExpr(WideStart, 1612 getMulExpr(WideMaxBECount, 1613 getZeroExtendExpr(Step, WideTy))); 1614 if (SAdd == OperandExtendedAdd) { 1615 // If AR wraps around then 1616 // 1617 // abs(Step) * MaxBECount > unsigned-max(AR->getType()) 1618 // => SAdd != OperandExtendedAdd 1619 // 1620 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=> 1621 // (SAdd == OperandExtendedAdd => AR is NW) 1622 1623 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); 1624 1625 // Return the expression with the addrec on the outside. 1626 return getAddRecExpr( 1627 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), 1628 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1629 } 1630 } 1631 1632 // If the backedge is guarded by a comparison with the pre-inc value 1633 // the addrec is safe. Also, if the entry is guarded by a comparison 1634 // with the start value and the backedge is guarded by a comparison 1635 // with the post-inc value, the addrec is safe. 1636 ICmpInst::Predicate Pred; 1637 const SCEV *OverflowLimit = 1638 getSignedOverflowLimitForStep(Step, &Pred, this); 1639 if (OverflowLimit && 1640 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || 1641 (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) && 1642 isLoopBackedgeGuardedByCond(L, Pred, AR->getPostIncExpr(*this), 1643 OverflowLimit)))) { 1644 // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec. 1645 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); 1646 return getAddRecExpr( 1647 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), 1648 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1649 } 1650 } 1651 // If Start and Step are constants, check if we can apply this 1652 // transformation: 1653 // sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2 1654 auto SC1 = dyn_cast<SCEVConstant>(Start); 1655 auto SC2 = dyn_cast<SCEVConstant>(Step); 1656 if (SC1 && SC2) { 1657 const APInt &C1 = SC1->getValue()->getValue(); 1658 const APInt &C2 = SC2->getValue()->getValue(); 1659 if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && 1660 C2.isPowerOf2()) { 1661 Start = getSignExtendExpr(Start, Ty); 1662 const SCEV *NewAR = getAddRecExpr(getConstant(AR->getType(), 0), Step, 1663 L, AR->getNoWrapFlags()); 1664 return getAddExpr(Start, getSignExtendExpr(NewAR, Ty)); 1665 } 1666 } 1667 } 1668 1669 // The cast wasn't folded; create an explicit cast node. 1670 // Recompute the insert position, as it may have been invalidated. 1671 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 1672 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), 1673 Op, Ty); 1674 UniqueSCEVs.InsertNode(S, IP); 1675 return S; 1676 } 1677 1678 /// getAnyExtendExpr - Return a SCEV for the given operand extended with 1679 /// unspecified bits out to the given type. 1680 /// 1681 const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, 1682 Type *Ty) { 1683 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && 1684 "This is not an extending conversion!"); 1685 assert(isSCEVable(Ty) && 1686 "This is not a conversion to a SCEVable type!"); 1687 Ty = getEffectiveSCEVType(Ty); 1688 1689 // Sign-extend negative constants. 1690 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) 1691 if (SC->getValue()->getValue().isNegative()) 1692 return getSignExtendExpr(Op, Ty); 1693 1694 // Peel off a truncate cast. 1695 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) { 1696 const SCEV *NewOp = T->getOperand(); 1697 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty)) 1698 return getAnyExtendExpr(NewOp, Ty); 1699 return getTruncateOrNoop(NewOp, Ty); 1700 } 1701 1702 // Next try a zext cast. If the cast is folded, use it. 1703 const SCEV *ZExt = getZeroExtendExpr(Op, Ty); 1704 if (!isa<SCEVZeroExtendExpr>(ZExt)) 1705 return ZExt; 1706 1707 // Next try a sext cast. If the cast is folded, use it. 1708 const SCEV *SExt = getSignExtendExpr(Op, Ty); 1709 if (!isa<SCEVSignExtendExpr>(SExt)) 1710 return SExt; 1711 1712 // Force the cast to be folded into the operands of an addrec. 1713 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) { 1714 SmallVector<const SCEV *, 4> Ops; 1715 for (const SCEV *Op : AR->operands()) 1716 Ops.push_back(getAnyExtendExpr(Op, Ty)); 1717 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW); 1718 } 1719 1720 // If the expression is obviously signed, use the sext cast value. 1721 if (isa<SCEVSMaxExpr>(Op)) 1722 return SExt; 1723 1724 // Absent any other information, use the zext cast value. 1725 return ZExt; 1726 } 1727 1728 /// CollectAddOperandsWithScales - Process the given Ops list, which is 1729 /// a list of operands to be added under the given scale, update the given 1730 /// map. This is a helper function for getAddRecExpr. As an example of 1731 /// what it does, given a sequence of operands that would form an add 1732 /// expression like this: 1733 /// 1734 /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r) 1735 /// 1736 /// where A and B are constants, update the map with these values: 1737 /// 1738 /// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0) 1739 /// 1740 /// and add 13 + A*B*29 to AccumulatedConstant. 1741 /// This will allow getAddRecExpr to produce this: 1742 /// 1743 /// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B) 1744 /// 1745 /// This form often exposes folding opportunities that are hidden in 1746 /// the original operand list. 1747 /// 1748 /// Return true iff it appears that any interesting folding opportunities 1749 /// may be exposed. This helps getAddRecExpr short-circuit extra work in 1750 /// the common case where no interesting opportunities are present, and 1751 /// is also used as a check to avoid infinite recursion. 1752 /// 1753 static bool 1754 CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M, 1755 SmallVectorImpl<const SCEV *> &NewOps, 1756 APInt &AccumulatedConstant, 1757 const SCEV *const *Ops, size_t NumOperands, 1758 const APInt &Scale, 1759 ScalarEvolution &SE) { 1760 bool Interesting = false; 1761 1762 // Iterate over the add operands. They are sorted, with constants first. 1763 unsigned i = 0; 1764 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) { 1765 ++i; 1766 // Pull a buried constant out to the outside. 1767 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero()) 1768 Interesting = true; 1769 AccumulatedConstant += Scale * C->getValue()->getValue(); 1770 } 1771 1772 // Next comes everything else. We're especially interested in multiplies 1773 // here, but they're in the middle, so just visit the rest with one loop. 1774 for (; i != NumOperands; ++i) { 1775 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]); 1776 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) { 1777 APInt NewScale = 1778 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getValue()->getValue(); 1779 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) { 1780 // A multiplication of a constant with another add; recurse. 1781 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1)); 1782 Interesting |= 1783 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, 1784 Add->op_begin(), Add->getNumOperands(), 1785 NewScale, SE); 1786 } else { 1787 // A multiplication of a constant with some other value. Update 1788 // the map. 1789 SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end()); 1790 const SCEV *Key = SE.getMulExpr(MulOps); 1791 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair = 1792 M.insert(std::make_pair(Key, NewScale)); 1793 if (Pair.second) { 1794 NewOps.push_back(Pair.first->first); 1795 } else { 1796 Pair.first->second += NewScale; 1797 // The map already had an entry for this value, which may indicate 1798 // a folding opportunity. 1799 Interesting = true; 1800 } 1801 } 1802 } else { 1803 // An ordinary operand. Update the map. 1804 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair = 1805 M.insert(std::make_pair(Ops[i], Scale)); 1806 if (Pair.second) { 1807 NewOps.push_back(Pair.first->first); 1808 } else { 1809 Pair.first->second += Scale; 1810 // The map already had an entry for this value, which may indicate 1811 // a folding opportunity. 1812 Interesting = true; 1813 } 1814 } 1815 } 1816 1817 return Interesting; 1818 } 1819 1820 namespace { 1821 struct APIntCompare { 1822 bool operator()(const APInt &LHS, const APInt &RHS) const { 1823 return LHS.ult(RHS); 1824 } 1825 }; 1826 } 1827 1828 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and 1829 // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of 1830 // can't-overflow flags for the operation if possible. 1831 static SCEV::NoWrapFlags 1832 StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, 1833 const SmallVectorImpl<const SCEV *> &Ops, 1834 SCEV::NoWrapFlags OldFlags) { 1835 using namespace std::placeholders; 1836 1837 bool CanAnalyze = 1838 Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr; 1839 (void)CanAnalyze; 1840 assert(CanAnalyze && "don't call from other places!"); 1841 1842 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; 1843 SCEV::NoWrapFlags SignOrUnsignWrap = 1844 ScalarEvolution::maskFlags(OldFlags, SignOrUnsignMask); 1845 1846 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. 1847 auto IsKnownNonNegative = 1848 std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1); 1849 1850 if (SignOrUnsignWrap == SCEV::FlagNSW && 1851 std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative)) 1852 return ScalarEvolution::setFlags(OldFlags, 1853 (SCEV::NoWrapFlags)SignOrUnsignMask); 1854 1855 return OldFlags; 1856 } 1857 1858 /// getAddExpr - Get a canonical add expression, or something simpler if 1859 /// possible. 1860 const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, 1861 SCEV::NoWrapFlags Flags) { 1862 assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && 1863 "only nuw or nsw allowed"); 1864 assert(!Ops.empty() && "Cannot get empty add!"); 1865 if (Ops.size() == 1) return Ops[0]; 1866 #ifndef NDEBUG 1867 Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); 1868 for (unsigned i = 1, e = Ops.size(); i != e; ++i) 1869 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && 1870 "SCEVAddExpr operand types don't match!"); 1871 #endif 1872 1873 Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags); 1874 1875 // Sort by complexity, this groups all similar expression types together. 1876 GroupByComplexity(Ops, LI); 1877 1878 // If there are any constants, fold them together. 1879 unsigned Idx = 0; 1880 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { 1881 ++Idx; 1882 assert(Idx < Ops.size()); 1883 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { 1884 // We found two constants, fold them together! 1885 Ops[0] = getConstant(LHSC->getValue()->getValue() + 1886 RHSC->getValue()->getValue()); 1887 if (Ops.size() == 2) return Ops[0]; 1888 Ops.erase(Ops.begin()+1); // Erase the folded element 1889 LHSC = cast<SCEVConstant>(Ops[0]); 1890 } 1891 1892 // If we are left with a constant zero being added, strip it off. 1893 if (LHSC->getValue()->isZero()) { 1894 Ops.erase(Ops.begin()); 1895 --Idx; 1896 } 1897 1898 if (Ops.size() == 1) return Ops[0]; 1899 } 1900 1901 // Okay, check to see if the same value occurs in the operand list more than 1902 // once. If so, merge them together into an multiply expression. Since we 1903 // sorted the list, these values are required to be adjacent. 1904 Type *Ty = Ops[0]->getType(); 1905 bool FoundMatch = false; 1906 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i) 1907 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2 1908 // Scan ahead to count how many equal operands there are. 1909 unsigned Count = 2; 1910 while (i+Count != e && Ops[i+Count] == Ops[i]) 1911 ++Count; 1912 // Merge the values into a multiply. 1913 const SCEV *Scale = getConstant(Ty, Count); 1914 const SCEV *Mul = getMulExpr(Scale, Ops[i]); 1915 if (Ops.size() == Count) 1916 return Mul; 1917 Ops[i] = Mul; 1918 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count); 1919 --i; e -= Count - 1; 1920 FoundMatch = true; 1921 } 1922 if (FoundMatch) 1923 return getAddExpr(Ops, Flags); 1924 1925 // Check for truncates. If all the operands are truncated from the same 1926 // type, see if factoring out the truncate would permit the result to be 1927 // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n) 1928 // if the contents of the resulting outer trunc fold to something simple. 1929 for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) { 1930 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]); 1931 Type *DstType = Trunc->getType(); 1932 Type *SrcType = Trunc->getOperand()->getType(); 1933 SmallVector<const SCEV *, 8> LargeOps; 1934 bool Ok = true; 1935 // Check all the operands to see if they can be represented in the 1936 // source type of the truncate. 1937 for (unsigned i = 0, e = Ops.size(); i != e; ++i) { 1938 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) { 1939 if (T->getOperand()->getType() != SrcType) { 1940 Ok = false; 1941 break; 1942 } 1943 LargeOps.push_back(T->getOperand()); 1944 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) { 1945 LargeOps.push_back(getAnyExtendExpr(C, SrcType)); 1946 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) { 1947 SmallVector<const SCEV *, 8> LargeMulOps; 1948 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) { 1949 if (const SCEVTruncateExpr *T = 1950 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) { 1951 if (T->getOperand()->getType() != SrcType) { 1952 Ok = false; 1953 break; 1954 } 1955 LargeMulOps.push_back(T->getOperand()); 1956 } else if (const SCEVConstant *C = 1957 dyn_cast<SCEVConstant>(M->getOperand(j))) { 1958 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType)); 1959 } else { 1960 Ok = false; 1961 break; 1962 } 1963 } 1964 if (Ok) 1965 LargeOps.push_back(getMulExpr(LargeMulOps)); 1966 } else { 1967 Ok = false; 1968 break; 1969 } 1970 } 1971 if (Ok) { 1972 // Evaluate the expression in the larger type. 1973 const SCEV *Fold = getAddExpr(LargeOps, Flags); 1974 // If it folds to something simple, use it. Otherwise, don't. 1975 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold)) 1976 return getTruncateExpr(Fold, DstType); 1977 } 1978 } 1979 1980 // Skip past any other cast SCEVs. 1981 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr) 1982 ++Idx; 1983 1984 // If there are add operands they would be next. 1985 if (Idx < Ops.size()) { 1986 bool DeletedAdd = false; 1987 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) { 1988 // If we have an add, expand the add operands onto the end of the operands 1989 // list. 1990 Ops.erase(Ops.begin()+Idx); 1991 Ops.append(Add->op_begin(), Add->op_end()); 1992 DeletedAdd = true; 1993 } 1994 1995 // If we deleted at least one add, we added operands to the end of the list, 1996 // and they are not necessarily sorted. Recurse to resort and resimplify 1997 // any operands we just acquired. 1998 if (DeletedAdd) 1999 return getAddExpr(Ops); 2000 } 2001 2002 // Skip over the add expression until we get to a multiply. 2003 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) 2004 ++Idx; 2005 2006 // Check to see if there are any folding opportunities present with 2007 // operands multiplied by constant values. 2008 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) { 2009 uint64_t BitWidth = getTypeSizeInBits(Ty); 2010 DenseMap<const SCEV *, APInt> M; 2011 SmallVector<const SCEV *, 8> NewOps; 2012 APInt AccumulatedConstant(BitWidth, 0); 2013 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, 2014 Ops.data(), Ops.size(), 2015 APInt(BitWidth, 1), *this)) { 2016 // Some interesting folding opportunity is present, so its worthwhile to 2017 // re-generate the operands list. Group the operands by constant scale, 2018 // to avoid multiplying by the same constant scale multiple times. 2019 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists; 2020 for (SmallVectorImpl<const SCEV *>::const_iterator I = NewOps.begin(), 2021 E = NewOps.end(); I != E; ++I) 2022 MulOpLists[M.find(*I)->second].push_back(*I); 2023 // Re-generate the operands list. 2024 Ops.clear(); 2025 if (AccumulatedConstant != 0) 2026 Ops.push_back(getConstant(AccumulatedConstant)); 2027 for (std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare>::iterator 2028 I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I) 2029 if (I->first != 0) 2030 Ops.push_back(getMulExpr(getConstant(I->first), 2031 getAddExpr(I->second))); 2032 if (Ops.empty()) 2033 return getConstant(Ty, 0); 2034 if (Ops.size() == 1) 2035 return Ops[0]; 2036 return getAddExpr(Ops); 2037 } 2038 } 2039 2040 // If we are adding something to a multiply expression, make sure the 2041 // something is not already an operand of the multiply. If so, merge it into 2042 // the multiply. 2043 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) { 2044 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]); 2045 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) { 2046 const SCEV *MulOpSCEV = Mul->getOperand(MulOp); 2047 if (isa<SCEVConstant>(MulOpSCEV)) 2048 continue; 2049 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) 2050 if (MulOpSCEV == Ops[AddOp]) { 2051 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1)) 2052 const SCEV *InnerMul = Mul->getOperand(MulOp == 0); 2053 if (Mul->getNumOperands() != 2) { 2054 // If the multiply has more than two operands, we must get the 2055 // Y*Z term. 2056 SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(), 2057 Mul->op_begin()+MulOp); 2058 MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); 2059 InnerMul = getMulExpr(MulOps); 2060 } 2061 const SCEV *One = getConstant(Ty, 1); 2062 const SCEV *AddOne = getAddExpr(One, InnerMul); 2063 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV); 2064 if (Ops.size() == 2) return OuterMul; 2065 if (AddOp < Idx) { 2066 Ops.erase(Ops.begin()+AddOp); 2067 Ops.erase(Ops.begin()+Idx-1); 2068 } else { 2069 Ops.erase(Ops.begin()+Idx); 2070 Ops.erase(Ops.begin()+AddOp-1); 2071 } 2072 Ops.push_back(OuterMul); 2073 return getAddExpr(Ops); 2074 } 2075 2076 // Check this multiply against other multiplies being added together. 2077 for (unsigned OtherMulIdx = Idx+1; 2078 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]); 2079 ++OtherMulIdx) { 2080 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]); 2081 // If MulOp occurs in OtherMul, we can fold the two multiplies 2082 // together. 2083 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands(); 2084 OMulOp != e; ++OMulOp) 2085 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) { 2086 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E)) 2087 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0); 2088 if (Mul->getNumOperands() != 2) { 2089 SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(), 2090 Mul->op_begin()+MulOp); 2091 MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); 2092 InnerMul1 = getMulExpr(MulOps); 2093 } 2094 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0); 2095 if (OtherMul->getNumOperands() != 2) { 2096 SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(), 2097 OtherMul->op_begin()+OMulOp); 2098 MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end()); 2099 InnerMul2 = getMulExpr(MulOps); 2100 } 2101 const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2); 2102 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum); 2103 if (Ops.size() == 2) return OuterMul; 2104 Ops.erase(Ops.begin()+Idx); 2105 Ops.erase(Ops.begin()+OtherMulIdx-1); 2106 Ops.push_back(OuterMul); 2107 return getAddExpr(Ops); 2108 } 2109 } 2110 } 2111 } 2112 2113 // If there are any add recurrences in the operands list, see if any other 2114 // added values are loop invariant. If so, we can fold them into the 2115 // recurrence. 2116 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr) 2117 ++Idx; 2118 2119 // Scan over all recurrences, trying to fold loop invariants into them. 2120 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { 2121 // Scan all of the other operands to this add and add them to the vector if 2122 // they are loop invariant w.r.t. the recurrence. 2123 SmallVector<const SCEV *, 8> LIOps; 2124 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); 2125 const Loop *AddRecLoop = AddRec->getLoop(); 2126 for (unsigned i = 0, e = Ops.size(); i != e; ++i) 2127 if (isLoopInvariant(Ops[i], AddRecLoop)) { 2128 LIOps.push_back(Ops[i]); 2129 Ops.erase(Ops.begin()+i); 2130 --i; --e; 2131 } 2132 2133 // If we found some loop invariants, fold them into the recurrence. 2134 if (!LIOps.empty()) { 2135 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} 2136 LIOps.push_back(AddRec->getStart()); 2137 2138 SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(), 2139 AddRec->op_end()); 2140 AddRecOps[0] = getAddExpr(LIOps); 2141 2142 // Build the new addrec. Propagate the NUW and NSW flags if both the 2143 // outer add and the inner addrec are guaranteed to have no overflow. 2144 // Always propagate NW. 2145 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW)); 2146 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags); 2147 2148 // If all of the other operands were loop invariant, we are done. 2149 if (Ops.size() == 1) return NewRec; 2150 2151 // Otherwise, add the folded AddRec by the non-invariant parts. 2152 for (unsigned i = 0;; ++i) 2153 if (Ops[i] == AddRec) { 2154 Ops[i] = NewRec; 2155 break; 2156 } 2157 return getAddExpr(Ops); 2158 } 2159 2160 // Okay, if there weren't any loop invariants to be folded, check to see if 2161 // there are multiple AddRec's with the same loop induction variable being 2162 // added together. If so, we can fold them. 2163 for (unsigned OtherIdx = Idx+1; 2164 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); 2165 ++OtherIdx) 2166 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) { 2167 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L> 2168 SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(), 2169 AddRec->op_end()); 2170 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); 2171 ++OtherIdx) 2172 if (const SCEVAddRecExpr *OtherAddRec = 2173 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx])) 2174 if (OtherAddRec->getLoop() == AddRecLoop) { 2175 for (unsigned i = 0, e = OtherAddRec->getNumOperands(); 2176 i != e; ++i) { 2177 if (i >= AddRecOps.size()) { 2178 AddRecOps.append(OtherAddRec->op_begin()+i, 2179 OtherAddRec->op_end()); 2180 break; 2181 } 2182 AddRecOps[i] = getAddExpr(AddRecOps[i], 2183 OtherAddRec->getOperand(i)); 2184 } 2185 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; 2186 } 2187 // Step size has changed, so we cannot guarantee no self-wraparound. 2188 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap); 2189 return getAddExpr(Ops); 2190 } 2191 2192 // Otherwise couldn't fold anything into this recurrence. Move onto the 2193 // next one. 2194 } 2195 2196 // Okay, it looks like we really DO need an add expr. Check to see if we 2197 // already have one, otherwise create a new one. 2198 FoldingSetNodeID ID; 2199 ID.AddInteger(scAddExpr); 2200 for (unsigned i = 0, e = Ops.size(); i != e; ++i) 2201 ID.AddPointer(Ops[i]); 2202 void *IP = nullptr; 2203 SCEVAddExpr *S = 2204 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); 2205 if (!S) { 2206 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); 2207 std::uninitialized_copy(Ops.begin(), Ops.end(), O); 2208 S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator), 2209 O, Ops.size()); 2210 UniqueSCEVs.InsertNode(S, IP); 2211 } 2212 S->setNoWrapFlags(Flags); 2213 return S; 2214 } 2215 2216 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) { 2217 uint64_t k = i*j; 2218 if (j > 1 && k / j != i) Overflow = true; 2219 return k; 2220 } 2221 2222 /// Compute the result of "n choose k", the binomial coefficient. If an 2223 /// intermediate computation overflows, Overflow will be set and the return will 2224 /// be garbage. Overflow is not cleared on absence of overflow. 2225 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) { 2226 // We use the multiplicative formula: 2227 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 . 2228 // At each iteration, we take the n-th term of the numeral and divide by the 2229 // (k-n)th term of the denominator. This division will always produce an 2230 // integral result, and helps reduce the chance of overflow in the 2231 // intermediate computations. However, we can still overflow even when the 2232 // final result would fit. 2233 2234 if (n == 0 || n == k) return 1; 2235 if (k > n) return 0; 2236 2237 if (k > n/2) 2238 k = n-k; 2239 2240 uint64_t r = 1; 2241 for (uint64_t i = 1; i <= k; ++i) { 2242 r = umul_ov(r, n-(i-1), Overflow); 2243 r /= i; 2244 } 2245 return r; 2246 } 2247 2248 /// Determine if any of the operands in this SCEV are a constant or if 2249 /// any of the add or multiply expressions in this SCEV contain a constant. 2250 static bool containsConstantSomewhere(const SCEV *StartExpr) { 2251 SmallVector<const SCEV *, 4> Ops; 2252 Ops.push_back(StartExpr); 2253 while (!Ops.empty()) { 2254 const SCEV *CurrentExpr = Ops.pop_back_val(); 2255 if (isa<SCEVConstant>(*CurrentExpr)) 2256 return true; 2257 2258 if (isa<SCEVAddExpr>(*CurrentExpr) || isa<SCEVMulExpr>(*CurrentExpr)) { 2259 const auto *CurrentNAry = cast<SCEVNAryExpr>(CurrentExpr); 2260 Ops.append(CurrentNAry->op_begin(), CurrentNAry->op_end()); 2261 } 2262 } 2263 return false; 2264 } 2265 2266 /// getMulExpr - Get a canonical multiply expression, or something simpler if 2267 /// possible. 2268 const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, 2269 SCEV::NoWrapFlags Flags) { 2270 assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) && 2271 "only nuw or nsw allowed"); 2272 assert(!Ops.empty() && "Cannot get empty mul!"); 2273 if (Ops.size() == 1) return Ops[0]; 2274 #ifndef NDEBUG 2275 Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); 2276 for (unsigned i = 1, e = Ops.size(); i != e; ++i) 2277 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && 2278 "SCEVMulExpr operand types don't match!"); 2279 #endif 2280 2281 Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); 2282 2283 // Sort by complexity, this groups all similar expression types together. 2284 GroupByComplexity(Ops, LI); 2285 2286 // If there are any constants, fold them together. 2287 unsigned Idx = 0; 2288 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { 2289 2290 // C1*(C2+V) -> C1*C2 + C1*V 2291 if (Ops.size() == 2) 2292 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) 2293 // If any of Add's ops are Adds or Muls with a constant, 2294 // apply this transformation as well. 2295 if (Add->getNumOperands() == 2) 2296 if (containsConstantSomewhere(Add)) 2297 return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), 2298 getMulExpr(LHSC, Add->getOperand(1))); 2299 2300 ++Idx; 2301 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { 2302 // We found two constants, fold them together! 2303 ConstantInt *Fold = ConstantInt::get(getContext(), 2304 LHSC->getValue()->getValue() * 2305 RHSC->getValue()->getValue()); 2306 Ops[0] = getConstant(Fold); 2307 Ops.erase(Ops.begin()+1); // Erase the folded element 2308 if (Ops.size() == 1) return Ops[0]; 2309 LHSC = cast<SCEVConstant>(Ops[0]); 2310 } 2311 2312 // If we are left with a constant one being multiplied, strip it off. 2313 if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) { 2314 Ops.erase(Ops.begin()); 2315 --Idx; 2316 } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) { 2317 // If we have a multiply of zero, it will always be zero. 2318 return Ops[0]; 2319 } else if (Ops[0]->isAllOnesValue()) { 2320 // If we have a mul by -1 of an add, try distributing the -1 among the 2321 // add operands. 2322 if (Ops.size() == 2) { 2323 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) { 2324 SmallVector<const SCEV *, 4> NewOps; 2325 bool AnyFolded = false; 2326 for (SCEVAddRecExpr::op_iterator I = Add->op_begin(), 2327 E = Add->op_end(); I != E; ++I) { 2328 const SCEV *Mul = getMulExpr(Ops[0], *I); 2329 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true; 2330 NewOps.push_back(Mul); 2331 } 2332 if (AnyFolded) 2333 return getAddExpr(NewOps); 2334 } 2335 else if (const SCEVAddRecExpr * 2336 AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) { 2337 // Negation preserves a recurrence's no self-wrap property. 2338 SmallVector<const SCEV *, 4> Operands; 2339 for (SCEVAddRecExpr::op_iterator I = AddRec->op_begin(), 2340 E = AddRec->op_end(); I != E; ++I) { 2341 Operands.push_back(getMulExpr(Ops[0], *I)); 2342 } 2343 return getAddRecExpr(Operands, AddRec->getLoop(), 2344 AddRec->getNoWrapFlags(SCEV::FlagNW)); 2345 } 2346 } 2347 } 2348 2349 if (Ops.size() == 1) 2350 return Ops[0]; 2351 } 2352 2353 // Skip over the add expression until we get to a multiply. 2354 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) 2355 ++Idx; 2356 2357 // If there are mul operands inline them all into this expression. 2358 if (Idx < Ops.size()) { 2359 bool DeletedMul = false; 2360 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) { 2361 // If we have an mul, expand the mul operands onto the end of the operands 2362 // list. 2363 Ops.erase(Ops.begin()+Idx); 2364 Ops.append(Mul->op_begin(), Mul->op_end()); 2365 DeletedMul = true; 2366 } 2367 2368 // If we deleted at least one mul, we added operands to the end of the list, 2369 // and they are not necessarily sorted. Recurse to resort and resimplify 2370 // any operands we just acquired. 2371 if (DeletedMul) 2372 return getMulExpr(Ops); 2373 } 2374 2375 // If there are any add recurrences in the operands list, see if any other 2376 // added values are loop invariant. If so, we can fold them into the 2377 // recurrence. 2378 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr) 2379 ++Idx; 2380 2381 // Scan over all recurrences, trying to fold loop invariants into them. 2382 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { 2383 // Scan all of the other operands to this mul and add them to the vector if 2384 // they are loop invariant w.r.t. the recurrence. 2385 SmallVector<const SCEV *, 8> LIOps; 2386 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); 2387 const Loop *AddRecLoop = AddRec->getLoop(); 2388 for (unsigned i = 0, e = Ops.size(); i != e; ++i) 2389 if (isLoopInvariant(Ops[i], AddRecLoop)) { 2390 LIOps.push_back(Ops[i]); 2391 Ops.erase(Ops.begin()+i); 2392 --i; --e; 2393 } 2394 2395 // If we found some loop invariants, fold them into the recurrence. 2396 if (!LIOps.empty()) { 2397 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} 2398 SmallVector<const SCEV *, 4> NewOps; 2399 NewOps.reserve(AddRec->getNumOperands()); 2400 const SCEV *Scale = getMulExpr(LIOps); 2401 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) 2402 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i))); 2403 2404 // Build the new addrec. Propagate the NUW and NSW flags if both the 2405 // outer mul and the inner addrec are guaranteed to have no overflow. 2406 // 2407 // No self-wrap cannot be guaranteed after changing the step size, but 2408 // will be inferred if either NUW or NSW is true. 2409 Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW)); 2410 const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags); 2411 2412 // If all of the other operands were loop invariant, we are done. 2413 if (Ops.size() == 1) return NewRec; 2414 2415 // Otherwise, multiply the folded AddRec by the non-invariant parts. 2416 for (unsigned i = 0;; ++i) 2417 if (Ops[i] == AddRec) { 2418 Ops[i] = NewRec; 2419 break; 2420 } 2421 return getMulExpr(Ops); 2422 } 2423 2424 // Okay, if there weren't any loop invariants to be folded, check to see if 2425 // there are multiple AddRec's with the same loop induction variable being 2426 // multiplied together. If so, we can fold them. 2427 2428 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L> 2429 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ 2430 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z 2431 // ]]],+,...up to x=2n}. 2432 // Note that the arguments to choose() are always integers with values 2433 // known at compile time, never SCEV objects. 2434 // 2435 // The implementation avoids pointless extra computations when the two 2436 // addrec's are of different length (mathematically, it's equivalent to 2437 // an infinite stream of zeros on the right). 2438 bool OpsModified = false; 2439 for (unsigned OtherIdx = Idx+1; 2440 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); 2441 ++OtherIdx) { 2442 const SCEVAddRecExpr *OtherAddRec = 2443 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]); 2444 if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) 2445 continue; 2446 2447 bool Overflow = false; 2448 Type *Ty = AddRec->getType(); 2449 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; 2450 SmallVector<const SCEV*, 7> AddRecOps; 2451 for (int x = 0, xe = AddRec->getNumOperands() + 2452 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { 2453 const SCEV *Term = getConstant(Ty, 0); 2454 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { 2455 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); 2456 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), 2457 ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); 2458 z < ze && !Overflow; ++z) { 2459 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); 2460 uint64_t Coeff; 2461 if (LargerThan64Bits) 2462 Coeff = umul_ov(Coeff1, Coeff2, Overflow); 2463 else 2464 Coeff = Coeff1*Coeff2; 2465 const SCEV *CoeffTerm = getConstant(Ty, Coeff); 2466 const SCEV *Term1 = AddRec->getOperand(y-z); 2467 const SCEV *Term2 = OtherAddRec->getOperand(z); 2468 Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2)); 2469 } 2470 } 2471 AddRecOps.push_back(Term); 2472 } 2473 if (!Overflow) { 2474 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), 2475 SCEV::FlagAnyWrap); 2476 if (Ops.size() == 2) return NewAddRec; 2477 Ops[Idx] = NewAddRec; 2478 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; 2479 OpsModified = true; 2480 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec); 2481 if (!AddRec) 2482 break; 2483 } 2484 } 2485 if (OpsModified) 2486 return getMulExpr(Ops); 2487 2488 // Otherwise couldn't fold anything into this recurrence. Move onto the 2489 // next one. 2490 } 2491 2492 // Okay, it looks like we really DO need an mul expr. Check to see if we 2493 // already have one, otherwise create a new one. 2494 FoldingSetNodeID ID; 2495 ID.AddInteger(scMulExpr); 2496 for (unsigned i = 0, e = Ops.size(); i != e; ++i) 2497 ID.AddPointer(Ops[i]); 2498 void *IP = nullptr; 2499 SCEVMulExpr *S = 2500 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); 2501 if (!S) { 2502 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); 2503 std::uninitialized_copy(Ops.begin(), Ops.end(), O); 2504 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator), 2505 O, Ops.size()); 2506 UniqueSCEVs.InsertNode(S, IP); 2507 } 2508 S->setNoWrapFlags(Flags); 2509 return S; 2510 } 2511 2512 /// getUDivExpr - Get a canonical unsigned division expression, or something 2513 /// simpler if possible. 2514 const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, 2515 const SCEV *RHS) { 2516 assert(getEffectiveSCEVType(LHS->getType()) == 2517 getEffectiveSCEVType(RHS->getType()) && 2518 "SCEVUDivExpr operand types don't match!"); 2519 2520 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { 2521 if (RHSC->getValue()->equalsInt(1)) 2522 return LHS; // X udiv 1 --> x 2523 // If the denominator is zero, the result of the udiv is undefined. Don't 2524 // try to analyze it, because the resolution chosen here may differ from 2525 // the resolution chosen in other parts of the compiler. 2526 if (!RHSC->getValue()->isZero()) { 2527 // Determine if the division can be folded into the operands of 2528 // its operands. 2529 // TODO: Generalize this to non-constants by using known-bits information. 2530 Type *Ty = LHS->getType(); 2531 unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros(); 2532 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1; 2533 // For non-power-of-two values, effectively round the value up to the 2534 // nearest power of two. 2535 if (!RHSC->getValue()->getValue().isPowerOf2()) 2536 ++MaxShiftAmt; 2537 IntegerType *ExtTy = 2538 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt); 2539 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS)) 2540 if (const SCEVConstant *Step = 2541 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) { 2542 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded. 2543 const APInt &StepInt = Step->getValue()->getValue(); 2544 const APInt &DivInt = RHSC->getValue()->getValue(); 2545 if (!StepInt.urem(DivInt) && 2546 getZeroExtendExpr(AR, ExtTy) == 2547 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), 2548 getZeroExtendExpr(Step, ExtTy), 2549 AR->getLoop(), SCEV::FlagAnyWrap)) { 2550 SmallVector<const SCEV *, 4> Operands; 2551 for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i) 2552 Operands.push_back(getUDivExpr(AR->getOperand(i), RHS)); 2553 return getAddRecExpr(Operands, AR->getLoop(), 2554 SCEV::FlagNW); 2555 } 2556 /// Get a canonical UDivExpr for a recurrence. 2557 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0. 2558 // We can currently only fold X%N if X is constant. 2559 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart()); 2560 if (StartC && !DivInt.urem(StepInt) && 2561 getZeroExtendExpr(AR, ExtTy) == 2562 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), 2563 getZeroExtendExpr(Step, ExtTy), 2564 AR->getLoop(), SCEV::FlagAnyWrap)) { 2565 const APInt &StartInt = StartC->getValue()->getValue(); 2566 const APInt &StartRem = StartInt.urem(StepInt); 2567 if (StartRem != 0) 2568 LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step, 2569 AR->getLoop(), SCEV::FlagNW); 2570 } 2571 } 2572 // (A*B)/C --> A*(B/C) if safe and B/C can be folded. 2573 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) { 2574 SmallVector<const SCEV *, 4> Operands; 2575 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) 2576 Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy)); 2577 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) 2578 // Find an operand that's safely divisible. 2579 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) { 2580 const SCEV *Op = M->getOperand(i); 2581 const SCEV *Div = getUDivExpr(Op, RHSC); 2582 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) { 2583 Operands = SmallVector<const SCEV *, 4>(M->op_begin(), 2584 M->op_end()); 2585 Operands[i] = Div; 2586 return getMulExpr(Operands); 2587 } 2588 } 2589 } 2590 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. 2591 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) { 2592 SmallVector<const SCEV *, 4> Operands; 2593 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) 2594 Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy)); 2595 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) { 2596 Operands.clear(); 2597 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { 2598 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS); 2599 if (isa<SCEVUDivExpr>(Op) || 2600 getMulExpr(Op, RHS) != A->getOperand(i)) 2601 break; 2602 Operands.push_back(Op); 2603 } 2604 if (Operands.size() == A->getNumOperands()) 2605 return getAddExpr(Operands); 2606 } 2607 } 2608 2609 // Fold if both operands are constant. 2610 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) { 2611 Constant *LHSCV = LHSC->getValue(); 2612 Constant *RHSCV = RHSC->getValue(); 2613 return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV, 2614 RHSCV))); 2615 } 2616 } 2617 } 2618 2619 FoldingSetNodeID ID; 2620 ID.AddInteger(scUDivExpr); 2621 ID.AddPointer(LHS); 2622 ID.AddPointer(RHS); 2623 void *IP = nullptr; 2624 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 2625 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), 2626 LHS, RHS); 2627 UniqueSCEVs.InsertNode(S, IP); 2628 return S; 2629 } 2630 2631 static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { 2632 APInt A = C1->getValue()->getValue().abs(); 2633 APInt B = C2->getValue()->getValue().abs(); 2634 uint32_t ABW = A.getBitWidth(); 2635 uint32_t BBW = B.getBitWidth(); 2636 2637 if (ABW > BBW) 2638 B = B.zext(ABW); 2639 else if (ABW < BBW) 2640 A = A.zext(BBW); 2641 2642 return APIntOps::GreatestCommonDivisor(A, B); 2643 } 2644 2645 /// getUDivExactExpr - Get a canonical unsigned division expression, or 2646 /// something simpler if possible. There is no representation for an exact udiv 2647 /// in SCEV IR, but we can attempt to remove factors from the LHS and RHS. 2648 /// We can't do this when it's not exact because the udiv may be clearing bits. 2649 const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, 2650 const SCEV *RHS) { 2651 // TODO: we could try to find factors in all sorts of things, but for now we 2652 // just deal with u/exact (multiply, constant). See SCEVDivision towards the 2653 // end of this file for inspiration. 2654 2655 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS); 2656 if (!Mul) 2657 return getUDivExpr(LHS, RHS); 2658 2659 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) { 2660 // If the mulexpr multiplies by a constant, then that constant must be the 2661 // first element of the mulexpr. 2662 if (const SCEVConstant *LHSCst = 2663 dyn_cast<SCEVConstant>(Mul->getOperand(0))) { 2664 if (LHSCst == RHSCst) { 2665 SmallVector<const SCEV *, 2> Operands; 2666 Operands.append(Mul->op_begin() + 1, Mul->op_end()); 2667 return getMulExpr(Operands); 2668 } 2669 2670 // We can't just assume that LHSCst divides RHSCst cleanly, it could be 2671 // that there's a factor provided by one of the other terms. We need to 2672 // check. 2673 APInt Factor = gcd(LHSCst, RHSCst); 2674 if (!Factor.isIntN(1)) { 2675 LHSCst = cast<SCEVConstant>( 2676 getConstant(LHSCst->getValue()->getValue().udiv(Factor))); 2677 RHSCst = cast<SCEVConstant>( 2678 getConstant(RHSCst->getValue()->getValue().udiv(Factor))); 2679 SmallVector<const SCEV *, 2> Operands; 2680 Operands.push_back(LHSCst); 2681 Operands.append(Mul->op_begin() + 1, Mul->op_end()); 2682 LHS = getMulExpr(Operands); 2683 RHS = RHSCst; 2684 Mul = dyn_cast<SCEVMulExpr>(LHS); 2685 if (!Mul) 2686 return getUDivExactExpr(LHS, RHS); 2687 } 2688 } 2689 } 2690 2691 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) { 2692 if (Mul->getOperand(i) == RHS) { 2693 SmallVector<const SCEV *, 2> Operands; 2694 Operands.append(Mul->op_begin(), Mul->op_begin() + i); 2695 Operands.append(Mul->op_begin() + i + 1, Mul->op_end()); 2696 return getMulExpr(Operands); 2697 } 2698 } 2699 2700 return getUDivExpr(LHS, RHS); 2701 } 2702 2703 /// getAddRecExpr - Get an add recurrence expression for the specified loop. 2704 /// Simplify the expression as much as possible. 2705 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, 2706 const Loop *L, 2707 SCEV::NoWrapFlags Flags) { 2708 SmallVector<const SCEV *, 4> Operands; 2709 Operands.push_back(Start); 2710 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step)) 2711 if (StepChrec->getLoop() == L) { 2712 Operands.append(StepChrec->op_begin(), StepChrec->op_end()); 2713 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW)); 2714 } 2715 2716 Operands.push_back(Step); 2717 return getAddRecExpr(Operands, L, Flags); 2718 } 2719 2720 /// getAddRecExpr - Get an add recurrence expression for the specified loop. 2721 /// Simplify the expression as much as possible. 2722 const SCEV * 2723 ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, 2724 const Loop *L, SCEV::NoWrapFlags Flags) { 2725 if (Operands.size() == 1) return Operands[0]; 2726 #ifndef NDEBUG 2727 Type *ETy = getEffectiveSCEVType(Operands[0]->getType()); 2728 for (unsigned i = 1, e = Operands.size(); i != e; ++i) 2729 assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy && 2730 "SCEVAddRecExpr operand types don't match!"); 2731 for (unsigned i = 0, e = Operands.size(); i != e; ++i) 2732 assert(isLoopInvariant(Operands[i], L) && 2733 "SCEVAddRecExpr operand is not loop-invariant!"); 2734 #endif 2735 2736 if (Operands.back()->isZero()) { 2737 Operands.pop_back(); 2738 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X 2739 } 2740 2741 // It's tempting to want to call getMaxBackedgeTakenCount count here and 2742 // use that information to infer NUW and NSW flags. However, computing a 2743 // BE count requires calling getAddRecExpr, so we may not yet have a 2744 // meaningful BE count at this point (and if we don't, we'd be stuck 2745 // with a SCEVCouldNotCompute as the cached BE count). 2746 2747 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); 2748 2749 // Canonicalize nested AddRecs in by nesting them in order of loop depth. 2750 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) { 2751 const Loop *NestedLoop = NestedAR->getLoop(); 2752 if (L->contains(NestedLoop) ? 2753 (L->getLoopDepth() < NestedLoop->getLoopDepth()) : 2754 (!NestedLoop->contains(L) && 2755 DT->dominates(L->getHeader(), NestedLoop->getHeader()))) { 2756 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(), 2757 NestedAR->op_end()); 2758 Operands[0] = NestedAR->getStart(); 2759 // AddRecs require their operands be loop-invariant with respect to their 2760 // loops. Don't perform this transformation if it would break this 2761 // requirement. 2762 bool AllInvariant = true; 2763 for (unsigned i = 0, e = Operands.size(); i != e; ++i) 2764 if (!isLoopInvariant(Operands[i], L)) { 2765 AllInvariant = false; 2766 break; 2767 } 2768 if (AllInvariant) { 2769 // Create a recurrence for the outer loop with the same step size. 2770 // 2771 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the 2772 // inner recurrence has the same property. 2773 SCEV::NoWrapFlags OuterFlags = 2774 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags()); 2775 2776 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags); 2777 AllInvariant = true; 2778 for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i) 2779 if (!isLoopInvariant(NestedOperands[i], NestedLoop)) { 2780 AllInvariant = false; 2781 break; 2782 } 2783 if (AllInvariant) { 2784 // Ok, both add recurrences are valid after the transformation. 2785 // 2786 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if 2787 // the outer recurrence has the same property. 2788 SCEV::NoWrapFlags InnerFlags = 2789 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags); 2790 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags); 2791 } 2792 } 2793 // Reset Operands to its original state. 2794 Operands[0] = NestedAR; 2795 } 2796 } 2797 2798 // Okay, it looks like we really DO need an addrec expr. Check to see if we 2799 // already have one, otherwise create a new one. 2800 FoldingSetNodeID ID; 2801 ID.AddInteger(scAddRecExpr); 2802 for (unsigned i = 0, e = Operands.size(); i != e; ++i) 2803 ID.AddPointer(Operands[i]); 2804 ID.AddPointer(L); 2805 void *IP = nullptr; 2806 SCEVAddRecExpr *S = 2807 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); 2808 if (!S) { 2809 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Operands.size()); 2810 std::uninitialized_copy(Operands.begin(), Operands.end(), O); 2811 S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator), 2812 O, Operands.size(), L); 2813 UniqueSCEVs.InsertNode(S, IP); 2814 } 2815 S->setNoWrapFlags(Flags); 2816 return S; 2817 } 2818 2819 const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, 2820 const SCEV *RHS) { 2821 SmallVector<const SCEV *, 2> Ops; 2822 Ops.push_back(LHS); 2823 Ops.push_back(RHS); 2824 return getSMaxExpr(Ops); 2825 } 2826 2827 const SCEV * 2828 ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { 2829 assert(!Ops.empty() && "Cannot get empty smax!"); 2830 if (Ops.size() == 1) return Ops[0]; 2831 #ifndef NDEBUG 2832 Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); 2833 for (unsigned i = 1, e = Ops.size(); i != e; ++i) 2834 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && 2835 "SCEVSMaxExpr operand types don't match!"); 2836 #endif 2837 2838 // Sort by complexity, this groups all similar expression types together. 2839 GroupByComplexity(Ops, LI); 2840 2841 // If there are any constants, fold them together. 2842 unsigned Idx = 0; 2843 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { 2844 ++Idx; 2845 assert(Idx < Ops.size()); 2846 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { 2847 // We found two constants, fold them together! 2848 ConstantInt *Fold = ConstantInt::get(getContext(), 2849 APIntOps::smax(LHSC->getValue()->getValue(), 2850 RHSC->getValue()->getValue())); 2851 Ops[0] = getConstant(Fold); 2852 Ops.erase(Ops.begin()+1); // Erase the folded element 2853 if (Ops.size() == 1) return Ops[0]; 2854 LHSC = cast<SCEVConstant>(Ops[0]); 2855 } 2856 2857 // If we are left with a constant minimum-int, strip it off. 2858 if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) { 2859 Ops.erase(Ops.begin()); 2860 --Idx; 2861 } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(true)) { 2862 // If we have an smax with a constant maximum-int, it will always be 2863 // maximum-int. 2864 return Ops[0]; 2865 } 2866 2867 if (Ops.size() == 1) return Ops[0]; 2868 } 2869 2870 // Find the first SMax 2871 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr) 2872 ++Idx; 2873 2874 // Check to see if one of the operands is an SMax. If so, expand its operands 2875 // onto our operand list, and recurse to simplify. 2876 if (Idx < Ops.size()) { 2877 bool DeletedSMax = false; 2878 while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) { 2879 Ops.erase(Ops.begin()+Idx); 2880 Ops.append(SMax->op_begin(), SMax->op_end()); 2881 DeletedSMax = true; 2882 } 2883 2884 if (DeletedSMax) 2885 return getSMaxExpr(Ops); 2886 } 2887 2888 // Okay, check to see if the same value occurs in the operand list twice. If 2889 // so, delete one. Since we sorted the list, these values are required to 2890 // be adjacent. 2891 for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) 2892 // X smax Y smax Y --> X smax Y 2893 // X smax Y --> X, if X is always greater than Y 2894 if (Ops[i] == Ops[i+1] || 2895 isKnownPredicate(ICmpInst::ICMP_SGE, Ops[i], Ops[i+1])) { 2896 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2); 2897 --i; --e; 2898 } else if (isKnownPredicate(ICmpInst::ICMP_SLE, Ops[i], Ops[i+1])) { 2899 Ops.erase(Ops.begin()+i, Ops.begin()+i+1); 2900 --i; --e; 2901 } 2902 2903 if (Ops.size() == 1) return Ops[0]; 2904 2905 assert(!Ops.empty() && "Reduced smax down to nothing!"); 2906 2907 // Okay, it looks like we really DO need an smax expr. Check to see if we 2908 // already have one, otherwise create a new one. 2909 FoldingSetNodeID ID; 2910 ID.AddInteger(scSMaxExpr); 2911 for (unsigned i = 0, e = Ops.size(); i != e; ++i) 2912 ID.AddPointer(Ops[i]); 2913 void *IP = nullptr; 2914 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 2915 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); 2916 std::uninitialized_copy(Ops.begin(), Ops.end(), O); 2917 SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator), 2918 O, Ops.size()); 2919 UniqueSCEVs.InsertNode(S, IP); 2920 return S; 2921 } 2922 2923 const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, 2924 const SCEV *RHS) { 2925 SmallVector<const SCEV *, 2> Ops; 2926 Ops.push_back(LHS); 2927 Ops.push_back(RHS); 2928 return getUMaxExpr(Ops); 2929 } 2930 2931 const SCEV * 2932 ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { 2933 assert(!Ops.empty() && "Cannot get empty umax!"); 2934 if (Ops.size() == 1) return Ops[0]; 2935 #ifndef NDEBUG 2936 Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); 2937 for (unsigned i = 1, e = Ops.size(); i != e; ++i) 2938 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && 2939 "SCEVUMaxExpr operand types don't match!"); 2940 #endif 2941 2942 // Sort by complexity, this groups all similar expression types together. 2943 GroupByComplexity(Ops, LI); 2944 2945 // If there are any constants, fold them together. 2946 unsigned Idx = 0; 2947 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { 2948 ++Idx; 2949 assert(Idx < Ops.size()); 2950 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { 2951 // We found two constants, fold them together! 2952 ConstantInt *Fold = ConstantInt::get(getContext(), 2953 APIntOps::umax(LHSC->getValue()->getValue(), 2954 RHSC->getValue()->getValue())); 2955 Ops[0] = getConstant(Fold); 2956 Ops.erase(Ops.begin()+1); // Erase the folded element 2957 if (Ops.size() == 1) return Ops[0]; 2958 LHSC = cast<SCEVConstant>(Ops[0]); 2959 } 2960 2961 // If we are left with a constant minimum-int, strip it off. 2962 if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) { 2963 Ops.erase(Ops.begin()); 2964 --Idx; 2965 } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(false)) { 2966 // If we have an umax with a constant maximum-int, it will always be 2967 // maximum-int. 2968 return Ops[0]; 2969 } 2970 2971 if (Ops.size() == 1) return Ops[0]; 2972 } 2973 2974 // Find the first UMax 2975 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr) 2976 ++Idx; 2977 2978 // Check to see if one of the operands is a UMax. If so, expand its operands 2979 // onto our operand list, and recurse to simplify. 2980 if (Idx < Ops.size()) { 2981 bool DeletedUMax = false; 2982 while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) { 2983 Ops.erase(Ops.begin()+Idx); 2984 Ops.append(UMax->op_begin(), UMax->op_end()); 2985 DeletedUMax = true; 2986 } 2987 2988 if (DeletedUMax) 2989 return getUMaxExpr(Ops); 2990 } 2991 2992 // Okay, check to see if the same value occurs in the operand list twice. If 2993 // so, delete one. Since we sorted the list, these values are required to 2994 // be adjacent. 2995 for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) 2996 // X umax Y umax Y --> X umax Y 2997 // X umax Y --> X, if X is always greater than Y 2998 if (Ops[i] == Ops[i+1] || 2999 isKnownPredicate(ICmpInst::ICMP_UGE, Ops[i], Ops[i+1])) { 3000 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2); 3001 --i; --e; 3002 } else if (isKnownPredicate(ICmpInst::ICMP_ULE, Ops[i], Ops[i+1])) { 3003 Ops.erase(Ops.begin()+i, Ops.begin()+i+1); 3004 --i; --e; 3005 } 3006 3007 if (Ops.size() == 1) return Ops[0]; 3008 3009 assert(!Ops.empty() && "Reduced umax down to nothing!"); 3010 3011 // Okay, it looks like we really DO need a umax expr. Check to see if we 3012 // already have one, otherwise create a new one. 3013 FoldingSetNodeID ID; 3014 ID.AddInteger(scUMaxExpr); 3015 for (unsigned i = 0, e = Ops.size(); i != e; ++i) 3016 ID.AddPointer(Ops[i]); 3017 void *IP = nullptr; 3018 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 3019 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); 3020 std::uninitialized_copy(Ops.begin(), Ops.end(), O); 3021 SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator), 3022 O, Ops.size()); 3023 UniqueSCEVs.InsertNode(S, IP); 3024 return S; 3025 } 3026 3027 const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS, 3028 const SCEV *RHS) { 3029 // ~smax(~x, ~y) == smin(x, y). 3030 return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); 3031 } 3032 3033 const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, 3034 const SCEV *RHS) { 3035 // ~umax(~x, ~y) == umin(x, y) 3036 return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); 3037 } 3038 3039 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { 3040 // If we have DataLayout, we can bypass creating a target-independent 3041 // constant expression and then folding it back into a ConstantInt. 3042 // This is just a compile-time optimization. 3043 if (DL) 3044 return getConstant(IntTy, DL->getTypeAllocSize(AllocTy)); 3045 3046 Constant *C = ConstantExpr::getSizeOf(AllocTy); 3047 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) 3048 if (Constant *Folded = ConstantFoldConstantExpression(CE, DL, TLI)) 3049 C = Folded; 3050 Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy)); 3051 assert(Ty == IntTy && "Effective SCEV type doesn't match"); 3052 return getTruncateOrZeroExtend(getSCEV(C), Ty); 3053 } 3054 3055 const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, 3056 StructType *STy, 3057 unsigned FieldNo) { 3058 // If we have DataLayout, we can bypass creating a target-independent 3059 // constant expression and then folding it back into a ConstantInt. 3060 // This is just a compile-time optimization. 3061 if (DL) { 3062 return getConstant(IntTy, 3063 DL->getStructLayout(STy)->getElementOffset(FieldNo)); 3064 } 3065 3066 Constant *C = ConstantExpr::getOffsetOf(STy, FieldNo); 3067 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) 3068 if (Constant *Folded = ConstantFoldConstantExpression(CE, DL, TLI)) 3069 C = Folded; 3070 3071 Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy)); 3072 return getTruncateOrZeroExtend(getSCEV(C), Ty); 3073 } 3074 3075 const SCEV *ScalarEvolution::getUnknown(Value *V) { 3076 // Don't attempt to do anything other than create a SCEVUnknown object 3077 // here. createSCEV only calls getUnknown after checking for all other 3078 // interesting possibilities, and any other code that calls getUnknown 3079 // is doing so in order to hide a value from SCEV canonicalization. 3080 3081 FoldingSetNodeID ID; 3082 ID.AddInteger(scUnknown); 3083 ID.AddPointer(V); 3084 void *IP = nullptr; 3085 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) { 3086 assert(cast<SCEVUnknown>(S)->getValue() == V && 3087 "Stale SCEVUnknown in uniquing map!"); 3088 return S; 3089 } 3090 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this, 3091 FirstUnknown); 3092 FirstUnknown = cast<SCEVUnknown>(S); 3093 UniqueSCEVs.InsertNode(S, IP); 3094 return S; 3095 } 3096 3097 //===----------------------------------------------------------------------===// 3098 // Basic SCEV Analysis and PHI Idiom Recognition Code 3099 // 3100 3101 /// isSCEVable - Test if values of the given type are analyzable within 3102 /// the SCEV framework. This primarily includes integer types, and it 3103 /// can optionally include pointer types if the ScalarEvolution class 3104 /// has access to target-specific information. 3105 bool ScalarEvolution::isSCEVable(Type *Ty) const { 3106 // Integers and pointers are always SCEVable. 3107 return Ty->isIntegerTy() || Ty->isPointerTy(); 3108 } 3109 3110 /// getTypeSizeInBits - Return the size in bits of the specified type, 3111 /// for which isSCEVable must return true. 3112 uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { 3113 assert(isSCEVable(Ty) && "Type is not SCEVable!"); 3114 3115 // If we have a DataLayout, use it! 3116 if (DL) 3117 return DL->getTypeSizeInBits(Ty); 3118 3119 // Integer types have fixed sizes. 3120 if (Ty->isIntegerTy()) 3121 return Ty->getPrimitiveSizeInBits(); 3122 3123 // The only other support type is pointer. Without DataLayout, conservatively 3124 // assume pointers are 64-bit. 3125 assert(Ty->isPointerTy() && "isSCEVable permitted a non-SCEVable type!"); 3126 return 64; 3127 } 3128 3129 /// getEffectiveSCEVType - Return a type with the same bitwidth as 3130 /// the given type and which represents how SCEV will treat the given 3131 /// type, for which isSCEVable must return true. For pointer types, 3132 /// this is the pointer-sized integer type. 3133 Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { 3134 assert(isSCEVable(Ty) && "Type is not SCEVable!"); 3135 3136 if (Ty->isIntegerTy()) { 3137 return Ty; 3138 } 3139 3140 // The only other support type is pointer. 3141 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!"); 3142 3143 if (DL) 3144 return DL->getIntPtrType(Ty); 3145 3146 // Without DataLayout, conservatively assume pointers are 64-bit. 3147 return Type::getInt64Ty(getContext()); 3148 } 3149 3150 const SCEV *ScalarEvolution::getCouldNotCompute() { 3151 return &CouldNotCompute; 3152 } 3153 3154 namespace { 3155 // Helper class working with SCEVTraversal to figure out if a SCEV contains 3156 // a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne 3157 // is set iff if find such SCEVUnknown. 3158 // 3159 struct FindInvalidSCEVUnknown { 3160 bool FindOne; 3161 FindInvalidSCEVUnknown() { FindOne = false; } 3162 bool follow(const SCEV *S) { 3163 switch (static_cast<SCEVTypes>(S->getSCEVType())) { 3164 case scConstant: 3165 return false; 3166 case scUnknown: 3167 if (!cast<SCEVUnknown>(S)->getValue()) 3168 FindOne = true; 3169 return false; 3170 default: 3171 return true; 3172 } 3173 } 3174 bool isDone() const { return FindOne; } 3175 }; 3176 } 3177 3178 bool ScalarEvolution::checkValidity(const SCEV *S) const { 3179 FindInvalidSCEVUnknown F; 3180 SCEVTraversal<FindInvalidSCEVUnknown> ST(F); 3181 ST.visitAll(S); 3182 3183 return !F.FindOne; 3184 } 3185 3186 /// getSCEV - Return an existing SCEV if it exists, otherwise analyze the 3187 /// expression and create a new one. 3188 const SCEV *ScalarEvolution::getSCEV(Value *V) { 3189 assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); 3190 3191 ValueExprMapType::iterator I = ValueExprMap.find_as(V); 3192 if (I != ValueExprMap.end()) { 3193 const SCEV *S = I->second; 3194 if (checkValidity(S)) 3195 return S; 3196 else 3197 ValueExprMap.erase(I); 3198 } 3199 const SCEV *S = createSCEV(V); 3200 3201 // The process of creating a SCEV for V may have caused other SCEVs 3202 // to have been created, so it's necessary to insert the new entry 3203 // from scratch, rather than trying to remember the insert position 3204 // above. 3205 ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S)); 3206 return S; 3207 } 3208 3209 /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V 3210 /// 3211 const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) { 3212 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) 3213 return getConstant( 3214 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue()))); 3215 3216 Type *Ty = V->getType(); 3217 Ty = getEffectiveSCEVType(Ty); 3218 return getMulExpr(V, 3219 getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty)))); 3220 } 3221 3222 /// getNotSCEV - Return a SCEV corresponding to ~V = -1-V 3223 const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { 3224 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) 3225 return getConstant( 3226 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue()))); 3227 3228 Type *Ty = V->getType(); 3229 Ty = getEffectiveSCEVType(Ty); 3230 const SCEV *AllOnes = 3231 getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))); 3232 return getMinusSCEV(AllOnes, V); 3233 } 3234 3235 /// getMinusSCEV - Return LHS-RHS. Minus is represented in SCEV as A+B*-1. 3236 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, 3237 SCEV::NoWrapFlags Flags) { 3238 assert(!maskFlags(Flags, SCEV::FlagNUW) && "subtraction does not have NUW"); 3239 3240 // Fast path: X - X --> 0. 3241 if (LHS == RHS) 3242 return getConstant(LHS->getType(), 0); 3243 3244 // X - Y --> X + -Y. 3245 // X -(nsw || nuw) Y --> X + -Y. 3246 return getAddExpr(LHS, getNegativeSCEV(RHS)); 3247 } 3248 3249 /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the 3250 /// input value to the specified type. If the type must be extended, it is zero 3251 /// extended. 3252 const SCEV * 3253 ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) { 3254 Type *SrcTy = V->getType(); 3255 assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && 3256 (Ty->isIntegerTy() || Ty->isPointerTy()) && 3257 "Cannot truncate or zero extend with non-integer arguments!"); 3258 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 3259 return V; // No conversion 3260 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty)) 3261 return getTruncateExpr(V, Ty); 3262 return getZeroExtendExpr(V, Ty); 3263 } 3264 3265 /// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the 3266 /// input value to the specified type. If the type must be extended, it is sign 3267 /// extended. 3268 const SCEV * 3269 ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, 3270 Type *Ty) { 3271 Type *SrcTy = V->getType(); 3272 assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && 3273 (Ty->isIntegerTy() || Ty->isPointerTy()) && 3274 "Cannot truncate or zero extend with non-integer arguments!"); 3275 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 3276 return V; // No conversion 3277 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty)) 3278 return getTruncateExpr(V, Ty); 3279 return getSignExtendExpr(V, Ty); 3280 } 3281 3282 /// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the 3283 /// input value to the specified type. If the type must be extended, it is zero 3284 /// extended. The conversion must not be narrowing. 3285 const SCEV * 3286 ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { 3287 Type *SrcTy = V->getType(); 3288 assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && 3289 (Ty->isIntegerTy() || Ty->isPointerTy()) && 3290 "Cannot noop or zero extend with non-integer arguments!"); 3291 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && 3292 "getNoopOrZeroExtend cannot truncate!"); 3293 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 3294 return V; // No conversion 3295 return getZeroExtendExpr(V, Ty); 3296 } 3297 3298 /// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the 3299 /// input value to the specified type. If the type must be extended, it is sign 3300 /// extended. The conversion must not be narrowing. 3301 const SCEV * 3302 ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { 3303 Type *SrcTy = V->getType(); 3304 assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && 3305 (Ty->isIntegerTy() || Ty->isPointerTy()) && 3306 "Cannot noop or sign extend with non-integer arguments!"); 3307 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && 3308 "getNoopOrSignExtend cannot truncate!"); 3309 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 3310 return V; // No conversion 3311 return getSignExtendExpr(V, Ty); 3312 } 3313 3314 /// getNoopOrAnyExtend - Return a SCEV corresponding to a conversion of 3315 /// the input value to the specified type. If the type must be extended, 3316 /// it is extended with unspecified bits. The conversion must not be 3317 /// narrowing. 3318 const SCEV * 3319 ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { 3320 Type *SrcTy = V->getType(); 3321 assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && 3322 (Ty->isIntegerTy() || Ty->isPointerTy()) && 3323 "Cannot noop or any extend with non-integer arguments!"); 3324 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && 3325 "getNoopOrAnyExtend cannot truncate!"); 3326 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 3327 return V; // No conversion 3328 return getAnyExtendExpr(V, Ty); 3329 } 3330 3331 /// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the 3332 /// input value to the specified type. The conversion must not be widening. 3333 const SCEV * 3334 ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { 3335 Type *SrcTy = V->getType(); 3336 assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && 3337 (Ty->isIntegerTy() || Ty->isPointerTy()) && 3338 "Cannot truncate or noop with non-integer arguments!"); 3339 assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) && 3340 "getTruncateOrNoop cannot extend!"); 3341 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 3342 return V; // No conversion 3343 return getTruncateExpr(V, Ty); 3344 } 3345 3346 /// getUMaxFromMismatchedTypes - Promote the operands to the wider of 3347 /// the types using zero-extension, and then perform a umax operation 3348 /// with them. 3349 const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, 3350 const SCEV *RHS) { 3351 const SCEV *PromotedLHS = LHS; 3352 const SCEV *PromotedRHS = RHS; 3353 3354 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) 3355 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); 3356 else 3357 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType()); 3358 3359 return getUMaxExpr(PromotedLHS, PromotedRHS); 3360 } 3361 3362 /// getUMinFromMismatchedTypes - Promote the operands to the wider of 3363 /// the types using zero-extension, and then perform a umin operation 3364 /// with them. 3365 const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, 3366 const SCEV *RHS) { 3367 const SCEV *PromotedLHS = LHS; 3368 const SCEV *PromotedRHS = RHS; 3369 3370 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) 3371 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); 3372 else 3373 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType()); 3374 3375 return getUMinExpr(PromotedLHS, PromotedRHS); 3376 } 3377 3378 /// getPointerBase - Transitively follow the chain of pointer-type operands 3379 /// until reaching a SCEV that does not have a single pointer operand. This 3380 /// returns a SCEVUnknown pointer for well-formed pointer-type expressions, 3381 /// but corner cases do exist. 3382 const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { 3383 // A pointer operand may evaluate to a nonpointer expression, such as null. 3384 if (!V->getType()->isPointerTy()) 3385 return V; 3386 3387 if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) { 3388 return getPointerBase(Cast->getOperand()); 3389 } 3390 else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) { 3391 const SCEV *PtrOp = nullptr; 3392 for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); 3393 I != E; ++I) { 3394 if ((*I)->getType()->isPointerTy()) { 3395 // Cannot find the base of an expression with multiple pointer operands. 3396 if (PtrOp) 3397 return V; 3398 PtrOp = *I; 3399 } 3400 } 3401 if (!PtrOp) 3402 return V; 3403 return getPointerBase(PtrOp); 3404 } 3405 return V; 3406 } 3407 3408 /// PushDefUseChildren - Push users of the given Instruction 3409 /// onto the given Worklist. 3410 static void 3411 PushDefUseChildren(Instruction *I, 3412 SmallVectorImpl<Instruction *> &Worklist) { 3413 // Push the def-use children onto the Worklist stack. 3414 for (User *U : I->users()) 3415 Worklist.push_back(cast<Instruction>(U)); 3416 } 3417 3418 /// ForgetSymbolicValue - This looks up computed SCEV values for all 3419 /// instructions that depend on the given instruction and removes them from 3420 /// the ValueExprMapType map if they reference SymName. This is used during PHI 3421 /// resolution. 3422 void 3423 ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { 3424 SmallVector<Instruction *, 16> Worklist; 3425 PushDefUseChildren(PN, Worklist); 3426 3427 SmallPtrSet<Instruction *, 8> Visited; 3428 Visited.insert(PN); 3429 while (!Worklist.empty()) { 3430 Instruction *I = Worklist.pop_back_val(); 3431 if (!Visited.insert(I).second) 3432 continue; 3433 3434 ValueExprMapType::iterator It = 3435 ValueExprMap.find_as(static_cast<Value *>(I)); 3436 if (It != ValueExprMap.end()) { 3437 const SCEV *Old = It->second; 3438 3439 // Short-circuit the def-use traversal if the symbolic name 3440 // ceases to appear in expressions. 3441 if (Old != SymName && !hasOperand(Old, SymName)) 3442 continue; 3443 3444 // SCEVUnknown for a PHI either means that it has an unrecognized 3445 // structure, it's a PHI that's in the progress of being computed 3446 // by createNodeForPHI, or it's a single-value PHI. In the first case, 3447 // additional loop trip count information isn't going to change anything. 3448 // In the second case, createNodeForPHI will perform the necessary 3449 // updates on its own when it gets to that point. In the third, we do 3450 // want to forget the SCEVUnknown. 3451 if (!isa<PHINode>(I) || 3452 !isa<SCEVUnknown>(Old) || 3453 (I != PN && Old == SymName)) { 3454 forgetMemoizedResults(Old); 3455 ValueExprMap.erase(It); 3456 } 3457 } 3458 3459 PushDefUseChildren(I, Worklist); 3460 } 3461 } 3462 3463 /// createNodeForPHI - PHI nodes have two cases. Either the PHI node exists in 3464 /// a loop header, making it a potential recurrence, or it doesn't. 3465 /// 3466 const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { 3467 if (const Loop *L = LI->getLoopFor(PN->getParent())) 3468 if (L->getHeader() == PN->getParent()) { 3469 // The loop may have multiple entrances or multiple exits; we can analyze 3470 // this phi as an addrec if it has a unique entry value and a unique 3471 // backedge value. 3472 Value *BEValueV = nullptr, *StartValueV = nullptr; 3473 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { 3474 Value *V = PN->getIncomingValue(i); 3475 if (L->contains(PN->getIncomingBlock(i))) { 3476 if (!BEValueV) { 3477 BEValueV = V; 3478 } else if (BEValueV != V) { 3479 BEValueV = nullptr; 3480 break; 3481 } 3482 } else if (!StartValueV) { 3483 StartValueV = V; 3484 } else if (StartValueV != V) { 3485 StartValueV = nullptr; 3486 break; 3487 } 3488 } 3489 if (BEValueV && StartValueV) { 3490 // While we are analyzing this PHI node, handle its value symbolically. 3491 const SCEV *SymbolicName = getUnknown(PN); 3492 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && 3493 "PHI node already processed?"); 3494 ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName)); 3495 3496 // Using this symbolic name for the PHI, analyze the value coming around 3497 // the back-edge. 3498 const SCEV *BEValue = getSCEV(BEValueV); 3499 3500 // NOTE: If BEValue is loop invariant, we know that the PHI node just 3501 // has a special value for the first iteration of the loop. 3502 3503 // If the value coming around the backedge is an add with the symbolic 3504 // value we just inserted, then we found a simple induction variable! 3505 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) { 3506 // If there is a single occurrence of the symbolic value, replace it 3507 // with a recurrence. 3508 unsigned FoundIndex = Add->getNumOperands(); 3509 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) 3510 if (Add->getOperand(i) == SymbolicName) 3511 if (FoundIndex == e) { 3512 FoundIndex = i; 3513 break; 3514 } 3515 3516 if (FoundIndex != Add->getNumOperands()) { 3517 // Create an add with everything but the specified operand. 3518 SmallVector<const SCEV *, 8> Ops; 3519 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) 3520 if (i != FoundIndex) 3521 Ops.push_back(Add->getOperand(i)); 3522 const SCEV *Accum = getAddExpr(Ops); 3523 3524 // This is not a valid addrec if the step amount is varying each 3525 // loop iteration, but is not itself an addrec in this loop. 3526 if (isLoopInvariant(Accum, L) || 3527 (isa<SCEVAddRecExpr>(Accum) && 3528 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) { 3529 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; 3530 3531 // If the increment doesn't overflow, then neither the addrec nor 3532 // the post-increment will overflow. 3533 if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) { 3534 if (OBO->hasNoUnsignedWrap()) 3535 Flags = setFlags(Flags, SCEV::FlagNUW); 3536 if (OBO->hasNoSignedWrap()) 3537 Flags = setFlags(Flags, SCEV::FlagNSW); 3538 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) { 3539 // If the increment is an inbounds GEP, then we know the address 3540 // space cannot be wrapped around. We cannot make any guarantee 3541 // about signed or unsigned overflow because pointers are 3542 // unsigned but we may have a negative index from the base 3543 // pointer. We can guarantee that no unsigned wrap occurs if the 3544 // indices form a positive value. 3545 if (GEP->isInBounds()) { 3546 Flags = setFlags(Flags, SCEV::FlagNW); 3547 3548 const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); 3549 if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) 3550 Flags = setFlags(Flags, SCEV::FlagNUW); 3551 } 3552 3553 // We cannot transfer nuw and nsw flags from subtraction 3554 // operations -- sub nuw X, Y is not the same as add nuw X, -Y 3555 // for instance. 3556 } 3557 3558 const SCEV *StartVal = getSCEV(StartValueV); 3559 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); 3560 3561 // Since the no-wrap flags are on the increment, they apply to the 3562 // post-incremented value as well. 3563 if (isLoopInvariant(Accum, L)) 3564 (void)getAddRecExpr(getAddExpr(StartVal, Accum), 3565 Accum, L, Flags); 3566 3567 // Okay, for the entire analysis of this edge we assumed the PHI 3568 // to be symbolic. We now need to go back and purge all of the 3569 // entries for the scalars that use the symbolic expression. 3570 ForgetSymbolicName(PN, SymbolicName); 3571 ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; 3572 return PHISCEV; 3573 } 3574 } 3575 } else if (const SCEVAddRecExpr *AddRec = 3576 dyn_cast<SCEVAddRecExpr>(BEValue)) { 3577 // Otherwise, this could be a loop like this: 3578 // i = 0; for (j = 1; ..; ++j) { .... i = j; } 3579 // In this case, j = {1,+,1} and BEValue is j. 3580 // Because the other in-value of i (0) fits the evolution of BEValue 3581 // i really is an addrec evolution. 3582 if (AddRec->getLoop() == L && AddRec->isAffine()) { 3583 const SCEV *StartVal = getSCEV(StartValueV); 3584 3585 // If StartVal = j.start - j.stride, we can use StartVal as the 3586 // initial step of the addrec evolution. 3587 if (StartVal == getMinusSCEV(AddRec->getOperand(0), 3588 AddRec->getOperand(1))) { 3589 // FIXME: For constant StartVal, we should be able to infer 3590 // no-wrap flags. 3591 const SCEV *PHISCEV = 3592 getAddRecExpr(StartVal, AddRec->getOperand(1), L, 3593 SCEV::FlagAnyWrap); 3594 3595 // Okay, for the entire analysis of this edge we assumed the PHI 3596 // to be symbolic. We now need to go back and purge all of the 3597 // entries for the scalars that use the symbolic expression. 3598 ForgetSymbolicName(PN, SymbolicName); 3599 ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; 3600 return PHISCEV; 3601 } 3602 } 3603 } 3604 } 3605 } 3606 3607 // If the PHI has a single incoming value, follow that value, unless the 3608 // PHI's incoming blocks are in a different loop, in which case doing so 3609 // risks breaking LCSSA form. Instcombine would normally zap these, but 3610 // it doesn't have DominatorTree information, so it may miss cases. 3611 if (Value *V = SimplifyInstruction(PN, DL, TLI, DT, AC)) 3612 if (LI->replacementPreservesLCSSAForm(PN, V)) 3613 return getSCEV(V); 3614 3615 // If it's not a loop phi, we can't handle it yet. 3616 return getUnknown(PN); 3617 } 3618 3619 /// createNodeForGEP - Expand GEP instructions into add and multiply 3620 /// operations. This allows them to be analyzed by regular SCEV code. 3621 /// 3622 const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { 3623 Type *IntPtrTy = getEffectiveSCEVType(GEP->getType()); 3624 Value *Base = GEP->getOperand(0); 3625 // Don't attempt to analyze GEPs over unsized objects. 3626 if (!Base->getType()->getPointerElementType()->isSized()) 3627 return getUnknown(GEP); 3628 3629 // Don't blindly transfer the inbounds flag from the GEP instruction to the 3630 // Add expression, because the Instruction may be guarded by control flow 3631 // and the no-overflow bits may not be valid for the expression in any 3632 // context. 3633 SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW : SCEV::FlagAnyWrap; 3634 3635 const SCEV *TotalOffset = getConstant(IntPtrTy, 0); 3636 gep_type_iterator GTI = gep_type_begin(GEP); 3637 for (GetElementPtrInst::op_iterator I = std::next(GEP->op_begin()), 3638 E = GEP->op_end(); 3639 I != E; ++I) { 3640 Value *Index = *I; 3641 // Compute the (potentially symbolic) offset in bytes for this index. 3642 if (StructType *STy = dyn_cast<StructType>(*GTI++)) { 3643 // For a struct, add the member offset. 3644 unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue(); 3645 const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo); 3646 3647 // Add the field offset to the running total offset. 3648 TotalOffset = getAddExpr(TotalOffset, FieldOffset); 3649 } else { 3650 // For an array, add the element offset, explicitly scaled. 3651 const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, *GTI); 3652 const SCEV *IndexS = getSCEV(Index); 3653 // Getelementptr indices are signed. 3654 IndexS = getTruncateOrSignExtend(IndexS, IntPtrTy); 3655 3656 // Multiply the index by the element size to compute the element offset. 3657 const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize, Wrap); 3658 3659 // Add the element offset to the running total offset. 3660 TotalOffset = getAddExpr(TotalOffset, LocalOffset); 3661 } 3662 } 3663 3664 // Get the SCEV for the GEP base. 3665 const SCEV *BaseS = getSCEV(Base); 3666 3667 // Add the total offset from all the GEP indices to the base. 3668 return getAddExpr(BaseS, TotalOffset, Wrap); 3669 } 3670 3671 /// GetMinTrailingZeros - Determine the minimum number of zero bits that S is 3672 /// guaranteed to end in (at every loop iteration). It is, at the same time, 3673 /// the minimum number of times S is divisible by 2. For example, given {4,+,8} 3674 /// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S. 3675 uint32_t 3676 ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { 3677 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) 3678 return C->getValue()->getValue().countTrailingZeros(); 3679 3680 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S)) 3681 return std::min(GetMinTrailingZeros(T->getOperand()), 3682 (uint32_t)getTypeSizeInBits(T->getType())); 3683 3684 if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) { 3685 uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); 3686 return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? 3687 getTypeSizeInBits(E->getType()) : OpRes; 3688 } 3689 3690 if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) { 3691 uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); 3692 return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? 3693 getTypeSizeInBits(E->getType()) : OpRes; 3694 } 3695 3696 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) { 3697 // The result is the min of all operands results. 3698 uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); 3699 for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) 3700 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); 3701 return MinOpRes; 3702 } 3703 3704 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) { 3705 // The result is the sum of all operands results. 3706 uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0)); 3707 uint32_t BitWidth = getTypeSizeInBits(M->getType()); 3708 for (unsigned i = 1, e = M->getNumOperands(); 3709 SumOpRes != BitWidth && i != e; ++i) 3710 SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), 3711 BitWidth); 3712 return SumOpRes; 3713 } 3714 3715 if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) { 3716 // The result is the min of all operands results. 3717 uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); 3718 for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) 3719 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); 3720 return MinOpRes; 3721 } 3722 3723 if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) { 3724 // The result is the min of all operands results. 3725 uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); 3726 for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) 3727 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); 3728 return MinOpRes; 3729 } 3730 3731 if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) { 3732 // The result is the min of all operands results. 3733 uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); 3734 for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) 3735 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); 3736 return MinOpRes; 3737 } 3738 3739 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { 3740 // For a SCEVUnknown, ask ValueTracking. 3741 unsigned BitWidth = getTypeSizeInBits(U->getType()); 3742 APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); 3743 computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT); 3744 return Zeros.countTrailingOnes(); 3745 } 3746 3747 // SCEVUDivExpr 3748 return 0; 3749 } 3750 3751 /// GetRangeFromMetadata - Helper method to assign a range to V from 3752 /// metadata present in the IR. 3753 static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { 3754 if (Instruction *I = dyn_cast<Instruction>(V)) { 3755 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) { 3756 ConstantRange TotalRange( 3757 cast<IntegerType>(I->getType())->getBitWidth(), false); 3758 3759 unsigned NumRanges = MD->getNumOperands() / 2; 3760 assert(NumRanges >= 1); 3761 3762 for (unsigned i = 0; i < NumRanges; ++i) { 3763 ConstantInt *Lower = 3764 mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 0)); 3765 ConstantInt *Upper = 3766 mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 1)); 3767 ConstantRange Range(Lower->getValue(), Upper->getValue()); 3768 TotalRange = TotalRange.unionWith(Range); 3769 } 3770 3771 return TotalRange; 3772 } 3773 } 3774 3775 return None; 3776 } 3777 3778 /// getUnsignedRange - Determine the unsigned range for a particular SCEV. 3779 /// 3780 ConstantRange 3781 ScalarEvolution::getUnsignedRange(const SCEV *S) { 3782 // See if we've computed this range already. 3783 DenseMap<const SCEV *, ConstantRange>::iterator I = UnsignedRanges.find(S); 3784 if (I != UnsignedRanges.end()) 3785 return I->second; 3786 3787 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) 3788 return setUnsignedRange(C, ConstantRange(C->getValue()->getValue())); 3789 3790 unsigned BitWidth = getTypeSizeInBits(S->getType()); 3791 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); 3792 3793 // If the value has known zeros, the maximum unsigned value will have those 3794 // known zeros as well. 3795 uint32_t TZ = GetMinTrailingZeros(S); 3796 if (TZ != 0) 3797 ConservativeResult = 3798 ConstantRange(APInt::getMinValue(BitWidth), 3799 APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1); 3800 3801 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { 3802 ConstantRange X = getUnsignedRange(Add->getOperand(0)); 3803 for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) 3804 X = X.add(getUnsignedRange(Add->getOperand(i))); 3805 return setUnsignedRange(Add, ConservativeResult.intersectWith(X)); 3806 } 3807 3808 if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) { 3809 ConstantRange X = getUnsignedRange(Mul->getOperand(0)); 3810 for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) 3811 X = X.multiply(getUnsignedRange(Mul->getOperand(i))); 3812 return setUnsignedRange(Mul, ConservativeResult.intersectWith(X)); 3813 } 3814 3815 if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) { 3816 ConstantRange X = getUnsignedRange(SMax->getOperand(0)); 3817 for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) 3818 X = X.smax(getUnsignedRange(SMax->getOperand(i))); 3819 return setUnsignedRange(SMax, ConservativeResult.intersectWith(X)); 3820 } 3821 3822 if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) { 3823 ConstantRange X = getUnsignedRange(UMax->getOperand(0)); 3824 for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) 3825 X = X.umax(getUnsignedRange(UMax->getOperand(i))); 3826 return setUnsignedRange(UMax, ConservativeResult.intersectWith(X)); 3827 } 3828 3829 if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) { 3830 ConstantRange X = getUnsignedRange(UDiv->getLHS()); 3831 ConstantRange Y = getUnsignedRange(UDiv->getRHS()); 3832 return setUnsignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y))); 3833 } 3834 3835 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) { 3836 ConstantRange X = getUnsignedRange(ZExt->getOperand()); 3837 return setUnsignedRange(ZExt, 3838 ConservativeResult.intersectWith(X.zeroExtend(BitWidth))); 3839 } 3840 3841 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) { 3842 ConstantRange X = getUnsignedRange(SExt->getOperand()); 3843 return setUnsignedRange(SExt, 3844 ConservativeResult.intersectWith(X.signExtend(BitWidth))); 3845 } 3846 3847 if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) { 3848 ConstantRange X = getUnsignedRange(Trunc->getOperand()); 3849 return setUnsignedRange(Trunc, 3850 ConservativeResult.intersectWith(X.truncate(BitWidth))); 3851 } 3852 3853 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) { 3854 // If there's no unsigned wrap, the value will never be less than its 3855 // initial value. 3856 if (AddRec->getNoWrapFlags(SCEV::FlagNUW)) 3857 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart())) 3858 if (!C->getValue()->isZero()) 3859 ConservativeResult = 3860 ConservativeResult.intersectWith( 3861 ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0))); 3862 3863 // TODO: non-affine addrec 3864 if (AddRec->isAffine()) { 3865 Type *Ty = AddRec->getType(); 3866 const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); 3867 if (!isa<SCEVCouldNotCompute>(MaxBECount) && 3868 getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { 3869 MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); 3870 3871 const SCEV *Start = AddRec->getStart(); 3872 const SCEV *Step = AddRec->getStepRecurrence(*this); 3873 3874 ConstantRange StartRange = getUnsignedRange(Start); 3875 ConstantRange StepRange = getSignedRange(Step); 3876 ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); 3877 ConstantRange EndRange = 3878 StartRange.add(MaxBECountRange.multiply(StepRange)); 3879 3880 // Check for overflow. This must be done with ConstantRange arithmetic 3881 // because we could be called from within the ScalarEvolution overflow 3882 // checking code. 3883 ConstantRange ExtStartRange = StartRange.zextOrTrunc(BitWidth*2+1); 3884 ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1); 3885 ConstantRange ExtMaxBECountRange = 3886 MaxBECountRange.zextOrTrunc(BitWidth*2+1); 3887 ConstantRange ExtEndRange = EndRange.zextOrTrunc(BitWidth*2+1); 3888 if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) != 3889 ExtEndRange) 3890 return setUnsignedRange(AddRec, ConservativeResult); 3891 3892 APInt Min = APIntOps::umin(StartRange.getUnsignedMin(), 3893 EndRange.getUnsignedMin()); 3894 APInt Max = APIntOps::umax(StartRange.getUnsignedMax(), 3895 EndRange.getUnsignedMax()); 3896 if (Min.isMinValue() && Max.isMaxValue()) 3897 return setUnsignedRange(AddRec, ConservativeResult); 3898 return setUnsignedRange(AddRec, 3899 ConservativeResult.intersectWith(ConstantRange(Min, Max+1))); 3900 } 3901 } 3902 3903 return setUnsignedRange(AddRec, ConservativeResult); 3904 } 3905 3906 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { 3907 // Check if the IR explicitly contains !range metadata. 3908 Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue()); 3909 if (MDRange.hasValue()) 3910 ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue()); 3911 3912 // For a SCEVUnknown, ask ValueTracking. 3913 APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); 3914 computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT); 3915 if (Ones == ~Zeros + 1) 3916 return setUnsignedRange(U, ConservativeResult); 3917 return setUnsignedRange(U, 3918 ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1))); 3919 } 3920 3921 return setUnsignedRange(S, ConservativeResult); 3922 } 3923 3924 /// getSignedRange - Determine the signed range for a particular SCEV. 3925 /// 3926 ConstantRange 3927 ScalarEvolution::getSignedRange(const SCEV *S) { 3928 // See if we've computed this range already. 3929 DenseMap<const SCEV *, ConstantRange>::iterator I = SignedRanges.find(S); 3930 if (I != SignedRanges.end()) 3931 return I->second; 3932 3933 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) 3934 return setSignedRange(C, ConstantRange(C->getValue()->getValue())); 3935 3936 unsigned BitWidth = getTypeSizeInBits(S->getType()); 3937 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); 3938 3939 // If the value has known zeros, the maximum signed value will have those 3940 // known zeros as well. 3941 uint32_t TZ = GetMinTrailingZeros(S); 3942 if (TZ != 0) 3943 ConservativeResult = 3944 ConstantRange(APInt::getSignedMinValue(BitWidth), 3945 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1); 3946 3947 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { 3948 ConstantRange X = getSignedRange(Add->getOperand(0)); 3949 for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) 3950 X = X.add(getSignedRange(Add->getOperand(i))); 3951 return setSignedRange(Add, ConservativeResult.intersectWith(X)); 3952 } 3953 3954 if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) { 3955 ConstantRange X = getSignedRange(Mul->getOperand(0)); 3956 for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) 3957 X = X.multiply(getSignedRange(Mul->getOperand(i))); 3958 return setSignedRange(Mul, ConservativeResult.intersectWith(X)); 3959 } 3960 3961 if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) { 3962 ConstantRange X = getSignedRange(SMax->getOperand(0)); 3963 for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) 3964 X = X.smax(getSignedRange(SMax->getOperand(i))); 3965 return setSignedRange(SMax, ConservativeResult.intersectWith(X)); 3966 } 3967 3968 if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) { 3969 ConstantRange X = getSignedRange(UMax->getOperand(0)); 3970 for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) 3971 X = X.umax(getSignedRange(UMax->getOperand(i))); 3972 return setSignedRange(UMax, ConservativeResult.intersectWith(X)); 3973 } 3974 3975 if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) { 3976 ConstantRange X = getSignedRange(UDiv->getLHS()); 3977 ConstantRange Y = getSignedRange(UDiv->getRHS()); 3978 return setSignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y))); 3979 } 3980 3981 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) { 3982 ConstantRange X = getSignedRange(ZExt->getOperand()); 3983 return setSignedRange(ZExt, 3984 ConservativeResult.intersectWith(X.zeroExtend(BitWidth))); 3985 } 3986 3987 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) { 3988 ConstantRange X = getSignedRange(SExt->getOperand()); 3989 return setSignedRange(SExt, 3990 ConservativeResult.intersectWith(X.signExtend(BitWidth))); 3991 } 3992 3993 if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) { 3994 ConstantRange X = getSignedRange(Trunc->getOperand()); 3995 return setSignedRange(Trunc, 3996 ConservativeResult.intersectWith(X.truncate(BitWidth))); 3997 } 3998 3999 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) { 4000 // If there's no signed wrap, and all the operands have the same sign or 4001 // zero, the value won't ever change sign. 4002 if (AddRec->getNoWrapFlags(SCEV::FlagNSW)) { 4003 bool AllNonNeg = true; 4004 bool AllNonPos = true; 4005 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { 4006 if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false; 4007 if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false; 4008 } 4009 if (AllNonNeg) 4010 ConservativeResult = ConservativeResult.intersectWith( 4011 ConstantRange(APInt(BitWidth, 0), 4012 APInt::getSignedMinValue(BitWidth))); 4013 else if (AllNonPos) 4014 ConservativeResult = ConservativeResult.intersectWith( 4015 ConstantRange(APInt::getSignedMinValue(BitWidth), 4016 APInt(BitWidth, 1))); 4017 } 4018 4019 // TODO: non-affine addrec 4020 if (AddRec->isAffine()) { 4021 Type *Ty = AddRec->getType(); 4022 const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); 4023 if (!isa<SCEVCouldNotCompute>(MaxBECount) && 4024 getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { 4025 MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); 4026 4027 const SCEV *Start = AddRec->getStart(); 4028 const SCEV *Step = AddRec->getStepRecurrence(*this); 4029 4030 ConstantRange StartRange = getSignedRange(Start); 4031 ConstantRange StepRange = getSignedRange(Step); 4032 ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); 4033 ConstantRange EndRange = 4034 StartRange.add(MaxBECountRange.multiply(StepRange)); 4035 4036 // Check for overflow. This must be done with ConstantRange arithmetic 4037 // because we could be called from within the ScalarEvolution overflow 4038 // checking code. 4039 ConstantRange ExtStartRange = StartRange.sextOrTrunc(BitWidth*2+1); 4040 ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1); 4041 ConstantRange ExtMaxBECountRange = 4042 MaxBECountRange.zextOrTrunc(BitWidth*2+1); 4043 ConstantRange ExtEndRange = EndRange.sextOrTrunc(BitWidth*2+1); 4044 if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) != 4045 ExtEndRange) 4046 return setSignedRange(AddRec, ConservativeResult); 4047 4048 APInt Min = APIntOps::smin(StartRange.getSignedMin(), 4049 EndRange.getSignedMin()); 4050 APInt Max = APIntOps::smax(StartRange.getSignedMax(), 4051 EndRange.getSignedMax()); 4052 if (Min.isMinSignedValue() && Max.isMaxSignedValue()) 4053 return setSignedRange(AddRec, ConservativeResult); 4054 return setSignedRange(AddRec, 4055 ConservativeResult.intersectWith(ConstantRange(Min, Max+1))); 4056 } 4057 } 4058 4059 return setSignedRange(AddRec, ConservativeResult); 4060 } 4061 4062 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { 4063 // Check if the IR explicitly contains !range metadata. 4064 Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue()); 4065 if (MDRange.hasValue()) 4066 ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue()); 4067 4068 // For a SCEVUnknown, ask ValueTracking. 4069 if (!U->getValue()->getType()->isIntegerTy() && !DL) 4070 return setSignedRange(U, ConservativeResult); 4071 unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AC, nullptr, DT); 4072 if (NS <= 1) 4073 return setSignedRange(U, ConservativeResult); 4074 return setSignedRange(U, ConservativeResult.intersectWith( 4075 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), 4076 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1))); 4077 } 4078 4079 return setSignedRange(S, ConservativeResult); 4080 } 4081 4082 /// createSCEV - We know that there is no SCEV for the specified value. 4083 /// Analyze the expression. 4084 /// 4085 const SCEV *ScalarEvolution::createSCEV(Value *V) { 4086 if (!isSCEVable(V->getType())) 4087 return getUnknown(V); 4088 4089 unsigned Opcode = Instruction::UserOp1; 4090 if (Instruction *I = dyn_cast<Instruction>(V)) { 4091 Opcode = I->getOpcode(); 4092 4093 // Don't attempt to analyze instructions in blocks that aren't 4094 // reachable. Such instructions don't matter, and they aren't required 4095 // to obey basic rules for definitions dominating uses which this 4096 // analysis depends on. 4097 if (!DT->isReachableFromEntry(I->getParent())) 4098 return getUnknown(V); 4099 } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) 4100 Opcode = CE->getOpcode(); 4101 else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) 4102 return getConstant(CI); 4103 else if (isa<ConstantPointerNull>(V)) 4104 return getConstant(V->getType(), 0); 4105 else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) 4106 return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee()); 4107 else 4108 return getUnknown(V); 4109 4110 Operator *U = cast<Operator>(V); 4111 switch (Opcode) { 4112 case Instruction::Add: { 4113 // The simple thing to do would be to just call getSCEV on both operands 4114 // and call getAddExpr with the result. However if we're looking at a 4115 // bunch of things all added together, this can be quite inefficient, 4116 // because it leads to N-1 getAddExpr calls for N ultimate operands. 4117 // Instead, gather up all the operands and make a single getAddExpr call. 4118 // LLVM IR canonical form means we need only traverse the left operands. 4119 // 4120 // Don't apply this instruction's NSW or NUW flags to the new 4121 // expression. The instruction may be guarded by control flow that the 4122 // no-wrap behavior depends on. Non-control-equivalent instructions can be 4123 // mapped to the same SCEV expression, and it would be incorrect to transfer 4124 // NSW/NUW semantics to those operations. 4125 SmallVector<const SCEV *, 4> AddOps; 4126 AddOps.push_back(getSCEV(U->getOperand(1))); 4127 for (Value *Op = U->getOperand(0); ; Op = U->getOperand(0)) { 4128 unsigned Opcode = Op->getValueID() - Value::InstructionVal; 4129 if (Opcode != Instruction::Add && Opcode != Instruction::Sub) 4130 break; 4131 U = cast<Operator>(Op); 4132 const SCEV *Op1 = getSCEV(U->getOperand(1)); 4133 if (Opcode == Instruction::Sub) 4134 AddOps.push_back(getNegativeSCEV(Op1)); 4135 else 4136 AddOps.push_back(Op1); 4137 } 4138 AddOps.push_back(getSCEV(U->getOperand(0))); 4139 return getAddExpr(AddOps); 4140 } 4141 case Instruction::Mul: { 4142 // Don't transfer NSW/NUW for the same reason as AddExpr. 4143 SmallVector<const SCEV *, 4> MulOps; 4144 MulOps.push_back(getSCEV(U->getOperand(1))); 4145 for (Value *Op = U->getOperand(0); 4146 Op->getValueID() == Instruction::Mul + Value::InstructionVal; 4147 Op = U->getOperand(0)) { 4148 U = cast<Operator>(Op); 4149 MulOps.push_back(getSCEV(U->getOperand(1))); 4150 } 4151 MulOps.push_back(getSCEV(U->getOperand(0))); 4152 return getMulExpr(MulOps); 4153 } 4154 case Instruction::UDiv: 4155 return getUDivExpr(getSCEV(U->getOperand(0)), 4156 getSCEV(U->getOperand(1))); 4157 case Instruction::Sub: 4158 return getMinusSCEV(getSCEV(U->getOperand(0)), 4159 getSCEV(U->getOperand(1))); 4160 case Instruction::And: 4161 // For an expression like x&255 that merely masks off the high bits, 4162 // use zext(trunc(x)) as the SCEV expression. 4163 if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { 4164 if (CI->isNullValue()) 4165 return getSCEV(U->getOperand(1)); 4166 if (CI->isAllOnesValue()) 4167 return getSCEV(U->getOperand(0)); 4168 const APInt &A = CI->getValue(); 4169 4170 // Instcombine's ShrinkDemandedConstant may strip bits out of 4171 // constants, obscuring what would otherwise be a low-bits mask. 4172 // Use computeKnownBits to compute what ShrinkDemandedConstant 4173 // knew about to reconstruct a low-bits mask value. 4174 unsigned LZ = A.countLeadingZeros(); 4175 unsigned TZ = A.countTrailingZeros(); 4176 unsigned BitWidth = A.getBitWidth(); 4177 APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); 4178 computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL, 0, AC, 4179 nullptr, DT); 4180 4181 APInt EffectiveMask = 4182 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); 4183 if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) { 4184 const SCEV *MulCount = getConstant( 4185 ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, TZ))); 4186 return getMulExpr( 4187 getZeroExtendExpr( 4188 getTruncateExpr( 4189 getUDivExactExpr(getSCEV(U->getOperand(0)), MulCount), 4190 IntegerType::get(getContext(), BitWidth - LZ - TZ)), 4191 U->getType()), 4192 MulCount); 4193 } 4194 } 4195 break; 4196 4197 case Instruction::Or: 4198 // If the RHS of the Or is a constant, we may have something like: 4199 // X*4+1 which got turned into X*4|1. Handle this as an Add so loop 4200 // optimizations will transparently handle this case. 4201 // 4202 // In order for this transformation to be safe, the LHS must be of the 4203 // form X*(2^n) and the Or constant must be less than 2^n. 4204 if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { 4205 const SCEV *LHS = getSCEV(U->getOperand(0)); 4206 const APInt &CIVal = CI->getValue(); 4207 if (GetMinTrailingZeros(LHS) >= 4208 (CIVal.getBitWidth() - CIVal.countLeadingZeros())) { 4209 // Build a plain add SCEV. 4210 const SCEV *S = getAddExpr(LHS, getSCEV(CI)); 4211 // If the LHS of the add was an addrec and it has no-wrap flags, 4212 // transfer the no-wrap flags, since an or won't introduce a wrap. 4213 if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) { 4214 const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS); 4215 const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags( 4216 OldAR->getNoWrapFlags()); 4217 } 4218 return S; 4219 } 4220 } 4221 break; 4222 case Instruction::Xor: 4223 if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { 4224 // If the RHS of the xor is a signbit, then this is just an add. 4225 // Instcombine turns add of signbit into xor as a strength reduction step. 4226 if (CI->getValue().isSignBit()) 4227 return getAddExpr(getSCEV(U->getOperand(0)), 4228 getSCEV(U->getOperand(1))); 4229 4230 // If the RHS of xor is -1, then this is a not operation. 4231 if (CI->isAllOnesValue()) 4232 return getNotSCEV(getSCEV(U->getOperand(0))); 4233 4234 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask. 4235 // This is a variant of the check for xor with -1, and it handles 4236 // the case where instcombine has trimmed non-demanded bits out 4237 // of an xor with -1. 4238 if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0))) 4239 if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1))) 4240 if (BO->getOpcode() == Instruction::And && 4241 LCI->getValue() == CI->getValue()) 4242 if (const SCEVZeroExtendExpr *Z = 4243 dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) { 4244 Type *UTy = U->getType(); 4245 const SCEV *Z0 = Z->getOperand(); 4246 Type *Z0Ty = Z0->getType(); 4247 unsigned Z0TySize = getTypeSizeInBits(Z0Ty); 4248 4249 // If C is a low-bits mask, the zero extend is serving to 4250 // mask off the high bits. Complement the operand and 4251 // re-apply the zext. 4252 if (APIntOps::isMask(Z0TySize, CI->getValue())) 4253 return getZeroExtendExpr(getNotSCEV(Z0), UTy); 4254 4255 // If C is a single bit, it may be in the sign-bit position 4256 // before the zero-extend. In this case, represent the xor 4257 // using an add, which is equivalent, and re-apply the zext. 4258 APInt Trunc = CI->getValue().trunc(Z0TySize); 4259 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && 4260 Trunc.isSignBit()) 4261 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), 4262 UTy); 4263 } 4264 } 4265 break; 4266 4267 case Instruction::Shl: 4268 // Turn shift left of a constant amount into a multiply. 4269 if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) { 4270 uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth(); 4271 4272 // If the shift count is not less than the bitwidth, the result of 4273 // the shift is undefined. Don't try to analyze it, because the 4274 // resolution chosen here may differ from the resolution chosen in 4275 // other parts of the compiler. 4276 if (SA->getValue().uge(BitWidth)) 4277 break; 4278 4279 Constant *X = ConstantInt::get(getContext(), 4280 APInt::getOneBitSet(BitWidth, SA->getZExtValue())); 4281 return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X)); 4282 } 4283 break; 4284 4285 case Instruction::LShr: 4286 // Turn logical shift right of a constant into a unsigned divide. 4287 if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) { 4288 uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth(); 4289 4290 // If the shift count is not less than the bitwidth, the result of 4291 // the shift is undefined. Don't try to analyze it, because the 4292 // resolution chosen here may differ from the resolution chosen in 4293 // other parts of the compiler. 4294 if (SA->getValue().uge(BitWidth)) 4295 break; 4296 4297 Constant *X = ConstantInt::get(getContext(), 4298 APInt::getOneBitSet(BitWidth, SA->getZExtValue())); 4299 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X)); 4300 } 4301 break; 4302 4303 case Instruction::AShr: 4304 // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. 4305 if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) 4306 if (Operator *L = dyn_cast<Operator>(U->getOperand(0))) 4307 if (L->getOpcode() == Instruction::Shl && 4308 L->getOperand(1) == U->getOperand(1)) { 4309 uint64_t BitWidth = getTypeSizeInBits(U->getType()); 4310 4311 // If the shift count is not less than the bitwidth, the result of 4312 // the shift is undefined. Don't try to analyze it, because the 4313 // resolution chosen here may differ from the resolution chosen in 4314 // other parts of the compiler. 4315 if (CI->getValue().uge(BitWidth)) 4316 break; 4317 4318 uint64_t Amt = BitWidth - CI->getZExtValue(); 4319 if (Amt == BitWidth) 4320 return getSCEV(L->getOperand(0)); // shift by zero --> noop 4321 return 4322 getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)), 4323 IntegerType::get(getContext(), 4324 Amt)), 4325 U->getType()); 4326 } 4327 break; 4328 4329 case Instruction::Trunc: 4330 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType()); 4331 4332 case Instruction::ZExt: 4333 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType()); 4334 4335 case Instruction::SExt: 4336 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType()); 4337 4338 case Instruction::BitCast: 4339 // BitCasts are no-op casts so we just eliminate the cast. 4340 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) 4341 return getSCEV(U->getOperand(0)); 4342 break; 4343 4344 // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can 4345 // lead to pointer expressions which cannot safely be expanded to GEPs, 4346 // because ScalarEvolution doesn't respect the GEP aliasing rules when 4347 // simplifying integer expressions. 4348 4349 case Instruction::GetElementPtr: 4350 return createNodeForGEP(cast<GEPOperator>(U)); 4351 4352 case Instruction::PHI: 4353 return createNodeForPHI(cast<PHINode>(U)); 4354 4355 case Instruction::Select: 4356 // This could be a smax or umax that was lowered earlier. 4357 // Try to recover it. 4358 if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) { 4359 Value *LHS = ICI->getOperand(0); 4360 Value *RHS = ICI->getOperand(1); 4361 switch (ICI->getPredicate()) { 4362 case ICmpInst::ICMP_SLT: 4363 case ICmpInst::ICMP_SLE: 4364 std::swap(LHS, RHS); 4365 // fall through 4366 case ICmpInst::ICMP_SGT: 4367 case ICmpInst::ICMP_SGE: 4368 // a >s b ? a+x : b+x -> smax(a, b)+x 4369 // a >s b ? b+x : a+x -> smin(a, b)+x 4370 if (getTypeSizeInBits(LHS->getType()) <= 4371 getTypeSizeInBits(U->getType())) { 4372 const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType()); 4373 const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType()); 4374 const SCEV *LA = getSCEV(U->getOperand(1)); 4375 const SCEV *RA = getSCEV(U->getOperand(2)); 4376 const SCEV *LDiff = getMinusSCEV(LA, LS); 4377 const SCEV *RDiff = getMinusSCEV(RA, RS); 4378 if (LDiff == RDiff) 4379 return getAddExpr(getSMaxExpr(LS, RS), LDiff); 4380 LDiff = getMinusSCEV(LA, RS); 4381 RDiff = getMinusSCEV(RA, LS); 4382 if (LDiff == RDiff) 4383 return getAddExpr(getSMinExpr(LS, RS), LDiff); 4384 } 4385 break; 4386 case ICmpInst::ICMP_ULT: 4387 case ICmpInst::ICMP_ULE: 4388 std::swap(LHS, RHS); 4389 // fall through 4390 case ICmpInst::ICMP_UGT: 4391 case ICmpInst::ICMP_UGE: 4392 // a >u b ? a+x : b+x -> umax(a, b)+x 4393 // a >u b ? b+x : a+x -> umin(a, b)+x 4394 if (getTypeSizeInBits(LHS->getType()) <= 4395 getTypeSizeInBits(U->getType())) { 4396 const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); 4397 const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType()); 4398 const SCEV *LA = getSCEV(U->getOperand(1)); 4399 const SCEV *RA = getSCEV(U->getOperand(2)); 4400 const SCEV *LDiff = getMinusSCEV(LA, LS); 4401 const SCEV *RDiff = getMinusSCEV(RA, RS); 4402 if (LDiff == RDiff) 4403 return getAddExpr(getUMaxExpr(LS, RS), LDiff); 4404 LDiff = getMinusSCEV(LA, RS); 4405 RDiff = getMinusSCEV(RA, LS); 4406 if (LDiff == RDiff) 4407 return getAddExpr(getUMinExpr(LS, RS), LDiff); 4408 } 4409 break; 4410 case ICmpInst::ICMP_NE: 4411 // n != 0 ? n+x : 1+x -> umax(n, 1)+x 4412 if (getTypeSizeInBits(LHS->getType()) <= 4413 getTypeSizeInBits(U->getType()) && 4414 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) { 4415 const SCEV *One = getConstant(U->getType(), 1); 4416 const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); 4417 const SCEV *LA = getSCEV(U->getOperand(1)); 4418 const SCEV *RA = getSCEV(U->getOperand(2)); 4419 const SCEV *LDiff = getMinusSCEV(LA, LS); 4420 const SCEV *RDiff = getMinusSCEV(RA, One); 4421 if (LDiff == RDiff) 4422 return getAddExpr(getUMaxExpr(One, LS), LDiff); 4423 } 4424 break; 4425 case ICmpInst::ICMP_EQ: 4426 // n == 0 ? 1+x : n+x -> umax(n, 1)+x 4427 if (getTypeSizeInBits(LHS->getType()) <= 4428 getTypeSizeInBits(U->getType()) && 4429 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) { 4430 const SCEV *One = getConstant(U->getType(), 1); 4431 const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); 4432 const SCEV *LA = getSCEV(U->getOperand(1)); 4433 const SCEV *RA = getSCEV(U->getOperand(2)); 4434 const SCEV *LDiff = getMinusSCEV(LA, One); 4435 const SCEV *RDiff = getMinusSCEV(RA, LS); 4436 if (LDiff == RDiff) 4437 return getAddExpr(getUMaxExpr(One, LS), LDiff); 4438 } 4439 break; 4440 default: 4441 break; 4442 } 4443 } 4444 4445 default: // We cannot analyze this expression. 4446 break; 4447 } 4448 4449 return getUnknown(V); 4450 } 4451 4452 4453 4454 //===----------------------------------------------------------------------===// 4455 // Iteration Count Computation Code 4456 // 4457 4458 unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { 4459 if (BasicBlock *ExitingBB = L->getExitingBlock()) 4460 return getSmallConstantTripCount(L, ExitingBB); 4461 4462 // No trip count information for multiple exits. 4463 return 0; 4464 } 4465 4466 /// getSmallConstantTripCount - Returns the maximum trip count of this loop as a 4467 /// normal unsigned value. Returns 0 if the trip count is unknown or not 4468 /// constant. Will also return 0 if the maximum trip count is very large (>= 4469 /// 2^32). 4470 /// 4471 /// This "trip count" assumes that control exits via ExitingBlock. More 4472 /// precisely, it is the number of times that control may reach ExitingBlock 4473 /// before taking the branch. For loops with multiple exits, it may not be the 4474 /// number times that the loop header executes because the loop may exit 4475 /// prematurely via another branch. 4476 unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, 4477 BasicBlock *ExitingBlock) { 4478 assert(ExitingBlock && "Must pass a non-null exiting block!"); 4479 assert(L->isLoopExiting(ExitingBlock) && 4480 "Exiting block must actually branch out of the loop!"); 4481 const SCEVConstant *ExitCount = 4482 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock)); 4483 if (!ExitCount) 4484 return 0; 4485 4486 ConstantInt *ExitConst = ExitCount->getValue(); 4487 4488 // Guard against huge trip counts. 4489 if (ExitConst->getValue().getActiveBits() > 32) 4490 return 0; 4491 4492 // In case of integer overflow, this returns 0, which is correct. 4493 return ((unsigned)ExitConst->getZExtValue()) + 1; 4494 } 4495 4496 unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { 4497 if (BasicBlock *ExitingBB = L->getExitingBlock()) 4498 return getSmallConstantTripMultiple(L, ExitingBB); 4499 4500 // No trip multiple information for multiple exits. 4501 return 0; 4502 } 4503 4504 /// getSmallConstantTripMultiple - Returns the largest constant divisor of the 4505 /// trip count of this loop as a normal unsigned value, if possible. This 4506 /// means that the actual trip count is always a multiple of the returned 4507 /// value (don't forget the trip count could very well be zero as well!). 4508 /// 4509 /// Returns 1 if the trip count is unknown or not guaranteed to be the 4510 /// multiple of a constant (which is also the case if the trip count is simply 4511 /// constant, use getSmallConstantTripCount for that case), Will also return 1 4512 /// if the trip count is very large (>= 2^32). 4513 /// 4514 /// As explained in the comments for getSmallConstantTripCount, this assumes 4515 /// that control exits the loop via ExitingBlock. 4516 unsigned 4517 ScalarEvolution::getSmallConstantTripMultiple(Loop *L, 4518 BasicBlock *ExitingBlock) { 4519 assert(ExitingBlock && "Must pass a non-null exiting block!"); 4520 assert(L->isLoopExiting(ExitingBlock) && 4521 "Exiting block must actually branch out of the loop!"); 4522 const SCEV *ExitCount = getExitCount(L, ExitingBlock); 4523 if (ExitCount == getCouldNotCompute()) 4524 return 1; 4525 4526 // Get the trip count from the BE count by adding 1. 4527 const SCEV *TCMul = getAddExpr(ExitCount, 4528 getConstant(ExitCount->getType(), 1)); 4529 // FIXME: SCEV distributes multiplication as V1*C1 + V2*C1. We could attempt 4530 // to factor simple cases. 4531 if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(TCMul)) 4532 TCMul = Mul->getOperand(0); 4533 4534 const SCEVConstant *MulC = dyn_cast<SCEVConstant>(TCMul); 4535 if (!MulC) 4536 return 1; 4537 4538 ConstantInt *Result = MulC->getValue(); 4539 4540 // Guard against huge trip counts (this requires checking 4541 // for zero to handle the case where the trip count == -1 and the 4542 // addition wraps). 4543 if (!Result || Result->getValue().getActiveBits() > 32 || 4544 Result->getValue().getActiveBits() == 0) 4545 return 1; 4546 4547 return (unsigned)Result->getZExtValue(); 4548 } 4549 4550 // getExitCount - Get the expression for the number of loop iterations for which 4551 // this loop is guaranteed not to exit via ExitingBlock. Otherwise return 4552 // SCEVCouldNotCompute. 4553 const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) { 4554 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); 4555 } 4556 4557 /// getBackedgeTakenCount - If the specified loop has a predictable 4558 /// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute 4559 /// object. The backedge-taken count is the number of times the loop header 4560 /// will be branched to from within the loop. This is one less than the 4561 /// trip count of the loop, since it doesn't count the first iteration, 4562 /// when the header is branched to from outside the loop. 4563 /// 4564 /// Note that it is not valid to call this method on a loop without a 4565 /// loop-invariant backedge-taken count (see 4566 /// hasLoopInvariantBackedgeTakenCount). 4567 /// 4568 const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) { 4569 return getBackedgeTakenInfo(L).getExact(this); 4570 } 4571 4572 /// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except 4573 /// return the least SCEV value that is known never to be less than the 4574 /// actual backedge taken count. 4575 const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) { 4576 return getBackedgeTakenInfo(L).getMax(this); 4577 } 4578 4579 /// PushLoopPHIs - Push PHI nodes in the header of the given loop 4580 /// onto the given Worklist. 4581 static void 4582 PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) { 4583 BasicBlock *Header = L->getHeader(); 4584 4585 // Push all Loop-header PHIs onto the Worklist stack. 4586 for (BasicBlock::iterator I = Header->begin(); 4587 PHINode *PN = dyn_cast<PHINode>(I); ++I) 4588 Worklist.push_back(PN); 4589 } 4590 4591 const ScalarEvolution::BackedgeTakenInfo & 4592 ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { 4593 // Initially insert an invalid entry for this loop. If the insertion 4594 // succeeds, proceed to actually compute a backedge-taken count and 4595 // update the value. The temporary CouldNotCompute value tells SCEV 4596 // code elsewhere that it shouldn't attempt to request a new 4597 // backedge-taken count, which could result in infinite recursion. 4598 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair = 4599 BackedgeTakenCounts.insert(std::make_pair(L, BackedgeTakenInfo())); 4600 if (!Pair.second) 4601 return Pair.first->second; 4602 4603 // ComputeBackedgeTakenCount may allocate memory for its result. Inserting it 4604 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result 4605 // must be cleared in this scope. 4606 BackedgeTakenInfo Result = ComputeBackedgeTakenCount(L); 4607 4608 if (Result.getExact(this) != getCouldNotCompute()) { 4609 assert(isLoopInvariant(Result.getExact(this), L) && 4610 isLoopInvariant(Result.getMax(this), L) && 4611 "Computed backedge-taken count isn't loop invariant for loop!"); 4612 ++NumTripCountsComputed; 4613 } 4614 else if (Result.getMax(this) == getCouldNotCompute() && 4615 isa<PHINode>(L->getHeader()->begin())) { 4616 // Only count loops that have phi nodes as not being computable. 4617 ++NumTripCountsNotComputed; 4618 } 4619 4620 // Now that we know more about the trip count for this loop, forget any 4621 // existing SCEV values for PHI nodes in this loop since they are only 4622 // conservative estimates made without the benefit of trip count 4623 // information. This is similar to the code in forgetLoop, except that 4624 // it handles SCEVUnknown PHI nodes specially. 4625 if (Result.hasAnyInfo()) { 4626 SmallVector<Instruction *, 16> Worklist; 4627 PushLoopPHIs(L, Worklist); 4628 4629 SmallPtrSet<Instruction *, 8> Visited; 4630 while (!Worklist.empty()) { 4631 Instruction *I = Worklist.pop_back_val(); 4632 if (!Visited.insert(I).second) 4633 continue; 4634 4635 ValueExprMapType::iterator It = 4636 ValueExprMap.find_as(static_cast<Value *>(I)); 4637 if (It != ValueExprMap.end()) { 4638 const SCEV *Old = It->second; 4639 4640 // SCEVUnknown for a PHI either means that it has an unrecognized 4641 // structure, or it's a PHI that's in the progress of being computed 4642 // by createNodeForPHI. In the former case, additional loop trip 4643 // count information isn't going to change anything. In the later 4644 // case, createNodeForPHI will perform the necessary updates on its 4645 // own when it gets to that point. 4646 if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) { 4647 forgetMemoizedResults(Old); 4648 ValueExprMap.erase(It); 4649 } 4650 if (PHINode *PN = dyn_cast<PHINode>(I)) 4651 ConstantEvolutionLoopExitValue.erase(PN); 4652 } 4653 4654 PushDefUseChildren(I, Worklist); 4655 } 4656 } 4657 4658 // Re-lookup the insert position, since the call to 4659 // ComputeBackedgeTakenCount above could result in a 4660 // recusive call to getBackedgeTakenInfo (on a different 4661 // loop), which would invalidate the iterator computed 4662 // earlier. 4663 return BackedgeTakenCounts.find(L)->second = Result; 4664 } 4665 4666 /// forgetLoop - This method should be called by the client when it has 4667 /// changed a loop in a way that may effect ScalarEvolution's ability to 4668 /// compute a trip count, or if the loop is deleted. 4669 void ScalarEvolution::forgetLoop(const Loop *L) { 4670 // Drop any stored trip count value. 4671 DenseMap<const Loop*, BackedgeTakenInfo>::iterator BTCPos = 4672 BackedgeTakenCounts.find(L); 4673 if (BTCPos != BackedgeTakenCounts.end()) { 4674 BTCPos->second.clear(); 4675 BackedgeTakenCounts.erase(BTCPos); 4676 } 4677 4678 // Drop information about expressions based on loop-header PHIs. 4679 SmallVector<Instruction *, 16> Worklist; 4680 PushLoopPHIs(L, Worklist); 4681 4682 SmallPtrSet<Instruction *, 8> Visited; 4683 while (!Worklist.empty()) { 4684 Instruction *I = Worklist.pop_back_val(); 4685 if (!Visited.insert(I).second) 4686 continue; 4687 4688 ValueExprMapType::iterator It = 4689 ValueExprMap.find_as(static_cast<Value *>(I)); 4690 if (It != ValueExprMap.end()) { 4691 forgetMemoizedResults(It->second); 4692 ValueExprMap.erase(It); 4693 if (PHINode *PN = dyn_cast<PHINode>(I)) 4694 ConstantEvolutionLoopExitValue.erase(PN); 4695 } 4696 4697 PushDefUseChildren(I, Worklist); 4698 } 4699 4700 // Forget all contained loops too, to avoid dangling entries in the 4701 // ValuesAtScopes map. 4702 for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) 4703 forgetLoop(*I); 4704 } 4705 4706 /// forgetValue - This method should be called by the client when it has 4707 /// changed a value in a way that may effect its value, or which may 4708 /// disconnect it from a def-use chain linking it to a loop. 4709 void ScalarEvolution::forgetValue(Value *V) { 4710 Instruction *I = dyn_cast<Instruction>(V); 4711 if (!I) return; 4712 4713 // Drop information about expressions based on loop-header PHIs. 4714 SmallVector<Instruction *, 16> Worklist; 4715 Worklist.push_back(I); 4716 4717 SmallPtrSet<Instruction *, 8> Visited; 4718 while (!Worklist.empty()) { 4719 I = Worklist.pop_back_val(); 4720 if (!Visited.insert(I).second) 4721 continue; 4722 4723 ValueExprMapType::iterator It = 4724 ValueExprMap.find_as(static_cast<Value *>(I)); 4725 if (It != ValueExprMap.end()) { 4726 forgetMemoizedResults(It->second); 4727 ValueExprMap.erase(It); 4728 if (PHINode *PN = dyn_cast<PHINode>(I)) 4729 ConstantEvolutionLoopExitValue.erase(PN); 4730 } 4731 4732 PushDefUseChildren(I, Worklist); 4733 } 4734 } 4735 4736 /// getExact - Get the exact loop backedge taken count considering all loop 4737 /// exits. A computable result can only be return for loops with a single exit. 4738 /// Returning the minimum taken count among all exits is incorrect because one 4739 /// of the loop's exit limit's may have been skipped. HowFarToZero assumes that 4740 /// the limit of each loop test is never skipped. This is a valid assumption as 4741 /// long as the loop exits via that test. For precise results, it is the 4742 /// caller's responsibility to specify the relevant loop exit using 4743 /// getExact(ExitingBlock, SE). 4744 const SCEV * 4745 ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const { 4746 // If any exits were not computable, the loop is not computable. 4747 if (!ExitNotTaken.isCompleteList()) return SE->getCouldNotCompute(); 4748 4749 // We need exactly one computable exit. 4750 if (!ExitNotTaken.ExitingBlock) return SE->getCouldNotCompute(); 4751 assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info"); 4752 4753 const SCEV *BECount = nullptr; 4754 for (const ExitNotTakenInfo *ENT = &ExitNotTaken; 4755 ENT != nullptr; ENT = ENT->getNextExit()) { 4756 4757 assert(ENT->ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV"); 4758 4759 if (!BECount) 4760 BECount = ENT->ExactNotTaken; 4761 else if (BECount != ENT->ExactNotTaken) 4762 return SE->getCouldNotCompute(); 4763 } 4764 assert(BECount && "Invalid not taken count for loop exit"); 4765 return BECount; 4766 } 4767 4768 /// getExact - Get the exact not taken count for this loop exit. 4769 const SCEV * 4770 ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock, 4771 ScalarEvolution *SE) const { 4772 for (const ExitNotTakenInfo *ENT = &ExitNotTaken; 4773 ENT != nullptr; ENT = ENT->getNextExit()) { 4774 4775 if (ENT->ExitingBlock == ExitingBlock) 4776 return ENT->ExactNotTaken; 4777 } 4778 return SE->getCouldNotCompute(); 4779 } 4780 4781 /// getMax - Get the max backedge taken count for the loop. 4782 const SCEV * 4783 ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const { 4784 return Max ? Max : SE->getCouldNotCompute(); 4785 } 4786 4787 bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, 4788 ScalarEvolution *SE) const { 4789 if (Max && Max != SE->getCouldNotCompute() && SE->hasOperand(Max, S)) 4790 return true; 4791 4792 if (!ExitNotTaken.ExitingBlock) 4793 return false; 4794 4795 for (const ExitNotTakenInfo *ENT = &ExitNotTaken; 4796 ENT != nullptr; ENT = ENT->getNextExit()) { 4797 4798 if (ENT->ExactNotTaken != SE->getCouldNotCompute() 4799 && SE->hasOperand(ENT->ExactNotTaken, S)) { 4800 return true; 4801 } 4802 } 4803 return false; 4804 } 4805 4806 /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each 4807 /// computable exit into a persistent ExitNotTakenInfo array. 4808 ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( 4809 SmallVectorImpl< std::pair<BasicBlock *, const SCEV *> > &ExitCounts, 4810 bool Complete, const SCEV *MaxCount) : Max(MaxCount) { 4811 4812 if (!Complete) 4813 ExitNotTaken.setIncomplete(); 4814 4815 unsigned NumExits = ExitCounts.size(); 4816 if (NumExits == 0) return; 4817 4818 ExitNotTaken.ExitingBlock = ExitCounts[0].first; 4819 ExitNotTaken.ExactNotTaken = ExitCounts[0].second; 4820 if (NumExits == 1) return; 4821 4822 // Handle the rare case of multiple computable exits. 4823 ExitNotTakenInfo *ENT = new ExitNotTakenInfo[NumExits-1]; 4824 4825 ExitNotTakenInfo *PrevENT = &ExitNotTaken; 4826 for (unsigned i = 1; i < NumExits; ++i, PrevENT = ENT, ++ENT) { 4827 PrevENT->setNextExit(ENT); 4828 ENT->ExitingBlock = ExitCounts[i].first; 4829 ENT->ExactNotTaken = ExitCounts[i].second; 4830 } 4831 } 4832 4833 /// clear - Invalidate this result and free the ExitNotTakenInfo array. 4834 void ScalarEvolution::BackedgeTakenInfo::clear() { 4835 ExitNotTaken.ExitingBlock = nullptr; 4836 ExitNotTaken.ExactNotTaken = nullptr; 4837 delete[] ExitNotTaken.getNextExit(); 4838 } 4839 4840 /// ComputeBackedgeTakenCount - Compute the number of times the backedge 4841 /// of the specified loop will execute. 4842 ScalarEvolution::BackedgeTakenInfo 4843 ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { 4844 SmallVector<BasicBlock *, 8> ExitingBlocks; 4845 L->getExitingBlocks(ExitingBlocks); 4846 4847 SmallVector<std::pair<BasicBlock *, const SCEV *>, 4> ExitCounts; 4848 bool CouldComputeBECount = true; 4849 BasicBlock *Latch = L->getLoopLatch(); // may be NULL. 4850 const SCEV *MustExitMaxBECount = nullptr; 4851 const SCEV *MayExitMaxBECount = nullptr; 4852 4853 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts 4854 // and compute maxBECount. 4855 for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { 4856 BasicBlock *ExitBB = ExitingBlocks[i]; 4857 ExitLimit EL = ComputeExitLimit(L, ExitBB); 4858 4859 // 1. For each exit that can be computed, add an entry to ExitCounts. 4860 // CouldComputeBECount is true only if all exits can be computed. 4861 if (EL.Exact == getCouldNotCompute()) 4862 // We couldn't compute an exact value for this exit, so 4863 // we won't be able to compute an exact value for the loop. 4864 CouldComputeBECount = false; 4865 else 4866 ExitCounts.push_back(std::make_pair(ExitBB, EL.Exact)); 4867 4868 // 2. Derive the loop's MaxBECount from each exit's max number of 4869 // non-exiting iterations. Partition the loop exits into two kinds: 4870 // LoopMustExits and LoopMayExits. 4871 // 4872 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it 4873 // is a LoopMayExit. If any computable LoopMustExit is found, then 4874 // MaxBECount is the minimum EL.Max of computable LoopMustExits. Otherwise, 4875 // MaxBECount is conservatively the maximum EL.Max, where CouldNotCompute is 4876 // considered greater than any computable EL.Max. 4877 if (EL.Max != getCouldNotCompute() && Latch && 4878 DT->dominates(ExitBB, Latch)) { 4879 if (!MustExitMaxBECount) 4880 MustExitMaxBECount = EL.Max; 4881 else { 4882 MustExitMaxBECount = 4883 getUMinFromMismatchedTypes(MustExitMaxBECount, EL.Max); 4884 } 4885 } else if (MayExitMaxBECount != getCouldNotCompute()) { 4886 if (!MayExitMaxBECount || EL.Max == getCouldNotCompute()) 4887 MayExitMaxBECount = EL.Max; 4888 else { 4889 MayExitMaxBECount = 4890 getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.Max); 4891 } 4892 } 4893 } 4894 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount : 4895 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); 4896 return BackedgeTakenInfo(ExitCounts, CouldComputeBECount, MaxBECount); 4897 } 4898 4899 /// ComputeExitLimit - Compute the number of times the backedge of the specified 4900 /// loop will execute if it exits via the specified block. 4901 ScalarEvolution::ExitLimit 4902 ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { 4903 4904 // Okay, we've chosen an exiting block. See what condition causes us to 4905 // exit at this block and remember the exit block and whether all other targets 4906 // lead to the loop header. 4907 bool MustExecuteLoopHeader = true; 4908 BasicBlock *Exit = nullptr; 4909 for (succ_iterator SI = succ_begin(ExitingBlock), SE = succ_end(ExitingBlock); 4910 SI != SE; ++SI) 4911 if (!L->contains(*SI)) { 4912 if (Exit) // Multiple exit successors. 4913 return getCouldNotCompute(); 4914 Exit = *SI; 4915 } else if (*SI != L->getHeader()) { 4916 MustExecuteLoopHeader = false; 4917 } 4918 4919 // At this point, we know we have a conditional branch that determines whether 4920 // the loop is exited. However, we don't know if the branch is executed each 4921 // time through the loop. If not, then the execution count of the branch will 4922 // not be equal to the trip count of the loop. 4923 // 4924 // Currently we check for this by checking to see if the Exit branch goes to 4925 // the loop header. If so, we know it will always execute the same number of 4926 // times as the loop. We also handle the case where the exit block *is* the 4927 // loop header. This is common for un-rotated loops. 4928 // 4929 // If both of those tests fail, walk up the unique predecessor chain to the 4930 // header, stopping if there is an edge that doesn't exit the loop. If the 4931 // header is reached, the execution count of the branch will be equal to the 4932 // trip count of the loop. 4933 // 4934 // More extensive analysis could be done to handle more cases here. 4935 // 4936 if (!MustExecuteLoopHeader && ExitingBlock != L->getHeader()) { 4937 // The simple checks failed, try climbing the unique predecessor chain 4938 // up to the header. 4939 bool Ok = false; 4940 for (BasicBlock *BB = ExitingBlock; BB; ) { 4941 BasicBlock *Pred = BB->getUniquePredecessor(); 4942 if (!Pred) 4943 return getCouldNotCompute(); 4944 TerminatorInst *PredTerm = Pred->getTerminator(); 4945 for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) { 4946 BasicBlock *PredSucc = PredTerm->getSuccessor(i); 4947 if (PredSucc == BB) 4948 continue; 4949 // If the predecessor has a successor that isn't BB and isn't 4950 // outside the loop, assume the worst. 4951 if (L->contains(PredSucc)) 4952 return getCouldNotCompute(); 4953 } 4954 if (Pred == L->getHeader()) { 4955 Ok = true; 4956 break; 4957 } 4958 BB = Pred; 4959 } 4960 if (!Ok) 4961 return getCouldNotCompute(); 4962 } 4963 4964 bool IsOnlyExit = (L->getExitingBlock() != nullptr); 4965 TerminatorInst *Term = ExitingBlock->getTerminator(); 4966 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) { 4967 assert(BI->isConditional() && "If unconditional, it can't be in loop!"); 4968 // Proceed to the next level to examine the exit condition expression. 4969 return ComputeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0), 4970 BI->getSuccessor(1), 4971 /*ControlsExit=*/IsOnlyExit); 4972 } 4973 4974 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) 4975 return ComputeExitLimitFromSingleExitSwitch(L, SI, Exit, 4976 /*ControlsExit=*/IsOnlyExit); 4977 4978 return getCouldNotCompute(); 4979 } 4980 4981 /// ComputeExitLimitFromCond - Compute the number of times the 4982 /// backedge of the specified loop will execute if its exit condition 4983 /// were a conditional branch of ExitCond, TBB, and FBB. 4984 /// 4985 /// @param ControlsExit is true if ExitCond directly controls the exit 4986 /// branch. In this case, we can assume that the loop exits only if the 4987 /// condition is true and can infer that failing to meet the condition prior to 4988 /// integer wraparound results in undefined behavior. 4989 ScalarEvolution::ExitLimit 4990 ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, 4991 Value *ExitCond, 4992 BasicBlock *TBB, 4993 BasicBlock *FBB, 4994 bool ControlsExit) { 4995 // Check if the controlling expression for this loop is an And or Or. 4996 if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) { 4997 if (BO->getOpcode() == Instruction::And) { 4998 // Recurse on the operands of the and. 4999 bool EitherMayExit = L->contains(TBB); 5000 ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, 5001 ControlsExit && !EitherMayExit); 5002 ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, 5003 ControlsExit && !EitherMayExit); 5004 const SCEV *BECount = getCouldNotCompute(); 5005 const SCEV *MaxBECount = getCouldNotCompute(); 5006 if (EitherMayExit) { 5007 // Both conditions must be true for the loop to continue executing. 5008 // Choose the less conservative count. 5009 if (EL0.Exact == getCouldNotCompute() || 5010 EL1.Exact == getCouldNotCompute()) 5011 BECount = getCouldNotCompute(); 5012 else 5013 BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact); 5014 if (EL0.Max == getCouldNotCompute()) 5015 MaxBECount = EL1.Max; 5016 else if (EL1.Max == getCouldNotCompute()) 5017 MaxBECount = EL0.Max; 5018 else 5019 MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); 5020 } else { 5021 // Both conditions must be true at the same time for the loop to exit. 5022 // For now, be conservative. 5023 assert(L->contains(FBB) && "Loop block has no successor in loop!"); 5024 if (EL0.Max == EL1.Max) 5025 MaxBECount = EL0.Max; 5026 if (EL0.Exact == EL1.Exact) 5027 BECount = EL0.Exact; 5028 } 5029 5030 return ExitLimit(BECount, MaxBECount); 5031 } 5032 if (BO->getOpcode() == Instruction::Or) { 5033 // Recurse on the operands of the or. 5034 bool EitherMayExit = L->contains(FBB); 5035 ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, 5036 ControlsExit && !EitherMayExit); 5037 ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, 5038 ControlsExit && !EitherMayExit); 5039 const SCEV *BECount = getCouldNotCompute(); 5040 const SCEV *MaxBECount = getCouldNotCompute(); 5041 if (EitherMayExit) { 5042 // Both conditions must be false for the loop to continue executing. 5043 // Choose the less conservative count. 5044 if (EL0.Exact == getCouldNotCompute() || 5045 EL1.Exact == getCouldNotCompute()) 5046 BECount = getCouldNotCompute(); 5047 else 5048 BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact); 5049 if (EL0.Max == getCouldNotCompute()) 5050 MaxBECount = EL1.Max; 5051 else if (EL1.Max == getCouldNotCompute()) 5052 MaxBECount = EL0.Max; 5053 else 5054 MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); 5055 } else { 5056 // Both conditions must be false at the same time for the loop to exit. 5057 // For now, be conservative. 5058 assert(L->contains(TBB) && "Loop block has no successor in loop!"); 5059 if (EL0.Max == EL1.Max) 5060 MaxBECount = EL0.Max; 5061 if (EL0.Exact == EL1.Exact) 5062 BECount = EL0.Exact; 5063 } 5064 5065 return ExitLimit(BECount, MaxBECount); 5066 } 5067 } 5068 5069 // With an icmp, it may be feasible to compute an exact backedge-taken count. 5070 // Proceed to the next level to examine the icmp. 5071 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) 5072 return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); 5073 5074 // Check for a constant condition. These are normally stripped out by 5075 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to 5076 // preserve the CFG and is temporarily leaving constant conditions 5077 // in place. 5078 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) { 5079 if (L->contains(FBB) == !CI->getZExtValue()) 5080 // The backedge is always taken. 5081 return getCouldNotCompute(); 5082 else 5083 // The backedge is never taken. 5084 return getConstant(CI->getType(), 0); 5085 } 5086 5087 // If it's not an integer or pointer comparison then compute it the hard way. 5088 return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); 5089 } 5090 5091 /// ComputeExitLimitFromICmp - Compute the number of times the 5092 /// backedge of the specified loop will execute if its exit condition 5093 /// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB. 5094 ScalarEvolution::ExitLimit 5095 ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, 5096 ICmpInst *ExitCond, 5097 BasicBlock *TBB, 5098 BasicBlock *FBB, 5099 bool ControlsExit) { 5100 5101 // If the condition was exit on true, convert the condition to exit on false 5102 ICmpInst::Predicate Cond; 5103 if (!L->contains(FBB)) 5104 Cond = ExitCond->getPredicate(); 5105 else 5106 Cond = ExitCond->getInversePredicate(); 5107 5108 // Handle common loops like: for (X = "string"; *X; ++X) 5109 if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0))) 5110 if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) { 5111 ExitLimit ItCnt = 5112 ComputeLoadConstantCompareExitLimit(LI, RHS, L, Cond); 5113 if (ItCnt.hasAnyInfo()) 5114 return ItCnt; 5115 } 5116 5117 const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); 5118 const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); 5119 5120 // Try to evaluate any dependencies out of the loop. 5121 LHS = getSCEVAtScope(LHS, L); 5122 RHS = getSCEVAtScope(RHS, L); 5123 5124 // At this point, we would like to compute how many iterations of the 5125 // loop the predicate will return true for these inputs. 5126 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) { 5127 // If there is a loop-invariant, force it into the RHS. 5128 std::swap(LHS, RHS); 5129 Cond = ICmpInst::getSwappedPredicate(Cond); 5130 } 5131 5132 // Simplify the operands before analyzing them. 5133 (void)SimplifyICmpOperands(Cond, LHS, RHS); 5134 5135 // If we have a comparison of a chrec against a constant, try to use value 5136 // ranges to answer this query. 5137 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) 5138 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS)) 5139 if (AddRec->getLoop() == L) { 5140 // Form the constant range. 5141 ConstantRange CompRange( 5142 ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue())); 5143 5144 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); 5145 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret; 5146 } 5147 5148 switch (Cond) { 5149 case ICmpInst::ICMP_NE: { // while (X != Y) 5150 // Convert to: while (X-Y != 0) 5151 ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); 5152 if (EL.hasAnyInfo()) return EL; 5153 break; 5154 } 5155 case ICmpInst::ICMP_EQ: { // while (X == Y) 5156 // Convert to: while (X-Y == 0) 5157 ExitLimit EL = HowFarToNonZero(getMinusSCEV(LHS, RHS), L); 5158 if (EL.hasAnyInfo()) return EL; 5159 break; 5160 } 5161 case ICmpInst::ICMP_SLT: 5162 case ICmpInst::ICMP_ULT: { // while (X < Y) 5163 bool IsSigned = Cond == ICmpInst::ICMP_SLT; 5164 ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, ControlsExit); 5165 if (EL.hasAnyInfo()) return EL; 5166 break; 5167 } 5168 case ICmpInst::ICMP_SGT: 5169 case ICmpInst::ICMP_UGT: { // while (X > Y) 5170 bool IsSigned = Cond == ICmpInst::ICMP_SGT; 5171 ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit); 5172 if (EL.hasAnyInfo()) return EL; 5173 break; 5174 } 5175 default: 5176 #if 0 5177 dbgs() << "ComputeBackedgeTakenCount "; 5178 if (ExitCond->getOperand(0)->getType()->isUnsigned()) 5179 dbgs() << "[unsigned] "; 5180 dbgs() << *LHS << " " 5181 << Instruction::getOpcodeName(Instruction::ICmp) 5182 << " " << *RHS << "\n"; 5183 #endif 5184 break; 5185 } 5186 return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); 5187 } 5188 5189 ScalarEvolution::ExitLimit 5190 ScalarEvolution::ComputeExitLimitFromSingleExitSwitch(const Loop *L, 5191 SwitchInst *Switch, 5192 BasicBlock *ExitingBlock, 5193 bool ControlsExit) { 5194 assert(!L->contains(ExitingBlock) && "Not an exiting block!"); 5195 5196 // Give up if the exit is the default dest of a switch. 5197 if (Switch->getDefaultDest() == ExitingBlock) 5198 return getCouldNotCompute(); 5199 5200 assert(L->contains(Switch->getDefaultDest()) && 5201 "Default case must not exit the loop!"); 5202 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L); 5203 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); 5204 5205 // while (X != Y) --> while (X-Y != 0) 5206 ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); 5207 if (EL.hasAnyInfo()) 5208 return EL; 5209 5210 return getCouldNotCompute(); 5211 } 5212 5213 static ConstantInt * 5214 EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, 5215 ScalarEvolution &SE) { 5216 const SCEV *InVal = SE.getConstant(C); 5217 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE); 5218 assert(isa<SCEVConstant>(Val) && 5219 "Evaluation of SCEV at constant didn't fold correctly?"); 5220 return cast<SCEVConstant>(Val)->getValue(); 5221 } 5222 5223 /// ComputeLoadConstantCompareExitLimit - Given an exit condition of 5224 /// 'icmp op load X, cst', try to see if we can compute the backedge 5225 /// execution count. 5226 ScalarEvolution::ExitLimit 5227 ScalarEvolution::ComputeLoadConstantCompareExitLimit( 5228 LoadInst *LI, 5229 Constant *RHS, 5230 const Loop *L, 5231 ICmpInst::Predicate predicate) { 5232 5233 if (LI->isVolatile()) return getCouldNotCompute(); 5234 5235 // Check to see if the loaded pointer is a getelementptr of a global. 5236 // TODO: Use SCEV instead of manually grubbing with GEPs. 5237 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0)); 5238 if (!GEP) return getCouldNotCompute(); 5239 5240 // Make sure that it is really a constant global we are gepping, with an 5241 // initializer, and make sure the first IDX is really 0. 5242 GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)); 5243 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() || 5244 GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) || 5245 !cast<Constant>(GEP->getOperand(1))->isNullValue()) 5246 return getCouldNotCompute(); 5247 5248 // Okay, we allow one non-constant index into the GEP instruction. 5249 Value *VarIdx = nullptr; 5250 std::vector<Constant*> Indexes; 5251 unsigned VarIdxNum = 0; 5252 for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i) 5253 if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) { 5254 Indexes.push_back(CI); 5255 } else if (!isa<ConstantInt>(GEP->getOperand(i))) { 5256 if (VarIdx) return getCouldNotCompute(); // Multiple non-constant idx's. 5257 VarIdx = GEP->getOperand(i); 5258 VarIdxNum = i-2; 5259 Indexes.push_back(nullptr); 5260 } 5261 5262 // Loop-invariant loads may be a byproduct of loop optimization. Skip them. 5263 if (!VarIdx) 5264 return getCouldNotCompute(); 5265 5266 // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant. 5267 // Check to see if X is a loop variant variable value now. 5268 const SCEV *Idx = getSCEV(VarIdx); 5269 Idx = getSCEVAtScope(Idx, L); 5270 5271 // We can only recognize very limited forms of loop index expressions, in 5272 // particular, only affine AddRec's like {C1,+,C2}. 5273 const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx); 5274 if (!IdxExpr || !IdxExpr->isAffine() || isLoopInvariant(IdxExpr, L) || 5275 !isa<SCEVConstant>(IdxExpr->getOperand(0)) || 5276 !isa<SCEVConstant>(IdxExpr->getOperand(1))) 5277 return getCouldNotCompute(); 5278 5279 unsigned MaxSteps = MaxBruteForceIterations; 5280 for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) { 5281 ConstantInt *ItCst = ConstantInt::get( 5282 cast<IntegerType>(IdxExpr->getType()), IterationNum); 5283 ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this); 5284 5285 // Form the GEP offset. 5286 Indexes[VarIdxNum] = Val; 5287 5288 Constant *Result = ConstantFoldLoadThroughGEPIndices(GV->getInitializer(), 5289 Indexes); 5290 if (!Result) break; // Cannot compute! 5291 5292 // Evaluate the condition for this iteration. 5293 Result = ConstantExpr::getICmp(predicate, Result, RHS); 5294 if (!isa<ConstantInt>(Result)) break; // Couldn't decide for sure 5295 if (cast<ConstantInt>(Result)->getValue().isMinValue()) { 5296 #if 0 5297 dbgs() << "\n***\n*** Computed loop count " << *ItCst 5298 << "\n*** From global " << *GV << "*** BB: " << *L->getHeader() 5299 << "***\n"; 5300 #endif 5301 ++NumArrayLenItCounts; 5302 return getConstant(ItCst); // Found terminating iteration! 5303 } 5304 } 5305 return getCouldNotCompute(); 5306 } 5307 5308 5309 /// CanConstantFold - Return true if we can constant fold an instruction of the 5310 /// specified type, assuming that all operands were constants. 5311 static bool CanConstantFold(const Instruction *I) { 5312 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) || 5313 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) || 5314 isa<LoadInst>(I)) 5315 return true; 5316 5317 if (const CallInst *CI = dyn_cast<CallInst>(I)) 5318 if (const Function *F = CI->getCalledFunction()) 5319 return canConstantFoldCallTo(F); 5320 return false; 5321 } 5322 5323 /// Determine whether this instruction can constant evolve within this loop 5324 /// assuming its operands can all constant evolve. 5325 static bool canConstantEvolve(Instruction *I, const Loop *L) { 5326 // An instruction outside of the loop can't be derived from a loop PHI. 5327 if (!L->contains(I)) return false; 5328 5329 if (isa<PHINode>(I)) { 5330 if (L->getHeader() == I->getParent()) 5331 return true; 5332 else 5333 // We don't currently keep track of the control flow needed to evaluate 5334 // PHIs, so we cannot handle PHIs inside of loops. 5335 return false; 5336 } 5337 5338 // If we won't be able to constant fold this expression even if the operands 5339 // are constants, bail early. 5340 return CanConstantFold(I); 5341 } 5342 5343 /// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by 5344 /// recursing through each instruction operand until reaching a loop header phi. 5345 static PHINode * 5346 getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, 5347 DenseMap<Instruction *, PHINode *> &PHIMap) { 5348 5349 // Otherwise, we can evaluate this instruction if all of its operands are 5350 // constant or derived from a PHI node themselves. 5351 PHINode *PHI = nullptr; 5352 for (Instruction::op_iterator OpI = UseInst->op_begin(), 5353 OpE = UseInst->op_end(); OpI != OpE; ++OpI) { 5354 5355 if (isa<Constant>(*OpI)) continue; 5356 5357 Instruction *OpInst = dyn_cast<Instruction>(*OpI); 5358 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr; 5359 5360 PHINode *P = dyn_cast<PHINode>(OpInst); 5361 if (!P) 5362 // If this operand is already visited, reuse the prior result. 5363 // We may have P != PHI if this is the deepest point at which the 5364 // inconsistent paths meet. 5365 P = PHIMap.lookup(OpInst); 5366 if (!P) { 5367 // Recurse and memoize the results, whether a phi is found or not. 5368 // This recursive call invalidates pointers into PHIMap. 5369 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap); 5370 PHIMap[OpInst] = P; 5371 } 5372 if (!P) 5373 return nullptr; // Not evolving from PHI 5374 if (PHI && PHI != P) 5375 return nullptr; // Evolving from multiple different PHIs. 5376 PHI = P; 5377 } 5378 // This is a expression evolving from a constant PHI! 5379 return PHI; 5380 } 5381 5382 /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node 5383 /// in the loop that V is derived from. We allow arbitrary operations along the 5384 /// way, but the operands of an operation must either be constants or a value 5385 /// derived from a constant PHI. If this expression does not fit with these 5386 /// constraints, return null. 5387 static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { 5388 Instruction *I = dyn_cast<Instruction>(V); 5389 if (!I || !canConstantEvolve(I, L)) return nullptr; 5390 5391 if (PHINode *PN = dyn_cast<PHINode>(I)) { 5392 return PN; 5393 } 5394 5395 // Record non-constant instructions contained by the loop. 5396 DenseMap<Instruction *, PHINode *> PHIMap; 5397 return getConstantEvolvingPHIOperands(I, L, PHIMap); 5398 } 5399 5400 /// EvaluateExpression - Given an expression that passes the 5401 /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node 5402 /// in the loop has the value PHIVal. If we can't fold this expression for some 5403 /// reason, return null. 5404 static Constant *EvaluateExpression(Value *V, const Loop *L, 5405 DenseMap<Instruction *, Constant *> &Vals, 5406 const DataLayout *DL, 5407 const TargetLibraryInfo *TLI) { 5408 // Convenient constant check, but redundant for recursive calls. 5409 if (Constant *C = dyn_cast<Constant>(V)) return C; 5410 Instruction *I = dyn_cast<Instruction>(V); 5411 if (!I) return nullptr; 5412 5413 if (Constant *C = Vals.lookup(I)) return C; 5414 5415 // An instruction inside the loop depends on a value outside the loop that we 5416 // weren't given a mapping for, or a value such as a call inside the loop. 5417 if (!canConstantEvolve(I, L)) return nullptr; 5418 5419 // An unmapped PHI can be due to a branch or another loop inside this loop, 5420 // or due to this not being the initial iteration through a loop where we 5421 // couldn't compute the evolution of this particular PHI last time. 5422 if (isa<PHINode>(I)) return nullptr; 5423 5424 std::vector<Constant*> Operands(I->getNumOperands()); 5425 5426 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { 5427 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i)); 5428 if (!Operand) { 5429 Operands[i] = dyn_cast<Constant>(I->getOperand(i)); 5430 if (!Operands[i]) return nullptr; 5431 continue; 5432 } 5433 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI); 5434 Vals[Operand] = C; 5435 if (!C) return nullptr; 5436 Operands[i] = C; 5437 } 5438 5439 if (CmpInst *CI = dyn_cast<CmpInst>(I)) 5440 return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], 5441 Operands[1], DL, TLI); 5442 if (LoadInst *LI = dyn_cast<LoadInst>(I)) { 5443 if (!LI->isVolatile()) 5444 return ConstantFoldLoadFromConstPtr(Operands[0], DL); 5445 } 5446 return ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, DL, 5447 TLI); 5448 } 5449 5450 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is 5451 /// in the header of its containing loop, we know the loop executes a 5452 /// constant number of times, and the PHI node is just a recurrence 5453 /// involving constants, fold it. 5454 Constant * 5455 ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, 5456 const APInt &BEs, 5457 const Loop *L) { 5458 DenseMap<PHINode*, Constant*>::const_iterator I = 5459 ConstantEvolutionLoopExitValue.find(PN); 5460 if (I != ConstantEvolutionLoopExitValue.end()) 5461 return I->second; 5462 5463 if (BEs.ugt(MaxBruteForceIterations)) 5464 return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it. 5465 5466 Constant *&RetVal = ConstantEvolutionLoopExitValue[PN]; 5467 5468 DenseMap<Instruction *, Constant *> CurrentIterVals; 5469 BasicBlock *Header = L->getHeader(); 5470 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!"); 5471 5472 // Since the loop is canonicalized, the PHI node must have two entries. One 5473 // entry must be a constant (coming in from outside of the loop), and the 5474 // second must be derived from the same PHI. 5475 bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1)); 5476 PHINode *PHI = nullptr; 5477 for (BasicBlock::iterator I = Header->begin(); 5478 (PHI = dyn_cast<PHINode>(I)); ++I) { 5479 Constant *StartCST = 5480 dyn_cast<Constant>(PHI->getIncomingValue(!SecondIsBackedge)); 5481 if (!StartCST) continue; 5482 CurrentIterVals[PHI] = StartCST; 5483 } 5484 if (!CurrentIterVals.count(PN)) 5485 return RetVal = nullptr; 5486 5487 Value *BEValue = PN->getIncomingValue(SecondIsBackedge); 5488 5489 // Execute the loop symbolically to determine the exit value. 5490 if (BEs.getActiveBits() >= 32) 5491 return RetVal = nullptr; // More than 2^32-1 iterations?? Not doing it! 5492 5493 unsigned NumIterations = BEs.getZExtValue(); // must be in range 5494 unsigned IterationNum = 0; 5495 for (; ; ++IterationNum) { 5496 if (IterationNum == NumIterations) 5497 return RetVal = CurrentIterVals[PN]; // Got exit value! 5498 5499 // Compute the value of the PHIs for the next iteration. 5500 // EvaluateExpression adds non-phi values to the CurrentIterVals map. 5501 DenseMap<Instruction *, Constant *> NextIterVals; 5502 Constant *NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, 5503 TLI); 5504 if (!NextPHI) 5505 return nullptr; // Couldn't evaluate! 5506 NextIterVals[PN] = NextPHI; 5507 5508 bool StoppedEvolving = NextPHI == CurrentIterVals[PN]; 5509 5510 // Also evaluate the other PHI nodes. However, we don't get to stop if we 5511 // cease to be able to evaluate one of them or if they stop evolving, 5512 // because that doesn't necessarily prevent us from computing PN. 5513 SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute; 5514 for (DenseMap<Instruction *, Constant *>::const_iterator 5515 I = CurrentIterVals.begin(), E = CurrentIterVals.end(); I != E; ++I){ 5516 PHINode *PHI = dyn_cast<PHINode>(I->first); 5517 if (!PHI || PHI == PN || PHI->getParent() != Header) continue; 5518 PHIsToCompute.push_back(std::make_pair(PHI, I->second)); 5519 } 5520 // We use two distinct loops because EvaluateExpression may invalidate any 5521 // iterators into CurrentIterVals. 5522 for (SmallVectorImpl<std::pair<PHINode *, Constant*> >::const_iterator 5523 I = PHIsToCompute.begin(), E = PHIsToCompute.end(); I != E; ++I) { 5524 PHINode *PHI = I->first; 5525 Constant *&NextPHI = NextIterVals[PHI]; 5526 if (!NextPHI) { // Not already computed. 5527 Value *BEValue = PHI->getIncomingValue(SecondIsBackedge); 5528 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, TLI); 5529 } 5530 if (NextPHI != I->second) 5531 StoppedEvolving = false; 5532 } 5533 5534 // If all entries in CurrentIterVals == NextIterVals then we can stop 5535 // iterating, the loop can't continue to change. 5536 if (StoppedEvolving) 5537 return RetVal = CurrentIterVals[PN]; 5538 5539 CurrentIterVals.swap(NextIterVals); 5540 } 5541 } 5542 5543 /// ComputeExitCountExhaustively - If the loop is known to execute a 5544 /// constant number of times (the condition evolves only from constants), 5545 /// try to evaluate a few iterations of the loop until we get the exit 5546 /// condition gets a value of ExitWhen (true or false). If we cannot 5547 /// evaluate the trip count of the loop, return getCouldNotCompute(). 5548 const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L, 5549 Value *Cond, 5550 bool ExitWhen) { 5551 PHINode *PN = getConstantEvolvingPHI(Cond, L); 5552 if (!PN) return getCouldNotCompute(); 5553 5554 // If the loop is canonicalized, the PHI will have exactly two entries. 5555 // That's the only form we support here. 5556 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute(); 5557 5558 DenseMap<Instruction *, Constant *> CurrentIterVals; 5559 BasicBlock *Header = L->getHeader(); 5560 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!"); 5561 5562 // One entry must be a constant (coming in from outside of the loop), and the 5563 // second must be derived from the same PHI. 5564 bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1)); 5565 PHINode *PHI = nullptr; 5566 for (BasicBlock::iterator I = Header->begin(); 5567 (PHI = dyn_cast<PHINode>(I)); ++I) { 5568 Constant *StartCST = 5569 dyn_cast<Constant>(PHI->getIncomingValue(!SecondIsBackedge)); 5570 if (!StartCST) continue; 5571 CurrentIterVals[PHI] = StartCST; 5572 } 5573 if (!CurrentIterVals.count(PN)) 5574 return getCouldNotCompute(); 5575 5576 // Okay, we find a PHI node that defines the trip count of this loop. Execute 5577 // the loop symbolically to determine when the condition gets a value of 5578 // "ExitWhen". 5579 5580 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis. 5581 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){ 5582 ConstantInt *CondVal = 5583 dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, L, CurrentIterVals, 5584 DL, TLI)); 5585 5586 // Couldn't symbolically evaluate. 5587 if (!CondVal) return getCouldNotCompute(); 5588 5589 if (CondVal->getValue() == uint64_t(ExitWhen)) { 5590 ++NumBruteForceTripCountsComputed; 5591 return getConstant(Type::getInt32Ty(getContext()), IterationNum); 5592 } 5593 5594 // Update all the PHI nodes for the next iteration. 5595 DenseMap<Instruction *, Constant *> NextIterVals; 5596 5597 // Create a list of which PHIs we need to compute. We want to do this before 5598 // calling EvaluateExpression on them because that may invalidate iterators 5599 // into CurrentIterVals. 5600 SmallVector<PHINode *, 8> PHIsToCompute; 5601 for (DenseMap<Instruction *, Constant *>::const_iterator 5602 I = CurrentIterVals.begin(), E = CurrentIterVals.end(); I != E; ++I){ 5603 PHINode *PHI = dyn_cast<PHINode>(I->first); 5604 if (!PHI || PHI->getParent() != Header) continue; 5605 PHIsToCompute.push_back(PHI); 5606 } 5607 for (SmallVectorImpl<PHINode *>::const_iterator I = PHIsToCompute.begin(), 5608 E = PHIsToCompute.end(); I != E; ++I) { 5609 PHINode *PHI = *I; 5610 Constant *&NextPHI = NextIterVals[PHI]; 5611 if (NextPHI) continue; // Already computed! 5612 5613 Value *BEValue = PHI->getIncomingValue(SecondIsBackedge); 5614 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, TLI); 5615 } 5616 CurrentIterVals.swap(NextIterVals); 5617 } 5618 5619 // Too many iterations were needed to evaluate. 5620 return getCouldNotCompute(); 5621 } 5622 5623 /// getSCEVAtScope - Return a SCEV expression for the specified value 5624 /// at the specified scope in the program. The L value specifies a loop 5625 /// nest to evaluate the expression at, where null is the top-level or a 5626 /// specified loop is immediately inside of the loop. 5627 /// 5628 /// This method can be used to compute the exit value for a variable defined 5629 /// in a loop by querying what the value will hold in the parent loop. 5630 /// 5631 /// In the case that a relevant loop exit value cannot be computed, the 5632 /// original value V is returned. 5633 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { 5634 // Check to see if we've folded this expression at this loop before. 5635 SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = ValuesAtScopes[V]; 5636 for (unsigned u = 0; u < Values.size(); u++) { 5637 if (Values[u].first == L) 5638 return Values[u].second ? Values[u].second : V; 5639 } 5640 Values.push_back(std::make_pair(L, static_cast<const SCEV *>(nullptr))); 5641 // Otherwise compute it. 5642 const SCEV *C = computeSCEVAtScope(V, L); 5643 SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values2 = ValuesAtScopes[V]; 5644 for (unsigned u = Values2.size(); u > 0; u--) { 5645 if (Values2[u - 1].first == L) { 5646 Values2[u - 1].second = C; 5647 break; 5648 } 5649 } 5650 return C; 5651 } 5652 5653 /// This builds up a Constant using the ConstantExpr interface. That way, we 5654 /// will return Constants for objects which aren't represented by a 5655 /// SCEVConstant, because SCEVConstant is restricted to ConstantInt. 5656 /// Returns NULL if the SCEV isn't representable as a Constant. 5657 static Constant *BuildConstantFromSCEV(const SCEV *V) { 5658 switch (static_cast<SCEVTypes>(V->getSCEVType())) { 5659 case scCouldNotCompute: 5660 case scAddRecExpr: 5661 break; 5662 case scConstant: 5663 return cast<SCEVConstant>(V)->getValue(); 5664 case scUnknown: 5665 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue()); 5666 case scSignExtend: { 5667 const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V); 5668 if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand())) 5669 return ConstantExpr::getSExt(CastOp, SS->getType()); 5670 break; 5671 } 5672 case scZeroExtend: { 5673 const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V); 5674 if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand())) 5675 return ConstantExpr::getZExt(CastOp, SZ->getType()); 5676 break; 5677 } 5678 case scTruncate: { 5679 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V); 5680 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand())) 5681 return ConstantExpr::getTrunc(CastOp, ST->getType()); 5682 break; 5683 } 5684 case scAddExpr: { 5685 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V); 5686 if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) { 5687 if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) { 5688 unsigned AS = PTy->getAddressSpace(); 5689 Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); 5690 C = ConstantExpr::getBitCast(C, DestPtrTy); 5691 } 5692 for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) { 5693 Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i)); 5694 if (!C2) return nullptr; 5695 5696 // First pointer! 5697 if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) { 5698 unsigned AS = C2->getType()->getPointerAddressSpace(); 5699 std::swap(C, C2); 5700 Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); 5701 // The offsets have been converted to bytes. We can add bytes to an 5702 // i8* by GEP with the byte count in the first index. 5703 C = ConstantExpr::getBitCast(C, DestPtrTy); 5704 } 5705 5706 // Don't bother trying to sum two pointers. We probably can't 5707 // statically compute a load that results from it anyway. 5708 if (C2->getType()->isPointerTy()) 5709 return nullptr; 5710 5711 if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) { 5712 if (PTy->getElementType()->isStructTy()) 5713 C2 = ConstantExpr::getIntegerCast( 5714 C2, Type::getInt32Ty(C->getContext()), true); 5715 C = ConstantExpr::getGetElementPtr(C, C2); 5716 } else 5717 C = ConstantExpr::getAdd(C, C2); 5718 } 5719 return C; 5720 } 5721 break; 5722 } 5723 case scMulExpr: { 5724 const SCEVMulExpr *SM = cast<SCEVMulExpr>(V); 5725 if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) { 5726 // Don't bother with pointers at all. 5727 if (C->getType()->isPointerTy()) return nullptr; 5728 for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) { 5729 Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i)); 5730 if (!C2 || C2->getType()->isPointerTy()) return nullptr; 5731 C = ConstantExpr::getMul(C, C2); 5732 } 5733 return C; 5734 } 5735 break; 5736 } 5737 case scUDivExpr: { 5738 const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V); 5739 if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS())) 5740 if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS())) 5741 if (LHS->getType() == RHS->getType()) 5742 return ConstantExpr::getUDiv(LHS, RHS); 5743 break; 5744 } 5745 case scSMaxExpr: 5746 case scUMaxExpr: 5747 break; // TODO: smax, umax. 5748 } 5749 return nullptr; 5750 } 5751 5752 const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { 5753 if (isa<SCEVConstant>(V)) return V; 5754 5755 // If this instruction is evolved from a constant-evolving PHI, compute the 5756 // exit value from the loop without using SCEVs. 5757 if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) { 5758 if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) { 5759 const Loop *LI = (*this->LI)[I->getParent()]; 5760 if (LI && LI->getParentLoop() == L) // Looking for loop exit value. 5761 if (PHINode *PN = dyn_cast<PHINode>(I)) 5762 if (PN->getParent() == LI->getHeader()) { 5763 // Okay, there is no closed form solution for the PHI node. Check 5764 // to see if the loop that contains it has a known backedge-taken 5765 // count. If so, we may be able to force computation of the exit 5766 // value. 5767 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI); 5768 if (const SCEVConstant *BTCC = 5769 dyn_cast<SCEVConstant>(BackedgeTakenCount)) { 5770 // Okay, we know how many times the containing loop executes. If 5771 // this is a constant evolving PHI node, get the final value at 5772 // the specified iteration number. 5773 Constant *RV = getConstantEvolutionLoopExitValue(PN, 5774 BTCC->getValue()->getValue(), 5775 LI); 5776 if (RV) return getSCEV(RV); 5777 } 5778 } 5779 5780 // Okay, this is an expression that we cannot symbolically evaluate 5781 // into a SCEV. Check to see if it's possible to symbolically evaluate 5782 // the arguments into constants, and if so, try to constant propagate the 5783 // result. This is particularly useful for computing loop exit values. 5784 if (CanConstantFold(I)) { 5785 SmallVector<Constant *, 4> Operands; 5786 bool MadeImprovement = false; 5787 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { 5788 Value *Op = I->getOperand(i); 5789 if (Constant *C = dyn_cast<Constant>(Op)) { 5790 Operands.push_back(C); 5791 continue; 5792 } 5793 5794 // If any of the operands is non-constant and if they are 5795 // non-integer and non-pointer, don't even try to analyze them 5796 // with scev techniques. 5797 if (!isSCEVable(Op->getType())) 5798 return V; 5799 5800 const SCEV *OrigV = getSCEV(Op); 5801 const SCEV *OpV = getSCEVAtScope(OrigV, L); 5802 MadeImprovement |= OrigV != OpV; 5803 5804 Constant *C = BuildConstantFromSCEV(OpV); 5805 if (!C) return V; 5806 if (C->getType() != Op->getType()) 5807 C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false, 5808 Op->getType(), 5809 false), 5810 C, Op->getType()); 5811 Operands.push_back(C); 5812 } 5813 5814 // Check to see if getSCEVAtScope actually made an improvement. 5815 if (MadeImprovement) { 5816 Constant *C = nullptr; 5817 if (const CmpInst *CI = dyn_cast<CmpInst>(I)) 5818 C = ConstantFoldCompareInstOperands(CI->getPredicate(), 5819 Operands[0], Operands[1], DL, 5820 TLI); 5821 else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) { 5822 if (!LI->isVolatile()) 5823 C = ConstantFoldLoadFromConstPtr(Operands[0], DL); 5824 } else 5825 C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), 5826 Operands, DL, TLI); 5827 if (!C) return V; 5828 return getSCEV(C); 5829 } 5830 } 5831 } 5832 5833 // This is some other type of SCEVUnknown, just return it. 5834 return V; 5835 } 5836 5837 if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) { 5838 // Avoid performing the look-up in the common case where the specified 5839 // expression has no loop-variant portions. 5840 for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) { 5841 const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); 5842 if (OpAtScope != Comm->getOperand(i)) { 5843 // Okay, at least one of these operands is loop variant but might be 5844 // foldable. Build a new instance of the folded commutative expression. 5845 SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(), 5846 Comm->op_begin()+i); 5847 NewOps.push_back(OpAtScope); 5848 5849 for (++i; i != e; ++i) { 5850 OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); 5851 NewOps.push_back(OpAtScope); 5852 } 5853 if (isa<SCEVAddExpr>(Comm)) 5854 return getAddExpr(NewOps); 5855 if (isa<SCEVMulExpr>(Comm)) 5856 return getMulExpr(NewOps); 5857 if (isa<SCEVSMaxExpr>(Comm)) 5858 return getSMaxExpr(NewOps); 5859 if (isa<SCEVUMaxExpr>(Comm)) 5860 return getUMaxExpr(NewOps); 5861 llvm_unreachable("Unknown commutative SCEV type!"); 5862 } 5863 } 5864 // If we got here, all operands are loop invariant. 5865 return Comm; 5866 } 5867 5868 if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) { 5869 const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L); 5870 const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L); 5871 if (LHS == Div->getLHS() && RHS == Div->getRHS()) 5872 return Div; // must be loop invariant 5873 return getUDivExpr(LHS, RHS); 5874 } 5875 5876 // If this is a loop recurrence for a loop that does not contain L, then we 5877 // are dealing with the final value computed by the loop. 5878 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) { 5879 // First, attempt to evaluate each operand. 5880 // Avoid performing the look-up in the common case where the specified 5881 // expression has no loop-variant portions. 5882 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { 5883 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L); 5884 if (OpAtScope == AddRec->getOperand(i)) 5885 continue; 5886 5887 // Okay, at least one of these operands is loop variant but might be 5888 // foldable. Build a new instance of the folded commutative expression. 5889 SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(), 5890 AddRec->op_begin()+i); 5891 NewOps.push_back(OpAtScope); 5892 for (++i; i != e; ++i) 5893 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L)); 5894 5895 const SCEV *FoldedRec = 5896 getAddRecExpr(NewOps, AddRec->getLoop(), 5897 AddRec->getNoWrapFlags(SCEV::FlagNW)); 5898 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec); 5899 // The addrec may be folded to a nonrecurrence, for example, if the 5900 // induction variable is multiplied by zero after constant folding. Go 5901 // ahead and return the folded value. 5902 if (!AddRec) 5903 return FoldedRec; 5904 break; 5905 } 5906 5907 // If the scope is outside the addrec's loop, evaluate it by using the 5908 // loop exit value of the addrec. 5909 if (!AddRec->getLoop()->contains(L)) { 5910 // To evaluate this recurrence, we need to know how many times the AddRec 5911 // loop iterates. Compute this now. 5912 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); 5913 if (BackedgeTakenCount == getCouldNotCompute()) return AddRec; 5914 5915 // Then, evaluate the AddRec. 5916 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this); 5917 } 5918 5919 return AddRec; 5920 } 5921 5922 if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) { 5923 const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); 5924 if (Op == Cast->getOperand()) 5925 return Cast; // must be loop invariant 5926 return getZeroExtendExpr(Op, Cast->getType()); 5927 } 5928 5929 if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) { 5930 const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); 5931 if (Op == Cast->getOperand()) 5932 return Cast; // must be loop invariant 5933 return getSignExtendExpr(Op, Cast->getType()); 5934 } 5935 5936 if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) { 5937 const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); 5938 if (Op == Cast->getOperand()) 5939 return Cast; // must be loop invariant 5940 return getTruncateExpr(Op, Cast->getType()); 5941 } 5942 5943 llvm_unreachable("Unknown SCEV type!"); 5944 } 5945 5946 /// getSCEVAtScope - This is a convenience function which does 5947 /// getSCEVAtScope(getSCEV(V), L). 5948 const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { 5949 return getSCEVAtScope(getSCEV(V), L); 5950 } 5951 5952 /// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the 5953 /// following equation: 5954 /// 5955 /// A * X = B (mod N) 5956 /// 5957 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of 5958 /// A and B isn't important. 5959 /// 5960 /// If the equation does not have a solution, SCEVCouldNotCompute is returned. 5961 static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, 5962 ScalarEvolution &SE) { 5963 uint32_t BW = A.getBitWidth(); 5964 assert(BW == B.getBitWidth() && "Bit widths must be the same."); 5965 assert(A != 0 && "A must be non-zero."); 5966 5967 // 1. D = gcd(A, N) 5968 // 5969 // The gcd of A and N may have only one prime factor: 2. The number of 5970 // trailing zeros in A is its multiplicity 5971 uint32_t Mult2 = A.countTrailingZeros(); 5972 // D = 2^Mult2 5973 5974 // 2. Check if B is divisible by D. 5975 // 5976 // B is divisible by D if and only if the multiplicity of prime factor 2 for B 5977 // is not less than multiplicity of this prime factor for D. 5978 if (B.countTrailingZeros() < Mult2) 5979 return SE.getCouldNotCompute(); 5980 5981 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic 5982 // modulo (N / D). 5983 // 5984 // (N / D) may need BW+1 bits in its representation. Hence, we'll use this 5985 // bit width during computations. 5986 APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D 5987 APInt Mod(BW + 1, 0); 5988 Mod.setBit(BW - Mult2); // Mod = N / D 5989 APInt I = AD.multiplicativeInverse(Mod); 5990 5991 // 4. Compute the minimum unsigned root of the equation: 5992 // I * (B / D) mod (N / D) 5993 APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod); 5994 5995 // The result is guaranteed to be less than 2^BW so we may truncate it to BW 5996 // bits. 5997 return SE.getConstant(Result.trunc(BW)); 5998 } 5999 6000 /// SolveQuadraticEquation - Find the roots of the quadratic equation for the 6001 /// given quadratic chrec {L,+,M,+,N}. This returns either the two roots (which 6002 /// might be the same) or two SCEVCouldNotCompute objects. 6003 /// 6004 static std::pair<const SCEV *,const SCEV *> 6005 SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { 6006 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); 6007 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0)); 6008 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1)); 6009 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2)); 6010 6011 // We currently can only solve this if the coefficients are constants. 6012 if (!LC || !MC || !NC) { 6013 const SCEV *CNC = SE.getCouldNotCompute(); 6014 return std::make_pair(CNC, CNC); 6015 } 6016 6017 uint32_t BitWidth = LC->getValue()->getValue().getBitWidth(); 6018 const APInt &L = LC->getValue()->getValue(); 6019 const APInt &M = MC->getValue()->getValue(); 6020 const APInt &N = NC->getValue()->getValue(); 6021 APInt Two(BitWidth, 2); 6022 APInt Four(BitWidth, 4); 6023 6024 { 6025 using namespace APIntOps; 6026 const APInt& C = L; 6027 // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C 6028 // The B coefficient is M-N/2 6029 APInt B(M); 6030 B -= sdiv(N,Two); 6031 6032 // The A coefficient is N/2 6033 APInt A(N.sdiv(Two)); 6034 6035 // Compute the B^2-4ac term. 6036 APInt SqrtTerm(B); 6037 SqrtTerm *= B; 6038 SqrtTerm -= Four * (A * C); 6039 6040 if (SqrtTerm.isNegative()) { 6041 // The loop is provably infinite. 6042 const SCEV *CNC = SE.getCouldNotCompute(); 6043 return std::make_pair(CNC, CNC); 6044 } 6045 6046 // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest 6047 // integer value or else APInt::sqrt() will assert. 6048 APInt SqrtVal(SqrtTerm.sqrt()); 6049 6050 // Compute the two solutions for the quadratic formula. 6051 // The divisions must be performed as signed divisions. 6052 APInt NegB(-B); 6053 APInt TwoA(A << 1); 6054 if (TwoA.isMinValue()) { 6055 const SCEV *CNC = SE.getCouldNotCompute(); 6056 return std::make_pair(CNC, CNC); 6057 } 6058 6059 LLVMContext &Context = SE.getContext(); 6060 6061 ConstantInt *Solution1 = 6062 ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA)); 6063 ConstantInt *Solution2 = 6064 ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); 6065 6066 return std::make_pair(SE.getConstant(Solution1), 6067 SE.getConstant(Solution2)); 6068 } // end APIntOps namespace 6069 } 6070 6071 /// HowFarToZero - Return the number of times a backedge comparing the specified 6072 /// value to zero will execute. If not computable, return CouldNotCompute. 6073 /// 6074 /// This is only used for loops with a "x != y" exit test. The exit condition is 6075 /// now expressed as a single expression, V = x-y. So the exit test is 6076 /// effectively V != 0. We know and take advantage of the fact that this 6077 /// expression only being used in a comparison by zero context. 6078 ScalarEvolution::ExitLimit 6079 ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { 6080 // If the value is a constant 6081 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { 6082 // If the value is already zero, the branch will execute zero times. 6083 if (C->getValue()->isZero()) return C; 6084 return getCouldNotCompute(); // Otherwise it will loop infinitely. 6085 } 6086 6087 const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V); 6088 if (!AddRec || AddRec->getLoop() != L) 6089 return getCouldNotCompute(); 6090 6091 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of 6092 // the quadratic equation to solve it. 6093 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) { 6094 std::pair<const SCEV *,const SCEV *> Roots = 6095 SolveQuadraticEquation(AddRec, *this); 6096 const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first); 6097 const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second); 6098 if (R1 && R2) { 6099 #if 0 6100 dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1 6101 << " sol#2: " << *R2 << "\n"; 6102 #endif 6103 // Pick the smallest positive root value. 6104 if (ConstantInt *CB = 6105 dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT, 6106 R1->getValue(), 6107 R2->getValue()))) { 6108 if (CB->getZExtValue() == false) 6109 std::swap(R1, R2); // R1 is the minimum root now. 6110 6111 // We can only use this value if the chrec ends up with an exact zero 6112 // value at this index. When solving for "X*X != 5", for example, we 6113 // should not accept a root of 2. 6114 const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); 6115 if (Val->isZero()) 6116 return R1; // We found a quadratic root! 6117 } 6118 } 6119 return getCouldNotCompute(); 6120 } 6121 6122 // Otherwise we can only handle this if it is affine. 6123 if (!AddRec->isAffine()) 6124 return getCouldNotCompute(); 6125 6126 // If this is an affine expression, the execution count of this branch is 6127 // the minimum unsigned root of the following equation: 6128 // 6129 // Start + Step*N = 0 (mod 2^BW) 6130 // 6131 // equivalent to: 6132 // 6133 // Step*N = -Start (mod 2^BW) 6134 // 6135 // where BW is the common bit width of Start and Step. 6136 6137 // Get the initial value for the loop. 6138 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); 6139 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); 6140 6141 // For now we handle only constant steps. 6142 // 6143 // TODO: Handle a nonconstant Step given AddRec<NUW>. If the 6144 // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap 6145 // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step. 6146 // We have not yet seen any such cases. 6147 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step); 6148 if (!StepC || StepC->getValue()->equalsInt(0)) 6149 return getCouldNotCompute(); 6150 6151 // For positive steps (counting up until unsigned overflow): 6152 // N = -Start/Step (as unsigned) 6153 // For negative steps (counting down to zero): 6154 // N = Start/-Step 6155 // First compute the unsigned distance from zero in the direction of Step. 6156 bool CountDown = StepC->getValue()->getValue().isNegative(); 6157 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start); 6158 6159 // Handle unitary steps, which cannot wraparound. 6160 // 1*N = -Start; -1*N = Start (mod 2^BW), so: 6161 // N = Distance (as unsigned) 6162 if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue()) { 6163 ConstantRange CR = getUnsignedRange(Start); 6164 const SCEV *MaxBECount; 6165 if (!CountDown && CR.getUnsignedMin().isMinValue()) 6166 // When counting up, the worst starting value is 1, not 0. 6167 MaxBECount = CR.getUnsignedMax().isMinValue() 6168 ? getConstant(APInt::getMinValue(CR.getBitWidth())) 6169 : getConstant(APInt::getMaxValue(CR.getBitWidth())); 6170 else 6171 MaxBECount = getConstant(CountDown ? CR.getUnsignedMax() 6172 : -CR.getUnsignedMin()); 6173 return ExitLimit(Distance, MaxBECount); 6174 } 6175 6176 // As a special case, handle the instance where Step is a positive power of 6177 // two. In this case, determining whether Step divides Distance evenly can be 6178 // done by counting and comparing the number of trailing zeros of Step and 6179 // Distance. 6180 if (!CountDown) { 6181 const APInt &StepV = StepC->getValue()->getValue(); 6182 // StepV.isPowerOf2() returns true if StepV is an positive power of two. It 6183 // also returns true if StepV is maximally negative (eg, INT_MIN), but that 6184 // case is not handled as this code is guarded by !CountDown. 6185 if (StepV.isPowerOf2() && 6186 GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) 6187 return getUDivExactExpr(Distance, Step); 6188 } 6189 6190 // If the condition controls loop exit (the loop exits only if the expression 6191 // is true) and the addition is no-wrap we can use unsigned divide to 6192 // compute the backedge count. In this case, the step may not divide the 6193 // distance, but we don't care because if the condition is "missed" the loop 6194 // will have undefined behavior due to wrapping. 6195 if (ControlsExit && AddRec->getNoWrapFlags(SCEV::FlagNW)) { 6196 const SCEV *Exact = 6197 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); 6198 return ExitLimit(Exact, Exact); 6199 } 6200 6201 // Then, try to solve the above equation provided that Start is constant. 6202 if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) 6203 return SolveLinEquationWithOverflow(StepC->getValue()->getValue(), 6204 -StartC->getValue()->getValue(), 6205 *this); 6206 return getCouldNotCompute(); 6207 } 6208 6209 /// HowFarToNonZero - Return the number of times a backedge checking the 6210 /// specified value for nonzero will execute. If not computable, return 6211 /// CouldNotCompute 6212 ScalarEvolution::ExitLimit 6213 ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { 6214 // Loops that look like: while (X == 0) are very strange indeed. We don't 6215 // handle them yet except for the trivial case. This could be expanded in the 6216 // future as needed. 6217 6218 // If the value is a constant, check to see if it is known to be non-zero 6219 // already. If so, the backedge will execute zero times. 6220 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { 6221 if (!C->getValue()->isNullValue()) 6222 return getConstant(C->getType(), 0); 6223 return getCouldNotCompute(); // Otherwise it will loop infinitely. 6224 } 6225 6226 // We could implement others, but I really doubt anyone writes loops like 6227 // this, and if they did, they would already be constant folded. 6228 return getCouldNotCompute(); 6229 } 6230 6231 /// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB 6232 /// (which may not be an immediate predecessor) which has exactly one 6233 /// successor from which BB is reachable, or null if no such block is 6234 /// found. 6235 /// 6236 std::pair<BasicBlock *, BasicBlock *> 6237 ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { 6238 // If the block has a unique predecessor, then there is no path from the 6239 // predecessor to the block that does not go through the direct edge 6240 // from the predecessor to the block. 6241 if (BasicBlock *Pred = BB->getSinglePredecessor()) 6242 return std::make_pair(Pred, BB); 6243 6244 // A loop's header is defined to be a block that dominates the loop. 6245 // If the header has a unique predecessor outside the loop, it must be 6246 // a block that has exactly one successor that can reach the loop. 6247 if (Loop *L = LI->getLoopFor(BB)) 6248 return std::make_pair(L->getLoopPredecessor(), L->getHeader()); 6249 6250 return std::pair<BasicBlock *, BasicBlock *>(); 6251 } 6252 6253 /// HasSameValue - SCEV structural equivalence is usually sufficient for 6254 /// testing whether two expressions are equal, however for the purposes of 6255 /// looking for a condition guarding a loop, it can be useful to be a little 6256 /// more general, since a front-end may have replicated the controlling 6257 /// expression. 6258 /// 6259 static bool HasSameValue(const SCEV *A, const SCEV *B) { 6260 // Quick check to see if they are the same SCEV. 6261 if (A == B) return true; 6262 6263 // Otherwise, if they're both SCEVUnknown, it's possible that they hold 6264 // two different instructions with the same value. Check for this case. 6265 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A)) 6266 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B)) 6267 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue())) 6268 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue())) 6269 if (AI->isIdenticalTo(BI) && !AI->mayReadFromMemory()) 6270 return true; 6271 6272 // Otherwise assume they may have a different value. 6273 return false; 6274 } 6275 6276 /// SimplifyICmpOperands - Simplify LHS and RHS in a comparison with 6277 /// predicate Pred. Return true iff any changes were made. 6278 /// 6279 bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, 6280 const SCEV *&LHS, const SCEV *&RHS, 6281 unsigned Depth) { 6282 bool Changed = false; 6283 6284 // If we hit the max recursion limit bail out. 6285 if (Depth >= 3) 6286 return false; 6287 6288 // Canonicalize a constant to the right side. 6289 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) { 6290 // Check for both operands constant. 6291 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { 6292 if (ConstantExpr::getICmp(Pred, 6293 LHSC->getValue(), 6294 RHSC->getValue())->isNullValue()) 6295 goto trivially_false; 6296 else 6297 goto trivially_true; 6298 } 6299 // Otherwise swap the operands to put the constant on the right. 6300 std::swap(LHS, RHS); 6301 Pred = ICmpInst::getSwappedPredicate(Pred); 6302 Changed = true; 6303 } 6304 6305 // If we're comparing an addrec with a value which is loop-invariant in the 6306 // addrec's loop, put the addrec on the left. Also make a dominance check, 6307 // as both operands could be addrecs loop-invariant in each other's loop. 6308 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) { 6309 const Loop *L = AR->getLoop(); 6310 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) { 6311 std::swap(LHS, RHS); 6312 Pred = ICmpInst::getSwappedPredicate(Pred); 6313 Changed = true; 6314 } 6315 } 6316 6317 // If there's a constant operand, canonicalize comparisons with boundary 6318 // cases, and canonicalize *-or-equal comparisons to regular comparisons. 6319 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) { 6320 const APInt &RA = RC->getValue()->getValue(); 6321 switch (Pred) { 6322 default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); 6323 case ICmpInst::ICMP_EQ: 6324 case ICmpInst::ICMP_NE: 6325 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b. 6326 if (!RA) 6327 if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS)) 6328 if (const SCEVMulExpr *ME = dyn_cast<SCEVMulExpr>(AE->getOperand(0))) 6329 if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 && 6330 ME->getOperand(0)->isAllOnesValue()) { 6331 RHS = AE->getOperand(1); 6332 LHS = ME->getOperand(1); 6333 Changed = true; 6334 } 6335 break; 6336 case ICmpInst::ICMP_UGE: 6337 if ((RA - 1).isMinValue()) { 6338 Pred = ICmpInst::ICMP_NE; 6339 RHS = getConstant(RA - 1); 6340 Changed = true; 6341 break; 6342 } 6343 if (RA.isMaxValue()) { 6344 Pred = ICmpInst::ICMP_EQ; 6345 Changed = true; 6346 break; 6347 } 6348 if (RA.isMinValue()) goto trivially_true; 6349 6350 Pred = ICmpInst::ICMP_UGT; 6351 RHS = getConstant(RA - 1); 6352 Changed = true; 6353 break; 6354 case ICmpInst::ICMP_ULE: 6355 if ((RA + 1).isMaxValue()) { 6356 Pred = ICmpInst::ICMP_NE; 6357 RHS = getConstant(RA + 1); 6358 Changed = true; 6359 break; 6360 } 6361 if (RA.isMinValue()) { 6362 Pred = ICmpInst::ICMP_EQ; 6363 Changed = true; 6364 break; 6365 } 6366 if (RA.isMaxValue()) goto trivially_true; 6367 6368 Pred = ICmpInst::ICMP_ULT; 6369 RHS = getConstant(RA + 1); 6370 Changed = true; 6371 break; 6372 case ICmpInst::ICMP_SGE: 6373 if ((RA - 1).isMinSignedValue()) { 6374 Pred = ICmpInst::ICMP_NE; 6375 RHS = getConstant(RA - 1); 6376 Changed = true; 6377 break; 6378 } 6379 if (RA.isMaxSignedValue()) { 6380 Pred = ICmpInst::ICMP_EQ; 6381 Changed = true; 6382 break; 6383 } 6384 if (RA.isMinSignedValue()) goto trivially_true; 6385 6386 Pred = ICmpInst::ICMP_SGT; 6387 RHS = getConstant(RA - 1); 6388 Changed = true; 6389 break; 6390 case ICmpInst::ICMP_SLE: 6391 if ((RA + 1).isMaxSignedValue()) { 6392 Pred = ICmpInst::ICMP_NE; 6393 RHS = getConstant(RA + 1); 6394 Changed = true; 6395 break; 6396 } 6397 if (RA.isMinSignedValue()) { 6398 Pred = ICmpInst::ICMP_EQ; 6399 Changed = true; 6400 break; 6401 } 6402 if (RA.isMaxSignedValue()) goto trivially_true; 6403 6404 Pred = ICmpInst::ICMP_SLT; 6405 RHS = getConstant(RA + 1); 6406 Changed = true; 6407 break; 6408 case ICmpInst::ICMP_UGT: 6409 if (RA.isMinValue()) { 6410 Pred = ICmpInst::ICMP_NE; 6411 Changed = true; 6412 break; 6413 } 6414 if ((RA + 1).isMaxValue()) { 6415 Pred = ICmpInst::ICMP_EQ; 6416 RHS = getConstant(RA + 1); 6417 Changed = true; 6418 break; 6419 } 6420 if (RA.isMaxValue()) goto trivially_false; 6421 break; 6422 case ICmpInst::ICMP_ULT: 6423 if (RA.isMaxValue()) { 6424 Pred = ICmpInst::ICMP_NE; 6425 Changed = true; 6426 break; 6427 } 6428 if ((RA - 1).isMinValue()) { 6429 Pred = ICmpInst::ICMP_EQ; 6430 RHS = getConstant(RA - 1); 6431 Changed = true; 6432 break; 6433 } 6434 if (RA.isMinValue()) goto trivially_false; 6435 break; 6436 case ICmpInst::ICMP_SGT: 6437 if (RA.isMinSignedValue()) { 6438 Pred = ICmpInst::ICMP_NE; 6439 Changed = true; 6440 break; 6441 } 6442 if ((RA + 1).isMaxSignedValue()) { 6443 Pred = ICmpInst::ICMP_EQ; 6444 RHS = getConstant(RA + 1); 6445 Changed = true; 6446 break; 6447 } 6448 if (RA.isMaxSignedValue()) goto trivially_false; 6449 break; 6450 case ICmpInst::ICMP_SLT: 6451 if (RA.isMaxSignedValue()) { 6452 Pred = ICmpInst::ICMP_NE; 6453 Changed = true; 6454 break; 6455 } 6456 if ((RA - 1).isMinSignedValue()) { 6457 Pred = ICmpInst::ICMP_EQ; 6458 RHS = getConstant(RA - 1); 6459 Changed = true; 6460 break; 6461 } 6462 if (RA.isMinSignedValue()) goto trivially_false; 6463 break; 6464 } 6465 } 6466 6467 // Check for obvious equality. 6468 if (HasSameValue(LHS, RHS)) { 6469 if (ICmpInst::isTrueWhenEqual(Pred)) 6470 goto trivially_true; 6471 if (ICmpInst::isFalseWhenEqual(Pred)) 6472 goto trivially_false; 6473 } 6474 6475 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by 6476 // adding or subtracting 1 from one of the operands. 6477 switch (Pred) { 6478 case ICmpInst::ICMP_SLE: 6479 if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) { 6480 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, 6481 SCEV::FlagNSW); 6482 Pred = ICmpInst::ICMP_SLT; 6483 Changed = true; 6484 } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) { 6485 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS, 6486 SCEV::FlagNSW); 6487 Pred = ICmpInst::ICMP_SLT; 6488 Changed = true; 6489 } 6490 break; 6491 case ICmpInst::ICMP_SGE: 6492 if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) { 6493 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS, 6494 SCEV::FlagNSW); 6495 Pred = ICmpInst::ICMP_SGT; 6496 Changed = true; 6497 } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) { 6498 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, 6499 SCEV::FlagNSW); 6500 Pred = ICmpInst::ICMP_SGT; 6501 Changed = true; 6502 } 6503 break; 6504 case ICmpInst::ICMP_ULE: 6505 if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) { 6506 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, 6507 SCEV::FlagNUW); 6508 Pred = ICmpInst::ICMP_ULT; 6509 Changed = true; 6510 } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) { 6511 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS, 6512 SCEV::FlagNUW); 6513 Pred = ICmpInst::ICMP_ULT; 6514 Changed = true; 6515 } 6516 break; 6517 case ICmpInst::ICMP_UGE: 6518 if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) { 6519 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS, 6520 SCEV::FlagNUW); 6521 Pred = ICmpInst::ICMP_UGT; 6522 Changed = true; 6523 } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) { 6524 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, 6525 SCEV::FlagNUW); 6526 Pred = ICmpInst::ICMP_UGT; 6527 Changed = true; 6528 } 6529 break; 6530 default: 6531 break; 6532 } 6533 6534 // TODO: More simplifications are possible here. 6535 6536 // Recursively simplify until we either hit a recursion limit or nothing 6537 // changes. 6538 if (Changed) 6539 return SimplifyICmpOperands(Pred, LHS, RHS, Depth+1); 6540 6541 return Changed; 6542 6543 trivially_true: 6544 // Return 0 == 0. 6545 LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); 6546 Pred = ICmpInst::ICMP_EQ; 6547 return true; 6548 6549 trivially_false: 6550 // Return 0 != 0. 6551 LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); 6552 Pred = ICmpInst::ICMP_NE; 6553 return true; 6554 } 6555 6556 bool ScalarEvolution::isKnownNegative(const SCEV *S) { 6557 return getSignedRange(S).getSignedMax().isNegative(); 6558 } 6559 6560 bool ScalarEvolution::isKnownPositive(const SCEV *S) { 6561 return getSignedRange(S).getSignedMin().isStrictlyPositive(); 6562 } 6563 6564 bool ScalarEvolution::isKnownNonNegative(const SCEV *S) { 6565 return !getSignedRange(S).getSignedMin().isNegative(); 6566 } 6567 6568 bool ScalarEvolution::isKnownNonPositive(const SCEV *S) { 6569 return !getSignedRange(S).getSignedMax().isStrictlyPositive(); 6570 } 6571 6572 bool ScalarEvolution::isKnownNonZero(const SCEV *S) { 6573 return isKnownNegative(S) || isKnownPositive(S); 6574 } 6575 6576 bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, 6577 const SCEV *LHS, const SCEV *RHS) { 6578 // Canonicalize the inputs first. 6579 (void)SimplifyICmpOperands(Pred, LHS, RHS); 6580 6581 // If LHS or RHS is an addrec, check to see if the condition is true in 6582 // every iteration of the loop. 6583 // If LHS and RHS are both addrec, both conditions must be true in 6584 // every iteration of the loop. 6585 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS); 6586 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS); 6587 bool LeftGuarded = false; 6588 bool RightGuarded = false; 6589 if (LAR) { 6590 const Loop *L = LAR->getLoop(); 6591 if (isLoopEntryGuardedByCond(L, Pred, LAR->getStart(), RHS) && 6592 isLoopBackedgeGuardedByCond(L, Pred, LAR->getPostIncExpr(*this), RHS)) { 6593 if (!RAR) return true; 6594 LeftGuarded = true; 6595 } 6596 } 6597 if (RAR) { 6598 const Loop *L = RAR->getLoop(); 6599 if (isLoopEntryGuardedByCond(L, Pred, LHS, RAR->getStart()) && 6600 isLoopBackedgeGuardedByCond(L, Pred, LHS, RAR->getPostIncExpr(*this))) { 6601 if (!LAR) return true; 6602 RightGuarded = true; 6603 } 6604 } 6605 if (LeftGuarded && RightGuarded) 6606 return true; 6607 6608 // Otherwise see what can be done with known constant ranges. 6609 return isKnownPredicateWithRanges(Pred, LHS, RHS); 6610 } 6611 6612 bool 6613 ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred, 6614 const SCEV *LHS, const SCEV *RHS) { 6615 if (HasSameValue(LHS, RHS)) 6616 return ICmpInst::isTrueWhenEqual(Pred); 6617 6618 // This code is split out from isKnownPredicate because it is called from 6619 // within isLoopEntryGuardedByCond. 6620 switch (Pred) { 6621 default: 6622 llvm_unreachable("Unexpected ICmpInst::Predicate value!"); 6623 case ICmpInst::ICMP_SGT: 6624 std::swap(LHS, RHS); 6625 case ICmpInst::ICMP_SLT: { 6626 ConstantRange LHSRange = getSignedRange(LHS); 6627 ConstantRange RHSRange = getSignedRange(RHS); 6628 if (LHSRange.getSignedMax().slt(RHSRange.getSignedMin())) 6629 return true; 6630 if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax())) 6631 return false; 6632 break; 6633 } 6634 case ICmpInst::ICMP_SGE: 6635 std::swap(LHS, RHS); 6636 case ICmpInst::ICMP_SLE: { 6637 ConstantRange LHSRange = getSignedRange(LHS); 6638 ConstantRange RHSRange = getSignedRange(RHS); 6639 if (LHSRange.getSignedMax().sle(RHSRange.getSignedMin())) 6640 return true; 6641 if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax())) 6642 return false; 6643 break; 6644 } 6645 case ICmpInst::ICMP_UGT: 6646 std::swap(LHS, RHS); 6647 case ICmpInst::ICMP_ULT: { 6648 ConstantRange LHSRange = getUnsignedRange(LHS); 6649 ConstantRange RHSRange = getUnsignedRange(RHS); 6650 if (LHSRange.getUnsignedMax().ult(RHSRange.getUnsignedMin())) 6651 return true; 6652 if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax())) 6653 return false; 6654 break; 6655 } 6656 case ICmpInst::ICMP_UGE: 6657 std::swap(LHS, RHS); 6658 case ICmpInst::ICMP_ULE: { 6659 ConstantRange LHSRange = getUnsignedRange(LHS); 6660 ConstantRange RHSRange = getUnsignedRange(RHS); 6661 if (LHSRange.getUnsignedMax().ule(RHSRange.getUnsignedMin())) 6662 return true; 6663 if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax())) 6664 return false; 6665 break; 6666 } 6667 case ICmpInst::ICMP_NE: { 6668 if (getUnsignedRange(LHS).intersectWith(getUnsignedRange(RHS)).isEmptySet()) 6669 return true; 6670 if (getSignedRange(LHS).intersectWith(getSignedRange(RHS)).isEmptySet()) 6671 return true; 6672 6673 const SCEV *Diff = getMinusSCEV(LHS, RHS); 6674 if (isKnownNonZero(Diff)) 6675 return true; 6676 break; 6677 } 6678 case ICmpInst::ICMP_EQ: 6679 // The check at the top of the function catches the case where 6680 // the values are known to be equal. 6681 break; 6682 } 6683 return false; 6684 } 6685 6686 /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is 6687 /// protected by a conditional between LHS and RHS. This is used to 6688 /// to eliminate casts. 6689 bool 6690 ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, 6691 ICmpInst::Predicate Pred, 6692 const SCEV *LHS, const SCEV *RHS) { 6693 // Interpret a null as meaning no loop, where there is obviously no guard 6694 // (interprocedural conditions notwithstanding). 6695 if (!L) return true; 6696 6697 if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; 6698 6699 BasicBlock *Latch = L->getLoopLatch(); 6700 if (!Latch) 6701 return false; 6702 6703 BranchInst *LoopContinuePredicate = 6704 dyn_cast<BranchInst>(Latch->getTerminator()); 6705 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() && 6706 isImpliedCond(Pred, LHS, RHS, 6707 LoopContinuePredicate->getCondition(), 6708 LoopContinuePredicate->getSuccessor(0) != L->getHeader())) 6709 return true; 6710 6711 // Check conditions due to any @llvm.assume intrinsics. 6712 for (auto &AssumeVH : AC->assumptions()) { 6713 if (!AssumeVH) 6714 continue; 6715 auto *CI = cast<CallInst>(AssumeVH); 6716 if (!DT->dominates(CI, Latch->getTerminator())) 6717 continue; 6718 6719 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) 6720 return true; 6721 } 6722 6723 return false; 6724 } 6725 6726 /// isLoopEntryGuardedByCond - Test whether entry to the loop is protected 6727 /// by a conditional between LHS and RHS. This is used to help avoid max 6728 /// expressions in loop trip counts, and to eliminate casts. 6729 bool 6730 ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, 6731 ICmpInst::Predicate Pred, 6732 const SCEV *LHS, const SCEV *RHS) { 6733 // Interpret a null as meaning no loop, where there is obviously no guard 6734 // (interprocedural conditions notwithstanding). 6735 if (!L) return false; 6736 6737 if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; 6738 6739 // Starting at the loop predecessor, climb up the predecessor chain, as long 6740 // as there are predecessors that can be found that have unique successors 6741 // leading to the original header. 6742 for (std::pair<BasicBlock *, BasicBlock *> 6743 Pair(L->getLoopPredecessor(), L->getHeader()); 6744 Pair.first; 6745 Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { 6746 6747 BranchInst *LoopEntryPredicate = 6748 dyn_cast<BranchInst>(Pair.first->getTerminator()); 6749 if (!LoopEntryPredicate || 6750 LoopEntryPredicate->isUnconditional()) 6751 continue; 6752 6753 if (isImpliedCond(Pred, LHS, RHS, 6754 LoopEntryPredicate->getCondition(), 6755 LoopEntryPredicate->getSuccessor(0) != Pair.second)) 6756 return true; 6757 } 6758 6759 // Check conditions due to any @llvm.assume intrinsics. 6760 for (auto &AssumeVH : AC->assumptions()) { 6761 if (!AssumeVH) 6762 continue; 6763 auto *CI = cast<CallInst>(AssumeVH); 6764 if (!DT->dominates(CI, L->getHeader())) 6765 continue; 6766 6767 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) 6768 return true; 6769 } 6770 6771 return false; 6772 } 6773 6774 /// RAII wrapper to prevent recursive application of isImpliedCond. 6775 /// ScalarEvolution's PendingLoopPredicates set must be empty unless we are 6776 /// currently evaluating isImpliedCond. 6777 struct MarkPendingLoopPredicate { 6778 Value *Cond; 6779 DenseSet<Value*> &LoopPreds; 6780 bool Pending; 6781 6782 MarkPendingLoopPredicate(Value *C, DenseSet<Value*> &LP) 6783 : Cond(C), LoopPreds(LP) { 6784 Pending = !LoopPreds.insert(Cond).second; 6785 } 6786 ~MarkPendingLoopPredicate() { 6787 if (!Pending) 6788 LoopPreds.erase(Cond); 6789 } 6790 }; 6791 6792 /// isImpliedCond - Test whether the condition described by Pred, LHS, 6793 /// and RHS is true whenever the given Cond value evaluates to true. 6794 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, 6795 const SCEV *LHS, const SCEV *RHS, 6796 Value *FoundCondValue, 6797 bool Inverse) { 6798 MarkPendingLoopPredicate Mark(FoundCondValue, PendingLoopPredicates); 6799 if (Mark.Pending) 6800 return false; 6801 6802 // Recursively handle And and Or conditions. 6803 if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) { 6804 if (BO->getOpcode() == Instruction::And) { 6805 if (!Inverse) 6806 return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || 6807 isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); 6808 } else if (BO->getOpcode() == Instruction::Or) { 6809 if (Inverse) 6810 return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || 6811 isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); 6812 } 6813 } 6814 6815 ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue); 6816 if (!ICI) return false; 6817 6818 // Bail if the ICmp's operands' types are wider than the needed type 6819 // before attempting to call getSCEV on them. This avoids infinite 6820 // recursion, since the analysis of widening casts can require loop 6821 // exit condition information for overflow checking, which would 6822 // lead back here. 6823 if (getTypeSizeInBits(LHS->getType()) < 6824 getTypeSizeInBits(ICI->getOperand(0)->getType())) 6825 return false; 6826 6827 // Now that we found a conditional branch that dominates the loop or controls 6828 // the loop latch. Check to see if it is the comparison we are looking for. 6829 ICmpInst::Predicate FoundPred; 6830 if (Inverse) 6831 FoundPred = ICI->getInversePredicate(); 6832 else 6833 FoundPred = ICI->getPredicate(); 6834 6835 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); 6836 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); 6837 6838 // Balance the types. The case where FoundLHS' type is wider than 6839 // LHS' type is checked for above. 6840 if (getTypeSizeInBits(LHS->getType()) > 6841 getTypeSizeInBits(FoundLHS->getType())) { 6842 if (CmpInst::isSigned(FoundPred)) { 6843 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType()); 6844 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType()); 6845 } else { 6846 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType()); 6847 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType()); 6848 } 6849 } 6850 6851 // Canonicalize the query to match the way instcombine will have 6852 // canonicalized the comparison. 6853 if (SimplifyICmpOperands(Pred, LHS, RHS)) 6854 if (LHS == RHS) 6855 return CmpInst::isTrueWhenEqual(Pred); 6856 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS)) 6857 if (FoundLHS == FoundRHS) 6858 return CmpInst::isFalseWhenEqual(FoundPred); 6859 6860 // Check to see if we can make the LHS or RHS match. 6861 if (LHS == FoundRHS || RHS == FoundLHS) { 6862 if (isa<SCEVConstant>(RHS)) { 6863 std::swap(FoundLHS, FoundRHS); 6864 FoundPred = ICmpInst::getSwappedPredicate(FoundPred); 6865 } else { 6866 std::swap(LHS, RHS); 6867 Pred = ICmpInst::getSwappedPredicate(Pred); 6868 } 6869 } 6870 6871 // Check whether the found predicate is the same as the desired predicate. 6872 if (FoundPred == Pred) 6873 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); 6874 6875 // Check whether swapping the found predicate makes it the same as the 6876 // desired predicate. 6877 if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) { 6878 if (isa<SCEVConstant>(RHS)) 6879 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS); 6880 else 6881 return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), 6882 RHS, LHS, FoundLHS, FoundRHS); 6883 } 6884 6885 // Check if we can make progress by sharpening ranges. 6886 if (FoundPred == ICmpInst::ICMP_NE && 6887 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) { 6888 6889 const SCEVConstant *C = nullptr; 6890 const SCEV *V = nullptr; 6891 6892 if (isa<SCEVConstant>(FoundLHS)) { 6893 C = cast<SCEVConstant>(FoundLHS); 6894 V = FoundRHS; 6895 } else { 6896 C = cast<SCEVConstant>(FoundRHS); 6897 V = FoundLHS; 6898 } 6899 6900 // The guarding predicate tells us that C != V. If the known range 6901 // of V is [C, t), we can sharpen the range to [C + 1, t). The 6902 // range we consider has to correspond to same signedness as the 6903 // predicate we're interested in folding. 6904 6905 APInt Min = ICmpInst::isSigned(Pred) ? 6906 getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin(); 6907 6908 if (Min == C->getValue()->getValue()) { 6909 // Given (V >= Min && V != Min) we conclude V >= (Min + 1). 6910 // This is true even if (Min + 1) wraps around -- in case of 6911 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)). 6912 6913 APInt SharperMin = Min + 1; 6914 6915 switch (Pred) { 6916 case ICmpInst::ICMP_SGE: 6917 case ICmpInst::ICMP_UGE: 6918 // We know V `Pred` SharperMin. If this implies LHS `Pred` 6919 // RHS, we're done. 6920 if (isImpliedCondOperands(Pred, LHS, RHS, V, 6921 getConstant(SharperMin))) 6922 return true; 6923 6924 case ICmpInst::ICMP_SGT: 6925 case ICmpInst::ICMP_UGT: 6926 // We know from the range information that (V `Pred` Min || 6927 // V == Min). We know from the guarding condition that !(V 6928 // == Min). This gives us 6929 // 6930 // V `Pred` Min || V == Min && !(V == Min) 6931 // => V `Pred` Min 6932 // 6933 // If V `Pred` Min implies LHS `Pred` RHS, we're done. 6934 6935 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min))) 6936 return true; 6937 6938 default: 6939 // No change 6940 break; 6941 } 6942 } 6943 } 6944 6945 // Check whether the actual condition is beyond sufficient. 6946 if (FoundPred == ICmpInst::ICMP_EQ) 6947 if (ICmpInst::isTrueWhenEqual(Pred)) 6948 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS)) 6949 return true; 6950 if (Pred == ICmpInst::ICMP_NE) 6951 if (!ICmpInst::isTrueWhenEqual(FoundPred)) 6952 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS)) 6953 return true; 6954 6955 // Otherwise assume the worst. 6956 return false; 6957 } 6958 6959 /// isImpliedCondOperands - Test whether the condition described by Pred, 6960 /// LHS, and RHS is true whenever the condition described by Pred, FoundLHS, 6961 /// and FoundRHS is true. 6962 bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, 6963 const SCEV *LHS, const SCEV *RHS, 6964 const SCEV *FoundLHS, 6965 const SCEV *FoundRHS) { 6966 return isImpliedCondOperandsHelper(Pred, LHS, RHS, 6967 FoundLHS, FoundRHS) || 6968 // ~x < ~y --> x > y 6969 isImpliedCondOperandsHelper(Pred, LHS, RHS, 6970 getNotSCEV(FoundRHS), 6971 getNotSCEV(FoundLHS)); 6972 } 6973 6974 6975 /// If Expr computes ~A, return A else return nullptr 6976 static const SCEV *MatchNotExpr(const SCEV *Expr) { 6977 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr); 6978 if (!Add || Add->getNumOperands() != 2) return nullptr; 6979 6980 const SCEVConstant *AddLHS = dyn_cast<SCEVConstant>(Add->getOperand(0)); 6981 if (!(AddLHS && AddLHS->getValue()->getValue().isAllOnesValue())) 6982 return nullptr; 6983 6984 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1)); 6985 if (!AddRHS || AddRHS->getNumOperands() != 2) return nullptr; 6986 6987 const SCEVConstant *MulLHS = dyn_cast<SCEVConstant>(AddRHS->getOperand(0)); 6988 if (!(MulLHS && MulLHS->getValue()->getValue().isAllOnesValue())) 6989 return nullptr; 6990 6991 return AddRHS->getOperand(1); 6992 } 6993 6994 6995 /// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values? 6996 template<typename MaxExprType> 6997 static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr, 6998 const SCEV *Candidate) { 6999 const MaxExprType *MaxExpr = dyn_cast<MaxExprType>(MaybeMaxExpr); 7000 if (!MaxExpr) return false; 7001 7002 auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate); 7003 return It != MaxExpr->op_end(); 7004 } 7005 7006 7007 /// Is MaybeMinExpr an SMin or UMin of Candidate and some other values? 7008 template<typename MaxExprType> 7009 static bool IsMinConsistingOf(ScalarEvolution &SE, 7010 const SCEV *MaybeMinExpr, 7011 const SCEV *Candidate) { 7012 const SCEV *MaybeMaxExpr = MatchNotExpr(MaybeMinExpr); 7013 if (!MaybeMaxExpr) 7014 return false; 7015 7016 return IsMaxConsistingOf<MaxExprType>(MaybeMaxExpr, SE.getNotSCEV(Candidate)); 7017 } 7018 7019 7020 /// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max 7021 /// expression? 7022 static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, 7023 ICmpInst::Predicate Pred, 7024 const SCEV *LHS, const SCEV *RHS) { 7025 switch (Pred) { 7026 default: 7027 return false; 7028 7029 case ICmpInst::ICMP_SGE: 7030 std::swap(LHS, RHS); 7031 // fall through 7032 case ICmpInst::ICMP_SLE: 7033 return 7034 // min(A, ...) <= A 7035 IsMinConsistingOf<SCEVSMaxExpr>(SE, LHS, RHS) || 7036 // A <= max(A, ...) 7037 IsMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS); 7038 7039 case ICmpInst::ICMP_UGE: 7040 std::swap(LHS, RHS); 7041 // fall through 7042 case ICmpInst::ICMP_ULE: 7043 return 7044 // min(A, ...) <= A 7045 IsMinConsistingOf<SCEVUMaxExpr>(SE, LHS, RHS) || 7046 // A <= max(A, ...) 7047 IsMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS); 7048 } 7049 7050 llvm_unreachable("covered switch fell through?!"); 7051 } 7052 7053 /// isImpliedCondOperandsHelper - Test whether the condition described by 7054 /// Pred, LHS, and RHS is true whenever the condition described by Pred, 7055 /// FoundLHS, and FoundRHS is true. 7056 bool 7057 ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, 7058 const SCEV *LHS, const SCEV *RHS, 7059 const SCEV *FoundLHS, 7060 const SCEV *FoundRHS) { 7061 auto IsKnownPredicateFull = 7062 [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { 7063 return isKnownPredicateWithRanges(Pred, LHS, RHS) || 7064 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS); 7065 }; 7066 7067 switch (Pred) { 7068 default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); 7069 case ICmpInst::ICMP_EQ: 7070 case ICmpInst::ICMP_NE: 7071 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS)) 7072 return true; 7073 break; 7074 case ICmpInst::ICMP_SLT: 7075 case ICmpInst::ICMP_SLE: 7076 if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) && 7077 IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS)) 7078 return true; 7079 break; 7080 case ICmpInst::ICMP_SGT: 7081 case ICmpInst::ICMP_SGE: 7082 if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) && 7083 IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS)) 7084 return true; 7085 break; 7086 case ICmpInst::ICMP_ULT: 7087 case ICmpInst::ICMP_ULE: 7088 if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) && 7089 IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS)) 7090 return true; 7091 break; 7092 case ICmpInst::ICMP_UGT: 7093 case ICmpInst::ICMP_UGE: 7094 if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) && 7095 IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS)) 7096 return true; 7097 break; 7098 } 7099 7100 return false; 7101 } 7102 7103 // Verify if an linear IV with positive stride can overflow when in a 7104 // less-than comparison, knowing the invariant term of the comparison, the 7105 // stride and the knowledge of NSW/NUW flags on the recurrence. 7106 bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, 7107 bool IsSigned, bool NoWrap) { 7108 if (NoWrap) return false; 7109 7110 unsigned BitWidth = getTypeSizeInBits(RHS->getType()); 7111 const SCEV *One = getConstant(Stride->getType(), 1); 7112 7113 if (IsSigned) { 7114 APInt MaxRHS = getSignedRange(RHS).getSignedMax(); 7115 APInt MaxValue = APInt::getSignedMaxValue(BitWidth); 7116 APInt MaxStrideMinusOne = getSignedRange(getMinusSCEV(Stride, One)) 7117 .getSignedMax(); 7118 7119 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow! 7120 return (MaxValue - MaxStrideMinusOne).slt(MaxRHS); 7121 } 7122 7123 APInt MaxRHS = getUnsignedRange(RHS).getUnsignedMax(); 7124 APInt MaxValue = APInt::getMaxValue(BitWidth); 7125 APInt MaxStrideMinusOne = getUnsignedRange(getMinusSCEV(Stride, One)) 7126 .getUnsignedMax(); 7127 7128 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow! 7129 return (MaxValue - MaxStrideMinusOne).ult(MaxRHS); 7130 } 7131 7132 // Verify if an linear IV with negative stride can overflow when in a 7133 // greater-than comparison, knowing the invariant term of the comparison, 7134 // the stride and the knowledge of NSW/NUW flags on the recurrence. 7135 bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, 7136 bool IsSigned, bool NoWrap) { 7137 if (NoWrap) return false; 7138 7139 unsigned BitWidth = getTypeSizeInBits(RHS->getType()); 7140 const SCEV *One = getConstant(Stride->getType(), 1); 7141 7142 if (IsSigned) { 7143 APInt MinRHS = getSignedRange(RHS).getSignedMin(); 7144 APInt MinValue = APInt::getSignedMinValue(BitWidth); 7145 APInt MaxStrideMinusOne = getSignedRange(getMinusSCEV(Stride, One)) 7146 .getSignedMax(); 7147 7148 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow! 7149 return (MinValue + MaxStrideMinusOne).sgt(MinRHS); 7150 } 7151 7152 APInt MinRHS = getUnsignedRange(RHS).getUnsignedMin(); 7153 APInt MinValue = APInt::getMinValue(BitWidth); 7154 APInt MaxStrideMinusOne = getUnsignedRange(getMinusSCEV(Stride, One)) 7155 .getUnsignedMax(); 7156 7157 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow! 7158 return (MinValue + MaxStrideMinusOne).ugt(MinRHS); 7159 } 7160 7161 // Compute the backedge taken count knowing the interval difference, the 7162 // stride and presence of the equality in the comparison. 7163 const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, 7164 bool Equality) { 7165 const SCEV *One = getConstant(Step->getType(), 1); 7166 Delta = Equality ? getAddExpr(Delta, Step) 7167 : getAddExpr(Delta, getMinusSCEV(Step, One)); 7168 return getUDivExpr(Delta, Step); 7169 } 7170 7171 /// HowManyLessThans - Return the number of times a backedge containing the 7172 /// specified less-than comparison will execute. If not computable, return 7173 /// CouldNotCompute. 7174 /// 7175 /// @param ControlsExit is true when the LHS < RHS condition directly controls 7176 /// the branch (loops exits only if condition is true). In this case, we can use 7177 /// NoWrapFlags to skip overflow checks. 7178 ScalarEvolution::ExitLimit 7179 ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, 7180 const Loop *L, bool IsSigned, 7181 bool ControlsExit) { 7182 // We handle only IV < Invariant 7183 if (!isLoopInvariant(RHS, L)) 7184 return getCouldNotCompute(); 7185 7186 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); 7187 7188 // Avoid weird loops 7189 if (!IV || IV->getLoop() != L || !IV->isAffine()) 7190 return getCouldNotCompute(); 7191 7192 bool NoWrap = ControlsExit && 7193 IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); 7194 7195 const SCEV *Stride = IV->getStepRecurrence(*this); 7196 7197 // Avoid negative or zero stride values 7198 if (!isKnownPositive(Stride)) 7199 return getCouldNotCompute(); 7200 7201 // Avoid proven overflow cases: this will ensure that the backedge taken count 7202 // will not generate any unsigned overflow. Relaxed no-overflow conditions 7203 // exploit NoWrapFlags, allowing to optimize in presence of undefined 7204 // behaviors like the case of C language. 7205 if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) 7206 return getCouldNotCompute(); 7207 7208 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT 7209 : ICmpInst::ICMP_ULT; 7210 const SCEV *Start = IV->getStart(); 7211 const SCEV *End = RHS; 7212 if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) { 7213 const SCEV *Diff = getMinusSCEV(RHS, Start); 7214 // If we have NoWrap set, then we can assume that the increment won't 7215 // overflow, in which case if RHS - Start is a constant, we don't need to 7216 // do a max operation since we can just figure it out statically 7217 if (NoWrap && isa<SCEVConstant>(Diff)) { 7218 APInt D = dyn_cast<const SCEVConstant>(Diff)->getValue()->getValue(); 7219 if (D.isNegative()) 7220 End = Start; 7221 } else 7222 End = IsSigned ? getSMaxExpr(RHS, Start) 7223 : getUMaxExpr(RHS, Start); 7224 } 7225 7226 const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); 7227 7228 APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin() 7229 : getUnsignedRange(Start).getUnsignedMin(); 7230 7231 APInt MinStride = IsSigned ? getSignedRange(Stride).getSignedMin() 7232 : getUnsignedRange(Stride).getUnsignedMin(); 7233 7234 unsigned BitWidth = getTypeSizeInBits(LHS->getType()); 7235 APInt Limit = IsSigned ? APInt::getSignedMaxValue(BitWidth) - (MinStride - 1) 7236 : APInt::getMaxValue(BitWidth) - (MinStride - 1); 7237 7238 // Although End can be a MAX expression we estimate MaxEnd considering only 7239 // the case End = RHS. This is safe because in the other case (End - Start) 7240 // is zero, leading to a zero maximum backedge taken count. 7241 APInt MaxEnd = 7242 IsSigned ? APIntOps::smin(getSignedRange(RHS).getSignedMax(), Limit) 7243 : APIntOps::umin(getUnsignedRange(RHS).getUnsignedMax(), Limit); 7244 7245 const SCEV *MaxBECount; 7246 if (isa<SCEVConstant>(BECount)) 7247 MaxBECount = BECount; 7248 else 7249 MaxBECount = computeBECount(getConstant(MaxEnd - MinStart), 7250 getConstant(MinStride), false); 7251 7252 if (isa<SCEVCouldNotCompute>(MaxBECount)) 7253 MaxBECount = BECount; 7254 7255 return ExitLimit(BECount, MaxBECount); 7256 } 7257 7258 ScalarEvolution::ExitLimit 7259 ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, 7260 const Loop *L, bool IsSigned, 7261 bool ControlsExit) { 7262 // We handle only IV > Invariant 7263 if (!isLoopInvariant(RHS, L)) 7264 return getCouldNotCompute(); 7265 7266 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); 7267 7268 // Avoid weird loops 7269 if (!IV || IV->getLoop() != L || !IV->isAffine()) 7270 return getCouldNotCompute(); 7271 7272 bool NoWrap = ControlsExit && 7273 IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); 7274 7275 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); 7276 7277 // Avoid negative or zero stride values 7278 if (!isKnownPositive(Stride)) 7279 return getCouldNotCompute(); 7280 7281 // Avoid proven overflow cases: this will ensure that the backedge taken count 7282 // will not generate any unsigned overflow. Relaxed no-overflow conditions 7283 // exploit NoWrapFlags, allowing to optimize in presence of undefined 7284 // behaviors like the case of C language. 7285 if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap)) 7286 return getCouldNotCompute(); 7287 7288 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT 7289 : ICmpInst::ICMP_UGT; 7290 7291 const SCEV *Start = IV->getStart(); 7292 const SCEV *End = RHS; 7293 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) { 7294 const SCEV *Diff = getMinusSCEV(RHS, Start); 7295 // If we have NoWrap set, then we can assume that the increment won't 7296 // overflow, in which case if RHS - Start is a constant, we don't need to 7297 // do a max operation since we can just figure it out statically 7298 if (NoWrap && isa<SCEVConstant>(Diff)) { 7299 APInt D = dyn_cast<const SCEVConstant>(Diff)->getValue()->getValue(); 7300 if (!D.isNegative()) 7301 End = Start; 7302 } else 7303 End = IsSigned ? getSMinExpr(RHS, Start) 7304 : getUMinExpr(RHS, Start); 7305 } 7306 7307 const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false); 7308 7309 APInt MaxStart = IsSigned ? getSignedRange(Start).getSignedMax() 7310 : getUnsignedRange(Start).getUnsignedMax(); 7311 7312 APInt MinStride = IsSigned ? getSignedRange(Stride).getSignedMin() 7313 : getUnsignedRange(Stride).getUnsignedMin(); 7314 7315 unsigned BitWidth = getTypeSizeInBits(LHS->getType()); 7316 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1) 7317 : APInt::getMinValue(BitWidth) + (MinStride - 1); 7318 7319 // Although End can be a MIN expression we estimate MinEnd considering only 7320 // the case End = RHS. This is safe because in the other case (Start - End) 7321 // is zero, leading to a zero maximum backedge taken count. 7322 APInt MinEnd = 7323 IsSigned ? APIntOps::smax(getSignedRange(RHS).getSignedMin(), Limit) 7324 : APIntOps::umax(getUnsignedRange(RHS).getUnsignedMin(), Limit); 7325 7326 7327 const SCEV *MaxBECount = getCouldNotCompute(); 7328 if (isa<SCEVConstant>(BECount)) 7329 MaxBECount = BECount; 7330 else 7331 MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), 7332 getConstant(MinStride), false); 7333 7334 if (isa<SCEVCouldNotCompute>(MaxBECount)) 7335 MaxBECount = BECount; 7336 7337 return ExitLimit(BECount, MaxBECount); 7338 } 7339 7340 /// getNumIterationsInRange - Return the number of iterations of this loop that 7341 /// produce values in the specified constant range. Another way of looking at 7342 /// this is that it returns the first iteration number where the value is not in 7343 /// the condition, thus computing the exit count. If the iteration count can't 7344 /// be computed, an instance of SCEVCouldNotCompute is returned. 7345 const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, 7346 ScalarEvolution &SE) const { 7347 if (Range.isFullSet()) // Infinite loop. 7348 return SE.getCouldNotCompute(); 7349 7350 // If the start is a non-zero constant, shift the range to simplify things. 7351 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart())) 7352 if (!SC->getValue()->isZero()) { 7353 SmallVector<const SCEV *, 4> Operands(op_begin(), op_end()); 7354 Operands[0] = SE.getConstant(SC->getType(), 0); 7355 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(), 7356 getNoWrapFlags(FlagNW)); 7357 if (const SCEVAddRecExpr *ShiftedAddRec = 7358 dyn_cast<SCEVAddRecExpr>(Shifted)) 7359 return ShiftedAddRec->getNumIterationsInRange( 7360 Range.subtract(SC->getValue()->getValue()), SE); 7361 // This is strange and shouldn't happen. 7362 return SE.getCouldNotCompute(); 7363 } 7364 7365 // The only time we can solve this is when we have all constant indices. 7366 // Otherwise, we cannot determine the overflow conditions. 7367 for (unsigned i = 0, e = getNumOperands(); i != e; ++i) 7368 if (!isa<SCEVConstant>(getOperand(i))) 7369 return SE.getCouldNotCompute(); 7370 7371 7372 // Okay at this point we know that all elements of the chrec are constants and 7373 // that the start element is zero. 7374 7375 // First check to see if the range contains zero. If not, the first 7376 // iteration exits. 7377 unsigned BitWidth = SE.getTypeSizeInBits(getType()); 7378 if (!Range.contains(APInt(BitWidth, 0))) 7379 return SE.getConstant(getType(), 0); 7380 7381 if (isAffine()) { 7382 // If this is an affine expression then we have this situation: 7383 // Solve {0,+,A} in Range === Ax in Range 7384 7385 // We know that zero is in the range. If A is positive then we know that 7386 // the upper value of the range must be the first possible exit value. 7387 // If A is negative then the lower of the range is the last possible loop 7388 // value. Also note that we already checked for a full range. 7389 APInt One(BitWidth,1); 7390 APInt A = cast<SCEVConstant>(getOperand(1))->getValue()->getValue(); 7391 APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower(); 7392 7393 // The exit value should be (End+A)/A. 7394 APInt ExitVal = (End + A).udiv(A); 7395 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal); 7396 7397 // Evaluate at the exit value. If we really did fall out of the valid 7398 // range, then we computed our trip count, otherwise wrap around or other 7399 // things must have happened. 7400 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE); 7401 if (Range.contains(Val->getValue())) 7402 return SE.getCouldNotCompute(); // Something strange happened 7403 7404 // Ensure that the previous value is in the range. This is a sanity check. 7405 assert(Range.contains( 7406 EvaluateConstantChrecAtConstant(this, 7407 ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) && 7408 "Linear scev computation is off in a bad way!"); 7409 return SE.getConstant(ExitValue); 7410 } else if (isQuadratic()) { 7411 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the 7412 // quadratic equation to solve it. To do this, we must frame our problem in 7413 // terms of figuring out when zero is crossed, instead of when 7414 // Range.getUpper() is crossed. 7415 SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end()); 7416 NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper())); 7417 const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), 7418 // getNoWrapFlags(FlagNW) 7419 FlagAnyWrap); 7420 7421 // Next, solve the constructed addrec 7422 std::pair<const SCEV *,const SCEV *> Roots = 7423 SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE); 7424 const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first); 7425 const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second); 7426 if (R1) { 7427 // Pick the smallest positive root value. 7428 if (ConstantInt *CB = 7429 dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, 7430 R1->getValue(), R2->getValue()))) { 7431 if (CB->getZExtValue() == false) 7432 std::swap(R1, R2); // R1 is the minimum root now. 7433 7434 // Make sure the root is not off by one. The returned iteration should 7435 // not be in the range, but the previous one should be. When solving 7436 // for "X*X < 5", for example, we should not return a root of 2. 7437 ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this, 7438 R1->getValue(), 7439 SE); 7440 if (Range.contains(R1Val->getValue())) { 7441 // The next iteration must be out of the range... 7442 ConstantInt *NextVal = 7443 ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1); 7444 7445 R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); 7446 if (!Range.contains(R1Val->getValue())) 7447 return SE.getConstant(NextVal); 7448 return SE.getCouldNotCompute(); // Something strange happened 7449 } 7450 7451 // If R1 was not in the range, then it is a good return value. Make 7452 // sure that R1-1 WAS in the range though, just in case. 7453 ConstantInt *NextVal = 7454 ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1); 7455 R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); 7456 if (Range.contains(R1Val->getValue())) 7457 return R1; 7458 return SE.getCouldNotCompute(); // Something strange happened 7459 } 7460 } 7461 } 7462 7463 return SE.getCouldNotCompute(); 7464 } 7465 7466 namespace { 7467 struct FindUndefs { 7468 bool Found; 7469 FindUndefs() : Found(false) {} 7470 7471 bool follow(const SCEV *S) { 7472 if (const SCEVUnknown *C = dyn_cast<SCEVUnknown>(S)) { 7473 if (isa<UndefValue>(C->getValue())) 7474 Found = true; 7475 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) { 7476 if (isa<UndefValue>(C->getValue())) 7477 Found = true; 7478 } 7479 7480 // Keep looking if we haven't found it yet. 7481 return !Found; 7482 } 7483 bool isDone() const { 7484 // Stop recursion if we have found an undef. 7485 return Found; 7486 } 7487 }; 7488 } 7489 7490 // Return true when S contains at least an undef value. 7491 static inline bool 7492 containsUndefs(const SCEV *S) { 7493 FindUndefs F; 7494 SCEVTraversal<FindUndefs> ST(F); 7495 ST.visitAll(S); 7496 7497 return F.Found; 7498 } 7499 7500 namespace { 7501 // Collect all steps of SCEV expressions. 7502 struct SCEVCollectStrides { 7503 ScalarEvolution &SE; 7504 SmallVectorImpl<const SCEV *> &Strides; 7505 7506 SCEVCollectStrides(ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &S) 7507 : SE(SE), Strides(S) {} 7508 7509 bool follow(const SCEV *S) { 7510 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) 7511 Strides.push_back(AR->getStepRecurrence(SE)); 7512 return true; 7513 } 7514 bool isDone() const { return false; } 7515 }; 7516 7517 // Collect all SCEVUnknown and SCEVMulExpr expressions. 7518 struct SCEVCollectTerms { 7519 SmallVectorImpl<const SCEV *> &Terms; 7520 7521 SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T) 7522 : Terms(T) {} 7523 7524 bool follow(const SCEV *S) { 7525 if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S)) { 7526 if (!containsUndefs(S)) 7527 Terms.push_back(S); 7528 7529 // Stop recursion: once we collected a term, do not walk its operands. 7530 return false; 7531 } 7532 7533 // Keep looking. 7534 return true; 7535 } 7536 bool isDone() const { return false; } 7537 }; 7538 } 7539 7540 /// Find parametric terms in this SCEVAddRecExpr. 7541 void SCEVAddRecExpr::collectParametricTerms( 7542 ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &Terms) const { 7543 SmallVector<const SCEV *, 4> Strides; 7544 SCEVCollectStrides StrideCollector(SE, Strides); 7545 visitAll(this, StrideCollector); 7546 7547 DEBUG({ 7548 dbgs() << "Strides:\n"; 7549 for (const SCEV *S : Strides) 7550 dbgs() << *S << "\n"; 7551 }); 7552 7553 for (const SCEV *S : Strides) { 7554 SCEVCollectTerms TermCollector(Terms); 7555 visitAll(S, TermCollector); 7556 } 7557 7558 DEBUG({ 7559 dbgs() << "Terms:\n"; 7560 for (const SCEV *T : Terms) 7561 dbgs() << *T << "\n"; 7562 }); 7563 } 7564 7565 static bool findArrayDimensionsRec(ScalarEvolution &SE, 7566 SmallVectorImpl<const SCEV *> &Terms, 7567 SmallVectorImpl<const SCEV *> &Sizes) { 7568 int Last = Terms.size() - 1; 7569 const SCEV *Step = Terms[Last]; 7570 7571 // End of recursion. 7572 if (Last == 0) { 7573 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Step)) { 7574 SmallVector<const SCEV *, 2> Qs; 7575 for (const SCEV *Op : M->operands()) 7576 if (!isa<SCEVConstant>(Op)) 7577 Qs.push_back(Op); 7578 7579 Step = SE.getMulExpr(Qs); 7580 } 7581 7582 Sizes.push_back(Step); 7583 return true; 7584 } 7585 7586 for (const SCEV *&Term : Terms) { 7587 // Normalize the terms before the next call to findArrayDimensionsRec. 7588 const SCEV *Q, *R; 7589 SCEVDivision::divide(SE, Term, Step, &Q, &R); 7590 7591 // Bail out when GCD does not evenly divide one of the terms. 7592 if (!R->isZero()) 7593 return false; 7594 7595 Term = Q; 7596 } 7597 7598 // Remove all SCEVConstants. 7599 Terms.erase(std::remove_if(Terms.begin(), Terms.end(), [](const SCEV *E) { 7600 return isa<SCEVConstant>(E); 7601 }), 7602 Terms.end()); 7603 7604 if (Terms.size() > 0) 7605 if (!findArrayDimensionsRec(SE, Terms, Sizes)) 7606 return false; 7607 7608 Sizes.push_back(Step); 7609 return true; 7610 } 7611 7612 namespace { 7613 struct FindParameter { 7614 bool FoundParameter; 7615 FindParameter() : FoundParameter(false) {} 7616 7617 bool follow(const SCEV *S) { 7618 if (isa<SCEVUnknown>(S)) { 7619 FoundParameter = true; 7620 // Stop recursion: we found a parameter. 7621 return false; 7622 } 7623 // Keep looking. 7624 return true; 7625 } 7626 bool isDone() const { 7627 // Stop recursion if we have found a parameter. 7628 return FoundParameter; 7629 } 7630 }; 7631 } 7632 7633 // Returns true when S contains at least a SCEVUnknown parameter. 7634 static inline bool 7635 containsParameters(const SCEV *S) { 7636 FindParameter F; 7637 SCEVTraversal<FindParameter> ST(F); 7638 ST.visitAll(S); 7639 7640 return F.FoundParameter; 7641 } 7642 7643 // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter. 7644 static inline bool 7645 containsParameters(SmallVectorImpl<const SCEV *> &Terms) { 7646 for (const SCEV *T : Terms) 7647 if (containsParameters(T)) 7648 return true; 7649 return false; 7650 } 7651 7652 // Return the number of product terms in S. 7653 static inline int numberOfTerms(const SCEV *S) { 7654 if (const SCEVMulExpr *Expr = dyn_cast<SCEVMulExpr>(S)) 7655 return Expr->getNumOperands(); 7656 return 1; 7657 } 7658 7659 static const SCEV *removeConstantFactors(ScalarEvolution &SE, const SCEV *T) { 7660 if (isa<SCEVConstant>(T)) 7661 return nullptr; 7662 7663 if (isa<SCEVUnknown>(T)) 7664 return T; 7665 7666 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(T)) { 7667 SmallVector<const SCEV *, 2> Factors; 7668 for (const SCEV *Op : M->operands()) 7669 if (!isa<SCEVConstant>(Op)) 7670 Factors.push_back(Op); 7671 7672 return SE.getMulExpr(Factors); 7673 } 7674 7675 return T; 7676 } 7677 7678 /// Return the size of an element read or written by Inst. 7679 const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { 7680 Type *Ty; 7681 if (StoreInst *Store = dyn_cast<StoreInst>(Inst)) 7682 Ty = Store->getValueOperand()->getType(); 7683 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst)) 7684 Ty = Load->getType(); 7685 else 7686 return nullptr; 7687 7688 Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty)); 7689 return getSizeOfExpr(ETy, Ty); 7690 } 7691 7692 /// Second step of delinearization: compute the array dimensions Sizes from the 7693 /// set of Terms extracted from the memory access function of this SCEVAddRec. 7694 void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, 7695 SmallVectorImpl<const SCEV *> &Sizes, 7696 const SCEV *ElementSize) const { 7697 7698 if (Terms.size() < 1 || !ElementSize) 7699 return; 7700 7701 // Early return when Terms do not contain parameters: we do not delinearize 7702 // non parametric SCEVs. 7703 if (!containsParameters(Terms)) 7704 return; 7705 7706 DEBUG({ 7707 dbgs() << "Terms:\n"; 7708 for (const SCEV *T : Terms) 7709 dbgs() << *T << "\n"; 7710 }); 7711 7712 // Remove duplicates. 7713 std::sort(Terms.begin(), Terms.end()); 7714 Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end()); 7715 7716 // Put larger terms first. 7717 std::sort(Terms.begin(), Terms.end(), [](const SCEV *LHS, const SCEV *RHS) { 7718 return numberOfTerms(LHS) > numberOfTerms(RHS); 7719 }); 7720 7721 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); 7722 7723 // Divide all terms by the element size. 7724 for (const SCEV *&Term : Terms) { 7725 const SCEV *Q, *R; 7726 SCEVDivision::divide(SE, Term, ElementSize, &Q, &R); 7727 Term = Q; 7728 } 7729 7730 SmallVector<const SCEV *, 4> NewTerms; 7731 7732 // Remove constant factors. 7733 for (const SCEV *T : Terms) 7734 if (const SCEV *NewT = removeConstantFactors(SE, T)) 7735 NewTerms.push_back(NewT); 7736 7737 DEBUG({ 7738 dbgs() << "Terms after sorting:\n"; 7739 for (const SCEV *T : NewTerms) 7740 dbgs() << *T << "\n"; 7741 }); 7742 7743 if (NewTerms.empty() || 7744 !findArrayDimensionsRec(SE, NewTerms, Sizes)) { 7745 Sizes.clear(); 7746 return; 7747 } 7748 7749 // The last element to be pushed into Sizes is the size of an element. 7750 Sizes.push_back(ElementSize); 7751 7752 DEBUG({ 7753 dbgs() << "Sizes:\n"; 7754 for (const SCEV *S : Sizes) 7755 dbgs() << *S << "\n"; 7756 }); 7757 } 7758 7759 /// Third step of delinearization: compute the access functions for the 7760 /// Subscripts based on the dimensions in Sizes. 7761 void SCEVAddRecExpr::computeAccessFunctions( 7762 ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &Subscripts, 7763 SmallVectorImpl<const SCEV *> &Sizes) const { 7764 7765 // Early exit in case this SCEV is not an affine multivariate function. 7766 if (Sizes.empty() || !this->isAffine()) 7767 return; 7768 7769 const SCEV *Res = this; 7770 int Last = Sizes.size() - 1; 7771 for (int i = Last; i >= 0; i--) { 7772 const SCEV *Q, *R; 7773 SCEVDivision::divide(SE, Res, Sizes[i], &Q, &R); 7774 7775 DEBUG({ 7776 dbgs() << "Res: " << *Res << "\n"; 7777 dbgs() << "Sizes[i]: " << *Sizes[i] << "\n"; 7778 dbgs() << "Res divided by Sizes[i]:\n"; 7779 dbgs() << "Quotient: " << *Q << "\n"; 7780 dbgs() << "Remainder: " << *R << "\n"; 7781 }); 7782 7783 Res = Q; 7784 7785 // Do not record the last subscript corresponding to the size of elements in 7786 // the array. 7787 if (i == Last) { 7788 7789 // Bail out if the remainder is too complex. 7790 if (isa<SCEVAddRecExpr>(R)) { 7791 Subscripts.clear(); 7792 Sizes.clear(); 7793 return; 7794 } 7795 7796 continue; 7797 } 7798 7799 // Record the access function for the current subscript. 7800 Subscripts.push_back(R); 7801 } 7802 7803 // Also push in last position the remainder of the last division: it will be 7804 // the access function of the innermost dimension. 7805 Subscripts.push_back(Res); 7806 7807 std::reverse(Subscripts.begin(), Subscripts.end()); 7808 7809 DEBUG({ 7810 dbgs() << "Subscripts:\n"; 7811 for (const SCEV *S : Subscripts) 7812 dbgs() << *S << "\n"; 7813 }); 7814 } 7815 7816 /// Splits the SCEV into two vectors of SCEVs representing the subscripts and 7817 /// sizes of an array access. Returns the remainder of the delinearization that 7818 /// is the offset start of the array. The SCEV->delinearize algorithm computes 7819 /// the multiples of SCEV coefficients: that is a pattern matching of sub 7820 /// expressions in the stride and base of a SCEV corresponding to the 7821 /// computation of a GCD (greatest common divisor) of base and stride. When 7822 /// SCEV->delinearize fails, it returns the SCEV unchanged. 7823 /// 7824 /// For example: when analyzing the memory access A[i][j][k] in this loop nest 7825 /// 7826 /// void foo(long n, long m, long o, double A[n][m][o]) { 7827 /// 7828 /// for (long i = 0; i < n; i++) 7829 /// for (long j = 0; j < m; j++) 7830 /// for (long k = 0; k < o; k++) 7831 /// A[i][j][k] = 1.0; 7832 /// } 7833 /// 7834 /// the delinearization input is the following AddRec SCEV: 7835 /// 7836 /// AddRec: {{{%A,+,(8 * %m * %o)}<%for.i>,+,(8 * %o)}<%for.j>,+,8}<%for.k> 7837 /// 7838 /// From this SCEV, we are able to say that the base offset of the access is %A 7839 /// because it appears as an offset that does not divide any of the strides in 7840 /// the loops: 7841 /// 7842 /// CHECK: Base offset: %A 7843 /// 7844 /// and then SCEV->delinearize determines the size of some of the dimensions of 7845 /// the array as these are the multiples by which the strides are happening: 7846 /// 7847 /// CHECK: ArrayDecl[UnknownSize][%m][%o] with elements of sizeof(double) bytes. 7848 /// 7849 /// Note that the outermost dimension remains of UnknownSize because there are 7850 /// no strides that would help identifying the size of the last dimension: when 7851 /// the array has been statically allocated, one could compute the size of that 7852 /// dimension by dividing the overall size of the array by the size of the known 7853 /// dimensions: %m * %o * 8. 7854 /// 7855 /// Finally delinearize provides the access functions for the array reference 7856 /// that does correspond to A[i][j][k] of the above C testcase: 7857 /// 7858 /// CHECK: ArrayRef[{0,+,1}<%for.i>][{0,+,1}<%for.j>][{0,+,1}<%for.k>] 7859 /// 7860 /// The testcases are checking the output of a function pass: 7861 /// DelinearizationPass that walks through all loads and stores of a function 7862 /// asking for the SCEV of the memory access with respect to all enclosing 7863 /// loops, calling SCEV->delinearize on that and printing the results. 7864 7865 void SCEVAddRecExpr::delinearize(ScalarEvolution &SE, 7866 SmallVectorImpl<const SCEV *> &Subscripts, 7867 SmallVectorImpl<const SCEV *> &Sizes, 7868 const SCEV *ElementSize) const { 7869 // First step: collect parametric terms. 7870 SmallVector<const SCEV *, 4> Terms; 7871 collectParametricTerms(SE, Terms); 7872 7873 if (Terms.empty()) 7874 return; 7875 7876 // Second step: find subscript sizes. 7877 SE.findArrayDimensions(Terms, Sizes, ElementSize); 7878 7879 if (Sizes.empty()) 7880 return; 7881 7882 // Third step: compute the access functions for each subscript. 7883 computeAccessFunctions(SE, Subscripts, Sizes); 7884 7885 if (Subscripts.empty()) 7886 return; 7887 7888 DEBUG({ 7889 dbgs() << "succeeded to delinearize " << *this << "\n"; 7890 dbgs() << "ArrayDecl[UnknownSize]"; 7891 for (const SCEV *S : Sizes) 7892 dbgs() << "[" << *S << "]"; 7893 7894 dbgs() << "\nArrayRef"; 7895 for (const SCEV *S : Subscripts) 7896 dbgs() << "[" << *S << "]"; 7897 dbgs() << "\n"; 7898 }); 7899 } 7900 7901 //===----------------------------------------------------------------------===// 7902 // SCEVCallbackVH Class Implementation 7903 //===----------------------------------------------------------------------===// 7904 7905 void ScalarEvolution::SCEVCallbackVH::deleted() { 7906 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); 7907 if (PHINode *PN = dyn_cast<PHINode>(getValPtr())) 7908 SE->ConstantEvolutionLoopExitValue.erase(PN); 7909 SE->ValueExprMap.erase(getValPtr()); 7910 // this now dangles! 7911 } 7912 7913 void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { 7914 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); 7915 7916 // Forget all the expressions associated with users of the old value, 7917 // so that future queries will recompute the expressions using the new 7918 // value. 7919 Value *Old = getValPtr(); 7920 SmallVector<User *, 16> Worklist(Old->user_begin(), Old->user_end()); 7921 SmallPtrSet<User *, 8> Visited; 7922 while (!Worklist.empty()) { 7923 User *U = Worklist.pop_back_val(); 7924 // Deleting the Old value will cause this to dangle. Postpone 7925 // that until everything else is done. 7926 if (U == Old) 7927 continue; 7928 if (!Visited.insert(U).second) 7929 continue; 7930 if (PHINode *PN = dyn_cast<PHINode>(U)) 7931 SE->ConstantEvolutionLoopExitValue.erase(PN); 7932 SE->ValueExprMap.erase(U); 7933 Worklist.insert(Worklist.end(), U->user_begin(), U->user_end()); 7934 } 7935 // Delete the Old value. 7936 if (PHINode *PN = dyn_cast<PHINode>(Old)) 7937 SE->ConstantEvolutionLoopExitValue.erase(PN); 7938 SE->ValueExprMap.erase(Old); 7939 // this now dangles! 7940 } 7941 7942 ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) 7943 : CallbackVH(V), SE(se) {} 7944 7945 //===----------------------------------------------------------------------===// 7946 // ScalarEvolution Class Implementation 7947 //===----------------------------------------------------------------------===// 7948 7949 ScalarEvolution::ScalarEvolution() 7950 : FunctionPass(ID), ValuesAtScopes(64), LoopDispositions(64), 7951 BlockDispositions(64), FirstUnknown(nullptr) { 7952 initializeScalarEvolutionPass(*PassRegistry::getPassRegistry()); 7953 } 7954 7955 bool ScalarEvolution::runOnFunction(Function &F) { 7956 this->F = &F; 7957 AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); 7958 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 7959 DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); 7960 DL = DLP ? &DLP->getDataLayout() : nullptr; 7961 TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); 7962 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 7963 return false; 7964 } 7965 7966 void ScalarEvolution::releaseMemory() { 7967 // Iterate through all the SCEVUnknown instances and call their 7968 // destructors, so that they release their references to their values. 7969 for (SCEVUnknown *U = FirstUnknown; U; U = U->Next) 7970 U->~SCEVUnknown(); 7971 FirstUnknown = nullptr; 7972 7973 ValueExprMap.clear(); 7974 7975 // Free any extra memory created for ExitNotTakenInfo in the unlikely event 7976 // that a loop had multiple computable exits. 7977 for (DenseMap<const Loop*, BackedgeTakenInfo>::iterator I = 7978 BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end(); 7979 I != E; ++I) { 7980 I->second.clear(); 7981 } 7982 7983 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); 7984 7985 BackedgeTakenCounts.clear(); 7986 ConstantEvolutionLoopExitValue.clear(); 7987 ValuesAtScopes.clear(); 7988 LoopDispositions.clear(); 7989 BlockDispositions.clear(); 7990 UnsignedRanges.clear(); 7991 SignedRanges.clear(); 7992 UniqueSCEVs.clear(); 7993 SCEVAllocator.Reset(); 7994 } 7995 7996 void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const { 7997 AU.setPreservesAll(); 7998 AU.addRequired<AssumptionCacheTracker>(); 7999 AU.addRequiredTransitive<LoopInfoWrapperPass>(); 8000 AU.addRequiredTransitive<DominatorTreeWrapperPass>(); 8001 AU.addRequired<TargetLibraryInfoWrapperPass>(); 8002 } 8003 8004 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { 8005 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L)); 8006 } 8007 8008 static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, 8009 const Loop *L) { 8010 // Print all inner loops first 8011 for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) 8012 PrintLoopInfo(OS, SE, *I); 8013 8014 OS << "Loop "; 8015 L->getHeader()->printAsOperand(OS, /*PrintType=*/false); 8016 OS << ": "; 8017 8018 SmallVector<BasicBlock *, 8> ExitBlocks; 8019 L->getExitBlocks(ExitBlocks); 8020 if (ExitBlocks.size() != 1) 8021 OS << "<multiple exits> "; 8022 8023 if (SE->hasLoopInvariantBackedgeTakenCount(L)) { 8024 OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L); 8025 } else { 8026 OS << "Unpredictable backedge-taken count. "; 8027 } 8028 8029 OS << "\n" 8030 "Loop "; 8031 L->getHeader()->printAsOperand(OS, /*PrintType=*/false); 8032 OS << ": "; 8033 8034 if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) { 8035 OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L); 8036 } else { 8037 OS << "Unpredictable max backedge-taken count. "; 8038 } 8039 8040 OS << "\n"; 8041 } 8042 8043 void ScalarEvolution::print(raw_ostream &OS, const Module *) const { 8044 // ScalarEvolution's implementation of the print method is to print 8045 // out SCEV values of all instructions that are interesting. Doing 8046 // this potentially causes it to create new SCEV objects though, 8047 // which technically conflicts with the const qualifier. This isn't 8048 // observable from outside the class though, so casting away the 8049 // const isn't dangerous. 8050 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); 8051 8052 OS << "Classifying expressions for: "; 8053 F->printAsOperand(OS, /*PrintType=*/false); 8054 OS << "\n"; 8055 for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) 8056 if (isSCEVable(I->getType()) && !isa<CmpInst>(*I)) { 8057 OS << *I << '\n'; 8058 OS << " --> "; 8059 const SCEV *SV = SE.getSCEV(&*I); 8060 SV->print(OS); 8061 8062 const Loop *L = LI->getLoopFor((*I).getParent()); 8063 8064 const SCEV *AtUse = SE.getSCEVAtScope(SV, L); 8065 if (AtUse != SV) { 8066 OS << " --> "; 8067 AtUse->print(OS); 8068 } 8069 8070 if (L) { 8071 OS << "\t\t" "Exits: "; 8072 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); 8073 if (!SE.isLoopInvariant(ExitValue, L)) { 8074 OS << "<<Unknown>>"; 8075 } else { 8076 OS << *ExitValue; 8077 } 8078 } 8079 8080 OS << "\n"; 8081 } 8082 8083 OS << "Determining loop execution counts for: "; 8084 F->printAsOperand(OS, /*PrintType=*/false); 8085 OS << "\n"; 8086 for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) 8087 PrintLoopInfo(OS, &SE, *I); 8088 } 8089 8090 ScalarEvolution::LoopDisposition 8091 ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { 8092 auto &Values = LoopDispositions[S]; 8093 for (auto &V : Values) { 8094 if (V.getPointer() == L) 8095 return V.getInt(); 8096 } 8097 Values.emplace_back(L, LoopVariant); 8098 LoopDisposition D = computeLoopDisposition(S, L); 8099 auto &Values2 = LoopDispositions[S]; 8100 for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { 8101 if (V.getPointer() == L) { 8102 V.setInt(D); 8103 break; 8104 } 8105 } 8106 return D; 8107 } 8108 8109 ScalarEvolution::LoopDisposition 8110 ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { 8111 switch (static_cast<SCEVTypes>(S->getSCEVType())) { 8112 case scConstant: 8113 return LoopInvariant; 8114 case scTruncate: 8115 case scZeroExtend: 8116 case scSignExtend: 8117 return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L); 8118 case scAddRecExpr: { 8119 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S); 8120 8121 // If L is the addrec's loop, it's computable. 8122 if (AR->getLoop() == L) 8123 return LoopComputable; 8124 8125 // Add recurrences are never invariant in the function-body (null loop). 8126 if (!L) 8127 return LoopVariant; 8128 8129 // This recurrence is variant w.r.t. L if L contains AR's loop. 8130 if (L->contains(AR->getLoop())) 8131 return LoopVariant; 8132 8133 // This recurrence is invariant w.r.t. L if AR's loop contains L. 8134 if (AR->getLoop()->contains(L)) 8135 return LoopInvariant; 8136 8137 // This recurrence is variant w.r.t. L if any of its operands 8138 // are variant. 8139 for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end(); 8140 I != E; ++I) 8141 if (!isLoopInvariant(*I, L)) 8142 return LoopVariant; 8143 8144 // Otherwise it's loop-invariant. 8145 return LoopInvariant; 8146 } 8147 case scAddExpr: 8148 case scMulExpr: 8149 case scUMaxExpr: 8150 case scSMaxExpr: { 8151 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S); 8152 bool HasVarying = false; 8153 for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); 8154 I != E; ++I) { 8155 LoopDisposition D = getLoopDisposition(*I, L); 8156 if (D == LoopVariant) 8157 return LoopVariant; 8158 if (D == LoopComputable) 8159 HasVarying = true; 8160 } 8161 return HasVarying ? LoopComputable : LoopInvariant; 8162 } 8163 case scUDivExpr: { 8164 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S); 8165 LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L); 8166 if (LD == LoopVariant) 8167 return LoopVariant; 8168 LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L); 8169 if (RD == LoopVariant) 8170 return LoopVariant; 8171 return (LD == LoopInvariant && RD == LoopInvariant) ? 8172 LoopInvariant : LoopComputable; 8173 } 8174 case scUnknown: 8175 // All non-instruction values are loop invariant. All instructions are loop 8176 // invariant if they are not contained in the specified loop. 8177 // Instructions are never considered invariant in the function body 8178 // (null loop) because they are defined within the "loop". 8179 if (Instruction *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) 8180 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant; 8181 return LoopInvariant; 8182 case scCouldNotCompute: 8183 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 8184 } 8185 llvm_unreachable("Unknown SCEV kind!"); 8186 } 8187 8188 bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) { 8189 return getLoopDisposition(S, L) == LoopInvariant; 8190 } 8191 8192 bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) { 8193 return getLoopDisposition(S, L) == LoopComputable; 8194 } 8195 8196 ScalarEvolution::BlockDisposition 8197 ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { 8198 auto &Values = BlockDispositions[S]; 8199 for (auto &V : Values) { 8200 if (V.getPointer() == BB) 8201 return V.getInt(); 8202 } 8203 Values.emplace_back(BB, DoesNotDominateBlock); 8204 BlockDisposition D = computeBlockDisposition(S, BB); 8205 auto &Values2 = BlockDispositions[S]; 8206 for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { 8207 if (V.getPointer() == BB) { 8208 V.setInt(D); 8209 break; 8210 } 8211 } 8212 return D; 8213 } 8214 8215 ScalarEvolution::BlockDisposition 8216 ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { 8217 switch (static_cast<SCEVTypes>(S->getSCEVType())) { 8218 case scConstant: 8219 return ProperlyDominatesBlock; 8220 case scTruncate: 8221 case scZeroExtend: 8222 case scSignExtend: 8223 return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB); 8224 case scAddRecExpr: { 8225 // This uses a "dominates" query instead of "properly dominates" query 8226 // to test for proper dominance too, because the instruction which 8227 // produces the addrec's value is a PHI, and a PHI effectively properly 8228 // dominates its entire containing block. 8229 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S); 8230 if (!DT->dominates(AR->getLoop()->getHeader(), BB)) 8231 return DoesNotDominateBlock; 8232 } 8233 // FALL THROUGH into SCEVNAryExpr handling. 8234 case scAddExpr: 8235 case scMulExpr: 8236 case scUMaxExpr: 8237 case scSMaxExpr: { 8238 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S); 8239 bool Proper = true; 8240 for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); 8241 I != E; ++I) { 8242 BlockDisposition D = getBlockDisposition(*I, BB); 8243 if (D == DoesNotDominateBlock) 8244 return DoesNotDominateBlock; 8245 if (D == DominatesBlock) 8246 Proper = false; 8247 } 8248 return Proper ? ProperlyDominatesBlock : DominatesBlock; 8249 } 8250 case scUDivExpr: { 8251 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S); 8252 const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS(); 8253 BlockDisposition LD = getBlockDisposition(LHS, BB); 8254 if (LD == DoesNotDominateBlock) 8255 return DoesNotDominateBlock; 8256 BlockDisposition RD = getBlockDisposition(RHS, BB); 8257 if (RD == DoesNotDominateBlock) 8258 return DoesNotDominateBlock; 8259 return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ? 8260 ProperlyDominatesBlock : DominatesBlock; 8261 } 8262 case scUnknown: 8263 if (Instruction *I = 8264 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) { 8265 if (I->getParent() == BB) 8266 return DominatesBlock; 8267 if (DT->properlyDominates(I->getParent(), BB)) 8268 return ProperlyDominatesBlock; 8269 return DoesNotDominateBlock; 8270 } 8271 return ProperlyDominatesBlock; 8272 case scCouldNotCompute: 8273 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 8274 } 8275 llvm_unreachable("Unknown SCEV kind!"); 8276 } 8277 8278 bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) { 8279 return getBlockDisposition(S, BB) >= DominatesBlock; 8280 } 8281 8282 bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) { 8283 return getBlockDisposition(S, BB) == ProperlyDominatesBlock; 8284 } 8285 8286 namespace { 8287 // Search for a SCEV expression node within an expression tree. 8288 // Implements SCEVTraversal::Visitor. 8289 struct SCEVSearch { 8290 const SCEV *Node; 8291 bool IsFound; 8292 8293 SCEVSearch(const SCEV *N): Node(N), IsFound(false) {} 8294 8295 bool follow(const SCEV *S) { 8296 IsFound |= (S == Node); 8297 return !IsFound; 8298 } 8299 bool isDone() const { return IsFound; } 8300 }; 8301 } 8302 8303 bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { 8304 SCEVSearch Search(Op); 8305 visitAll(S, Search); 8306 return Search.IsFound; 8307 } 8308 8309 void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { 8310 ValuesAtScopes.erase(S); 8311 LoopDispositions.erase(S); 8312 BlockDispositions.erase(S); 8313 UnsignedRanges.erase(S); 8314 SignedRanges.erase(S); 8315 8316 for (DenseMap<const Loop*, BackedgeTakenInfo>::iterator I = 8317 BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end(); I != E; ) { 8318 BackedgeTakenInfo &BEInfo = I->second; 8319 if (BEInfo.hasOperand(S, this)) { 8320 BEInfo.clear(); 8321 BackedgeTakenCounts.erase(I++); 8322 } 8323 else 8324 ++I; 8325 } 8326 } 8327 8328 typedef DenseMap<const Loop *, std::string> VerifyMap; 8329 8330 /// replaceSubString - Replaces all occurrences of From in Str with To. 8331 static void replaceSubString(std::string &Str, StringRef From, StringRef To) { 8332 size_t Pos = 0; 8333 while ((Pos = Str.find(From, Pos)) != std::string::npos) { 8334 Str.replace(Pos, From.size(), To.data(), To.size()); 8335 Pos += To.size(); 8336 } 8337 } 8338 8339 /// getLoopBackedgeTakenCounts - Helper method for verifyAnalysis. 8340 static void 8341 getLoopBackedgeTakenCounts(Loop *L, VerifyMap &Map, ScalarEvolution &SE) { 8342 for (Loop::reverse_iterator I = L->rbegin(), E = L->rend(); I != E; ++I) { 8343 getLoopBackedgeTakenCounts(*I, Map, SE); // recurse. 8344 8345 std::string &S = Map[L]; 8346 if (S.empty()) { 8347 raw_string_ostream OS(S); 8348 SE.getBackedgeTakenCount(L)->print(OS); 8349 8350 // false and 0 are semantically equivalent. This can happen in dead loops. 8351 replaceSubString(OS.str(), "false", "0"); 8352 // Remove wrap flags, their use in SCEV is highly fragile. 8353 // FIXME: Remove this when SCEV gets smarter about them. 8354 replaceSubString(OS.str(), "<nw>", ""); 8355 replaceSubString(OS.str(), "<nsw>", ""); 8356 replaceSubString(OS.str(), "<nuw>", ""); 8357 } 8358 } 8359 } 8360 8361 void ScalarEvolution::verifyAnalysis() const { 8362 if (!VerifySCEV) 8363 return; 8364 8365 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); 8366 8367 // Gather stringified backedge taken counts for all loops using SCEV's caches. 8368 // FIXME: It would be much better to store actual values instead of strings, 8369 // but SCEV pointers will change if we drop the caches. 8370 VerifyMap BackedgeDumpsOld, BackedgeDumpsNew; 8371 for (LoopInfo::reverse_iterator I = LI->rbegin(), E = LI->rend(); I != E; ++I) 8372 getLoopBackedgeTakenCounts(*I, BackedgeDumpsOld, SE); 8373 8374 // Gather stringified backedge taken counts for all loops without using 8375 // SCEV's caches. 8376 SE.releaseMemory(); 8377 for (LoopInfo::reverse_iterator I = LI->rbegin(), E = LI->rend(); I != E; ++I) 8378 getLoopBackedgeTakenCounts(*I, BackedgeDumpsNew, SE); 8379 8380 // Now compare whether they're the same with and without caches. This allows 8381 // verifying that no pass changed the cache. 8382 assert(BackedgeDumpsOld.size() == BackedgeDumpsNew.size() && 8383 "New loops suddenly appeared!"); 8384 8385 for (VerifyMap::iterator OldI = BackedgeDumpsOld.begin(), 8386 OldE = BackedgeDumpsOld.end(), 8387 NewI = BackedgeDumpsNew.begin(); 8388 OldI != OldE; ++OldI, ++NewI) { 8389 assert(OldI->first == NewI->first && "Loop order changed!"); 8390 8391 // Compare the stringified SCEVs. We don't care if undef backedgetaken count 8392 // changes. 8393 // FIXME: We currently ignore SCEV changes from/to CouldNotCompute. This 8394 // means that a pass is buggy or SCEV has to learn a new pattern but is 8395 // usually not harmful. 8396 if (OldI->second != NewI->second && 8397 OldI->second.find("undef") == std::string::npos && 8398 NewI->second.find("undef") == std::string::npos && 8399 OldI->second != "***COULDNOTCOMPUTE***" && 8400 NewI->second != "***COULDNOTCOMPUTE***") { 8401 dbgs() << "SCEVValidator: SCEV for loop '" 8402 << OldI->first->getHeader()->getName() 8403 << "' changed from '" << OldI->second 8404 << "' to '" << NewI->second << "'!\n"; 8405 std::abort(); 8406 } 8407 } 8408 8409 // TODO: Verify more things. 8410 } 8411