xref: /llvm-project/llvm/lib/Analysis/ScalarEvolution.cpp (revision a57cc8bc817f3ff7a48bfd8221562e3cc2a2bc10)
1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains the implementation of the scalar evolution analysis
11 // engine, which is used primarily to analyze expressions involving induction
12 // variables in loops.
13 //
14 // There are several aspects to this library.  First is the representation of
15 // scalar expressions, which are represented as subclasses of the SCEV class.
16 // These classes are used to represent certain types of subexpressions that we
17 // can handle. We only create one SCEV of a particular shape, so
18 // pointer-comparisons for equality are legal.
19 //
20 // One important aspect of the SCEV objects is that they are never cyclic, even
21 // if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
22 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
23 // recurrence) then we represent it directly as a recurrence node, otherwise we
24 // represent it as a SCEVUnknown node.
25 //
26 // In addition to being able to represent expressions of various types, we also
27 // have folders that are used to build the *canonical* representation for a
28 // particular expression.  These folders are capable of using a variety of
29 // rewrite rules to simplify the expressions.
30 //
31 // Once the folders are defined, we can implement the more interesting
32 // higher-level code, such as the code that recognizes PHI nodes of various
33 // types, computes the execution count of a loop, etc.
34 //
35 // TODO: We should use these routines and value representations to implement
36 // dependence analysis!
37 //
38 //===----------------------------------------------------------------------===//
39 //
40 // There are several good references for the techniques used in this analysis.
41 //
42 //  Chains of recurrences -- a method to expedite the evaluation
43 //  of closed-form functions
44 //  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
45 //
46 //  On computational properties of chains of recurrences
47 //  Eugene V. Zima
48 //
49 //  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
50 //  Robert A. van Engelen
51 //
52 //  Efficient Symbolic Analysis for Optimizing Compilers
53 //  Robert A. van Engelen
54 //
55 //  Using the chains of recurrences algebra for data dependence testing and
56 //  induction variable substitution
57 //  MS Thesis, Johnie Birch
58 //
59 //===----------------------------------------------------------------------===//
60 
61 #include "llvm/Analysis/ScalarEvolution.h"
62 #include "llvm/ADT/Optional.h"
63 #include "llvm/ADT/STLExtras.h"
64 #include "llvm/ADT/SmallPtrSet.h"
65 #include "llvm/ADT/Statistic.h"
66 #include "llvm/Analysis/AssumptionCache.h"
67 #include "llvm/Analysis/ConstantFolding.h"
68 #include "llvm/Analysis/InstructionSimplify.h"
69 #include "llvm/Analysis/LoopInfo.h"
70 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
71 #include "llvm/Analysis/TargetLibraryInfo.h"
72 #include "llvm/Analysis/ValueTracking.h"
73 #include "llvm/IR/ConstantRange.h"
74 #include "llvm/IR/Constants.h"
75 #include "llvm/IR/DataLayout.h"
76 #include "llvm/IR/DerivedTypes.h"
77 #include "llvm/IR/Dominators.h"
78 #include "llvm/IR/GetElementPtrTypeIterator.h"
79 #include "llvm/IR/GlobalAlias.h"
80 #include "llvm/IR/GlobalVariable.h"
81 #include "llvm/IR/InstIterator.h"
82 #include "llvm/IR/Instructions.h"
83 #include "llvm/IR/LLVMContext.h"
84 #include "llvm/IR/Metadata.h"
85 #include "llvm/IR/Operator.h"
86 #include "llvm/Support/CommandLine.h"
87 #include "llvm/Support/Debug.h"
88 #include "llvm/Support/ErrorHandling.h"
89 #include "llvm/Support/MathExtras.h"
90 #include "llvm/Support/raw_ostream.h"
91 #include <algorithm>
92 using namespace llvm;
93 
94 #define DEBUG_TYPE "scalar-evolution"
95 
96 STATISTIC(NumArrayLenItCounts,
97           "Number of trip counts computed with array length");
98 STATISTIC(NumTripCountsComputed,
99           "Number of loops with predictable loop counts");
100 STATISTIC(NumTripCountsNotComputed,
101           "Number of loops without predictable loop counts");
102 STATISTIC(NumBruteForceTripCountsComputed,
103           "Number of loops with trip counts computed by force");
104 
105 static cl::opt<unsigned>
106 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
107                         cl::desc("Maximum number of iterations SCEV will "
108                                  "symbolically execute a constant "
109                                  "derived loop"),
110                         cl::init(100));
111 
112 // FIXME: Enable this with XDEBUG when the test suite is clean.
113 static cl::opt<bool>
114 VerifySCEV("verify-scev",
115            cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
116 
117 INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution",
118                 "Scalar Evolution Analysis", false, true)
119 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
120 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
121 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
122 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
123 INITIALIZE_PASS_END(ScalarEvolution, "scalar-evolution",
124                 "Scalar Evolution Analysis", false, true)
125 char ScalarEvolution::ID = 0;
126 
127 //===----------------------------------------------------------------------===//
128 //                           SCEV class definitions
129 //===----------------------------------------------------------------------===//
130 
131 //===----------------------------------------------------------------------===//
132 // Implementation of the SCEV class.
133 //
134 
135 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
136 void SCEV::dump() const {
137   print(dbgs());
138   dbgs() << '\n';
139 }
140 #endif
141 
142 void SCEV::print(raw_ostream &OS) const {
143   switch (static_cast<SCEVTypes>(getSCEVType())) {
144   case scConstant:
145     cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
146     return;
147   case scTruncate: {
148     const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
149     const SCEV *Op = Trunc->getOperand();
150     OS << "(trunc " << *Op->getType() << " " << *Op << " to "
151        << *Trunc->getType() << ")";
152     return;
153   }
154   case scZeroExtend: {
155     const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
156     const SCEV *Op = ZExt->getOperand();
157     OS << "(zext " << *Op->getType() << " " << *Op << " to "
158        << *ZExt->getType() << ")";
159     return;
160   }
161   case scSignExtend: {
162     const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
163     const SCEV *Op = SExt->getOperand();
164     OS << "(sext " << *Op->getType() << " " << *Op << " to "
165        << *SExt->getType() << ")";
166     return;
167   }
168   case scAddRecExpr: {
169     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
170     OS << "{" << *AR->getOperand(0);
171     for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
172       OS << ",+," << *AR->getOperand(i);
173     OS << "}<";
174     if (AR->getNoWrapFlags(FlagNUW))
175       OS << "nuw><";
176     if (AR->getNoWrapFlags(FlagNSW))
177       OS << "nsw><";
178     if (AR->getNoWrapFlags(FlagNW) &&
179         !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
180       OS << "nw><";
181     AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
182     OS << ">";
183     return;
184   }
185   case scAddExpr:
186   case scMulExpr:
187   case scUMaxExpr:
188   case scSMaxExpr: {
189     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
190     const char *OpStr = nullptr;
191     switch (NAry->getSCEVType()) {
192     case scAddExpr: OpStr = " + "; break;
193     case scMulExpr: OpStr = " * "; break;
194     case scUMaxExpr: OpStr = " umax "; break;
195     case scSMaxExpr: OpStr = " smax "; break;
196     }
197     OS << "(";
198     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
199          I != E; ++I) {
200       OS << **I;
201       if (std::next(I) != E)
202         OS << OpStr;
203     }
204     OS << ")";
205     switch (NAry->getSCEVType()) {
206     case scAddExpr:
207     case scMulExpr:
208       if (NAry->getNoWrapFlags(FlagNUW))
209         OS << "<nuw>";
210       if (NAry->getNoWrapFlags(FlagNSW))
211         OS << "<nsw>";
212     }
213     return;
214   }
215   case scUDivExpr: {
216     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
217     OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
218     return;
219   }
220   case scUnknown: {
221     const SCEVUnknown *U = cast<SCEVUnknown>(this);
222     Type *AllocTy;
223     if (U->isSizeOf(AllocTy)) {
224       OS << "sizeof(" << *AllocTy << ")";
225       return;
226     }
227     if (U->isAlignOf(AllocTy)) {
228       OS << "alignof(" << *AllocTy << ")";
229       return;
230     }
231 
232     Type *CTy;
233     Constant *FieldNo;
234     if (U->isOffsetOf(CTy, FieldNo)) {
235       OS << "offsetof(" << *CTy << ", ";
236       FieldNo->printAsOperand(OS, false);
237       OS << ")";
238       return;
239     }
240 
241     // Otherwise just print it normally.
242     U->getValue()->printAsOperand(OS, false);
243     return;
244   }
245   case scCouldNotCompute:
246     OS << "***COULDNOTCOMPUTE***";
247     return;
248   }
249   llvm_unreachable("Unknown SCEV kind!");
250 }
251 
252 Type *SCEV::getType() const {
253   switch (static_cast<SCEVTypes>(getSCEVType())) {
254   case scConstant:
255     return cast<SCEVConstant>(this)->getType();
256   case scTruncate:
257   case scZeroExtend:
258   case scSignExtend:
259     return cast<SCEVCastExpr>(this)->getType();
260   case scAddRecExpr:
261   case scMulExpr:
262   case scUMaxExpr:
263   case scSMaxExpr:
264     return cast<SCEVNAryExpr>(this)->getType();
265   case scAddExpr:
266     return cast<SCEVAddExpr>(this)->getType();
267   case scUDivExpr:
268     return cast<SCEVUDivExpr>(this)->getType();
269   case scUnknown:
270     return cast<SCEVUnknown>(this)->getType();
271   case scCouldNotCompute:
272     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
273   }
274   llvm_unreachable("Unknown SCEV kind!");
275 }
276 
277 bool SCEV::isZero() const {
278   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
279     return SC->getValue()->isZero();
280   return false;
281 }
282 
283 bool SCEV::isOne() const {
284   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
285     return SC->getValue()->isOne();
286   return false;
287 }
288 
289 bool SCEV::isAllOnesValue() const {
290   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
291     return SC->getValue()->isAllOnesValue();
292   return false;
293 }
294 
295 /// isNonConstantNegative - Return true if the specified scev is negated, but
296 /// not a constant.
297 bool SCEV::isNonConstantNegative() const {
298   const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
299   if (!Mul) return false;
300 
301   // If there is a constant factor, it will be first.
302   const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
303   if (!SC) return false;
304 
305   // Return true if the value is negative, this matches things like (-42 * V).
306   return SC->getValue()->getValue().isNegative();
307 }
308 
309 SCEVCouldNotCompute::SCEVCouldNotCompute() :
310   SCEV(FoldingSetNodeIDRef(), scCouldNotCompute) {}
311 
312 bool SCEVCouldNotCompute::classof(const SCEV *S) {
313   return S->getSCEVType() == scCouldNotCompute;
314 }
315 
316 const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
317   FoldingSetNodeID ID;
318   ID.AddInteger(scConstant);
319   ID.AddPointer(V);
320   void *IP = nullptr;
321   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
322   SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
323   UniqueSCEVs.InsertNode(S, IP);
324   return S;
325 }
326 
327 const SCEV *ScalarEvolution::getConstant(const APInt &Val) {
328   return getConstant(ConstantInt::get(getContext(), Val));
329 }
330 
331 const SCEV *
332 ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
333   IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
334   return getConstant(ConstantInt::get(ITy, V, isSigned));
335 }
336 
337 SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID,
338                            unsigned SCEVTy, const SCEV *op, Type *ty)
339   : SCEV(ID, SCEVTy), Op(op), Ty(ty) {}
340 
341 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
342                                    const SCEV *op, Type *ty)
343   : SCEVCastExpr(ID, scTruncate, op, ty) {
344   assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
345          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
346          "Cannot truncate non-integer value!");
347 }
348 
349 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
350                                        const SCEV *op, Type *ty)
351   : SCEVCastExpr(ID, scZeroExtend, op, ty) {
352   assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
353          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
354          "Cannot zero extend non-integer value!");
355 }
356 
357 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
358                                        const SCEV *op, Type *ty)
359   : SCEVCastExpr(ID, scSignExtend, op, ty) {
360   assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
361          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
362          "Cannot sign extend non-integer value!");
363 }
364 
365 void SCEVUnknown::deleted() {
366   // Clear this SCEVUnknown from various maps.
367   SE->forgetMemoizedResults(this);
368 
369   // Remove this SCEVUnknown from the uniquing map.
370   SE->UniqueSCEVs.RemoveNode(this);
371 
372   // Release the value.
373   setValPtr(nullptr);
374 }
375 
376 void SCEVUnknown::allUsesReplacedWith(Value *New) {
377   // Clear this SCEVUnknown from various maps.
378   SE->forgetMemoizedResults(this);
379 
380   // Remove this SCEVUnknown from the uniquing map.
381   SE->UniqueSCEVs.RemoveNode(this);
382 
383   // Update this SCEVUnknown to point to the new value. This is needed
384   // because there may still be outstanding SCEVs which still point to
385   // this SCEVUnknown.
386   setValPtr(New);
387 }
388 
389 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
390   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
391     if (VCE->getOpcode() == Instruction::PtrToInt)
392       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
393         if (CE->getOpcode() == Instruction::GetElementPtr &&
394             CE->getOperand(0)->isNullValue() &&
395             CE->getNumOperands() == 2)
396           if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
397             if (CI->isOne()) {
398               AllocTy = cast<PointerType>(CE->getOperand(0)->getType())
399                                  ->getElementType();
400               return true;
401             }
402 
403   return false;
404 }
405 
406 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
407   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
408     if (VCE->getOpcode() == Instruction::PtrToInt)
409       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
410         if (CE->getOpcode() == Instruction::GetElementPtr &&
411             CE->getOperand(0)->isNullValue()) {
412           Type *Ty =
413             cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
414           if (StructType *STy = dyn_cast<StructType>(Ty))
415             if (!STy->isPacked() &&
416                 CE->getNumOperands() == 3 &&
417                 CE->getOperand(1)->isNullValue()) {
418               if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
419                 if (CI->isOne() &&
420                     STy->getNumElements() == 2 &&
421                     STy->getElementType(0)->isIntegerTy(1)) {
422                   AllocTy = STy->getElementType(1);
423                   return true;
424                 }
425             }
426         }
427 
428   return false;
429 }
430 
431 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
432   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
433     if (VCE->getOpcode() == Instruction::PtrToInt)
434       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
435         if (CE->getOpcode() == Instruction::GetElementPtr &&
436             CE->getNumOperands() == 3 &&
437             CE->getOperand(0)->isNullValue() &&
438             CE->getOperand(1)->isNullValue()) {
439           Type *Ty =
440             cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
441           // Ignore vector types here so that ScalarEvolutionExpander doesn't
442           // emit getelementptrs that index into vectors.
443           if (Ty->isStructTy() || Ty->isArrayTy()) {
444             CTy = Ty;
445             FieldNo = CE->getOperand(2);
446             return true;
447           }
448         }
449 
450   return false;
451 }
452 
453 //===----------------------------------------------------------------------===//
454 //                               SCEV Utilities
455 //===----------------------------------------------------------------------===//
456 
457 namespace {
458   /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
459   /// than the complexity of the RHS.  This comparator is used to canonicalize
460   /// expressions.
461   class SCEVComplexityCompare {
462     const LoopInfo *const LI;
463   public:
464     explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {}
465 
466     // Return true or false if LHS is less than, or at least RHS, respectively.
467     bool operator()(const SCEV *LHS, const SCEV *RHS) const {
468       return compare(LHS, RHS) < 0;
469     }
470 
471     // Return negative, zero, or positive, if LHS is less than, equal to, or
472     // greater than RHS, respectively. A three-way result allows recursive
473     // comparisons to be more efficient.
474     int compare(const SCEV *LHS, const SCEV *RHS) const {
475       // Fast-path: SCEVs are uniqued so we can do a quick equality check.
476       if (LHS == RHS)
477         return 0;
478 
479       // Primarily, sort the SCEVs by their getSCEVType().
480       unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
481       if (LType != RType)
482         return (int)LType - (int)RType;
483 
484       // Aside from the getSCEVType() ordering, the particular ordering
485       // isn't very important except that it's beneficial to be consistent,
486       // so that (a + b) and (b + a) don't end up as different expressions.
487       switch (static_cast<SCEVTypes>(LType)) {
488       case scUnknown: {
489         const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
490         const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
491 
492         // Sort SCEVUnknown values with some loose heuristics. TODO: This is
493         // not as complete as it could be.
494         const Value *LV = LU->getValue(), *RV = RU->getValue();
495 
496         // Order pointer values after integer values. This helps SCEVExpander
497         // form GEPs.
498         bool LIsPointer = LV->getType()->isPointerTy(),
499              RIsPointer = RV->getType()->isPointerTy();
500         if (LIsPointer != RIsPointer)
501           return (int)LIsPointer - (int)RIsPointer;
502 
503         // Compare getValueID values.
504         unsigned LID = LV->getValueID(),
505                  RID = RV->getValueID();
506         if (LID != RID)
507           return (int)LID - (int)RID;
508 
509         // Sort arguments by their position.
510         if (const Argument *LA = dyn_cast<Argument>(LV)) {
511           const Argument *RA = cast<Argument>(RV);
512           unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
513           return (int)LArgNo - (int)RArgNo;
514         }
515 
516         // For instructions, compare their loop depth, and their operand
517         // count.  This is pretty loose.
518         if (const Instruction *LInst = dyn_cast<Instruction>(LV)) {
519           const Instruction *RInst = cast<Instruction>(RV);
520 
521           // Compare loop depths.
522           const BasicBlock *LParent = LInst->getParent(),
523                            *RParent = RInst->getParent();
524           if (LParent != RParent) {
525             unsigned LDepth = LI->getLoopDepth(LParent),
526                      RDepth = LI->getLoopDepth(RParent);
527             if (LDepth != RDepth)
528               return (int)LDepth - (int)RDepth;
529           }
530 
531           // Compare the number of operands.
532           unsigned LNumOps = LInst->getNumOperands(),
533                    RNumOps = RInst->getNumOperands();
534           return (int)LNumOps - (int)RNumOps;
535         }
536 
537         return 0;
538       }
539 
540       case scConstant: {
541         const SCEVConstant *LC = cast<SCEVConstant>(LHS);
542         const SCEVConstant *RC = cast<SCEVConstant>(RHS);
543 
544         // Compare constant values.
545         const APInt &LA = LC->getValue()->getValue();
546         const APInt &RA = RC->getValue()->getValue();
547         unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
548         if (LBitWidth != RBitWidth)
549           return (int)LBitWidth - (int)RBitWidth;
550         return LA.ult(RA) ? -1 : 1;
551       }
552 
553       case scAddRecExpr: {
554         const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
555         const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
556 
557         // Compare addrec loop depths.
558         const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
559         if (LLoop != RLoop) {
560           unsigned LDepth = LLoop->getLoopDepth(),
561                    RDepth = RLoop->getLoopDepth();
562           if (LDepth != RDepth)
563             return (int)LDepth - (int)RDepth;
564         }
565 
566         // Addrec complexity grows with operand count.
567         unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
568         if (LNumOps != RNumOps)
569           return (int)LNumOps - (int)RNumOps;
570 
571         // Lexicographically compare.
572         for (unsigned i = 0; i != LNumOps; ++i) {
573           long X = compare(LA->getOperand(i), RA->getOperand(i));
574           if (X != 0)
575             return X;
576         }
577 
578         return 0;
579       }
580 
581       case scAddExpr:
582       case scMulExpr:
583       case scSMaxExpr:
584       case scUMaxExpr: {
585         const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
586         const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
587 
588         // Lexicographically compare n-ary expressions.
589         unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
590         if (LNumOps != RNumOps)
591           return (int)LNumOps - (int)RNumOps;
592 
593         for (unsigned i = 0; i != LNumOps; ++i) {
594           if (i >= RNumOps)
595             return 1;
596           long X = compare(LC->getOperand(i), RC->getOperand(i));
597           if (X != 0)
598             return X;
599         }
600         return (int)LNumOps - (int)RNumOps;
601       }
602 
603       case scUDivExpr: {
604         const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
605         const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
606 
607         // Lexicographically compare udiv expressions.
608         long X = compare(LC->getLHS(), RC->getLHS());
609         if (X != 0)
610           return X;
611         return compare(LC->getRHS(), RC->getRHS());
612       }
613 
614       case scTruncate:
615       case scZeroExtend:
616       case scSignExtend: {
617         const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
618         const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
619 
620         // Compare cast expressions by operand.
621         return compare(LC->getOperand(), RC->getOperand());
622       }
623 
624       case scCouldNotCompute:
625         llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
626       }
627       llvm_unreachable("Unknown SCEV kind!");
628     }
629   };
630 }
631 
632 /// GroupByComplexity - Given a list of SCEV objects, order them by their
633 /// complexity, and group objects of the same complexity together by value.
634 /// When this routine is finished, we know that any duplicates in the vector are
635 /// consecutive and that complexity is monotonically increasing.
636 ///
637 /// Note that we go take special precautions to ensure that we get deterministic
638 /// results from this routine.  In other words, we don't want the results of
639 /// this to depend on where the addresses of various SCEV objects happened to
640 /// land in memory.
641 ///
642 static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
643                               LoopInfo *LI) {
644   if (Ops.size() < 2) return;  // Noop
645   if (Ops.size() == 2) {
646     // This is the common case, which also happens to be trivially simple.
647     // Special case it.
648     const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
649     if (SCEVComplexityCompare(LI)(RHS, LHS))
650       std::swap(LHS, RHS);
651     return;
652   }
653 
654   // Do the rough sort by complexity.
655   std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI));
656 
657   // Now that we are sorted by complexity, group elements of the same
658   // complexity.  Note that this is, at worst, N^2, but the vector is likely to
659   // be extremely short in practice.  Note that we take this approach because we
660   // do not want to depend on the addresses of the objects we are grouping.
661   for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
662     const SCEV *S = Ops[i];
663     unsigned Complexity = S->getSCEVType();
664 
665     // If there are any objects of the same complexity and same value as this
666     // one, group them.
667     for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
668       if (Ops[j] == S) { // Found a duplicate.
669         // Move it to immediately after i'th element.
670         std::swap(Ops[i+1], Ops[j]);
671         ++i;   // no need to rescan it.
672         if (i == e-2) return;  // Done!
673       }
674     }
675   }
676 }
677 
678 namespace {
679 struct FindSCEVSize {
680   int Size;
681   FindSCEVSize() : Size(0) {}
682 
683   bool follow(const SCEV *S) {
684     ++Size;
685     // Keep looking at all operands of S.
686     return true;
687   }
688   bool isDone() const {
689     return false;
690   }
691 };
692 }
693 
694 // Returns the size of the SCEV S.
695 static inline int sizeOfSCEV(const SCEV *S) {
696   FindSCEVSize F;
697   SCEVTraversal<FindSCEVSize> ST(F);
698   ST.visitAll(S);
699   return F.Size;
700 }
701 
702 namespace {
703 
704 struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> {
705 public:
706   // Computes the Quotient and Remainder of the division of Numerator by
707   // Denominator.
708   static void divide(ScalarEvolution &SE, const SCEV *Numerator,
709                      const SCEV *Denominator, const SCEV **Quotient,
710                      const SCEV **Remainder) {
711     assert(Numerator && Denominator && "Uninitialized SCEV");
712 
713     SCEVDivision D(SE, Numerator, Denominator);
714 
715     // Check for the trivial case here to avoid having to check for it in the
716     // rest of the code.
717     if (Numerator == Denominator) {
718       *Quotient = D.One;
719       *Remainder = D.Zero;
720       return;
721     }
722 
723     if (Numerator->isZero()) {
724       *Quotient = D.Zero;
725       *Remainder = D.Zero;
726       return;
727     }
728 
729     // A simple case when N/1. The quotient is N.
730     if (Denominator->isOne()) {
731       *Quotient = Numerator;
732       *Remainder = D.Zero;
733       return;
734     }
735 
736     // Split the Denominator when it is a product.
737     if (const SCEVMulExpr *T = dyn_cast<const SCEVMulExpr>(Denominator)) {
738       const SCEV *Q, *R;
739       *Quotient = Numerator;
740       for (const SCEV *Op : T->operands()) {
741         divide(SE, *Quotient, Op, &Q, &R);
742         *Quotient = Q;
743 
744         // Bail out when the Numerator is not divisible by one of the terms of
745         // the Denominator.
746         if (!R->isZero()) {
747           *Quotient = D.Zero;
748           *Remainder = Numerator;
749           return;
750         }
751       }
752       *Remainder = D.Zero;
753       return;
754     }
755 
756     D.visit(Numerator);
757     *Quotient = D.Quotient;
758     *Remainder = D.Remainder;
759   }
760 
761   // Except in the trivial case described above, we do not know how to divide
762   // Expr by Denominator for the following functions with empty implementation.
763   void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {}
764   void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {}
765   void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {}
766   void visitUDivExpr(const SCEVUDivExpr *Numerator) {}
767   void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {}
768   void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {}
769   void visitUnknown(const SCEVUnknown *Numerator) {}
770   void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {}
771 
772   void visitConstant(const SCEVConstant *Numerator) {
773     if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
774       APInt NumeratorVal = Numerator->getValue()->getValue();
775       APInt DenominatorVal = D->getValue()->getValue();
776       uint32_t NumeratorBW = NumeratorVal.getBitWidth();
777       uint32_t DenominatorBW = DenominatorVal.getBitWidth();
778 
779       if (NumeratorBW > DenominatorBW)
780         DenominatorVal = DenominatorVal.sext(NumeratorBW);
781       else if (NumeratorBW < DenominatorBW)
782         NumeratorVal = NumeratorVal.sext(DenominatorBW);
783 
784       APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
785       APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
786       APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
787       Quotient = SE.getConstant(QuotientVal);
788       Remainder = SE.getConstant(RemainderVal);
789       return;
790     }
791   }
792 
793   void visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
794     const SCEV *StartQ, *StartR, *StepQ, *StepR;
795     assert(Numerator->isAffine() && "Numerator should be affine");
796     divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
797     divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
798     Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
799                                 Numerator->getNoWrapFlags());
800     Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
801                                  Numerator->getNoWrapFlags());
802   }
803 
804   void visitAddExpr(const SCEVAddExpr *Numerator) {
805     SmallVector<const SCEV *, 2> Qs, Rs;
806     Type *Ty = Denominator->getType();
807 
808     for (const SCEV *Op : Numerator->operands()) {
809       const SCEV *Q, *R;
810       divide(SE, Op, Denominator, &Q, &R);
811 
812       // Bail out if types do not match.
813       if (Ty != Q->getType() || Ty != R->getType()) {
814         Quotient = Zero;
815         Remainder = Numerator;
816         return;
817       }
818 
819       Qs.push_back(Q);
820       Rs.push_back(R);
821     }
822 
823     if (Qs.size() == 1) {
824       Quotient = Qs[0];
825       Remainder = Rs[0];
826       return;
827     }
828 
829     Quotient = SE.getAddExpr(Qs);
830     Remainder = SE.getAddExpr(Rs);
831   }
832 
833   void visitMulExpr(const SCEVMulExpr *Numerator) {
834     SmallVector<const SCEV *, 2> Qs;
835     Type *Ty = Denominator->getType();
836 
837     bool FoundDenominatorTerm = false;
838     for (const SCEV *Op : Numerator->operands()) {
839       // Bail out if types do not match.
840       if (Ty != Op->getType()) {
841         Quotient = Zero;
842         Remainder = Numerator;
843         return;
844       }
845 
846       if (FoundDenominatorTerm) {
847         Qs.push_back(Op);
848         continue;
849       }
850 
851       // Check whether Denominator divides one of the product operands.
852       const SCEV *Q, *R;
853       divide(SE, Op, Denominator, &Q, &R);
854       if (!R->isZero()) {
855         Qs.push_back(Op);
856         continue;
857       }
858 
859       // Bail out if types do not match.
860       if (Ty != Q->getType()) {
861         Quotient = Zero;
862         Remainder = Numerator;
863         return;
864       }
865 
866       FoundDenominatorTerm = true;
867       Qs.push_back(Q);
868     }
869 
870     if (FoundDenominatorTerm) {
871       Remainder = Zero;
872       if (Qs.size() == 1)
873         Quotient = Qs[0];
874       else
875         Quotient = SE.getMulExpr(Qs);
876       return;
877     }
878 
879     if (!isa<SCEVUnknown>(Denominator)) {
880       Quotient = Zero;
881       Remainder = Numerator;
882       return;
883     }
884 
885     // The Remainder is obtained by replacing Denominator by 0 in Numerator.
886     ValueToValueMap RewriteMap;
887     RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
888         cast<SCEVConstant>(Zero)->getValue();
889     Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
890 
891     if (Remainder->isZero()) {
892       // The Quotient is obtained by replacing Denominator by 1 in Numerator.
893       RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
894           cast<SCEVConstant>(One)->getValue();
895       Quotient =
896           SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
897       return;
898     }
899 
900     // Quotient is (Numerator - Remainder) divided by Denominator.
901     const SCEV *Q, *R;
902     const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
903     if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) {
904       // This SCEV does not seem to simplify: fail the division here.
905       Quotient = Zero;
906       Remainder = Numerator;
907       return;
908     }
909     divide(SE, Diff, Denominator, &Q, &R);
910     assert(R == Zero &&
911            "(Numerator - Remainder) should evenly divide Denominator");
912     Quotient = Q;
913   }
914 
915 private:
916   SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
917                const SCEV *Denominator)
918       : SE(S), Denominator(Denominator) {
919     Zero = SE.getConstant(Denominator->getType(), 0);
920     One = SE.getConstant(Denominator->getType(), 1);
921 
922     // By default, we don't know how to divide Expr by Denominator.
923     // Providing the default here simplifies the rest of the code.
924     Quotient = Zero;
925     Remainder = Numerator;
926   }
927 
928   ScalarEvolution &SE;
929   const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One;
930 };
931 
932 }
933 
934 //===----------------------------------------------------------------------===//
935 //                      Simple SCEV method implementations
936 //===----------------------------------------------------------------------===//
937 
938 /// BinomialCoefficient - Compute BC(It, K).  The result has width W.
939 /// Assume, K > 0.
940 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
941                                        ScalarEvolution &SE,
942                                        Type *ResultTy) {
943   // Handle the simplest case efficiently.
944   if (K == 1)
945     return SE.getTruncateOrZeroExtend(It, ResultTy);
946 
947   // We are using the following formula for BC(It, K):
948   //
949   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
950   //
951   // Suppose, W is the bitwidth of the return value.  We must be prepared for
952   // overflow.  Hence, we must assure that the result of our computation is
953   // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
954   // safe in modular arithmetic.
955   //
956   // However, this code doesn't use exactly that formula; the formula it uses
957   // is something like the following, where T is the number of factors of 2 in
958   // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
959   // exponentiation:
960   //
961   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
962   //
963   // This formula is trivially equivalent to the previous formula.  However,
964   // this formula can be implemented much more efficiently.  The trick is that
965   // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
966   // arithmetic.  To do exact division in modular arithmetic, all we have
967   // to do is multiply by the inverse.  Therefore, this step can be done at
968   // width W.
969   //
970   // The next issue is how to safely do the division by 2^T.  The way this
971   // is done is by doing the multiplication step at a width of at least W + T
972   // bits.  This way, the bottom W+T bits of the product are accurate. Then,
973   // when we perform the division by 2^T (which is equivalent to a right shift
974   // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
975   // truncated out after the division by 2^T.
976   //
977   // In comparison to just directly using the first formula, this technique
978   // is much more efficient; using the first formula requires W * K bits,
979   // but this formula less than W + K bits. Also, the first formula requires
980   // a division step, whereas this formula only requires multiplies and shifts.
981   //
982   // It doesn't matter whether the subtraction step is done in the calculation
983   // width or the input iteration count's width; if the subtraction overflows,
984   // the result must be zero anyway.  We prefer here to do it in the width of
985   // the induction variable because it helps a lot for certain cases; CodeGen
986   // isn't smart enough to ignore the overflow, which leads to much less
987   // efficient code if the width of the subtraction is wider than the native
988   // register width.
989   //
990   // (It's possible to not widen at all by pulling out factors of 2 before
991   // the multiplication; for example, K=2 can be calculated as
992   // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
993   // extra arithmetic, so it's not an obvious win, and it gets
994   // much more complicated for K > 3.)
995 
996   // Protection from insane SCEVs; this bound is conservative,
997   // but it probably doesn't matter.
998   if (K > 1000)
999     return SE.getCouldNotCompute();
1000 
1001   unsigned W = SE.getTypeSizeInBits(ResultTy);
1002 
1003   // Calculate K! / 2^T and T; we divide out the factors of two before
1004   // multiplying for calculating K! / 2^T to avoid overflow.
1005   // Other overflow doesn't matter because we only care about the bottom
1006   // W bits of the result.
1007   APInt OddFactorial(W, 1);
1008   unsigned T = 1;
1009   for (unsigned i = 3; i <= K; ++i) {
1010     APInt Mult(W, i);
1011     unsigned TwoFactors = Mult.countTrailingZeros();
1012     T += TwoFactors;
1013     Mult = Mult.lshr(TwoFactors);
1014     OddFactorial *= Mult;
1015   }
1016 
1017   // We need at least W + T bits for the multiplication step
1018   unsigned CalculationBits = W + T;
1019 
1020   // Calculate 2^T, at width T+W.
1021   APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1022 
1023   // Calculate the multiplicative inverse of K! / 2^T;
1024   // this multiplication factor will perform the exact division by
1025   // K! / 2^T.
1026   APInt Mod = APInt::getSignedMinValue(W+1);
1027   APInt MultiplyFactor = OddFactorial.zext(W+1);
1028   MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
1029   MultiplyFactor = MultiplyFactor.trunc(W);
1030 
1031   // Calculate the product, at width T+W
1032   IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1033                                                       CalculationBits);
1034   const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1035   for (unsigned i = 1; i != K; ++i) {
1036     const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1037     Dividend = SE.getMulExpr(Dividend,
1038                              SE.getTruncateOrZeroExtend(S, CalculationTy));
1039   }
1040 
1041   // Divide by 2^T
1042   const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1043 
1044   // Truncate the result, and divide by K! / 2^T.
1045 
1046   return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1047                        SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1048 }
1049 
1050 /// evaluateAtIteration - Return the value of this chain of recurrences at
1051 /// the specified iteration number.  We can evaluate this recurrence by
1052 /// multiplying each element in the chain by the binomial coefficient
1053 /// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
1054 ///
1055 ///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1056 ///
1057 /// where BC(It, k) stands for binomial coefficient.
1058 ///
1059 const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
1060                                                 ScalarEvolution &SE) const {
1061   const SCEV *Result = getStart();
1062   for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
1063     // The computation is correct in the face of overflow provided that the
1064     // multiplication is performed _after_ the evaluation of the binomial
1065     // coefficient.
1066     const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType());
1067     if (isa<SCEVCouldNotCompute>(Coeff))
1068       return Coeff;
1069 
1070     Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
1071   }
1072   return Result;
1073 }
1074 
1075 //===----------------------------------------------------------------------===//
1076 //                    SCEV Expression folder implementations
1077 //===----------------------------------------------------------------------===//
1078 
1079 const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
1080                                              Type *Ty) {
1081   assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1082          "This is not a truncating conversion!");
1083   assert(isSCEVable(Ty) &&
1084          "This is not a conversion to a SCEVable type!");
1085   Ty = getEffectiveSCEVType(Ty);
1086 
1087   FoldingSetNodeID ID;
1088   ID.AddInteger(scTruncate);
1089   ID.AddPointer(Op);
1090   ID.AddPointer(Ty);
1091   void *IP = nullptr;
1092   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1093 
1094   // Fold if the operand is constant.
1095   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1096     return getConstant(
1097       cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1098 
1099   // trunc(trunc(x)) --> trunc(x)
1100   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1101     return getTruncateExpr(ST->getOperand(), Ty);
1102 
1103   // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1104   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1105     return getTruncateOrSignExtend(SS->getOperand(), Ty);
1106 
1107   // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1108   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1109     return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
1110 
1111   // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can
1112   // eliminate all the truncates, or we replace other casts with truncates.
1113   if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) {
1114     SmallVector<const SCEV *, 4> Operands;
1115     bool hasTrunc = false;
1116     for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) {
1117       const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty);
1118       if (!isa<SCEVCastExpr>(SA->getOperand(i)))
1119         hasTrunc = isa<SCEVTruncateExpr>(S);
1120       Operands.push_back(S);
1121     }
1122     if (!hasTrunc)
1123       return getAddExpr(Operands);
1124     UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
1125   }
1126 
1127   // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can
1128   // eliminate all the truncates, or we replace other casts with truncates.
1129   if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) {
1130     SmallVector<const SCEV *, 4> Operands;
1131     bool hasTrunc = false;
1132     for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) {
1133       const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty);
1134       if (!isa<SCEVCastExpr>(SM->getOperand(i)))
1135         hasTrunc = isa<SCEVTruncateExpr>(S);
1136       Operands.push_back(S);
1137     }
1138     if (!hasTrunc)
1139       return getMulExpr(Operands);
1140     UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
1141   }
1142 
1143   // If the input value is a chrec scev, truncate the chrec's operands.
1144   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1145     SmallVector<const SCEV *, 4> Operands;
1146     for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
1147       Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
1148     return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1149   }
1150 
1151   // The cast wasn't folded; create an explicit cast node. We can reuse
1152   // the existing insert position since if we get here, we won't have
1153   // made any changes which would invalidate it.
1154   SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1155                                                  Op, Ty);
1156   UniqueSCEVs.InsertNode(S, IP);
1157   return S;
1158 }
1159 
1160 // Get the limit of a recurrence such that incrementing by Step cannot cause
1161 // signed overflow as long as the value of the recurrence within the
1162 // loop does not exceed this limit before incrementing.
1163 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1164                                                  ICmpInst::Predicate *Pred,
1165                                                  ScalarEvolution *SE) {
1166   unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1167   if (SE->isKnownPositive(Step)) {
1168     *Pred = ICmpInst::ICMP_SLT;
1169     return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
1170                            SE->getSignedRange(Step).getSignedMax());
1171   }
1172   if (SE->isKnownNegative(Step)) {
1173     *Pred = ICmpInst::ICMP_SGT;
1174     return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
1175                            SE->getSignedRange(Step).getSignedMin());
1176   }
1177   return nullptr;
1178 }
1179 
1180 // Get the limit of a recurrence such that incrementing by Step cannot cause
1181 // unsigned overflow as long as the value of the recurrence within the loop does
1182 // not exceed this limit before incrementing.
1183 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
1184                                                    ICmpInst::Predicate *Pred,
1185                                                    ScalarEvolution *SE) {
1186   unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1187   *Pred = ICmpInst::ICMP_ULT;
1188 
1189   return SE->getConstant(APInt::getMinValue(BitWidth) -
1190                          SE->getUnsignedRange(Step).getUnsignedMax());
1191 }
1192 
1193 namespace {
1194 
1195 struct ExtendOpTraitsBase {
1196   typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *);
1197 };
1198 
1199 // Used to make code generic over signed and unsigned overflow.
1200 template <typename ExtendOp> struct ExtendOpTraits {
1201   // Members present:
1202   //
1203   // static const SCEV::NoWrapFlags WrapType;
1204   //
1205   // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1206   //
1207   // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1208   //                                           ICmpInst::Predicate *Pred,
1209   //                                           ScalarEvolution *SE);
1210 };
1211 
1212 template <>
1213 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1214   static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1215 
1216   static const GetExtendExprTy GetExtendExpr;
1217 
1218   static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1219                                              ICmpInst::Predicate *Pred,
1220                                              ScalarEvolution *SE) {
1221     return getSignedOverflowLimitForStep(Step, Pred, SE);
1222   }
1223 };
1224 
1225 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1226     SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr;
1227 
1228 template <>
1229 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1230   static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1231 
1232   static const GetExtendExprTy GetExtendExpr;
1233 
1234   static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1235                                              ICmpInst::Predicate *Pred,
1236                                              ScalarEvolution *SE) {
1237     return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1238   }
1239 };
1240 
1241 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1242     SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
1243 }
1244 
1245 // The recurrence AR has been shown to have no signed/unsigned wrap or something
1246 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1247 // easily prove NSW/NUW for its preincrement or postincrement sibling. This
1248 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1249 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1250 // expression "Step + sext/zext(PreIncAR)" is congruent with
1251 // "sext/zext(PostIncAR)"
1252 template <typename ExtendOpTy>
1253 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1254                                         ScalarEvolution *SE) {
1255   auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1256   auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1257 
1258   const Loop *L = AR->getLoop();
1259   const SCEV *Start = AR->getStart();
1260   const SCEV *Step = AR->getStepRecurrence(*SE);
1261 
1262   // Check for a simple looking step prior to loop entry.
1263   const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1264   if (!SA)
1265     return nullptr;
1266 
1267   // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1268   // subtraction is expensive. For this purpose, perform a quick and dirty
1269   // difference, by checking for Step in the operand list.
1270   SmallVector<const SCEV *, 4> DiffOps;
1271   for (const SCEV *Op : SA->operands())
1272     if (Op != Step)
1273       DiffOps.push_back(Op);
1274 
1275   if (DiffOps.size() == SA->getNumOperands())
1276     return nullptr;
1277 
1278   // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1279   // `Step`:
1280 
1281   // 1. NSW/NUW flags on the step increment.
1282   const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags());
1283   const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1284       SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1285 
1286   // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1287   // "S+X does not sign/unsign-overflow".
1288   //
1289 
1290   const SCEV *BECount = SE->getBackedgeTakenCount(L);
1291   if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1292       !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1293     return PreStart;
1294 
1295   // 2. Direct overflow check on the step operation's expression.
1296   unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1297   Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1298   const SCEV *OperandExtendedStart =
1299       SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy),
1300                      (SE->*GetExtendExpr)(Step, WideTy));
1301   if ((SE->*GetExtendExpr)(Start, WideTy) == OperandExtendedStart) {
1302     if (PreAR && AR->getNoWrapFlags(WrapType)) {
1303       // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1304       // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1305       // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`.  Cache this fact.
1306       const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(WrapType);
1307     }
1308     return PreStart;
1309   }
1310 
1311   // 3. Loop precondition.
1312   ICmpInst::Predicate Pred;
1313   const SCEV *OverflowLimit =
1314       ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1315 
1316   if (OverflowLimit &&
1317       SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) {
1318     return PreStart;
1319   }
1320   return nullptr;
1321 }
1322 
1323 // Get the normalized zero or sign extended expression for this AddRec's Start.
1324 template <typename ExtendOpTy>
1325 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1326                                         ScalarEvolution *SE) {
1327   auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1328 
1329   const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE);
1330   if (!PreStart)
1331     return (SE->*GetExtendExpr)(AR->getStart(), Ty);
1332 
1333   return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty),
1334                         (SE->*GetExtendExpr)(PreStart, Ty));
1335 }
1336 
1337 // Try to prove away overflow by looking at "nearby" add recurrences.  A
1338 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1339 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1340 //
1341 // Formally:
1342 //
1343 //     {S,+,X} == {S-T,+,X} + T
1344 //  => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1345 //
1346 // If ({S-T,+,X} + T) does not overflow  ... (1)
1347 //
1348 //  RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1349 //
1350 // If {S-T,+,X} does not overflow  ... (2)
1351 //
1352 //  RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1353 //      == {Ext(S-T)+Ext(T),+,Ext(X)}
1354 //
1355 // If (S-T)+T does not overflow  ... (3)
1356 //
1357 //  RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1358 //      == {Ext(S),+,Ext(X)} == LHS
1359 //
1360 // Thus, if (1), (2) and (3) are true for some T, then
1361 //   Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1362 //
1363 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1364 // does not overflow" restricted to the 0th iteration.  Therefore we only need
1365 // to check for (1) and (2).
1366 //
1367 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1368 // is `Delta` (defined below).
1369 //
1370 template <typename ExtendOpTy>
1371 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1372                                                 const SCEV *Step,
1373                                                 const Loop *L) {
1374   auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1375 
1376   // We restrict `Start` to a constant to prevent SCEV from spending too much
1377   // time here.  It is correct (but more expensive) to continue with a
1378   // non-constant `Start` and do a general SCEV subtraction to compute
1379   // `PreStart` below.
1380   //
1381   const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1382   if (!StartC)
1383     return false;
1384 
1385   APInt StartAI = StartC->getValue()->getValue();
1386 
1387   for (unsigned Delta : {-2, -1, 1, 2}) {
1388     const SCEV *PreStart = getConstant(StartAI - Delta);
1389 
1390     // Give up if we don't already have the add recurrence we need because
1391     // actually constructing an add recurrence is relatively expensive.
1392     const SCEVAddRecExpr *PreAR = [&]() {
1393       FoldingSetNodeID ID;
1394       ID.AddInteger(scAddRecExpr);
1395       ID.AddPointer(PreStart);
1396       ID.AddPointer(Step);
1397       ID.AddPointer(L);
1398       void *IP = nullptr;
1399       return static_cast<SCEVAddRecExpr *>(
1400           this->UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1401     }();
1402 
1403     if (PreAR && PreAR->getNoWrapFlags(WrapType)) {  // proves (2)
1404       const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1405       ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
1406       const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1407           DeltaS, &Pred, this);
1408       if (Limit && isKnownPredicate(Pred, PreAR, Limit))  // proves (1)
1409         return true;
1410     }
1411   }
1412 
1413   return false;
1414 }
1415 
1416 const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
1417                                                Type *Ty) {
1418   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1419          "This is not an extending conversion!");
1420   assert(isSCEVable(Ty) &&
1421          "This is not a conversion to a SCEVable type!");
1422   Ty = getEffectiveSCEVType(Ty);
1423 
1424   // Fold if the operand is constant.
1425   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1426     return getConstant(
1427       cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
1428 
1429   // zext(zext(x)) --> zext(x)
1430   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1431     return getZeroExtendExpr(SZ->getOperand(), Ty);
1432 
1433   // Before doing any expensive analysis, check to see if we've already
1434   // computed a SCEV for this Op and Ty.
1435   FoldingSetNodeID ID;
1436   ID.AddInteger(scZeroExtend);
1437   ID.AddPointer(Op);
1438   ID.AddPointer(Ty);
1439   void *IP = nullptr;
1440   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1441 
1442   // zext(trunc(x)) --> zext(x) or x or trunc(x)
1443   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1444     // It's possible the bits taken off by the truncate were all zero bits. If
1445     // so, we should be able to simplify this further.
1446     const SCEV *X = ST->getOperand();
1447     ConstantRange CR = getUnsignedRange(X);
1448     unsigned TruncBits = getTypeSizeInBits(ST->getType());
1449     unsigned NewBits = getTypeSizeInBits(Ty);
1450     if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1451             CR.zextOrTrunc(NewBits)))
1452       return getTruncateOrZeroExtend(X, Ty);
1453   }
1454 
1455   // If the input value is a chrec scev, and we can prove that the value
1456   // did not overflow the old, smaller, value, we can zero extend all of the
1457   // operands (often constants).  This allows analysis of something like
1458   // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1459   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1460     if (AR->isAffine()) {
1461       const SCEV *Start = AR->getStart();
1462       const SCEV *Step = AR->getStepRecurrence(*this);
1463       unsigned BitWidth = getTypeSizeInBits(AR->getType());
1464       const Loop *L = AR->getLoop();
1465 
1466       // If we have special knowledge that this addrec won't overflow,
1467       // we don't need to do any further analysis.
1468       if (AR->getNoWrapFlags(SCEV::FlagNUW))
1469         return getAddRecExpr(
1470             getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
1471             getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1472 
1473       // Check whether the backedge-taken count is SCEVCouldNotCompute.
1474       // Note that this serves two purposes: It filters out loops that are
1475       // simply not analyzable, and it covers the case where this code is
1476       // being called from within backedge-taken count analysis, such that
1477       // attempting to ask for the backedge-taken count would likely result
1478       // in infinite recursion. In the later case, the analysis code will
1479       // cope with a conservative value, and it will take care to purge
1480       // that value once it has finished.
1481       const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
1482       if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1483         // Manually compute the final value for AR, checking for
1484         // overflow.
1485 
1486         // Check whether the backedge-taken count can be losslessly casted to
1487         // the addrec's type. The count is always unsigned.
1488         const SCEV *CastedMaxBECount =
1489           getTruncateOrZeroExtend(MaxBECount, Start->getType());
1490         const SCEV *RecastedMaxBECount =
1491           getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
1492         if (MaxBECount == RecastedMaxBECount) {
1493           Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1494           // Check whether Start+Step*MaxBECount has no unsigned overflow.
1495           const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step);
1496           const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul), WideTy);
1497           const SCEV *WideStart = getZeroExtendExpr(Start, WideTy);
1498           const SCEV *WideMaxBECount =
1499             getZeroExtendExpr(CastedMaxBECount, WideTy);
1500           const SCEV *OperandExtendedAdd =
1501             getAddExpr(WideStart,
1502                        getMulExpr(WideMaxBECount,
1503                                   getZeroExtendExpr(Step, WideTy)));
1504           if (ZAdd == OperandExtendedAdd) {
1505             // Cache knowledge of AR NUW, which is propagated to this AddRec.
1506             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1507             // Return the expression with the addrec on the outside.
1508             return getAddRecExpr(
1509                 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
1510                 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1511           }
1512           // Similar to above, only this time treat the step value as signed.
1513           // This covers loops that count down.
1514           OperandExtendedAdd =
1515             getAddExpr(WideStart,
1516                        getMulExpr(WideMaxBECount,
1517                                   getSignExtendExpr(Step, WideTy)));
1518           if (ZAdd == OperandExtendedAdd) {
1519             // Cache knowledge of AR NW, which is propagated to this AddRec.
1520             // Negative step causes unsigned wrap, but it still can't self-wrap.
1521             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1522             // Return the expression with the addrec on the outside.
1523             return getAddRecExpr(
1524                 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
1525                 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1526           }
1527         }
1528 
1529         // If the backedge is guarded by a comparison with the pre-inc value
1530         // the addrec is safe. Also, if the entry is guarded by a comparison
1531         // with the start value and the backedge is guarded by a comparison
1532         // with the post-inc value, the addrec is safe.
1533         if (isKnownPositive(Step)) {
1534           const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
1535                                       getUnsignedRange(Step).getUnsignedMax());
1536           if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
1537               (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) &&
1538                isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT,
1539                                            AR->getPostIncExpr(*this), N))) {
1540             // Cache knowledge of AR NUW, which is propagated to this AddRec.
1541             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1542             // Return the expression with the addrec on the outside.
1543             return getAddRecExpr(
1544                 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
1545                 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1546           }
1547         } else if (isKnownNegative(Step)) {
1548           const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
1549                                       getSignedRange(Step).getSignedMin());
1550           if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
1551               (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) &&
1552                isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT,
1553                                            AR->getPostIncExpr(*this), N))) {
1554             // Cache knowledge of AR NW, which is propagated to this AddRec.
1555             // Negative step causes unsigned wrap, but it still can't self-wrap.
1556             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1557             // Return the expression with the addrec on the outside.
1558             return getAddRecExpr(
1559                 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
1560                 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1561           }
1562         }
1563       }
1564 
1565       if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1566         const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1567         return getAddRecExpr(
1568             getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
1569             getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1570       }
1571     }
1572 
1573   // The cast wasn't folded; create an explicit cast node.
1574   // Recompute the insert position, as it may have been invalidated.
1575   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1576   SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1577                                                    Op, Ty);
1578   UniqueSCEVs.InsertNode(S, IP);
1579   return S;
1580 }
1581 
1582 const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
1583                                                Type *Ty) {
1584   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1585          "This is not an extending conversion!");
1586   assert(isSCEVable(Ty) &&
1587          "This is not a conversion to a SCEVable type!");
1588   Ty = getEffectiveSCEVType(Ty);
1589 
1590   // Fold if the operand is constant.
1591   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1592     return getConstant(
1593       cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
1594 
1595   // sext(sext(x)) --> sext(x)
1596   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1597     return getSignExtendExpr(SS->getOperand(), Ty);
1598 
1599   // sext(zext(x)) --> zext(x)
1600   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1601     return getZeroExtendExpr(SZ->getOperand(), Ty);
1602 
1603   // Before doing any expensive analysis, check to see if we've already
1604   // computed a SCEV for this Op and Ty.
1605   FoldingSetNodeID ID;
1606   ID.AddInteger(scSignExtend);
1607   ID.AddPointer(Op);
1608   ID.AddPointer(Ty);
1609   void *IP = nullptr;
1610   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1611 
1612   // If the input value is provably positive, build a zext instead.
1613   if (isKnownNonNegative(Op))
1614     return getZeroExtendExpr(Op, Ty);
1615 
1616   // sext(trunc(x)) --> sext(x) or x or trunc(x)
1617   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1618     // It's possible the bits taken off by the truncate were all sign bits. If
1619     // so, we should be able to simplify this further.
1620     const SCEV *X = ST->getOperand();
1621     ConstantRange CR = getSignedRange(X);
1622     unsigned TruncBits = getTypeSizeInBits(ST->getType());
1623     unsigned NewBits = getTypeSizeInBits(Ty);
1624     if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1625             CR.sextOrTrunc(NewBits)))
1626       return getTruncateOrSignExtend(X, Ty);
1627   }
1628 
1629   // sext(C1 + (C2 * x)) --> C1 + sext(C2 * x) if C1 < C2
1630   if (auto SA = dyn_cast<SCEVAddExpr>(Op)) {
1631     if (SA->getNumOperands() == 2) {
1632       auto SC1 = dyn_cast<SCEVConstant>(SA->getOperand(0));
1633       auto SMul = dyn_cast<SCEVMulExpr>(SA->getOperand(1));
1634       if (SMul && SC1) {
1635         if (auto SC2 = dyn_cast<SCEVConstant>(SMul->getOperand(0))) {
1636           const APInt &C1 = SC1->getValue()->getValue();
1637           const APInt &C2 = SC2->getValue()->getValue();
1638           if (C1.isStrictlyPositive() && C2.isStrictlyPositive() &&
1639               C2.ugt(C1) && C2.isPowerOf2())
1640             return getAddExpr(getSignExtendExpr(SC1, Ty),
1641                               getSignExtendExpr(SMul, Ty));
1642         }
1643       }
1644     }
1645   }
1646   // If the input value is a chrec scev, and we can prove that the value
1647   // did not overflow the old, smaller, value, we can sign extend all of the
1648   // operands (often constants).  This allows analysis of something like
1649   // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
1650   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1651     if (AR->isAffine()) {
1652       const SCEV *Start = AR->getStart();
1653       const SCEV *Step = AR->getStepRecurrence(*this);
1654       unsigned BitWidth = getTypeSizeInBits(AR->getType());
1655       const Loop *L = AR->getLoop();
1656 
1657       // If we have special knowledge that this addrec won't overflow,
1658       // we don't need to do any further analysis.
1659       if (AR->getNoWrapFlags(SCEV::FlagNSW))
1660         return getAddRecExpr(
1661             getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
1662             getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW);
1663 
1664       // Check whether the backedge-taken count is SCEVCouldNotCompute.
1665       // Note that this serves two purposes: It filters out loops that are
1666       // simply not analyzable, and it covers the case where this code is
1667       // being called from within backedge-taken count analysis, such that
1668       // attempting to ask for the backedge-taken count would likely result
1669       // in infinite recursion. In the later case, the analysis code will
1670       // cope with a conservative value, and it will take care to purge
1671       // that value once it has finished.
1672       const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
1673       if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1674         // Manually compute the final value for AR, checking for
1675         // overflow.
1676 
1677         // Check whether the backedge-taken count can be losslessly casted to
1678         // the addrec's type. The count is always unsigned.
1679         const SCEV *CastedMaxBECount =
1680           getTruncateOrZeroExtend(MaxBECount, Start->getType());
1681         const SCEV *RecastedMaxBECount =
1682           getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
1683         if (MaxBECount == RecastedMaxBECount) {
1684           Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1685           // Check whether Start+Step*MaxBECount has no signed overflow.
1686           const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
1687           const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul), WideTy);
1688           const SCEV *WideStart = getSignExtendExpr(Start, WideTy);
1689           const SCEV *WideMaxBECount =
1690             getZeroExtendExpr(CastedMaxBECount, WideTy);
1691           const SCEV *OperandExtendedAdd =
1692             getAddExpr(WideStart,
1693                        getMulExpr(WideMaxBECount,
1694                                   getSignExtendExpr(Step, WideTy)));
1695           if (SAdd == OperandExtendedAdd) {
1696             // Cache knowledge of AR NSW, which is propagated to this AddRec.
1697             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1698             // Return the expression with the addrec on the outside.
1699             return getAddRecExpr(
1700                 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
1701                 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1702           }
1703           // Similar to above, only this time treat the step value as unsigned.
1704           // This covers loops that count up with an unsigned step.
1705           OperandExtendedAdd =
1706             getAddExpr(WideStart,
1707                        getMulExpr(WideMaxBECount,
1708                                   getZeroExtendExpr(Step, WideTy)));
1709           if (SAdd == OperandExtendedAdd) {
1710             // If AR wraps around then
1711             //
1712             //    abs(Step) * MaxBECount > unsigned-max(AR->getType())
1713             // => SAdd != OperandExtendedAdd
1714             //
1715             // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
1716             // (SAdd == OperandExtendedAdd => AR is NW)
1717 
1718             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1719 
1720             // Return the expression with the addrec on the outside.
1721             return getAddRecExpr(
1722                 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
1723                 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1724           }
1725         }
1726 
1727         // If the backedge is guarded by a comparison with the pre-inc value
1728         // the addrec is safe. Also, if the entry is guarded by a comparison
1729         // with the start value and the backedge is guarded by a comparison
1730         // with the post-inc value, the addrec is safe.
1731         ICmpInst::Predicate Pred;
1732         const SCEV *OverflowLimit =
1733             getSignedOverflowLimitForStep(Step, &Pred, this);
1734         if (OverflowLimit &&
1735             (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
1736              (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) &&
1737               isLoopBackedgeGuardedByCond(L, Pred, AR->getPostIncExpr(*this),
1738                                           OverflowLimit)))) {
1739           // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec.
1740           const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1741           return getAddRecExpr(
1742               getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
1743               getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1744         }
1745       }
1746       // If Start and Step are constants, check if we can apply this
1747       // transformation:
1748       // sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2
1749       auto SC1 = dyn_cast<SCEVConstant>(Start);
1750       auto SC2 = dyn_cast<SCEVConstant>(Step);
1751       if (SC1 && SC2) {
1752         const APInt &C1 = SC1->getValue()->getValue();
1753         const APInt &C2 = SC2->getValue()->getValue();
1754         if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) &&
1755             C2.isPowerOf2()) {
1756           Start = getSignExtendExpr(Start, Ty);
1757           const SCEV *NewAR = getAddRecExpr(getConstant(AR->getType(), 0), Step,
1758                                             L, AR->getNoWrapFlags());
1759           return getAddExpr(Start, getSignExtendExpr(NewAR, Ty));
1760         }
1761       }
1762 
1763       if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
1764         const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1765         return getAddRecExpr(
1766             getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
1767             getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1768       }
1769     }
1770 
1771   // The cast wasn't folded; create an explicit cast node.
1772   // Recompute the insert position, as it may have been invalidated.
1773   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1774   SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1775                                                    Op, Ty);
1776   UniqueSCEVs.InsertNode(S, IP);
1777   return S;
1778 }
1779 
1780 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
1781 /// unspecified bits out to the given type.
1782 ///
1783 const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
1784                                               Type *Ty) {
1785   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1786          "This is not an extending conversion!");
1787   assert(isSCEVable(Ty) &&
1788          "This is not a conversion to a SCEVable type!");
1789   Ty = getEffectiveSCEVType(Ty);
1790 
1791   // Sign-extend negative constants.
1792   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1793     if (SC->getValue()->getValue().isNegative())
1794       return getSignExtendExpr(Op, Ty);
1795 
1796   // Peel off a truncate cast.
1797   if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
1798     const SCEV *NewOp = T->getOperand();
1799     if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
1800       return getAnyExtendExpr(NewOp, Ty);
1801     return getTruncateOrNoop(NewOp, Ty);
1802   }
1803 
1804   // Next try a zext cast. If the cast is folded, use it.
1805   const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
1806   if (!isa<SCEVZeroExtendExpr>(ZExt))
1807     return ZExt;
1808 
1809   // Next try a sext cast. If the cast is folded, use it.
1810   const SCEV *SExt = getSignExtendExpr(Op, Ty);
1811   if (!isa<SCEVSignExtendExpr>(SExt))
1812     return SExt;
1813 
1814   // Force the cast to be folded into the operands of an addrec.
1815   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
1816     SmallVector<const SCEV *, 4> Ops;
1817     for (const SCEV *Op : AR->operands())
1818       Ops.push_back(getAnyExtendExpr(Op, Ty));
1819     return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
1820   }
1821 
1822   // If the expression is obviously signed, use the sext cast value.
1823   if (isa<SCEVSMaxExpr>(Op))
1824     return SExt;
1825 
1826   // Absent any other information, use the zext cast value.
1827   return ZExt;
1828 }
1829 
1830 /// CollectAddOperandsWithScales - Process the given Ops list, which is
1831 /// a list of operands to be added under the given scale, update the given
1832 /// map. This is a helper function for getAddRecExpr. As an example of
1833 /// what it does, given a sequence of operands that would form an add
1834 /// expression like this:
1835 ///
1836 ///    m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
1837 ///
1838 /// where A and B are constants, update the map with these values:
1839 ///
1840 ///    (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
1841 ///
1842 /// and add 13 + A*B*29 to AccumulatedConstant.
1843 /// This will allow getAddRecExpr to produce this:
1844 ///
1845 ///    13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
1846 ///
1847 /// This form often exposes folding opportunities that are hidden in
1848 /// the original operand list.
1849 ///
1850 /// Return true iff it appears that any interesting folding opportunities
1851 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
1852 /// the common case where no interesting opportunities are present, and
1853 /// is also used as a check to avoid infinite recursion.
1854 ///
1855 static bool
1856 CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
1857                              SmallVectorImpl<const SCEV *> &NewOps,
1858                              APInt &AccumulatedConstant,
1859                              const SCEV *const *Ops, size_t NumOperands,
1860                              const APInt &Scale,
1861                              ScalarEvolution &SE) {
1862   bool Interesting = false;
1863 
1864   // Iterate over the add operands. They are sorted, with constants first.
1865   unsigned i = 0;
1866   while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1867     ++i;
1868     // Pull a buried constant out to the outside.
1869     if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
1870       Interesting = true;
1871     AccumulatedConstant += Scale * C->getValue()->getValue();
1872   }
1873 
1874   // Next comes everything else. We're especially interested in multiplies
1875   // here, but they're in the middle, so just visit the rest with one loop.
1876   for (; i != NumOperands; ++i) {
1877     const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
1878     if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
1879       APInt NewScale =
1880         Scale * cast<SCEVConstant>(Mul->getOperand(0))->getValue()->getValue();
1881       if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
1882         // A multiplication of a constant with another add; recurse.
1883         const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
1884         Interesting |=
1885           CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1886                                        Add->op_begin(), Add->getNumOperands(),
1887                                        NewScale, SE);
1888       } else {
1889         // A multiplication of a constant with some other value. Update
1890         // the map.
1891         SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
1892         const SCEV *Key = SE.getMulExpr(MulOps);
1893         std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1894           M.insert(std::make_pair(Key, NewScale));
1895         if (Pair.second) {
1896           NewOps.push_back(Pair.first->first);
1897         } else {
1898           Pair.first->second += NewScale;
1899           // The map already had an entry for this value, which may indicate
1900           // a folding opportunity.
1901           Interesting = true;
1902         }
1903       }
1904     } else {
1905       // An ordinary operand. Update the map.
1906       std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1907         M.insert(std::make_pair(Ops[i], Scale));
1908       if (Pair.second) {
1909         NewOps.push_back(Pair.first->first);
1910       } else {
1911         Pair.first->second += Scale;
1912         // The map already had an entry for this value, which may indicate
1913         // a folding opportunity.
1914         Interesting = true;
1915       }
1916     }
1917   }
1918 
1919   return Interesting;
1920 }
1921 
1922 namespace {
1923   struct APIntCompare {
1924     bool operator()(const APInt &LHS, const APInt &RHS) const {
1925       return LHS.ult(RHS);
1926     }
1927   };
1928 }
1929 
1930 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and
1931 // `OldFlags' as can't-wrap behavior.  Infer a more aggressive set of
1932 // can't-overflow flags for the operation if possible.
1933 static SCEV::NoWrapFlags
1934 StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
1935                       const SmallVectorImpl<const SCEV *> &Ops,
1936                       SCEV::NoWrapFlags OldFlags) {
1937   using namespace std::placeholders;
1938 
1939   bool CanAnalyze =
1940       Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
1941   (void)CanAnalyze;
1942   assert(CanAnalyze && "don't call from other places!");
1943 
1944   int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
1945   SCEV::NoWrapFlags SignOrUnsignWrap =
1946       ScalarEvolution::maskFlags(OldFlags, SignOrUnsignMask);
1947 
1948   // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
1949   auto IsKnownNonNegative =
1950     std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1);
1951 
1952   if (SignOrUnsignWrap == SCEV::FlagNSW &&
1953       std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative))
1954     return ScalarEvolution::setFlags(OldFlags,
1955                                      (SCEV::NoWrapFlags)SignOrUnsignMask);
1956 
1957   return OldFlags;
1958 }
1959 
1960 /// getAddExpr - Get a canonical add expression, or something simpler if
1961 /// possible.
1962 const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
1963                                         SCEV::NoWrapFlags Flags) {
1964   assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
1965          "only nuw or nsw allowed");
1966   assert(!Ops.empty() && "Cannot get empty add!");
1967   if (Ops.size() == 1) return Ops[0];
1968 #ifndef NDEBUG
1969   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
1970   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1971     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
1972            "SCEVAddExpr operand types don't match!");
1973 #endif
1974 
1975   Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags);
1976 
1977   // Sort by complexity, this groups all similar expression types together.
1978   GroupByComplexity(Ops, LI);
1979 
1980   // If there are any constants, fold them together.
1981   unsigned Idx = 0;
1982   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1983     ++Idx;
1984     assert(Idx < Ops.size());
1985     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1986       // We found two constants, fold them together!
1987       Ops[0] = getConstant(LHSC->getValue()->getValue() +
1988                            RHSC->getValue()->getValue());
1989       if (Ops.size() == 2) return Ops[0];
1990       Ops.erase(Ops.begin()+1);  // Erase the folded element
1991       LHSC = cast<SCEVConstant>(Ops[0]);
1992     }
1993 
1994     // If we are left with a constant zero being added, strip it off.
1995     if (LHSC->getValue()->isZero()) {
1996       Ops.erase(Ops.begin());
1997       --Idx;
1998     }
1999 
2000     if (Ops.size() == 1) return Ops[0];
2001   }
2002 
2003   // Okay, check to see if the same value occurs in the operand list more than
2004   // once.  If so, merge them together into an multiply expression.  Since we
2005   // sorted the list, these values are required to be adjacent.
2006   Type *Ty = Ops[0]->getType();
2007   bool FoundMatch = false;
2008   for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2009     if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
2010       // Scan ahead to count how many equal operands there are.
2011       unsigned Count = 2;
2012       while (i+Count != e && Ops[i+Count] == Ops[i])
2013         ++Count;
2014       // Merge the values into a multiply.
2015       const SCEV *Scale = getConstant(Ty, Count);
2016       const SCEV *Mul = getMulExpr(Scale, Ops[i]);
2017       if (Ops.size() == Count)
2018         return Mul;
2019       Ops[i] = Mul;
2020       Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2021       --i; e -= Count - 1;
2022       FoundMatch = true;
2023     }
2024   if (FoundMatch)
2025     return getAddExpr(Ops, Flags);
2026 
2027   // Check for truncates. If all the operands are truncated from the same
2028   // type, see if factoring out the truncate would permit the result to be
2029   // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
2030   // if the contents of the resulting outer trunc fold to something simple.
2031   for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
2032     const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
2033     Type *DstType = Trunc->getType();
2034     Type *SrcType = Trunc->getOperand()->getType();
2035     SmallVector<const SCEV *, 8> LargeOps;
2036     bool Ok = true;
2037     // Check all the operands to see if they can be represented in the
2038     // source type of the truncate.
2039     for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2040       if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2041         if (T->getOperand()->getType() != SrcType) {
2042           Ok = false;
2043           break;
2044         }
2045         LargeOps.push_back(T->getOperand());
2046       } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2047         LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2048       } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2049         SmallVector<const SCEV *, 8> LargeMulOps;
2050         for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2051           if (const SCEVTruncateExpr *T =
2052                 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2053             if (T->getOperand()->getType() != SrcType) {
2054               Ok = false;
2055               break;
2056             }
2057             LargeMulOps.push_back(T->getOperand());
2058           } else if (const SCEVConstant *C =
2059                        dyn_cast<SCEVConstant>(M->getOperand(j))) {
2060             LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2061           } else {
2062             Ok = false;
2063             break;
2064           }
2065         }
2066         if (Ok)
2067           LargeOps.push_back(getMulExpr(LargeMulOps));
2068       } else {
2069         Ok = false;
2070         break;
2071       }
2072     }
2073     if (Ok) {
2074       // Evaluate the expression in the larger type.
2075       const SCEV *Fold = getAddExpr(LargeOps, Flags);
2076       // If it folds to something simple, use it. Otherwise, don't.
2077       if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2078         return getTruncateExpr(Fold, DstType);
2079     }
2080   }
2081 
2082   // Skip past any other cast SCEVs.
2083   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2084     ++Idx;
2085 
2086   // If there are add operands they would be next.
2087   if (Idx < Ops.size()) {
2088     bool DeletedAdd = false;
2089     while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2090       // If we have an add, expand the add operands onto the end of the operands
2091       // list.
2092       Ops.erase(Ops.begin()+Idx);
2093       Ops.append(Add->op_begin(), Add->op_end());
2094       DeletedAdd = true;
2095     }
2096 
2097     // If we deleted at least one add, we added operands to the end of the list,
2098     // and they are not necessarily sorted.  Recurse to resort and resimplify
2099     // any operands we just acquired.
2100     if (DeletedAdd)
2101       return getAddExpr(Ops);
2102   }
2103 
2104   // Skip over the add expression until we get to a multiply.
2105   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2106     ++Idx;
2107 
2108   // Check to see if there are any folding opportunities present with
2109   // operands multiplied by constant values.
2110   if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2111     uint64_t BitWidth = getTypeSizeInBits(Ty);
2112     DenseMap<const SCEV *, APInt> M;
2113     SmallVector<const SCEV *, 8> NewOps;
2114     APInt AccumulatedConstant(BitWidth, 0);
2115     if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2116                                      Ops.data(), Ops.size(),
2117                                      APInt(BitWidth, 1), *this)) {
2118       // Some interesting folding opportunity is present, so its worthwhile to
2119       // re-generate the operands list. Group the operands by constant scale,
2120       // to avoid multiplying by the same constant scale multiple times.
2121       std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2122       for (SmallVectorImpl<const SCEV *>::const_iterator I = NewOps.begin(),
2123            E = NewOps.end(); I != E; ++I)
2124         MulOpLists[M.find(*I)->second].push_back(*I);
2125       // Re-generate the operands list.
2126       Ops.clear();
2127       if (AccumulatedConstant != 0)
2128         Ops.push_back(getConstant(AccumulatedConstant));
2129       for (std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare>::iterator
2130            I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I)
2131         if (I->first != 0)
2132           Ops.push_back(getMulExpr(getConstant(I->first),
2133                                    getAddExpr(I->second)));
2134       if (Ops.empty())
2135         return getConstant(Ty, 0);
2136       if (Ops.size() == 1)
2137         return Ops[0];
2138       return getAddExpr(Ops);
2139     }
2140   }
2141 
2142   // If we are adding something to a multiply expression, make sure the
2143   // something is not already an operand of the multiply.  If so, merge it into
2144   // the multiply.
2145   for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2146     const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2147     for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2148       const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2149       if (isa<SCEVConstant>(MulOpSCEV))
2150         continue;
2151       for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2152         if (MulOpSCEV == Ops[AddOp]) {
2153           // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
2154           const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2155           if (Mul->getNumOperands() != 2) {
2156             // If the multiply has more than two operands, we must get the
2157             // Y*Z term.
2158             SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2159                                                 Mul->op_begin()+MulOp);
2160             MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2161             InnerMul = getMulExpr(MulOps);
2162           }
2163           const SCEV *One = getConstant(Ty, 1);
2164           const SCEV *AddOne = getAddExpr(One, InnerMul);
2165           const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV);
2166           if (Ops.size() == 2) return OuterMul;
2167           if (AddOp < Idx) {
2168             Ops.erase(Ops.begin()+AddOp);
2169             Ops.erase(Ops.begin()+Idx-1);
2170           } else {
2171             Ops.erase(Ops.begin()+Idx);
2172             Ops.erase(Ops.begin()+AddOp-1);
2173           }
2174           Ops.push_back(OuterMul);
2175           return getAddExpr(Ops);
2176         }
2177 
2178       // Check this multiply against other multiplies being added together.
2179       for (unsigned OtherMulIdx = Idx+1;
2180            OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2181            ++OtherMulIdx) {
2182         const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2183         // If MulOp occurs in OtherMul, we can fold the two multiplies
2184         // together.
2185         for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2186              OMulOp != e; ++OMulOp)
2187           if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2188             // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2189             const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2190             if (Mul->getNumOperands() != 2) {
2191               SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2192                                                   Mul->op_begin()+MulOp);
2193               MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2194               InnerMul1 = getMulExpr(MulOps);
2195             }
2196             const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2197             if (OtherMul->getNumOperands() != 2) {
2198               SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
2199                                                   OtherMul->op_begin()+OMulOp);
2200               MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
2201               InnerMul2 = getMulExpr(MulOps);
2202             }
2203             const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
2204             const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
2205             if (Ops.size() == 2) return OuterMul;
2206             Ops.erase(Ops.begin()+Idx);
2207             Ops.erase(Ops.begin()+OtherMulIdx-1);
2208             Ops.push_back(OuterMul);
2209             return getAddExpr(Ops);
2210           }
2211       }
2212     }
2213   }
2214 
2215   // If there are any add recurrences in the operands list, see if any other
2216   // added values are loop invariant.  If so, we can fold them into the
2217   // recurrence.
2218   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2219     ++Idx;
2220 
2221   // Scan over all recurrences, trying to fold loop invariants into them.
2222   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2223     // Scan all of the other operands to this add and add them to the vector if
2224     // they are loop invariant w.r.t. the recurrence.
2225     SmallVector<const SCEV *, 8> LIOps;
2226     const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2227     const Loop *AddRecLoop = AddRec->getLoop();
2228     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2229       if (isLoopInvariant(Ops[i], AddRecLoop)) {
2230         LIOps.push_back(Ops[i]);
2231         Ops.erase(Ops.begin()+i);
2232         --i; --e;
2233       }
2234 
2235     // If we found some loop invariants, fold them into the recurrence.
2236     if (!LIOps.empty()) {
2237       //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
2238       LIOps.push_back(AddRec->getStart());
2239 
2240       SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
2241                                              AddRec->op_end());
2242       AddRecOps[0] = getAddExpr(LIOps);
2243 
2244       // Build the new addrec. Propagate the NUW and NSW flags if both the
2245       // outer add and the inner addrec are guaranteed to have no overflow.
2246       // Always propagate NW.
2247       Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2248       const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2249 
2250       // If all of the other operands were loop invariant, we are done.
2251       if (Ops.size() == 1) return NewRec;
2252 
2253       // Otherwise, add the folded AddRec by the non-invariant parts.
2254       for (unsigned i = 0;; ++i)
2255         if (Ops[i] == AddRec) {
2256           Ops[i] = NewRec;
2257           break;
2258         }
2259       return getAddExpr(Ops);
2260     }
2261 
2262     // Okay, if there weren't any loop invariants to be folded, check to see if
2263     // there are multiple AddRec's with the same loop induction variable being
2264     // added together.  If so, we can fold them.
2265     for (unsigned OtherIdx = Idx+1;
2266          OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2267          ++OtherIdx)
2268       if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2269         // Other + {A,+,B}<L> + {C,+,D}<L>  -->  Other + {A+C,+,B+D}<L>
2270         SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
2271                                                AddRec->op_end());
2272         for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2273              ++OtherIdx)
2274           if (const SCEVAddRecExpr *OtherAddRec =
2275                 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
2276             if (OtherAddRec->getLoop() == AddRecLoop) {
2277               for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2278                    i != e; ++i) {
2279                 if (i >= AddRecOps.size()) {
2280                   AddRecOps.append(OtherAddRec->op_begin()+i,
2281                                    OtherAddRec->op_end());
2282                   break;
2283                 }
2284                 AddRecOps[i] = getAddExpr(AddRecOps[i],
2285                                           OtherAddRec->getOperand(i));
2286               }
2287               Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2288             }
2289         // Step size has changed, so we cannot guarantee no self-wraparound.
2290         Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2291         return getAddExpr(Ops);
2292       }
2293 
2294     // Otherwise couldn't fold anything into this recurrence.  Move onto the
2295     // next one.
2296   }
2297 
2298   // Okay, it looks like we really DO need an add expr.  Check to see if we
2299   // already have one, otherwise create a new one.
2300   FoldingSetNodeID ID;
2301   ID.AddInteger(scAddExpr);
2302   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2303     ID.AddPointer(Ops[i]);
2304   void *IP = nullptr;
2305   SCEVAddExpr *S =
2306     static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2307   if (!S) {
2308     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2309     std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2310     S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator),
2311                                         O, Ops.size());
2312     UniqueSCEVs.InsertNode(S, IP);
2313   }
2314   S->setNoWrapFlags(Flags);
2315   return S;
2316 }
2317 
2318 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
2319   uint64_t k = i*j;
2320   if (j > 1 && k / j != i) Overflow = true;
2321   return k;
2322 }
2323 
2324 /// Compute the result of "n choose k", the binomial coefficient.  If an
2325 /// intermediate computation overflows, Overflow will be set and the return will
2326 /// be garbage. Overflow is not cleared on absence of overflow.
2327 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
2328   // We use the multiplicative formula:
2329   //     n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
2330   // At each iteration, we take the n-th term of the numeral and divide by the
2331   // (k-n)th term of the denominator.  This division will always produce an
2332   // integral result, and helps reduce the chance of overflow in the
2333   // intermediate computations. However, we can still overflow even when the
2334   // final result would fit.
2335 
2336   if (n == 0 || n == k) return 1;
2337   if (k > n) return 0;
2338 
2339   if (k > n/2)
2340     k = n-k;
2341 
2342   uint64_t r = 1;
2343   for (uint64_t i = 1; i <= k; ++i) {
2344     r = umul_ov(r, n-(i-1), Overflow);
2345     r /= i;
2346   }
2347   return r;
2348 }
2349 
2350 /// Determine if any of the operands in this SCEV are a constant or if
2351 /// any of the add or multiply expressions in this SCEV contain a constant.
2352 static bool containsConstantSomewhere(const SCEV *StartExpr) {
2353   SmallVector<const SCEV *, 4> Ops;
2354   Ops.push_back(StartExpr);
2355   while (!Ops.empty()) {
2356     const SCEV *CurrentExpr = Ops.pop_back_val();
2357     if (isa<SCEVConstant>(*CurrentExpr))
2358       return true;
2359 
2360     if (isa<SCEVAddExpr>(*CurrentExpr) || isa<SCEVMulExpr>(*CurrentExpr)) {
2361       const auto *CurrentNAry = cast<SCEVNAryExpr>(CurrentExpr);
2362       Ops.append(CurrentNAry->op_begin(), CurrentNAry->op_end());
2363     }
2364   }
2365   return false;
2366 }
2367 
2368 /// getMulExpr - Get a canonical multiply expression, or something simpler if
2369 /// possible.
2370 const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
2371                                         SCEV::NoWrapFlags Flags) {
2372   assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) &&
2373          "only nuw or nsw allowed");
2374   assert(!Ops.empty() && "Cannot get empty mul!");
2375   if (Ops.size() == 1) return Ops[0];
2376 #ifndef NDEBUG
2377   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2378   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2379     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2380            "SCEVMulExpr operand types don't match!");
2381 #endif
2382 
2383   Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags);
2384 
2385   // Sort by complexity, this groups all similar expression types together.
2386   GroupByComplexity(Ops, LI);
2387 
2388   // If there are any constants, fold them together.
2389   unsigned Idx = 0;
2390   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2391 
2392     // C1*(C2+V) -> C1*C2 + C1*V
2393     if (Ops.size() == 2)
2394         if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
2395           // If any of Add's ops are Adds or Muls with a constant,
2396           // apply this transformation as well.
2397           if (Add->getNumOperands() == 2)
2398             if (containsConstantSomewhere(Add))
2399               return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
2400                                 getMulExpr(LHSC, Add->getOperand(1)));
2401 
2402     ++Idx;
2403     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2404       // We found two constants, fold them together!
2405       ConstantInt *Fold = ConstantInt::get(getContext(),
2406                                            LHSC->getValue()->getValue() *
2407                                            RHSC->getValue()->getValue());
2408       Ops[0] = getConstant(Fold);
2409       Ops.erase(Ops.begin()+1);  // Erase the folded element
2410       if (Ops.size() == 1) return Ops[0];
2411       LHSC = cast<SCEVConstant>(Ops[0]);
2412     }
2413 
2414     // If we are left with a constant one being multiplied, strip it off.
2415     if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
2416       Ops.erase(Ops.begin());
2417       --Idx;
2418     } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
2419       // If we have a multiply of zero, it will always be zero.
2420       return Ops[0];
2421     } else if (Ops[0]->isAllOnesValue()) {
2422       // If we have a mul by -1 of an add, try distributing the -1 among the
2423       // add operands.
2424       if (Ops.size() == 2) {
2425         if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
2426           SmallVector<const SCEV *, 4> NewOps;
2427           bool AnyFolded = false;
2428           for (SCEVAddRecExpr::op_iterator I = Add->op_begin(),
2429                  E = Add->op_end(); I != E; ++I) {
2430             const SCEV *Mul = getMulExpr(Ops[0], *I);
2431             if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
2432             NewOps.push_back(Mul);
2433           }
2434           if (AnyFolded)
2435             return getAddExpr(NewOps);
2436         }
2437         else if (const SCEVAddRecExpr *
2438                  AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
2439           // Negation preserves a recurrence's no self-wrap property.
2440           SmallVector<const SCEV *, 4> Operands;
2441           for (SCEVAddRecExpr::op_iterator I = AddRec->op_begin(),
2442                  E = AddRec->op_end(); I != E; ++I) {
2443             Operands.push_back(getMulExpr(Ops[0], *I));
2444           }
2445           return getAddRecExpr(Operands, AddRec->getLoop(),
2446                                AddRec->getNoWrapFlags(SCEV::FlagNW));
2447         }
2448       }
2449     }
2450 
2451     if (Ops.size() == 1)
2452       return Ops[0];
2453   }
2454 
2455   // Skip over the add expression until we get to a multiply.
2456   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2457     ++Idx;
2458 
2459   // If there are mul operands inline them all into this expression.
2460   if (Idx < Ops.size()) {
2461     bool DeletedMul = false;
2462     while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2463       // If we have an mul, expand the mul operands onto the end of the operands
2464       // list.
2465       Ops.erase(Ops.begin()+Idx);
2466       Ops.append(Mul->op_begin(), Mul->op_end());
2467       DeletedMul = true;
2468     }
2469 
2470     // If we deleted at least one mul, we added operands to the end of the list,
2471     // and they are not necessarily sorted.  Recurse to resort and resimplify
2472     // any operands we just acquired.
2473     if (DeletedMul)
2474       return getMulExpr(Ops);
2475   }
2476 
2477   // If there are any add recurrences in the operands list, see if any other
2478   // added values are loop invariant.  If so, we can fold them into the
2479   // recurrence.
2480   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2481     ++Idx;
2482 
2483   // Scan over all recurrences, trying to fold loop invariants into them.
2484   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2485     // Scan all of the other operands to this mul and add them to the vector if
2486     // they are loop invariant w.r.t. the recurrence.
2487     SmallVector<const SCEV *, 8> LIOps;
2488     const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2489     const Loop *AddRecLoop = AddRec->getLoop();
2490     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2491       if (isLoopInvariant(Ops[i], AddRecLoop)) {
2492         LIOps.push_back(Ops[i]);
2493         Ops.erase(Ops.begin()+i);
2494         --i; --e;
2495       }
2496 
2497     // If we found some loop invariants, fold them into the recurrence.
2498     if (!LIOps.empty()) {
2499       //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
2500       SmallVector<const SCEV *, 4> NewOps;
2501       NewOps.reserve(AddRec->getNumOperands());
2502       const SCEV *Scale = getMulExpr(LIOps);
2503       for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
2504         NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
2505 
2506       // Build the new addrec. Propagate the NUW and NSW flags if both the
2507       // outer mul and the inner addrec are guaranteed to have no overflow.
2508       //
2509       // No self-wrap cannot be guaranteed after changing the step size, but
2510       // will be inferred if either NUW or NSW is true.
2511       Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW));
2512       const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags);
2513 
2514       // If all of the other operands were loop invariant, we are done.
2515       if (Ops.size() == 1) return NewRec;
2516 
2517       // Otherwise, multiply the folded AddRec by the non-invariant parts.
2518       for (unsigned i = 0;; ++i)
2519         if (Ops[i] == AddRec) {
2520           Ops[i] = NewRec;
2521           break;
2522         }
2523       return getMulExpr(Ops);
2524     }
2525 
2526     // Okay, if there weren't any loop invariants to be folded, check to see if
2527     // there are multiple AddRec's with the same loop induction variable being
2528     // multiplied together.  If so, we can fold them.
2529 
2530     // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
2531     // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
2532     //       choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
2533     //   ]]],+,...up to x=2n}.
2534     // Note that the arguments to choose() are always integers with values
2535     // known at compile time, never SCEV objects.
2536     //
2537     // The implementation avoids pointless extra computations when the two
2538     // addrec's are of different length (mathematically, it's equivalent to
2539     // an infinite stream of zeros on the right).
2540     bool OpsModified = false;
2541     for (unsigned OtherIdx = Idx+1;
2542          OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2543          ++OtherIdx) {
2544       const SCEVAddRecExpr *OtherAddRec =
2545         dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2546       if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
2547         continue;
2548 
2549       bool Overflow = false;
2550       Type *Ty = AddRec->getType();
2551       bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
2552       SmallVector<const SCEV*, 7> AddRecOps;
2553       for (int x = 0, xe = AddRec->getNumOperands() +
2554              OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
2555         const SCEV *Term = getConstant(Ty, 0);
2556         for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
2557           uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
2558           for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
2559                  ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
2560                z < ze && !Overflow; ++z) {
2561             uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
2562             uint64_t Coeff;
2563             if (LargerThan64Bits)
2564               Coeff = umul_ov(Coeff1, Coeff2, Overflow);
2565             else
2566               Coeff = Coeff1*Coeff2;
2567             const SCEV *CoeffTerm = getConstant(Ty, Coeff);
2568             const SCEV *Term1 = AddRec->getOperand(y-z);
2569             const SCEV *Term2 = OtherAddRec->getOperand(z);
2570             Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2));
2571           }
2572         }
2573         AddRecOps.push_back(Term);
2574       }
2575       if (!Overflow) {
2576         const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
2577                                               SCEV::FlagAnyWrap);
2578         if (Ops.size() == 2) return NewAddRec;
2579         Ops[Idx] = NewAddRec;
2580         Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2581         OpsModified = true;
2582         AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
2583         if (!AddRec)
2584           break;
2585       }
2586     }
2587     if (OpsModified)
2588       return getMulExpr(Ops);
2589 
2590     // Otherwise couldn't fold anything into this recurrence.  Move onto the
2591     // next one.
2592   }
2593 
2594   // Okay, it looks like we really DO need an mul expr.  Check to see if we
2595   // already have one, otherwise create a new one.
2596   FoldingSetNodeID ID;
2597   ID.AddInteger(scMulExpr);
2598   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2599     ID.AddPointer(Ops[i]);
2600   void *IP = nullptr;
2601   SCEVMulExpr *S =
2602     static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2603   if (!S) {
2604     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2605     std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2606     S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2607                                         O, Ops.size());
2608     UniqueSCEVs.InsertNode(S, IP);
2609   }
2610   S->setNoWrapFlags(Flags);
2611   return S;
2612 }
2613 
2614 /// getUDivExpr - Get a canonical unsigned division expression, or something
2615 /// simpler if possible.
2616 const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
2617                                          const SCEV *RHS) {
2618   assert(getEffectiveSCEVType(LHS->getType()) ==
2619          getEffectiveSCEVType(RHS->getType()) &&
2620          "SCEVUDivExpr operand types don't match!");
2621 
2622   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
2623     if (RHSC->getValue()->equalsInt(1))
2624       return LHS;                               // X udiv 1 --> x
2625     // If the denominator is zero, the result of the udiv is undefined. Don't
2626     // try to analyze it, because the resolution chosen here may differ from
2627     // the resolution chosen in other parts of the compiler.
2628     if (!RHSC->getValue()->isZero()) {
2629       // Determine if the division can be folded into the operands of
2630       // its operands.
2631       // TODO: Generalize this to non-constants by using known-bits information.
2632       Type *Ty = LHS->getType();
2633       unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
2634       unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
2635       // For non-power-of-two values, effectively round the value up to the
2636       // nearest power of two.
2637       if (!RHSC->getValue()->getValue().isPowerOf2())
2638         ++MaxShiftAmt;
2639       IntegerType *ExtTy =
2640         IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
2641       if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
2642         if (const SCEVConstant *Step =
2643             dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
2644           // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
2645           const APInt &StepInt = Step->getValue()->getValue();
2646           const APInt &DivInt = RHSC->getValue()->getValue();
2647           if (!StepInt.urem(DivInt) &&
2648               getZeroExtendExpr(AR, ExtTy) ==
2649               getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
2650                             getZeroExtendExpr(Step, ExtTy),
2651                             AR->getLoop(), SCEV::FlagAnyWrap)) {
2652             SmallVector<const SCEV *, 4> Operands;
2653             for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i)
2654               Operands.push_back(getUDivExpr(AR->getOperand(i), RHS));
2655             return getAddRecExpr(Operands, AR->getLoop(),
2656                                  SCEV::FlagNW);
2657           }
2658           /// Get a canonical UDivExpr for a recurrence.
2659           /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
2660           // We can currently only fold X%N if X is constant.
2661           const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
2662           if (StartC && !DivInt.urem(StepInt) &&
2663               getZeroExtendExpr(AR, ExtTy) ==
2664               getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
2665                             getZeroExtendExpr(Step, ExtTy),
2666                             AR->getLoop(), SCEV::FlagAnyWrap)) {
2667             const APInt &StartInt = StartC->getValue()->getValue();
2668             const APInt &StartRem = StartInt.urem(StepInt);
2669             if (StartRem != 0)
2670               LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step,
2671                                   AR->getLoop(), SCEV::FlagNW);
2672           }
2673         }
2674       // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
2675       if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
2676         SmallVector<const SCEV *, 4> Operands;
2677         for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i)
2678           Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy));
2679         if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
2680           // Find an operand that's safely divisible.
2681           for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
2682             const SCEV *Op = M->getOperand(i);
2683             const SCEV *Div = getUDivExpr(Op, RHSC);
2684             if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
2685               Operands = SmallVector<const SCEV *, 4>(M->op_begin(),
2686                                                       M->op_end());
2687               Operands[i] = Div;
2688               return getMulExpr(Operands);
2689             }
2690           }
2691       }
2692       // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
2693       if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
2694         SmallVector<const SCEV *, 4> Operands;
2695         for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i)
2696           Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy));
2697         if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
2698           Operands.clear();
2699           for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
2700             const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
2701             if (isa<SCEVUDivExpr>(Op) ||
2702                 getMulExpr(Op, RHS) != A->getOperand(i))
2703               break;
2704             Operands.push_back(Op);
2705           }
2706           if (Operands.size() == A->getNumOperands())
2707             return getAddExpr(Operands);
2708         }
2709       }
2710 
2711       // Fold if both operands are constant.
2712       if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
2713         Constant *LHSCV = LHSC->getValue();
2714         Constant *RHSCV = RHSC->getValue();
2715         return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
2716                                                                    RHSCV)));
2717       }
2718     }
2719   }
2720 
2721   FoldingSetNodeID ID;
2722   ID.AddInteger(scUDivExpr);
2723   ID.AddPointer(LHS);
2724   ID.AddPointer(RHS);
2725   void *IP = nullptr;
2726   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2727   SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
2728                                              LHS, RHS);
2729   UniqueSCEVs.InsertNode(S, IP);
2730   return S;
2731 }
2732 
2733 static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
2734   APInt A = C1->getValue()->getValue().abs();
2735   APInt B = C2->getValue()->getValue().abs();
2736   uint32_t ABW = A.getBitWidth();
2737   uint32_t BBW = B.getBitWidth();
2738 
2739   if (ABW > BBW)
2740     B = B.zext(ABW);
2741   else if (ABW < BBW)
2742     A = A.zext(BBW);
2743 
2744   return APIntOps::GreatestCommonDivisor(A, B);
2745 }
2746 
2747 /// getUDivExactExpr - Get a canonical unsigned division expression, or
2748 /// something simpler if possible. There is no representation for an exact udiv
2749 /// in SCEV IR, but we can attempt to remove factors from the LHS and RHS.
2750 /// We can't do this when it's not exact because the udiv may be clearing bits.
2751 const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS,
2752                                               const SCEV *RHS) {
2753   // TODO: we could try to find factors in all sorts of things, but for now we
2754   // just deal with u/exact (multiply, constant). See SCEVDivision towards the
2755   // end of this file for inspiration.
2756 
2757   const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
2758   if (!Mul)
2759     return getUDivExpr(LHS, RHS);
2760 
2761   if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
2762     // If the mulexpr multiplies by a constant, then that constant must be the
2763     // first element of the mulexpr.
2764     if (const SCEVConstant *LHSCst =
2765             dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
2766       if (LHSCst == RHSCst) {
2767         SmallVector<const SCEV *, 2> Operands;
2768         Operands.append(Mul->op_begin() + 1, Mul->op_end());
2769         return getMulExpr(Operands);
2770       }
2771 
2772       // We can't just assume that LHSCst divides RHSCst cleanly, it could be
2773       // that there's a factor provided by one of the other terms. We need to
2774       // check.
2775       APInt Factor = gcd(LHSCst, RHSCst);
2776       if (!Factor.isIntN(1)) {
2777         LHSCst = cast<SCEVConstant>(
2778             getConstant(LHSCst->getValue()->getValue().udiv(Factor)));
2779         RHSCst = cast<SCEVConstant>(
2780             getConstant(RHSCst->getValue()->getValue().udiv(Factor)));
2781         SmallVector<const SCEV *, 2> Operands;
2782         Operands.push_back(LHSCst);
2783         Operands.append(Mul->op_begin() + 1, Mul->op_end());
2784         LHS = getMulExpr(Operands);
2785         RHS = RHSCst;
2786         Mul = dyn_cast<SCEVMulExpr>(LHS);
2787         if (!Mul)
2788           return getUDivExactExpr(LHS, RHS);
2789       }
2790     }
2791   }
2792 
2793   for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
2794     if (Mul->getOperand(i) == RHS) {
2795       SmallVector<const SCEV *, 2> Operands;
2796       Operands.append(Mul->op_begin(), Mul->op_begin() + i);
2797       Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
2798       return getMulExpr(Operands);
2799     }
2800   }
2801 
2802   return getUDivExpr(LHS, RHS);
2803 }
2804 
2805 /// getAddRecExpr - Get an add recurrence expression for the specified loop.
2806 /// Simplify the expression as much as possible.
2807 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
2808                                            const Loop *L,
2809                                            SCEV::NoWrapFlags Flags) {
2810   SmallVector<const SCEV *, 4> Operands;
2811   Operands.push_back(Start);
2812   if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
2813     if (StepChrec->getLoop() == L) {
2814       Operands.append(StepChrec->op_begin(), StepChrec->op_end());
2815       return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
2816     }
2817 
2818   Operands.push_back(Step);
2819   return getAddRecExpr(Operands, L, Flags);
2820 }
2821 
2822 /// getAddRecExpr - Get an add recurrence expression for the specified loop.
2823 /// Simplify the expression as much as possible.
2824 const SCEV *
2825 ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
2826                                const Loop *L, SCEV::NoWrapFlags Flags) {
2827   if (Operands.size() == 1) return Operands[0];
2828 #ifndef NDEBUG
2829   Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
2830   for (unsigned i = 1, e = Operands.size(); i != e; ++i)
2831     assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
2832            "SCEVAddRecExpr operand types don't match!");
2833   for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2834     assert(isLoopInvariant(Operands[i], L) &&
2835            "SCEVAddRecExpr operand is not loop-invariant!");
2836 #endif
2837 
2838   if (Operands.back()->isZero()) {
2839     Operands.pop_back();
2840     return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0}  -->  X
2841   }
2842 
2843   // It's tempting to want to call getMaxBackedgeTakenCount count here and
2844   // use that information to infer NUW and NSW flags. However, computing a
2845   // BE count requires calling getAddRecExpr, so we may not yet have a
2846   // meaningful BE count at this point (and if we don't, we'd be stuck
2847   // with a SCEVCouldNotCompute as the cached BE count).
2848 
2849   Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
2850 
2851   // Canonicalize nested AddRecs in by nesting them in order of loop depth.
2852   if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
2853     const Loop *NestedLoop = NestedAR->getLoop();
2854     if (L->contains(NestedLoop) ?
2855         (L->getLoopDepth() < NestedLoop->getLoopDepth()) :
2856         (!NestedLoop->contains(L) &&
2857          DT->dominates(L->getHeader(), NestedLoop->getHeader()))) {
2858       SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(),
2859                                                   NestedAR->op_end());
2860       Operands[0] = NestedAR->getStart();
2861       // AddRecs require their operands be loop-invariant with respect to their
2862       // loops. Don't perform this transformation if it would break this
2863       // requirement.
2864       bool AllInvariant = true;
2865       for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2866         if (!isLoopInvariant(Operands[i], L)) {
2867           AllInvariant = false;
2868           break;
2869         }
2870       if (AllInvariant) {
2871         // Create a recurrence for the outer loop with the same step size.
2872         //
2873         // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
2874         // inner recurrence has the same property.
2875         SCEV::NoWrapFlags OuterFlags =
2876           maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
2877 
2878         NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
2879         AllInvariant = true;
2880         for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i)
2881           if (!isLoopInvariant(NestedOperands[i], NestedLoop)) {
2882             AllInvariant = false;
2883             break;
2884           }
2885         if (AllInvariant) {
2886           // Ok, both add recurrences are valid after the transformation.
2887           //
2888           // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
2889           // the outer recurrence has the same property.
2890           SCEV::NoWrapFlags InnerFlags =
2891             maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
2892           return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
2893         }
2894       }
2895       // Reset Operands to its original state.
2896       Operands[0] = NestedAR;
2897     }
2898   }
2899 
2900   // Okay, it looks like we really DO need an addrec expr.  Check to see if we
2901   // already have one, otherwise create a new one.
2902   FoldingSetNodeID ID;
2903   ID.AddInteger(scAddRecExpr);
2904   for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2905     ID.AddPointer(Operands[i]);
2906   ID.AddPointer(L);
2907   void *IP = nullptr;
2908   SCEVAddRecExpr *S =
2909     static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2910   if (!S) {
2911     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Operands.size());
2912     std::uninitialized_copy(Operands.begin(), Operands.end(), O);
2913     S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator),
2914                                            O, Operands.size(), L);
2915     UniqueSCEVs.InsertNode(S, IP);
2916   }
2917   S->setNoWrapFlags(Flags);
2918   return S;
2919 }
2920 
2921 const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS,
2922                                          const SCEV *RHS) {
2923   SmallVector<const SCEV *, 2> Ops;
2924   Ops.push_back(LHS);
2925   Ops.push_back(RHS);
2926   return getSMaxExpr(Ops);
2927 }
2928 
2929 const SCEV *
2930 ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
2931   assert(!Ops.empty() && "Cannot get empty smax!");
2932   if (Ops.size() == 1) return Ops[0];
2933 #ifndef NDEBUG
2934   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2935   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2936     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2937            "SCEVSMaxExpr operand types don't match!");
2938 #endif
2939 
2940   // Sort by complexity, this groups all similar expression types together.
2941   GroupByComplexity(Ops, LI);
2942 
2943   // If there are any constants, fold them together.
2944   unsigned Idx = 0;
2945   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2946     ++Idx;
2947     assert(Idx < Ops.size());
2948     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2949       // We found two constants, fold them together!
2950       ConstantInt *Fold = ConstantInt::get(getContext(),
2951                               APIntOps::smax(LHSC->getValue()->getValue(),
2952                                              RHSC->getValue()->getValue()));
2953       Ops[0] = getConstant(Fold);
2954       Ops.erase(Ops.begin()+1);  // Erase the folded element
2955       if (Ops.size() == 1) return Ops[0];
2956       LHSC = cast<SCEVConstant>(Ops[0]);
2957     }
2958 
2959     // If we are left with a constant minimum-int, strip it off.
2960     if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
2961       Ops.erase(Ops.begin());
2962       --Idx;
2963     } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(true)) {
2964       // If we have an smax with a constant maximum-int, it will always be
2965       // maximum-int.
2966       return Ops[0];
2967     }
2968 
2969     if (Ops.size() == 1) return Ops[0];
2970   }
2971 
2972   // Find the first SMax
2973   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
2974     ++Idx;
2975 
2976   // Check to see if one of the operands is an SMax. If so, expand its operands
2977   // onto our operand list, and recurse to simplify.
2978   if (Idx < Ops.size()) {
2979     bool DeletedSMax = false;
2980     while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
2981       Ops.erase(Ops.begin()+Idx);
2982       Ops.append(SMax->op_begin(), SMax->op_end());
2983       DeletedSMax = true;
2984     }
2985 
2986     if (DeletedSMax)
2987       return getSMaxExpr(Ops);
2988   }
2989 
2990   // Okay, check to see if the same value occurs in the operand list twice.  If
2991   // so, delete one.  Since we sorted the list, these values are required to
2992   // be adjacent.
2993   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
2994     //  X smax Y smax Y  -->  X smax Y
2995     //  X smax Y         -->  X, if X is always greater than Y
2996     if (Ops[i] == Ops[i+1] ||
2997         isKnownPredicate(ICmpInst::ICMP_SGE, Ops[i], Ops[i+1])) {
2998       Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
2999       --i; --e;
3000     } else if (isKnownPredicate(ICmpInst::ICMP_SLE, Ops[i], Ops[i+1])) {
3001       Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
3002       --i; --e;
3003     }
3004 
3005   if (Ops.size() == 1) return Ops[0];
3006 
3007   assert(!Ops.empty() && "Reduced smax down to nothing!");
3008 
3009   // Okay, it looks like we really DO need an smax expr.  Check to see if we
3010   // already have one, otherwise create a new one.
3011   FoldingSetNodeID ID;
3012   ID.AddInteger(scSMaxExpr);
3013   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3014     ID.AddPointer(Ops[i]);
3015   void *IP = nullptr;
3016   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3017   const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3018   std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3019   SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator),
3020                                              O, Ops.size());
3021   UniqueSCEVs.InsertNode(S, IP);
3022   return S;
3023 }
3024 
3025 const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS,
3026                                          const SCEV *RHS) {
3027   SmallVector<const SCEV *, 2> Ops;
3028   Ops.push_back(LHS);
3029   Ops.push_back(RHS);
3030   return getUMaxExpr(Ops);
3031 }
3032 
3033 const SCEV *
3034 ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
3035   assert(!Ops.empty() && "Cannot get empty umax!");
3036   if (Ops.size() == 1) return Ops[0];
3037 #ifndef NDEBUG
3038   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3039   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3040     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3041            "SCEVUMaxExpr operand types don't match!");
3042 #endif
3043 
3044   // Sort by complexity, this groups all similar expression types together.
3045   GroupByComplexity(Ops, LI);
3046 
3047   // If there are any constants, fold them together.
3048   unsigned Idx = 0;
3049   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3050     ++Idx;
3051     assert(Idx < Ops.size());
3052     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3053       // We found two constants, fold them together!
3054       ConstantInt *Fold = ConstantInt::get(getContext(),
3055                               APIntOps::umax(LHSC->getValue()->getValue(),
3056                                              RHSC->getValue()->getValue()));
3057       Ops[0] = getConstant(Fold);
3058       Ops.erase(Ops.begin()+1);  // Erase the folded element
3059       if (Ops.size() == 1) return Ops[0];
3060       LHSC = cast<SCEVConstant>(Ops[0]);
3061     }
3062 
3063     // If we are left with a constant minimum-int, strip it off.
3064     if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
3065       Ops.erase(Ops.begin());
3066       --Idx;
3067     } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(false)) {
3068       // If we have an umax with a constant maximum-int, it will always be
3069       // maximum-int.
3070       return Ops[0];
3071     }
3072 
3073     if (Ops.size() == 1) return Ops[0];
3074   }
3075 
3076   // Find the first UMax
3077   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
3078     ++Idx;
3079 
3080   // Check to see if one of the operands is a UMax. If so, expand its operands
3081   // onto our operand list, and recurse to simplify.
3082   if (Idx < Ops.size()) {
3083     bool DeletedUMax = false;
3084     while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
3085       Ops.erase(Ops.begin()+Idx);
3086       Ops.append(UMax->op_begin(), UMax->op_end());
3087       DeletedUMax = true;
3088     }
3089 
3090     if (DeletedUMax)
3091       return getUMaxExpr(Ops);
3092   }
3093 
3094   // Okay, check to see if the same value occurs in the operand list twice.  If
3095   // so, delete one.  Since we sorted the list, these values are required to
3096   // be adjacent.
3097   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
3098     //  X umax Y umax Y  -->  X umax Y
3099     //  X umax Y         -->  X, if X is always greater than Y
3100     if (Ops[i] == Ops[i+1] ||
3101         isKnownPredicate(ICmpInst::ICMP_UGE, Ops[i], Ops[i+1])) {
3102       Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
3103       --i; --e;
3104     } else if (isKnownPredicate(ICmpInst::ICMP_ULE, Ops[i], Ops[i+1])) {
3105       Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
3106       --i; --e;
3107     }
3108 
3109   if (Ops.size() == 1) return Ops[0];
3110 
3111   assert(!Ops.empty() && "Reduced umax down to nothing!");
3112 
3113   // Okay, it looks like we really DO need a umax expr.  Check to see if we
3114   // already have one, otherwise create a new one.
3115   FoldingSetNodeID ID;
3116   ID.AddInteger(scUMaxExpr);
3117   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3118     ID.AddPointer(Ops[i]);
3119   void *IP = nullptr;
3120   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3121   const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3122   std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3123   SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator),
3124                                              O, Ops.size());
3125   UniqueSCEVs.InsertNode(S, IP);
3126   return S;
3127 }
3128 
3129 const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
3130                                          const SCEV *RHS) {
3131   // ~smax(~x, ~y) == smin(x, y).
3132   return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
3133 }
3134 
3135 const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
3136                                          const SCEV *RHS) {
3137   // ~umax(~x, ~y) == umin(x, y)
3138   return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
3139 }
3140 
3141 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
3142   // We can bypass creating a target-independent
3143   // constant expression and then folding it back into a ConstantInt.
3144   // This is just a compile-time optimization.
3145   return getConstant(IntTy,
3146                      F->getParent()->getDataLayout().getTypeAllocSize(AllocTy));
3147 }
3148 
3149 const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
3150                                              StructType *STy,
3151                                              unsigned FieldNo) {
3152   // We can bypass creating a target-independent
3153   // constant expression and then folding it back into a ConstantInt.
3154   // This is just a compile-time optimization.
3155   return getConstant(
3156       IntTy,
3157       F->getParent()->getDataLayout().getStructLayout(STy)->getElementOffset(
3158           FieldNo));
3159 }
3160 
3161 const SCEV *ScalarEvolution::getUnknown(Value *V) {
3162   // Don't attempt to do anything other than create a SCEVUnknown object
3163   // here.  createSCEV only calls getUnknown after checking for all other
3164   // interesting possibilities, and any other code that calls getUnknown
3165   // is doing so in order to hide a value from SCEV canonicalization.
3166 
3167   FoldingSetNodeID ID;
3168   ID.AddInteger(scUnknown);
3169   ID.AddPointer(V);
3170   void *IP = nullptr;
3171   if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
3172     assert(cast<SCEVUnknown>(S)->getValue() == V &&
3173            "Stale SCEVUnknown in uniquing map!");
3174     return S;
3175   }
3176   SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
3177                                             FirstUnknown);
3178   FirstUnknown = cast<SCEVUnknown>(S);
3179   UniqueSCEVs.InsertNode(S, IP);
3180   return S;
3181 }
3182 
3183 //===----------------------------------------------------------------------===//
3184 //            Basic SCEV Analysis and PHI Idiom Recognition Code
3185 //
3186 
3187 /// isSCEVable - Test if values of the given type are analyzable within
3188 /// the SCEV framework. This primarily includes integer types, and it
3189 /// can optionally include pointer types if the ScalarEvolution class
3190 /// has access to target-specific information.
3191 bool ScalarEvolution::isSCEVable(Type *Ty) const {
3192   // Integers and pointers are always SCEVable.
3193   return Ty->isIntegerTy() || Ty->isPointerTy();
3194 }
3195 
3196 /// getTypeSizeInBits - Return the size in bits of the specified type,
3197 /// for which isSCEVable must return true.
3198 uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
3199   assert(isSCEVable(Ty) && "Type is not SCEVable!");
3200   return F->getParent()->getDataLayout().getTypeSizeInBits(Ty);
3201 }
3202 
3203 /// getEffectiveSCEVType - Return a type with the same bitwidth as
3204 /// the given type and which represents how SCEV will treat the given
3205 /// type, for which isSCEVable must return true. For pointer types,
3206 /// this is the pointer-sized integer type.
3207 Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
3208   assert(isSCEVable(Ty) && "Type is not SCEVable!");
3209 
3210   if (Ty->isIntegerTy()) {
3211     return Ty;
3212   }
3213 
3214   // The only other support type is pointer.
3215   assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
3216   return F->getParent()->getDataLayout().getIntPtrType(Ty);
3217 }
3218 
3219 const SCEV *ScalarEvolution::getCouldNotCompute() {
3220   return &CouldNotCompute;
3221 }
3222 
3223 namespace {
3224   // Helper class working with SCEVTraversal to figure out if a SCEV contains
3225   // a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne
3226   // is set iff if find such SCEVUnknown.
3227   //
3228   struct FindInvalidSCEVUnknown {
3229     bool FindOne;
3230     FindInvalidSCEVUnknown() { FindOne = false; }
3231     bool follow(const SCEV *S) {
3232       switch (static_cast<SCEVTypes>(S->getSCEVType())) {
3233       case scConstant:
3234         return false;
3235       case scUnknown:
3236         if (!cast<SCEVUnknown>(S)->getValue())
3237           FindOne = true;
3238         return false;
3239       default:
3240         return true;
3241       }
3242     }
3243     bool isDone() const { return FindOne; }
3244   };
3245 }
3246 
3247 bool ScalarEvolution::checkValidity(const SCEV *S) const {
3248   FindInvalidSCEVUnknown F;
3249   SCEVTraversal<FindInvalidSCEVUnknown> ST(F);
3250   ST.visitAll(S);
3251 
3252   return !F.FindOne;
3253 }
3254 
3255 /// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
3256 /// expression and create a new one.
3257 const SCEV *ScalarEvolution::getSCEV(Value *V) {
3258   assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
3259 
3260   ValueExprMapType::iterator I = ValueExprMap.find_as(V);
3261   if (I != ValueExprMap.end()) {
3262     const SCEV *S = I->second;
3263     if (checkValidity(S))
3264       return S;
3265     else
3266       ValueExprMap.erase(I);
3267   }
3268   const SCEV *S = createSCEV(V);
3269 
3270   // The process of creating a SCEV for V may have caused other SCEVs
3271   // to have been created, so it's necessary to insert the new entry
3272   // from scratch, rather than trying to remember the insert position
3273   // above.
3274   ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S));
3275   return S;
3276 }
3277 
3278 /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
3279 ///
3280 const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) {
3281   if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
3282     return getConstant(
3283                cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
3284 
3285   Type *Ty = V->getType();
3286   Ty = getEffectiveSCEVType(Ty);
3287   return getMulExpr(V,
3288                   getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))));
3289 }
3290 
3291 /// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
3292 const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
3293   if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
3294     return getConstant(
3295                 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
3296 
3297   Type *Ty = V->getType();
3298   Ty = getEffectiveSCEVType(Ty);
3299   const SCEV *AllOnes =
3300                    getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty)));
3301   return getMinusSCEV(AllOnes, V);
3302 }
3303 
3304 /// getMinusSCEV - Return LHS-RHS.  Minus is represented in SCEV as A+B*-1.
3305 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
3306                                           SCEV::NoWrapFlags Flags) {
3307   assert(!maskFlags(Flags, SCEV::FlagNUW) && "subtraction does not have NUW");
3308 
3309   // Fast path: X - X --> 0.
3310   if (LHS == RHS)
3311     return getConstant(LHS->getType(), 0);
3312 
3313   // X - Y --> X + -Y.
3314   // X -(nsw || nuw) Y --> X + -Y.
3315   return getAddExpr(LHS, getNegativeSCEV(RHS));
3316 }
3317 
3318 /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
3319 /// input value to the specified type.  If the type must be extended, it is zero
3320 /// extended.
3321 const SCEV *
3322 ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) {
3323   Type *SrcTy = V->getType();
3324   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3325          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3326          "Cannot truncate or zero extend with non-integer arguments!");
3327   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3328     return V;  // No conversion
3329   if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
3330     return getTruncateExpr(V, Ty);
3331   return getZeroExtendExpr(V, Ty);
3332 }
3333 
3334 /// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the
3335 /// input value to the specified type.  If the type must be extended, it is sign
3336 /// extended.
3337 const SCEV *
3338 ScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
3339                                          Type *Ty) {
3340   Type *SrcTy = V->getType();
3341   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3342          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3343          "Cannot truncate or zero extend with non-integer arguments!");
3344   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3345     return V;  // No conversion
3346   if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
3347     return getTruncateExpr(V, Ty);
3348   return getSignExtendExpr(V, Ty);
3349 }
3350 
3351 /// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the
3352 /// input value to the specified type.  If the type must be extended, it is zero
3353 /// extended.  The conversion must not be narrowing.
3354 const SCEV *
3355 ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
3356   Type *SrcTy = V->getType();
3357   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3358          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3359          "Cannot noop or zero extend with non-integer arguments!");
3360   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
3361          "getNoopOrZeroExtend cannot truncate!");
3362   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3363     return V;  // No conversion
3364   return getZeroExtendExpr(V, Ty);
3365 }
3366 
3367 /// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the
3368 /// input value to the specified type.  If the type must be extended, it is sign
3369 /// extended.  The conversion must not be narrowing.
3370 const SCEV *
3371 ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
3372   Type *SrcTy = V->getType();
3373   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3374          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3375          "Cannot noop or sign extend with non-integer arguments!");
3376   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
3377          "getNoopOrSignExtend cannot truncate!");
3378   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3379     return V;  // No conversion
3380   return getSignExtendExpr(V, Ty);
3381 }
3382 
3383 /// getNoopOrAnyExtend - Return a SCEV corresponding to a conversion of
3384 /// the input value to the specified type. If the type must be extended,
3385 /// it is extended with unspecified bits. The conversion must not be
3386 /// narrowing.
3387 const SCEV *
3388 ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
3389   Type *SrcTy = V->getType();
3390   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3391          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3392          "Cannot noop or any extend with non-integer arguments!");
3393   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
3394          "getNoopOrAnyExtend cannot truncate!");
3395   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3396     return V;  // No conversion
3397   return getAnyExtendExpr(V, Ty);
3398 }
3399 
3400 /// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the
3401 /// input value to the specified type.  The conversion must not be widening.
3402 const SCEV *
3403 ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
3404   Type *SrcTy = V->getType();
3405   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3406          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3407          "Cannot truncate or noop with non-integer arguments!");
3408   assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
3409          "getTruncateOrNoop cannot extend!");
3410   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3411     return V;  // No conversion
3412   return getTruncateExpr(V, Ty);
3413 }
3414 
3415 /// getUMaxFromMismatchedTypes - Promote the operands to the wider of
3416 /// the types using zero-extension, and then perform a umax operation
3417 /// with them.
3418 const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
3419                                                         const SCEV *RHS) {
3420   const SCEV *PromotedLHS = LHS;
3421   const SCEV *PromotedRHS = RHS;
3422 
3423   if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
3424     PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
3425   else
3426     PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
3427 
3428   return getUMaxExpr(PromotedLHS, PromotedRHS);
3429 }
3430 
3431 /// getUMinFromMismatchedTypes - Promote the operands to the wider of
3432 /// the types using zero-extension, and then perform a umin operation
3433 /// with them.
3434 const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
3435                                                         const SCEV *RHS) {
3436   const SCEV *PromotedLHS = LHS;
3437   const SCEV *PromotedRHS = RHS;
3438 
3439   if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
3440     PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
3441   else
3442     PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
3443 
3444   return getUMinExpr(PromotedLHS, PromotedRHS);
3445 }
3446 
3447 /// getPointerBase - Transitively follow the chain of pointer-type operands
3448 /// until reaching a SCEV that does not have a single pointer operand. This
3449 /// returns a SCEVUnknown pointer for well-formed pointer-type expressions,
3450 /// but corner cases do exist.
3451 const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
3452   // A pointer operand may evaluate to a nonpointer expression, such as null.
3453   if (!V->getType()->isPointerTy())
3454     return V;
3455 
3456   if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
3457     return getPointerBase(Cast->getOperand());
3458   }
3459   else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
3460     const SCEV *PtrOp = nullptr;
3461     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
3462          I != E; ++I) {
3463       if ((*I)->getType()->isPointerTy()) {
3464         // Cannot find the base of an expression with multiple pointer operands.
3465         if (PtrOp)
3466           return V;
3467         PtrOp = *I;
3468       }
3469     }
3470     if (!PtrOp)
3471       return V;
3472     return getPointerBase(PtrOp);
3473   }
3474   return V;
3475 }
3476 
3477 /// PushDefUseChildren - Push users of the given Instruction
3478 /// onto the given Worklist.
3479 static void
3480 PushDefUseChildren(Instruction *I,
3481                    SmallVectorImpl<Instruction *> &Worklist) {
3482   // Push the def-use children onto the Worklist stack.
3483   for (User *U : I->users())
3484     Worklist.push_back(cast<Instruction>(U));
3485 }
3486 
3487 /// ForgetSymbolicValue - This looks up computed SCEV values for all
3488 /// instructions that depend on the given instruction and removes them from
3489 /// the ValueExprMapType map if they reference SymName. This is used during PHI
3490 /// resolution.
3491 void
3492 ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
3493   SmallVector<Instruction *, 16> Worklist;
3494   PushDefUseChildren(PN, Worklist);
3495 
3496   SmallPtrSet<Instruction *, 8> Visited;
3497   Visited.insert(PN);
3498   while (!Worklist.empty()) {
3499     Instruction *I = Worklist.pop_back_val();
3500     if (!Visited.insert(I).second)
3501       continue;
3502 
3503     ValueExprMapType::iterator It =
3504       ValueExprMap.find_as(static_cast<Value *>(I));
3505     if (It != ValueExprMap.end()) {
3506       const SCEV *Old = It->second;
3507 
3508       // Short-circuit the def-use traversal if the symbolic name
3509       // ceases to appear in expressions.
3510       if (Old != SymName && !hasOperand(Old, SymName))
3511         continue;
3512 
3513       // SCEVUnknown for a PHI either means that it has an unrecognized
3514       // structure, it's a PHI that's in the progress of being computed
3515       // by createNodeForPHI, or it's a single-value PHI. In the first case,
3516       // additional loop trip count information isn't going to change anything.
3517       // In the second case, createNodeForPHI will perform the necessary
3518       // updates on its own when it gets to that point. In the third, we do
3519       // want to forget the SCEVUnknown.
3520       if (!isa<PHINode>(I) ||
3521           !isa<SCEVUnknown>(Old) ||
3522           (I != PN && Old == SymName)) {
3523         forgetMemoizedResults(Old);
3524         ValueExprMap.erase(It);
3525       }
3526     }
3527 
3528     PushDefUseChildren(I, Worklist);
3529   }
3530 }
3531 
3532 /// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
3533 /// a loop header, making it a potential recurrence, or it doesn't.
3534 ///
3535 const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
3536   if (const Loop *L = LI->getLoopFor(PN->getParent()))
3537     if (L->getHeader() == PN->getParent()) {
3538       // The loop may have multiple entrances or multiple exits; we can analyze
3539       // this phi as an addrec if it has a unique entry value and a unique
3540       // backedge value.
3541       Value *BEValueV = nullptr, *StartValueV = nullptr;
3542       for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
3543         Value *V = PN->getIncomingValue(i);
3544         if (L->contains(PN->getIncomingBlock(i))) {
3545           if (!BEValueV) {
3546             BEValueV = V;
3547           } else if (BEValueV != V) {
3548             BEValueV = nullptr;
3549             break;
3550           }
3551         } else if (!StartValueV) {
3552           StartValueV = V;
3553         } else if (StartValueV != V) {
3554           StartValueV = nullptr;
3555           break;
3556         }
3557       }
3558       if (BEValueV && StartValueV) {
3559         // While we are analyzing this PHI node, handle its value symbolically.
3560         const SCEV *SymbolicName = getUnknown(PN);
3561         assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
3562                "PHI node already processed?");
3563         ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName));
3564 
3565         // Using this symbolic name for the PHI, analyze the value coming around
3566         // the back-edge.
3567         const SCEV *BEValue = getSCEV(BEValueV);
3568 
3569         // NOTE: If BEValue is loop invariant, we know that the PHI node just
3570         // has a special value for the first iteration of the loop.
3571 
3572         // If the value coming around the backedge is an add with the symbolic
3573         // value we just inserted, then we found a simple induction variable!
3574         if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
3575           // If there is a single occurrence of the symbolic value, replace it
3576           // with a recurrence.
3577           unsigned FoundIndex = Add->getNumOperands();
3578           for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
3579             if (Add->getOperand(i) == SymbolicName)
3580               if (FoundIndex == e) {
3581                 FoundIndex = i;
3582                 break;
3583               }
3584 
3585           if (FoundIndex != Add->getNumOperands()) {
3586             // Create an add with everything but the specified operand.
3587             SmallVector<const SCEV *, 8> Ops;
3588             for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
3589               if (i != FoundIndex)
3590                 Ops.push_back(Add->getOperand(i));
3591             const SCEV *Accum = getAddExpr(Ops);
3592 
3593             // This is not a valid addrec if the step amount is varying each
3594             // loop iteration, but is not itself an addrec in this loop.
3595             if (isLoopInvariant(Accum, L) ||
3596                 (isa<SCEVAddRecExpr>(Accum) &&
3597                  cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
3598               SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
3599 
3600               // If the increment doesn't overflow, then neither the addrec nor
3601               // the post-increment will overflow.
3602               if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) {
3603                 if (OBO->getOperand(0) == PN) {
3604                   if (OBO->hasNoUnsignedWrap())
3605                     Flags = setFlags(Flags, SCEV::FlagNUW);
3606                   if (OBO->hasNoSignedWrap())
3607                     Flags = setFlags(Flags, SCEV::FlagNSW);
3608                 }
3609               } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
3610                 // If the increment is an inbounds GEP, then we know the address
3611                 // space cannot be wrapped around. We cannot make any guarantee
3612                 // about signed or unsigned overflow because pointers are
3613                 // unsigned but we may have a negative index from the base
3614                 // pointer. We can guarantee that no unsigned wrap occurs if the
3615                 // indices form a positive value.
3616                 if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
3617                   Flags = setFlags(Flags, SCEV::FlagNW);
3618 
3619                   const SCEV *Ptr = getSCEV(GEP->getPointerOperand());
3620                   if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))
3621                     Flags = setFlags(Flags, SCEV::FlagNUW);
3622                 }
3623 
3624                 // We cannot transfer nuw and nsw flags from subtraction
3625                 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
3626                 // for instance.
3627               }
3628 
3629               const SCEV *StartVal = getSCEV(StartValueV);
3630               const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
3631 
3632               // Since the no-wrap flags are on the increment, they apply to the
3633               // post-incremented value as well.
3634               if (isLoopInvariant(Accum, L))
3635                 (void)getAddRecExpr(getAddExpr(StartVal, Accum),
3636                                     Accum, L, Flags);
3637 
3638               // Okay, for the entire analysis of this edge we assumed the PHI
3639               // to be symbolic.  We now need to go back and purge all of the
3640               // entries for the scalars that use the symbolic expression.
3641               ForgetSymbolicName(PN, SymbolicName);
3642               ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
3643               return PHISCEV;
3644             }
3645           }
3646         } else if (const SCEVAddRecExpr *AddRec =
3647                      dyn_cast<SCEVAddRecExpr>(BEValue)) {
3648           // Otherwise, this could be a loop like this:
3649           //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
3650           // In this case, j = {1,+,1}  and BEValue is j.
3651           // Because the other in-value of i (0) fits the evolution of BEValue
3652           // i really is an addrec evolution.
3653           if (AddRec->getLoop() == L && AddRec->isAffine()) {
3654             const SCEV *StartVal = getSCEV(StartValueV);
3655 
3656             // If StartVal = j.start - j.stride, we can use StartVal as the
3657             // initial step of the addrec evolution.
3658             if (StartVal == getMinusSCEV(AddRec->getOperand(0),
3659                                          AddRec->getOperand(1))) {
3660               // FIXME: For constant StartVal, we should be able to infer
3661               // no-wrap flags.
3662               const SCEV *PHISCEV =
3663                 getAddRecExpr(StartVal, AddRec->getOperand(1), L,
3664                               SCEV::FlagAnyWrap);
3665 
3666               // Okay, for the entire analysis of this edge we assumed the PHI
3667               // to be symbolic.  We now need to go back and purge all of the
3668               // entries for the scalars that use the symbolic expression.
3669               ForgetSymbolicName(PN, SymbolicName);
3670               ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
3671               return PHISCEV;
3672             }
3673           }
3674         }
3675       }
3676     }
3677 
3678   // If the PHI has a single incoming value, follow that value, unless the
3679   // PHI's incoming blocks are in a different loop, in which case doing so
3680   // risks breaking LCSSA form. Instcombine would normally zap these, but
3681   // it doesn't have DominatorTree information, so it may miss cases.
3682   if (Value *V =
3683           SimplifyInstruction(PN, F->getParent()->getDataLayout(), TLI, DT, AC))
3684     if (LI->replacementPreservesLCSSAForm(PN, V))
3685       return getSCEV(V);
3686 
3687   // If it's not a loop phi, we can't handle it yet.
3688   return getUnknown(PN);
3689 }
3690 
3691 /// createNodeForGEP - Expand GEP instructions into add and multiply
3692 /// operations. This allows them to be analyzed by regular SCEV code.
3693 ///
3694 const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
3695   Type *IntPtrTy = getEffectiveSCEVType(GEP->getType());
3696   Value *Base = GEP->getOperand(0);
3697   // Don't attempt to analyze GEPs over unsized objects.
3698   if (!Base->getType()->getPointerElementType()->isSized())
3699     return getUnknown(GEP);
3700 
3701   // Don't blindly transfer the inbounds flag from the GEP instruction to the
3702   // Add expression, because the Instruction may be guarded by control flow
3703   // and the no-overflow bits may not be valid for the expression in any
3704   // context.
3705   SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3706 
3707   const SCEV *TotalOffset = getConstant(IntPtrTy, 0);
3708   gep_type_iterator GTI = gep_type_begin(GEP);
3709   for (GetElementPtrInst::op_iterator I = std::next(GEP->op_begin()),
3710                                       E = GEP->op_end();
3711        I != E; ++I) {
3712     Value *Index = *I;
3713     // Compute the (potentially symbolic) offset in bytes for this index.
3714     if (StructType *STy = dyn_cast<StructType>(*GTI++)) {
3715       // For a struct, add the member offset.
3716       unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue();
3717       const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo);
3718 
3719       // Add the field offset to the running total offset.
3720       TotalOffset = getAddExpr(TotalOffset, FieldOffset);
3721     } else {
3722       // For an array, add the element offset, explicitly scaled.
3723       const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, *GTI);
3724       const SCEV *IndexS = getSCEV(Index);
3725       // Getelementptr indices are signed.
3726       IndexS = getTruncateOrSignExtend(IndexS, IntPtrTy);
3727 
3728       // Multiply the index by the element size to compute the element offset.
3729       const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize, Wrap);
3730 
3731       // Add the element offset to the running total offset.
3732       TotalOffset = getAddExpr(TotalOffset, LocalOffset);
3733     }
3734   }
3735 
3736   // Get the SCEV for the GEP base.
3737   const SCEV *BaseS = getSCEV(Base);
3738 
3739   // Add the total offset from all the GEP indices to the base.
3740   return getAddExpr(BaseS, TotalOffset, Wrap);
3741 }
3742 
3743 /// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
3744 /// guaranteed to end in (at every loop iteration).  It is, at the same time,
3745 /// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
3746 /// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
3747 uint32_t
3748 ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
3749   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3750     return C->getValue()->getValue().countTrailingZeros();
3751 
3752   if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
3753     return std::min(GetMinTrailingZeros(T->getOperand()),
3754                     (uint32_t)getTypeSizeInBits(T->getType()));
3755 
3756   if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
3757     uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
3758     return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
3759              getTypeSizeInBits(E->getType()) : OpRes;
3760   }
3761 
3762   if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
3763     uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
3764     return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
3765              getTypeSizeInBits(E->getType()) : OpRes;
3766   }
3767 
3768   if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
3769     // The result is the min of all operands results.
3770     uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
3771     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
3772       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
3773     return MinOpRes;
3774   }
3775 
3776   if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
3777     // The result is the sum of all operands results.
3778     uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
3779     uint32_t BitWidth = getTypeSizeInBits(M->getType());
3780     for (unsigned i = 1, e = M->getNumOperands();
3781          SumOpRes != BitWidth && i != e; ++i)
3782       SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
3783                           BitWidth);
3784     return SumOpRes;
3785   }
3786 
3787   if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
3788     // The result is the min of all operands results.
3789     uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
3790     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
3791       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
3792     return MinOpRes;
3793   }
3794 
3795   if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
3796     // The result is the min of all operands results.
3797     uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
3798     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
3799       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
3800     return MinOpRes;
3801   }
3802 
3803   if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
3804     // The result is the min of all operands results.
3805     uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
3806     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
3807       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
3808     return MinOpRes;
3809   }
3810 
3811   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3812     // For a SCEVUnknown, ask ValueTracking.
3813     unsigned BitWidth = getTypeSizeInBits(U->getType());
3814     APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
3815     computeKnownBits(U->getValue(), Zeros, Ones,
3816                      F->getParent()->getDataLayout(), 0, AC, nullptr, DT);
3817     return Zeros.countTrailingOnes();
3818   }
3819 
3820   // SCEVUDivExpr
3821   return 0;
3822 }
3823 
3824 /// GetRangeFromMetadata - Helper method to assign a range to V from
3825 /// metadata present in the IR.
3826 static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
3827   if (Instruction *I = dyn_cast<Instruction>(V)) {
3828     if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) {
3829       ConstantRange TotalRange(
3830           cast<IntegerType>(I->getType())->getBitWidth(), false);
3831 
3832       unsigned NumRanges = MD->getNumOperands() / 2;
3833       assert(NumRanges >= 1);
3834 
3835       for (unsigned i = 0; i < NumRanges; ++i) {
3836         ConstantInt *Lower =
3837             mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 0));
3838         ConstantInt *Upper =
3839             mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 1));
3840         ConstantRange Range(Lower->getValue(), Upper->getValue());
3841         TotalRange = TotalRange.unionWith(Range);
3842       }
3843 
3844       return TotalRange;
3845     }
3846   }
3847 
3848   return None;
3849 }
3850 
3851 /// getRange - Determine the range for a particular SCEV.  If SignHint is
3852 /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
3853 /// with a "cleaner" unsigned (resp. signed) representation.
3854 ///
3855 ConstantRange
3856 ScalarEvolution::getRange(const SCEV *S,
3857                           ScalarEvolution::RangeSignHint SignHint) {
3858   DenseMap<const SCEV *, ConstantRange> &Cache =
3859       SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
3860                                                        : SignedRanges;
3861 
3862   // See if we've computed this range already.
3863   DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S);
3864   if (I != Cache.end())
3865     return I->second;
3866 
3867   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3868     return setRange(C, SignHint, ConstantRange(C->getValue()->getValue()));
3869 
3870   unsigned BitWidth = getTypeSizeInBits(S->getType());
3871   ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
3872 
3873   // If the value has known zeros, the maximum value will have those known zeros
3874   // as well.
3875   uint32_t TZ = GetMinTrailingZeros(S);
3876   if (TZ != 0) {
3877     if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED)
3878       ConservativeResult =
3879           ConstantRange(APInt::getMinValue(BitWidth),
3880                         APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
3881     else
3882       ConservativeResult = ConstantRange(
3883           APInt::getSignedMinValue(BitWidth),
3884           APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
3885   }
3886 
3887   if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
3888     ConstantRange X = getRange(Add->getOperand(0), SignHint);
3889     for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
3890       X = X.add(getRange(Add->getOperand(i), SignHint));
3891     return setRange(Add, SignHint, ConservativeResult.intersectWith(X));
3892   }
3893 
3894   if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
3895     ConstantRange X = getRange(Mul->getOperand(0), SignHint);
3896     for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
3897       X = X.multiply(getRange(Mul->getOperand(i), SignHint));
3898     return setRange(Mul, SignHint, ConservativeResult.intersectWith(X));
3899   }
3900 
3901   if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
3902     ConstantRange X = getRange(SMax->getOperand(0), SignHint);
3903     for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
3904       X = X.smax(getRange(SMax->getOperand(i), SignHint));
3905     return setRange(SMax, SignHint, ConservativeResult.intersectWith(X));
3906   }
3907 
3908   if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
3909     ConstantRange X = getRange(UMax->getOperand(0), SignHint);
3910     for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
3911       X = X.umax(getRange(UMax->getOperand(i), SignHint));
3912     return setRange(UMax, SignHint, ConservativeResult.intersectWith(X));
3913   }
3914 
3915   if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
3916     ConstantRange X = getRange(UDiv->getLHS(), SignHint);
3917     ConstantRange Y = getRange(UDiv->getRHS(), SignHint);
3918     return setRange(UDiv, SignHint,
3919                     ConservativeResult.intersectWith(X.udiv(Y)));
3920   }
3921 
3922   if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
3923     ConstantRange X = getRange(ZExt->getOperand(), SignHint);
3924     return setRange(ZExt, SignHint,
3925                     ConservativeResult.intersectWith(X.zeroExtend(BitWidth)));
3926   }
3927 
3928   if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
3929     ConstantRange X = getRange(SExt->getOperand(), SignHint);
3930     return setRange(SExt, SignHint,
3931                     ConservativeResult.intersectWith(X.signExtend(BitWidth)));
3932   }
3933 
3934   if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
3935     ConstantRange X = getRange(Trunc->getOperand(), SignHint);
3936     return setRange(Trunc, SignHint,
3937                     ConservativeResult.intersectWith(X.truncate(BitWidth)));
3938   }
3939 
3940   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
3941     // If there's no unsigned wrap, the value will never be less than its
3942     // initial value.
3943     if (AddRec->getNoWrapFlags(SCEV::FlagNUW))
3944       if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart()))
3945         if (!C->getValue()->isZero())
3946           ConservativeResult =
3947             ConservativeResult.intersectWith(
3948               ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0)));
3949 
3950     // If there's no signed wrap, and all the operands have the same sign or
3951     // zero, the value won't ever change sign.
3952     if (AddRec->getNoWrapFlags(SCEV::FlagNSW)) {
3953       bool AllNonNeg = true;
3954       bool AllNonPos = true;
3955       for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3956         if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false;
3957         if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false;
3958       }
3959       if (AllNonNeg)
3960         ConservativeResult = ConservativeResult.intersectWith(
3961           ConstantRange(APInt(BitWidth, 0),
3962                         APInt::getSignedMinValue(BitWidth)));
3963       else if (AllNonPos)
3964         ConservativeResult = ConservativeResult.intersectWith(
3965           ConstantRange(APInt::getSignedMinValue(BitWidth),
3966                         APInt(BitWidth, 1)));
3967     }
3968 
3969     // TODO: non-affine addrec
3970     if (AddRec->isAffine()) {
3971       Type *Ty = AddRec->getType();
3972       const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
3973       if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
3974           getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
3975 
3976         // Check for overflow.  This must be done with ConstantRange arithmetic
3977         // because we could be called from within the ScalarEvolution overflow
3978         // checking code.
3979 
3980         MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
3981         ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
3982         ConstantRange ZExtMaxBECountRange =
3983             MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1);
3984 
3985         const SCEV *Start = AddRec->getStart();
3986         const SCEV *Step = AddRec->getStepRecurrence(*this);
3987         ConstantRange StepSRange = getSignedRange(Step);
3988         ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1);
3989 
3990         ConstantRange StartURange = getUnsignedRange(Start);
3991         ConstantRange EndURange =
3992             StartURange.add(MaxBECountRange.multiply(StepSRange));
3993 
3994         // Check for unsigned overflow.
3995         ConstantRange ZExtStartURange =
3996             StartURange.zextOrTrunc(BitWidth * 2 + 1);
3997         ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1);
3998         if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) ==
3999             ZExtEndURange) {
4000           APInt Min = APIntOps::umin(StartURange.getUnsignedMin(),
4001                                      EndURange.getUnsignedMin());
4002           APInt Max = APIntOps::umax(StartURange.getUnsignedMax(),
4003                                      EndURange.getUnsignedMax());
4004           bool IsFullRange = Min.isMinValue() && Max.isMaxValue();
4005           if (!IsFullRange)
4006             ConservativeResult =
4007                 ConservativeResult.intersectWith(ConstantRange(Min, Max + 1));
4008         }
4009 
4010         ConstantRange StartSRange = getSignedRange(Start);
4011         ConstantRange EndSRange =
4012             StartSRange.add(MaxBECountRange.multiply(StepSRange));
4013 
4014         // Check for signed overflow. This must be done with ConstantRange
4015         // arithmetic because we could be called from within the ScalarEvolution
4016         // overflow checking code.
4017         ConstantRange SExtStartSRange =
4018             StartSRange.sextOrTrunc(BitWidth * 2 + 1);
4019         ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1);
4020         if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) ==
4021             SExtEndSRange) {
4022           APInt Min = APIntOps::smin(StartSRange.getSignedMin(),
4023                                      EndSRange.getSignedMin());
4024           APInt Max = APIntOps::smax(StartSRange.getSignedMax(),
4025                                      EndSRange.getSignedMax());
4026           bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue();
4027           if (!IsFullRange)
4028             ConservativeResult =
4029                 ConservativeResult.intersectWith(ConstantRange(Min, Max + 1));
4030         }
4031       }
4032     }
4033 
4034     return setRange(AddRec, SignHint, ConservativeResult);
4035   }
4036 
4037   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
4038     // Check if the IR explicitly contains !range metadata.
4039     Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
4040     if (MDRange.hasValue())
4041       ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue());
4042 
4043     // Split here to avoid paying the compile-time cost of calling both
4044     // computeKnownBits and ComputeNumSignBits.  This restriction can be lifted
4045     // if needed.
4046     const DataLayout &DL = F->getParent()->getDataLayout();
4047     if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
4048       // For a SCEVUnknown, ask ValueTracking.
4049       APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
4050       computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT);
4051       if (Ones != ~Zeros + 1)
4052         ConservativeResult =
4053             ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1));
4054     } else {
4055       assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED &&
4056              "generalize as needed!");
4057       unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AC, nullptr, DT);
4058       if (NS > 1)
4059         ConservativeResult = ConservativeResult.intersectWith(
4060             ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
4061                           APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1));
4062     }
4063 
4064     return setRange(U, SignHint, ConservativeResult);
4065   }
4066 
4067   return setRange(S, SignHint, ConservativeResult);
4068 }
4069 
4070 /// createSCEV - We know that there is no SCEV for the specified value.
4071 /// Analyze the expression.
4072 ///
4073 const SCEV *ScalarEvolution::createSCEV(Value *V) {
4074   if (!isSCEVable(V->getType()))
4075     return getUnknown(V);
4076 
4077   unsigned Opcode = Instruction::UserOp1;
4078   if (Instruction *I = dyn_cast<Instruction>(V)) {
4079     Opcode = I->getOpcode();
4080 
4081     // Don't attempt to analyze instructions in blocks that aren't
4082     // reachable. Such instructions don't matter, and they aren't required
4083     // to obey basic rules for definitions dominating uses which this
4084     // analysis depends on.
4085     if (!DT->isReachableFromEntry(I->getParent()))
4086       return getUnknown(V);
4087   } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
4088     Opcode = CE->getOpcode();
4089   else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
4090     return getConstant(CI);
4091   else if (isa<ConstantPointerNull>(V))
4092     return getConstant(V->getType(), 0);
4093   else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
4094     return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee());
4095   else
4096     return getUnknown(V);
4097 
4098   Operator *U = cast<Operator>(V);
4099   switch (Opcode) {
4100   case Instruction::Add: {
4101     // The simple thing to do would be to just call getSCEV on both operands
4102     // and call getAddExpr with the result. However if we're looking at a
4103     // bunch of things all added together, this can be quite inefficient,
4104     // because it leads to N-1 getAddExpr calls for N ultimate operands.
4105     // Instead, gather up all the operands and make a single getAddExpr call.
4106     // LLVM IR canonical form means we need only traverse the left operands.
4107     //
4108     // Don't apply this instruction's NSW or NUW flags to the new
4109     // expression. The instruction may be guarded by control flow that the
4110     // no-wrap behavior depends on. Non-control-equivalent instructions can be
4111     // mapped to the same SCEV expression, and it would be incorrect to transfer
4112     // NSW/NUW semantics to those operations.
4113     SmallVector<const SCEV *, 4> AddOps;
4114     AddOps.push_back(getSCEV(U->getOperand(1)));
4115     for (Value *Op = U->getOperand(0); ; Op = U->getOperand(0)) {
4116       unsigned Opcode = Op->getValueID() - Value::InstructionVal;
4117       if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
4118         break;
4119       U = cast<Operator>(Op);
4120       const SCEV *Op1 = getSCEV(U->getOperand(1));
4121       if (Opcode == Instruction::Sub)
4122         AddOps.push_back(getNegativeSCEV(Op1));
4123       else
4124         AddOps.push_back(Op1);
4125     }
4126     AddOps.push_back(getSCEV(U->getOperand(0)));
4127     return getAddExpr(AddOps);
4128   }
4129   case Instruction::Mul: {
4130     // Don't transfer NSW/NUW for the same reason as AddExpr.
4131     SmallVector<const SCEV *, 4> MulOps;
4132     MulOps.push_back(getSCEV(U->getOperand(1)));
4133     for (Value *Op = U->getOperand(0);
4134          Op->getValueID() == Instruction::Mul + Value::InstructionVal;
4135          Op = U->getOperand(0)) {
4136       U = cast<Operator>(Op);
4137       MulOps.push_back(getSCEV(U->getOperand(1)));
4138     }
4139     MulOps.push_back(getSCEV(U->getOperand(0)));
4140     return getMulExpr(MulOps);
4141   }
4142   case Instruction::UDiv:
4143     return getUDivExpr(getSCEV(U->getOperand(0)),
4144                        getSCEV(U->getOperand(1)));
4145   case Instruction::Sub:
4146     return getMinusSCEV(getSCEV(U->getOperand(0)),
4147                         getSCEV(U->getOperand(1)));
4148   case Instruction::And:
4149     // For an expression like x&255 that merely masks off the high bits,
4150     // use zext(trunc(x)) as the SCEV expression.
4151     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
4152       if (CI->isNullValue())
4153         return getSCEV(U->getOperand(1));
4154       if (CI->isAllOnesValue())
4155         return getSCEV(U->getOperand(0));
4156       const APInt &A = CI->getValue();
4157 
4158       // Instcombine's ShrinkDemandedConstant may strip bits out of
4159       // constants, obscuring what would otherwise be a low-bits mask.
4160       // Use computeKnownBits to compute what ShrinkDemandedConstant
4161       // knew about to reconstruct a low-bits mask value.
4162       unsigned LZ = A.countLeadingZeros();
4163       unsigned TZ = A.countTrailingZeros();
4164       unsigned BitWidth = A.getBitWidth();
4165       APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
4166       computeKnownBits(U->getOperand(0), KnownZero, KnownOne,
4167                        F->getParent()->getDataLayout(), 0, AC, nullptr, DT);
4168 
4169       APInt EffectiveMask =
4170           APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
4171       if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) {
4172         const SCEV *MulCount = getConstant(
4173             ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, TZ)));
4174         return getMulExpr(
4175             getZeroExtendExpr(
4176                 getTruncateExpr(
4177                     getUDivExactExpr(getSCEV(U->getOperand(0)), MulCount),
4178                     IntegerType::get(getContext(), BitWidth - LZ - TZ)),
4179                 U->getType()),
4180             MulCount);
4181       }
4182     }
4183     break;
4184 
4185   case Instruction::Or:
4186     // If the RHS of the Or is a constant, we may have something like:
4187     // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
4188     // optimizations will transparently handle this case.
4189     //
4190     // In order for this transformation to be safe, the LHS must be of the
4191     // form X*(2^n) and the Or constant must be less than 2^n.
4192     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
4193       const SCEV *LHS = getSCEV(U->getOperand(0));
4194       const APInt &CIVal = CI->getValue();
4195       if (GetMinTrailingZeros(LHS) >=
4196           (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
4197         // Build a plain add SCEV.
4198         const SCEV *S = getAddExpr(LHS, getSCEV(CI));
4199         // If the LHS of the add was an addrec and it has no-wrap flags,
4200         // transfer the no-wrap flags, since an or won't introduce a wrap.
4201         if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) {
4202           const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS);
4203           const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags(
4204             OldAR->getNoWrapFlags());
4205         }
4206         return S;
4207       }
4208     }
4209     break;
4210   case Instruction::Xor:
4211     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
4212       // If the RHS of the xor is a signbit, then this is just an add.
4213       // Instcombine turns add of signbit into xor as a strength reduction step.
4214       if (CI->getValue().isSignBit())
4215         return getAddExpr(getSCEV(U->getOperand(0)),
4216                           getSCEV(U->getOperand(1)));
4217 
4218       // If the RHS of xor is -1, then this is a not operation.
4219       if (CI->isAllOnesValue())
4220         return getNotSCEV(getSCEV(U->getOperand(0)));
4221 
4222       // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
4223       // This is a variant of the check for xor with -1, and it handles
4224       // the case where instcombine has trimmed non-demanded bits out
4225       // of an xor with -1.
4226       if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0)))
4227         if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1)))
4228           if (BO->getOpcode() == Instruction::And &&
4229               LCI->getValue() == CI->getValue())
4230             if (const SCEVZeroExtendExpr *Z =
4231                   dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) {
4232               Type *UTy = U->getType();
4233               const SCEV *Z0 = Z->getOperand();
4234               Type *Z0Ty = Z0->getType();
4235               unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
4236 
4237               // If C is a low-bits mask, the zero extend is serving to
4238               // mask off the high bits. Complement the operand and
4239               // re-apply the zext.
4240               if (APIntOps::isMask(Z0TySize, CI->getValue()))
4241                 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
4242 
4243               // If C is a single bit, it may be in the sign-bit position
4244               // before the zero-extend. In this case, represent the xor
4245               // using an add, which is equivalent, and re-apply the zext.
4246               APInt Trunc = CI->getValue().trunc(Z0TySize);
4247               if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
4248                   Trunc.isSignBit())
4249                 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
4250                                          UTy);
4251             }
4252     }
4253     break;
4254 
4255   case Instruction::Shl:
4256     // Turn shift left of a constant amount into a multiply.
4257     if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
4258       uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
4259 
4260       // If the shift count is not less than the bitwidth, the result of
4261       // the shift is undefined. Don't try to analyze it, because the
4262       // resolution chosen here may differ from the resolution chosen in
4263       // other parts of the compiler.
4264       if (SA->getValue().uge(BitWidth))
4265         break;
4266 
4267       Constant *X = ConstantInt::get(getContext(),
4268         APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
4269       return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
4270     }
4271     break;
4272 
4273   case Instruction::LShr:
4274     // Turn logical shift right of a constant into a unsigned divide.
4275     if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
4276       uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
4277 
4278       // If the shift count is not less than the bitwidth, the result of
4279       // the shift is undefined. Don't try to analyze it, because the
4280       // resolution chosen here may differ from the resolution chosen in
4281       // other parts of the compiler.
4282       if (SA->getValue().uge(BitWidth))
4283         break;
4284 
4285       Constant *X = ConstantInt::get(getContext(),
4286         APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
4287       return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
4288     }
4289     break;
4290 
4291   case Instruction::AShr:
4292     // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
4293     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1)))
4294       if (Operator *L = dyn_cast<Operator>(U->getOperand(0)))
4295         if (L->getOpcode() == Instruction::Shl &&
4296             L->getOperand(1) == U->getOperand(1)) {
4297           uint64_t BitWidth = getTypeSizeInBits(U->getType());
4298 
4299           // If the shift count is not less than the bitwidth, the result of
4300           // the shift is undefined. Don't try to analyze it, because the
4301           // resolution chosen here may differ from the resolution chosen in
4302           // other parts of the compiler.
4303           if (CI->getValue().uge(BitWidth))
4304             break;
4305 
4306           uint64_t Amt = BitWidth - CI->getZExtValue();
4307           if (Amt == BitWidth)
4308             return getSCEV(L->getOperand(0));       // shift by zero --> noop
4309           return
4310             getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)),
4311                                               IntegerType::get(getContext(),
4312                                                                Amt)),
4313                               U->getType());
4314         }
4315     break;
4316 
4317   case Instruction::Trunc:
4318     return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
4319 
4320   case Instruction::ZExt:
4321     return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
4322 
4323   case Instruction::SExt:
4324     return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
4325 
4326   case Instruction::BitCast:
4327     // BitCasts are no-op casts so we just eliminate the cast.
4328     if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
4329       return getSCEV(U->getOperand(0));
4330     break;
4331 
4332   // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can
4333   // lead to pointer expressions which cannot safely be expanded to GEPs,
4334   // because ScalarEvolution doesn't respect the GEP aliasing rules when
4335   // simplifying integer expressions.
4336 
4337   case Instruction::GetElementPtr:
4338     return createNodeForGEP(cast<GEPOperator>(U));
4339 
4340   case Instruction::PHI:
4341     return createNodeForPHI(cast<PHINode>(U));
4342 
4343   case Instruction::Select:
4344     // This could be a smax or umax that was lowered earlier.
4345     // Try to recover it.
4346     if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
4347       Value *LHS = ICI->getOperand(0);
4348       Value *RHS = ICI->getOperand(1);
4349       switch (ICI->getPredicate()) {
4350       case ICmpInst::ICMP_SLT:
4351       case ICmpInst::ICMP_SLE:
4352         std::swap(LHS, RHS);
4353         // fall through
4354       case ICmpInst::ICMP_SGT:
4355       case ICmpInst::ICMP_SGE:
4356         // a >s b ? a+x : b+x  ->  smax(a, b)+x
4357         // a >s b ? b+x : a+x  ->  smin(a, b)+x
4358         if (getTypeSizeInBits(LHS->getType()) <=
4359             getTypeSizeInBits(U->getType())) {
4360           const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType());
4361           const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType());
4362           const SCEV *LA = getSCEV(U->getOperand(1));
4363           const SCEV *RA = getSCEV(U->getOperand(2));
4364           const SCEV *LDiff = getMinusSCEV(LA, LS);
4365           const SCEV *RDiff = getMinusSCEV(RA, RS);
4366           if (LDiff == RDiff)
4367             return getAddExpr(getSMaxExpr(LS, RS), LDiff);
4368           LDiff = getMinusSCEV(LA, RS);
4369           RDiff = getMinusSCEV(RA, LS);
4370           if (LDiff == RDiff)
4371             return getAddExpr(getSMinExpr(LS, RS), LDiff);
4372         }
4373         break;
4374       case ICmpInst::ICMP_ULT:
4375       case ICmpInst::ICMP_ULE:
4376         std::swap(LHS, RHS);
4377         // fall through
4378       case ICmpInst::ICMP_UGT:
4379       case ICmpInst::ICMP_UGE:
4380         // a >u b ? a+x : b+x  ->  umax(a, b)+x
4381         // a >u b ? b+x : a+x  ->  umin(a, b)+x
4382         if (getTypeSizeInBits(LHS->getType()) <=
4383             getTypeSizeInBits(U->getType())) {
4384           const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
4385           const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType());
4386           const SCEV *LA = getSCEV(U->getOperand(1));
4387           const SCEV *RA = getSCEV(U->getOperand(2));
4388           const SCEV *LDiff = getMinusSCEV(LA, LS);
4389           const SCEV *RDiff = getMinusSCEV(RA, RS);
4390           if (LDiff == RDiff)
4391             return getAddExpr(getUMaxExpr(LS, RS), LDiff);
4392           LDiff = getMinusSCEV(LA, RS);
4393           RDiff = getMinusSCEV(RA, LS);
4394           if (LDiff == RDiff)
4395             return getAddExpr(getUMinExpr(LS, RS), LDiff);
4396         }
4397         break;
4398       case ICmpInst::ICMP_NE:
4399         // n != 0 ? n+x : 1+x  ->  umax(n, 1)+x
4400         if (getTypeSizeInBits(LHS->getType()) <=
4401                 getTypeSizeInBits(U->getType()) &&
4402             isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
4403           const SCEV *One = getConstant(U->getType(), 1);
4404           const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
4405           const SCEV *LA = getSCEV(U->getOperand(1));
4406           const SCEV *RA = getSCEV(U->getOperand(2));
4407           const SCEV *LDiff = getMinusSCEV(LA, LS);
4408           const SCEV *RDiff = getMinusSCEV(RA, One);
4409           if (LDiff == RDiff)
4410             return getAddExpr(getUMaxExpr(One, LS), LDiff);
4411         }
4412         break;
4413       case ICmpInst::ICMP_EQ:
4414         // n == 0 ? 1+x : n+x  ->  umax(n, 1)+x
4415         if (getTypeSizeInBits(LHS->getType()) <=
4416                 getTypeSizeInBits(U->getType()) &&
4417             isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
4418           const SCEV *One = getConstant(U->getType(), 1);
4419           const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
4420           const SCEV *LA = getSCEV(U->getOperand(1));
4421           const SCEV *RA = getSCEV(U->getOperand(2));
4422           const SCEV *LDiff = getMinusSCEV(LA, One);
4423           const SCEV *RDiff = getMinusSCEV(RA, LS);
4424           if (LDiff == RDiff)
4425             return getAddExpr(getUMaxExpr(One, LS), LDiff);
4426         }
4427         break;
4428       default:
4429         break;
4430       }
4431     }
4432 
4433   default: // We cannot analyze this expression.
4434     break;
4435   }
4436 
4437   return getUnknown(V);
4438 }
4439 
4440 
4441 
4442 //===----------------------------------------------------------------------===//
4443 //                   Iteration Count Computation Code
4444 //
4445 
4446 unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) {
4447   if (BasicBlock *ExitingBB = L->getExitingBlock())
4448     return getSmallConstantTripCount(L, ExitingBB);
4449 
4450   // No trip count information for multiple exits.
4451   return 0;
4452 }
4453 
4454 /// getSmallConstantTripCount - Returns the maximum trip count of this loop as a
4455 /// normal unsigned value. Returns 0 if the trip count is unknown or not
4456 /// constant. Will also return 0 if the maximum trip count is very large (>=
4457 /// 2^32).
4458 ///
4459 /// This "trip count" assumes that control exits via ExitingBlock. More
4460 /// precisely, it is the number of times that control may reach ExitingBlock
4461 /// before taking the branch. For loops with multiple exits, it may not be the
4462 /// number times that the loop header executes because the loop may exit
4463 /// prematurely via another branch.
4464 unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L,
4465                                                     BasicBlock *ExitingBlock) {
4466   assert(ExitingBlock && "Must pass a non-null exiting block!");
4467   assert(L->isLoopExiting(ExitingBlock) &&
4468          "Exiting block must actually branch out of the loop!");
4469   const SCEVConstant *ExitCount =
4470       dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
4471   if (!ExitCount)
4472     return 0;
4473 
4474   ConstantInt *ExitConst = ExitCount->getValue();
4475 
4476   // Guard against huge trip counts.
4477   if (ExitConst->getValue().getActiveBits() > 32)
4478     return 0;
4479 
4480   // In case of integer overflow, this returns 0, which is correct.
4481   return ((unsigned)ExitConst->getZExtValue()) + 1;
4482 }
4483 
4484 unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) {
4485   if (BasicBlock *ExitingBB = L->getExitingBlock())
4486     return getSmallConstantTripMultiple(L, ExitingBB);
4487 
4488   // No trip multiple information for multiple exits.
4489   return 0;
4490 }
4491 
4492 /// getSmallConstantTripMultiple - Returns the largest constant divisor of the
4493 /// trip count of this loop as a normal unsigned value, if possible. This
4494 /// means that the actual trip count is always a multiple of the returned
4495 /// value (don't forget the trip count could very well be zero as well!).
4496 ///
4497 /// Returns 1 if the trip count is unknown or not guaranteed to be the
4498 /// multiple of a constant (which is also the case if the trip count is simply
4499 /// constant, use getSmallConstantTripCount for that case), Will also return 1
4500 /// if the trip count is very large (>= 2^32).
4501 ///
4502 /// As explained in the comments for getSmallConstantTripCount, this assumes
4503 /// that control exits the loop via ExitingBlock.
4504 unsigned
4505 ScalarEvolution::getSmallConstantTripMultiple(Loop *L,
4506                                               BasicBlock *ExitingBlock) {
4507   assert(ExitingBlock && "Must pass a non-null exiting block!");
4508   assert(L->isLoopExiting(ExitingBlock) &&
4509          "Exiting block must actually branch out of the loop!");
4510   const SCEV *ExitCount = getExitCount(L, ExitingBlock);
4511   if (ExitCount == getCouldNotCompute())
4512     return 1;
4513 
4514   // Get the trip count from the BE count by adding 1.
4515   const SCEV *TCMul = getAddExpr(ExitCount,
4516                                  getConstant(ExitCount->getType(), 1));
4517   // FIXME: SCEV distributes multiplication as V1*C1 + V2*C1. We could attempt
4518   // to factor simple cases.
4519   if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(TCMul))
4520     TCMul = Mul->getOperand(0);
4521 
4522   const SCEVConstant *MulC = dyn_cast<SCEVConstant>(TCMul);
4523   if (!MulC)
4524     return 1;
4525 
4526   ConstantInt *Result = MulC->getValue();
4527 
4528   // Guard against huge trip counts (this requires checking
4529   // for zero to handle the case where the trip count == -1 and the
4530   // addition wraps).
4531   if (!Result || Result->getValue().getActiveBits() > 32 ||
4532       Result->getValue().getActiveBits() == 0)
4533     return 1;
4534 
4535   return (unsigned)Result->getZExtValue();
4536 }
4537 
4538 // getExitCount - Get the expression for the number of loop iterations for which
4539 // this loop is guaranteed not to exit via ExitingBlock. Otherwise return
4540 // SCEVCouldNotCompute.
4541 const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) {
4542   return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
4543 }
4544 
4545 /// getBackedgeTakenCount - If the specified loop has a predictable
4546 /// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute
4547 /// object. The backedge-taken count is the number of times the loop header
4548 /// will be branched to from within the loop. This is one less than the
4549 /// trip count of the loop, since it doesn't count the first iteration,
4550 /// when the header is branched to from outside the loop.
4551 ///
4552 /// Note that it is not valid to call this method on a loop without a
4553 /// loop-invariant backedge-taken count (see
4554 /// hasLoopInvariantBackedgeTakenCount).
4555 ///
4556 const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
4557   return getBackedgeTakenInfo(L).getExact(this);
4558 }
4559 
4560 /// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except
4561 /// return the least SCEV value that is known never to be less than the
4562 /// actual backedge taken count.
4563 const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
4564   return getBackedgeTakenInfo(L).getMax(this);
4565 }
4566 
4567 /// PushLoopPHIs - Push PHI nodes in the header of the given loop
4568 /// onto the given Worklist.
4569 static void
4570 PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
4571   BasicBlock *Header = L->getHeader();
4572 
4573   // Push all Loop-header PHIs onto the Worklist stack.
4574   for (BasicBlock::iterator I = Header->begin();
4575        PHINode *PN = dyn_cast<PHINode>(I); ++I)
4576     Worklist.push_back(PN);
4577 }
4578 
4579 const ScalarEvolution::BackedgeTakenInfo &
4580 ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
4581   // Initially insert an invalid entry for this loop. If the insertion
4582   // succeeds, proceed to actually compute a backedge-taken count and
4583   // update the value. The temporary CouldNotCompute value tells SCEV
4584   // code elsewhere that it shouldn't attempt to request a new
4585   // backedge-taken count, which could result in infinite recursion.
4586   std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
4587     BackedgeTakenCounts.insert(std::make_pair(L, BackedgeTakenInfo()));
4588   if (!Pair.second)
4589     return Pair.first->second;
4590 
4591   // ComputeBackedgeTakenCount may allocate memory for its result. Inserting it
4592   // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
4593   // must be cleared in this scope.
4594   BackedgeTakenInfo Result = ComputeBackedgeTakenCount(L);
4595 
4596   if (Result.getExact(this) != getCouldNotCompute()) {
4597     assert(isLoopInvariant(Result.getExact(this), L) &&
4598            isLoopInvariant(Result.getMax(this), L) &&
4599            "Computed backedge-taken count isn't loop invariant for loop!");
4600     ++NumTripCountsComputed;
4601   }
4602   else if (Result.getMax(this) == getCouldNotCompute() &&
4603            isa<PHINode>(L->getHeader()->begin())) {
4604     // Only count loops that have phi nodes as not being computable.
4605     ++NumTripCountsNotComputed;
4606   }
4607 
4608   // Now that we know more about the trip count for this loop, forget any
4609   // existing SCEV values for PHI nodes in this loop since they are only
4610   // conservative estimates made without the benefit of trip count
4611   // information. This is similar to the code in forgetLoop, except that
4612   // it handles SCEVUnknown PHI nodes specially.
4613   if (Result.hasAnyInfo()) {
4614     SmallVector<Instruction *, 16> Worklist;
4615     PushLoopPHIs(L, Worklist);
4616 
4617     SmallPtrSet<Instruction *, 8> Visited;
4618     while (!Worklist.empty()) {
4619       Instruction *I = Worklist.pop_back_val();
4620       if (!Visited.insert(I).second)
4621         continue;
4622 
4623       ValueExprMapType::iterator It =
4624         ValueExprMap.find_as(static_cast<Value *>(I));
4625       if (It != ValueExprMap.end()) {
4626         const SCEV *Old = It->second;
4627 
4628         // SCEVUnknown for a PHI either means that it has an unrecognized
4629         // structure, or it's a PHI that's in the progress of being computed
4630         // by createNodeForPHI.  In the former case, additional loop trip
4631         // count information isn't going to change anything. In the later
4632         // case, createNodeForPHI will perform the necessary updates on its
4633         // own when it gets to that point.
4634         if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) {
4635           forgetMemoizedResults(Old);
4636           ValueExprMap.erase(It);
4637         }
4638         if (PHINode *PN = dyn_cast<PHINode>(I))
4639           ConstantEvolutionLoopExitValue.erase(PN);
4640       }
4641 
4642       PushDefUseChildren(I, Worklist);
4643     }
4644   }
4645 
4646   // Re-lookup the insert position, since the call to
4647   // ComputeBackedgeTakenCount above could result in a
4648   // recusive call to getBackedgeTakenInfo (on a different
4649   // loop), which would invalidate the iterator computed
4650   // earlier.
4651   return BackedgeTakenCounts.find(L)->second = Result;
4652 }
4653 
4654 /// forgetLoop - This method should be called by the client when it has
4655 /// changed a loop in a way that may effect ScalarEvolution's ability to
4656 /// compute a trip count, or if the loop is deleted.
4657 void ScalarEvolution::forgetLoop(const Loop *L) {
4658   // Drop any stored trip count value.
4659   DenseMap<const Loop*, BackedgeTakenInfo>::iterator BTCPos =
4660     BackedgeTakenCounts.find(L);
4661   if (BTCPos != BackedgeTakenCounts.end()) {
4662     BTCPos->second.clear();
4663     BackedgeTakenCounts.erase(BTCPos);
4664   }
4665 
4666   // Drop information about expressions based on loop-header PHIs.
4667   SmallVector<Instruction *, 16> Worklist;
4668   PushLoopPHIs(L, Worklist);
4669 
4670   SmallPtrSet<Instruction *, 8> Visited;
4671   while (!Worklist.empty()) {
4672     Instruction *I = Worklist.pop_back_val();
4673     if (!Visited.insert(I).second)
4674       continue;
4675 
4676     ValueExprMapType::iterator It =
4677       ValueExprMap.find_as(static_cast<Value *>(I));
4678     if (It != ValueExprMap.end()) {
4679       forgetMemoizedResults(It->second);
4680       ValueExprMap.erase(It);
4681       if (PHINode *PN = dyn_cast<PHINode>(I))
4682         ConstantEvolutionLoopExitValue.erase(PN);
4683     }
4684 
4685     PushDefUseChildren(I, Worklist);
4686   }
4687 
4688   // Forget all contained loops too, to avoid dangling entries in the
4689   // ValuesAtScopes map.
4690   for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
4691     forgetLoop(*I);
4692 }
4693 
4694 /// forgetValue - This method should be called by the client when it has
4695 /// changed a value in a way that may effect its value, or which may
4696 /// disconnect it from a def-use chain linking it to a loop.
4697 void ScalarEvolution::forgetValue(Value *V) {
4698   Instruction *I = dyn_cast<Instruction>(V);
4699   if (!I) return;
4700 
4701   // Drop information about expressions based on loop-header PHIs.
4702   SmallVector<Instruction *, 16> Worklist;
4703   Worklist.push_back(I);
4704 
4705   SmallPtrSet<Instruction *, 8> Visited;
4706   while (!Worklist.empty()) {
4707     I = Worklist.pop_back_val();
4708     if (!Visited.insert(I).second)
4709       continue;
4710 
4711     ValueExprMapType::iterator It =
4712       ValueExprMap.find_as(static_cast<Value *>(I));
4713     if (It != ValueExprMap.end()) {
4714       forgetMemoizedResults(It->second);
4715       ValueExprMap.erase(It);
4716       if (PHINode *PN = dyn_cast<PHINode>(I))
4717         ConstantEvolutionLoopExitValue.erase(PN);
4718     }
4719 
4720     PushDefUseChildren(I, Worklist);
4721   }
4722 }
4723 
4724 /// getExact - Get the exact loop backedge taken count considering all loop
4725 /// exits. A computable result can only be return for loops with a single exit.
4726 /// Returning the minimum taken count among all exits is incorrect because one
4727 /// of the loop's exit limit's may have been skipped. HowFarToZero assumes that
4728 /// the limit of each loop test is never skipped. This is a valid assumption as
4729 /// long as the loop exits via that test. For precise results, it is the
4730 /// caller's responsibility to specify the relevant loop exit using
4731 /// getExact(ExitingBlock, SE).
4732 const SCEV *
4733 ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const {
4734   // If any exits were not computable, the loop is not computable.
4735   if (!ExitNotTaken.isCompleteList()) return SE->getCouldNotCompute();
4736 
4737   // We need exactly one computable exit.
4738   if (!ExitNotTaken.ExitingBlock) return SE->getCouldNotCompute();
4739   assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info");
4740 
4741   const SCEV *BECount = nullptr;
4742   for (const ExitNotTakenInfo *ENT = &ExitNotTaken;
4743        ENT != nullptr; ENT = ENT->getNextExit()) {
4744 
4745     assert(ENT->ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV");
4746 
4747     if (!BECount)
4748       BECount = ENT->ExactNotTaken;
4749     else if (BECount != ENT->ExactNotTaken)
4750       return SE->getCouldNotCompute();
4751   }
4752   assert(BECount && "Invalid not taken count for loop exit");
4753   return BECount;
4754 }
4755 
4756 /// getExact - Get the exact not taken count for this loop exit.
4757 const SCEV *
4758 ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock,
4759                                              ScalarEvolution *SE) const {
4760   for (const ExitNotTakenInfo *ENT = &ExitNotTaken;
4761        ENT != nullptr; ENT = ENT->getNextExit()) {
4762 
4763     if (ENT->ExitingBlock == ExitingBlock)
4764       return ENT->ExactNotTaken;
4765   }
4766   return SE->getCouldNotCompute();
4767 }
4768 
4769 /// getMax - Get the max backedge taken count for the loop.
4770 const SCEV *
4771 ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const {
4772   return Max ? Max : SE->getCouldNotCompute();
4773 }
4774 
4775 bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
4776                                                     ScalarEvolution *SE) const {
4777   if (Max && Max != SE->getCouldNotCompute() && SE->hasOperand(Max, S))
4778     return true;
4779 
4780   if (!ExitNotTaken.ExitingBlock)
4781     return false;
4782 
4783   for (const ExitNotTakenInfo *ENT = &ExitNotTaken;
4784        ENT != nullptr; ENT = ENT->getNextExit()) {
4785 
4786     if (ENT->ExactNotTaken != SE->getCouldNotCompute()
4787         && SE->hasOperand(ENT->ExactNotTaken, S)) {
4788       return true;
4789     }
4790   }
4791   return false;
4792 }
4793 
4794 /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
4795 /// computable exit into a persistent ExitNotTakenInfo array.
4796 ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
4797   SmallVectorImpl< std::pair<BasicBlock *, const SCEV *> > &ExitCounts,
4798   bool Complete, const SCEV *MaxCount) : Max(MaxCount) {
4799 
4800   if (!Complete)
4801     ExitNotTaken.setIncomplete();
4802 
4803   unsigned NumExits = ExitCounts.size();
4804   if (NumExits == 0) return;
4805 
4806   ExitNotTaken.ExitingBlock = ExitCounts[0].first;
4807   ExitNotTaken.ExactNotTaken = ExitCounts[0].second;
4808   if (NumExits == 1) return;
4809 
4810   // Handle the rare case of multiple computable exits.
4811   ExitNotTakenInfo *ENT = new ExitNotTakenInfo[NumExits-1];
4812 
4813   ExitNotTakenInfo *PrevENT = &ExitNotTaken;
4814   for (unsigned i = 1; i < NumExits; ++i, PrevENT = ENT, ++ENT) {
4815     PrevENT->setNextExit(ENT);
4816     ENT->ExitingBlock = ExitCounts[i].first;
4817     ENT->ExactNotTaken = ExitCounts[i].second;
4818   }
4819 }
4820 
4821 /// clear - Invalidate this result and free the ExitNotTakenInfo array.
4822 void ScalarEvolution::BackedgeTakenInfo::clear() {
4823   ExitNotTaken.ExitingBlock = nullptr;
4824   ExitNotTaken.ExactNotTaken = nullptr;
4825   delete[] ExitNotTaken.getNextExit();
4826 }
4827 
4828 /// ComputeBackedgeTakenCount - Compute the number of times the backedge
4829 /// of the specified loop will execute.
4830 ScalarEvolution::BackedgeTakenInfo
4831 ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
4832   SmallVector<BasicBlock *, 8> ExitingBlocks;
4833   L->getExitingBlocks(ExitingBlocks);
4834 
4835   SmallVector<std::pair<BasicBlock *, const SCEV *>, 4> ExitCounts;
4836   bool CouldComputeBECount = true;
4837   BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
4838   const SCEV *MustExitMaxBECount = nullptr;
4839   const SCEV *MayExitMaxBECount = nullptr;
4840 
4841   // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
4842   // and compute maxBECount.
4843   for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
4844     BasicBlock *ExitBB = ExitingBlocks[i];
4845     ExitLimit EL = ComputeExitLimit(L, ExitBB);
4846 
4847     // 1. For each exit that can be computed, add an entry to ExitCounts.
4848     // CouldComputeBECount is true only if all exits can be computed.
4849     if (EL.Exact == getCouldNotCompute())
4850       // We couldn't compute an exact value for this exit, so
4851       // we won't be able to compute an exact value for the loop.
4852       CouldComputeBECount = false;
4853     else
4854       ExitCounts.push_back(std::make_pair(ExitBB, EL.Exact));
4855 
4856     // 2. Derive the loop's MaxBECount from each exit's max number of
4857     // non-exiting iterations. Partition the loop exits into two kinds:
4858     // LoopMustExits and LoopMayExits.
4859     //
4860     // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
4861     // is a LoopMayExit.  If any computable LoopMustExit is found, then
4862     // MaxBECount is the minimum EL.Max of computable LoopMustExits. Otherwise,
4863     // MaxBECount is conservatively the maximum EL.Max, where CouldNotCompute is
4864     // considered greater than any computable EL.Max.
4865     if (EL.Max != getCouldNotCompute() && Latch &&
4866         DT->dominates(ExitBB, Latch)) {
4867       if (!MustExitMaxBECount)
4868         MustExitMaxBECount = EL.Max;
4869       else {
4870         MustExitMaxBECount =
4871           getUMinFromMismatchedTypes(MustExitMaxBECount, EL.Max);
4872       }
4873     } else if (MayExitMaxBECount != getCouldNotCompute()) {
4874       if (!MayExitMaxBECount || EL.Max == getCouldNotCompute())
4875         MayExitMaxBECount = EL.Max;
4876       else {
4877         MayExitMaxBECount =
4878           getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.Max);
4879       }
4880     }
4881   }
4882   const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
4883     (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
4884   return BackedgeTakenInfo(ExitCounts, CouldComputeBECount, MaxBECount);
4885 }
4886 
4887 /// ComputeExitLimit - Compute the number of times the backedge of the specified
4888 /// loop will execute if it exits via the specified block.
4889 ScalarEvolution::ExitLimit
4890 ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) {
4891 
4892   // Okay, we've chosen an exiting block.  See what condition causes us to
4893   // exit at this block and remember the exit block and whether all other targets
4894   // lead to the loop header.
4895   bool MustExecuteLoopHeader = true;
4896   BasicBlock *Exit = nullptr;
4897   for (succ_iterator SI = succ_begin(ExitingBlock), SE = succ_end(ExitingBlock);
4898        SI != SE; ++SI)
4899     if (!L->contains(*SI)) {
4900       if (Exit) // Multiple exit successors.
4901         return getCouldNotCompute();
4902       Exit = *SI;
4903     } else if (*SI != L->getHeader()) {
4904       MustExecuteLoopHeader = false;
4905     }
4906 
4907   // At this point, we know we have a conditional branch that determines whether
4908   // the loop is exited.  However, we don't know if the branch is executed each
4909   // time through the loop.  If not, then the execution count of the branch will
4910   // not be equal to the trip count of the loop.
4911   //
4912   // Currently we check for this by checking to see if the Exit branch goes to
4913   // the loop header.  If so, we know it will always execute the same number of
4914   // times as the loop.  We also handle the case where the exit block *is* the
4915   // loop header.  This is common for un-rotated loops.
4916   //
4917   // If both of those tests fail, walk up the unique predecessor chain to the
4918   // header, stopping if there is an edge that doesn't exit the loop. If the
4919   // header is reached, the execution count of the branch will be equal to the
4920   // trip count of the loop.
4921   //
4922   //  More extensive analysis could be done to handle more cases here.
4923   //
4924   if (!MustExecuteLoopHeader && ExitingBlock != L->getHeader()) {
4925     // The simple checks failed, try climbing the unique predecessor chain
4926     // up to the header.
4927     bool Ok = false;
4928     for (BasicBlock *BB = ExitingBlock; BB; ) {
4929       BasicBlock *Pred = BB->getUniquePredecessor();
4930       if (!Pred)
4931         return getCouldNotCompute();
4932       TerminatorInst *PredTerm = Pred->getTerminator();
4933       for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) {
4934         BasicBlock *PredSucc = PredTerm->getSuccessor(i);
4935         if (PredSucc == BB)
4936           continue;
4937         // If the predecessor has a successor that isn't BB and isn't
4938         // outside the loop, assume the worst.
4939         if (L->contains(PredSucc))
4940           return getCouldNotCompute();
4941       }
4942       if (Pred == L->getHeader()) {
4943         Ok = true;
4944         break;
4945       }
4946       BB = Pred;
4947     }
4948     if (!Ok)
4949       return getCouldNotCompute();
4950   }
4951 
4952   bool IsOnlyExit = (L->getExitingBlock() != nullptr);
4953   TerminatorInst *Term = ExitingBlock->getTerminator();
4954   if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
4955     assert(BI->isConditional() && "If unconditional, it can't be in loop!");
4956     // Proceed to the next level to examine the exit condition expression.
4957     return ComputeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0),
4958                                     BI->getSuccessor(1),
4959                                     /*ControlsExit=*/IsOnlyExit);
4960   }
4961 
4962   if (SwitchInst *SI = dyn_cast<SwitchInst>(Term))
4963     return ComputeExitLimitFromSingleExitSwitch(L, SI, Exit,
4964                                                 /*ControlsExit=*/IsOnlyExit);
4965 
4966   return getCouldNotCompute();
4967 }
4968 
4969 /// ComputeExitLimitFromCond - Compute the number of times the
4970 /// backedge of the specified loop will execute if its exit condition
4971 /// were a conditional branch of ExitCond, TBB, and FBB.
4972 ///
4973 /// @param ControlsExit is true if ExitCond directly controls the exit
4974 /// branch. In this case, we can assume that the loop exits only if the
4975 /// condition is true and can infer that failing to meet the condition prior to
4976 /// integer wraparound results in undefined behavior.
4977 ScalarEvolution::ExitLimit
4978 ScalarEvolution::ComputeExitLimitFromCond(const Loop *L,
4979                                           Value *ExitCond,
4980                                           BasicBlock *TBB,
4981                                           BasicBlock *FBB,
4982                                           bool ControlsExit) {
4983   // Check if the controlling expression for this loop is an And or Or.
4984   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
4985     if (BO->getOpcode() == Instruction::And) {
4986       // Recurse on the operands of the and.
4987       bool EitherMayExit = L->contains(TBB);
4988       ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB,
4989                                                ControlsExit && !EitherMayExit);
4990       ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB,
4991                                                ControlsExit && !EitherMayExit);
4992       const SCEV *BECount = getCouldNotCompute();
4993       const SCEV *MaxBECount = getCouldNotCompute();
4994       if (EitherMayExit) {
4995         // Both conditions must be true for the loop to continue executing.
4996         // Choose the less conservative count.
4997         if (EL0.Exact == getCouldNotCompute() ||
4998             EL1.Exact == getCouldNotCompute())
4999           BECount = getCouldNotCompute();
5000         else
5001           BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact);
5002         if (EL0.Max == getCouldNotCompute())
5003           MaxBECount = EL1.Max;
5004         else if (EL1.Max == getCouldNotCompute())
5005           MaxBECount = EL0.Max;
5006         else
5007           MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max);
5008       } else {
5009         // Both conditions must be true at the same time for the loop to exit.
5010         // For now, be conservative.
5011         assert(L->contains(FBB) && "Loop block has no successor in loop!");
5012         if (EL0.Max == EL1.Max)
5013           MaxBECount = EL0.Max;
5014         if (EL0.Exact == EL1.Exact)
5015           BECount = EL0.Exact;
5016       }
5017 
5018       return ExitLimit(BECount, MaxBECount);
5019     }
5020     if (BO->getOpcode() == Instruction::Or) {
5021       // Recurse on the operands of the or.
5022       bool EitherMayExit = L->contains(FBB);
5023       ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB,
5024                                                ControlsExit && !EitherMayExit);
5025       ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB,
5026                                                ControlsExit && !EitherMayExit);
5027       const SCEV *BECount = getCouldNotCompute();
5028       const SCEV *MaxBECount = getCouldNotCompute();
5029       if (EitherMayExit) {
5030         // Both conditions must be false for the loop to continue executing.
5031         // Choose the less conservative count.
5032         if (EL0.Exact == getCouldNotCompute() ||
5033             EL1.Exact == getCouldNotCompute())
5034           BECount = getCouldNotCompute();
5035         else
5036           BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact);
5037         if (EL0.Max == getCouldNotCompute())
5038           MaxBECount = EL1.Max;
5039         else if (EL1.Max == getCouldNotCompute())
5040           MaxBECount = EL0.Max;
5041         else
5042           MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max);
5043       } else {
5044         // Both conditions must be false at the same time for the loop to exit.
5045         // For now, be conservative.
5046         assert(L->contains(TBB) && "Loop block has no successor in loop!");
5047         if (EL0.Max == EL1.Max)
5048           MaxBECount = EL0.Max;
5049         if (EL0.Exact == EL1.Exact)
5050           BECount = EL0.Exact;
5051       }
5052 
5053       return ExitLimit(BECount, MaxBECount);
5054     }
5055   }
5056 
5057   // With an icmp, it may be feasible to compute an exact backedge-taken count.
5058   // Proceed to the next level to examine the icmp.
5059   if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond))
5060     return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit);
5061 
5062   // Check for a constant condition. These are normally stripped out by
5063   // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
5064   // preserve the CFG and is temporarily leaving constant conditions
5065   // in place.
5066   if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
5067     if (L->contains(FBB) == !CI->getZExtValue())
5068       // The backedge is always taken.
5069       return getCouldNotCompute();
5070     else
5071       // The backedge is never taken.
5072       return getConstant(CI->getType(), 0);
5073   }
5074 
5075   // If it's not an integer or pointer comparison then compute it the hard way.
5076   return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
5077 }
5078 
5079 /// ComputeExitLimitFromICmp - Compute the number of times the
5080 /// backedge of the specified loop will execute if its exit condition
5081 /// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB.
5082 ScalarEvolution::ExitLimit
5083 ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L,
5084                                           ICmpInst *ExitCond,
5085                                           BasicBlock *TBB,
5086                                           BasicBlock *FBB,
5087                                           bool ControlsExit) {
5088 
5089   // If the condition was exit on true, convert the condition to exit on false
5090   ICmpInst::Predicate Cond;
5091   if (!L->contains(FBB))
5092     Cond = ExitCond->getPredicate();
5093   else
5094     Cond = ExitCond->getInversePredicate();
5095 
5096   // Handle common loops like: for (X = "string"; *X; ++X)
5097   if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
5098     if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
5099       ExitLimit ItCnt =
5100         ComputeLoadConstantCompareExitLimit(LI, RHS, L, Cond);
5101       if (ItCnt.hasAnyInfo())
5102         return ItCnt;
5103     }
5104 
5105   const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
5106   const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
5107 
5108   // Try to evaluate any dependencies out of the loop.
5109   LHS = getSCEVAtScope(LHS, L);
5110   RHS = getSCEVAtScope(RHS, L);
5111 
5112   // At this point, we would like to compute how many iterations of the
5113   // loop the predicate will return true for these inputs.
5114   if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
5115     // If there is a loop-invariant, force it into the RHS.
5116     std::swap(LHS, RHS);
5117     Cond = ICmpInst::getSwappedPredicate(Cond);
5118   }
5119 
5120   // Simplify the operands before analyzing them.
5121   (void)SimplifyICmpOperands(Cond, LHS, RHS);
5122 
5123   // If we have a comparison of a chrec against a constant, try to use value
5124   // ranges to answer this query.
5125   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
5126     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
5127       if (AddRec->getLoop() == L) {
5128         // Form the constant range.
5129         ConstantRange CompRange(
5130             ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue()));
5131 
5132         const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
5133         if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
5134       }
5135 
5136   switch (Cond) {
5137   case ICmpInst::ICMP_NE: {                     // while (X != Y)
5138     // Convert to: while (X-Y != 0)
5139     ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit);
5140     if (EL.hasAnyInfo()) return EL;
5141     break;
5142   }
5143   case ICmpInst::ICMP_EQ: {                     // while (X == Y)
5144     // Convert to: while (X-Y == 0)
5145     ExitLimit EL = HowFarToNonZero(getMinusSCEV(LHS, RHS), L);
5146     if (EL.hasAnyInfo()) return EL;
5147     break;
5148   }
5149   case ICmpInst::ICMP_SLT:
5150   case ICmpInst::ICMP_ULT: {                    // while (X < Y)
5151     bool IsSigned = Cond == ICmpInst::ICMP_SLT;
5152     ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, ControlsExit);
5153     if (EL.hasAnyInfo()) return EL;
5154     break;
5155   }
5156   case ICmpInst::ICMP_SGT:
5157   case ICmpInst::ICMP_UGT: {                    // while (X > Y)
5158     bool IsSigned = Cond == ICmpInst::ICMP_SGT;
5159     ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit);
5160     if (EL.hasAnyInfo()) return EL;
5161     break;
5162   }
5163   default:
5164 #if 0
5165     dbgs() << "ComputeBackedgeTakenCount ";
5166     if (ExitCond->getOperand(0)->getType()->isUnsigned())
5167       dbgs() << "[unsigned] ";
5168     dbgs() << *LHS << "   "
5169          << Instruction::getOpcodeName(Instruction::ICmp)
5170          << "   " << *RHS << "\n";
5171 #endif
5172     break;
5173   }
5174   return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
5175 }
5176 
5177 ScalarEvolution::ExitLimit
5178 ScalarEvolution::ComputeExitLimitFromSingleExitSwitch(const Loop *L,
5179                                                       SwitchInst *Switch,
5180                                                       BasicBlock *ExitingBlock,
5181                                                       bool ControlsExit) {
5182   assert(!L->contains(ExitingBlock) && "Not an exiting block!");
5183 
5184   // Give up if the exit is the default dest of a switch.
5185   if (Switch->getDefaultDest() == ExitingBlock)
5186     return getCouldNotCompute();
5187 
5188   assert(L->contains(Switch->getDefaultDest()) &&
5189          "Default case must not exit the loop!");
5190   const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
5191   const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
5192 
5193   // while (X != Y) --> while (X-Y != 0)
5194   ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit);
5195   if (EL.hasAnyInfo())
5196     return EL;
5197 
5198   return getCouldNotCompute();
5199 }
5200 
5201 static ConstantInt *
5202 EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
5203                                 ScalarEvolution &SE) {
5204   const SCEV *InVal = SE.getConstant(C);
5205   const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
5206   assert(isa<SCEVConstant>(Val) &&
5207          "Evaluation of SCEV at constant didn't fold correctly?");
5208   return cast<SCEVConstant>(Val)->getValue();
5209 }
5210 
5211 /// ComputeLoadConstantCompareExitLimit - Given an exit condition of
5212 /// 'icmp op load X, cst', try to see if we can compute the backedge
5213 /// execution count.
5214 ScalarEvolution::ExitLimit
5215 ScalarEvolution::ComputeLoadConstantCompareExitLimit(
5216   LoadInst *LI,
5217   Constant *RHS,
5218   const Loop *L,
5219   ICmpInst::Predicate predicate) {
5220 
5221   if (LI->isVolatile()) return getCouldNotCompute();
5222 
5223   // Check to see if the loaded pointer is a getelementptr of a global.
5224   // TODO: Use SCEV instead of manually grubbing with GEPs.
5225   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
5226   if (!GEP) return getCouldNotCompute();
5227 
5228   // Make sure that it is really a constant global we are gepping, with an
5229   // initializer, and make sure the first IDX is really 0.
5230   GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
5231   if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() ||
5232       GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
5233       !cast<Constant>(GEP->getOperand(1))->isNullValue())
5234     return getCouldNotCompute();
5235 
5236   // Okay, we allow one non-constant index into the GEP instruction.
5237   Value *VarIdx = nullptr;
5238   std::vector<Constant*> Indexes;
5239   unsigned VarIdxNum = 0;
5240   for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
5241     if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
5242       Indexes.push_back(CI);
5243     } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
5244       if (VarIdx) return getCouldNotCompute();  // Multiple non-constant idx's.
5245       VarIdx = GEP->getOperand(i);
5246       VarIdxNum = i-2;
5247       Indexes.push_back(nullptr);
5248     }
5249 
5250   // Loop-invariant loads may be a byproduct of loop optimization. Skip them.
5251   if (!VarIdx)
5252     return getCouldNotCompute();
5253 
5254   // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
5255   // Check to see if X is a loop variant variable value now.
5256   const SCEV *Idx = getSCEV(VarIdx);
5257   Idx = getSCEVAtScope(Idx, L);
5258 
5259   // We can only recognize very limited forms of loop index expressions, in
5260   // particular, only affine AddRec's like {C1,+,C2}.
5261   const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
5262   if (!IdxExpr || !IdxExpr->isAffine() || isLoopInvariant(IdxExpr, L) ||
5263       !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
5264       !isa<SCEVConstant>(IdxExpr->getOperand(1)))
5265     return getCouldNotCompute();
5266 
5267   unsigned MaxSteps = MaxBruteForceIterations;
5268   for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
5269     ConstantInt *ItCst = ConstantInt::get(
5270                            cast<IntegerType>(IdxExpr->getType()), IterationNum);
5271     ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
5272 
5273     // Form the GEP offset.
5274     Indexes[VarIdxNum] = Val;
5275 
5276     Constant *Result = ConstantFoldLoadThroughGEPIndices(GV->getInitializer(),
5277                                                          Indexes);
5278     if (!Result) break;  // Cannot compute!
5279 
5280     // Evaluate the condition for this iteration.
5281     Result = ConstantExpr::getICmp(predicate, Result, RHS);
5282     if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
5283     if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
5284 #if 0
5285       dbgs() << "\n***\n*** Computed loop count " << *ItCst
5286              << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
5287              << "***\n";
5288 #endif
5289       ++NumArrayLenItCounts;
5290       return getConstant(ItCst);   // Found terminating iteration!
5291     }
5292   }
5293   return getCouldNotCompute();
5294 }
5295 
5296 
5297 /// CanConstantFold - Return true if we can constant fold an instruction of the
5298 /// specified type, assuming that all operands were constants.
5299 static bool CanConstantFold(const Instruction *I) {
5300   if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
5301       isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
5302       isa<LoadInst>(I))
5303     return true;
5304 
5305   if (const CallInst *CI = dyn_cast<CallInst>(I))
5306     if (const Function *F = CI->getCalledFunction())
5307       return canConstantFoldCallTo(F);
5308   return false;
5309 }
5310 
5311 /// Determine whether this instruction can constant evolve within this loop
5312 /// assuming its operands can all constant evolve.
5313 static bool canConstantEvolve(Instruction *I, const Loop *L) {
5314   // An instruction outside of the loop can't be derived from a loop PHI.
5315   if (!L->contains(I)) return false;
5316 
5317   if (isa<PHINode>(I)) {
5318     // We don't currently keep track of the control flow needed to evaluate
5319     // PHIs, so we cannot handle PHIs inside of loops.
5320     return L->getHeader() == I->getParent();
5321   }
5322 
5323   // If we won't be able to constant fold this expression even if the operands
5324   // are constants, bail early.
5325   return CanConstantFold(I);
5326 }
5327 
5328 /// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
5329 /// recursing through each instruction operand until reaching a loop header phi.
5330 static PHINode *
5331 getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L,
5332                                DenseMap<Instruction *, PHINode *> &PHIMap) {
5333 
5334   // Otherwise, we can evaluate this instruction if all of its operands are
5335   // constant or derived from a PHI node themselves.
5336   PHINode *PHI = nullptr;
5337   for (Instruction::op_iterator OpI = UseInst->op_begin(),
5338          OpE = UseInst->op_end(); OpI != OpE; ++OpI) {
5339 
5340     if (isa<Constant>(*OpI)) continue;
5341 
5342     Instruction *OpInst = dyn_cast<Instruction>(*OpI);
5343     if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
5344 
5345     PHINode *P = dyn_cast<PHINode>(OpInst);
5346     if (!P)
5347       // If this operand is already visited, reuse the prior result.
5348       // We may have P != PHI if this is the deepest point at which the
5349       // inconsistent paths meet.
5350       P = PHIMap.lookup(OpInst);
5351     if (!P) {
5352       // Recurse and memoize the results, whether a phi is found or not.
5353       // This recursive call invalidates pointers into PHIMap.
5354       P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap);
5355       PHIMap[OpInst] = P;
5356     }
5357     if (!P)
5358       return nullptr;  // Not evolving from PHI
5359     if (PHI && PHI != P)
5360       return nullptr;  // Evolving from multiple different PHIs.
5361     PHI = P;
5362   }
5363   // This is a expression evolving from a constant PHI!
5364   return PHI;
5365 }
5366 
5367 /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
5368 /// in the loop that V is derived from.  We allow arbitrary operations along the
5369 /// way, but the operands of an operation must either be constants or a value
5370 /// derived from a constant PHI.  If this expression does not fit with these
5371 /// constraints, return null.
5372 static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
5373   Instruction *I = dyn_cast<Instruction>(V);
5374   if (!I || !canConstantEvolve(I, L)) return nullptr;
5375 
5376   if (PHINode *PN = dyn_cast<PHINode>(I)) {
5377     return PN;
5378   }
5379 
5380   // Record non-constant instructions contained by the loop.
5381   DenseMap<Instruction *, PHINode *> PHIMap;
5382   return getConstantEvolvingPHIOperands(I, L, PHIMap);
5383 }
5384 
5385 /// EvaluateExpression - Given an expression that passes the
5386 /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
5387 /// in the loop has the value PHIVal.  If we can't fold this expression for some
5388 /// reason, return null.
5389 static Constant *EvaluateExpression(Value *V, const Loop *L,
5390                                     DenseMap<Instruction *, Constant *> &Vals,
5391                                     const DataLayout &DL,
5392                                     const TargetLibraryInfo *TLI) {
5393   // Convenient constant check, but redundant for recursive calls.
5394   if (Constant *C = dyn_cast<Constant>(V)) return C;
5395   Instruction *I = dyn_cast<Instruction>(V);
5396   if (!I) return nullptr;
5397 
5398   if (Constant *C = Vals.lookup(I)) return C;
5399 
5400   // An instruction inside the loop depends on a value outside the loop that we
5401   // weren't given a mapping for, or a value such as a call inside the loop.
5402   if (!canConstantEvolve(I, L)) return nullptr;
5403 
5404   // An unmapped PHI can be due to a branch or another loop inside this loop,
5405   // or due to this not being the initial iteration through a loop where we
5406   // couldn't compute the evolution of this particular PHI last time.
5407   if (isa<PHINode>(I)) return nullptr;
5408 
5409   std::vector<Constant*> Operands(I->getNumOperands());
5410 
5411   for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
5412     Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
5413     if (!Operand) {
5414       Operands[i] = dyn_cast<Constant>(I->getOperand(i));
5415       if (!Operands[i]) return nullptr;
5416       continue;
5417     }
5418     Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
5419     Vals[Operand] = C;
5420     if (!C) return nullptr;
5421     Operands[i] = C;
5422   }
5423 
5424   if (CmpInst *CI = dyn_cast<CmpInst>(I))
5425     return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
5426                                            Operands[1], DL, TLI);
5427   if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
5428     if (!LI->isVolatile())
5429       return ConstantFoldLoadFromConstPtr(Operands[0], DL);
5430   }
5431   return ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, DL,
5432                                   TLI);
5433 }
5434 
5435 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
5436 /// in the header of its containing loop, we know the loop executes a
5437 /// constant number of times, and the PHI node is just a recurrence
5438 /// involving constants, fold it.
5439 Constant *
5440 ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
5441                                                    const APInt &BEs,
5442                                                    const Loop *L) {
5443   DenseMap<PHINode*, Constant*>::const_iterator I =
5444     ConstantEvolutionLoopExitValue.find(PN);
5445   if (I != ConstantEvolutionLoopExitValue.end())
5446     return I->second;
5447 
5448   if (BEs.ugt(MaxBruteForceIterations))
5449     return ConstantEvolutionLoopExitValue[PN] = nullptr;  // Not going to evaluate it.
5450 
5451   Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
5452 
5453   DenseMap<Instruction *, Constant *> CurrentIterVals;
5454   BasicBlock *Header = L->getHeader();
5455   assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
5456 
5457   // Since the loop is canonicalized, the PHI node must have two entries.  One
5458   // entry must be a constant (coming in from outside of the loop), and the
5459   // second must be derived from the same PHI.
5460   bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
5461   PHINode *PHI = nullptr;
5462   for (BasicBlock::iterator I = Header->begin();
5463        (PHI = dyn_cast<PHINode>(I)); ++I) {
5464     Constant *StartCST =
5465       dyn_cast<Constant>(PHI->getIncomingValue(!SecondIsBackedge));
5466     if (!StartCST) continue;
5467     CurrentIterVals[PHI] = StartCST;
5468   }
5469   if (!CurrentIterVals.count(PN))
5470     return RetVal = nullptr;
5471 
5472   Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
5473 
5474   // Execute the loop symbolically to determine the exit value.
5475   if (BEs.getActiveBits() >= 32)
5476     return RetVal = nullptr; // More than 2^32-1 iterations?? Not doing it!
5477 
5478   unsigned NumIterations = BEs.getZExtValue(); // must be in range
5479   unsigned IterationNum = 0;
5480   const DataLayout &DL = F->getParent()->getDataLayout();
5481   for (; ; ++IterationNum) {
5482     if (IterationNum == NumIterations)
5483       return RetVal = CurrentIterVals[PN];  // Got exit value!
5484 
5485     // Compute the value of the PHIs for the next iteration.
5486     // EvaluateExpression adds non-phi values to the CurrentIterVals map.
5487     DenseMap<Instruction *, Constant *> NextIterVals;
5488     Constant *NextPHI =
5489         EvaluateExpression(BEValue, L, CurrentIterVals, DL, TLI);
5490     if (!NextPHI)
5491       return nullptr;        // Couldn't evaluate!
5492     NextIterVals[PN] = NextPHI;
5493 
5494     bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
5495 
5496     // Also evaluate the other PHI nodes.  However, we don't get to stop if we
5497     // cease to be able to evaluate one of them or if they stop evolving,
5498     // because that doesn't necessarily prevent us from computing PN.
5499     SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute;
5500     for (DenseMap<Instruction *, Constant *>::const_iterator
5501            I = CurrentIterVals.begin(), E = CurrentIterVals.end(); I != E; ++I){
5502       PHINode *PHI = dyn_cast<PHINode>(I->first);
5503       if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
5504       PHIsToCompute.push_back(std::make_pair(PHI, I->second));
5505     }
5506     // We use two distinct loops because EvaluateExpression may invalidate any
5507     // iterators into CurrentIterVals.
5508     for (SmallVectorImpl<std::pair<PHINode *, Constant*> >::const_iterator
5509              I = PHIsToCompute.begin(), E = PHIsToCompute.end(); I != E; ++I) {
5510       PHINode *PHI = I->first;
5511       Constant *&NextPHI = NextIterVals[PHI];
5512       if (!NextPHI) {   // Not already computed.
5513         Value *BEValue = PHI->getIncomingValue(SecondIsBackedge);
5514         NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, TLI);
5515       }
5516       if (NextPHI != I->second)
5517         StoppedEvolving = false;
5518     }
5519 
5520     // If all entries in CurrentIterVals == NextIterVals then we can stop
5521     // iterating, the loop can't continue to change.
5522     if (StoppedEvolving)
5523       return RetVal = CurrentIterVals[PN];
5524 
5525     CurrentIterVals.swap(NextIterVals);
5526   }
5527 }
5528 
5529 /// ComputeExitCountExhaustively - If the loop is known to execute a
5530 /// constant number of times (the condition evolves only from constants),
5531 /// try to evaluate a few iterations of the loop until we get the exit
5532 /// condition gets a value of ExitWhen (true or false).  If we cannot
5533 /// evaluate the trip count of the loop, return getCouldNotCompute().
5534 const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L,
5535                                                           Value *Cond,
5536                                                           bool ExitWhen) {
5537   PHINode *PN = getConstantEvolvingPHI(Cond, L);
5538   if (!PN) return getCouldNotCompute();
5539 
5540   // If the loop is canonicalized, the PHI will have exactly two entries.
5541   // That's the only form we support here.
5542   if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
5543 
5544   DenseMap<Instruction *, Constant *> CurrentIterVals;
5545   BasicBlock *Header = L->getHeader();
5546   assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
5547 
5548   // One entry must be a constant (coming in from outside of the loop), and the
5549   // second must be derived from the same PHI.
5550   bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
5551   PHINode *PHI = nullptr;
5552   for (BasicBlock::iterator I = Header->begin();
5553        (PHI = dyn_cast<PHINode>(I)); ++I) {
5554     Constant *StartCST =
5555       dyn_cast<Constant>(PHI->getIncomingValue(!SecondIsBackedge));
5556     if (!StartCST) continue;
5557     CurrentIterVals[PHI] = StartCST;
5558   }
5559   if (!CurrentIterVals.count(PN))
5560     return getCouldNotCompute();
5561 
5562   // Okay, we find a PHI node that defines the trip count of this loop.  Execute
5563   // the loop symbolically to determine when the condition gets a value of
5564   // "ExitWhen".
5565   unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
5566   const DataLayout &DL = F->getParent()->getDataLayout();
5567   for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
5568     ConstantInt *CondVal = dyn_cast_or_null<ConstantInt>(
5569         EvaluateExpression(Cond, L, CurrentIterVals, DL, TLI));
5570 
5571     // Couldn't symbolically evaluate.
5572     if (!CondVal) return getCouldNotCompute();
5573 
5574     if (CondVal->getValue() == uint64_t(ExitWhen)) {
5575       ++NumBruteForceTripCountsComputed;
5576       return getConstant(Type::getInt32Ty(getContext()), IterationNum);
5577     }
5578 
5579     // Update all the PHI nodes for the next iteration.
5580     DenseMap<Instruction *, Constant *> NextIterVals;
5581 
5582     // Create a list of which PHIs we need to compute. We want to do this before
5583     // calling EvaluateExpression on them because that may invalidate iterators
5584     // into CurrentIterVals.
5585     SmallVector<PHINode *, 8> PHIsToCompute;
5586     for (DenseMap<Instruction *, Constant *>::const_iterator
5587            I = CurrentIterVals.begin(), E = CurrentIterVals.end(); I != E; ++I){
5588       PHINode *PHI = dyn_cast<PHINode>(I->first);
5589       if (!PHI || PHI->getParent() != Header) continue;
5590       PHIsToCompute.push_back(PHI);
5591     }
5592     for (SmallVectorImpl<PHINode *>::const_iterator I = PHIsToCompute.begin(),
5593              E = PHIsToCompute.end(); I != E; ++I) {
5594       PHINode *PHI = *I;
5595       Constant *&NextPHI = NextIterVals[PHI];
5596       if (NextPHI) continue;    // Already computed!
5597 
5598       Value *BEValue = PHI->getIncomingValue(SecondIsBackedge);
5599       NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, TLI);
5600     }
5601     CurrentIterVals.swap(NextIterVals);
5602   }
5603 
5604   // Too many iterations were needed to evaluate.
5605   return getCouldNotCompute();
5606 }
5607 
5608 /// getSCEVAtScope - Return a SCEV expression for the specified value
5609 /// at the specified scope in the program.  The L value specifies a loop
5610 /// nest to evaluate the expression at, where null is the top-level or a
5611 /// specified loop is immediately inside of the loop.
5612 ///
5613 /// This method can be used to compute the exit value for a variable defined
5614 /// in a loop by querying what the value will hold in the parent loop.
5615 ///
5616 /// In the case that a relevant loop exit value cannot be computed, the
5617 /// original value V is returned.
5618 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
5619   // Check to see if we've folded this expression at this loop before.
5620   SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = ValuesAtScopes[V];
5621   for (unsigned u = 0; u < Values.size(); u++) {
5622     if (Values[u].first == L)
5623       return Values[u].second ? Values[u].second : V;
5624   }
5625   Values.push_back(std::make_pair(L, static_cast<const SCEV *>(nullptr)));
5626   // Otherwise compute it.
5627   const SCEV *C = computeSCEVAtScope(V, L);
5628   SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values2 = ValuesAtScopes[V];
5629   for (unsigned u = Values2.size(); u > 0; u--) {
5630     if (Values2[u - 1].first == L) {
5631       Values2[u - 1].second = C;
5632       break;
5633     }
5634   }
5635   return C;
5636 }
5637 
5638 /// This builds up a Constant using the ConstantExpr interface.  That way, we
5639 /// will return Constants for objects which aren't represented by a
5640 /// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
5641 /// Returns NULL if the SCEV isn't representable as a Constant.
5642 static Constant *BuildConstantFromSCEV(const SCEV *V) {
5643   switch (static_cast<SCEVTypes>(V->getSCEVType())) {
5644     case scCouldNotCompute:
5645     case scAddRecExpr:
5646       break;
5647     case scConstant:
5648       return cast<SCEVConstant>(V)->getValue();
5649     case scUnknown:
5650       return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
5651     case scSignExtend: {
5652       const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
5653       if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
5654         return ConstantExpr::getSExt(CastOp, SS->getType());
5655       break;
5656     }
5657     case scZeroExtend: {
5658       const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
5659       if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
5660         return ConstantExpr::getZExt(CastOp, SZ->getType());
5661       break;
5662     }
5663     case scTruncate: {
5664       const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
5665       if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
5666         return ConstantExpr::getTrunc(CastOp, ST->getType());
5667       break;
5668     }
5669     case scAddExpr: {
5670       const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
5671       if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) {
5672         if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
5673           unsigned AS = PTy->getAddressSpace();
5674           Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
5675           C = ConstantExpr::getBitCast(C, DestPtrTy);
5676         }
5677         for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) {
5678           Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i));
5679           if (!C2) return nullptr;
5680 
5681           // First pointer!
5682           if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) {
5683             unsigned AS = C2->getType()->getPointerAddressSpace();
5684             std::swap(C, C2);
5685             Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
5686             // The offsets have been converted to bytes.  We can add bytes to an
5687             // i8* by GEP with the byte count in the first index.
5688             C = ConstantExpr::getBitCast(C, DestPtrTy);
5689           }
5690 
5691           // Don't bother trying to sum two pointers. We probably can't
5692           // statically compute a load that results from it anyway.
5693           if (C2->getType()->isPointerTy())
5694             return nullptr;
5695 
5696           if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
5697             if (PTy->getElementType()->isStructTy())
5698               C2 = ConstantExpr::getIntegerCast(
5699                   C2, Type::getInt32Ty(C->getContext()), true);
5700             C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2);
5701           } else
5702             C = ConstantExpr::getAdd(C, C2);
5703         }
5704         return C;
5705       }
5706       break;
5707     }
5708     case scMulExpr: {
5709       const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
5710       if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) {
5711         // Don't bother with pointers at all.
5712         if (C->getType()->isPointerTy()) return nullptr;
5713         for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) {
5714           Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i));
5715           if (!C2 || C2->getType()->isPointerTy()) return nullptr;
5716           C = ConstantExpr::getMul(C, C2);
5717         }
5718         return C;
5719       }
5720       break;
5721     }
5722     case scUDivExpr: {
5723       const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
5724       if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS()))
5725         if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS()))
5726           if (LHS->getType() == RHS->getType())
5727             return ConstantExpr::getUDiv(LHS, RHS);
5728       break;
5729     }
5730     case scSMaxExpr:
5731     case scUMaxExpr:
5732       break; // TODO: smax, umax.
5733   }
5734   return nullptr;
5735 }
5736 
5737 const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
5738   if (isa<SCEVConstant>(V)) return V;
5739 
5740   // If this instruction is evolved from a constant-evolving PHI, compute the
5741   // exit value from the loop without using SCEVs.
5742   if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
5743     if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
5744       const Loop *LI = (*this->LI)[I->getParent()];
5745       if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
5746         if (PHINode *PN = dyn_cast<PHINode>(I))
5747           if (PN->getParent() == LI->getHeader()) {
5748             // Okay, there is no closed form solution for the PHI node.  Check
5749             // to see if the loop that contains it has a known backedge-taken
5750             // count.  If so, we may be able to force computation of the exit
5751             // value.
5752             const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI);
5753             if (const SCEVConstant *BTCC =
5754                   dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
5755               // Okay, we know how many times the containing loop executes.  If
5756               // this is a constant evolving PHI node, get the final value at
5757               // the specified iteration number.
5758               Constant *RV = getConstantEvolutionLoopExitValue(PN,
5759                                                    BTCC->getValue()->getValue(),
5760                                                                LI);
5761               if (RV) return getSCEV(RV);
5762             }
5763           }
5764 
5765       // Okay, this is an expression that we cannot symbolically evaluate
5766       // into a SCEV.  Check to see if it's possible to symbolically evaluate
5767       // the arguments into constants, and if so, try to constant propagate the
5768       // result.  This is particularly useful for computing loop exit values.
5769       if (CanConstantFold(I)) {
5770         SmallVector<Constant *, 4> Operands;
5771         bool MadeImprovement = false;
5772         for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
5773           Value *Op = I->getOperand(i);
5774           if (Constant *C = dyn_cast<Constant>(Op)) {
5775             Operands.push_back(C);
5776             continue;
5777           }
5778 
5779           // If any of the operands is non-constant and if they are
5780           // non-integer and non-pointer, don't even try to analyze them
5781           // with scev techniques.
5782           if (!isSCEVable(Op->getType()))
5783             return V;
5784 
5785           const SCEV *OrigV = getSCEV(Op);
5786           const SCEV *OpV = getSCEVAtScope(OrigV, L);
5787           MadeImprovement |= OrigV != OpV;
5788 
5789           Constant *C = BuildConstantFromSCEV(OpV);
5790           if (!C) return V;
5791           if (C->getType() != Op->getType())
5792             C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
5793                                                               Op->getType(),
5794                                                               false),
5795                                       C, Op->getType());
5796           Operands.push_back(C);
5797         }
5798 
5799         // Check to see if getSCEVAtScope actually made an improvement.
5800         if (MadeImprovement) {
5801           Constant *C = nullptr;
5802           const DataLayout &DL = F->getParent()->getDataLayout();
5803           if (const CmpInst *CI = dyn_cast<CmpInst>(I))
5804             C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
5805                                                 Operands[1], DL, TLI);
5806           else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) {
5807             if (!LI->isVolatile())
5808               C = ConstantFoldLoadFromConstPtr(Operands[0], DL);
5809           } else
5810             C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands,
5811                                          DL, TLI);
5812           if (!C) return V;
5813           return getSCEV(C);
5814         }
5815       }
5816     }
5817 
5818     // This is some other type of SCEVUnknown, just return it.
5819     return V;
5820   }
5821 
5822   if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
5823     // Avoid performing the look-up in the common case where the specified
5824     // expression has no loop-variant portions.
5825     for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
5826       const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
5827       if (OpAtScope != Comm->getOperand(i)) {
5828         // Okay, at least one of these operands is loop variant but might be
5829         // foldable.  Build a new instance of the folded commutative expression.
5830         SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
5831                                             Comm->op_begin()+i);
5832         NewOps.push_back(OpAtScope);
5833 
5834         for (++i; i != e; ++i) {
5835           OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
5836           NewOps.push_back(OpAtScope);
5837         }
5838         if (isa<SCEVAddExpr>(Comm))
5839           return getAddExpr(NewOps);
5840         if (isa<SCEVMulExpr>(Comm))
5841           return getMulExpr(NewOps);
5842         if (isa<SCEVSMaxExpr>(Comm))
5843           return getSMaxExpr(NewOps);
5844         if (isa<SCEVUMaxExpr>(Comm))
5845           return getUMaxExpr(NewOps);
5846         llvm_unreachable("Unknown commutative SCEV type!");
5847       }
5848     }
5849     // If we got here, all operands are loop invariant.
5850     return Comm;
5851   }
5852 
5853   if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
5854     const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
5855     const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
5856     if (LHS == Div->getLHS() && RHS == Div->getRHS())
5857       return Div;   // must be loop invariant
5858     return getUDivExpr(LHS, RHS);
5859   }
5860 
5861   // If this is a loop recurrence for a loop that does not contain L, then we
5862   // are dealing with the final value computed by the loop.
5863   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
5864     // First, attempt to evaluate each operand.
5865     // Avoid performing the look-up in the common case where the specified
5866     // expression has no loop-variant portions.
5867     for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
5868       const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
5869       if (OpAtScope == AddRec->getOperand(i))
5870         continue;
5871 
5872       // Okay, at least one of these operands is loop variant but might be
5873       // foldable.  Build a new instance of the folded commutative expression.
5874       SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(),
5875                                           AddRec->op_begin()+i);
5876       NewOps.push_back(OpAtScope);
5877       for (++i; i != e; ++i)
5878         NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
5879 
5880       const SCEV *FoldedRec =
5881         getAddRecExpr(NewOps, AddRec->getLoop(),
5882                       AddRec->getNoWrapFlags(SCEV::FlagNW));
5883       AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
5884       // The addrec may be folded to a nonrecurrence, for example, if the
5885       // induction variable is multiplied by zero after constant folding. Go
5886       // ahead and return the folded value.
5887       if (!AddRec)
5888         return FoldedRec;
5889       break;
5890     }
5891 
5892     // If the scope is outside the addrec's loop, evaluate it by using the
5893     // loop exit value of the addrec.
5894     if (!AddRec->getLoop()->contains(L)) {
5895       // To evaluate this recurrence, we need to know how many times the AddRec
5896       // loop iterates.  Compute this now.
5897       const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
5898       if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
5899 
5900       // Then, evaluate the AddRec.
5901       return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
5902     }
5903 
5904     return AddRec;
5905   }
5906 
5907   if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) {
5908     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
5909     if (Op == Cast->getOperand())
5910       return Cast;  // must be loop invariant
5911     return getZeroExtendExpr(Op, Cast->getType());
5912   }
5913 
5914   if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) {
5915     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
5916     if (Op == Cast->getOperand())
5917       return Cast;  // must be loop invariant
5918     return getSignExtendExpr(Op, Cast->getType());
5919   }
5920 
5921   if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) {
5922     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
5923     if (Op == Cast->getOperand())
5924       return Cast;  // must be loop invariant
5925     return getTruncateExpr(Op, Cast->getType());
5926   }
5927 
5928   llvm_unreachable("Unknown SCEV type!");
5929 }
5930 
5931 /// getSCEVAtScope - This is a convenience function which does
5932 /// getSCEVAtScope(getSCEV(V), L).
5933 const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
5934   return getSCEVAtScope(getSCEV(V), L);
5935 }
5936 
5937 /// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
5938 /// following equation:
5939 ///
5940 ///     A * X = B (mod N)
5941 ///
5942 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of
5943 /// A and B isn't important.
5944 ///
5945 /// If the equation does not have a solution, SCEVCouldNotCompute is returned.
5946 static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
5947                                                ScalarEvolution &SE) {
5948   uint32_t BW = A.getBitWidth();
5949   assert(BW == B.getBitWidth() && "Bit widths must be the same.");
5950   assert(A != 0 && "A must be non-zero.");
5951 
5952   // 1. D = gcd(A, N)
5953   //
5954   // The gcd of A and N may have only one prime factor: 2. The number of
5955   // trailing zeros in A is its multiplicity
5956   uint32_t Mult2 = A.countTrailingZeros();
5957   // D = 2^Mult2
5958 
5959   // 2. Check if B is divisible by D.
5960   //
5961   // B is divisible by D if and only if the multiplicity of prime factor 2 for B
5962   // is not less than multiplicity of this prime factor for D.
5963   if (B.countTrailingZeros() < Mult2)
5964     return SE.getCouldNotCompute();
5965 
5966   // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
5967   // modulo (N / D).
5968   //
5969   // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
5970   // bit width during computations.
5971   APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
5972   APInt Mod(BW + 1, 0);
5973   Mod.setBit(BW - Mult2);  // Mod = N / D
5974   APInt I = AD.multiplicativeInverse(Mod);
5975 
5976   // 4. Compute the minimum unsigned root of the equation:
5977   // I * (B / D) mod (N / D)
5978   APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
5979 
5980   // The result is guaranteed to be less than 2^BW so we may truncate it to BW
5981   // bits.
5982   return SE.getConstant(Result.trunc(BW));
5983 }
5984 
5985 /// SolveQuadraticEquation - Find the roots of the quadratic equation for the
5986 /// given quadratic chrec {L,+,M,+,N}.  This returns either the two roots (which
5987 /// might be the same) or two SCEVCouldNotCompute objects.
5988 ///
5989 static std::pair<const SCEV *,const SCEV *>
5990 SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
5991   assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
5992   const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
5993   const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
5994   const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
5995 
5996   // We currently can only solve this if the coefficients are constants.
5997   if (!LC || !MC || !NC) {
5998     const SCEV *CNC = SE.getCouldNotCompute();
5999     return std::make_pair(CNC, CNC);
6000   }
6001 
6002   uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
6003   const APInt &L = LC->getValue()->getValue();
6004   const APInt &M = MC->getValue()->getValue();
6005   const APInt &N = NC->getValue()->getValue();
6006   APInt Two(BitWidth, 2);
6007   APInt Four(BitWidth, 4);
6008 
6009   {
6010     using namespace APIntOps;
6011     const APInt& C = L;
6012     // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
6013     // The B coefficient is M-N/2
6014     APInt B(M);
6015     B -= sdiv(N,Two);
6016 
6017     // The A coefficient is N/2
6018     APInt A(N.sdiv(Two));
6019 
6020     // Compute the B^2-4ac term.
6021     APInt SqrtTerm(B);
6022     SqrtTerm *= B;
6023     SqrtTerm -= Four * (A * C);
6024 
6025     if (SqrtTerm.isNegative()) {
6026       // The loop is provably infinite.
6027       const SCEV *CNC = SE.getCouldNotCompute();
6028       return std::make_pair(CNC, CNC);
6029     }
6030 
6031     // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
6032     // integer value or else APInt::sqrt() will assert.
6033     APInt SqrtVal(SqrtTerm.sqrt());
6034 
6035     // Compute the two solutions for the quadratic formula.
6036     // The divisions must be performed as signed divisions.
6037     APInt NegB(-B);
6038     APInt TwoA(A << 1);
6039     if (TwoA.isMinValue()) {
6040       const SCEV *CNC = SE.getCouldNotCompute();
6041       return std::make_pair(CNC, CNC);
6042     }
6043 
6044     LLVMContext &Context = SE.getContext();
6045 
6046     ConstantInt *Solution1 =
6047       ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA));
6048     ConstantInt *Solution2 =
6049       ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA));
6050 
6051     return std::make_pair(SE.getConstant(Solution1),
6052                           SE.getConstant(Solution2));
6053   } // end APIntOps namespace
6054 }
6055 
6056 /// HowFarToZero - Return the number of times a backedge comparing the specified
6057 /// value to zero will execute.  If not computable, return CouldNotCompute.
6058 ///
6059 /// This is only used for loops with a "x != y" exit test. The exit condition is
6060 /// now expressed as a single expression, V = x-y. So the exit test is
6061 /// effectively V != 0.  We know and take advantage of the fact that this
6062 /// expression only being used in a comparison by zero context.
6063 ScalarEvolution::ExitLimit
6064 ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) {
6065   // If the value is a constant
6066   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
6067     // If the value is already zero, the branch will execute zero times.
6068     if (C->getValue()->isZero()) return C;
6069     return getCouldNotCompute();  // Otherwise it will loop infinitely.
6070   }
6071 
6072   const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
6073   if (!AddRec || AddRec->getLoop() != L)
6074     return getCouldNotCompute();
6075 
6076   // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
6077   // the quadratic equation to solve it.
6078   if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
6079     std::pair<const SCEV *,const SCEV *> Roots =
6080       SolveQuadraticEquation(AddRec, *this);
6081     const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
6082     const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
6083     if (R1 && R2) {
6084 #if 0
6085       dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1
6086              << "  sol#2: " << *R2 << "\n";
6087 #endif
6088       // Pick the smallest positive root value.
6089       if (ConstantInt *CB =
6090           dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT,
6091                                                       R1->getValue(),
6092                                                       R2->getValue()))) {
6093         if (!CB->getZExtValue())
6094           std::swap(R1, R2);   // R1 is the minimum root now.
6095 
6096         // We can only use this value if the chrec ends up with an exact zero
6097         // value at this index.  When solving for "X*X != 5", for example, we
6098         // should not accept a root of 2.
6099         const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
6100         if (Val->isZero())
6101           return R1;  // We found a quadratic root!
6102       }
6103     }
6104     return getCouldNotCompute();
6105   }
6106 
6107   // Otherwise we can only handle this if it is affine.
6108   if (!AddRec->isAffine())
6109     return getCouldNotCompute();
6110 
6111   // If this is an affine expression, the execution count of this branch is
6112   // the minimum unsigned root of the following equation:
6113   //
6114   //     Start + Step*N = 0 (mod 2^BW)
6115   //
6116   // equivalent to:
6117   //
6118   //             Step*N = -Start (mod 2^BW)
6119   //
6120   // where BW is the common bit width of Start and Step.
6121 
6122   // Get the initial value for the loop.
6123   const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
6124   const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
6125 
6126   // For now we handle only constant steps.
6127   //
6128   // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
6129   // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
6130   // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
6131   // We have not yet seen any such cases.
6132   const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
6133   if (!StepC || StepC->getValue()->equalsInt(0))
6134     return getCouldNotCompute();
6135 
6136   // For positive steps (counting up until unsigned overflow):
6137   //   N = -Start/Step (as unsigned)
6138   // For negative steps (counting down to zero):
6139   //   N = Start/-Step
6140   // First compute the unsigned distance from zero in the direction of Step.
6141   bool CountDown = StepC->getValue()->getValue().isNegative();
6142   const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
6143 
6144   // Handle unitary steps, which cannot wraparound.
6145   // 1*N = -Start; -1*N = Start (mod 2^BW), so:
6146   //   N = Distance (as unsigned)
6147   if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue()) {
6148     ConstantRange CR = getUnsignedRange(Start);
6149     const SCEV *MaxBECount;
6150     if (!CountDown && CR.getUnsignedMin().isMinValue())
6151       // When counting up, the worst starting value is 1, not 0.
6152       MaxBECount = CR.getUnsignedMax().isMinValue()
6153         ? getConstant(APInt::getMinValue(CR.getBitWidth()))
6154         : getConstant(APInt::getMaxValue(CR.getBitWidth()));
6155     else
6156       MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
6157                                          : -CR.getUnsignedMin());
6158     return ExitLimit(Distance, MaxBECount);
6159   }
6160 
6161   // As a special case, handle the instance where Step is a positive power of
6162   // two. In this case, determining whether Step divides Distance evenly can be
6163   // done by counting and comparing the number of trailing zeros of Step and
6164   // Distance.
6165   if (!CountDown) {
6166     const APInt &StepV = StepC->getValue()->getValue();
6167     // StepV.isPowerOf2() returns true if StepV is an positive power of two.  It
6168     // also returns true if StepV is maximally negative (eg, INT_MIN), but that
6169     // case is not handled as this code is guarded by !CountDown.
6170     if (StepV.isPowerOf2() &&
6171         GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros())
6172       return getUDivExactExpr(Distance, Step);
6173   }
6174 
6175   // If the condition controls loop exit (the loop exits only if the expression
6176   // is true) and the addition is no-wrap we can use unsigned divide to
6177   // compute the backedge count.  In this case, the step may not divide the
6178   // distance, but we don't care because if the condition is "missed" the loop
6179   // will have undefined behavior due to wrapping.
6180   if (ControlsExit && AddRec->getNoWrapFlags(SCEV::FlagNW)) {
6181     const SCEV *Exact =
6182         getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
6183     return ExitLimit(Exact, Exact);
6184   }
6185 
6186   // Then, try to solve the above equation provided that Start is constant.
6187   if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
6188     return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
6189                                         -StartC->getValue()->getValue(),
6190                                         *this);
6191   return getCouldNotCompute();
6192 }
6193 
6194 /// HowFarToNonZero - Return the number of times a backedge checking the
6195 /// specified value for nonzero will execute.  If not computable, return
6196 /// CouldNotCompute
6197 ScalarEvolution::ExitLimit
6198 ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) {
6199   // Loops that look like: while (X == 0) are very strange indeed.  We don't
6200   // handle them yet except for the trivial case.  This could be expanded in the
6201   // future as needed.
6202 
6203   // If the value is a constant, check to see if it is known to be non-zero
6204   // already.  If so, the backedge will execute zero times.
6205   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
6206     if (!C->getValue()->isNullValue())
6207       return getConstant(C->getType(), 0);
6208     return getCouldNotCompute();  // Otherwise it will loop infinitely.
6209   }
6210 
6211   // We could implement others, but I really doubt anyone writes loops like
6212   // this, and if they did, they would already be constant folded.
6213   return getCouldNotCompute();
6214 }
6215 
6216 /// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
6217 /// (which may not be an immediate predecessor) which has exactly one
6218 /// successor from which BB is reachable, or null if no such block is
6219 /// found.
6220 ///
6221 std::pair<BasicBlock *, BasicBlock *>
6222 ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
6223   // If the block has a unique predecessor, then there is no path from the
6224   // predecessor to the block that does not go through the direct edge
6225   // from the predecessor to the block.
6226   if (BasicBlock *Pred = BB->getSinglePredecessor())
6227     return std::make_pair(Pred, BB);
6228 
6229   // A loop's header is defined to be a block that dominates the loop.
6230   // If the header has a unique predecessor outside the loop, it must be
6231   // a block that has exactly one successor that can reach the loop.
6232   if (Loop *L = LI->getLoopFor(BB))
6233     return std::make_pair(L->getLoopPredecessor(), L->getHeader());
6234 
6235   return std::pair<BasicBlock *, BasicBlock *>();
6236 }
6237 
6238 /// HasSameValue - SCEV structural equivalence is usually sufficient for
6239 /// testing whether two expressions are equal, however for the purposes of
6240 /// looking for a condition guarding a loop, it can be useful to be a little
6241 /// more general, since a front-end may have replicated the controlling
6242 /// expression.
6243 ///
6244 static bool HasSameValue(const SCEV *A, const SCEV *B) {
6245   // Quick check to see if they are the same SCEV.
6246   if (A == B) return true;
6247 
6248   // Otherwise, if they're both SCEVUnknown, it's possible that they hold
6249   // two different instructions with the same value. Check for this case.
6250   if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
6251     if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
6252       if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
6253         if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
6254           if (AI->isIdenticalTo(BI) && !AI->mayReadFromMemory())
6255             return true;
6256 
6257   // Otherwise assume they may have a different value.
6258   return false;
6259 }
6260 
6261 /// SimplifyICmpOperands - Simplify LHS and RHS in a comparison with
6262 /// predicate Pred. Return true iff any changes were made.
6263 ///
6264 bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
6265                                            const SCEV *&LHS, const SCEV *&RHS,
6266                                            unsigned Depth) {
6267   bool Changed = false;
6268 
6269   // If we hit the max recursion limit bail out.
6270   if (Depth >= 3)
6271     return false;
6272 
6273   // Canonicalize a constant to the right side.
6274   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
6275     // Check for both operands constant.
6276     if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
6277       if (ConstantExpr::getICmp(Pred,
6278                                 LHSC->getValue(),
6279                                 RHSC->getValue())->isNullValue())
6280         goto trivially_false;
6281       else
6282         goto trivially_true;
6283     }
6284     // Otherwise swap the operands to put the constant on the right.
6285     std::swap(LHS, RHS);
6286     Pred = ICmpInst::getSwappedPredicate(Pred);
6287     Changed = true;
6288   }
6289 
6290   // If we're comparing an addrec with a value which is loop-invariant in the
6291   // addrec's loop, put the addrec on the left. Also make a dominance check,
6292   // as both operands could be addrecs loop-invariant in each other's loop.
6293   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
6294     const Loop *L = AR->getLoop();
6295     if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
6296       std::swap(LHS, RHS);
6297       Pred = ICmpInst::getSwappedPredicate(Pred);
6298       Changed = true;
6299     }
6300   }
6301 
6302   // If there's a constant operand, canonicalize comparisons with boundary
6303   // cases, and canonicalize *-or-equal comparisons to regular comparisons.
6304   if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
6305     const APInt &RA = RC->getValue()->getValue();
6306     switch (Pred) {
6307     default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
6308     case ICmpInst::ICMP_EQ:
6309     case ICmpInst::ICMP_NE:
6310       // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
6311       if (!RA)
6312         if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS))
6313           if (const SCEVMulExpr *ME = dyn_cast<SCEVMulExpr>(AE->getOperand(0)))
6314             if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 &&
6315                 ME->getOperand(0)->isAllOnesValue()) {
6316               RHS = AE->getOperand(1);
6317               LHS = ME->getOperand(1);
6318               Changed = true;
6319             }
6320       break;
6321     case ICmpInst::ICMP_UGE:
6322       if ((RA - 1).isMinValue()) {
6323         Pred = ICmpInst::ICMP_NE;
6324         RHS = getConstant(RA - 1);
6325         Changed = true;
6326         break;
6327       }
6328       if (RA.isMaxValue()) {
6329         Pred = ICmpInst::ICMP_EQ;
6330         Changed = true;
6331         break;
6332       }
6333       if (RA.isMinValue()) goto trivially_true;
6334 
6335       Pred = ICmpInst::ICMP_UGT;
6336       RHS = getConstant(RA - 1);
6337       Changed = true;
6338       break;
6339     case ICmpInst::ICMP_ULE:
6340       if ((RA + 1).isMaxValue()) {
6341         Pred = ICmpInst::ICMP_NE;
6342         RHS = getConstant(RA + 1);
6343         Changed = true;
6344         break;
6345       }
6346       if (RA.isMinValue()) {
6347         Pred = ICmpInst::ICMP_EQ;
6348         Changed = true;
6349         break;
6350       }
6351       if (RA.isMaxValue()) goto trivially_true;
6352 
6353       Pred = ICmpInst::ICMP_ULT;
6354       RHS = getConstant(RA + 1);
6355       Changed = true;
6356       break;
6357     case ICmpInst::ICMP_SGE:
6358       if ((RA - 1).isMinSignedValue()) {
6359         Pred = ICmpInst::ICMP_NE;
6360         RHS = getConstant(RA - 1);
6361         Changed = true;
6362         break;
6363       }
6364       if (RA.isMaxSignedValue()) {
6365         Pred = ICmpInst::ICMP_EQ;
6366         Changed = true;
6367         break;
6368       }
6369       if (RA.isMinSignedValue()) goto trivially_true;
6370 
6371       Pred = ICmpInst::ICMP_SGT;
6372       RHS = getConstant(RA - 1);
6373       Changed = true;
6374       break;
6375     case ICmpInst::ICMP_SLE:
6376       if ((RA + 1).isMaxSignedValue()) {
6377         Pred = ICmpInst::ICMP_NE;
6378         RHS = getConstant(RA + 1);
6379         Changed = true;
6380         break;
6381       }
6382       if (RA.isMinSignedValue()) {
6383         Pred = ICmpInst::ICMP_EQ;
6384         Changed = true;
6385         break;
6386       }
6387       if (RA.isMaxSignedValue()) goto trivially_true;
6388 
6389       Pred = ICmpInst::ICMP_SLT;
6390       RHS = getConstant(RA + 1);
6391       Changed = true;
6392       break;
6393     case ICmpInst::ICMP_UGT:
6394       if (RA.isMinValue()) {
6395         Pred = ICmpInst::ICMP_NE;
6396         Changed = true;
6397         break;
6398       }
6399       if ((RA + 1).isMaxValue()) {
6400         Pred = ICmpInst::ICMP_EQ;
6401         RHS = getConstant(RA + 1);
6402         Changed = true;
6403         break;
6404       }
6405       if (RA.isMaxValue()) goto trivially_false;
6406       break;
6407     case ICmpInst::ICMP_ULT:
6408       if (RA.isMaxValue()) {
6409         Pred = ICmpInst::ICMP_NE;
6410         Changed = true;
6411         break;
6412       }
6413       if ((RA - 1).isMinValue()) {
6414         Pred = ICmpInst::ICMP_EQ;
6415         RHS = getConstant(RA - 1);
6416         Changed = true;
6417         break;
6418       }
6419       if (RA.isMinValue()) goto trivially_false;
6420       break;
6421     case ICmpInst::ICMP_SGT:
6422       if (RA.isMinSignedValue()) {
6423         Pred = ICmpInst::ICMP_NE;
6424         Changed = true;
6425         break;
6426       }
6427       if ((RA + 1).isMaxSignedValue()) {
6428         Pred = ICmpInst::ICMP_EQ;
6429         RHS = getConstant(RA + 1);
6430         Changed = true;
6431         break;
6432       }
6433       if (RA.isMaxSignedValue()) goto trivially_false;
6434       break;
6435     case ICmpInst::ICMP_SLT:
6436       if (RA.isMaxSignedValue()) {
6437         Pred = ICmpInst::ICMP_NE;
6438         Changed = true;
6439         break;
6440       }
6441       if ((RA - 1).isMinSignedValue()) {
6442        Pred = ICmpInst::ICMP_EQ;
6443        RHS = getConstant(RA - 1);
6444         Changed = true;
6445        break;
6446       }
6447       if (RA.isMinSignedValue()) goto trivially_false;
6448       break;
6449     }
6450   }
6451 
6452   // Check for obvious equality.
6453   if (HasSameValue(LHS, RHS)) {
6454     if (ICmpInst::isTrueWhenEqual(Pred))
6455       goto trivially_true;
6456     if (ICmpInst::isFalseWhenEqual(Pred))
6457       goto trivially_false;
6458   }
6459 
6460   // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
6461   // adding or subtracting 1 from one of the operands.
6462   switch (Pred) {
6463   case ICmpInst::ICMP_SLE:
6464     if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) {
6465       RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
6466                        SCEV::FlagNSW);
6467       Pred = ICmpInst::ICMP_SLT;
6468       Changed = true;
6469     } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) {
6470       LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
6471                        SCEV::FlagNSW);
6472       Pred = ICmpInst::ICMP_SLT;
6473       Changed = true;
6474     }
6475     break;
6476   case ICmpInst::ICMP_SGE:
6477     if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) {
6478       RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
6479                        SCEV::FlagNSW);
6480       Pred = ICmpInst::ICMP_SGT;
6481       Changed = true;
6482     } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) {
6483       LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
6484                        SCEV::FlagNSW);
6485       Pred = ICmpInst::ICMP_SGT;
6486       Changed = true;
6487     }
6488     break;
6489   case ICmpInst::ICMP_ULE:
6490     if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) {
6491       RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
6492                        SCEV::FlagNUW);
6493       Pred = ICmpInst::ICMP_ULT;
6494       Changed = true;
6495     } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
6496       LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
6497                        SCEV::FlagNUW);
6498       Pred = ICmpInst::ICMP_ULT;
6499       Changed = true;
6500     }
6501     break;
6502   case ICmpInst::ICMP_UGE:
6503     if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
6504       RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
6505                        SCEV::FlagNUW);
6506       Pred = ICmpInst::ICMP_UGT;
6507       Changed = true;
6508     } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
6509       LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
6510                        SCEV::FlagNUW);
6511       Pred = ICmpInst::ICMP_UGT;
6512       Changed = true;
6513     }
6514     break;
6515   default:
6516     break;
6517   }
6518 
6519   // TODO: More simplifications are possible here.
6520 
6521   // Recursively simplify until we either hit a recursion limit or nothing
6522   // changes.
6523   if (Changed)
6524     return SimplifyICmpOperands(Pred, LHS, RHS, Depth+1);
6525 
6526   return Changed;
6527 
6528 trivially_true:
6529   // Return 0 == 0.
6530   LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
6531   Pred = ICmpInst::ICMP_EQ;
6532   return true;
6533 
6534 trivially_false:
6535   // Return 0 != 0.
6536   LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
6537   Pred = ICmpInst::ICMP_NE;
6538   return true;
6539 }
6540 
6541 bool ScalarEvolution::isKnownNegative(const SCEV *S) {
6542   return getSignedRange(S).getSignedMax().isNegative();
6543 }
6544 
6545 bool ScalarEvolution::isKnownPositive(const SCEV *S) {
6546   return getSignedRange(S).getSignedMin().isStrictlyPositive();
6547 }
6548 
6549 bool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
6550   return !getSignedRange(S).getSignedMin().isNegative();
6551 }
6552 
6553 bool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
6554   return !getSignedRange(S).getSignedMax().isStrictlyPositive();
6555 }
6556 
6557 bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
6558   return isKnownNegative(S) || isKnownPositive(S);
6559 }
6560 
6561 bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
6562                                        const SCEV *LHS, const SCEV *RHS) {
6563   // Canonicalize the inputs first.
6564   (void)SimplifyICmpOperands(Pred, LHS, RHS);
6565 
6566   // If LHS or RHS is an addrec, check to see if the condition is true in
6567   // every iteration of the loop.
6568   // If LHS and RHS are both addrec, both conditions must be true in
6569   // every iteration of the loop.
6570   const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
6571   const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
6572   bool LeftGuarded = false;
6573   bool RightGuarded = false;
6574   if (LAR) {
6575     const Loop *L = LAR->getLoop();
6576     if (isLoopEntryGuardedByCond(L, Pred, LAR->getStart(), RHS) &&
6577         isLoopBackedgeGuardedByCond(L, Pred, LAR->getPostIncExpr(*this), RHS)) {
6578       if (!RAR) return true;
6579       LeftGuarded = true;
6580     }
6581   }
6582   if (RAR) {
6583     const Loop *L = RAR->getLoop();
6584     if (isLoopEntryGuardedByCond(L, Pred, LHS, RAR->getStart()) &&
6585         isLoopBackedgeGuardedByCond(L, Pred, LHS, RAR->getPostIncExpr(*this))) {
6586       if (!LAR) return true;
6587       RightGuarded = true;
6588     }
6589   }
6590   if (LeftGuarded && RightGuarded)
6591     return true;
6592 
6593   // Otherwise see what can be done with known constant ranges.
6594   return isKnownPredicateWithRanges(Pred, LHS, RHS);
6595 }
6596 
6597 bool
6598 ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred,
6599                                             const SCEV *LHS, const SCEV *RHS) {
6600   if (HasSameValue(LHS, RHS))
6601     return ICmpInst::isTrueWhenEqual(Pred);
6602 
6603   // This code is split out from isKnownPredicate because it is called from
6604   // within isLoopEntryGuardedByCond.
6605   switch (Pred) {
6606   default:
6607     llvm_unreachable("Unexpected ICmpInst::Predicate value!");
6608   case ICmpInst::ICMP_SGT:
6609     std::swap(LHS, RHS);
6610   case ICmpInst::ICMP_SLT: {
6611     ConstantRange LHSRange = getSignedRange(LHS);
6612     ConstantRange RHSRange = getSignedRange(RHS);
6613     if (LHSRange.getSignedMax().slt(RHSRange.getSignedMin()))
6614       return true;
6615     if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax()))
6616       return false;
6617     break;
6618   }
6619   case ICmpInst::ICMP_SGE:
6620     std::swap(LHS, RHS);
6621   case ICmpInst::ICMP_SLE: {
6622     ConstantRange LHSRange = getSignedRange(LHS);
6623     ConstantRange RHSRange = getSignedRange(RHS);
6624     if (LHSRange.getSignedMax().sle(RHSRange.getSignedMin()))
6625       return true;
6626     if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax()))
6627       return false;
6628     break;
6629   }
6630   case ICmpInst::ICMP_UGT:
6631     std::swap(LHS, RHS);
6632   case ICmpInst::ICMP_ULT: {
6633     ConstantRange LHSRange = getUnsignedRange(LHS);
6634     ConstantRange RHSRange = getUnsignedRange(RHS);
6635     if (LHSRange.getUnsignedMax().ult(RHSRange.getUnsignedMin()))
6636       return true;
6637     if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax()))
6638       return false;
6639     break;
6640   }
6641   case ICmpInst::ICMP_UGE:
6642     std::swap(LHS, RHS);
6643   case ICmpInst::ICMP_ULE: {
6644     ConstantRange LHSRange = getUnsignedRange(LHS);
6645     ConstantRange RHSRange = getUnsignedRange(RHS);
6646     if (LHSRange.getUnsignedMax().ule(RHSRange.getUnsignedMin()))
6647       return true;
6648     if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax()))
6649       return false;
6650     break;
6651   }
6652   case ICmpInst::ICMP_NE: {
6653     if (getUnsignedRange(LHS).intersectWith(getUnsignedRange(RHS)).isEmptySet())
6654       return true;
6655     if (getSignedRange(LHS).intersectWith(getSignedRange(RHS)).isEmptySet())
6656       return true;
6657 
6658     const SCEV *Diff = getMinusSCEV(LHS, RHS);
6659     if (isKnownNonZero(Diff))
6660       return true;
6661     break;
6662   }
6663   case ICmpInst::ICMP_EQ:
6664     // The check at the top of the function catches the case where
6665     // the values are known to be equal.
6666     break;
6667   }
6668   return false;
6669 }
6670 
6671 /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
6672 /// protected by a conditional between LHS and RHS.  This is used to
6673 /// to eliminate casts.
6674 bool
6675 ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
6676                                              ICmpInst::Predicate Pred,
6677                                              const SCEV *LHS, const SCEV *RHS) {
6678   // Interpret a null as meaning no loop, where there is obviously no guard
6679   // (interprocedural conditions notwithstanding).
6680   if (!L) return true;
6681 
6682   if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true;
6683 
6684   BasicBlock *Latch = L->getLoopLatch();
6685   if (!Latch)
6686     return false;
6687 
6688   BranchInst *LoopContinuePredicate =
6689     dyn_cast<BranchInst>(Latch->getTerminator());
6690   if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
6691       isImpliedCond(Pred, LHS, RHS,
6692                     LoopContinuePredicate->getCondition(),
6693                     LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
6694     return true;
6695 
6696   // Check conditions due to any @llvm.assume intrinsics.
6697   for (auto &AssumeVH : AC->assumptions()) {
6698     if (!AssumeVH)
6699       continue;
6700     auto *CI = cast<CallInst>(AssumeVH);
6701     if (!DT->dominates(CI, Latch->getTerminator()))
6702       continue;
6703 
6704     if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
6705       return true;
6706   }
6707 
6708   struct ClearWalkingBEDominatingCondsOnExit {
6709     ScalarEvolution &SE;
6710 
6711     explicit ClearWalkingBEDominatingCondsOnExit(ScalarEvolution &SE)
6712         : SE(SE){};
6713 
6714     ~ClearWalkingBEDominatingCondsOnExit() {
6715       SE.WalkingBEDominatingConds = false;
6716     }
6717   };
6718 
6719   // We don't want more than one activation of the following loop on the stack
6720   // -- that can lead to O(n!) time complexity.
6721   if (WalkingBEDominatingConds)
6722     return false;
6723 
6724   WalkingBEDominatingConds = true;
6725   ClearWalkingBEDominatingCondsOnExit ClearOnExit(*this);
6726 
6727   // If the loop is not reachable from the entry block, we risk running into an
6728   // infinite loop as we walk up into the dom tree.  These loops do not matter
6729   // anyway, so we just return a conservative answer when we see them.
6730   if (!DT->isReachableFromEntry(L->getHeader()))
6731     return false;
6732 
6733   for (DomTreeNode *DTN = (*DT)[Latch], *HeaderDTN = (*DT)[L->getHeader()];
6734        DTN != HeaderDTN;
6735        DTN = DTN->getIDom()) {
6736 
6737     assert(DTN && "should reach the loop header before reaching the root!");
6738 
6739     BasicBlock *BB = DTN->getBlock();
6740     BasicBlock *PBB = BB->getSinglePredecessor();
6741     if (!PBB)
6742       continue;
6743 
6744     BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
6745     if (!ContinuePredicate || !ContinuePredicate->isConditional())
6746       continue;
6747 
6748     Value *Condition = ContinuePredicate->getCondition();
6749 
6750     // If we have an edge `E` within the loop body that dominates the only
6751     // latch, the condition guarding `E` also guards the backedge.  This
6752     // reasoning works only for loops with a single latch.
6753 
6754     BasicBlockEdge DominatingEdge(PBB, BB);
6755     if (DominatingEdge.isSingleEdge()) {
6756       // We're constructively (and conservatively) enumerating edges within the
6757       // loop body that dominate the latch.  The dominator tree better agree
6758       // with us on this:
6759       assert(DT->dominates(DominatingEdge, Latch) && "should be!");
6760 
6761       if (isImpliedCond(Pred, LHS, RHS, Condition,
6762                         BB != ContinuePredicate->getSuccessor(0)))
6763         return true;
6764     }
6765   }
6766 
6767   return false;
6768 }
6769 
6770 /// isLoopEntryGuardedByCond - Test whether entry to the loop is protected
6771 /// by a conditional between LHS and RHS.  This is used to help avoid max
6772 /// expressions in loop trip counts, and to eliminate casts.
6773 bool
6774 ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
6775                                           ICmpInst::Predicate Pred,
6776                                           const SCEV *LHS, const SCEV *RHS) {
6777   // Interpret a null as meaning no loop, where there is obviously no guard
6778   // (interprocedural conditions notwithstanding).
6779   if (!L) return false;
6780 
6781   if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true;
6782 
6783   // Starting at the loop predecessor, climb up the predecessor chain, as long
6784   // as there are predecessors that can be found that have unique successors
6785   // leading to the original header.
6786   for (std::pair<BasicBlock *, BasicBlock *>
6787          Pair(L->getLoopPredecessor(), L->getHeader());
6788        Pair.first;
6789        Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
6790 
6791     BranchInst *LoopEntryPredicate =
6792       dyn_cast<BranchInst>(Pair.first->getTerminator());
6793     if (!LoopEntryPredicate ||
6794         LoopEntryPredicate->isUnconditional())
6795       continue;
6796 
6797     if (isImpliedCond(Pred, LHS, RHS,
6798                       LoopEntryPredicate->getCondition(),
6799                       LoopEntryPredicate->getSuccessor(0) != Pair.second))
6800       return true;
6801   }
6802 
6803   // Check conditions due to any @llvm.assume intrinsics.
6804   for (auto &AssumeVH : AC->assumptions()) {
6805     if (!AssumeVH)
6806       continue;
6807     auto *CI = cast<CallInst>(AssumeVH);
6808     if (!DT->dominates(CI, L->getHeader()))
6809       continue;
6810 
6811     if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
6812       return true;
6813   }
6814 
6815   return false;
6816 }
6817 
6818 /// RAII wrapper to prevent recursive application of isImpliedCond.
6819 /// ScalarEvolution's PendingLoopPredicates set must be empty unless we are
6820 /// currently evaluating isImpliedCond.
6821 struct MarkPendingLoopPredicate {
6822   Value *Cond;
6823   DenseSet<Value*> &LoopPreds;
6824   bool Pending;
6825 
6826   MarkPendingLoopPredicate(Value *C, DenseSet<Value*> &LP)
6827     : Cond(C), LoopPreds(LP) {
6828     Pending = !LoopPreds.insert(Cond).second;
6829   }
6830   ~MarkPendingLoopPredicate() {
6831     if (!Pending)
6832       LoopPreds.erase(Cond);
6833   }
6834 };
6835 
6836 /// isImpliedCond - Test whether the condition described by Pred, LHS,
6837 /// and RHS is true whenever the given Cond value evaluates to true.
6838 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
6839                                     const SCEV *LHS, const SCEV *RHS,
6840                                     Value *FoundCondValue,
6841                                     bool Inverse) {
6842   MarkPendingLoopPredicate Mark(FoundCondValue, PendingLoopPredicates);
6843   if (Mark.Pending)
6844     return false;
6845 
6846   // Recursively handle And and Or conditions.
6847   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
6848     if (BO->getOpcode() == Instruction::And) {
6849       if (!Inverse)
6850         return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
6851                isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
6852     } else if (BO->getOpcode() == Instruction::Or) {
6853       if (Inverse)
6854         return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
6855                isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
6856     }
6857   }
6858 
6859   ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
6860   if (!ICI) return false;
6861 
6862   // Now that we found a conditional branch that dominates the loop or controls
6863   // the loop latch. Check to see if it is the comparison we are looking for.
6864   ICmpInst::Predicate FoundPred;
6865   if (Inverse)
6866     FoundPred = ICI->getInversePredicate();
6867   else
6868     FoundPred = ICI->getPredicate();
6869 
6870   const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
6871   const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
6872 
6873   // Balance the types.
6874   if (getTypeSizeInBits(LHS->getType()) <
6875       getTypeSizeInBits(FoundLHS->getType())) {
6876     if (CmpInst::isSigned(Pred)) {
6877       LHS = getSignExtendExpr(LHS, FoundLHS->getType());
6878       RHS = getSignExtendExpr(RHS, FoundLHS->getType());
6879     } else {
6880       LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
6881       RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
6882     }
6883   } else if (getTypeSizeInBits(LHS->getType()) >
6884       getTypeSizeInBits(FoundLHS->getType())) {
6885     if (CmpInst::isSigned(FoundPred)) {
6886       FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
6887       FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
6888     } else {
6889       FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
6890       FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
6891     }
6892   }
6893 
6894   // Canonicalize the query to match the way instcombine will have
6895   // canonicalized the comparison.
6896   if (SimplifyICmpOperands(Pred, LHS, RHS))
6897     if (LHS == RHS)
6898       return CmpInst::isTrueWhenEqual(Pred);
6899   if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
6900     if (FoundLHS == FoundRHS)
6901       return CmpInst::isFalseWhenEqual(FoundPred);
6902 
6903   // Check to see if we can make the LHS or RHS match.
6904   if (LHS == FoundRHS || RHS == FoundLHS) {
6905     if (isa<SCEVConstant>(RHS)) {
6906       std::swap(FoundLHS, FoundRHS);
6907       FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
6908     } else {
6909       std::swap(LHS, RHS);
6910       Pred = ICmpInst::getSwappedPredicate(Pred);
6911     }
6912   }
6913 
6914   // Check whether the found predicate is the same as the desired predicate.
6915   if (FoundPred == Pred)
6916     return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
6917 
6918   // Check whether swapping the found predicate makes it the same as the
6919   // desired predicate.
6920   if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
6921     if (isa<SCEVConstant>(RHS))
6922       return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS);
6923     else
6924       return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred),
6925                                    RHS, LHS, FoundLHS, FoundRHS);
6926   }
6927 
6928   // Check if we can make progress by sharpening ranges.
6929   if (FoundPred == ICmpInst::ICMP_NE &&
6930       (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
6931 
6932     const SCEVConstant *C = nullptr;
6933     const SCEV *V = nullptr;
6934 
6935     if (isa<SCEVConstant>(FoundLHS)) {
6936       C = cast<SCEVConstant>(FoundLHS);
6937       V = FoundRHS;
6938     } else {
6939       C = cast<SCEVConstant>(FoundRHS);
6940       V = FoundLHS;
6941     }
6942 
6943     // The guarding predicate tells us that C != V. If the known range
6944     // of V is [C, t), we can sharpen the range to [C + 1, t).  The
6945     // range we consider has to correspond to same signedness as the
6946     // predicate we're interested in folding.
6947 
6948     APInt Min = ICmpInst::isSigned(Pred) ?
6949         getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin();
6950 
6951     if (Min == C->getValue()->getValue()) {
6952       // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
6953       // This is true even if (Min + 1) wraps around -- in case of
6954       // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
6955 
6956       APInt SharperMin = Min + 1;
6957 
6958       switch (Pred) {
6959         case ICmpInst::ICMP_SGE:
6960         case ICmpInst::ICMP_UGE:
6961           // We know V `Pred` SharperMin.  If this implies LHS `Pred`
6962           // RHS, we're done.
6963           if (isImpliedCondOperands(Pred, LHS, RHS, V,
6964                                     getConstant(SharperMin)))
6965             return true;
6966 
6967         case ICmpInst::ICMP_SGT:
6968         case ICmpInst::ICMP_UGT:
6969           // We know from the range information that (V `Pred` Min ||
6970           // V == Min).  We know from the guarding condition that !(V
6971           // == Min).  This gives us
6972           //
6973           //       V `Pred` Min || V == Min && !(V == Min)
6974           //   =>  V `Pred` Min
6975           //
6976           // If V `Pred` Min implies LHS `Pred` RHS, we're done.
6977 
6978           if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min)))
6979             return true;
6980 
6981         default:
6982           // No change
6983           break;
6984       }
6985     }
6986   }
6987 
6988   // Check whether the actual condition is beyond sufficient.
6989   if (FoundPred == ICmpInst::ICMP_EQ)
6990     if (ICmpInst::isTrueWhenEqual(Pred))
6991       if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS))
6992         return true;
6993   if (Pred == ICmpInst::ICMP_NE)
6994     if (!ICmpInst::isTrueWhenEqual(FoundPred))
6995       if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS))
6996         return true;
6997 
6998   // Otherwise assume the worst.
6999   return false;
7000 }
7001 
7002 /// isImpliedCondOperands - Test whether the condition described by Pred,
7003 /// LHS, and RHS is true whenever the condition described by Pred, FoundLHS,
7004 /// and FoundRHS is true.
7005 bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
7006                                             const SCEV *LHS, const SCEV *RHS,
7007                                             const SCEV *FoundLHS,
7008                                             const SCEV *FoundRHS) {
7009   if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
7010     return true;
7011 
7012   return isImpliedCondOperandsHelper(Pred, LHS, RHS,
7013                                      FoundLHS, FoundRHS) ||
7014          // ~x < ~y --> x > y
7015          isImpliedCondOperandsHelper(Pred, LHS, RHS,
7016                                      getNotSCEV(FoundRHS),
7017                                      getNotSCEV(FoundLHS));
7018 }
7019 
7020 
7021 /// If Expr computes ~A, return A else return nullptr
7022 static const SCEV *MatchNotExpr(const SCEV *Expr) {
7023   const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
7024   if (!Add || Add->getNumOperands() != 2) return nullptr;
7025 
7026   const SCEVConstant *AddLHS = dyn_cast<SCEVConstant>(Add->getOperand(0));
7027   if (!(AddLHS && AddLHS->getValue()->getValue().isAllOnesValue()))
7028     return nullptr;
7029 
7030   const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
7031   if (!AddRHS || AddRHS->getNumOperands() != 2) return nullptr;
7032 
7033   const SCEVConstant *MulLHS = dyn_cast<SCEVConstant>(AddRHS->getOperand(0));
7034   if (!(MulLHS && MulLHS->getValue()->getValue().isAllOnesValue()))
7035     return nullptr;
7036 
7037   return AddRHS->getOperand(1);
7038 }
7039 
7040 
7041 /// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values?
7042 template<typename MaxExprType>
7043 static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr,
7044                               const SCEV *Candidate) {
7045   const MaxExprType *MaxExpr = dyn_cast<MaxExprType>(MaybeMaxExpr);
7046   if (!MaxExpr) return false;
7047 
7048   auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate);
7049   return It != MaxExpr->op_end();
7050 }
7051 
7052 
7053 /// Is MaybeMinExpr an SMin or UMin of Candidate and some other values?
7054 template<typename MaxExprType>
7055 static bool IsMinConsistingOf(ScalarEvolution &SE,
7056                               const SCEV *MaybeMinExpr,
7057                               const SCEV *Candidate) {
7058   const SCEV *MaybeMaxExpr = MatchNotExpr(MaybeMinExpr);
7059   if (!MaybeMaxExpr)
7060     return false;
7061 
7062   return IsMaxConsistingOf<MaxExprType>(MaybeMaxExpr, SE.getNotSCEV(Candidate));
7063 }
7064 
7065 
7066 /// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
7067 /// expression?
7068 static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
7069                                         ICmpInst::Predicate Pred,
7070                                         const SCEV *LHS, const SCEV *RHS) {
7071   switch (Pred) {
7072   default:
7073     return false;
7074 
7075   case ICmpInst::ICMP_SGE:
7076     std::swap(LHS, RHS);
7077     // fall through
7078   case ICmpInst::ICMP_SLE:
7079     return
7080       // min(A, ...) <= A
7081       IsMinConsistingOf<SCEVSMaxExpr>(SE, LHS, RHS) ||
7082       // A <= max(A, ...)
7083       IsMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
7084 
7085   case ICmpInst::ICMP_UGE:
7086     std::swap(LHS, RHS);
7087     // fall through
7088   case ICmpInst::ICMP_ULE:
7089     return
7090       // min(A, ...) <= A
7091       IsMinConsistingOf<SCEVUMaxExpr>(SE, LHS, RHS) ||
7092       // A <= max(A, ...)
7093       IsMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
7094   }
7095 
7096   llvm_unreachable("covered switch fell through?!");
7097 }
7098 
7099 /// isImpliedCondOperandsHelper - Test whether the condition described by
7100 /// Pred, LHS, and RHS is true whenever the condition described by Pred,
7101 /// FoundLHS, and FoundRHS is true.
7102 bool
7103 ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
7104                                              const SCEV *LHS, const SCEV *RHS,
7105                                              const SCEV *FoundLHS,
7106                                              const SCEV *FoundRHS) {
7107   auto IsKnownPredicateFull =
7108       [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
7109     return isKnownPredicateWithRanges(Pred, LHS, RHS) ||
7110         IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS);
7111   };
7112 
7113   switch (Pred) {
7114   default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
7115   case ICmpInst::ICMP_EQ:
7116   case ICmpInst::ICMP_NE:
7117     if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
7118       return true;
7119     break;
7120   case ICmpInst::ICMP_SLT:
7121   case ICmpInst::ICMP_SLE:
7122     if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
7123         IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS))
7124       return true;
7125     break;
7126   case ICmpInst::ICMP_SGT:
7127   case ICmpInst::ICMP_SGE:
7128     if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
7129         IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS))
7130       return true;
7131     break;
7132   case ICmpInst::ICMP_ULT:
7133   case ICmpInst::ICMP_ULE:
7134     if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
7135         IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS))
7136       return true;
7137     break;
7138   case ICmpInst::ICMP_UGT:
7139   case ICmpInst::ICMP_UGE:
7140     if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
7141         IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS))
7142       return true;
7143     break;
7144   }
7145 
7146   return false;
7147 }
7148 
7149 /// isImpliedCondOperandsViaRanges - helper function for isImpliedCondOperands.
7150 /// Tries to get cases like "X `sgt` 0 => X - 1 `sgt` -1".
7151 bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
7152                                                      const SCEV *LHS,
7153                                                      const SCEV *RHS,
7154                                                      const SCEV *FoundLHS,
7155                                                      const SCEV *FoundRHS) {
7156   if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
7157     // The restriction on `FoundRHS` be lifted easily -- it exists only to
7158     // reduce the compile time impact of this optimization.
7159     return false;
7160 
7161   const SCEVAddExpr *AddLHS = dyn_cast<SCEVAddExpr>(LHS);
7162   if (!AddLHS || AddLHS->getOperand(1) != FoundLHS ||
7163       !isa<SCEVConstant>(AddLHS->getOperand(0)))
7164     return false;
7165 
7166   APInt ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getValue()->getValue();
7167 
7168   // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
7169   // antecedent "`FoundLHS` `Pred` `FoundRHS`".
7170   ConstantRange FoundLHSRange =
7171       ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS);
7172 
7173   // Since `LHS` is `FoundLHS` + `AddLHS->getOperand(0)`, we can compute a range
7174   // for `LHS`:
7175   APInt Addend =
7176       cast<SCEVConstant>(AddLHS->getOperand(0))->getValue()->getValue();
7177   ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(Addend));
7178 
7179   // We can also compute the range of values for `LHS` that satisfy the
7180   // consequent, "`LHS` `Pred` `RHS`":
7181   APInt ConstRHS = cast<SCEVConstant>(RHS)->getValue()->getValue();
7182   ConstantRange SatisfyingLHSRange =
7183       ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS);
7184 
7185   // The antecedent implies the consequent if every value of `LHS` that
7186   // satisfies the antecedent also satisfies the consequent.
7187   return SatisfyingLHSRange.contains(LHSRange);
7188 }
7189 
7190 // Verify if an linear IV with positive stride can overflow when in a
7191 // less-than comparison, knowing the invariant term of the comparison, the
7192 // stride and the knowledge of NSW/NUW flags on the recurrence.
7193 bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
7194                                          bool IsSigned, bool NoWrap) {
7195   if (NoWrap) return false;
7196 
7197   unsigned BitWidth = getTypeSizeInBits(RHS->getType());
7198   const SCEV *One = getConstant(Stride->getType(), 1);
7199 
7200   if (IsSigned) {
7201     APInt MaxRHS = getSignedRange(RHS).getSignedMax();
7202     APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
7203     APInt MaxStrideMinusOne = getSignedRange(getMinusSCEV(Stride, One))
7204                                 .getSignedMax();
7205 
7206     // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
7207     return (MaxValue - MaxStrideMinusOne).slt(MaxRHS);
7208   }
7209 
7210   APInt MaxRHS = getUnsignedRange(RHS).getUnsignedMax();
7211   APInt MaxValue = APInt::getMaxValue(BitWidth);
7212   APInt MaxStrideMinusOne = getUnsignedRange(getMinusSCEV(Stride, One))
7213                               .getUnsignedMax();
7214 
7215   // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
7216   return (MaxValue - MaxStrideMinusOne).ult(MaxRHS);
7217 }
7218 
7219 // Verify if an linear IV with negative stride can overflow when in a
7220 // greater-than comparison, knowing the invariant term of the comparison,
7221 // the stride and the knowledge of NSW/NUW flags on the recurrence.
7222 bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
7223                                          bool IsSigned, bool NoWrap) {
7224   if (NoWrap) return false;
7225 
7226   unsigned BitWidth = getTypeSizeInBits(RHS->getType());
7227   const SCEV *One = getConstant(Stride->getType(), 1);
7228 
7229   if (IsSigned) {
7230     APInt MinRHS = getSignedRange(RHS).getSignedMin();
7231     APInt MinValue = APInt::getSignedMinValue(BitWidth);
7232     APInt MaxStrideMinusOne = getSignedRange(getMinusSCEV(Stride, One))
7233                                .getSignedMax();
7234 
7235     // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
7236     return (MinValue + MaxStrideMinusOne).sgt(MinRHS);
7237   }
7238 
7239   APInt MinRHS = getUnsignedRange(RHS).getUnsignedMin();
7240   APInt MinValue = APInt::getMinValue(BitWidth);
7241   APInt MaxStrideMinusOne = getUnsignedRange(getMinusSCEV(Stride, One))
7242                             .getUnsignedMax();
7243 
7244   // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
7245   return (MinValue + MaxStrideMinusOne).ugt(MinRHS);
7246 }
7247 
7248 // Compute the backedge taken count knowing the interval difference, the
7249 // stride and presence of the equality in the comparison.
7250 const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step,
7251                                             bool Equality) {
7252   const SCEV *One = getConstant(Step->getType(), 1);
7253   Delta = Equality ? getAddExpr(Delta, Step)
7254                    : getAddExpr(Delta, getMinusSCEV(Step, One));
7255   return getUDivExpr(Delta, Step);
7256 }
7257 
7258 /// HowManyLessThans - Return the number of times a backedge containing the
7259 /// specified less-than comparison will execute.  If not computable, return
7260 /// CouldNotCompute.
7261 ///
7262 /// @param ControlsExit is true when the LHS < RHS condition directly controls
7263 /// the branch (loops exits only if condition is true). In this case, we can use
7264 /// NoWrapFlags to skip overflow checks.
7265 ScalarEvolution::ExitLimit
7266 ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
7267                                   const Loop *L, bool IsSigned,
7268                                   bool ControlsExit) {
7269   // We handle only IV < Invariant
7270   if (!isLoopInvariant(RHS, L))
7271     return getCouldNotCompute();
7272 
7273   const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
7274 
7275   // Avoid weird loops
7276   if (!IV || IV->getLoop() != L || !IV->isAffine())
7277     return getCouldNotCompute();
7278 
7279   bool NoWrap = ControlsExit &&
7280                 IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW);
7281 
7282   const SCEV *Stride = IV->getStepRecurrence(*this);
7283 
7284   // Avoid negative or zero stride values
7285   if (!isKnownPositive(Stride))
7286     return getCouldNotCompute();
7287 
7288   // Avoid proven overflow cases: this will ensure that the backedge taken count
7289   // will not generate any unsigned overflow. Relaxed no-overflow conditions
7290   // exploit NoWrapFlags, allowing to optimize in presence of undefined
7291   // behaviors like the case of C language.
7292   if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap))
7293     return getCouldNotCompute();
7294 
7295   ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT
7296                                       : ICmpInst::ICMP_ULT;
7297   const SCEV *Start = IV->getStart();
7298   const SCEV *End = RHS;
7299   if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) {
7300     const SCEV *Diff = getMinusSCEV(RHS, Start);
7301     // If we have NoWrap set, then we can assume that the increment won't
7302     // overflow, in which case if RHS - Start is a constant, we don't need to
7303     // do a max operation since we can just figure it out statically
7304     if (NoWrap && isa<SCEVConstant>(Diff)) {
7305       APInt D = dyn_cast<const SCEVConstant>(Diff)->getValue()->getValue();
7306       if (D.isNegative())
7307         End = Start;
7308     } else
7309       End = IsSigned ? getSMaxExpr(RHS, Start)
7310                      : getUMaxExpr(RHS, Start);
7311   }
7312 
7313   const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false);
7314 
7315   APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin()
7316                             : getUnsignedRange(Start).getUnsignedMin();
7317 
7318   APInt MinStride = IsSigned ? getSignedRange(Stride).getSignedMin()
7319                              : getUnsignedRange(Stride).getUnsignedMin();
7320 
7321   unsigned BitWidth = getTypeSizeInBits(LHS->getType());
7322   APInt Limit = IsSigned ? APInt::getSignedMaxValue(BitWidth) - (MinStride - 1)
7323                          : APInt::getMaxValue(BitWidth) - (MinStride - 1);
7324 
7325   // Although End can be a MAX expression we estimate MaxEnd considering only
7326   // the case End = RHS. This is safe because in the other case (End - Start)
7327   // is zero, leading to a zero maximum backedge taken count.
7328   APInt MaxEnd =
7329     IsSigned ? APIntOps::smin(getSignedRange(RHS).getSignedMax(), Limit)
7330              : APIntOps::umin(getUnsignedRange(RHS).getUnsignedMax(), Limit);
7331 
7332   const SCEV *MaxBECount;
7333   if (isa<SCEVConstant>(BECount))
7334     MaxBECount = BECount;
7335   else
7336     MaxBECount = computeBECount(getConstant(MaxEnd - MinStart),
7337                                 getConstant(MinStride), false);
7338 
7339   if (isa<SCEVCouldNotCompute>(MaxBECount))
7340     MaxBECount = BECount;
7341 
7342   return ExitLimit(BECount, MaxBECount);
7343 }
7344 
7345 ScalarEvolution::ExitLimit
7346 ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
7347                                      const Loop *L, bool IsSigned,
7348                                      bool ControlsExit) {
7349   // We handle only IV > Invariant
7350   if (!isLoopInvariant(RHS, L))
7351     return getCouldNotCompute();
7352 
7353   const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
7354 
7355   // Avoid weird loops
7356   if (!IV || IV->getLoop() != L || !IV->isAffine())
7357     return getCouldNotCompute();
7358 
7359   bool NoWrap = ControlsExit &&
7360                 IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW);
7361 
7362   const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
7363 
7364   // Avoid negative or zero stride values
7365   if (!isKnownPositive(Stride))
7366     return getCouldNotCompute();
7367 
7368   // Avoid proven overflow cases: this will ensure that the backedge taken count
7369   // will not generate any unsigned overflow. Relaxed no-overflow conditions
7370   // exploit NoWrapFlags, allowing to optimize in presence of undefined
7371   // behaviors like the case of C language.
7372   if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap))
7373     return getCouldNotCompute();
7374 
7375   ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT
7376                                       : ICmpInst::ICMP_UGT;
7377 
7378   const SCEV *Start = IV->getStart();
7379   const SCEV *End = RHS;
7380   if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
7381     const SCEV *Diff = getMinusSCEV(RHS, Start);
7382     // If we have NoWrap set, then we can assume that the increment won't
7383     // overflow, in which case if RHS - Start is a constant, we don't need to
7384     // do a max operation since we can just figure it out statically
7385     if (NoWrap && isa<SCEVConstant>(Diff)) {
7386       APInt D = dyn_cast<const SCEVConstant>(Diff)->getValue()->getValue();
7387       if (!D.isNegative())
7388         End = Start;
7389     } else
7390       End = IsSigned ? getSMinExpr(RHS, Start)
7391                      : getUMinExpr(RHS, Start);
7392   }
7393 
7394   const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false);
7395 
7396   APInt MaxStart = IsSigned ? getSignedRange(Start).getSignedMax()
7397                             : getUnsignedRange(Start).getUnsignedMax();
7398 
7399   APInt MinStride = IsSigned ? getSignedRange(Stride).getSignedMin()
7400                              : getUnsignedRange(Stride).getUnsignedMin();
7401 
7402   unsigned BitWidth = getTypeSizeInBits(LHS->getType());
7403   APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
7404                          : APInt::getMinValue(BitWidth) + (MinStride - 1);
7405 
7406   // Although End can be a MIN expression we estimate MinEnd considering only
7407   // the case End = RHS. This is safe because in the other case (Start - End)
7408   // is zero, leading to a zero maximum backedge taken count.
7409   APInt MinEnd =
7410     IsSigned ? APIntOps::smax(getSignedRange(RHS).getSignedMin(), Limit)
7411              : APIntOps::umax(getUnsignedRange(RHS).getUnsignedMin(), Limit);
7412 
7413 
7414   const SCEV *MaxBECount = getCouldNotCompute();
7415   if (isa<SCEVConstant>(BECount))
7416     MaxBECount = BECount;
7417   else
7418     MaxBECount = computeBECount(getConstant(MaxStart - MinEnd),
7419                                 getConstant(MinStride), false);
7420 
7421   if (isa<SCEVCouldNotCompute>(MaxBECount))
7422     MaxBECount = BECount;
7423 
7424   return ExitLimit(BECount, MaxBECount);
7425 }
7426 
7427 /// getNumIterationsInRange - Return the number of iterations of this loop that
7428 /// produce values in the specified constant range.  Another way of looking at
7429 /// this is that it returns the first iteration number where the value is not in
7430 /// the condition, thus computing the exit count. If the iteration count can't
7431 /// be computed, an instance of SCEVCouldNotCompute is returned.
7432 const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
7433                                                     ScalarEvolution &SE) const {
7434   if (Range.isFullSet())  // Infinite loop.
7435     return SE.getCouldNotCompute();
7436 
7437   // If the start is a non-zero constant, shift the range to simplify things.
7438   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
7439     if (!SC->getValue()->isZero()) {
7440       SmallVector<const SCEV *, 4> Operands(op_begin(), op_end());
7441       Operands[0] = SE.getConstant(SC->getType(), 0);
7442       const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
7443                                              getNoWrapFlags(FlagNW));
7444       if (const SCEVAddRecExpr *ShiftedAddRec =
7445             dyn_cast<SCEVAddRecExpr>(Shifted))
7446         return ShiftedAddRec->getNumIterationsInRange(
7447                            Range.subtract(SC->getValue()->getValue()), SE);
7448       // This is strange and shouldn't happen.
7449       return SE.getCouldNotCompute();
7450     }
7451 
7452   // The only time we can solve this is when we have all constant indices.
7453   // Otherwise, we cannot determine the overflow conditions.
7454   for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
7455     if (!isa<SCEVConstant>(getOperand(i)))
7456       return SE.getCouldNotCompute();
7457 
7458 
7459   // Okay at this point we know that all elements of the chrec are constants and
7460   // that the start element is zero.
7461 
7462   // First check to see if the range contains zero.  If not, the first
7463   // iteration exits.
7464   unsigned BitWidth = SE.getTypeSizeInBits(getType());
7465   if (!Range.contains(APInt(BitWidth, 0)))
7466     return SE.getConstant(getType(), 0);
7467 
7468   if (isAffine()) {
7469     // If this is an affine expression then we have this situation:
7470     //   Solve {0,+,A} in Range  ===  Ax in Range
7471 
7472     // We know that zero is in the range.  If A is positive then we know that
7473     // the upper value of the range must be the first possible exit value.
7474     // If A is negative then the lower of the range is the last possible loop
7475     // value.  Also note that we already checked for a full range.
7476     APInt One(BitWidth,1);
7477     APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
7478     APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
7479 
7480     // The exit value should be (End+A)/A.
7481     APInt ExitVal = (End + A).udiv(A);
7482     ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
7483 
7484     // Evaluate at the exit value.  If we really did fall out of the valid
7485     // range, then we computed our trip count, otherwise wrap around or other
7486     // things must have happened.
7487     ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
7488     if (Range.contains(Val->getValue()))
7489       return SE.getCouldNotCompute();  // Something strange happened
7490 
7491     // Ensure that the previous value is in the range.  This is a sanity check.
7492     assert(Range.contains(
7493            EvaluateConstantChrecAtConstant(this,
7494            ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) &&
7495            "Linear scev computation is off in a bad way!");
7496     return SE.getConstant(ExitValue);
7497   } else if (isQuadratic()) {
7498     // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
7499     // quadratic equation to solve it.  To do this, we must frame our problem in
7500     // terms of figuring out when zero is crossed, instead of when
7501     // Range.getUpper() is crossed.
7502     SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end());
7503     NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
7504     const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(),
7505                                              // getNoWrapFlags(FlagNW)
7506                                              FlagAnyWrap);
7507 
7508     // Next, solve the constructed addrec
7509     std::pair<const SCEV *,const SCEV *> Roots =
7510       SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
7511     const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
7512     const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
7513     if (R1) {
7514       // Pick the smallest positive root value.
7515       if (ConstantInt *CB =
7516           dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
7517                          R1->getValue(), R2->getValue()))) {
7518         if (!CB->getZExtValue())
7519           std::swap(R1, R2);   // R1 is the minimum root now.
7520 
7521         // Make sure the root is not off by one.  The returned iteration should
7522         // not be in the range, but the previous one should be.  When solving
7523         // for "X*X < 5", for example, we should not return a root of 2.
7524         ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
7525                                                              R1->getValue(),
7526                                                              SE);
7527         if (Range.contains(R1Val->getValue())) {
7528           // The next iteration must be out of the range...
7529           ConstantInt *NextVal =
7530                 ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1);
7531 
7532           R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
7533           if (!Range.contains(R1Val->getValue()))
7534             return SE.getConstant(NextVal);
7535           return SE.getCouldNotCompute();  // Something strange happened
7536         }
7537 
7538         // If R1 was not in the range, then it is a good return value.  Make
7539         // sure that R1-1 WAS in the range though, just in case.
7540         ConstantInt *NextVal =
7541                ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1);
7542         R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
7543         if (Range.contains(R1Val->getValue()))
7544           return R1;
7545         return SE.getCouldNotCompute();  // Something strange happened
7546       }
7547     }
7548   }
7549 
7550   return SE.getCouldNotCompute();
7551 }
7552 
7553 namespace {
7554 struct FindUndefs {
7555   bool Found;
7556   FindUndefs() : Found(false) {}
7557 
7558   bool follow(const SCEV *S) {
7559     if (const SCEVUnknown *C = dyn_cast<SCEVUnknown>(S)) {
7560       if (isa<UndefValue>(C->getValue()))
7561         Found = true;
7562     } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
7563       if (isa<UndefValue>(C->getValue()))
7564         Found = true;
7565     }
7566 
7567     // Keep looking if we haven't found it yet.
7568     return !Found;
7569   }
7570   bool isDone() const {
7571     // Stop recursion if we have found an undef.
7572     return Found;
7573   }
7574 };
7575 }
7576 
7577 // Return true when S contains at least an undef value.
7578 static inline bool
7579 containsUndefs(const SCEV *S) {
7580   FindUndefs F;
7581   SCEVTraversal<FindUndefs> ST(F);
7582   ST.visitAll(S);
7583 
7584   return F.Found;
7585 }
7586 
7587 namespace {
7588 // Collect all steps of SCEV expressions.
7589 struct SCEVCollectStrides {
7590   ScalarEvolution &SE;
7591   SmallVectorImpl<const SCEV *> &Strides;
7592 
7593   SCEVCollectStrides(ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &S)
7594       : SE(SE), Strides(S) {}
7595 
7596   bool follow(const SCEV *S) {
7597     if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S))
7598       Strides.push_back(AR->getStepRecurrence(SE));
7599     return true;
7600   }
7601   bool isDone() const { return false; }
7602 };
7603 
7604 // Collect all SCEVUnknown and SCEVMulExpr expressions.
7605 struct SCEVCollectTerms {
7606   SmallVectorImpl<const SCEV *> &Terms;
7607 
7608   SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T)
7609       : Terms(T) {}
7610 
7611   bool follow(const SCEV *S) {
7612     if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S)) {
7613       if (!containsUndefs(S))
7614         Terms.push_back(S);
7615 
7616       // Stop recursion: once we collected a term, do not walk its operands.
7617       return false;
7618     }
7619 
7620     // Keep looking.
7621     return true;
7622   }
7623   bool isDone() const { return false; }
7624 };
7625 }
7626 
7627 /// Find parametric terms in this SCEVAddRecExpr.
7628 void SCEVAddRecExpr::collectParametricTerms(
7629     ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &Terms) const {
7630   SmallVector<const SCEV *, 4> Strides;
7631   SCEVCollectStrides StrideCollector(SE, Strides);
7632   visitAll(this, StrideCollector);
7633 
7634   DEBUG({
7635       dbgs() << "Strides:\n";
7636       for (const SCEV *S : Strides)
7637         dbgs() << *S << "\n";
7638     });
7639 
7640   for (const SCEV *S : Strides) {
7641     SCEVCollectTerms TermCollector(Terms);
7642     visitAll(S, TermCollector);
7643   }
7644 
7645   DEBUG({
7646       dbgs() << "Terms:\n";
7647       for (const SCEV *T : Terms)
7648         dbgs() << *T << "\n";
7649     });
7650 }
7651 
7652 static bool findArrayDimensionsRec(ScalarEvolution &SE,
7653                                    SmallVectorImpl<const SCEV *> &Terms,
7654                                    SmallVectorImpl<const SCEV *> &Sizes) {
7655   int Last = Terms.size() - 1;
7656   const SCEV *Step = Terms[Last];
7657 
7658   // End of recursion.
7659   if (Last == 0) {
7660     if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Step)) {
7661       SmallVector<const SCEV *, 2> Qs;
7662       for (const SCEV *Op : M->operands())
7663         if (!isa<SCEVConstant>(Op))
7664           Qs.push_back(Op);
7665 
7666       Step = SE.getMulExpr(Qs);
7667     }
7668 
7669     Sizes.push_back(Step);
7670     return true;
7671   }
7672 
7673   for (const SCEV *&Term : Terms) {
7674     // Normalize the terms before the next call to findArrayDimensionsRec.
7675     const SCEV *Q, *R;
7676     SCEVDivision::divide(SE, Term, Step, &Q, &R);
7677 
7678     // Bail out when GCD does not evenly divide one of the terms.
7679     if (!R->isZero())
7680       return false;
7681 
7682     Term = Q;
7683   }
7684 
7685   // Remove all SCEVConstants.
7686   Terms.erase(std::remove_if(Terms.begin(), Terms.end(), [](const SCEV *E) {
7687                 return isa<SCEVConstant>(E);
7688               }),
7689               Terms.end());
7690 
7691   if (Terms.size() > 0)
7692     if (!findArrayDimensionsRec(SE, Terms, Sizes))
7693       return false;
7694 
7695   Sizes.push_back(Step);
7696   return true;
7697 }
7698 
7699 namespace {
7700 struct FindParameter {
7701   bool FoundParameter;
7702   FindParameter() : FoundParameter(false) {}
7703 
7704   bool follow(const SCEV *S) {
7705     if (isa<SCEVUnknown>(S)) {
7706       FoundParameter = true;
7707       // Stop recursion: we found a parameter.
7708       return false;
7709     }
7710     // Keep looking.
7711     return true;
7712   }
7713   bool isDone() const {
7714     // Stop recursion if we have found a parameter.
7715     return FoundParameter;
7716   }
7717 };
7718 }
7719 
7720 // Returns true when S contains at least a SCEVUnknown parameter.
7721 static inline bool
7722 containsParameters(const SCEV *S) {
7723   FindParameter F;
7724   SCEVTraversal<FindParameter> ST(F);
7725   ST.visitAll(S);
7726 
7727   return F.FoundParameter;
7728 }
7729 
7730 // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter.
7731 static inline bool
7732 containsParameters(SmallVectorImpl<const SCEV *> &Terms) {
7733   for (const SCEV *T : Terms)
7734     if (containsParameters(T))
7735       return true;
7736   return false;
7737 }
7738 
7739 // Return the number of product terms in S.
7740 static inline int numberOfTerms(const SCEV *S) {
7741   if (const SCEVMulExpr *Expr = dyn_cast<SCEVMulExpr>(S))
7742     return Expr->getNumOperands();
7743   return 1;
7744 }
7745 
7746 static const SCEV *removeConstantFactors(ScalarEvolution &SE, const SCEV *T) {
7747   if (isa<SCEVConstant>(T))
7748     return nullptr;
7749 
7750   if (isa<SCEVUnknown>(T))
7751     return T;
7752 
7753   if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(T)) {
7754     SmallVector<const SCEV *, 2> Factors;
7755     for (const SCEV *Op : M->operands())
7756       if (!isa<SCEVConstant>(Op))
7757         Factors.push_back(Op);
7758 
7759     return SE.getMulExpr(Factors);
7760   }
7761 
7762   return T;
7763 }
7764 
7765 /// Return the size of an element read or written by Inst.
7766 const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) {
7767   Type *Ty;
7768   if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
7769     Ty = Store->getValueOperand()->getType();
7770   else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
7771     Ty = Load->getType();
7772   else
7773     return nullptr;
7774 
7775   Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty));
7776   return getSizeOfExpr(ETy, Ty);
7777 }
7778 
7779 /// Second step of delinearization: compute the array dimensions Sizes from the
7780 /// set of Terms extracted from the memory access function of this SCEVAddRec.
7781 void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,
7782                                           SmallVectorImpl<const SCEV *> &Sizes,
7783                                           const SCEV *ElementSize) const {
7784 
7785   if (Terms.size() < 1 || !ElementSize)
7786     return;
7787 
7788   // Early return when Terms do not contain parameters: we do not delinearize
7789   // non parametric SCEVs.
7790   if (!containsParameters(Terms))
7791     return;
7792 
7793   DEBUG({
7794       dbgs() << "Terms:\n";
7795       for (const SCEV *T : Terms)
7796         dbgs() << *T << "\n";
7797     });
7798 
7799   // Remove duplicates.
7800   std::sort(Terms.begin(), Terms.end());
7801   Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end());
7802 
7803   // Put larger terms first.
7804   std::sort(Terms.begin(), Terms.end(), [](const SCEV *LHS, const SCEV *RHS) {
7805     return numberOfTerms(LHS) > numberOfTerms(RHS);
7806   });
7807 
7808   ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
7809 
7810   // Divide all terms by the element size.
7811   for (const SCEV *&Term : Terms) {
7812     const SCEV *Q, *R;
7813     SCEVDivision::divide(SE, Term, ElementSize, &Q, &R);
7814     Term = Q;
7815   }
7816 
7817   SmallVector<const SCEV *, 4> NewTerms;
7818 
7819   // Remove constant factors.
7820   for (const SCEV *T : Terms)
7821     if (const SCEV *NewT = removeConstantFactors(SE, T))
7822       NewTerms.push_back(NewT);
7823 
7824   DEBUG({
7825       dbgs() << "Terms after sorting:\n";
7826       for (const SCEV *T : NewTerms)
7827         dbgs() << *T << "\n";
7828     });
7829 
7830   if (NewTerms.empty() ||
7831       !findArrayDimensionsRec(SE, NewTerms, Sizes)) {
7832     Sizes.clear();
7833     return;
7834   }
7835 
7836   // The last element to be pushed into Sizes is the size of an element.
7837   Sizes.push_back(ElementSize);
7838 
7839   DEBUG({
7840       dbgs() << "Sizes:\n";
7841       for (const SCEV *S : Sizes)
7842         dbgs() << *S << "\n";
7843     });
7844 }
7845 
7846 /// Third step of delinearization: compute the access functions for the
7847 /// Subscripts based on the dimensions in Sizes.
7848 void SCEVAddRecExpr::computeAccessFunctions(
7849     ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &Subscripts,
7850     SmallVectorImpl<const SCEV *> &Sizes) const {
7851 
7852   // Early exit in case this SCEV is not an affine multivariate function.
7853   if (Sizes.empty() || !this->isAffine())
7854     return;
7855 
7856   const SCEV *Res = this;
7857   int Last = Sizes.size() - 1;
7858   for (int i = Last; i >= 0; i--) {
7859     const SCEV *Q, *R;
7860     SCEVDivision::divide(SE, Res, Sizes[i], &Q, &R);
7861 
7862     DEBUG({
7863         dbgs() << "Res: " << *Res << "\n";
7864         dbgs() << "Sizes[i]: " << *Sizes[i] << "\n";
7865         dbgs() << "Res divided by Sizes[i]:\n";
7866         dbgs() << "Quotient: " << *Q << "\n";
7867         dbgs() << "Remainder: " << *R << "\n";
7868       });
7869 
7870     Res = Q;
7871 
7872     // Do not record the last subscript corresponding to the size of elements in
7873     // the array.
7874     if (i == Last) {
7875 
7876       // Bail out if the remainder is too complex.
7877       if (isa<SCEVAddRecExpr>(R)) {
7878         Subscripts.clear();
7879         Sizes.clear();
7880         return;
7881       }
7882 
7883       continue;
7884     }
7885 
7886     // Record the access function for the current subscript.
7887     Subscripts.push_back(R);
7888   }
7889 
7890   // Also push in last position the remainder of the last division: it will be
7891   // the access function of the innermost dimension.
7892   Subscripts.push_back(Res);
7893 
7894   std::reverse(Subscripts.begin(), Subscripts.end());
7895 
7896   DEBUG({
7897       dbgs() << "Subscripts:\n";
7898       for (const SCEV *S : Subscripts)
7899         dbgs() << *S << "\n";
7900     });
7901 }
7902 
7903 /// Splits the SCEV into two vectors of SCEVs representing the subscripts and
7904 /// sizes of an array access. Returns the remainder of the delinearization that
7905 /// is the offset start of the array.  The SCEV->delinearize algorithm computes
7906 /// the multiples of SCEV coefficients: that is a pattern matching of sub
7907 /// expressions in the stride and base of a SCEV corresponding to the
7908 /// computation of a GCD (greatest common divisor) of base and stride.  When
7909 /// SCEV->delinearize fails, it returns the SCEV unchanged.
7910 ///
7911 /// For example: when analyzing the memory access A[i][j][k] in this loop nest
7912 ///
7913 ///  void foo(long n, long m, long o, double A[n][m][o]) {
7914 ///
7915 ///    for (long i = 0; i < n; i++)
7916 ///      for (long j = 0; j < m; j++)
7917 ///        for (long k = 0; k < o; k++)
7918 ///          A[i][j][k] = 1.0;
7919 ///  }
7920 ///
7921 /// the delinearization input is the following AddRec SCEV:
7922 ///
7923 ///  AddRec: {{{%A,+,(8 * %m * %o)}<%for.i>,+,(8 * %o)}<%for.j>,+,8}<%for.k>
7924 ///
7925 /// From this SCEV, we are able to say that the base offset of the access is %A
7926 /// because it appears as an offset that does not divide any of the strides in
7927 /// the loops:
7928 ///
7929 ///  CHECK: Base offset: %A
7930 ///
7931 /// and then SCEV->delinearize determines the size of some of the dimensions of
7932 /// the array as these are the multiples by which the strides are happening:
7933 ///
7934 ///  CHECK: ArrayDecl[UnknownSize][%m][%o] with elements of sizeof(double) bytes.
7935 ///
7936 /// Note that the outermost dimension remains of UnknownSize because there are
7937 /// no strides that would help identifying the size of the last dimension: when
7938 /// the array has been statically allocated, one could compute the size of that
7939 /// dimension by dividing the overall size of the array by the size of the known
7940 /// dimensions: %m * %o * 8.
7941 ///
7942 /// Finally delinearize provides the access functions for the array reference
7943 /// that does correspond to A[i][j][k] of the above C testcase:
7944 ///
7945 ///  CHECK: ArrayRef[{0,+,1}<%for.i>][{0,+,1}<%for.j>][{0,+,1}<%for.k>]
7946 ///
7947 /// The testcases are checking the output of a function pass:
7948 /// DelinearizationPass that walks through all loads and stores of a function
7949 /// asking for the SCEV of the memory access with respect to all enclosing
7950 /// loops, calling SCEV->delinearize on that and printing the results.
7951 
7952 void SCEVAddRecExpr::delinearize(ScalarEvolution &SE,
7953                                  SmallVectorImpl<const SCEV *> &Subscripts,
7954                                  SmallVectorImpl<const SCEV *> &Sizes,
7955                                  const SCEV *ElementSize) const {
7956   // First step: collect parametric terms.
7957   SmallVector<const SCEV *, 4> Terms;
7958   collectParametricTerms(SE, Terms);
7959 
7960   if (Terms.empty())
7961     return;
7962 
7963   // Second step: find subscript sizes.
7964   SE.findArrayDimensions(Terms, Sizes, ElementSize);
7965 
7966   if (Sizes.empty())
7967     return;
7968 
7969   // Third step: compute the access functions for each subscript.
7970   computeAccessFunctions(SE, Subscripts, Sizes);
7971 
7972   if (Subscripts.empty())
7973     return;
7974 
7975   DEBUG({
7976       dbgs() << "succeeded to delinearize " << *this << "\n";
7977       dbgs() << "ArrayDecl[UnknownSize]";
7978       for (const SCEV *S : Sizes)
7979         dbgs() << "[" << *S << "]";
7980 
7981       dbgs() << "\nArrayRef";
7982       for (const SCEV *S : Subscripts)
7983         dbgs() << "[" << *S << "]";
7984       dbgs() << "\n";
7985     });
7986 }
7987 
7988 //===----------------------------------------------------------------------===//
7989 //                   SCEVCallbackVH Class Implementation
7990 //===----------------------------------------------------------------------===//
7991 
7992 void ScalarEvolution::SCEVCallbackVH::deleted() {
7993   assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
7994   if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
7995     SE->ConstantEvolutionLoopExitValue.erase(PN);
7996   SE->ValueExprMap.erase(getValPtr());
7997   // this now dangles!
7998 }
7999 
8000 void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
8001   assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
8002 
8003   // Forget all the expressions associated with users of the old value,
8004   // so that future queries will recompute the expressions using the new
8005   // value.
8006   Value *Old = getValPtr();
8007   SmallVector<User *, 16> Worklist(Old->user_begin(), Old->user_end());
8008   SmallPtrSet<User *, 8> Visited;
8009   while (!Worklist.empty()) {
8010     User *U = Worklist.pop_back_val();
8011     // Deleting the Old value will cause this to dangle. Postpone
8012     // that until everything else is done.
8013     if (U == Old)
8014       continue;
8015     if (!Visited.insert(U).second)
8016       continue;
8017     if (PHINode *PN = dyn_cast<PHINode>(U))
8018       SE->ConstantEvolutionLoopExitValue.erase(PN);
8019     SE->ValueExprMap.erase(U);
8020     Worklist.insert(Worklist.end(), U->user_begin(), U->user_end());
8021   }
8022   // Delete the Old value.
8023   if (PHINode *PN = dyn_cast<PHINode>(Old))
8024     SE->ConstantEvolutionLoopExitValue.erase(PN);
8025   SE->ValueExprMap.erase(Old);
8026   // this now dangles!
8027 }
8028 
8029 ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
8030   : CallbackVH(V), SE(se) {}
8031 
8032 //===----------------------------------------------------------------------===//
8033 //                   ScalarEvolution Class Implementation
8034 //===----------------------------------------------------------------------===//
8035 
8036 ScalarEvolution::ScalarEvolution()
8037     : FunctionPass(ID), WalkingBEDominatingConds(false), ValuesAtScopes(64),
8038       LoopDispositions(64), BlockDispositions(64), FirstUnknown(nullptr) {
8039   initializeScalarEvolutionPass(*PassRegistry::getPassRegistry());
8040 }
8041 
8042 bool ScalarEvolution::runOnFunction(Function &F) {
8043   this->F = &F;
8044   AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
8045   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
8046   TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
8047   DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
8048   return false;
8049 }
8050 
8051 void ScalarEvolution::releaseMemory() {
8052   // Iterate through all the SCEVUnknown instances and call their
8053   // destructors, so that they release their references to their values.
8054   for (SCEVUnknown *U = FirstUnknown; U; U = U->Next)
8055     U->~SCEVUnknown();
8056   FirstUnknown = nullptr;
8057 
8058   ValueExprMap.clear();
8059 
8060   // Free any extra memory created for ExitNotTakenInfo in the unlikely event
8061   // that a loop had multiple computable exits.
8062   for (DenseMap<const Loop*, BackedgeTakenInfo>::iterator I =
8063          BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end();
8064        I != E; ++I) {
8065     I->second.clear();
8066   }
8067 
8068   assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
8069   assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
8070 
8071   BackedgeTakenCounts.clear();
8072   ConstantEvolutionLoopExitValue.clear();
8073   ValuesAtScopes.clear();
8074   LoopDispositions.clear();
8075   BlockDispositions.clear();
8076   UnsignedRanges.clear();
8077   SignedRanges.clear();
8078   UniqueSCEVs.clear();
8079   SCEVAllocator.Reset();
8080 }
8081 
8082 void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
8083   AU.setPreservesAll();
8084   AU.addRequired<AssumptionCacheTracker>();
8085   AU.addRequiredTransitive<LoopInfoWrapperPass>();
8086   AU.addRequiredTransitive<DominatorTreeWrapperPass>();
8087   AU.addRequired<TargetLibraryInfoWrapperPass>();
8088 }
8089 
8090 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
8091   return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
8092 }
8093 
8094 static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
8095                           const Loop *L) {
8096   // Print all inner loops first
8097   for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
8098     PrintLoopInfo(OS, SE, *I);
8099 
8100   OS << "Loop ";
8101   L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
8102   OS << ": ";
8103 
8104   SmallVector<BasicBlock *, 8> ExitBlocks;
8105   L->getExitBlocks(ExitBlocks);
8106   if (ExitBlocks.size() != 1)
8107     OS << "<multiple exits> ";
8108 
8109   if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
8110     OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L);
8111   } else {
8112     OS << "Unpredictable backedge-taken count. ";
8113   }
8114 
8115   OS << "\n"
8116         "Loop ";
8117   L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
8118   OS << ": ";
8119 
8120   if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {
8121     OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L);
8122   } else {
8123     OS << "Unpredictable max backedge-taken count. ";
8124   }
8125 
8126   OS << "\n";
8127 }
8128 
8129 void ScalarEvolution::print(raw_ostream &OS, const Module *) const {
8130   // ScalarEvolution's implementation of the print method is to print
8131   // out SCEV values of all instructions that are interesting. Doing
8132   // this potentially causes it to create new SCEV objects though,
8133   // which technically conflicts with the const qualifier. This isn't
8134   // observable from outside the class though, so casting away the
8135   // const isn't dangerous.
8136   ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
8137 
8138   OS << "Classifying expressions for: ";
8139   F->printAsOperand(OS, /*PrintType=*/false);
8140   OS << "\n";
8141   for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
8142     if (isSCEVable(I->getType()) && !isa<CmpInst>(*I)) {
8143       OS << *I << '\n';
8144       OS << "  -->  ";
8145       const SCEV *SV = SE.getSCEV(&*I);
8146       SV->print(OS);
8147       if (!isa<SCEVCouldNotCompute>(SV)) {
8148         OS << " U: ";
8149         SE.getUnsignedRange(SV).print(OS);
8150         OS << " S: ";
8151         SE.getSignedRange(SV).print(OS);
8152       }
8153 
8154       const Loop *L = LI->getLoopFor((*I).getParent());
8155 
8156       const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
8157       if (AtUse != SV) {
8158         OS << "  -->  ";
8159         AtUse->print(OS);
8160         if (!isa<SCEVCouldNotCompute>(AtUse)) {
8161           OS << " U: ";
8162           SE.getUnsignedRange(AtUse).print(OS);
8163           OS << " S: ";
8164           SE.getSignedRange(AtUse).print(OS);
8165         }
8166       }
8167 
8168       if (L) {
8169         OS << "\t\t" "Exits: ";
8170         const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
8171         if (!SE.isLoopInvariant(ExitValue, L)) {
8172           OS << "<<Unknown>>";
8173         } else {
8174           OS << *ExitValue;
8175         }
8176       }
8177 
8178       OS << "\n";
8179     }
8180 
8181   OS << "Determining loop execution counts for: ";
8182   F->printAsOperand(OS, /*PrintType=*/false);
8183   OS << "\n";
8184   for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
8185     PrintLoopInfo(OS, &SE, *I);
8186 }
8187 
8188 ScalarEvolution::LoopDisposition
8189 ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
8190   auto &Values = LoopDispositions[S];
8191   for (auto &V : Values) {
8192     if (V.getPointer() == L)
8193       return V.getInt();
8194   }
8195   Values.emplace_back(L, LoopVariant);
8196   LoopDisposition D = computeLoopDisposition(S, L);
8197   auto &Values2 = LoopDispositions[S];
8198   for (auto &V : make_range(Values2.rbegin(), Values2.rend())) {
8199     if (V.getPointer() == L) {
8200       V.setInt(D);
8201       break;
8202     }
8203   }
8204   return D;
8205 }
8206 
8207 ScalarEvolution::LoopDisposition
8208 ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
8209   switch (static_cast<SCEVTypes>(S->getSCEVType())) {
8210   case scConstant:
8211     return LoopInvariant;
8212   case scTruncate:
8213   case scZeroExtend:
8214   case scSignExtend:
8215     return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L);
8216   case scAddRecExpr: {
8217     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
8218 
8219     // If L is the addrec's loop, it's computable.
8220     if (AR->getLoop() == L)
8221       return LoopComputable;
8222 
8223     // Add recurrences are never invariant in the function-body (null loop).
8224     if (!L)
8225       return LoopVariant;
8226 
8227     // This recurrence is variant w.r.t. L if L contains AR's loop.
8228     if (L->contains(AR->getLoop()))
8229       return LoopVariant;
8230 
8231     // This recurrence is invariant w.r.t. L if AR's loop contains L.
8232     if (AR->getLoop()->contains(L))
8233       return LoopInvariant;
8234 
8235     // This recurrence is variant w.r.t. L if any of its operands
8236     // are variant.
8237     for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
8238          I != E; ++I)
8239       if (!isLoopInvariant(*I, L))
8240         return LoopVariant;
8241 
8242     // Otherwise it's loop-invariant.
8243     return LoopInvariant;
8244   }
8245   case scAddExpr:
8246   case scMulExpr:
8247   case scUMaxExpr:
8248   case scSMaxExpr: {
8249     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
8250     bool HasVarying = false;
8251     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
8252          I != E; ++I) {
8253       LoopDisposition D = getLoopDisposition(*I, L);
8254       if (D == LoopVariant)
8255         return LoopVariant;
8256       if (D == LoopComputable)
8257         HasVarying = true;
8258     }
8259     return HasVarying ? LoopComputable : LoopInvariant;
8260   }
8261   case scUDivExpr: {
8262     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
8263     LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L);
8264     if (LD == LoopVariant)
8265       return LoopVariant;
8266     LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L);
8267     if (RD == LoopVariant)
8268       return LoopVariant;
8269     return (LD == LoopInvariant && RD == LoopInvariant) ?
8270            LoopInvariant : LoopComputable;
8271   }
8272   case scUnknown:
8273     // All non-instruction values are loop invariant.  All instructions are loop
8274     // invariant if they are not contained in the specified loop.
8275     // Instructions are never considered invariant in the function body
8276     // (null loop) because they are defined within the "loop".
8277     if (Instruction *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
8278       return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
8279     return LoopInvariant;
8280   case scCouldNotCompute:
8281     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
8282   }
8283   llvm_unreachable("Unknown SCEV kind!");
8284 }
8285 
8286 bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) {
8287   return getLoopDisposition(S, L) == LoopInvariant;
8288 }
8289 
8290 bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) {
8291   return getLoopDisposition(S, L) == LoopComputable;
8292 }
8293 
8294 ScalarEvolution::BlockDisposition
8295 ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
8296   auto &Values = BlockDispositions[S];
8297   for (auto &V : Values) {
8298     if (V.getPointer() == BB)
8299       return V.getInt();
8300   }
8301   Values.emplace_back(BB, DoesNotDominateBlock);
8302   BlockDisposition D = computeBlockDisposition(S, BB);
8303   auto &Values2 = BlockDispositions[S];
8304   for (auto &V : make_range(Values2.rbegin(), Values2.rend())) {
8305     if (V.getPointer() == BB) {
8306       V.setInt(D);
8307       break;
8308     }
8309   }
8310   return D;
8311 }
8312 
8313 ScalarEvolution::BlockDisposition
8314 ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
8315   switch (static_cast<SCEVTypes>(S->getSCEVType())) {
8316   case scConstant:
8317     return ProperlyDominatesBlock;
8318   case scTruncate:
8319   case scZeroExtend:
8320   case scSignExtend:
8321     return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB);
8322   case scAddRecExpr: {
8323     // This uses a "dominates" query instead of "properly dominates" query
8324     // to test for proper dominance too, because the instruction which
8325     // produces the addrec's value is a PHI, and a PHI effectively properly
8326     // dominates its entire containing block.
8327     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
8328     if (!DT->dominates(AR->getLoop()->getHeader(), BB))
8329       return DoesNotDominateBlock;
8330   }
8331   // FALL THROUGH into SCEVNAryExpr handling.
8332   case scAddExpr:
8333   case scMulExpr:
8334   case scUMaxExpr:
8335   case scSMaxExpr: {
8336     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
8337     bool Proper = true;
8338     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
8339          I != E; ++I) {
8340       BlockDisposition D = getBlockDisposition(*I, BB);
8341       if (D == DoesNotDominateBlock)
8342         return DoesNotDominateBlock;
8343       if (D == DominatesBlock)
8344         Proper = false;
8345     }
8346     return Proper ? ProperlyDominatesBlock : DominatesBlock;
8347   }
8348   case scUDivExpr: {
8349     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
8350     const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
8351     BlockDisposition LD = getBlockDisposition(LHS, BB);
8352     if (LD == DoesNotDominateBlock)
8353       return DoesNotDominateBlock;
8354     BlockDisposition RD = getBlockDisposition(RHS, BB);
8355     if (RD == DoesNotDominateBlock)
8356       return DoesNotDominateBlock;
8357     return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ?
8358       ProperlyDominatesBlock : DominatesBlock;
8359   }
8360   case scUnknown:
8361     if (Instruction *I =
8362           dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
8363       if (I->getParent() == BB)
8364         return DominatesBlock;
8365       if (DT->properlyDominates(I->getParent(), BB))
8366         return ProperlyDominatesBlock;
8367       return DoesNotDominateBlock;
8368     }
8369     return ProperlyDominatesBlock;
8370   case scCouldNotCompute:
8371     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
8372   }
8373   llvm_unreachable("Unknown SCEV kind!");
8374 }
8375 
8376 bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
8377   return getBlockDisposition(S, BB) >= DominatesBlock;
8378 }
8379 
8380 bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
8381   return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
8382 }
8383 
8384 namespace {
8385 // Search for a SCEV expression node within an expression tree.
8386 // Implements SCEVTraversal::Visitor.
8387 struct SCEVSearch {
8388   const SCEV *Node;
8389   bool IsFound;
8390 
8391   SCEVSearch(const SCEV *N): Node(N), IsFound(false) {}
8392 
8393   bool follow(const SCEV *S) {
8394     IsFound |= (S == Node);
8395     return !IsFound;
8396   }
8397   bool isDone() const { return IsFound; }
8398 };
8399 }
8400 
8401 bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
8402   SCEVSearch Search(Op);
8403   visitAll(S, Search);
8404   return Search.IsFound;
8405 }
8406 
8407 void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
8408   ValuesAtScopes.erase(S);
8409   LoopDispositions.erase(S);
8410   BlockDispositions.erase(S);
8411   UnsignedRanges.erase(S);
8412   SignedRanges.erase(S);
8413 
8414   for (DenseMap<const Loop*, BackedgeTakenInfo>::iterator I =
8415          BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end(); I != E; ) {
8416     BackedgeTakenInfo &BEInfo = I->second;
8417     if (BEInfo.hasOperand(S, this)) {
8418       BEInfo.clear();
8419       BackedgeTakenCounts.erase(I++);
8420     }
8421     else
8422       ++I;
8423   }
8424 }
8425 
8426 typedef DenseMap<const Loop *, std::string> VerifyMap;
8427 
8428 /// replaceSubString - Replaces all occurrences of From in Str with To.
8429 static void replaceSubString(std::string &Str, StringRef From, StringRef To) {
8430   size_t Pos = 0;
8431   while ((Pos = Str.find(From, Pos)) != std::string::npos) {
8432     Str.replace(Pos, From.size(), To.data(), To.size());
8433     Pos += To.size();
8434   }
8435 }
8436 
8437 /// getLoopBackedgeTakenCounts - Helper method for verifyAnalysis.
8438 static void
8439 getLoopBackedgeTakenCounts(Loop *L, VerifyMap &Map, ScalarEvolution &SE) {
8440   for (Loop::reverse_iterator I = L->rbegin(), E = L->rend(); I != E; ++I) {
8441     getLoopBackedgeTakenCounts(*I, Map, SE); // recurse.
8442 
8443     std::string &S = Map[L];
8444     if (S.empty()) {
8445       raw_string_ostream OS(S);
8446       SE.getBackedgeTakenCount(L)->print(OS);
8447 
8448       // false and 0 are semantically equivalent. This can happen in dead loops.
8449       replaceSubString(OS.str(), "false", "0");
8450       // Remove wrap flags, their use in SCEV is highly fragile.
8451       // FIXME: Remove this when SCEV gets smarter about them.
8452       replaceSubString(OS.str(), "<nw>", "");
8453       replaceSubString(OS.str(), "<nsw>", "");
8454       replaceSubString(OS.str(), "<nuw>", "");
8455     }
8456   }
8457 }
8458 
8459 void ScalarEvolution::verifyAnalysis() const {
8460   if (!VerifySCEV)
8461     return;
8462 
8463   ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
8464 
8465   // Gather stringified backedge taken counts for all loops using SCEV's caches.
8466   // FIXME: It would be much better to store actual values instead of strings,
8467   //        but SCEV pointers will change if we drop the caches.
8468   VerifyMap BackedgeDumpsOld, BackedgeDumpsNew;
8469   for (LoopInfo::reverse_iterator I = LI->rbegin(), E = LI->rend(); I != E; ++I)
8470     getLoopBackedgeTakenCounts(*I, BackedgeDumpsOld, SE);
8471 
8472   // Gather stringified backedge taken counts for all loops without using
8473   // SCEV's caches.
8474   SE.releaseMemory();
8475   for (LoopInfo::reverse_iterator I = LI->rbegin(), E = LI->rend(); I != E; ++I)
8476     getLoopBackedgeTakenCounts(*I, BackedgeDumpsNew, SE);
8477 
8478   // Now compare whether they're the same with and without caches. This allows
8479   // verifying that no pass changed the cache.
8480   assert(BackedgeDumpsOld.size() == BackedgeDumpsNew.size() &&
8481          "New loops suddenly appeared!");
8482 
8483   for (VerifyMap::iterator OldI = BackedgeDumpsOld.begin(),
8484                            OldE = BackedgeDumpsOld.end(),
8485                            NewI = BackedgeDumpsNew.begin();
8486        OldI != OldE; ++OldI, ++NewI) {
8487     assert(OldI->first == NewI->first && "Loop order changed!");
8488 
8489     // Compare the stringified SCEVs. We don't care if undef backedgetaken count
8490     // changes.
8491     // FIXME: We currently ignore SCEV changes from/to CouldNotCompute. This
8492     // means that a pass is buggy or SCEV has to learn a new pattern but is
8493     // usually not harmful.
8494     if (OldI->second != NewI->second &&
8495         OldI->second.find("undef") == std::string::npos &&
8496         NewI->second.find("undef") == std::string::npos &&
8497         OldI->second != "***COULDNOTCOMPUTE***" &&
8498         NewI->second != "***COULDNOTCOMPUTE***") {
8499       dbgs() << "SCEVValidator: SCEV for loop '"
8500              << OldI->first->getHeader()->getName()
8501              << "' changed from '" << OldI->second
8502              << "' to '" << NewI->second << "'!\n";
8503       std::abort();
8504     }
8505   }
8506 
8507   // TODO: Verify more things.
8508 }
8509