xref: /llvm-project/llvm/lib/Analysis/ScalarEvolution.cpp (revision 299e67291c49cf28cf42e77b2ecdce1485a60ceb)
1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains the implementation of the scalar evolution analysis
11 // engine, which is used primarily to analyze expressions involving induction
12 // variables in loops.
13 //
14 // There are several aspects to this library.  First is the representation of
15 // scalar expressions, which are represented as subclasses of the SCEV class.
16 // These classes are used to represent certain types of subexpressions that we
17 // can handle. We only create one SCEV of a particular shape, so
18 // pointer-comparisons for equality are legal.
19 //
20 // One important aspect of the SCEV objects is that they are never cyclic, even
21 // if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
22 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
23 // recurrence) then we represent it directly as a recurrence node, otherwise we
24 // represent it as a SCEVUnknown node.
25 //
26 // In addition to being able to represent expressions of various types, we also
27 // have folders that are used to build the *canonical* representation for a
28 // particular expression.  These folders are capable of using a variety of
29 // rewrite rules to simplify the expressions.
30 //
31 // Once the folders are defined, we can implement the more interesting
32 // higher-level code, such as the code that recognizes PHI nodes of various
33 // types, computes the execution count of a loop, etc.
34 //
35 // TODO: We should use these routines and value representations to implement
36 // dependence analysis!
37 //
38 //===----------------------------------------------------------------------===//
39 //
40 // There are several good references for the techniques used in this analysis.
41 //
42 //  Chains of recurrences -- a method to expedite the evaluation
43 //  of closed-form functions
44 //  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
45 //
46 //  On computational properties of chains of recurrences
47 //  Eugene V. Zima
48 //
49 //  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
50 //  Robert A. van Engelen
51 //
52 //  Efficient Symbolic Analysis for Optimizing Compilers
53 //  Robert A. van Engelen
54 //
55 //  Using the chains of recurrences algebra for data dependence testing and
56 //  induction variable substitution
57 //  MS Thesis, Johnie Birch
58 //
59 //===----------------------------------------------------------------------===//
60 
61 #include "llvm/Analysis/ScalarEvolution.h"
62 #include "llvm/ADT/Optional.h"
63 #include "llvm/ADT/STLExtras.h"
64 #include "llvm/ADT/ScopeExit.h"
65 #include "llvm/ADT/SmallPtrSet.h"
66 #include "llvm/ADT/Statistic.h"
67 #include "llvm/Analysis/AssumptionCache.h"
68 #include "llvm/Analysis/ConstantFolding.h"
69 #include "llvm/Analysis/InstructionSimplify.h"
70 #include "llvm/Analysis/LoopInfo.h"
71 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
72 #include "llvm/Analysis/TargetLibraryInfo.h"
73 #include "llvm/Analysis/ValueTracking.h"
74 #include "llvm/IR/ConstantRange.h"
75 #include "llvm/IR/Constants.h"
76 #include "llvm/IR/DataLayout.h"
77 #include "llvm/IR/DerivedTypes.h"
78 #include "llvm/IR/Dominators.h"
79 #include "llvm/IR/GetElementPtrTypeIterator.h"
80 #include "llvm/IR/GlobalAlias.h"
81 #include "llvm/IR/GlobalVariable.h"
82 #include "llvm/IR/InstIterator.h"
83 #include "llvm/IR/Instructions.h"
84 #include "llvm/IR/LLVMContext.h"
85 #include "llvm/IR/Metadata.h"
86 #include "llvm/IR/Operator.h"
87 #include "llvm/IR/PatternMatch.h"
88 #include "llvm/Support/CommandLine.h"
89 #include "llvm/Support/Debug.h"
90 #include "llvm/Support/ErrorHandling.h"
91 #include "llvm/Support/MathExtras.h"
92 #include "llvm/Support/raw_ostream.h"
93 #include "llvm/Support/SaveAndRestore.h"
94 #include <algorithm>
95 using namespace llvm;
96 
97 #define DEBUG_TYPE "scalar-evolution"
98 
99 STATISTIC(NumArrayLenItCounts,
100           "Number of trip counts computed with array length");
101 STATISTIC(NumTripCountsComputed,
102           "Number of loops with predictable loop counts");
103 STATISTIC(NumTripCountsNotComputed,
104           "Number of loops without predictable loop counts");
105 STATISTIC(NumBruteForceTripCountsComputed,
106           "Number of loops with trip counts computed by force");
107 
108 static cl::opt<unsigned>
109 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
110                         cl::desc("Maximum number of iterations SCEV will "
111                                  "symbolically execute a constant "
112                                  "derived loop"),
113                         cl::init(100));
114 
115 // FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean.
116 static cl::opt<bool>
117 VerifySCEV("verify-scev",
118            cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
119 static cl::opt<bool>
120     VerifySCEVMap("verify-scev-maps",
121                   cl::desc("Verify no dangling value in ScalarEvolution's "
122                            "ExprValueMap (slow)"));
123 
124 static cl::opt<unsigned> MulOpsInlineThreshold(
125     "scev-mulops-inline-threshold", cl::Hidden,
126     cl::desc("Threshold for inlining multiplication operands into a SCEV"),
127     cl::init(1000));
128 
129 //===----------------------------------------------------------------------===//
130 //                           SCEV class definitions
131 //===----------------------------------------------------------------------===//
132 
133 //===----------------------------------------------------------------------===//
134 // Implementation of the SCEV class.
135 //
136 
137 LLVM_DUMP_METHOD
138 void SCEV::dump() const {
139   print(dbgs());
140   dbgs() << '\n';
141 }
142 
143 void SCEV::print(raw_ostream &OS) const {
144   switch (static_cast<SCEVTypes>(getSCEVType())) {
145   case scConstant:
146     cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
147     return;
148   case scTruncate: {
149     const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
150     const SCEV *Op = Trunc->getOperand();
151     OS << "(trunc " << *Op->getType() << " " << *Op << " to "
152        << *Trunc->getType() << ")";
153     return;
154   }
155   case scZeroExtend: {
156     const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
157     const SCEV *Op = ZExt->getOperand();
158     OS << "(zext " << *Op->getType() << " " << *Op << " to "
159        << *ZExt->getType() << ")";
160     return;
161   }
162   case scSignExtend: {
163     const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
164     const SCEV *Op = SExt->getOperand();
165     OS << "(sext " << *Op->getType() << " " << *Op << " to "
166        << *SExt->getType() << ")";
167     return;
168   }
169   case scAddRecExpr: {
170     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
171     OS << "{" << *AR->getOperand(0);
172     for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
173       OS << ",+," << *AR->getOperand(i);
174     OS << "}<";
175     if (AR->hasNoUnsignedWrap())
176       OS << "nuw><";
177     if (AR->hasNoSignedWrap())
178       OS << "nsw><";
179     if (AR->hasNoSelfWrap() &&
180         !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
181       OS << "nw><";
182     AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
183     OS << ">";
184     return;
185   }
186   case scAddExpr:
187   case scMulExpr:
188   case scUMaxExpr:
189   case scSMaxExpr: {
190     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
191     const char *OpStr = nullptr;
192     switch (NAry->getSCEVType()) {
193     case scAddExpr: OpStr = " + "; break;
194     case scMulExpr: OpStr = " * "; break;
195     case scUMaxExpr: OpStr = " umax "; break;
196     case scSMaxExpr: OpStr = " smax "; break;
197     }
198     OS << "(";
199     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
200          I != E; ++I) {
201       OS << **I;
202       if (std::next(I) != E)
203         OS << OpStr;
204     }
205     OS << ")";
206     switch (NAry->getSCEVType()) {
207     case scAddExpr:
208     case scMulExpr:
209       if (NAry->hasNoUnsignedWrap())
210         OS << "<nuw>";
211       if (NAry->hasNoSignedWrap())
212         OS << "<nsw>";
213     }
214     return;
215   }
216   case scUDivExpr: {
217     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
218     OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
219     return;
220   }
221   case scUnknown: {
222     const SCEVUnknown *U = cast<SCEVUnknown>(this);
223     Type *AllocTy;
224     if (U->isSizeOf(AllocTy)) {
225       OS << "sizeof(" << *AllocTy << ")";
226       return;
227     }
228     if (U->isAlignOf(AllocTy)) {
229       OS << "alignof(" << *AllocTy << ")";
230       return;
231     }
232 
233     Type *CTy;
234     Constant *FieldNo;
235     if (U->isOffsetOf(CTy, FieldNo)) {
236       OS << "offsetof(" << *CTy << ", ";
237       FieldNo->printAsOperand(OS, false);
238       OS << ")";
239       return;
240     }
241 
242     // Otherwise just print it normally.
243     U->getValue()->printAsOperand(OS, false);
244     return;
245   }
246   case scCouldNotCompute:
247     OS << "***COULDNOTCOMPUTE***";
248     return;
249   }
250   llvm_unreachable("Unknown SCEV kind!");
251 }
252 
253 Type *SCEV::getType() const {
254   switch (static_cast<SCEVTypes>(getSCEVType())) {
255   case scConstant:
256     return cast<SCEVConstant>(this)->getType();
257   case scTruncate:
258   case scZeroExtend:
259   case scSignExtend:
260     return cast<SCEVCastExpr>(this)->getType();
261   case scAddRecExpr:
262   case scMulExpr:
263   case scUMaxExpr:
264   case scSMaxExpr:
265     return cast<SCEVNAryExpr>(this)->getType();
266   case scAddExpr:
267     return cast<SCEVAddExpr>(this)->getType();
268   case scUDivExpr:
269     return cast<SCEVUDivExpr>(this)->getType();
270   case scUnknown:
271     return cast<SCEVUnknown>(this)->getType();
272   case scCouldNotCompute:
273     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
274   }
275   llvm_unreachable("Unknown SCEV kind!");
276 }
277 
278 bool SCEV::isZero() const {
279   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
280     return SC->getValue()->isZero();
281   return false;
282 }
283 
284 bool SCEV::isOne() const {
285   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
286     return SC->getValue()->isOne();
287   return false;
288 }
289 
290 bool SCEV::isAllOnesValue() const {
291   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
292     return SC->getValue()->isAllOnesValue();
293   return false;
294 }
295 
296 bool SCEV::isNonConstantNegative() const {
297   const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
298   if (!Mul) return false;
299 
300   // If there is a constant factor, it will be first.
301   const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
302   if (!SC) return false;
303 
304   // Return true if the value is negative, this matches things like (-42 * V).
305   return SC->getAPInt().isNegative();
306 }
307 
308 SCEVCouldNotCompute::SCEVCouldNotCompute() :
309   SCEV(FoldingSetNodeIDRef(), scCouldNotCompute) {}
310 
311 bool SCEVCouldNotCompute::classof(const SCEV *S) {
312   return S->getSCEVType() == scCouldNotCompute;
313 }
314 
315 const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
316   FoldingSetNodeID ID;
317   ID.AddInteger(scConstant);
318   ID.AddPointer(V);
319   void *IP = nullptr;
320   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
321   SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
322   UniqueSCEVs.InsertNode(S, IP);
323   return S;
324 }
325 
326 const SCEV *ScalarEvolution::getConstant(const APInt &Val) {
327   return getConstant(ConstantInt::get(getContext(), Val));
328 }
329 
330 const SCEV *
331 ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
332   IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
333   return getConstant(ConstantInt::get(ITy, V, isSigned));
334 }
335 
336 SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID,
337                            unsigned SCEVTy, const SCEV *op, Type *ty)
338   : SCEV(ID, SCEVTy), Op(op), Ty(ty) {}
339 
340 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
341                                    const SCEV *op, Type *ty)
342   : SCEVCastExpr(ID, scTruncate, op, ty) {
343   assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
344          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
345          "Cannot truncate non-integer value!");
346 }
347 
348 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
349                                        const SCEV *op, Type *ty)
350   : SCEVCastExpr(ID, scZeroExtend, op, ty) {
351   assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
352          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
353          "Cannot zero extend non-integer value!");
354 }
355 
356 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
357                                        const SCEV *op, Type *ty)
358   : SCEVCastExpr(ID, scSignExtend, op, ty) {
359   assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
360          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
361          "Cannot sign extend non-integer value!");
362 }
363 
364 void SCEVUnknown::deleted() {
365   // Clear this SCEVUnknown from various maps.
366   SE->forgetMemoizedResults(this);
367 
368   // Remove this SCEVUnknown from the uniquing map.
369   SE->UniqueSCEVs.RemoveNode(this);
370 
371   // Release the value.
372   setValPtr(nullptr);
373 }
374 
375 void SCEVUnknown::allUsesReplacedWith(Value *New) {
376   // Clear this SCEVUnknown from various maps.
377   SE->forgetMemoizedResults(this);
378 
379   // Remove this SCEVUnknown from the uniquing map.
380   SE->UniqueSCEVs.RemoveNode(this);
381 
382   // Update this SCEVUnknown to point to the new value. This is needed
383   // because there may still be outstanding SCEVs which still point to
384   // this SCEVUnknown.
385   setValPtr(New);
386 }
387 
388 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
389   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
390     if (VCE->getOpcode() == Instruction::PtrToInt)
391       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
392         if (CE->getOpcode() == Instruction::GetElementPtr &&
393             CE->getOperand(0)->isNullValue() &&
394             CE->getNumOperands() == 2)
395           if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
396             if (CI->isOne()) {
397               AllocTy = cast<PointerType>(CE->getOperand(0)->getType())
398                                  ->getElementType();
399               return true;
400             }
401 
402   return false;
403 }
404 
405 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
406   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
407     if (VCE->getOpcode() == Instruction::PtrToInt)
408       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
409         if (CE->getOpcode() == Instruction::GetElementPtr &&
410             CE->getOperand(0)->isNullValue()) {
411           Type *Ty =
412             cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
413           if (StructType *STy = dyn_cast<StructType>(Ty))
414             if (!STy->isPacked() &&
415                 CE->getNumOperands() == 3 &&
416                 CE->getOperand(1)->isNullValue()) {
417               if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
418                 if (CI->isOne() &&
419                     STy->getNumElements() == 2 &&
420                     STy->getElementType(0)->isIntegerTy(1)) {
421                   AllocTy = STy->getElementType(1);
422                   return true;
423                 }
424             }
425         }
426 
427   return false;
428 }
429 
430 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
431   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
432     if (VCE->getOpcode() == Instruction::PtrToInt)
433       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
434         if (CE->getOpcode() == Instruction::GetElementPtr &&
435             CE->getNumOperands() == 3 &&
436             CE->getOperand(0)->isNullValue() &&
437             CE->getOperand(1)->isNullValue()) {
438           Type *Ty =
439             cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
440           // Ignore vector types here so that ScalarEvolutionExpander doesn't
441           // emit getelementptrs that index into vectors.
442           if (Ty->isStructTy() || Ty->isArrayTy()) {
443             CTy = Ty;
444             FieldNo = CE->getOperand(2);
445             return true;
446           }
447         }
448 
449   return false;
450 }
451 
452 //===----------------------------------------------------------------------===//
453 //                               SCEV Utilities
454 //===----------------------------------------------------------------------===//
455 
456 static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
457                                   Value *RV, unsigned DepthLeft = 2) {
458   if (DepthLeft == 0)
459     return 0;
460 
461   // Order pointer values after integer values. This helps SCEVExpander form
462   // GEPs.
463   bool LIsPointer = LV->getType()->isPointerTy(),
464        RIsPointer = RV->getType()->isPointerTy();
465   if (LIsPointer != RIsPointer)
466     return (int)LIsPointer - (int)RIsPointer;
467 
468   // Compare getValueID values.
469   unsigned LID = LV->getValueID(), RID = RV->getValueID();
470   if (LID != RID)
471     return (int)LID - (int)RID;
472 
473   // Sort arguments by their position.
474   if (const auto *LA = dyn_cast<Argument>(LV)) {
475     const auto *RA = cast<Argument>(RV);
476     unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
477     return (int)LArgNo - (int)RArgNo;
478   }
479 
480   if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
481     const auto *RGV = cast<GlobalValue>(RV);
482 
483     const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
484       auto LT = GV->getLinkage();
485       return !(GlobalValue::isPrivateLinkage(LT) ||
486                GlobalValue::isInternalLinkage(LT));
487     };
488 
489     // Use the names to distinguish the two values, but only if the
490     // names are semantically important.
491     if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
492       return LGV->getName().compare(RGV->getName());
493   }
494 
495   // For instructions, compare their loop depth, and their operand count.  This
496   // is pretty loose.
497   if (const auto *LInst = dyn_cast<Instruction>(LV)) {
498     const auto *RInst = cast<Instruction>(RV);
499 
500     // Compare loop depths.
501     const BasicBlock *LParent = LInst->getParent(),
502                      *RParent = RInst->getParent();
503     if (LParent != RParent) {
504       unsigned LDepth = LI->getLoopDepth(LParent),
505                RDepth = LI->getLoopDepth(RParent);
506       if (LDepth != RDepth)
507         return (int)LDepth - (int)RDepth;
508     }
509 
510     // Compare the number of operands.
511     unsigned LNumOps = LInst->getNumOperands(),
512              RNumOps = RInst->getNumOperands();
513     if (LNumOps != RNumOps || LNumOps != 1)
514       return (int)LNumOps - (int)RNumOps;
515 
516     // We only bother "recursing" if we have one operand to look at (so we don't
517     // really recurse as much as we iterate).  We can consider expanding this
518     // logic in the future.
519     return CompareValueComplexity(LI, LInst->getOperand(0),
520                                   RInst->getOperand(0), DepthLeft - 1);
521   }
522 
523   return 0;
524 }
525 
526 // Return negative, zero, or positive, if LHS is less than, equal to, or greater
527 // than RHS, respectively. A three-way result allows recursive comparisons to be
528 // more efficient.
529 static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
530                                  const SCEV *RHS) {
531   // Fast-path: SCEVs are uniqued so we can do a quick equality check.
532   if (LHS == RHS)
533     return 0;
534 
535   // Primarily, sort the SCEVs by their getSCEVType().
536   unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
537   if (LType != RType)
538     return (int)LType - (int)RType;
539 
540   // Aside from the getSCEVType() ordering, the particular ordering
541   // isn't very important except that it's beneficial to be consistent,
542   // so that (a + b) and (b + a) don't end up as different expressions.
543   switch (static_cast<SCEVTypes>(LType)) {
544   case scUnknown: {
545     const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
546     const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
547 
548     return CompareValueComplexity(LI, LU->getValue(), RU->getValue());
549   }
550 
551   case scConstant: {
552     const SCEVConstant *LC = cast<SCEVConstant>(LHS);
553     const SCEVConstant *RC = cast<SCEVConstant>(RHS);
554 
555     // Compare constant values.
556     const APInt &LA = LC->getAPInt();
557     const APInt &RA = RC->getAPInt();
558     unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
559     if (LBitWidth != RBitWidth)
560       return (int)LBitWidth - (int)RBitWidth;
561     return LA.ult(RA) ? -1 : 1;
562   }
563 
564   case scAddRecExpr: {
565     const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
566     const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
567 
568     // Compare addrec loop depths.
569     const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
570     if (LLoop != RLoop) {
571       unsigned LDepth = LLoop->getLoopDepth(), RDepth = RLoop->getLoopDepth();
572       if (LDepth != RDepth)
573         return (int)LDepth - (int)RDepth;
574     }
575 
576     // Addrec complexity grows with operand count.
577     unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
578     if (LNumOps != RNumOps)
579       return (int)LNumOps - (int)RNumOps;
580 
581     // Lexicographically compare.
582     for (unsigned i = 0; i != LNumOps; ++i) {
583       long X = CompareSCEVComplexity(LI, LA->getOperand(i), RA->getOperand(i));
584       if (X != 0)
585         return X;
586     }
587 
588     return 0;
589   }
590 
591   case scAddExpr:
592   case scMulExpr:
593   case scSMaxExpr:
594   case scUMaxExpr: {
595     const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
596     const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
597 
598     // Lexicographically compare n-ary expressions.
599     unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
600     if (LNumOps != RNumOps)
601       return (int)LNumOps - (int)RNumOps;
602 
603     for (unsigned i = 0; i != LNumOps; ++i) {
604       if (i >= RNumOps)
605         return 1;
606       long X = CompareSCEVComplexity(LI, LC->getOperand(i), RC->getOperand(i));
607       if (X != 0)
608         return X;
609     }
610     return (int)LNumOps - (int)RNumOps;
611   }
612 
613   case scUDivExpr: {
614     const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
615     const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
616 
617     // Lexicographically compare udiv expressions.
618     long X = CompareSCEVComplexity(LI, LC->getLHS(), RC->getLHS());
619     if (X != 0)
620       return X;
621     return CompareSCEVComplexity(LI, LC->getRHS(), RC->getRHS());
622   }
623 
624   case scTruncate:
625   case scZeroExtend:
626   case scSignExtend: {
627     const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
628     const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
629 
630     // Compare cast expressions by operand.
631     return CompareSCEVComplexity(LI, LC->getOperand(), RC->getOperand());
632   }
633 
634   case scCouldNotCompute:
635     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
636   }
637   llvm_unreachable("Unknown SCEV kind!");
638 }
639 
640 /// Given a list of SCEV objects, order them by their complexity, and group
641 /// objects of the same complexity together by value.  When this routine is
642 /// finished, we know that any duplicates in the vector are consecutive and that
643 /// complexity is monotonically increasing.
644 ///
645 /// Note that we go take special precautions to ensure that we get deterministic
646 /// results from this routine.  In other words, we don't want the results of
647 /// this to depend on where the addresses of various SCEV objects happened to
648 /// land in memory.
649 ///
650 static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
651                               LoopInfo *LI) {
652   if (Ops.size() < 2) return;  // Noop
653   if (Ops.size() == 2) {
654     // This is the common case, which also happens to be trivially simple.
655     // Special case it.
656     const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
657     if (CompareSCEVComplexity(LI, RHS, LHS) < 0)
658       std::swap(LHS, RHS);
659     return;
660   }
661 
662   // Do the rough sort by complexity.
663   std::stable_sort(Ops.begin(), Ops.end(),
664                    [LI](const SCEV *LHS, const SCEV *RHS) {
665                      return CompareSCEVComplexity(LI, LHS, RHS) < 0;
666                    });
667 
668   // Now that we are sorted by complexity, group elements of the same
669   // complexity.  Note that this is, at worst, N^2, but the vector is likely to
670   // be extremely short in practice.  Note that we take this approach because we
671   // do not want to depend on the addresses of the objects we are grouping.
672   for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
673     const SCEV *S = Ops[i];
674     unsigned Complexity = S->getSCEVType();
675 
676     // If there are any objects of the same complexity and same value as this
677     // one, group them.
678     for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
679       if (Ops[j] == S) { // Found a duplicate.
680         // Move it to immediately after i'th element.
681         std::swap(Ops[i+1], Ops[j]);
682         ++i;   // no need to rescan it.
683         if (i == e-2) return;  // Done!
684       }
685     }
686   }
687 }
688 
689 // Returns the size of the SCEV S.
690 static inline int sizeOfSCEV(const SCEV *S) {
691   struct FindSCEVSize {
692     int Size;
693     FindSCEVSize() : Size(0) {}
694 
695     bool follow(const SCEV *S) {
696       ++Size;
697       // Keep looking at all operands of S.
698       return true;
699     }
700     bool isDone() const {
701       return false;
702     }
703   };
704 
705   FindSCEVSize F;
706   SCEVTraversal<FindSCEVSize> ST(F);
707   ST.visitAll(S);
708   return F.Size;
709 }
710 
711 namespace {
712 
713 struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> {
714 public:
715   // Computes the Quotient and Remainder of the division of Numerator by
716   // Denominator.
717   static void divide(ScalarEvolution &SE, const SCEV *Numerator,
718                      const SCEV *Denominator, const SCEV **Quotient,
719                      const SCEV **Remainder) {
720     assert(Numerator && Denominator && "Uninitialized SCEV");
721 
722     SCEVDivision D(SE, Numerator, Denominator);
723 
724     // Check for the trivial case here to avoid having to check for it in the
725     // rest of the code.
726     if (Numerator == Denominator) {
727       *Quotient = D.One;
728       *Remainder = D.Zero;
729       return;
730     }
731 
732     if (Numerator->isZero()) {
733       *Quotient = D.Zero;
734       *Remainder = D.Zero;
735       return;
736     }
737 
738     // A simple case when N/1. The quotient is N.
739     if (Denominator->isOne()) {
740       *Quotient = Numerator;
741       *Remainder = D.Zero;
742       return;
743     }
744 
745     // Split the Denominator when it is a product.
746     if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
747       const SCEV *Q, *R;
748       *Quotient = Numerator;
749       for (const SCEV *Op : T->operands()) {
750         divide(SE, *Quotient, Op, &Q, &R);
751         *Quotient = Q;
752 
753         // Bail out when the Numerator is not divisible by one of the terms of
754         // the Denominator.
755         if (!R->isZero()) {
756           *Quotient = D.Zero;
757           *Remainder = Numerator;
758           return;
759         }
760       }
761       *Remainder = D.Zero;
762       return;
763     }
764 
765     D.visit(Numerator);
766     *Quotient = D.Quotient;
767     *Remainder = D.Remainder;
768   }
769 
770   // Except in the trivial case described above, we do not know how to divide
771   // Expr by Denominator for the following functions with empty implementation.
772   void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {}
773   void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {}
774   void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {}
775   void visitUDivExpr(const SCEVUDivExpr *Numerator) {}
776   void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {}
777   void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {}
778   void visitUnknown(const SCEVUnknown *Numerator) {}
779   void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {}
780 
781   void visitConstant(const SCEVConstant *Numerator) {
782     if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
783       APInt NumeratorVal = Numerator->getAPInt();
784       APInt DenominatorVal = D->getAPInt();
785       uint32_t NumeratorBW = NumeratorVal.getBitWidth();
786       uint32_t DenominatorBW = DenominatorVal.getBitWidth();
787 
788       if (NumeratorBW > DenominatorBW)
789         DenominatorVal = DenominatorVal.sext(NumeratorBW);
790       else if (NumeratorBW < DenominatorBW)
791         NumeratorVal = NumeratorVal.sext(DenominatorBW);
792 
793       APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
794       APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
795       APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
796       Quotient = SE.getConstant(QuotientVal);
797       Remainder = SE.getConstant(RemainderVal);
798       return;
799     }
800   }
801 
802   void visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
803     const SCEV *StartQ, *StartR, *StepQ, *StepR;
804     if (!Numerator->isAffine())
805       return cannotDivide(Numerator);
806     divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
807     divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
808     // Bail out if the types do not match.
809     Type *Ty = Denominator->getType();
810     if (Ty != StartQ->getType() || Ty != StartR->getType() ||
811         Ty != StepQ->getType() || Ty != StepR->getType())
812       return cannotDivide(Numerator);
813     Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
814                                 Numerator->getNoWrapFlags());
815     Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
816                                  Numerator->getNoWrapFlags());
817   }
818 
819   void visitAddExpr(const SCEVAddExpr *Numerator) {
820     SmallVector<const SCEV *, 2> Qs, Rs;
821     Type *Ty = Denominator->getType();
822 
823     for (const SCEV *Op : Numerator->operands()) {
824       const SCEV *Q, *R;
825       divide(SE, Op, Denominator, &Q, &R);
826 
827       // Bail out if types do not match.
828       if (Ty != Q->getType() || Ty != R->getType())
829         return cannotDivide(Numerator);
830 
831       Qs.push_back(Q);
832       Rs.push_back(R);
833     }
834 
835     if (Qs.size() == 1) {
836       Quotient = Qs[0];
837       Remainder = Rs[0];
838       return;
839     }
840 
841     Quotient = SE.getAddExpr(Qs);
842     Remainder = SE.getAddExpr(Rs);
843   }
844 
845   void visitMulExpr(const SCEVMulExpr *Numerator) {
846     SmallVector<const SCEV *, 2> Qs;
847     Type *Ty = Denominator->getType();
848 
849     bool FoundDenominatorTerm = false;
850     for (const SCEV *Op : Numerator->operands()) {
851       // Bail out if types do not match.
852       if (Ty != Op->getType())
853         return cannotDivide(Numerator);
854 
855       if (FoundDenominatorTerm) {
856         Qs.push_back(Op);
857         continue;
858       }
859 
860       // Check whether Denominator divides one of the product operands.
861       const SCEV *Q, *R;
862       divide(SE, Op, Denominator, &Q, &R);
863       if (!R->isZero()) {
864         Qs.push_back(Op);
865         continue;
866       }
867 
868       // Bail out if types do not match.
869       if (Ty != Q->getType())
870         return cannotDivide(Numerator);
871 
872       FoundDenominatorTerm = true;
873       Qs.push_back(Q);
874     }
875 
876     if (FoundDenominatorTerm) {
877       Remainder = Zero;
878       if (Qs.size() == 1)
879         Quotient = Qs[0];
880       else
881         Quotient = SE.getMulExpr(Qs);
882       return;
883     }
884 
885     if (!isa<SCEVUnknown>(Denominator))
886       return cannotDivide(Numerator);
887 
888     // The Remainder is obtained by replacing Denominator by 0 in Numerator.
889     ValueToValueMap RewriteMap;
890     RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
891         cast<SCEVConstant>(Zero)->getValue();
892     Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
893 
894     if (Remainder->isZero()) {
895       // The Quotient is obtained by replacing Denominator by 1 in Numerator.
896       RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
897           cast<SCEVConstant>(One)->getValue();
898       Quotient =
899           SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
900       return;
901     }
902 
903     // Quotient is (Numerator - Remainder) divided by Denominator.
904     const SCEV *Q, *R;
905     const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
906     // This SCEV does not seem to simplify: fail the division here.
907     if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
908       return cannotDivide(Numerator);
909     divide(SE, Diff, Denominator, &Q, &R);
910     if (R != Zero)
911       return cannotDivide(Numerator);
912     Quotient = Q;
913   }
914 
915 private:
916   SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
917                const SCEV *Denominator)
918       : SE(S), Denominator(Denominator) {
919     Zero = SE.getZero(Denominator->getType());
920     One = SE.getOne(Denominator->getType());
921 
922     // We generally do not know how to divide Expr by Denominator. We
923     // initialize the division to a "cannot divide" state to simplify the rest
924     // of the code.
925     cannotDivide(Numerator);
926   }
927 
928   // Convenience function for giving up on the division. We set the quotient to
929   // be equal to zero and the remainder to be equal to the numerator.
930   void cannotDivide(const SCEV *Numerator) {
931     Quotient = Zero;
932     Remainder = Numerator;
933   }
934 
935   ScalarEvolution &SE;
936   const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One;
937 };
938 
939 }
940 
941 //===----------------------------------------------------------------------===//
942 //                      Simple SCEV method implementations
943 //===----------------------------------------------------------------------===//
944 
945 /// Compute BC(It, K).  The result has width W.  Assume, K > 0.
946 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
947                                        ScalarEvolution &SE,
948                                        Type *ResultTy) {
949   // Handle the simplest case efficiently.
950   if (K == 1)
951     return SE.getTruncateOrZeroExtend(It, ResultTy);
952 
953   // We are using the following formula for BC(It, K):
954   //
955   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
956   //
957   // Suppose, W is the bitwidth of the return value.  We must be prepared for
958   // overflow.  Hence, we must assure that the result of our computation is
959   // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
960   // safe in modular arithmetic.
961   //
962   // However, this code doesn't use exactly that formula; the formula it uses
963   // is something like the following, where T is the number of factors of 2 in
964   // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
965   // exponentiation:
966   //
967   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
968   //
969   // This formula is trivially equivalent to the previous formula.  However,
970   // this formula can be implemented much more efficiently.  The trick is that
971   // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
972   // arithmetic.  To do exact division in modular arithmetic, all we have
973   // to do is multiply by the inverse.  Therefore, this step can be done at
974   // width W.
975   //
976   // The next issue is how to safely do the division by 2^T.  The way this
977   // is done is by doing the multiplication step at a width of at least W + T
978   // bits.  This way, the bottom W+T bits of the product are accurate. Then,
979   // when we perform the division by 2^T (which is equivalent to a right shift
980   // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
981   // truncated out after the division by 2^T.
982   //
983   // In comparison to just directly using the first formula, this technique
984   // is much more efficient; using the first formula requires W * K bits,
985   // but this formula less than W + K bits. Also, the first formula requires
986   // a division step, whereas this formula only requires multiplies and shifts.
987   //
988   // It doesn't matter whether the subtraction step is done in the calculation
989   // width or the input iteration count's width; if the subtraction overflows,
990   // the result must be zero anyway.  We prefer here to do it in the width of
991   // the induction variable because it helps a lot for certain cases; CodeGen
992   // isn't smart enough to ignore the overflow, which leads to much less
993   // efficient code if the width of the subtraction is wider than the native
994   // register width.
995   //
996   // (It's possible to not widen at all by pulling out factors of 2 before
997   // the multiplication; for example, K=2 can be calculated as
998   // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
999   // extra arithmetic, so it's not an obvious win, and it gets
1000   // much more complicated for K > 3.)
1001 
1002   // Protection from insane SCEVs; this bound is conservative,
1003   // but it probably doesn't matter.
1004   if (K > 1000)
1005     return SE.getCouldNotCompute();
1006 
1007   unsigned W = SE.getTypeSizeInBits(ResultTy);
1008 
1009   // Calculate K! / 2^T and T; we divide out the factors of two before
1010   // multiplying for calculating K! / 2^T to avoid overflow.
1011   // Other overflow doesn't matter because we only care about the bottom
1012   // W bits of the result.
1013   APInt OddFactorial(W, 1);
1014   unsigned T = 1;
1015   for (unsigned i = 3; i <= K; ++i) {
1016     APInt Mult(W, i);
1017     unsigned TwoFactors = Mult.countTrailingZeros();
1018     T += TwoFactors;
1019     Mult = Mult.lshr(TwoFactors);
1020     OddFactorial *= Mult;
1021   }
1022 
1023   // We need at least W + T bits for the multiplication step
1024   unsigned CalculationBits = W + T;
1025 
1026   // Calculate 2^T, at width T+W.
1027   APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1028 
1029   // Calculate the multiplicative inverse of K! / 2^T;
1030   // this multiplication factor will perform the exact division by
1031   // K! / 2^T.
1032   APInt Mod = APInt::getSignedMinValue(W+1);
1033   APInt MultiplyFactor = OddFactorial.zext(W+1);
1034   MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
1035   MultiplyFactor = MultiplyFactor.trunc(W);
1036 
1037   // Calculate the product, at width T+W
1038   IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1039                                                       CalculationBits);
1040   const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1041   for (unsigned i = 1; i != K; ++i) {
1042     const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1043     Dividend = SE.getMulExpr(Dividend,
1044                              SE.getTruncateOrZeroExtend(S, CalculationTy));
1045   }
1046 
1047   // Divide by 2^T
1048   const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1049 
1050   // Truncate the result, and divide by K! / 2^T.
1051 
1052   return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1053                        SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1054 }
1055 
1056 /// Return the value of this chain of recurrences at the specified iteration
1057 /// number.  We can evaluate this recurrence by multiplying each element in the
1058 /// chain by the binomial coefficient corresponding to it.  In other words, we
1059 /// can evaluate {A,+,B,+,C,+,D} as:
1060 ///
1061 ///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1062 ///
1063 /// where BC(It, k) stands for binomial coefficient.
1064 ///
1065 const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
1066                                                 ScalarEvolution &SE) const {
1067   const SCEV *Result = getStart();
1068   for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
1069     // The computation is correct in the face of overflow provided that the
1070     // multiplication is performed _after_ the evaluation of the binomial
1071     // coefficient.
1072     const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType());
1073     if (isa<SCEVCouldNotCompute>(Coeff))
1074       return Coeff;
1075 
1076     Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
1077   }
1078   return Result;
1079 }
1080 
1081 //===----------------------------------------------------------------------===//
1082 //                    SCEV Expression folder implementations
1083 //===----------------------------------------------------------------------===//
1084 
1085 const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
1086                                              Type *Ty) {
1087   assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1088          "This is not a truncating conversion!");
1089   assert(isSCEVable(Ty) &&
1090          "This is not a conversion to a SCEVable type!");
1091   Ty = getEffectiveSCEVType(Ty);
1092 
1093   FoldingSetNodeID ID;
1094   ID.AddInteger(scTruncate);
1095   ID.AddPointer(Op);
1096   ID.AddPointer(Ty);
1097   void *IP = nullptr;
1098   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1099 
1100   // Fold if the operand is constant.
1101   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1102     return getConstant(
1103       cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1104 
1105   // trunc(trunc(x)) --> trunc(x)
1106   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1107     return getTruncateExpr(ST->getOperand(), Ty);
1108 
1109   // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1110   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1111     return getTruncateOrSignExtend(SS->getOperand(), Ty);
1112 
1113   // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1114   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1115     return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
1116 
1117   // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can
1118   // eliminate all the truncates, or we replace other casts with truncates.
1119   if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) {
1120     SmallVector<const SCEV *, 4> Operands;
1121     bool hasTrunc = false;
1122     for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) {
1123       const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty);
1124       if (!isa<SCEVCastExpr>(SA->getOperand(i)))
1125         hasTrunc = isa<SCEVTruncateExpr>(S);
1126       Operands.push_back(S);
1127     }
1128     if (!hasTrunc)
1129       return getAddExpr(Operands);
1130     UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
1131   }
1132 
1133   // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can
1134   // eliminate all the truncates, or we replace other casts with truncates.
1135   if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) {
1136     SmallVector<const SCEV *, 4> Operands;
1137     bool hasTrunc = false;
1138     for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) {
1139       const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty);
1140       if (!isa<SCEVCastExpr>(SM->getOperand(i)))
1141         hasTrunc = isa<SCEVTruncateExpr>(S);
1142       Operands.push_back(S);
1143     }
1144     if (!hasTrunc)
1145       return getMulExpr(Operands);
1146     UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
1147   }
1148 
1149   // If the input value is a chrec scev, truncate the chrec's operands.
1150   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1151     SmallVector<const SCEV *, 4> Operands;
1152     for (const SCEV *Op : AddRec->operands())
1153       Operands.push_back(getTruncateExpr(Op, Ty));
1154     return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1155   }
1156 
1157   // The cast wasn't folded; create an explicit cast node. We can reuse
1158   // the existing insert position since if we get here, we won't have
1159   // made any changes which would invalidate it.
1160   SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1161                                                  Op, Ty);
1162   UniqueSCEVs.InsertNode(S, IP);
1163   return S;
1164 }
1165 
1166 // Get the limit of a recurrence such that incrementing by Step cannot cause
1167 // signed overflow as long as the value of the recurrence within the
1168 // loop does not exceed this limit before incrementing.
1169 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1170                                                  ICmpInst::Predicate *Pred,
1171                                                  ScalarEvolution *SE) {
1172   unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1173   if (SE->isKnownPositive(Step)) {
1174     *Pred = ICmpInst::ICMP_SLT;
1175     return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
1176                            SE->getSignedRange(Step).getSignedMax());
1177   }
1178   if (SE->isKnownNegative(Step)) {
1179     *Pred = ICmpInst::ICMP_SGT;
1180     return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
1181                            SE->getSignedRange(Step).getSignedMin());
1182   }
1183   return nullptr;
1184 }
1185 
1186 // Get the limit of a recurrence such that incrementing by Step cannot cause
1187 // unsigned overflow as long as the value of the recurrence within the loop does
1188 // not exceed this limit before incrementing.
1189 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
1190                                                    ICmpInst::Predicate *Pred,
1191                                                    ScalarEvolution *SE) {
1192   unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1193   *Pred = ICmpInst::ICMP_ULT;
1194 
1195   return SE->getConstant(APInt::getMinValue(BitWidth) -
1196                          SE->getUnsignedRange(Step).getUnsignedMax());
1197 }
1198 
1199 namespace {
1200 
1201 struct ExtendOpTraitsBase {
1202   typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *);
1203 };
1204 
1205 // Used to make code generic over signed and unsigned overflow.
1206 template <typename ExtendOp> struct ExtendOpTraits {
1207   // Members present:
1208   //
1209   // static const SCEV::NoWrapFlags WrapType;
1210   //
1211   // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1212   //
1213   // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1214   //                                           ICmpInst::Predicate *Pred,
1215   //                                           ScalarEvolution *SE);
1216 };
1217 
1218 template <>
1219 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1220   static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1221 
1222   static const GetExtendExprTy GetExtendExpr;
1223 
1224   static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1225                                              ICmpInst::Predicate *Pred,
1226                                              ScalarEvolution *SE) {
1227     return getSignedOverflowLimitForStep(Step, Pred, SE);
1228   }
1229 };
1230 
1231 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1232     SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr;
1233 
1234 template <>
1235 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1236   static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1237 
1238   static const GetExtendExprTy GetExtendExpr;
1239 
1240   static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1241                                              ICmpInst::Predicate *Pred,
1242                                              ScalarEvolution *SE) {
1243     return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1244   }
1245 };
1246 
1247 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1248     SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
1249 }
1250 
1251 // The recurrence AR has been shown to have no signed/unsigned wrap or something
1252 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1253 // easily prove NSW/NUW for its preincrement or postincrement sibling. This
1254 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1255 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1256 // expression "Step + sext/zext(PreIncAR)" is congruent with
1257 // "sext/zext(PostIncAR)"
1258 template <typename ExtendOpTy>
1259 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1260                                         ScalarEvolution *SE) {
1261   auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1262   auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1263 
1264   const Loop *L = AR->getLoop();
1265   const SCEV *Start = AR->getStart();
1266   const SCEV *Step = AR->getStepRecurrence(*SE);
1267 
1268   // Check for a simple looking step prior to loop entry.
1269   const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1270   if (!SA)
1271     return nullptr;
1272 
1273   // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1274   // subtraction is expensive. For this purpose, perform a quick and dirty
1275   // difference, by checking for Step in the operand list.
1276   SmallVector<const SCEV *, 4> DiffOps;
1277   for (const SCEV *Op : SA->operands())
1278     if (Op != Step)
1279       DiffOps.push_back(Op);
1280 
1281   if (DiffOps.size() == SA->getNumOperands())
1282     return nullptr;
1283 
1284   // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1285   // `Step`:
1286 
1287   // 1. NSW/NUW flags on the step increment.
1288   auto PreStartFlags =
1289     ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW);
1290   const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1291   const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1292       SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1293 
1294   // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1295   // "S+X does not sign/unsign-overflow".
1296   //
1297 
1298   const SCEV *BECount = SE->getBackedgeTakenCount(L);
1299   if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1300       !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1301     return PreStart;
1302 
1303   // 2. Direct overflow check on the step operation's expression.
1304   unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1305   Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1306   const SCEV *OperandExtendedStart =
1307       SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy),
1308                      (SE->*GetExtendExpr)(Step, WideTy));
1309   if ((SE->*GetExtendExpr)(Start, WideTy) == OperandExtendedStart) {
1310     if (PreAR && AR->getNoWrapFlags(WrapType)) {
1311       // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1312       // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1313       // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`.  Cache this fact.
1314       const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(WrapType);
1315     }
1316     return PreStart;
1317   }
1318 
1319   // 3. Loop precondition.
1320   ICmpInst::Predicate Pred;
1321   const SCEV *OverflowLimit =
1322       ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1323 
1324   if (OverflowLimit &&
1325       SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1326     return PreStart;
1327 
1328   return nullptr;
1329 }
1330 
1331 // Get the normalized zero or sign extended expression for this AddRec's Start.
1332 template <typename ExtendOpTy>
1333 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1334                                         ScalarEvolution *SE) {
1335   auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1336 
1337   const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE);
1338   if (!PreStart)
1339     return (SE->*GetExtendExpr)(AR->getStart(), Ty);
1340 
1341   return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty),
1342                         (SE->*GetExtendExpr)(PreStart, Ty));
1343 }
1344 
1345 // Try to prove away overflow by looking at "nearby" add recurrences.  A
1346 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1347 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1348 //
1349 // Formally:
1350 //
1351 //     {S,+,X} == {S-T,+,X} + T
1352 //  => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1353 //
1354 // If ({S-T,+,X} + T) does not overflow  ... (1)
1355 //
1356 //  RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1357 //
1358 // If {S-T,+,X} does not overflow  ... (2)
1359 //
1360 //  RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1361 //      == {Ext(S-T)+Ext(T),+,Ext(X)}
1362 //
1363 // If (S-T)+T does not overflow  ... (3)
1364 //
1365 //  RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1366 //      == {Ext(S),+,Ext(X)} == LHS
1367 //
1368 // Thus, if (1), (2) and (3) are true for some T, then
1369 //   Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1370 //
1371 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1372 // does not overflow" restricted to the 0th iteration.  Therefore we only need
1373 // to check for (1) and (2).
1374 //
1375 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1376 // is `Delta` (defined below).
1377 //
1378 template <typename ExtendOpTy>
1379 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1380                                                 const SCEV *Step,
1381                                                 const Loop *L) {
1382   auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1383 
1384   // We restrict `Start` to a constant to prevent SCEV from spending too much
1385   // time here.  It is correct (but more expensive) to continue with a
1386   // non-constant `Start` and do a general SCEV subtraction to compute
1387   // `PreStart` below.
1388   //
1389   const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1390   if (!StartC)
1391     return false;
1392 
1393   APInt StartAI = StartC->getAPInt();
1394 
1395   for (unsigned Delta : {-2, -1, 1, 2}) {
1396     const SCEV *PreStart = getConstant(StartAI - Delta);
1397 
1398     FoldingSetNodeID ID;
1399     ID.AddInteger(scAddRecExpr);
1400     ID.AddPointer(PreStart);
1401     ID.AddPointer(Step);
1402     ID.AddPointer(L);
1403     void *IP = nullptr;
1404     const auto *PreAR =
1405       static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1406 
1407     // Give up if we don't already have the add recurrence we need because
1408     // actually constructing an add recurrence is relatively expensive.
1409     if (PreAR && PreAR->getNoWrapFlags(WrapType)) {  // proves (2)
1410       const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1411       ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
1412       const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1413           DeltaS, &Pred, this);
1414       if (Limit && isKnownPredicate(Pred, PreAR, Limit))  // proves (1)
1415         return true;
1416     }
1417   }
1418 
1419   return false;
1420 }
1421 
1422 const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
1423                                                Type *Ty) {
1424   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1425          "This is not an extending conversion!");
1426   assert(isSCEVable(Ty) &&
1427          "This is not a conversion to a SCEVable type!");
1428   Ty = getEffectiveSCEVType(Ty);
1429 
1430   // Fold if the operand is constant.
1431   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1432     return getConstant(
1433       cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
1434 
1435   // zext(zext(x)) --> zext(x)
1436   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1437     return getZeroExtendExpr(SZ->getOperand(), Ty);
1438 
1439   // Before doing any expensive analysis, check to see if we've already
1440   // computed a SCEV for this Op and Ty.
1441   FoldingSetNodeID ID;
1442   ID.AddInteger(scZeroExtend);
1443   ID.AddPointer(Op);
1444   ID.AddPointer(Ty);
1445   void *IP = nullptr;
1446   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1447 
1448   // zext(trunc(x)) --> zext(x) or x or trunc(x)
1449   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1450     // It's possible the bits taken off by the truncate were all zero bits. If
1451     // so, we should be able to simplify this further.
1452     const SCEV *X = ST->getOperand();
1453     ConstantRange CR = getUnsignedRange(X);
1454     unsigned TruncBits = getTypeSizeInBits(ST->getType());
1455     unsigned NewBits = getTypeSizeInBits(Ty);
1456     if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1457             CR.zextOrTrunc(NewBits)))
1458       return getTruncateOrZeroExtend(X, Ty);
1459   }
1460 
1461   // If the input value is a chrec scev, and we can prove that the value
1462   // did not overflow the old, smaller, value, we can zero extend all of the
1463   // operands (often constants).  This allows analysis of something like
1464   // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1465   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1466     if (AR->isAffine()) {
1467       const SCEV *Start = AR->getStart();
1468       const SCEV *Step = AR->getStepRecurrence(*this);
1469       unsigned BitWidth = getTypeSizeInBits(AR->getType());
1470       const Loop *L = AR->getLoop();
1471 
1472       if (!AR->hasNoUnsignedWrap()) {
1473         auto NewFlags = proveNoWrapViaConstantRanges(AR);
1474         const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags);
1475       }
1476 
1477       // If we have special knowledge that this addrec won't overflow,
1478       // we don't need to do any further analysis.
1479       if (AR->hasNoUnsignedWrap())
1480         return getAddRecExpr(
1481             getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
1482             getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1483 
1484       // Check whether the backedge-taken count is SCEVCouldNotCompute.
1485       // Note that this serves two purposes: It filters out loops that are
1486       // simply not analyzable, and it covers the case where this code is
1487       // being called from within backedge-taken count analysis, such that
1488       // attempting to ask for the backedge-taken count would likely result
1489       // in infinite recursion. In the later case, the analysis code will
1490       // cope with a conservative value, and it will take care to purge
1491       // that value once it has finished.
1492       const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
1493       if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1494         // Manually compute the final value for AR, checking for
1495         // overflow.
1496 
1497         // Check whether the backedge-taken count can be losslessly casted to
1498         // the addrec's type. The count is always unsigned.
1499         const SCEV *CastedMaxBECount =
1500           getTruncateOrZeroExtend(MaxBECount, Start->getType());
1501         const SCEV *RecastedMaxBECount =
1502           getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
1503         if (MaxBECount == RecastedMaxBECount) {
1504           Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1505           // Check whether Start+Step*MaxBECount has no unsigned overflow.
1506           const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step);
1507           const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul), WideTy);
1508           const SCEV *WideStart = getZeroExtendExpr(Start, WideTy);
1509           const SCEV *WideMaxBECount =
1510             getZeroExtendExpr(CastedMaxBECount, WideTy);
1511           const SCEV *OperandExtendedAdd =
1512             getAddExpr(WideStart,
1513                        getMulExpr(WideMaxBECount,
1514                                   getZeroExtendExpr(Step, WideTy)));
1515           if (ZAdd == OperandExtendedAdd) {
1516             // Cache knowledge of AR NUW, which is propagated to this AddRec.
1517             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1518             // Return the expression with the addrec on the outside.
1519             return getAddRecExpr(
1520                 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
1521                 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1522           }
1523           // Similar to above, only this time treat the step value as signed.
1524           // This covers loops that count down.
1525           OperandExtendedAdd =
1526             getAddExpr(WideStart,
1527                        getMulExpr(WideMaxBECount,
1528                                   getSignExtendExpr(Step, WideTy)));
1529           if (ZAdd == OperandExtendedAdd) {
1530             // Cache knowledge of AR NW, which is propagated to this AddRec.
1531             // Negative step causes unsigned wrap, but it still can't self-wrap.
1532             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1533             // Return the expression with the addrec on the outside.
1534             return getAddRecExpr(
1535                 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
1536                 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1537           }
1538         }
1539       }
1540 
1541       // Normally, in the cases we can prove no-overflow via a
1542       // backedge guarding condition, we can also compute a backedge
1543       // taken count for the loop.  The exceptions are assumptions and
1544       // guards present in the loop -- SCEV is not great at exploiting
1545       // these to compute max backedge taken counts, but can still use
1546       // these to prove lack of overflow.  Use this fact to avoid
1547       // doing extra work that may not pay off.
1548       if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1549           !AC.assumptions().empty()) {
1550         // If the backedge is guarded by a comparison with the pre-inc
1551         // value the addrec is safe. Also, if the entry is guarded by
1552         // a comparison with the start value and the backedge is
1553         // guarded by a comparison with the post-inc value, the addrec
1554         // is safe.
1555         if (isKnownPositive(Step)) {
1556           const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
1557                                       getUnsignedRange(Step).getUnsignedMax());
1558           if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
1559               (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) &&
1560                isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT,
1561                                            AR->getPostIncExpr(*this), N))) {
1562             // Cache knowledge of AR NUW, which is propagated to this
1563             // AddRec.
1564             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1565             // Return the expression with the addrec on the outside.
1566             return getAddRecExpr(
1567                 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
1568                 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1569           }
1570         } else if (isKnownNegative(Step)) {
1571           const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
1572                                       getSignedRange(Step).getSignedMin());
1573           if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
1574               (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) &&
1575                isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT,
1576                                            AR->getPostIncExpr(*this), N))) {
1577             // Cache knowledge of AR NW, which is propagated to this
1578             // AddRec.  Negative step causes unsigned wrap, but it
1579             // still can't self-wrap.
1580             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1581             // Return the expression with the addrec on the outside.
1582             return getAddRecExpr(
1583                 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
1584                 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1585           }
1586         }
1587       }
1588 
1589       if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1590         const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1591         return getAddRecExpr(
1592             getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
1593             getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1594       }
1595     }
1596 
1597   if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1598     // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1599     if (SA->hasNoUnsignedWrap()) {
1600       // If the addition does not unsign overflow then we can, by definition,
1601       // commute the zero extension with the addition operation.
1602       SmallVector<const SCEV *, 4> Ops;
1603       for (const auto *Op : SA->operands())
1604         Ops.push_back(getZeroExtendExpr(Op, Ty));
1605       return getAddExpr(Ops, SCEV::FlagNUW);
1606     }
1607   }
1608 
1609   // The cast wasn't folded; create an explicit cast node.
1610   // Recompute the insert position, as it may have been invalidated.
1611   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1612   SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1613                                                    Op, Ty);
1614   UniqueSCEVs.InsertNode(S, IP);
1615   return S;
1616 }
1617 
1618 const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
1619                                                Type *Ty) {
1620   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1621          "This is not an extending conversion!");
1622   assert(isSCEVable(Ty) &&
1623          "This is not a conversion to a SCEVable type!");
1624   Ty = getEffectiveSCEVType(Ty);
1625 
1626   // Fold if the operand is constant.
1627   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1628     return getConstant(
1629       cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
1630 
1631   // sext(sext(x)) --> sext(x)
1632   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1633     return getSignExtendExpr(SS->getOperand(), Ty);
1634 
1635   // sext(zext(x)) --> zext(x)
1636   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1637     return getZeroExtendExpr(SZ->getOperand(), Ty);
1638 
1639   // Before doing any expensive analysis, check to see if we've already
1640   // computed a SCEV for this Op and Ty.
1641   FoldingSetNodeID ID;
1642   ID.AddInteger(scSignExtend);
1643   ID.AddPointer(Op);
1644   ID.AddPointer(Ty);
1645   void *IP = nullptr;
1646   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1647 
1648   // sext(trunc(x)) --> sext(x) or x or trunc(x)
1649   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1650     // It's possible the bits taken off by the truncate were all sign bits. If
1651     // so, we should be able to simplify this further.
1652     const SCEV *X = ST->getOperand();
1653     ConstantRange CR = getSignedRange(X);
1654     unsigned TruncBits = getTypeSizeInBits(ST->getType());
1655     unsigned NewBits = getTypeSizeInBits(Ty);
1656     if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1657             CR.sextOrTrunc(NewBits)))
1658       return getTruncateOrSignExtend(X, Ty);
1659   }
1660 
1661   // sext(C1 + (C2 * x)) --> C1 + sext(C2 * x) if C1 < C2
1662   if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1663     if (SA->getNumOperands() == 2) {
1664       auto *SC1 = dyn_cast<SCEVConstant>(SA->getOperand(0));
1665       auto *SMul = dyn_cast<SCEVMulExpr>(SA->getOperand(1));
1666       if (SMul && SC1) {
1667         if (auto *SC2 = dyn_cast<SCEVConstant>(SMul->getOperand(0))) {
1668           const APInt &C1 = SC1->getAPInt();
1669           const APInt &C2 = SC2->getAPInt();
1670           if (C1.isStrictlyPositive() && C2.isStrictlyPositive() &&
1671               C2.ugt(C1) && C2.isPowerOf2())
1672             return getAddExpr(getSignExtendExpr(SC1, Ty),
1673                               getSignExtendExpr(SMul, Ty));
1674         }
1675       }
1676     }
1677 
1678     // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1679     if (SA->hasNoSignedWrap()) {
1680       // If the addition does not sign overflow then we can, by definition,
1681       // commute the sign extension with the addition operation.
1682       SmallVector<const SCEV *, 4> Ops;
1683       for (const auto *Op : SA->operands())
1684         Ops.push_back(getSignExtendExpr(Op, Ty));
1685       return getAddExpr(Ops, SCEV::FlagNSW);
1686     }
1687   }
1688   // If the input value is a chrec scev, and we can prove that the value
1689   // did not overflow the old, smaller, value, we can sign extend all of the
1690   // operands (often constants).  This allows analysis of something like
1691   // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
1692   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1693     if (AR->isAffine()) {
1694       const SCEV *Start = AR->getStart();
1695       const SCEV *Step = AR->getStepRecurrence(*this);
1696       unsigned BitWidth = getTypeSizeInBits(AR->getType());
1697       const Loop *L = AR->getLoop();
1698 
1699       if (!AR->hasNoSignedWrap()) {
1700         auto NewFlags = proveNoWrapViaConstantRanges(AR);
1701         const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags);
1702       }
1703 
1704       // If we have special knowledge that this addrec won't overflow,
1705       // we don't need to do any further analysis.
1706       if (AR->hasNoSignedWrap())
1707         return getAddRecExpr(
1708             getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
1709             getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW);
1710 
1711       // Check whether the backedge-taken count is SCEVCouldNotCompute.
1712       // Note that this serves two purposes: It filters out loops that are
1713       // simply not analyzable, and it covers the case where this code is
1714       // being called from within backedge-taken count analysis, such that
1715       // attempting to ask for the backedge-taken count would likely result
1716       // in infinite recursion. In the later case, the analysis code will
1717       // cope with a conservative value, and it will take care to purge
1718       // that value once it has finished.
1719       const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
1720       if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1721         // Manually compute the final value for AR, checking for
1722         // overflow.
1723 
1724         // Check whether the backedge-taken count can be losslessly casted to
1725         // the addrec's type. The count is always unsigned.
1726         const SCEV *CastedMaxBECount =
1727           getTruncateOrZeroExtend(MaxBECount, Start->getType());
1728         const SCEV *RecastedMaxBECount =
1729           getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
1730         if (MaxBECount == RecastedMaxBECount) {
1731           Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1732           // Check whether Start+Step*MaxBECount has no signed overflow.
1733           const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
1734           const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul), WideTy);
1735           const SCEV *WideStart = getSignExtendExpr(Start, WideTy);
1736           const SCEV *WideMaxBECount =
1737             getZeroExtendExpr(CastedMaxBECount, WideTy);
1738           const SCEV *OperandExtendedAdd =
1739             getAddExpr(WideStart,
1740                        getMulExpr(WideMaxBECount,
1741                                   getSignExtendExpr(Step, WideTy)));
1742           if (SAdd == OperandExtendedAdd) {
1743             // Cache knowledge of AR NSW, which is propagated to this AddRec.
1744             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1745             // Return the expression with the addrec on the outside.
1746             return getAddRecExpr(
1747                 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
1748                 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1749           }
1750           // Similar to above, only this time treat the step value as unsigned.
1751           // This covers loops that count up with an unsigned step.
1752           OperandExtendedAdd =
1753             getAddExpr(WideStart,
1754                        getMulExpr(WideMaxBECount,
1755                                   getZeroExtendExpr(Step, WideTy)));
1756           if (SAdd == OperandExtendedAdd) {
1757             // If AR wraps around then
1758             //
1759             //    abs(Step) * MaxBECount > unsigned-max(AR->getType())
1760             // => SAdd != OperandExtendedAdd
1761             //
1762             // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
1763             // (SAdd == OperandExtendedAdd => AR is NW)
1764 
1765             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1766 
1767             // Return the expression with the addrec on the outside.
1768             return getAddRecExpr(
1769                 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
1770                 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1771           }
1772         }
1773       }
1774 
1775       // Normally, in the cases we can prove no-overflow via a
1776       // backedge guarding condition, we can also compute a backedge
1777       // taken count for the loop.  The exceptions are assumptions and
1778       // guards present in the loop -- SCEV is not great at exploiting
1779       // these to compute max backedge taken counts, but can still use
1780       // these to prove lack of overflow.  Use this fact to avoid
1781       // doing extra work that may not pay off.
1782 
1783       if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1784           !AC.assumptions().empty()) {
1785         // If the backedge is guarded by a comparison with the pre-inc
1786         // value the addrec is safe. Also, if the entry is guarded by
1787         // a comparison with the start value and the backedge is
1788         // guarded by a comparison with the post-inc value, the addrec
1789         // is safe.
1790         ICmpInst::Predicate Pred;
1791         const SCEV *OverflowLimit =
1792             getSignedOverflowLimitForStep(Step, &Pred, this);
1793         if (OverflowLimit &&
1794             (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
1795              (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) &&
1796               isLoopBackedgeGuardedByCond(L, Pred, AR->getPostIncExpr(*this),
1797                                           OverflowLimit)))) {
1798           // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec.
1799           const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1800           return getAddRecExpr(
1801               getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
1802               getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1803         }
1804       }
1805 
1806       // If Start and Step are constants, check if we can apply this
1807       // transformation:
1808       // sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2
1809       auto *SC1 = dyn_cast<SCEVConstant>(Start);
1810       auto *SC2 = dyn_cast<SCEVConstant>(Step);
1811       if (SC1 && SC2) {
1812         const APInt &C1 = SC1->getAPInt();
1813         const APInt &C2 = SC2->getAPInt();
1814         if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) &&
1815             C2.isPowerOf2()) {
1816           Start = getSignExtendExpr(Start, Ty);
1817           const SCEV *NewAR = getAddRecExpr(getZero(AR->getType()), Step, L,
1818                                             AR->getNoWrapFlags());
1819           return getAddExpr(Start, getSignExtendExpr(NewAR, Ty));
1820         }
1821       }
1822 
1823       if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
1824         const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1825         return getAddRecExpr(
1826             getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
1827             getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
1828       }
1829     }
1830 
1831   // If the input value is provably positive and we could not simplify
1832   // away the sext build a zext instead.
1833   if (isKnownNonNegative(Op))
1834     return getZeroExtendExpr(Op, Ty);
1835 
1836   // The cast wasn't folded; create an explicit cast node.
1837   // Recompute the insert position, as it may have been invalidated.
1838   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1839   SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1840                                                    Op, Ty);
1841   UniqueSCEVs.InsertNode(S, IP);
1842   return S;
1843 }
1844 
1845 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
1846 /// unspecified bits out to the given type.
1847 ///
1848 const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
1849                                               Type *Ty) {
1850   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1851          "This is not an extending conversion!");
1852   assert(isSCEVable(Ty) &&
1853          "This is not a conversion to a SCEVable type!");
1854   Ty = getEffectiveSCEVType(Ty);
1855 
1856   // Sign-extend negative constants.
1857   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1858     if (SC->getAPInt().isNegative())
1859       return getSignExtendExpr(Op, Ty);
1860 
1861   // Peel off a truncate cast.
1862   if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
1863     const SCEV *NewOp = T->getOperand();
1864     if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
1865       return getAnyExtendExpr(NewOp, Ty);
1866     return getTruncateOrNoop(NewOp, Ty);
1867   }
1868 
1869   // Next try a zext cast. If the cast is folded, use it.
1870   const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
1871   if (!isa<SCEVZeroExtendExpr>(ZExt))
1872     return ZExt;
1873 
1874   // Next try a sext cast. If the cast is folded, use it.
1875   const SCEV *SExt = getSignExtendExpr(Op, Ty);
1876   if (!isa<SCEVSignExtendExpr>(SExt))
1877     return SExt;
1878 
1879   // Force the cast to be folded into the operands of an addrec.
1880   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
1881     SmallVector<const SCEV *, 4> Ops;
1882     for (const SCEV *Op : AR->operands())
1883       Ops.push_back(getAnyExtendExpr(Op, Ty));
1884     return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
1885   }
1886 
1887   // If the expression is obviously signed, use the sext cast value.
1888   if (isa<SCEVSMaxExpr>(Op))
1889     return SExt;
1890 
1891   // Absent any other information, use the zext cast value.
1892   return ZExt;
1893 }
1894 
1895 /// Process the given Ops list, which is a list of operands to be added under
1896 /// the given scale, update the given map. This is a helper function for
1897 /// getAddRecExpr. As an example of what it does, given a sequence of operands
1898 /// that would form an add expression like this:
1899 ///
1900 ///    m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
1901 ///
1902 /// where A and B are constants, update the map with these values:
1903 ///
1904 ///    (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
1905 ///
1906 /// and add 13 + A*B*29 to AccumulatedConstant.
1907 /// This will allow getAddRecExpr to produce this:
1908 ///
1909 ///    13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
1910 ///
1911 /// This form often exposes folding opportunities that are hidden in
1912 /// the original operand list.
1913 ///
1914 /// Return true iff it appears that any interesting folding opportunities
1915 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
1916 /// the common case where no interesting opportunities are present, and
1917 /// is also used as a check to avoid infinite recursion.
1918 ///
1919 static bool
1920 CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
1921                              SmallVectorImpl<const SCEV *> &NewOps,
1922                              APInt &AccumulatedConstant,
1923                              const SCEV *const *Ops, size_t NumOperands,
1924                              const APInt &Scale,
1925                              ScalarEvolution &SE) {
1926   bool Interesting = false;
1927 
1928   // Iterate over the add operands. They are sorted, with constants first.
1929   unsigned i = 0;
1930   while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1931     ++i;
1932     // Pull a buried constant out to the outside.
1933     if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
1934       Interesting = true;
1935     AccumulatedConstant += Scale * C->getAPInt();
1936   }
1937 
1938   // Next comes everything else. We're especially interested in multiplies
1939   // here, but they're in the middle, so just visit the rest with one loop.
1940   for (; i != NumOperands; ++i) {
1941     const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
1942     if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
1943       APInt NewScale =
1944           Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
1945       if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
1946         // A multiplication of a constant with another add; recurse.
1947         const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
1948         Interesting |=
1949           CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1950                                        Add->op_begin(), Add->getNumOperands(),
1951                                        NewScale, SE);
1952       } else {
1953         // A multiplication of a constant with some other value. Update
1954         // the map.
1955         SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
1956         const SCEV *Key = SE.getMulExpr(MulOps);
1957         auto Pair = M.insert({Key, NewScale});
1958         if (Pair.second) {
1959           NewOps.push_back(Pair.first->first);
1960         } else {
1961           Pair.first->second += NewScale;
1962           // The map already had an entry for this value, which may indicate
1963           // a folding opportunity.
1964           Interesting = true;
1965         }
1966       }
1967     } else {
1968       // An ordinary operand. Update the map.
1969       std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1970           M.insert({Ops[i], Scale});
1971       if (Pair.second) {
1972         NewOps.push_back(Pair.first->first);
1973       } else {
1974         Pair.first->second += Scale;
1975         // The map already had an entry for this value, which may indicate
1976         // a folding opportunity.
1977         Interesting = true;
1978       }
1979     }
1980   }
1981 
1982   return Interesting;
1983 }
1984 
1985 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and
1986 // `OldFlags' as can't-wrap behavior.  Infer a more aggressive set of
1987 // can't-overflow flags for the operation if possible.
1988 static SCEV::NoWrapFlags
1989 StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
1990                       const SmallVectorImpl<const SCEV *> &Ops,
1991                       SCEV::NoWrapFlags Flags) {
1992   using namespace std::placeholders;
1993   typedef OverflowingBinaryOperator OBO;
1994 
1995   bool CanAnalyze =
1996       Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
1997   (void)CanAnalyze;
1998   assert(CanAnalyze && "don't call from other places!");
1999 
2000   int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2001   SCEV::NoWrapFlags SignOrUnsignWrap =
2002       ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2003 
2004   // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2005   auto IsKnownNonNegative = [&](const SCEV *S) {
2006     return SE->isKnownNonNegative(S);
2007   };
2008 
2009   if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2010     Flags =
2011         ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2012 
2013   SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2014 
2015   if (SignOrUnsignWrap != SignOrUnsignMask && Type == scAddExpr &&
2016       Ops.size() == 2 && isa<SCEVConstant>(Ops[0])) {
2017 
2018     // (A + C) --> (A + C)<nsw> if the addition does not sign overflow
2019     // (A + C) --> (A + C)<nuw> if the addition does not unsign overflow
2020 
2021     const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2022     if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2023       auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
2024           Instruction::Add, C, OBO::NoSignedWrap);
2025       if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2026         Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2027     }
2028     if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2029       auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
2030           Instruction::Add, C, OBO::NoUnsignedWrap);
2031       if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2032         Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2033     }
2034   }
2035 
2036   return Flags;
2037 }
2038 
2039 /// Get a canonical add expression, or something simpler if possible.
2040 const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
2041                                         SCEV::NoWrapFlags Flags) {
2042   assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2043          "only nuw or nsw allowed");
2044   assert(!Ops.empty() && "Cannot get empty add!");
2045   if (Ops.size() == 1) return Ops[0];
2046 #ifndef NDEBUG
2047   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2048   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2049     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2050            "SCEVAddExpr operand types don't match!");
2051 #endif
2052 
2053   // Sort by complexity, this groups all similar expression types together.
2054   GroupByComplexity(Ops, &LI);
2055 
2056   Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags);
2057 
2058   // If there are any constants, fold them together.
2059   unsigned Idx = 0;
2060   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2061     ++Idx;
2062     assert(Idx < Ops.size());
2063     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2064       // We found two constants, fold them together!
2065       Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2066       if (Ops.size() == 2) return Ops[0];
2067       Ops.erase(Ops.begin()+1);  // Erase the folded element
2068       LHSC = cast<SCEVConstant>(Ops[0]);
2069     }
2070 
2071     // If we are left with a constant zero being added, strip it off.
2072     if (LHSC->getValue()->isZero()) {
2073       Ops.erase(Ops.begin());
2074       --Idx;
2075     }
2076 
2077     if (Ops.size() == 1) return Ops[0];
2078   }
2079 
2080   // Okay, check to see if the same value occurs in the operand list more than
2081   // once.  If so, merge them together into an multiply expression.  Since we
2082   // sorted the list, these values are required to be adjacent.
2083   Type *Ty = Ops[0]->getType();
2084   bool FoundMatch = false;
2085   for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2086     if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
2087       // Scan ahead to count how many equal operands there are.
2088       unsigned Count = 2;
2089       while (i+Count != e && Ops[i+Count] == Ops[i])
2090         ++Count;
2091       // Merge the values into a multiply.
2092       const SCEV *Scale = getConstant(Ty, Count);
2093       const SCEV *Mul = getMulExpr(Scale, Ops[i]);
2094       if (Ops.size() == Count)
2095         return Mul;
2096       Ops[i] = Mul;
2097       Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2098       --i; e -= Count - 1;
2099       FoundMatch = true;
2100     }
2101   if (FoundMatch)
2102     return getAddExpr(Ops, Flags);
2103 
2104   // Check for truncates. If all the operands are truncated from the same
2105   // type, see if factoring out the truncate would permit the result to be
2106   // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
2107   // if the contents of the resulting outer trunc fold to something simple.
2108   for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
2109     const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
2110     Type *DstType = Trunc->getType();
2111     Type *SrcType = Trunc->getOperand()->getType();
2112     SmallVector<const SCEV *, 8> LargeOps;
2113     bool Ok = true;
2114     // Check all the operands to see if they can be represented in the
2115     // source type of the truncate.
2116     for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2117       if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2118         if (T->getOperand()->getType() != SrcType) {
2119           Ok = false;
2120           break;
2121         }
2122         LargeOps.push_back(T->getOperand());
2123       } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2124         LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2125       } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2126         SmallVector<const SCEV *, 8> LargeMulOps;
2127         for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2128           if (const SCEVTruncateExpr *T =
2129                 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2130             if (T->getOperand()->getType() != SrcType) {
2131               Ok = false;
2132               break;
2133             }
2134             LargeMulOps.push_back(T->getOperand());
2135           } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2136             LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2137           } else {
2138             Ok = false;
2139             break;
2140           }
2141         }
2142         if (Ok)
2143           LargeOps.push_back(getMulExpr(LargeMulOps));
2144       } else {
2145         Ok = false;
2146         break;
2147       }
2148     }
2149     if (Ok) {
2150       // Evaluate the expression in the larger type.
2151       const SCEV *Fold = getAddExpr(LargeOps, Flags);
2152       // If it folds to something simple, use it. Otherwise, don't.
2153       if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2154         return getTruncateExpr(Fold, DstType);
2155     }
2156   }
2157 
2158   // Skip past any other cast SCEVs.
2159   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2160     ++Idx;
2161 
2162   // If there are add operands they would be next.
2163   if (Idx < Ops.size()) {
2164     bool DeletedAdd = false;
2165     while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2166       // If we have an add, expand the add operands onto the end of the operands
2167       // list.
2168       Ops.erase(Ops.begin()+Idx);
2169       Ops.append(Add->op_begin(), Add->op_end());
2170       DeletedAdd = true;
2171     }
2172 
2173     // If we deleted at least one add, we added operands to the end of the list,
2174     // and they are not necessarily sorted.  Recurse to resort and resimplify
2175     // any operands we just acquired.
2176     if (DeletedAdd)
2177       return getAddExpr(Ops);
2178   }
2179 
2180   // Skip over the add expression until we get to a multiply.
2181   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2182     ++Idx;
2183 
2184   // Check to see if there are any folding opportunities present with
2185   // operands multiplied by constant values.
2186   if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2187     uint64_t BitWidth = getTypeSizeInBits(Ty);
2188     DenseMap<const SCEV *, APInt> M;
2189     SmallVector<const SCEV *, 8> NewOps;
2190     APInt AccumulatedConstant(BitWidth, 0);
2191     if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2192                                      Ops.data(), Ops.size(),
2193                                      APInt(BitWidth, 1), *this)) {
2194       struct APIntCompare {
2195         bool operator()(const APInt &LHS, const APInt &RHS) const {
2196           return LHS.ult(RHS);
2197         }
2198       };
2199 
2200       // Some interesting folding opportunity is present, so its worthwhile to
2201       // re-generate the operands list. Group the operands by constant scale,
2202       // to avoid multiplying by the same constant scale multiple times.
2203       std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2204       for (const SCEV *NewOp : NewOps)
2205         MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2206       // Re-generate the operands list.
2207       Ops.clear();
2208       if (AccumulatedConstant != 0)
2209         Ops.push_back(getConstant(AccumulatedConstant));
2210       for (auto &MulOp : MulOpLists)
2211         if (MulOp.first != 0)
2212           Ops.push_back(getMulExpr(getConstant(MulOp.first),
2213                                    getAddExpr(MulOp.second)));
2214       if (Ops.empty())
2215         return getZero(Ty);
2216       if (Ops.size() == 1)
2217         return Ops[0];
2218       return getAddExpr(Ops);
2219     }
2220   }
2221 
2222   // If we are adding something to a multiply expression, make sure the
2223   // something is not already an operand of the multiply.  If so, merge it into
2224   // the multiply.
2225   for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2226     const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2227     for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2228       const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2229       if (isa<SCEVConstant>(MulOpSCEV))
2230         continue;
2231       for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2232         if (MulOpSCEV == Ops[AddOp]) {
2233           // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
2234           const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2235           if (Mul->getNumOperands() != 2) {
2236             // If the multiply has more than two operands, we must get the
2237             // Y*Z term.
2238             SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2239                                                 Mul->op_begin()+MulOp);
2240             MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2241             InnerMul = getMulExpr(MulOps);
2242           }
2243           const SCEV *One = getOne(Ty);
2244           const SCEV *AddOne = getAddExpr(One, InnerMul);
2245           const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV);
2246           if (Ops.size() == 2) return OuterMul;
2247           if (AddOp < Idx) {
2248             Ops.erase(Ops.begin()+AddOp);
2249             Ops.erase(Ops.begin()+Idx-1);
2250           } else {
2251             Ops.erase(Ops.begin()+Idx);
2252             Ops.erase(Ops.begin()+AddOp-1);
2253           }
2254           Ops.push_back(OuterMul);
2255           return getAddExpr(Ops);
2256         }
2257 
2258       // Check this multiply against other multiplies being added together.
2259       for (unsigned OtherMulIdx = Idx+1;
2260            OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2261            ++OtherMulIdx) {
2262         const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2263         // If MulOp occurs in OtherMul, we can fold the two multiplies
2264         // together.
2265         for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2266              OMulOp != e; ++OMulOp)
2267           if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2268             // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2269             const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2270             if (Mul->getNumOperands() != 2) {
2271               SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2272                                                   Mul->op_begin()+MulOp);
2273               MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2274               InnerMul1 = getMulExpr(MulOps);
2275             }
2276             const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2277             if (OtherMul->getNumOperands() != 2) {
2278               SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
2279                                                   OtherMul->op_begin()+OMulOp);
2280               MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
2281               InnerMul2 = getMulExpr(MulOps);
2282             }
2283             const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
2284             const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
2285             if (Ops.size() == 2) return OuterMul;
2286             Ops.erase(Ops.begin()+Idx);
2287             Ops.erase(Ops.begin()+OtherMulIdx-1);
2288             Ops.push_back(OuterMul);
2289             return getAddExpr(Ops);
2290           }
2291       }
2292     }
2293   }
2294 
2295   // If there are any add recurrences in the operands list, see if any other
2296   // added values are loop invariant.  If so, we can fold them into the
2297   // recurrence.
2298   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2299     ++Idx;
2300 
2301   // Scan over all recurrences, trying to fold loop invariants into them.
2302   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2303     // Scan all of the other operands to this add and add them to the vector if
2304     // they are loop invariant w.r.t. the recurrence.
2305     SmallVector<const SCEV *, 8> LIOps;
2306     const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2307     const Loop *AddRecLoop = AddRec->getLoop();
2308     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2309       if (isLoopInvariant(Ops[i], AddRecLoop)) {
2310         LIOps.push_back(Ops[i]);
2311         Ops.erase(Ops.begin()+i);
2312         --i; --e;
2313       }
2314 
2315     // If we found some loop invariants, fold them into the recurrence.
2316     if (!LIOps.empty()) {
2317       //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
2318       LIOps.push_back(AddRec->getStart());
2319 
2320       SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
2321                                              AddRec->op_end());
2322       // This follows from the fact that the no-wrap flags on the outer add
2323       // expression are applicable on the 0th iteration, when the add recurrence
2324       // will be equal to its start value.
2325       AddRecOps[0] = getAddExpr(LIOps, Flags);
2326 
2327       // Build the new addrec. Propagate the NUW and NSW flags if both the
2328       // outer add and the inner addrec are guaranteed to have no overflow.
2329       // Always propagate NW.
2330       Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2331       const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2332 
2333       // If all of the other operands were loop invariant, we are done.
2334       if (Ops.size() == 1) return NewRec;
2335 
2336       // Otherwise, add the folded AddRec by the non-invariant parts.
2337       for (unsigned i = 0;; ++i)
2338         if (Ops[i] == AddRec) {
2339           Ops[i] = NewRec;
2340           break;
2341         }
2342       return getAddExpr(Ops);
2343     }
2344 
2345     // Okay, if there weren't any loop invariants to be folded, check to see if
2346     // there are multiple AddRec's with the same loop induction variable being
2347     // added together.  If so, we can fold them.
2348     for (unsigned OtherIdx = Idx+1;
2349          OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2350          ++OtherIdx)
2351       if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2352         // Other + {A,+,B}<L> + {C,+,D}<L>  -->  Other + {A+C,+,B+D}<L>
2353         SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
2354                                                AddRec->op_end());
2355         for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2356              ++OtherIdx)
2357           if (const auto *OtherAddRec = dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
2358             if (OtherAddRec->getLoop() == AddRecLoop) {
2359               for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2360                    i != e; ++i) {
2361                 if (i >= AddRecOps.size()) {
2362                   AddRecOps.append(OtherAddRec->op_begin()+i,
2363                                    OtherAddRec->op_end());
2364                   break;
2365                 }
2366                 AddRecOps[i] = getAddExpr(AddRecOps[i],
2367                                           OtherAddRec->getOperand(i));
2368               }
2369               Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2370             }
2371         // Step size has changed, so we cannot guarantee no self-wraparound.
2372         Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2373         return getAddExpr(Ops);
2374       }
2375 
2376     // Otherwise couldn't fold anything into this recurrence.  Move onto the
2377     // next one.
2378   }
2379 
2380   // Okay, it looks like we really DO need an add expr.  Check to see if we
2381   // already have one, otherwise create a new one.
2382   FoldingSetNodeID ID;
2383   ID.AddInteger(scAddExpr);
2384   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2385     ID.AddPointer(Ops[i]);
2386   void *IP = nullptr;
2387   SCEVAddExpr *S =
2388     static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2389   if (!S) {
2390     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2391     std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2392     S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator),
2393                                         O, Ops.size());
2394     UniqueSCEVs.InsertNode(S, IP);
2395   }
2396   S->setNoWrapFlags(Flags);
2397   return S;
2398 }
2399 
2400 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
2401   uint64_t k = i*j;
2402   if (j > 1 && k / j != i) Overflow = true;
2403   return k;
2404 }
2405 
2406 /// Compute the result of "n choose k", the binomial coefficient.  If an
2407 /// intermediate computation overflows, Overflow will be set and the return will
2408 /// be garbage. Overflow is not cleared on absence of overflow.
2409 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
2410   // We use the multiplicative formula:
2411   //     n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
2412   // At each iteration, we take the n-th term of the numeral and divide by the
2413   // (k-n)th term of the denominator.  This division will always produce an
2414   // integral result, and helps reduce the chance of overflow in the
2415   // intermediate computations. However, we can still overflow even when the
2416   // final result would fit.
2417 
2418   if (n == 0 || n == k) return 1;
2419   if (k > n) return 0;
2420 
2421   if (k > n/2)
2422     k = n-k;
2423 
2424   uint64_t r = 1;
2425   for (uint64_t i = 1; i <= k; ++i) {
2426     r = umul_ov(r, n-(i-1), Overflow);
2427     r /= i;
2428   }
2429   return r;
2430 }
2431 
2432 /// Determine if any of the operands in this SCEV are a constant or if
2433 /// any of the add or multiply expressions in this SCEV contain a constant.
2434 static bool containsConstantSomewhere(const SCEV *StartExpr) {
2435   SmallVector<const SCEV *, 4> Ops;
2436   Ops.push_back(StartExpr);
2437   while (!Ops.empty()) {
2438     const SCEV *CurrentExpr = Ops.pop_back_val();
2439     if (isa<SCEVConstant>(*CurrentExpr))
2440       return true;
2441 
2442     if (isa<SCEVAddExpr>(*CurrentExpr) || isa<SCEVMulExpr>(*CurrentExpr)) {
2443       const auto *CurrentNAry = cast<SCEVNAryExpr>(CurrentExpr);
2444       Ops.append(CurrentNAry->op_begin(), CurrentNAry->op_end());
2445     }
2446   }
2447   return false;
2448 }
2449 
2450 /// Get a canonical multiply expression, or something simpler if possible.
2451 const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
2452                                         SCEV::NoWrapFlags Flags) {
2453   assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) &&
2454          "only nuw or nsw allowed");
2455   assert(!Ops.empty() && "Cannot get empty mul!");
2456   if (Ops.size() == 1) return Ops[0];
2457 #ifndef NDEBUG
2458   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2459   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2460     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2461            "SCEVMulExpr operand types don't match!");
2462 #endif
2463 
2464   // Sort by complexity, this groups all similar expression types together.
2465   GroupByComplexity(Ops, &LI);
2466 
2467   Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags);
2468 
2469   // If there are any constants, fold them together.
2470   unsigned Idx = 0;
2471   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2472 
2473     // C1*(C2+V) -> C1*C2 + C1*V
2474     if (Ops.size() == 2)
2475         if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
2476           // If any of Add's ops are Adds or Muls with a constant,
2477           // apply this transformation as well.
2478           if (Add->getNumOperands() == 2)
2479             if (containsConstantSomewhere(Add))
2480               return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
2481                                 getMulExpr(LHSC, Add->getOperand(1)));
2482 
2483     ++Idx;
2484     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2485       // We found two constants, fold them together!
2486       ConstantInt *Fold =
2487           ConstantInt::get(getContext(), LHSC->getAPInt() * RHSC->getAPInt());
2488       Ops[0] = getConstant(Fold);
2489       Ops.erase(Ops.begin()+1);  // Erase the folded element
2490       if (Ops.size() == 1) return Ops[0];
2491       LHSC = cast<SCEVConstant>(Ops[0]);
2492     }
2493 
2494     // If we are left with a constant one being multiplied, strip it off.
2495     if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
2496       Ops.erase(Ops.begin());
2497       --Idx;
2498     } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
2499       // If we have a multiply of zero, it will always be zero.
2500       return Ops[0];
2501     } else if (Ops[0]->isAllOnesValue()) {
2502       // If we have a mul by -1 of an add, try distributing the -1 among the
2503       // add operands.
2504       if (Ops.size() == 2) {
2505         if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
2506           SmallVector<const SCEV *, 4> NewOps;
2507           bool AnyFolded = false;
2508           for (const SCEV *AddOp : Add->operands()) {
2509             const SCEV *Mul = getMulExpr(Ops[0], AddOp);
2510             if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
2511             NewOps.push_back(Mul);
2512           }
2513           if (AnyFolded)
2514             return getAddExpr(NewOps);
2515         } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
2516           // Negation preserves a recurrence's no self-wrap property.
2517           SmallVector<const SCEV *, 4> Operands;
2518           for (const SCEV *AddRecOp : AddRec->operands())
2519             Operands.push_back(getMulExpr(Ops[0], AddRecOp));
2520 
2521           return getAddRecExpr(Operands, AddRec->getLoop(),
2522                                AddRec->getNoWrapFlags(SCEV::FlagNW));
2523         }
2524       }
2525     }
2526 
2527     if (Ops.size() == 1)
2528       return Ops[0];
2529   }
2530 
2531   // Skip over the add expression until we get to a multiply.
2532   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2533     ++Idx;
2534 
2535   // If there are mul operands inline them all into this expression.
2536   if (Idx < Ops.size()) {
2537     bool DeletedMul = false;
2538     while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2539       if (Ops.size() > MulOpsInlineThreshold)
2540         break;
2541       // If we have an mul, expand the mul operands onto the end of the operands
2542       // list.
2543       Ops.erase(Ops.begin()+Idx);
2544       Ops.append(Mul->op_begin(), Mul->op_end());
2545       DeletedMul = true;
2546     }
2547 
2548     // If we deleted at least one mul, we added operands to the end of the list,
2549     // and they are not necessarily sorted.  Recurse to resort and resimplify
2550     // any operands we just acquired.
2551     if (DeletedMul)
2552       return getMulExpr(Ops);
2553   }
2554 
2555   // If there are any add recurrences in the operands list, see if any other
2556   // added values are loop invariant.  If so, we can fold them into the
2557   // recurrence.
2558   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2559     ++Idx;
2560 
2561   // Scan over all recurrences, trying to fold loop invariants into them.
2562   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2563     // Scan all of the other operands to this mul and add them to the vector if
2564     // they are loop invariant w.r.t. the recurrence.
2565     SmallVector<const SCEV *, 8> LIOps;
2566     const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2567     const Loop *AddRecLoop = AddRec->getLoop();
2568     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2569       if (isLoopInvariant(Ops[i], AddRecLoop)) {
2570         LIOps.push_back(Ops[i]);
2571         Ops.erase(Ops.begin()+i);
2572         --i; --e;
2573       }
2574 
2575     // If we found some loop invariants, fold them into the recurrence.
2576     if (!LIOps.empty()) {
2577       //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
2578       SmallVector<const SCEV *, 4> NewOps;
2579       NewOps.reserve(AddRec->getNumOperands());
2580       const SCEV *Scale = getMulExpr(LIOps);
2581       for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
2582         NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
2583 
2584       // Build the new addrec. Propagate the NUW and NSW flags if both the
2585       // outer mul and the inner addrec are guaranteed to have no overflow.
2586       //
2587       // No self-wrap cannot be guaranteed after changing the step size, but
2588       // will be inferred if either NUW or NSW is true.
2589       Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW));
2590       const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags);
2591 
2592       // If all of the other operands were loop invariant, we are done.
2593       if (Ops.size() == 1) return NewRec;
2594 
2595       // Otherwise, multiply the folded AddRec by the non-invariant parts.
2596       for (unsigned i = 0;; ++i)
2597         if (Ops[i] == AddRec) {
2598           Ops[i] = NewRec;
2599           break;
2600         }
2601       return getMulExpr(Ops);
2602     }
2603 
2604     // Okay, if there weren't any loop invariants to be folded, check to see if
2605     // there are multiple AddRec's with the same loop induction variable being
2606     // multiplied together.  If so, we can fold them.
2607 
2608     // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
2609     // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
2610     //       choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
2611     //   ]]],+,...up to x=2n}.
2612     // Note that the arguments to choose() are always integers with values
2613     // known at compile time, never SCEV objects.
2614     //
2615     // The implementation avoids pointless extra computations when the two
2616     // addrec's are of different length (mathematically, it's equivalent to
2617     // an infinite stream of zeros on the right).
2618     bool OpsModified = false;
2619     for (unsigned OtherIdx = Idx+1;
2620          OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2621          ++OtherIdx) {
2622       const SCEVAddRecExpr *OtherAddRec =
2623         dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2624       if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
2625         continue;
2626 
2627       bool Overflow = false;
2628       Type *Ty = AddRec->getType();
2629       bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
2630       SmallVector<const SCEV*, 7> AddRecOps;
2631       for (int x = 0, xe = AddRec->getNumOperands() +
2632              OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
2633         const SCEV *Term = getZero(Ty);
2634         for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
2635           uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
2636           for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
2637                  ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
2638                z < ze && !Overflow; ++z) {
2639             uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
2640             uint64_t Coeff;
2641             if (LargerThan64Bits)
2642               Coeff = umul_ov(Coeff1, Coeff2, Overflow);
2643             else
2644               Coeff = Coeff1*Coeff2;
2645             const SCEV *CoeffTerm = getConstant(Ty, Coeff);
2646             const SCEV *Term1 = AddRec->getOperand(y-z);
2647             const SCEV *Term2 = OtherAddRec->getOperand(z);
2648             Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2));
2649           }
2650         }
2651         AddRecOps.push_back(Term);
2652       }
2653       if (!Overflow) {
2654         const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
2655                                               SCEV::FlagAnyWrap);
2656         if (Ops.size() == 2) return NewAddRec;
2657         Ops[Idx] = NewAddRec;
2658         Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2659         OpsModified = true;
2660         AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
2661         if (!AddRec)
2662           break;
2663       }
2664     }
2665     if (OpsModified)
2666       return getMulExpr(Ops);
2667 
2668     // Otherwise couldn't fold anything into this recurrence.  Move onto the
2669     // next one.
2670   }
2671 
2672   // Okay, it looks like we really DO need an mul expr.  Check to see if we
2673   // already have one, otherwise create a new one.
2674   FoldingSetNodeID ID;
2675   ID.AddInteger(scMulExpr);
2676   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2677     ID.AddPointer(Ops[i]);
2678   void *IP = nullptr;
2679   SCEVMulExpr *S =
2680     static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2681   if (!S) {
2682     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2683     std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2684     S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2685                                         O, Ops.size());
2686     UniqueSCEVs.InsertNode(S, IP);
2687   }
2688   S->setNoWrapFlags(Flags);
2689   return S;
2690 }
2691 
2692 /// Get a canonical unsigned division expression, or something simpler if
2693 /// possible.
2694 const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
2695                                          const SCEV *RHS) {
2696   assert(getEffectiveSCEVType(LHS->getType()) ==
2697          getEffectiveSCEVType(RHS->getType()) &&
2698          "SCEVUDivExpr operand types don't match!");
2699 
2700   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
2701     if (RHSC->getValue()->equalsInt(1))
2702       return LHS;                               // X udiv 1 --> x
2703     // If the denominator is zero, the result of the udiv is undefined. Don't
2704     // try to analyze it, because the resolution chosen here may differ from
2705     // the resolution chosen in other parts of the compiler.
2706     if (!RHSC->getValue()->isZero()) {
2707       // Determine if the division can be folded into the operands of
2708       // its operands.
2709       // TODO: Generalize this to non-constants by using known-bits information.
2710       Type *Ty = LHS->getType();
2711       unsigned LZ = RHSC->getAPInt().countLeadingZeros();
2712       unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
2713       // For non-power-of-two values, effectively round the value up to the
2714       // nearest power of two.
2715       if (!RHSC->getAPInt().isPowerOf2())
2716         ++MaxShiftAmt;
2717       IntegerType *ExtTy =
2718         IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
2719       if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
2720         if (const SCEVConstant *Step =
2721             dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
2722           // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
2723           const APInt &StepInt = Step->getAPInt();
2724           const APInt &DivInt = RHSC->getAPInt();
2725           if (!StepInt.urem(DivInt) &&
2726               getZeroExtendExpr(AR, ExtTy) ==
2727               getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
2728                             getZeroExtendExpr(Step, ExtTy),
2729                             AR->getLoop(), SCEV::FlagAnyWrap)) {
2730             SmallVector<const SCEV *, 4> Operands;
2731             for (const SCEV *Op : AR->operands())
2732               Operands.push_back(getUDivExpr(Op, RHS));
2733             return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
2734           }
2735           /// Get a canonical UDivExpr for a recurrence.
2736           /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
2737           // We can currently only fold X%N if X is constant.
2738           const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
2739           if (StartC && !DivInt.urem(StepInt) &&
2740               getZeroExtendExpr(AR, ExtTy) ==
2741               getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
2742                             getZeroExtendExpr(Step, ExtTy),
2743                             AR->getLoop(), SCEV::FlagAnyWrap)) {
2744             const APInt &StartInt = StartC->getAPInt();
2745             const APInt &StartRem = StartInt.urem(StepInt);
2746             if (StartRem != 0)
2747               LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step,
2748                                   AR->getLoop(), SCEV::FlagNW);
2749           }
2750         }
2751       // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
2752       if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
2753         SmallVector<const SCEV *, 4> Operands;
2754         for (const SCEV *Op : M->operands())
2755           Operands.push_back(getZeroExtendExpr(Op, ExtTy));
2756         if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
2757           // Find an operand that's safely divisible.
2758           for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
2759             const SCEV *Op = M->getOperand(i);
2760             const SCEV *Div = getUDivExpr(Op, RHSC);
2761             if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
2762               Operands = SmallVector<const SCEV *, 4>(M->op_begin(),
2763                                                       M->op_end());
2764               Operands[i] = Div;
2765               return getMulExpr(Operands);
2766             }
2767           }
2768       }
2769       // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
2770       if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
2771         SmallVector<const SCEV *, 4> Operands;
2772         for (const SCEV *Op : A->operands())
2773           Operands.push_back(getZeroExtendExpr(Op, ExtTy));
2774         if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
2775           Operands.clear();
2776           for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
2777             const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
2778             if (isa<SCEVUDivExpr>(Op) ||
2779                 getMulExpr(Op, RHS) != A->getOperand(i))
2780               break;
2781             Operands.push_back(Op);
2782           }
2783           if (Operands.size() == A->getNumOperands())
2784             return getAddExpr(Operands);
2785         }
2786       }
2787 
2788       // Fold if both operands are constant.
2789       if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
2790         Constant *LHSCV = LHSC->getValue();
2791         Constant *RHSCV = RHSC->getValue();
2792         return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
2793                                                                    RHSCV)));
2794       }
2795     }
2796   }
2797 
2798   FoldingSetNodeID ID;
2799   ID.AddInteger(scUDivExpr);
2800   ID.AddPointer(LHS);
2801   ID.AddPointer(RHS);
2802   void *IP = nullptr;
2803   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2804   SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
2805                                              LHS, RHS);
2806   UniqueSCEVs.InsertNode(S, IP);
2807   return S;
2808 }
2809 
2810 static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
2811   APInt A = C1->getAPInt().abs();
2812   APInt B = C2->getAPInt().abs();
2813   uint32_t ABW = A.getBitWidth();
2814   uint32_t BBW = B.getBitWidth();
2815 
2816   if (ABW > BBW)
2817     B = B.zext(ABW);
2818   else if (ABW < BBW)
2819     A = A.zext(BBW);
2820 
2821   return APIntOps::GreatestCommonDivisor(A, B);
2822 }
2823 
2824 /// Get a canonical unsigned division expression, or something simpler if
2825 /// possible. There is no representation for an exact udiv in SCEV IR, but we
2826 /// can attempt to remove factors from the LHS and RHS.  We can't do this when
2827 /// it's not exact because the udiv may be clearing bits.
2828 const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS,
2829                                               const SCEV *RHS) {
2830   // TODO: we could try to find factors in all sorts of things, but for now we
2831   // just deal with u/exact (multiply, constant). See SCEVDivision towards the
2832   // end of this file for inspiration.
2833 
2834   const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
2835   if (!Mul)
2836     return getUDivExpr(LHS, RHS);
2837 
2838   if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
2839     // If the mulexpr multiplies by a constant, then that constant must be the
2840     // first element of the mulexpr.
2841     if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
2842       if (LHSCst == RHSCst) {
2843         SmallVector<const SCEV *, 2> Operands;
2844         Operands.append(Mul->op_begin() + 1, Mul->op_end());
2845         return getMulExpr(Operands);
2846       }
2847 
2848       // We can't just assume that LHSCst divides RHSCst cleanly, it could be
2849       // that there's a factor provided by one of the other terms. We need to
2850       // check.
2851       APInt Factor = gcd(LHSCst, RHSCst);
2852       if (!Factor.isIntN(1)) {
2853         LHSCst =
2854             cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
2855         RHSCst =
2856             cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
2857         SmallVector<const SCEV *, 2> Operands;
2858         Operands.push_back(LHSCst);
2859         Operands.append(Mul->op_begin() + 1, Mul->op_end());
2860         LHS = getMulExpr(Operands);
2861         RHS = RHSCst;
2862         Mul = dyn_cast<SCEVMulExpr>(LHS);
2863         if (!Mul)
2864           return getUDivExactExpr(LHS, RHS);
2865       }
2866     }
2867   }
2868 
2869   for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
2870     if (Mul->getOperand(i) == RHS) {
2871       SmallVector<const SCEV *, 2> Operands;
2872       Operands.append(Mul->op_begin(), Mul->op_begin() + i);
2873       Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
2874       return getMulExpr(Operands);
2875     }
2876   }
2877 
2878   return getUDivExpr(LHS, RHS);
2879 }
2880 
2881 /// Get an add recurrence expression for the specified loop.  Simplify the
2882 /// expression as much as possible.
2883 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
2884                                            const Loop *L,
2885                                            SCEV::NoWrapFlags Flags) {
2886   SmallVector<const SCEV *, 4> Operands;
2887   Operands.push_back(Start);
2888   if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
2889     if (StepChrec->getLoop() == L) {
2890       Operands.append(StepChrec->op_begin(), StepChrec->op_end());
2891       return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
2892     }
2893 
2894   Operands.push_back(Step);
2895   return getAddRecExpr(Operands, L, Flags);
2896 }
2897 
2898 /// Get an add recurrence expression for the specified loop.  Simplify the
2899 /// expression as much as possible.
2900 const SCEV *
2901 ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
2902                                const Loop *L, SCEV::NoWrapFlags Flags) {
2903   if (Operands.size() == 1) return Operands[0];
2904 #ifndef NDEBUG
2905   Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
2906   for (unsigned i = 1, e = Operands.size(); i != e; ++i)
2907     assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
2908            "SCEVAddRecExpr operand types don't match!");
2909   for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2910     assert(isLoopInvariant(Operands[i], L) &&
2911            "SCEVAddRecExpr operand is not loop-invariant!");
2912 #endif
2913 
2914   if (Operands.back()->isZero()) {
2915     Operands.pop_back();
2916     return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0}  -->  X
2917   }
2918 
2919   // It's tempting to want to call getMaxBackedgeTakenCount count here and
2920   // use that information to infer NUW and NSW flags. However, computing a
2921   // BE count requires calling getAddRecExpr, so we may not yet have a
2922   // meaningful BE count at this point (and if we don't, we'd be stuck
2923   // with a SCEVCouldNotCompute as the cached BE count).
2924 
2925   Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
2926 
2927   // Canonicalize nested AddRecs in by nesting them in order of loop depth.
2928   if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
2929     const Loop *NestedLoop = NestedAR->getLoop();
2930     if (L->contains(NestedLoop)
2931             ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
2932             : (!NestedLoop->contains(L) &&
2933                DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
2934       SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(),
2935                                                   NestedAR->op_end());
2936       Operands[0] = NestedAR->getStart();
2937       // AddRecs require their operands be loop-invariant with respect to their
2938       // loops. Don't perform this transformation if it would break this
2939       // requirement.
2940       bool AllInvariant = all_of(
2941           Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
2942 
2943       if (AllInvariant) {
2944         // Create a recurrence for the outer loop with the same step size.
2945         //
2946         // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
2947         // inner recurrence has the same property.
2948         SCEV::NoWrapFlags OuterFlags =
2949           maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
2950 
2951         NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
2952         AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
2953           return isLoopInvariant(Op, NestedLoop);
2954         });
2955 
2956         if (AllInvariant) {
2957           // Ok, both add recurrences are valid after the transformation.
2958           //
2959           // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
2960           // the outer recurrence has the same property.
2961           SCEV::NoWrapFlags InnerFlags =
2962             maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
2963           return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
2964         }
2965       }
2966       // Reset Operands to its original state.
2967       Operands[0] = NestedAR;
2968     }
2969   }
2970 
2971   // Okay, it looks like we really DO need an addrec expr.  Check to see if we
2972   // already have one, otherwise create a new one.
2973   FoldingSetNodeID ID;
2974   ID.AddInteger(scAddRecExpr);
2975   for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2976     ID.AddPointer(Operands[i]);
2977   ID.AddPointer(L);
2978   void *IP = nullptr;
2979   SCEVAddRecExpr *S =
2980     static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2981   if (!S) {
2982     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Operands.size());
2983     std::uninitialized_copy(Operands.begin(), Operands.end(), O);
2984     S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator),
2985                                            O, Operands.size(), L);
2986     UniqueSCEVs.InsertNode(S, IP);
2987   }
2988   S->setNoWrapFlags(Flags);
2989   return S;
2990 }
2991 
2992 const SCEV *
2993 ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr,
2994                             const SmallVectorImpl<const SCEV *> &IndexExprs,
2995                             bool InBounds) {
2996   // getSCEV(Base)->getType() has the same address space as Base->getType()
2997   // because SCEV::getType() preserves the address space.
2998   Type *IntPtrTy = getEffectiveSCEVType(BaseExpr->getType());
2999   // FIXME(PR23527): Don't blindly transfer the inbounds flag from the GEP
3000   // instruction to its SCEV, because the Instruction may be guarded by control
3001   // flow and the no-overflow bits may not be valid for the expression in any
3002   // context. This can be fixed similarly to how these flags are handled for
3003   // adds.
3004   SCEV::NoWrapFlags Wrap = InBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3005 
3006   const SCEV *TotalOffset = getZero(IntPtrTy);
3007   // The address space is unimportant. The first thing we do on CurTy is getting
3008   // its element type.
3009   Type *CurTy = PointerType::getUnqual(PointeeType);
3010   for (const SCEV *IndexExpr : IndexExprs) {
3011     // Compute the (potentially symbolic) offset in bytes for this index.
3012     if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3013       // For a struct, add the member offset.
3014       ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3015       unsigned FieldNo = Index->getZExtValue();
3016       const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo);
3017 
3018       // Add the field offset to the running total offset.
3019       TotalOffset = getAddExpr(TotalOffset, FieldOffset);
3020 
3021       // Update CurTy to the type of the field at Index.
3022       CurTy = STy->getTypeAtIndex(Index);
3023     } else {
3024       // Update CurTy to its element type.
3025       CurTy = cast<SequentialType>(CurTy)->getElementType();
3026       // For an array, add the element offset, explicitly scaled.
3027       const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, CurTy);
3028       // Getelementptr indices are signed.
3029       IndexExpr = getTruncateOrSignExtend(IndexExpr, IntPtrTy);
3030 
3031       // Multiply the index by the element size to compute the element offset.
3032       const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, Wrap);
3033 
3034       // Add the element offset to the running total offset.
3035       TotalOffset = getAddExpr(TotalOffset, LocalOffset);
3036     }
3037   }
3038 
3039   // Add the total offset from all the GEP indices to the base.
3040   return getAddExpr(BaseExpr, TotalOffset, Wrap);
3041 }
3042 
3043 const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS,
3044                                          const SCEV *RHS) {
3045   SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
3046   return getSMaxExpr(Ops);
3047 }
3048 
3049 const SCEV *
3050 ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
3051   assert(!Ops.empty() && "Cannot get empty smax!");
3052   if (Ops.size() == 1) return Ops[0];
3053 #ifndef NDEBUG
3054   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3055   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3056     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3057            "SCEVSMaxExpr operand types don't match!");
3058 #endif
3059 
3060   // Sort by complexity, this groups all similar expression types together.
3061   GroupByComplexity(Ops, &LI);
3062 
3063   // If there are any constants, fold them together.
3064   unsigned Idx = 0;
3065   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3066     ++Idx;
3067     assert(Idx < Ops.size());
3068     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3069       // We found two constants, fold them together!
3070       ConstantInt *Fold = ConstantInt::get(
3071           getContext(), APIntOps::smax(LHSC->getAPInt(), RHSC->getAPInt()));
3072       Ops[0] = getConstant(Fold);
3073       Ops.erase(Ops.begin()+1);  // Erase the folded element
3074       if (Ops.size() == 1) return Ops[0];
3075       LHSC = cast<SCEVConstant>(Ops[0]);
3076     }
3077 
3078     // If we are left with a constant minimum-int, strip it off.
3079     if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
3080       Ops.erase(Ops.begin());
3081       --Idx;
3082     } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(true)) {
3083       // If we have an smax with a constant maximum-int, it will always be
3084       // maximum-int.
3085       return Ops[0];
3086     }
3087 
3088     if (Ops.size() == 1) return Ops[0];
3089   }
3090 
3091   // Find the first SMax
3092   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
3093     ++Idx;
3094 
3095   // Check to see if one of the operands is an SMax. If so, expand its operands
3096   // onto our operand list, and recurse to simplify.
3097   if (Idx < Ops.size()) {
3098     bool DeletedSMax = false;
3099     while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
3100       Ops.erase(Ops.begin()+Idx);
3101       Ops.append(SMax->op_begin(), SMax->op_end());
3102       DeletedSMax = true;
3103     }
3104 
3105     if (DeletedSMax)
3106       return getSMaxExpr(Ops);
3107   }
3108 
3109   // Okay, check to see if the same value occurs in the operand list twice.  If
3110   // so, delete one.  Since we sorted the list, these values are required to
3111   // be adjacent.
3112   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
3113     //  X smax Y smax Y  -->  X smax Y
3114     //  X smax Y         -->  X, if X is always greater than Y
3115     if (Ops[i] == Ops[i+1] ||
3116         isKnownPredicate(ICmpInst::ICMP_SGE, Ops[i], Ops[i+1])) {
3117       Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
3118       --i; --e;
3119     } else if (isKnownPredicate(ICmpInst::ICMP_SLE, Ops[i], Ops[i+1])) {
3120       Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
3121       --i; --e;
3122     }
3123 
3124   if (Ops.size() == 1) return Ops[0];
3125 
3126   assert(!Ops.empty() && "Reduced smax down to nothing!");
3127 
3128   // Okay, it looks like we really DO need an smax expr.  Check to see if we
3129   // already have one, otherwise create a new one.
3130   FoldingSetNodeID ID;
3131   ID.AddInteger(scSMaxExpr);
3132   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3133     ID.AddPointer(Ops[i]);
3134   void *IP = nullptr;
3135   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3136   const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3137   std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3138   SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator),
3139                                              O, Ops.size());
3140   UniqueSCEVs.InsertNode(S, IP);
3141   return S;
3142 }
3143 
3144 const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS,
3145                                          const SCEV *RHS) {
3146   SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
3147   return getUMaxExpr(Ops);
3148 }
3149 
3150 const SCEV *
3151 ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
3152   assert(!Ops.empty() && "Cannot get empty umax!");
3153   if (Ops.size() == 1) return Ops[0];
3154 #ifndef NDEBUG
3155   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3156   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3157     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3158            "SCEVUMaxExpr operand types don't match!");
3159 #endif
3160 
3161   // Sort by complexity, this groups all similar expression types together.
3162   GroupByComplexity(Ops, &LI);
3163 
3164   // If there are any constants, fold them together.
3165   unsigned Idx = 0;
3166   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3167     ++Idx;
3168     assert(Idx < Ops.size());
3169     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3170       // We found two constants, fold them together!
3171       ConstantInt *Fold = ConstantInt::get(
3172           getContext(), APIntOps::umax(LHSC->getAPInt(), RHSC->getAPInt()));
3173       Ops[0] = getConstant(Fold);
3174       Ops.erase(Ops.begin()+1);  // Erase the folded element
3175       if (Ops.size() == 1) return Ops[0];
3176       LHSC = cast<SCEVConstant>(Ops[0]);
3177     }
3178 
3179     // If we are left with a constant minimum-int, strip it off.
3180     if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
3181       Ops.erase(Ops.begin());
3182       --Idx;
3183     } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(false)) {
3184       // If we have an umax with a constant maximum-int, it will always be
3185       // maximum-int.
3186       return Ops[0];
3187     }
3188 
3189     if (Ops.size() == 1) return Ops[0];
3190   }
3191 
3192   // Find the first UMax
3193   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
3194     ++Idx;
3195 
3196   // Check to see if one of the operands is a UMax. If so, expand its operands
3197   // onto our operand list, and recurse to simplify.
3198   if (Idx < Ops.size()) {
3199     bool DeletedUMax = false;
3200     while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
3201       Ops.erase(Ops.begin()+Idx);
3202       Ops.append(UMax->op_begin(), UMax->op_end());
3203       DeletedUMax = true;
3204     }
3205 
3206     if (DeletedUMax)
3207       return getUMaxExpr(Ops);
3208   }
3209 
3210   // Okay, check to see if the same value occurs in the operand list twice.  If
3211   // so, delete one.  Since we sorted the list, these values are required to
3212   // be adjacent.
3213   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
3214     //  X umax Y umax Y  -->  X umax Y
3215     //  X umax Y         -->  X, if X is always greater than Y
3216     if (Ops[i] == Ops[i+1] ||
3217         isKnownPredicate(ICmpInst::ICMP_UGE, Ops[i], Ops[i+1])) {
3218       Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
3219       --i; --e;
3220     } else if (isKnownPredicate(ICmpInst::ICMP_ULE, Ops[i], Ops[i+1])) {
3221       Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
3222       --i; --e;
3223     }
3224 
3225   if (Ops.size() == 1) return Ops[0];
3226 
3227   assert(!Ops.empty() && "Reduced umax down to nothing!");
3228 
3229   // Okay, it looks like we really DO need a umax expr.  Check to see if we
3230   // already have one, otherwise create a new one.
3231   FoldingSetNodeID ID;
3232   ID.AddInteger(scUMaxExpr);
3233   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3234     ID.AddPointer(Ops[i]);
3235   void *IP = nullptr;
3236   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3237   const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3238   std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3239   SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator),
3240                                              O, Ops.size());
3241   UniqueSCEVs.InsertNode(S, IP);
3242   return S;
3243 }
3244 
3245 const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
3246                                          const SCEV *RHS) {
3247   // ~smax(~x, ~y) == smin(x, y).
3248   return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
3249 }
3250 
3251 const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
3252                                          const SCEV *RHS) {
3253   // ~umax(~x, ~y) == umin(x, y)
3254   return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
3255 }
3256 
3257 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
3258   // We can bypass creating a target-independent
3259   // constant expression and then folding it back into a ConstantInt.
3260   // This is just a compile-time optimization.
3261   return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
3262 }
3263 
3264 const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
3265                                              StructType *STy,
3266                                              unsigned FieldNo) {
3267   // We can bypass creating a target-independent
3268   // constant expression and then folding it back into a ConstantInt.
3269   // This is just a compile-time optimization.
3270   return getConstant(
3271       IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
3272 }
3273 
3274 const SCEV *ScalarEvolution::getUnknown(Value *V) {
3275   // Don't attempt to do anything other than create a SCEVUnknown object
3276   // here.  createSCEV only calls getUnknown after checking for all other
3277   // interesting possibilities, and any other code that calls getUnknown
3278   // is doing so in order to hide a value from SCEV canonicalization.
3279 
3280   FoldingSetNodeID ID;
3281   ID.AddInteger(scUnknown);
3282   ID.AddPointer(V);
3283   void *IP = nullptr;
3284   if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
3285     assert(cast<SCEVUnknown>(S)->getValue() == V &&
3286            "Stale SCEVUnknown in uniquing map!");
3287     return S;
3288   }
3289   SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
3290                                             FirstUnknown);
3291   FirstUnknown = cast<SCEVUnknown>(S);
3292   UniqueSCEVs.InsertNode(S, IP);
3293   return S;
3294 }
3295 
3296 //===----------------------------------------------------------------------===//
3297 //            Basic SCEV Analysis and PHI Idiom Recognition Code
3298 //
3299 
3300 /// Test if values of the given type are analyzable within the SCEV
3301 /// framework. This primarily includes integer types, and it can optionally
3302 /// include pointer types if the ScalarEvolution class has access to
3303 /// target-specific information.
3304 bool ScalarEvolution::isSCEVable(Type *Ty) const {
3305   // Integers and pointers are always SCEVable.
3306   return Ty->isIntegerTy() || Ty->isPointerTy();
3307 }
3308 
3309 /// Return the size in bits of the specified type, for which isSCEVable must
3310 /// return true.
3311 uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
3312   assert(isSCEVable(Ty) && "Type is not SCEVable!");
3313   return getDataLayout().getTypeSizeInBits(Ty);
3314 }
3315 
3316 /// Return a type with the same bitwidth as the given type and which represents
3317 /// how SCEV will treat the given type, for which isSCEVable must return
3318 /// true. For pointer types, this is the pointer-sized integer type.
3319 Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
3320   assert(isSCEVable(Ty) && "Type is not SCEVable!");
3321 
3322   if (Ty->isIntegerTy())
3323     return Ty;
3324 
3325   // The only other support type is pointer.
3326   assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
3327   return getDataLayout().getIntPtrType(Ty);
3328 }
3329 
3330 const SCEV *ScalarEvolution::getCouldNotCompute() {
3331   return CouldNotCompute.get();
3332 }
3333 
3334 
3335 bool ScalarEvolution::checkValidity(const SCEV *S) const {
3336   // Helper class working with SCEVTraversal to figure out if a SCEV contains
3337   // a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne
3338   // is set iff if find such SCEVUnknown.
3339   //
3340   struct FindInvalidSCEVUnknown {
3341     bool FindOne;
3342     FindInvalidSCEVUnknown() { FindOne = false; }
3343     bool follow(const SCEV *S) {
3344       switch (static_cast<SCEVTypes>(S->getSCEVType())) {
3345       case scConstant:
3346         return false;
3347       case scUnknown:
3348         if (!cast<SCEVUnknown>(S)->getValue())
3349           FindOne = true;
3350         return false;
3351       default:
3352         return true;
3353       }
3354     }
3355     bool isDone() const { return FindOne; }
3356   };
3357 
3358   FindInvalidSCEVUnknown F;
3359   SCEVTraversal<FindInvalidSCEVUnknown> ST(F);
3360   ST.visitAll(S);
3361 
3362   return !F.FindOne;
3363 }
3364 
3365 bool ScalarEvolution::containsAddRecurrence(const SCEV *S) {
3366   // Helper class working with SCEVTraversal to figure out if a SCEV contains a
3367   // sub SCEV of scAddRecExpr type.  FindInvalidSCEVUnknown::FoundOne is set iff
3368   // if such sub scAddRecExpr type SCEV is found.
3369   struct FindAddRecurrence {
3370     bool FoundOne;
3371     FindAddRecurrence() : FoundOne(false) {}
3372 
3373     bool follow(const SCEV *S) {
3374       switch (static_cast<SCEVTypes>(S->getSCEVType())) {
3375       case scAddRecExpr:
3376         FoundOne = true;
3377       case scConstant:
3378       case scUnknown:
3379       case scCouldNotCompute:
3380         return false;
3381       default:
3382         return true;
3383       }
3384     }
3385     bool isDone() const { return FoundOne; }
3386   };
3387 
3388   HasRecMapType::iterator I = HasRecMap.find(S);
3389   if (I != HasRecMap.end())
3390     return I->second;
3391 
3392   FindAddRecurrence F;
3393   SCEVTraversal<FindAddRecurrence> ST(F);
3394   ST.visitAll(S);
3395   HasRecMap.insert({S, F.FoundOne});
3396   return F.FoundOne;
3397 }
3398 
3399 /// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}.
3400 /// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an
3401 /// offset I, then return {S', I}, else return {\p S, nullptr}.
3402 static std::pair<const SCEV *, ConstantInt *> splitAddExpr(const SCEV *S) {
3403   const auto *Add = dyn_cast<SCEVAddExpr>(S);
3404   if (!Add)
3405     return {S, nullptr};
3406 
3407   if (Add->getNumOperands() != 2)
3408     return {S, nullptr};
3409 
3410   auto *ConstOp = dyn_cast<SCEVConstant>(Add->getOperand(0));
3411   if (!ConstOp)
3412     return {S, nullptr};
3413 
3414   return {Add->getOperand(1), ConstOp->getValue()};
3415 }
3416 
3417 /// Return the ValueOffsetPair set for \p S. \p S can be represented
3418 /// by the value and offset from any ValueOffsetPair in the set.
3419 SetVector<ScalarEvolution::ValueOffsetPair> *
3420 ScalarEvolution::getSCEVValues(const SCEV *S) {
3421   ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
3422   if (SI == ExprValueMap.end())
3423     return nullptr;
3424 #ifndef NDEBUG
3425   if (VerifySCEVMap) {
3426     // Check there is no dangling Value in the set returned.
3427     for (const auto &VE : SI->second)
3428       assert(ValueExprMap.count(VE.first));
3429   }
3430 #endif
3431   return &SI->second;
3432 }
3433 
3434 /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
3435 /// cannot be used separately. eraseValueFromMap should be used to remove
3436 /// V from ValueExprMap and ExprValueMap at the same time.
3437 void ScalarEvolution::eraseValueFromMap(Value *V) {
3438   ValueExprMapType::iterator I = ValueExprMap.find_as(V);
3439   if (I != ValueExprMap.end()) {
3440     const SCEV *S = I->second;
3441     // Remove {V, 0} from the set of ExprValueMap[S]
3442     if (SetVector<ValueOffsetPair> *SV = getSCEVValues(S))
3443       SV->remove({V, nullptr});
3444 
3445     // Remove {V, Offset} from the set of ExprValueMap[Stripped]
3446     const SCEV *Stripped;
3447     ConstantInt *Offset;
3448     std::tie(Stripped, Offset) = splitAddExpr(S);
3449     if (Offset != nullptr) {
3450       if (SetVector<ValueOffsetPair> *SV = getSCEVValues(Stripped))
3451         SV->remove({V, Offset});
3452     }
3453     ValueExprMap.erase(V);
3454   }
3455 }
3456 
3457 /// Return an existing SCEV if it exists, otherwise analyze the expression and
3458 /// create a new one.
3459 const SCEV *ScalarEvolution::getSCEV(Value *V) {
3460   assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
3461 
3462   const SCEV *S = getExistingSCEV(V);
3463   if (S == nullptr) {
3464     S = createSCEV(V);
3465     // During PHI resolution, it is possible to create two SCEVs for the same
3466     // V, so it is needed to double check whether V->S is inserted into
3467     // ValueExprMap before insert S->{V, 0} into ExprValueMap.
3468     std::pair<ValueExprMapType::iterator, bool> Pair =
3469         ValueExprMap.insert({SCEVCallbackVH(V, this), S});
3470     if (Pair.second) {
3471       ExprValueMap[S].insert({V, nullptr});
3472 
3473       // If S == Stripped + Offset, add Stripped -> {V, Offset} into
3474       // ExprValueMap.
3475       const SCEV *Stripped = S;
3476       ConstantInt *Offset = nullptr;
3477       std::tie(Stripped, Offset) = splitAddExpr(S);
3478       // If stripped is SCEVUnknown, don't bother to save
3479       // Stripped -> {V, offset}. It doesn't simplify and sometimes even
3480       // increase the complexity of the expansion code.
3481       // If V is GetElementPtrInst, don't save Stripped -> {V, offset}
3482       // because it may generate add/sub instead of GEP in SCEV expansion.
3483       if (Offset != nullptr && !isa<SCEVUnknown>(Stripped) &&
3484           !isa<GetElementPtrInst>(V))
3485         ExprValueMap[Stripped].insert({V, Offset});
3486     }
3487   }
3488   return S;
3489 }
3490 
3491 const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
3492   assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
3493 
3494   ValueExprMapType::iterator I = ValueExprMap.find_as(V);
3495   if (I != ValueExprMap.end()) {
3496     const SCEV *S = I->second;
3497     if (checkValidity(S))
3498       return S;
3499     eraseValueFromMap(V);
3500     forgetMemoizedResults(S);
3501   }
3502   return nullptr;
3503 }
3504 
3505 /// Return a SCEV corresponding to -V = -1*V
3506 ///
3507 const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
3508                                              SCEV::NoWrapFlags Flags) {
3509   if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
3510     return getConstant(
3511                cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
3512 
3513   Type *Ty = V->getType();
3514   Ty = getEffectiveSCEVType(Ty);
3515   return getMulExpr(
3516       V, getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))), Flags);
3517 }
3518 
3519 /// Return a SCEV corresponding to ~V = -1-V
3520 const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
3521   if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
3522     return getConstant(
3523                 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
3524 
3525   Type *Ty = V->getType();
3526   Ty = getEffectiveSCEVType(Ty);
3527   const SCEV *AllOnes =
3528                    getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty)));
3529   return getMinusSCEV(AllOnes, V);
3530 }
3531 
3532 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
3533                                           SCEV::NoWrapFlags Flags) {
3534   // Fast path: X - X --> 0.
3535   if (LHS == RHS)
3536     return getZero(LHS->getType());
3537 
3538   // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
3539   // makes it so that we cannot make much use of NUW.
3540   auto AddFlags = SCEV::FlagAnyWrap;
3541   const bool RHSIsNotMinSigned =
3542       !getSignedRange(RHS).getSignedMin().isMinSignedValue();
3543   if (maskFlags(Flags, SCEV::FlagNSW) == SCEV::FlagNSW) {
3544     // Let M be the minimum representable signed value. Then (-1)*RHS
3545     // signed-wraps if and only if RHS is M. That can happen even for
3546     // a NSW subtraction because e.g. (-1)*M signed-wraps even though
3547     // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
3548     // (-1)*RHS, we need to prove that RHS != M.
3549     //
3550     // If LHS is non-negative and we know that LHS - RHS does not
3551     // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
3552     // either by proving that RHS > M or that LHS >= 0.
3553     if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
3554       AddFlags = SCEV::FlagNSW;
3555     }
3556   }
3557 
3558   // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
3559   // RHS is NSW and LHS >= 0.
3560   //
3561   // The difficulty here is that the NSW flag may have been proven
3562   // relative to a loop that is to be found in a recurrence in LHS and
3563   // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
3564   // larger scope than intended.
3565   auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3566 
3567   return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags);
3568 }
3569 
3570 const SCEV *
3571 ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) {
3572   Type *SrcTy = V->getType();
3573   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3574          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3575          "Cannot truncate or zero extend with non-integer arguments!");
3576   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3577     return V;  // No conversion
3578   if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
3579     return getTruncateExpr(V, Ty);
3580   return getZeroExtendExpr(V, Ty);
3581 }
3582 
3583 const SCEV *
3584 ScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
3585                                          Type *Ty) {
3586   Type *SrcTy = V->getType();
3587   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3588          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3589          "Cannot truncate or zero extend with non-integer arguments!");
3590   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3591     return V;  // No conversion
3592   if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
3593     return getTruncateExpr(V, Ty);
3594   return getSignExtendExpr(V, Ty);
3595 }
3596 
3597 const SCEV *
3598 ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
3599   Type *SrcTy = V->getType();
3600   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3601          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3602          "Cannot noop or zero extend with non-integer arguments!");
3603   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
3604          "getNoopOrZeroExtend cannot truncate!");
3605   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3606     return V;  // No conversion
3607   return getZeroExtendExpr(V, Ty);
3608 }
3609 
3610 const SCEV *
3611 ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
3612   Type *SrcTy = V->getType();
3613   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3614          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3615          "Cannot noop or sign extend with non-integer arguments!");
3616   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
3617          "getNoopOrSignExtend cannot truncate!");
3618   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3619     return V;  // No conversion
3620   return getSignExtendExpr(V, Ty);
3621 }
3622 
3623 const SCEV *
3624 ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
3625   Type *SrcTy = V->getType();
3626   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3627          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3628          "Cannot noop or any extend with non-integer arguments!");
3629   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
3630          "getNoopOrAnyExtend cannot truncate!");
3631   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3632     return V;  // No conversion
3633   return getAnyExtendExpr(V, Ty);
3634 }
3635 
3636 const SCEV *
3637 ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
3638   Type *SrcTy = V->getType();
3639   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
3640          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
3641          "Cannot truncate or noop with non-integer arguments!");
3642   assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
3643          "getTruncateOrNoop cannot extend!");
3644   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3645     return V;  // No conversion
3646   return getTruncateExpr(V, Ty);
3647 }
3648 
3649 const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
3650                                                         const SCEV *RHS) {
3651   const SCEV *PromotedLHS = LHS;
3652   const SCEV *PromotedRHS = RHS;
3653 
3654   if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
3655     PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
3656   else
3657     PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
3658 
3659   return getUMaxExpr(PromotedLHS, PromotedRHS);
3660 }
3661 
3662 const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
3663                                                         const SCEV *RHS) {
3664   const SCEV *PromotedLHS = LHS;
3665   const SCEV *PromotedRHS = RHS;
3666 
3667   if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
3668     PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
3669   else
3670     PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
3671 
3672   return getUMinExpr(PromotedLHS, PromotedRHS);
3673 }
3674 
3675 const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
3676   // A pointer operand may evaluate to a nonpointer expression, such as null.
3677   if (!V->getType()->isPointerTy())
3678     return V;
3679 
3680   if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
3681     return getPointerBase(Cast->getOperand());
3682   } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
3683     const SCEV *PtrOp = nullptr;
3684     for (const SCEV *NAryOp : NAry->operands()) {
3685       if (NAryOp->getType()->isPointerTy()) {
3686         // Cannot find the base of an expression with multiple pointer operands.
3687         if (PtrOp)
3688           return V;
3689         PtrOp = NAryOp;
3690       }
3691     }
3692     if (!PtrOp)
3693       return V;
3694     return getPointerBase(PtrOp);
3695   }
3696   return V;
3697 }
3698 
3699 /// Push users of the given Instruction onto the given Worklist.
3700 static void
3701 PushDefUseChildren(Instruction *I,
3702                    SmallVectorImpl<Instruction *> &Worklist) {
3703   // Push the def-use children onto the Worklist stack.
3704   for (User *U : I->users())
3705     Worklist.push_back(cast<Instruction>(U));
3706 }
3707 
3708 void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) {
3709   SmallVector<Instruction *, 16> Worklist;
3710   PushDefUseChildren(PN, Worklist);
3711 
3712   SmallPtrSet<Instruction *, 8> Visited;
3713   Visited.insert(PN);
3714   while (!Worklist.empty()) {
3715     Instruction *I = Worklist.pop_back_val();
3716     if (!Visited.insert(I).second)
3717       continue;
3718 
3719     auto It = ValueExprMap.find_as(static_cast<Value *>(I));
3720     if (It != ValueExprMap.end()) {
3721       const SCEV *Old = It->second;
3722 
3723       // Short-circuit the def-use traversal if the symbolic name
3724       // ceases to appear in expressions.
3725       if (Old != SymName && !hasOperand(Old, SymName))
3726         continue;
3727 
3728       // SCEVUnknown for a PHI either means that it has an unrecognized
3729       // structure, it's a PHI that's in the progress of being computed
3730       // by createNodeForPHI, or it's a single-value PHI. In the first case,
3731       // additional loop trip count information isn't going to change anything.
3732       // In the second case, createNodeForPHI will perform the necessary
3733       // updates on its own when it gets to that point. In the third, we do
3734       // want to forget the SCEVUnknown.
3735       if (!isa<PHINode>(I) ||
3736           !isa<SCEVUnknown>(Old) ||
3737           (I != PN && Old == SymName)) {
3738         eraseValueFromMap(It->first);
3739         forgetMemoizedResults(Old);
3740       }
3741     }
3742 
3743     PushDefUseChildren(I, Worklist);
3744   }
3745 }
3746 
3747 namespace {
3748 class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
3749 public:
3750   static const SCEV *rewrite(const SCEV *S, const Loop *L,
3751                              ScalarEvolution &SE) {
3752     SCEVInitRewriter Rewriter(L, SE);
3753     const SCEV *Result = Rewriter.visit(S);
3754     return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
3755   }
3756 
3757   SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
3758       : SCEVRewriteVisitor(SE), L(L), Valid(true) {}
3759 
3760   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
3761     if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant))
3762       Valid = false;
3763     return Expr;
3764   }
3765 
3766   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
3767     // Only allow AddRecExprs for this loop.
3768     if (Expr->getLoop() == L)
3769       return Expr->getStart();
3770     Valid = false;
3771     return Expr;
3772   }
3773 
3774   bool isValid() { return Valid; }
3775 
3776 private:
3777   const Loop *L;
3778   bool Valid;
3779 };
3780 
3781 class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
3782 public:
3783   static const SCEV *rewrite(const SCEV *S, const Loop *L,
3784                              ScalarEvolution &SE) {
3785     SCEVShiftRewriter Rewriter(L, SE);
3786     const SCEV *Result = Rewriter.visit(S);
3787     return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
3788   }
3789 
3790   SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
3791       : SCEVRewriteVisitor(SE), L(L), Valid(true) {}
3792 
3793   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
3794     // Only allow AddRecExprs for this loop.
3795     if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant))
3796       Valid = false;
3797     return Expr;
3798   }
3799 
3800   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
3801     if (Expr->getLoop() == L && Expr->isAffine())
3802       return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
3803     Valid = false;
3804     return Expr;
3805   }
3806   bool isValid() { return Valid; }
3807 
3808 private:
3809   const Loop *L;
3810   bool Valid;
3811 };
3812 } // end anonymous namespace
3813 
3814 SCEV::NoWrapFlags
3815 ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
3816   if (!AR->isAffine())
3817     return SCEV::FlagAnyWrap;
3818 
3819   typedef OverflowingBinaryOperator OBO;
3820   SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap;
3821 
3822   if (!AR->hasNoSignedWrap()) {
3823     ConstantRange AddRecRange = getSignedRange(AR);
3824     ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
3825 
3826     auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
3827         Instruction::Add, IncRange, OBO::NoSignedWrap);
3828     if (NSWRegion.contains(AddRecRange))
3829       Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW);
3830   }
3831 
3832   if (!AR->hasNoUnsignedWrap()) {
3833     ConstantRange AddRecRange = getUnsignedRange(AR);
3834     ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
3835 
3836     auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
3837         Instruction::Add, IncRange, OBO::NoUnsignedWrap);
3838     if (NUWRegion.contains(AddRecRange))
3839       Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW);
3840   }
3841 
3842   return Result;
3843 }
3844 
3845 namespace {
3846 /// Represents an abstract binary operation.  This may exist as a
3847 /// normal instruction or constant expression, or may have been
3848 /// derived from an expression tree.
3849 struct BinaryOp {
3850   unsigned Opcode;
3851   Value *LHS;
3852   Value *RHS;
3853   bool IsNSW;
3854   bool IsNUW;
3855 
3856   /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
3857   /// constant expression.
3858   Operator *Op;
3859 
3860   explicit BinaryOp(Operator *Op)
3861       : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
3862         IsNSW(false), IsNUW(false), Op(Op) {
3863     if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
3864       IsNSW = OBO->hasNoSignedWrap();
3865       IsNUW = OBO->hasNoUnsignedWrap();
3866     }
3867   }
3868 
3869   explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
3870                     bool IsNUW = false)
3871       : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW),
3872         Op(nullptr) {}
3873 };
3874 }
3875 
3876 
3877 /// Try to map \p V into a BinaryOp, and return \c None on failure.
3878 static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
3879   auto *Op = dyn_cast<Operator>(V);
3880   if (!Op)
3881     return None;
3882 
3883   // Implementation detail: all the cleverness here should happen without
3884   // creating new SCEV expressions -- our caller knowns tricks to avoid creating
3885   // SCEV expressions when possible, and we should not break that.
3886 
3887   switch (Op->getOpcode()) {
3888   case Instruction::Add:
3889   case Instruction::Sub:
3890   case Instruction::Mul:
3891   case Instruction::UDiv:
3892   case Instruction::And:
3893   case Instruction::Or:
3894   case Instruction::AShr:
3895   case Instruction::Shl:
3896     return BinaryOp(Op);
3897 
3898   case Instruction::Xor:
3899     if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
3900       // If the RHS of the xor is a signbit, then this is just an add.
3901       // Instcombine turns add of signbit into xor as a strength reduction step.
3902       if (RHSC->getValue().isSignBit())
3903         return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
3904     return BinaryOp(Op);
3905 
3906   case Instruction::LShr:
3907     // Turn logical shift right of a constant into a unsigned divide.
3908     if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
3909       uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
3910 
3911       // If the shift count is not less than the bitwidth, the result of
3912       // the shift is undefined. Don't try to analyze it, because the
3913       // resolution chosen here may differ from the resolution chosen in
3914       // other parts of the compiler.
3915       if (SA->getValue().ult(BitWidth)) {
3916         Constant *X =
3917             ConstantInt::get(SA->getContext(),
3918                              APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
3919         return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
3920       }
3921     }
3922     return BinaryOp(Op);
3923 
3924   case Instruction::ExtractValue: {
3925     auto *EVI = cast<ExtractValueInst>(Op);
3926     if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
3927       break;
3928 
3929     auto *CI = dyn_cast<CallInst>(EVI->getAggregateOperand());
3930     if (!CI)
3931       break;
3932 
3933     if (auto *F = CI->getCalledFunction())
3934       switch (F->getIntrinsicID()) {
3935       case Intrinsic::sadd_with_overflow:
3936       case Intrinsic::uadd_with_overflow: {
3937         if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT))
3938           return BinaryOp(Instruction::Add, CI->getArgOperand(0),
3939                           CI->getArgOperand(1));
3940 
3941         // Now that we know that all uses of the arithmetic-result component of
3942         // CI are guarded by the overflow check, we can go ahead and pretend
3943         // that the arithmetic is non-overflowing.
3944         if (F->getIntrinsicID() == Intrinsic::sadd_with_overflow)
3945           return BinaryOp(Instruction::Add, CI->getArgOperand(0),
3946                           CI->getArgOperand(1), /* IsNSW = */ true,
3947                           /* IsNUW = */ false);
3948         else
3949           return BinaryOp(Instruction::Add, CI->getArgOperand(0),
3950                           CI->getArgOperand(1), /* IsNSW = */ false,
3951                           /* IsNUW*/ true);
3952       }
3953 
3954       case Intrinsic::ssub_with_overflow:
3955       case Intrinsic::usub_with_overflow:
3956         return BinaryOp(Instruction::Sub, CI->getArgOperand(0),
3957                         CI->getArgOperand(1));
3958 
3959       case Intrinsic::smul_with_overflow:
3960       case Intrinsic::umul_with_overflow:
3961         return BinaryOp(Instruction::Mul, CI->getArgOperand(0),
3962                         CI->getArgOperand(1));
3963       default:
3964         break;
3965       }
3966   }
3967 
3968   default:
3969     break;
3970   }
3971 
3972   return None;
3973 }
3974 
3975 const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
3976   const Loop *L = LI.getLoopFor(PN->getParent());
3977   if (!L || L->getHeader() != PN->getParent())
3978     return nullptr;
3979 
3980   // The loop may have multiple entrances or multiple exits; we can analyze
3981   // this phi as an addrec if it has a unique entry value and a unique
3982   // backedge value.
3983   Value *BEValueV = nullptr, *StartValueV = nullptr;
3984   for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
3985     Value *V = PN->getIncomingValue(i);
3986     if (L->contains(PN->getIncomingBlock(i))) {
3987       if (!BEValueV) {
3988         BEValueV = V;
3989       } else if (BEValueV != V) {
3990         BEValueV = nullptr;
3991         break;
3992       }
3993     } else if (!StartValueV) {
3994       StartValueV = V;
3995     } else if (StartValueV != V) {
3996       StartValueV = nullptr;
3997       break;
3998     }
3999   }
4000   if (BEValueV && StartValueV) {
4001     // While we are analyzing this PHI node, handle its value symbolically.
4002     const SCEV *SymbolicName = getUnknown(PN);
4003     assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
4004            "PHI node already processed?");
4005     ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName});
4006 
4007     // Using this symbolic name for the PHI, analyze the value coming around
4008     // the back-edge.
4009     const SCEV *BEValue = getSCEV(BEValueV);
4010 
4011     // NOTE: If BEValue is loop invariant, we know that the PHI node just
4012     // has a special value for the first iteration of the loop.
4013 
4014     // If the value coming around the backedge is an add with the symbolic
4015     // value we just inserted, then we found a simple induction variable!
4016     if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
4017       // If there is a single occurrence of the symbolic value, replace it
4018       // with a recurrence.
4019       unsigned FoundIndex = Add->getNumOperands();
4020       for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
4021         if (Add->getOperand(i) == SymbolicName)
4022           if (FoundIndex == e) {
4023             FoundIndex = i;
4024             break;
4025           }
4026 
4027       if (FoundIndex != Add->getNumOperands()) {
4028         // Create an add with everything but the specified operand.
4029         SmallVector<const SCEV *, 8> Ops;
4030         for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
4031           if (i != FoundIndex)
4032             Ops.push_back(Add->getOperand(i));
4033         const SCEV *Accum = getAddExpr(Ops);
4034 
4035         // This is not a valid addrec if the step amount is varying each
4036         // loop iteration, but is not itself an addrec in this loop.
4037         if (isLoopInvariant(Accum, L) ||
4038             (isa<SCEVAddRecExpr>(Accum) &&
4039              cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
4040           SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
4041 
4042           if (auto BO = MatchBinaryOp(BEValueV, DT)) {
4043             if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
4044               if (BO->IsNUW)
4045                 Flags = setFlags(Flags, SCEV::FlagNUW);
4046               if (BO->IsNSW)
4047                 Flags = setFlags(Flags, SCEV::FlagNSW);
4048             }
4049           } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
4050             // If the increment is an inbounds GEP, then we know the address
4051             // space cannot be wrapped around. We cannot make any guarantee
4052             // about signed or unsigned overflow because pointers are
4053             // unsigned but we may have a negative index from the base
4054             // pointer. We can guarantee that no unsigned wrap occurs if the
4055             // indices form a positive value.
4056             if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
4057               Flags = setFlags(Flags, SCEV::FlagNW);
4058 
4059               const SCEV *Ptr = getSCEV(GEP->getPointerOperand());
4060               if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))
4061                 Flags = setFlags(Flags, SCEV::FlagNUW);
4062             }
4063 
4064             // We cannot transfer nuw and nsw flags from subtraction
4065             // operations -- sub nuw X, Y is not the same as add nuw X, -Y
4066             // for instance.
4067           }
4068 
4069           const SCEV *StartVal = getSCEV(StartValueV);
4070           const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
4071 
4072           // Okay, for the entire analysis of this edge we assumed the PHI
4073           // to be symbolic.  We now need to go back and purge all of the
4074           // entries for the scalars that use the symbolic expression.
4075           forgetSymbolicName(PN, SymbolicName);
4076           ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
4077 
4078           // We can add Flags to the post-inc expression only if we
4079           // know that it us *undefined behavior* for BEValueV to
4080           // overflow.
4081           if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
4082             if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
4083               (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
4084 
4085           return PHISCEV;
4086         }
4087       }
4088     } else {
4089       // Otherwise, this could be a loop like this:
4090       //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
4091       // In this case, j = {1,+,1}  and BEValue is j.
4092       // Because the other in-value of i (0) fits the evolution of BEValue
4093       // i really is an addrec evolution.
4094       //
4095       // We can generalize this saying that i is the shifted value of BEValue
4096       // by one iteration:
4097       //   PHI(f(0), f({1,+,1})) --> f({0,+,1})
4098       const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
4099       const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this);
4100       if (Shifted != getCouldNotCompute() &&
4101           Start != getCouldNotCompute()) {
4102         const SCEV *StartVal = getSCEV(StartValueV);
4103         if (Start == StartVal) {
4104           // Okay, for the entire analysis of this edge we assumed the PHI
4105           // to be symbolic.  We now need to go back and purge all of the
4106           // entries for the scalars that use the symbolic expression.
4107           forgetSymbolicName(PN, SymbolicName);
4108           ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted;
4109           return Shifted;
4110         }
4111       }
4112     }
4113 
4114     // Remove the temporary PHI node SCEV that has been inserted while intending
4115     // to create an AddRecExpr for this PHI node. We can not keep this temporary
4116     // as it will prevent later (possibly simpler) SCEV expressions to be added
4117     // to the ValueExprMap.
4118     eraseValueFromMap(PN);
4119   }
4120 
4121   return nullptr;
4122 }
4123 
4124 // Checks if the SCEV S is available at BB.  S is considered available at BB
4125 // if S can be materialized at BB without introducing a fault.
4126 static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
4127                                BasicBlock *BB) {
4128   struct CheckAvailable {
4129     bool TraversalDone = false;
4130     bool Available = true;
4131 
4132     const Loop *L = nullptr;  // The loop BB is in (can be nullptr)
4133     BasicBlock *BB = nullptr;
4134     DominatorTree &DT;
4135 
4136     CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT)
4137       : L(L), BB(BB), DT(DT) {}
4138 
4139     bool setUnavailable() {
4140       TraversalDone = true;
4141       Available = false;
4142       return false;
4143     }
4144 
4145     bool follow(const SCEV *S) {
4146       switch (S->getSCEVType()) {
4147       case scConstant: case scTruncate: case scZeroExtend: case scSignExtend:
4148       case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr:
4149         // These expressions are available if their operand(s) is/are.
4150         return true;
4151 
4152       case scAddRecExpr: {
4153         // We allow add recurrences that are on the loop BB is in, or some
4154         // outer loop.  This guarantees availability because the value of the
4155         // add recurrence at BB is simply the "current" value of the induction
4156         // variable.  We can relax this in the future; for instance an add
4157         // recurrence on a sibling dominating loop is also available at BB.
4158         const auto *ARLoop = cast<SCEVAddRecExpr>(S)->getLoop();
4159         if (L && (ARLoop == L || ARLoop->contains(L)))
4160           return true;
4161 
4162         return setUnavailable();
4163       }
4164 
4165       case scUnknown: {
4166         // For SCEVUnknown, we check for simple dominance.
4167         const auto *SU = cast<SCEVUnknown>(S);
4168         Value *V = SU->getValue();
4169 
4170         if (isa<Argument>(V))
4171           return false;
4172 
4173         if (isa<Instruction>(V) && DT.dominates(cast<Instruction>(V), BB))
4174           return false;
4175 
4176         return setUnavailable();
4177       }
4178 
4179       case scUDivExpr:
4180       case scCouldNotCompute:
4181         // We do not try to smart about these at all.
4182         return setUnavailable();
4183       }
4184       llvm_unreachable("switch should be fully covered!");
4185     }
4186 
4187     bool isDone() { return TraversalDone; }
4188   };
4189 
4190   CheckAvailable CA(L, BB, DT);
4191   SCEVTraversal<CheckAvailable> ST(CA);
4192 
4193   ST.visitAll(S);
4194   return CA.Available;
4195 }
4196 
4197 // Try to match a control flow sequence that branches out at BI and merges back
4198 // at Merge into a "C ? LHS : RHS" select pattern.  Return true on a successful
4199 // match.
4200 static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge,
4201                           Value *&C, Value *&LHS, Value *&RHS) {
4202   C = BI->getCondition();
4203 
4204   BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
4205   BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
4206 
4207   if (!LeftEdge.isSingleEdge())
4208     return false;
4209 
4210   assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
4211 
4212   Use &LeftUse = Merge->getOperandUse(0);
4213   Use &RightUse = Merge->getOperandUse(1);
4214 
4215   if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
4216     LHS = LeftUse;
4217     RHS = RightUse;
4218     return true;
4219   }
4220 
4221   if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
4222     LHS = RightUse;
4223     RHS = LeftUse;
4224     return true;
4225   }
4226 
4227   return false;
4228 }
4229 
4230 const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
4231   auto IsReachable =
4232       [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
4233   if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
4234     const Loop *L = LI.getLoopFor(PN->getParent());
4235 
4236     // We don't want to break LCSSA, even in a SCEV expression tree.
4237     for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
4238       if (LI.getLoopFor(PN->getIncomingBlock(i)) != L)
4239         return nullptr;
4240 
4241     // Try to match
4242     //
4243     //  br %cond, label %left, label %right
4244     // left:
4245     //  br label %merge
4246     // right:
4247     //  br label %merge
4248     // merge:
4249     //  V = phi [ %x, %left ], [ %y, %right ]
4250     //
4251     // as "select %cond, %x, %y"
4252 
4253     BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
4254     assert(IDom && "At least the entry block should dominate PN");
4255 
4256     auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
4257     Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
4258 
4259     if (BI && BI->isConditional() &&
4260         BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
4261         IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) &&
4262         IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent()))
4263       return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
4264   }
4265 
4266   return nullptr;
4267 }
4268 
4269 const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
4270   if (const SCEV *S = createAddRecFromPHI(PN))
4271     return S;
4272 
4273   if (const SCEV *S = createNodeFromSelectLikePHI(PN))
4274     return S;
4275 
4276   // If the PHI has a single incoming value, follow that value, unless the
4277   // PHI's incoming blocks are in a different loop, in which case doing so
4278   // risks breaking LCSSA form. Instcombine would normally zap these, but
4279   // it doesn't have DominatorTree information, so it may miss cases.
4280   if (Value *V = SimplifyInstruction(PN, getDataLayout(), &TLI, &DT, &AC))
4281     if (LI.replacementPreservesLCSSAForm(PN, V))
4282       return getSCEV(V);
4283 
4284   // If it's not a loop phi, we can't handle it yet.
4285   return getUnknown(PN);
4286 }
4287 
4288 const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,
4289                                                       Value *Cond,
4290                                                       Value *TrueVal,
4291                                                       Value *FalseVal) {
4292   // Handle "constant" branch or select. This can occur for instance when a
4293   // loop pass transforms an inner loop and moves on to process the outer loop.
4294   if (auto *CI = dyn_cast<ConstantInt>(Cond))
4295     return getSCEV(CI->isOne() ? TrueVal : FalseVal);
4296 
4297   // Try to match some simple smax or umax patterns.
4298   auto *ICI = dyn_cast<ICmpInst>(Cond);
4299   if (!ICI)
4300     return getUnknown(I);
4301 
4302   Value *LHS = ICI->getOperand(0);
4303   Value *RHS = ICI->getOperand(1);
4304 
4305   switch (ICI->getPredicate()) {
4306   case ICmpInst::ICMP_SLT:
4307   case ICmpInst::ICMP_SLE:
4308     std::swap(LHS, RHS);
4309     LLVM_FALLTHROUGH;
4310   case ICmpInst::ICMP_SGT:
4311   case ICmpInst::ICMP_SGE:
4312     // a >s b ? a+x : b+x  ->  smax(a, b)+x
4313     // a >s b ? b+x : a+x  ->  smin(a, b)+x
4314     if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
4315       const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType());
4316       const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType());
4317       const SCEV *LA = getSCEV(TrueVal);
4318       const SCEV *RA = getSCEV(FalseVal);
4319       const SCEV *LDiff = getMinusSCEV(LA, LS);
4320       const SCEV *RDiff = getMinusSCEV(RA, RS);
4321       if (LDiff == RDiff)
4322         return getAddExpr(getSMaxExpr(LS, RS), LDiff);
4323       LDiff = getMinusSCEV(LA, RS);
4324       RDiff = getMinusSCEV(RA, LS);
4325       if (LDiff == RDiff)
4326         return getAddExpr(getSMinExpr(LS, RS), LDiff);
4327     }
4328     break;
4329   case ICmpInst::ICMP_ULT:
4330   case ICmpInst::ICMP_ULE:
4331     std::swap(LHS, RHS);
4332     LLVM_FALLTHROUGH;
4333   case ICmpInst::ICMP_UGT:
4334   case ICmpInst::ICMP_UGE:
4335     // a >u b ? a+x : b+x  ->  umax(a, b)+x
4336     // a >u b ? b+x : a+x  ->  umin(a, b)+x
4337     if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
4338       const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
4339       const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType());
4340       const SCEV *LA = getSCEV(TrueVal);
4341       const SCEV *RA = getSCEV(FalseVal);
4342       const SCEV *LDiff = getMinusSCEV(LA, LS);
4343       const SCEV *RDiff = getMinusSCEV(RA, RS);
4344       if (LDiff == RDiff)
4345         return getAddExpr(getUMaxExpr(LS, RS), LDiff);
4346       LDiff = getMinusSCEV(LA, RS);
4347       RDiff = getMinusSCEV(RA, LS);
4348       if (LDiff == RDiff)
4349         return getAddExpr(getUMinExpr(LS, RS), LDiff);
4350     }
4351     break;
4352   case ICmpInst::ICMP_NE:
4353     // n != 0 ? n+x : 1+x  ->  umax(n, 1)+x
4354     if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
4355         isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
4356       const SCEV *One = getOne(I->getType());
4357       const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
4358       const SCEV *LA = getSCEV(TrueVal);
4359       const SCEV *RA = getSCEV(FalseVal);
4360       const SCEV *LDiff = getMinusSCEV(LA, LS);
4361       const SCEV *RDiff = getMinusSCEV(RA, One);
4362       if (LDiff == RDiff)
4363         return getAddExpr(getUMaxExpr(One, LS), LDiff);
4364     }
4365     break;
4366   case ICmpInst::ICMP_EQ:
4367     // n == 0 ? 1+x : n+x  ->  umax(n, 1)+x
4368     if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
4369         isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
4370       const SCEV *One = getOne(I->getType());
4371       const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
4372       const SCEV *LA = getSCEV(TrueVal);
4373       const SCEV *RA = getSCEV(FalseVal);
4374       const SCEV *LDiff = getMinusSCEV(LA, One);
4375       const SCEV *RDiff = getMinusSCEV(RA, LS);
4376       if (LDiff == RDiff)
4377         return getAddExpr(getUMaxExpr(One, LS), LDiff);
4378     }
4379     break;
4380   default:
4381     break;
4382   }
4383 
4384   return getUnknown(I);
4385 }
4386 
4387 /// Expand GEP instructions into add and multiply operations. This allows them
4388 /// to be analyzed by regular SCEV code.
4389 const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
4390   // Don't attempt to analyze GEPs over unsized objects.
4391   if (!GEP->getSourceElementType()->isSized())
4392     return getUnknown(GEP);
4393 
4394   SmallVector<const SCEV *, 4> IndexExprs;
4395   for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index)
4396     IndexExprs.push_back(getSCEV(*Index));
4397   return getGEPExpr(GEP->getSourceElementType(),
4398                     getSCEV(GEP->getPointerOperand()),
4399                     IndexExprs, GEP->isInBounds());
4400 }
4401 
4402 uint32_t
4403 ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
4404   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
4405     return C->getAPInt().countTrailingZeros();
4406 
4407   if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
4408     return std::min(GetMinTrailingZeros(T->getOperand()),
4409                     (uint32_t)getTypeSizeInBits(T->getType()));
4410 
4411   if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
4412     uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
4413     return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
4414              getTypeSizeInBits(E->getType()) : OpRes;
4415   }
4416 
4417   if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
4418     uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
4419     return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
4420              getTypeSizeInBits(E->getType()) : OpRes;
4421   }
4422 
4423   if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
4424     // The result is the min of all operands results.
4425     uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
4426     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
4427       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
4428     return MinOpRes;
4429   }
4430 
4431   if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
4432     // The result is the sum of all operands results.
4433     uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
4434     uint32_t BitWidth = getTypeSizeInBits(M->getType());
4435     for (unsigned i = 1, e = M->getNumOperands();
4436          SumOpRes != BitWidth && i != e; ++i)
4437       SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
4438                           BitWidth);
4439     return SumOpRes;
4440   }
4441 
4442   if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
4443     // The result is the min of all operands results.
4444     uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
4445     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
4446       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
4447     return MinOpRes;
4448   }
4449 
4450   if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
4451     // The result is the min of all operands results.
4452     uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
4453     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
4454       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
4455     return MinOpRes;
4456   }
4457 
4458   if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
4459     // The result is the min of all operands results.
4460     uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
4461     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
4462       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
4463     return MinOpRes;
4464   }
4465 
4466   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
4467     // For a SCEVUnknown, ask ValueTracking.
4468     unsigned BitWidth = getTypeSizeInBits(U->getType());
4469     APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
4470     computeKnownBits(U->getValue(), Zeros, Ones, getDataLayout(), 0, &AC,
4471                      nullptr, &DT);
4472     return Zeros.countTrailingOnes();
4473   }
4474 
4475   // SCEVUDivExpr
4476   return 0;
4477 }
4478 
4479 /// Helper method to assign a range to V from metadata present in the IR.
4480 static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
4481   if (Instruction *I = dyn_cast<Instruction>(V))
4482     if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
4483       return getConstantRangeFromMetadata(*MD);
4484 
4485   return None;
4486 }
4487 
4488 /// Determine the range for a particular SCEV.  If SignHint is
4489 /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
4490 /// with a "cleaner" unsigned (resp. signed) representation.
4491 ConstantRange
4492 ScalarEvolution::getRange(const SCEV *S,
4493                           ScalarEvolution::RangeSignHint SignHint) {
4494   DenseMap<const SCEV *, ConstantRange> &Cache =
4495       SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
4496                                                        : SignedRanges;
4497 
4498   // See if we've computed this range already.
4499   DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S);
4500   if (I != Cache.end())
4501     return I->second;
4502 
4503   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
4504     return setRange(C, SignHint, ConstantRange(C->getAPInt()));
4505 
4506   unsigned BitWidth = getTypeSizeInBits(S->getType());
4507   ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
4508 
4509   // If the value has known zeros, the maximum value will have those known zeros
4510   // as well.
4511   uint32_t TZ = GetMinTrailingZeros(S);
4512   if (TZ != 0) {
4513     if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED)
4514       ConservativeResult =
4515           ConstantRange(APInt::getMinValue(BitWidth),
4516                         APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
4517     else
4518       ConservativeResult = ConstantRange(
4519           APInt::getSignedMinValue(BitWidth),
4520           APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
4521   }
4522 
4523   if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
4524     ConstantRange X = getRange(Add->getOperand(0), SignHint);
4525     for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
4526       X = X.add(getRange(Add->getOperand(i), SignHint));
4527     return setRange(Add, SignHint, ConservativeResult.intersectWith(X));
4528   }
4529 
4530   if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
4531     ConstantRange X = getRange(Mul->getOperand(0), SignHint);
4532     for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
4533       X = X.multiply(getRange(Mul->getOperand(i), SignHint));
4534     return setRange(Mul, SignHint, ConservativeResult.intersectWith(X));
4535   }
4536 
4537   if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
4538     ConstantRange X = getRange(SMax->getOperand(0), SignHint);
4539     for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
4540       X = X.smax(getRange(SMax->getOperand(i), SignHint));
4541     return setRange(SMax, SignHint, ConservativeResult.intersectWith(X));
4542   }
4543 
4544   if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
4545     ConstantRange X = getRange(UMax->getOperand(0), SignHint);
4546     for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
4547       X = X.umax(getRange(UMax->getOperand(i), SignHint));
4548     return setRange(UMax, SignHint, ConservativeResult.intersectWith(X));
4549   }
4550 
4551   if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
4552     ConstantRange X = getRange(UDiv->getLHS(), SignHint);
4553     ConstantRange Y = getRange(UDiv->getRHS(), SignHint);
4554     return setRange(UDiv, SignHint,
4555                     ConservativeResult.intersectWith(X.udiv(Y)));
4556   }
4557 
4558   if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
4559     ConstantRange X = getRange(ZExt->getOperand(), SignHint);
4560     return setRange(ZExt, SignHint,
4561                     ConservativeResult.intersectWith(X.zeroExtend(BitWidth)));
4562   }
4563 
4564   if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
4565     ConstantRange X = getRange(SExt->getOperand(), SignHint);
4566     return setRange(SExt, SignHint,
4567                     ConservativeResult.intersectWith(X.signExtend(BitWidth)));
4568   }
4569 
4570   if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
4571     ConstantRange X = getRange(Trunc->getOperand(), SignHint);
4572     return setRange(Trunc, SignHint,
4573                     ConservativeResult.intersectWith(X.truncate(BitWidth)));
4574   }
4575 
4576   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
4577     // If there's no unsigned wrap, the value will never be less than its
4578     // initial value.
4579     if (AddRec->hasNoUnsignedWrap())
4580       if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart()))
4581         if (!C->getValue()->isZero())
4582           ConservativeResult = ConservativeResult.intersectWith(
4583               ConstantRange(C->getAPInt(), APInt(BitWidth, 0)));
4584 
4585     // If there's no signed wrap, and all the operands have the same sign or
4586     // zero, the value won't ever change sign.
4587     if (AddRec->hasNoSignedWrap()) {
4588       bool AllNonNeg = true;
4589       bool AllNonPos = true;
4590       for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
4591         if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false;
4592         if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false;
4593       }
4594       if (AllNonNeg)
4595         ConservativeResult = ConservativeResult.intersectWith(
4596           ConstantRange(APInt(BitWidth, 0),
4597                         APInt::getSignedMinValue(BitWidth)));
4598       else if (AllNonPos)
4599         ConservativeResult = ConservativeResult.intersectWith(
4600           ConstantRange(APInt::getSignedMinValue(BitWidth),
4601                         APInt(BitWidth, 1)));
4602     }
4603 
4604     // TODO: non-affine addrec
4605     if (AddRec->isAffine()) {
4606       const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
4607       if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
4608           getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
4609         auto RangeFromAffine = getRangeForAffineAR(
4610             AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
4611             BitWidth);
4612         if (!RangeFromAffine.isFullSet())
4613           ConservativeResult =
4614               ConservativeResult.intersectWith(RangeFromAffine);
4615 
4616         auto RangeFromFactoring = getRangeViaFactoring(
4617             AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
4618             BitWidth);
4619         if (!RangeFromFactoring.isFullSet())
4620           ConservativeResult =
4621               ConservativeResult.intersectWith(RangeFromFactoring);
4622       }
4623     }
4624 
4625     return setRange(AddRec, SignHint, ConservativeResult);
4626   }
4627 
4628   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
4629     // Check if the IR explicitly contains !range metadata.
4630     Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
4631     if (MDRange.hasValue())
4632       ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue());
4633 
4634     // Split here to avoid paying the compile-time cost of calling both
4635     // computeKnownBits and ComputeNumSignBits.  This restriction can be lifted
4636     // if needed.
4637     const DataLayout &DL = getDataLayout();
4638     if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
4639       // For a SCEVUnknown, ask ValueTracking.
4640       APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
4641       computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, &AC, nullptr, &DT);
4642       if (Ones != ~Zeros + 1)
4643         ConservativeResult =
4644             ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1));
4645     } else {
4646       assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED &&
4647              "generalize as needed!");
4648       unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
4649       if (NS > 1)
4650         ConservativeResult = ConservativeResult.intersectWith(
4651             ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
4652                           APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1));
4653     }
4654 
4655     return setRange(U, SignHint, ConservativeResult);
4656   }
4657 
4658   return setRange(S, SignHint, ConservativeResult);
4659 }
4660 
4661 ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
4662                                                    const SCEV *Step,
4663                                                    const SCEV *MaxBECount,
4664                                                    unsigned BitWidth) {
4665   assert(!isa<SCEVCouldNotCompute>(MaxBECount) &&
4666          getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
4667          "Precondition!");
4668 
4669   ConstantRange Result(BitWidth, /* isFullSet = */ true);
4670 
4671   // Check for overflow.  This must be done with ConstantRange arithmetic
4672   // because we could be called from within the ScalarEvolution overflow
4673   // checking code.
4674 
4675   MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType());
4676   ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
4677   ConstantRange ZExtMaxBECountRange =
4678       MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1);
4679 
4680   ConstantRange StepSRange = getSignedRange(Step);
4681   ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1);
4682 
4683   ConstantRange StartURange = getUnsignedRange(Start);
4684   ConstantRange EndURange =
4685       StartURange.add(MaxBECountRange.multiply(StepSRange));
4686 
4687   // Check for unsigned overflow.
4688   ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2 + 1);
4689   ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1);
4690   if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) ==
4691       ZExtEndURange) {
4692     APInt Min = APIntOps::umin(StartURange.getUnsignedMin(),
4693                                EndURange.getUnsignedMin());
4694     APInt Max = APIntOps::umax(StartURange.getUnsignedMax(),
4695                                EndURange.getUnsignedMax());
4696     bool IsFullRange = Min.isMinValue() && Max.isMaxValue();
4697     if (!IsFullRange)
4698       Result =
4699           Result.intersectWith(ConstantRange(Min, Max + 1));
4700   }
4701 
4702   ConstantRange StartSRange = getSignedRange(Start);
4703   ConstantRange EndSRange =
4704       StartSRange.add(MaxBECountRange.multiply(StepSRange));
4705 
4706   // Check for signed overflow. This must be done with ConstantRange
4707   // arithmetic because we could be called from within the ScalarEvolution
4708   // overflow checking code.
4709   ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2 + 1);
4710   ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1);
4711   if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) ==
4712       SExtEndSRange) {
4713     APInt Min =
4714         APIntOps::smin(StartSRange.getSignedMin(), EndSRange.getSignedMin());
4715     APInt Max =
4716         APIntOps::smax(StartSRange.getSignedMax(), EndSRange.getSignedMax());
4717     bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue();
4718     if (!IsFullRange)
4719       Result =
4720           Result.intersectWith(ConstantRange(Min, Max + 1));
4721   }
4722 
4723   return Result;
4724 }
4725 
4726 ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
4727                                                     const SCEV *Step,
4728                                                     const SCEV *MaxBECount,
4729                                                     unsigned BitWidth) {
4730   //    RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
4731   // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
4732 
4733   struct SelectPattern {
4734     Value *Condition = nullptr;
4735     APInt TrueValue;
4736     APInt FalseValue;
4737 
4738     explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
4739                            const SCEV *S) {
4740       Optional<unsigned> CastOp;
4741       APInt Offset(BitWidth, 0);
4742 
4743       assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
4744              "Should be!");
4745 
4746       // Peel off a constant offset:
4747       if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
4748         // In the future we could consider being smarter here and handle
4749         // {Start+Step,+,Step} too.
4750         if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
4751           return;
4752 
4753         Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
4754         S = SA->getOperand(1);
4755       }
4756 
4757       // Peel off a cast operation
4758       if (auto *SCast = dyn_cast<SCEVCastExpr>(S)) {
4759         CastOp = SCast->getSCEVType();
4760         S = SCast->getOperand();
4761       }
4762 
4763       using namespace llvm::PatternMatch;
4764 
4765       auto *SU = dyn_cast<SCEVUnknown>(S);
4766       const APInt *TrueVal, *FalseVal;
4767       if (!SU ||
4768           !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
4769                                           m_APInt(FalseVal)))) {
4770         Condition = nullptr;
4771         return;
4772       }
4773 
4774       TrueValue = *TrueVal;
4775       FalseValue = *FalseVal;
4776 
4777       // Re-apply the cast we peeled off earlier
4778       if (CastOp.hasValue())
4779         switch (*CastOp) {
4780         default:
4781           llvm_unreachable("Unknown SCEV cast type!");
4782 
4783         case scTruncate:
4784           TrueValue = TrueValue.trunc(BitWidth);
4785           FalseValue = FalseValue.trunc(BitWidth);
4786           break;
4787         case scZeroExtend:
4788           TrueValue = TrueValue.zext(BitWidth);
4789           FalseValue = FalseValue.zext(BitWidth);
4790           break;
4791         case scSignExtend:
4792           TrueValue = TrueValue.sext(BitWidth);
4793           FalseValue = FalseValue.sext(BitWidth);
4794           break;
4795         }
4796 
4797       // Re-apply the constant offset we peeled off earlier
4798       TrueValue += Offset;
4799       FalseValue += Offset;
4800     }
4801 
4802     bool isRecognized() { return Condition != nullptr; }
4803   };
4804 
4805   SelectPattern StartPattern(*this, BitWidth, Start);
4806   if (!StartPattern.isRecognized())
4807     return ConstantRange(BitWidth, /* isFullSet = */ true);
4808 
4809   SelectPattern StepPattern(*this, BitWidth, Step);
4810   if (!StepPattern.isRecognized())
4811     return ConstantRange(BitWidth, /* isFullSet = */ true);
4812 
4813   if (StartPattern.Condition != StepPattern.Condition) {
4814     // We don't handle this case today; but we could, by considering four
4815     // possibilities below instead of two. I'm not sure if there are cases where
4816     // that will help over what getRange already does, though.
4817     return ConstantRange(BitWidth, /* isFullSet = */ true);
4818   }
4819 
4820   // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
4821   // construct arbitrary general SCEV expressions here.  This function is called
4822   // from deep in the call stack, and calling getSCEV (on a sext instruction,
4823   // say) can end up caching a suboptimal value.
4824 
4825   // FIXME: without the explicit `this` receiver below, MSVC errors out with
4826   // C2352 and C2512 (otherwise it isn't needed).
4827 
4828   const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
4829   const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
4830   const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
4831   const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
4832 
4833   ConstantRange TrueRange =
4834       this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth);
4835   ConstantRange FalseRange =
4836       this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth);
4837 
4838   return TrueRange.unionWith(FalseRange);
4839 }
4840 
4841 SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
4842   if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
4843   const BinaryOperator *BinOp = cast<BinaryOperator>(V);
4844 
4845   // Return early if there are no flags to propagate to the SCEV.
4846   SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
4847   if (BinOp->hasNoUnsignedWrap())
4848     Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
4849   if (BinOp->hasNoSignedWrap())
4850     Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
4851   if (Flags == SCEV::FlagAnyWrap)
4852     return SCEV::FlagAnyWrap;
4853 
4854   return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
4855 }
4856 
4857 bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
4858   // Here we check that I is in the header of the innermost loop containing I,
4859   // since we only deal with instructions in the loop header. The actual loop we
4860   // need to check later will come from an add recurrence, but getting that
4861   // requires computing the SCEV of the operands, which can be expensive. This
4862   // check we can do cheaply to rule out some cases early.
4863   Loop *InnermostContainingLoop = LI.getLoopFor(I->getParent());
4864   if (InnermostContainingLoop == nullptr ||
4865       InnermostContainingLoop->getHeader() != I->getParent())
4866     return false;
4867 
4868   // Only proceed if we can prove that I does not yield poison.
4869   if (!isKnownNotFullPoison(I)) return false;
4870 
4871   // At this point we know that if I is executed, then it does not wrap
4872   // according to at least one of NSW or NUW. If I is not executed, then we do
4873   // not know if the calculation that I represents would wrap. Multiple
4874   // instructions can map to the same SCEV. If we apply NSW or NUW from I to
4875   // the SCEV, we must guarantee no wrapping for that SCEV also when it is
4876   // derived from other instructions that map to the same SCEV. We cannot make
4877   // that guarantee for cases where I is not executed. So we need to find the
4878   // loop that I is considered in relation to and prove that I is executed for
4879   // every iteration of that loop. That implies that the value that I
4880   // calculates does not wrap anywhere in the loop, so then we can apply the
4881   // flags to the SCEV.
4882   //
4883   // We check isLoopInvariant to disambiguate in case we are adding recurrences
4884   // from different loops, so that we know which loop to prove that I is
4885   // executed in.
4886   for (unsigned OpIndex = 0; OpIndex < I->getNumOperands(); ++OpIndex) {
4887     // I could be an extractvalue from a call to an overflow intrinsic.
4888     // TODO: We can do better here in some cases.
4889     if (!isSCEVable(I->getOperand(OpIndex)->getType()))
4890       return false;
4891     const SCEV *Op = getSCEV(I->getOperand(OpIndex));
4892     if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
4893       bool AllOtherOpsLoopInvariant = true;
4894       for (unsigned OtherOpIndex = 0; OtherOpIndex < I->getNumOperands();
4895            ++OtherOpIndex) {
4896         if (OtherOpIndex != OpIndex) {
4897           const SCEV *OtherOp = getSCEV(I->getOperand(OtherOpIndex));
4898           if (!isLoopInvariant(OtherOp, AddRec->getLoop())) {
4899             AllOtherOpsLoopInvariant = false;
4900             break;
4901           }
4902         }
4903       }
4904       if (AllOtherOpsLoopInvariant &&
4905           isGuaranteedToExecuteForEveryIteration(I, AddRec->getLoop()))
4906         return true;
4907     }
4908   }
4909   return false;
4910 }
4911 
4912 bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
4913   // If we know that \c I can never be poison period, then that's enough.
4914   if (isSCEVExprNeverPoison(I))
4915     return true;
4916 
4917   // For an add recurrence specifically, we assume that infinite loops without
4918   // side effects are undefined behavior, and then reason as follows:
4919   //
4920   // If the add recurrence is poison in any iteration, it is poison on all
4921   // future iterations (since incrementing poison yields poison). If the result
4922   // of the add recurrence is fed into the loop latch condition and the loop
4923   // does not contain any throws or exiting blocks other than the latch, we now
4924   // have the ability to "choose" whether the backedge is taken or not (by
4925   // choosing a sufficiently evil value for the poison feeding into the branch)
4926   // for every iteration including and after the one in which \p I first became
4927   // poison.  There are two possibilities (let's call the iteration in which \p
4928   // I first became poison as K):
4929   //
4930   //  1. In the set of iterations including and after K, the loop body executes
4931   //     no side effects.  In this case executing the backege an infinte number
4932   //     of times will yield undefined behavior.
4933   //
4934   //  2. In the set of iterations including and after K, the loop body executes
4935   //     at least one side effect.  In this case, that specific instance of side
4936   //     effect is control dependent on poison, which also yields undefined
4937   //     behavior.
4938 
4939   auto *ExitingBB = L->getExitingBlock();
4940   auto *LatchBB = L->getLoopLatch();
4941   if (!ExitingBB || !LatchBB || ExitingBB != LatchBB)
4942     return false;
4943 
4944   SmallPtrSet<const Instruction *, 16> Pushed;
4945   SmallVector<const Instruction *, 8> PoisonStack;
4946 
4947   // We start by assuming \c I, the post-inc add recurrence, is poison.  Only
4948   // things that are known to be fully poison under that assumption go on the
4949   // PoisonStack.
4950   Pushed.insert(I);
4951   PoisonStack.push_back(I);
4952 
4953   bool LatchControlDependentOnPoison = false;
4954   while (!PoisonStack.empty() && !LatchControlDependentOnPoison) {
4955     const Instruction *Poison = PoisonStack.pop_back_val();
4956 
4957     for (auto *PoisonUser : Poison->users()) {
4958       if (propagatesFullPoison(cast<Instruction>(PoisonUser))) {
4959         if (Pushed.insert(cast<Instruction>(PoisonUser)).second)
4960           PoisonStack.push_back(cast<Instruction>(PoisonUser));
4961       } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) {
4962         assert(BI->isConditional() && "Only possibility!");
4963         if (BI->getParent() == LatchBB) {
4964           LatchControlDependentOnPoison = true;
4965           break;
4966         }
4967       }
4968     }
4969   }
4970 
4971   return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L);
4972 }
4973 
4974 ScalarEvolution::LoopProperties
4975 ScalarEvolution::getLoopProperties(const Loop *L) {
4976   typedef ScalarEvolution::LoopProperties LoopProperties;
4977 
4978   auto Itr = LoopPropertiesCache.find(L);
4979   if (Itr == LoopPropertiesCache.end()) {
4980     auto HasSideEffects = [](Instruction *I) {
4981       if (auto *SI = dyn_cast<StoreInst>(I))
4982         return !SI->isSimple();
4983 
4984       return I->mayHaveSideEffects();
4985     };
4986 
4987     LoopProperties LP = {/* HasNoAbnormalExits */ true,
4988                          /*HasNoSideEffects*/ true};
4989 
4990     for (auto *BB : L->getBlocks())
4991       for (auto &I : *BB) {
4992         if (!isGuaranteedToTransferExecutionToSuccessor(&I))
4993           LP.HasNoAbnormalExits = false;
4994         if (HasSideEffects(&I))
4995           LP.HasNoSideEffects = false;
4996         if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
4997           break; // We're already as pessimistic as we can get.
4998       }
4999 
5000     auto InsertPair = LoopPropertiesCache.insert({L, LP});
5001     assert(InsertPair.second && "We just checked!");
5002     Itr = InsertPair.first;
5003   }
5004 
5005   return Itr->second;
5006 }
5007 
5008 const SCEV *ScalarEvolution::createSCEV(Value *V) {
5009   if (!isSCEVable(V->getType()))
5010     return getUnknown(V);
5011 
5012   if (Instruction *I = dyn_cast<Instruction>(V)) {
5013     // Don't attempt to analyze instructions in blocks that aren't
5014     // reachable. Such instructions don't matter, and they aren't required
5015     // to obey basic rules for definitions dominating uses which this
5016     // analysis depends on.
5017     if (!DT.isReachableFromEntry(I->getParent()))
5018       return getUnknown(V);
5019   } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
5020     return getConstant(CI);
5021   else if (isa<ConstantPointerNull>(V))
5022     return getZero(V->getType());
5023   else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
5024     return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee());
5025   else if (!isa<ConstantExpr>(V))
5026     return getUnknown(V);
5027 
5028   Operator *U = cast<Operator>(V);
5029   if (auto BO = MatchBinaryOp(U, DT)) {
5030     switch (BO->Opcode) {
5031     case Instruction::Add: {
5032       // The simple thing to do would be to just call getSCEV on both operands
5033       // and call getAddExpr with the result. However if we're looking at a
5034       // bunch of things all added together, this can be quite inefficient,
5035       // because it leads to N-1 getAddExpr calls for N ultimate operands.
5036       // Instead, gather up all the operands and make a single getAddExpr call.
5037       // LLVM IR canonical form means we need only traverse the left operands.
5038       SmallVector<const SCEV *, 4> AddOps;
5039       do {
5040         if (BO->Op) {
5041           if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
5042             AddOps.push_back(OpSCEV);
5043             break;
5044           }
5045 
5046           // If a NUW or NSW flag can be applied to the SCEV for this
5047           // addition, then compute the SCEV for this addition by itself
5048           // with a separate call to getAddExpr. We need to do that
5049           // instead of pushing the operands of the addition onto AddOps,
5050           // since the flags are only known to apply to this particular
5051           // addition - they may not apply to other additions that can be
5052           // formed with operands from AddOps.
5053           const SCEV *RHS = getSCEV(BO->RHS);
5054           SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
5055           if (Flags != SCEV::FlagAnyWrap) {
5056             const SCEV *LHS = getSCEV(BO->LHS);
5057             if (BO->Opcode == Instruction::Sub)
5058               AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
5059             else
5060               AddOps.push_back(getAddExpr(LHS, RHS, Flags));
5061             break;
5062           }
5063         }
5064 
5065         if (BO->Opcode == Instruction::Sub)
5066           AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
5067         else
5068           AddOps.push_back(getSCEV(BO->RHS));
5069 
5070         auto NewBO = MatchBinaryOp(BO->LHS, DT);
5071         if (!NewBO || (NewBO->Opcode != Instruction::Add &&
5072                        NewBO->Opcode != Instruction::Sub)) {
5073           AddOps.push_back(getSCEV(BO->LHS));
5074           break;
5075         }
5076         BO = NewBO;
5077       } while (true);
5078 
5079       return getAddExpr(AddOps);
5080     }
5081 
5082     case Instruction::Mul: {
5083       SmallVector<const SCEV *, 4> MulOps;
5084       do {
5085         if (BO->Op) {
5086           if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
5087             MulOps.push_back(OpSCEV);
5088             break;
5089           }
5090 
5091           SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
5092           if (Flags != SCEV::FlagAnyWrap) {
5093             MulOps.push_back(
5094                 getMulExpr(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags));
5095             break;
5096           }
5097         }
5098 
5099         MulOps.push_back(getSCEV(BO->RHS));
5100         auto NewBO = MatchBinaryOp(BO->LHS, DT);
5101         if (!NewBO || NewBO->Opcode != Instruction::Mul) {
5102           MulOps.push_back(getSCEV(BO->LHS));
5103           break;
5104         }
5105         BO = NewBO;
5106       } while (true);
5107 
5108       return getMulExpr(MulOps);
5109     }
5110     case Instruction::UDiv:
5111       return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS));
5112     case Instruction::Sub: {
5113       SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
5114       if (BO->Op)
5115         Flags = getNoWrapFlagsFromUB(BO->Op);
5116       return getMinusSCEV(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags);
5117     }
5118     case Instruction::And:
5119       // For an expression like x&255 that merely masks off the high bits,
5120       // use zext(trunc(x)) as the SCEV expression.
5121       if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
5122         if (CI->isNullValue())
5123           return getSCEV(BO->RHS);
5124         if (CI->isAllOnesValue())
5125           return getSCEV(BO->LHS);
5126         const APInt &A = CI->getValue();
5127 
5128         // Instcombine's ShrinkDemandedConstant may strip bits out of
5129         // constants, obscuring what would otherwise be a low-bits mask.
5130         // Use computeKnownBits to compute what ShrinkDemandedConstant
5131         // knew about to reconstruct a low-bits mask value.
5132         unsigned LZ = A.countLeadingZeros();
5133         unsigned TZ = A.countTrailingZeros();
5134         unsigned BitWidth = A.getBitWidth();
5135         APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
5136         computeKnownBits(BO->LHS, KnownZero, KnownOne, getDataLayout(),
5137                          0, &AC, nullptr, &DT);
5138 
5139         APInt EffectiveMask =
5140             APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
5141         if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) {
5142           const SCEV *MulCount = getConstant(ConstantInt::get(
5143               getContext(), APInt::getOneBitSet(BitWidth, TZ)));
5144           return getMulExpr(
5145               getZeroExtendExpr(
5146                   getTruncateExpr(
5147                       getUDivExactExpr(getSCEV(BO->LHS), MulCount),
5148                       IntegerType::get(getContext(), BitWidth - LZ - TZ)),
5149                   BO->LHS->getType()),
5150               MulCount);
5151         }
5152       }
5153       break;
5154 
5155     case Instruction::Or:
5156       // If the RHS of the Or is a constant, we may have something like:
5157       // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
5158       // optimizations will transparently handle this case.
5159       //
5160       // In order for this transformation to be safe, the LHS must be of the
5161       // form X*(2^n) and the Or constant must be less than 2^n.
5162       if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
5163         const SCEV *LHS = getSCEV(BO->LHS);
5164         const APInt &CIVal = CI->getValue();
5165         if (GetMinTrailingZeros(LHS) >=
5166             (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
5167           // Build a plain add SCEV.
5168           const SCEV *S = getAddExpr(LHS, getSCEV(CI));
5169           // If the LHS of the add was an addrec and it has no-wrap flags,
5170           // transfer the no-wrap flags, since an or won't introduce a wrap.
5171           if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) {
5172             const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS);
5173             const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags(
5174                 OldAR->getNoWrapFlags());
5175           }
5176           return S;
5177         }
5178       }
5179       break;
5180 
5181     case Instruction::Xor:
5182       if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
5183         // If the RHS of xor is -1, then this is a not operation.
5184         if (CI->isAllOnesValue())
5185           return getNotSCEV(getSCEV(BO->LHS));
5186 
5187         // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
5188         // This is a variant of the check for xor with -1, and it handles
5189         // the case where instcombine has trimmed non-demanded bits out
5190         // of an xor with -1.
5191         if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
5192           if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
5193             if (LBO->getOpcode() == Instruction::And &&
5194                 LCI->getValue() == CI->getValue())
5195               if (const SCEVZeroExtendExpr *Z =
5196                       dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
5197                 Type *UTy = BO->LHS->getType();
5198                 const SCEV *Z0 = Z->getOperand();
5199                 Type *Z0Ty = Z0->getType();
5200                 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
5201 
5202                 // If C is a low-bits mask, the zero extend is serving to
5203                 // mask off the high bits. Complement the operand and
5204                 // re-apply the zext.
5205                 if (APIntOps::isMask(Z0TySize, CI->getValue()))
5206                   return getZeroExtendExpr(getNotSCEV(Z0), UTy);
5207 
5208                 // If C is a single bit, it may be in the sign-bit position
5209                 // before the zero-extend. In this case, represent the xor
5210                 // using an add, which is equivalent, and re-apply the zext.
5211                 APInt Trunc = CI->getValue().trunc(Z0TySize);
5212                 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
5213                     Trunc.isSignBit())
5214                   return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
5215                                            UTy);
5216               }
5217       }
5218       break;
5219 
5220   case Instruction::Shl:
5221     // Turn shift left of a constant amount into a multiply.
5222     if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
5223       uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
5224 
5225       // If the shift count is not less than the bitwidth, the result of
5226       // the shift is undefined. Don't try to analyze it, because the
5227       // resolution chosen here may differ from the resolution chosen in
5228       // other parts of the compiler.
5229       if (SA->getValue().uge(BitWidth))
5230         break;
5231 
5232       // It is currently not resolved how to interpret NSW for left
5233       // shift by BitWidth - 1, so we avoid applying flags in that
5234       // case. Remove this check (or this comment) once the situation
5235       // is resolved. See
5236       // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html
5237       // and http://reviews.llvm.org/D8890 .
5238       auto Flags = SCEV::FlagAnyWrap;
5239       if (BO->Op && SA->getValue().ult(BitWidth - 1))
5240         Flags = getNoWrapFlagsFromUB(BO->Op);
5241 
5242       Constant *X = ConstantInt::get(getContext(),
5243         APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5244       return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags);
5245     }
5246     break;
5247 
5248     case Instruction::AShr:
5249       // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
5250       if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS))
5251         if (Operator *L = dyn_cast<Operator>(BO->LHS))
5252           if (L->getOpcode() == Instruction::Shl &&
5253               L->getOperand(1) == BO->RHS) {
5254             uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType());
5255 
5256             // If the shift count is not less than the bitwidth, the result of
5257             // the shift is undefined. Don't try to analyze it, because the
5258             // resolution chosen here may differ from the resolution chosen in
5259             // other parts of the compiler.
5260             if (CI->getValue().uge(BitWidth))
5261               break;
5262 
5263             uint64_t Amt = BitWidth - CI->getZExtValue();
5264             if (Amt == BitWidth)
5265               return getSCEV(L->getOperand(0)); // shift by zero --> noop
5266             return getSignExtendExpr(
5267                 getTruncateExpr(getSCEV(L->getOperand(0)),
5268                                 IntegerType::get(getContext(), Amt)),
5269                 BO->LHS->getType());
5270           }
5271       break;
5272     }
5273   }
5274 
5275   switch (U->getOpcode()) {
5276   case Instruction::Trunc:
5277     return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
5278 
5279   case Instruction::ZExt:
5280     return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
5281 
5282   case Instruction::SExt:
5283     return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
5284 
5285   case Instruction::BitCast:
5286     // BitCasts are no-op casts so we just eliminate the cast.
5287     if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
5288       return getSCEV(U->getOperand(0));
5289     break;
5290 
5291   // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can
5292   // lead to pointer expressions which cannot safely be expanded to GEPs,
5293   // because ScalarEvolution doesn't respect the GEP aliasing rules when
5294   // simplifying integer expressions.
5295 
5296   case Instruction::GetElementPtr:
5297     return createNodeForGEP(cast<GEPOperator>(U));
5298 
5299   case Instruction::PHI:
5300     return createNodeForPHI(cast<PHINode>(U));
5301 
5302   case Instruction::Select:
5303     // U can also be a select constant expr, which let fall through.  Since
5304     // createNodeForSelect only works for a condition that is an `ICmpInst`, and
5305     // constant expressions cannot have instructions as operands, we'd have
5306     // returned getUnknown for a select constant expressions anyway.
5307     if (isa<Instruction>(U))
5308       return createNodeForSelectOrPHI(cast<Instruction>(U), U->getOperand(0),
5309                                       U->getOperand(1), U->getOperand(2));
5310     break;
5311 
5312   case Instruction::Call:
5313   case Instruction::Invoke:
5314     if (Value *RV = CallSite(U).getReturnedArgOperand())
5315       return getSCEV(RV);
5316     break;
5317   }
5318 
5319   return getUnknown(V);
5320 }
5321 
5322 
5323 
5324 //===----------------------------------------------------------------------===//
5325 //                   Iteration Count Computation Code
5326 //
5327 
5328 static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
5329   if (!ExitCount)
5330     return 0;
5331 
5332   ConstantInt *ExitConst = ExitCount->getValue();
5333 
5334   // Guard against huge trip counts.
5335   if (ExitConst->getValue().getActiveBits() > 32)
5336     return 0;
5337 
5338   // In case of integer overflow, this returns 0, which is correct.
5339   return ((unsigned)ExitConst->getZExtValue()) + 1;
5340 }
5341 
5342 unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) {
5343   if (BasicBlock *ExitingBB = L->getExitingBlock())
5344     return getSmallConstantTripCount(L, ExitingBB);
5345 
5346   // No trip count information for multiple exits.
5347   return 0;
5348 }
5349 
5350 unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L,
5351                                                     BasicBlock *ExitingBlock) {
5352   assert(ExitingBlock && "Must pass a non-null exiting block!");
5353   assert(L->isLoopExiting(ExitingBlock) &&
5354          "Exiting block must actually branch out of the loop!");
5355   const SCEVConstant *ExitCount =
5356       dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
5357   return getConstantTripCount(ExitCount);
5358 }
5359 
5360 unsigned ScalarEvolution::getSmallConstantMaxTripCount(Loop *L) {
5361   const auto *MaxExitCount =
5362       dyn_cast<SCEVConstant>(getMaxBackedgeTakenCount(L));
5363   return getConstantTripCount(MaxExitCount);
5364 }
5365 
5366 unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) {
5367   if (BasicBlock *ExitingBB = L->getExitingBlock())
5368     return getSmallConstantTripMultiple(L, ExitingBB);
5369 
5370   // No trip multiple information for multiple exits.
5371   return 0;
5372 }
5373 
5374 /// Returns the largest constant divisor of the trip count of this loop as a
5375 /// normal unsigned value, if possible. This means that the actual trip count is
5376 /// always a multiple of the returned value (don't forget the trip count could
5377 /// very well be zero as well!).
5378 ///
5379 /// Returns 1 if the trip count is unknown or not guaranteed to be the
5380 /// multiple of a constant (which is also the case if the trip count is simply
5381 /// constant, use getSmallConstantTripCount for that case), Will also return 1
5382 /// if the trip count is very large (>= 2^32).
5383 ///
5384 /// As explained in the comments for getSmallConstantTripCount, this assumes
5385 /// that control exits the loop via ExitingBlock.
5386 unsigned
5387 ScalarEvolution::getSmallConstantTripMultiple(Loop *L,
5388                                               BasicBlock *ExitingBlock) {
5389   assert(ExitingBlock && "Must pass a non-null exiting block!");
5390   assert(L->isLoopExiting(ExitingBlock) &&
5391          "Exiting block must actually branch out of the loop!");
5392   const SCEV *ExitCount = getExitCount(L, ExitingBlock);
5393   if (ExitCount == getCouldNotCompute())
5394     return 1;
5395 
5396   // Get the trip count from the BE count by adding 1.
5397   const SCEV *TCMul = getAddExpr(ExitCount, getOne(ExitCount->getType()));
5398   // FIXME: SCEV distributes multiplication as V1*C1 + V2*C1. We could attempt
5399   // to factor simple cases.
5400   if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(TCMul))
5401     TCMul = Mul->getOperand(0);
5402 
5403   const SCEVConstant *MulC = dyn_cast<SCEVConstant>(TCMul);
5404   if (!MulC)
5405     return 1;
5406 
5407   ConstantInt *Result = MulC->getValue();
5408 
5409   // Guard against huge trip counts (this requires checking
5410   // for zero to handle the case where the trip count == -1 and the
5411   // addition wraps).
5412   if (!Result || Result->getValue().getActiveBits() > 32 ||
5413       Result->getValue().getActiveBits() == 0)
5414     return 1;
5415 
5416   return (unsigned)Result->getZExtValue();
5417 }
5418 
5419 /// Get the expression for the number of loop iterations for which this loop is
5420 /// guaranteed not to exit via ExitingBlock. Otherwise return
5421 /// SCEVCouldNotCompute.
5422 const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) {
5423   return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
5424 }
5425 
5426 const SCEV *
5427 ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
5428                                                  SCEVUnionPredicate &Preds) {
5429   return getPredicatedBackedgeTakenInfo(L).getExact(this, &Preds);
5430 }
5431 
5432 const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
5433   return getBackedgeTakenInfo(L).getExact(this);
5434 }
5435 
5436 /// Similar to getBackedgeTakenCount, except return the least SCEV value that is
5437 /// known never to be less than the actual backedge taken count.
5438 const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
5439   return getBackedgeTakenInfo(L).getMax(this);
5440 }
5441 
5442 bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
5443   return getBackedgeTakenInfo(L).isMaxOrZero(this);
5444 }
5445 
5446 /// Push PHI nodes in the header of the given loop onto the given Worklist.
5447 static void
5448 PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
5449   BasicBlock *Header = L->getHeader();
5450 
5451   // Push all Loop-header PHIs onto the Worklist stack.
5452   for (BasicBlock::iterator I = Header->begin();
5453        PHINode *PN = dyn_cast<PHINode>(I); ++I)
5454     Worklist.push_back(PN);
5455 }
5456 
5457 const ScalarEvolution::BackedgeTakenInfo &
5458 ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
5459   auto &BTI = getBackedgeTakenInfo(L);
5460   if (BTI.hasFullInfo())
5461     return BTI;
5462 
5463   auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
5464 
5465   if (!Pair.second)
5466     return Pair.first->second;
5467 
5468   BackedgeTakenInfo Result =
5469       computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
5470 
5471   return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
5472 }
5473 
5474 const ScalarEvolution::BackedgeTakenInfo &
5475 ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
5476   // Initially insert an invalid entry for this loop. If the insertion
5477   // succeeds, proceed to actually compute a backedge-taken count and
5478   // update the value. The temporary CouldNotCompute value tells SCEV
5479   // code elsewhere that it shouldn't attempt to request a new
5480   // backedge-taken count, which could result in infinite recursion.
5481   std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
5482       BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
5483   if (!Pair.second)
5484     return Pair.first->second;
5485 
5486   // computeBackedgeTakenCount may allocate memory for its result. Inserting it
5487   // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
5488   // must be cleared in this scope.
5489   BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
5490 
5491   if (Result.getExact(this) != getCouldNotCompute()) {
5492     assert(isLoopInvariant(Result.getExact(this), L) &&
5493            isLoopInvariant(Result.getMax(this), L) &&
5494            "Computed backedge-taken count isn't loop invariant for loop!");
5495     ++NumTripCountsComputed;
5496   }
5497   else if (Result.getMax(this) == getCouldNotCompute() &&
5498            isa<PHINode>(L->getHeader()->begin())) {
5499     // Only count loops that have phi nodes as not being computable.
5500     ++NumTripCountsNotComputed;
5501   }
5502 
5503   // Now that we know more about the trip count for this loop, forget any
5504   // existing SCEV values for PHI nodes in this loop since they are only
5505   // conservative estimates made without the benefit of trip count
5506   // information. This is similar to the code in forgetLoop, except that
5507   // it handles SCEVUnknown PHI nodes specially.
5508   if (Result.hasAnyInfo()) {
5509     SmallVector<Instruction *, 16> Worklist;
5510     PushLoopPHIs(L, Worklist);
5511 
5512     SmallPtrSet<Instruction *, 8> Visited;
5513     while (!Worklist.empty()) {
5514       Instruction *I = Worklist.pop_back_val();
5515       if (!Visited.insert(I).second)
5516         continue;
5517 
5518       ValueExprMapType::iterator It =
5519         ValueExprMap.find_as(static_cast<Value *>(I));
5520       if (It != ValueExprMap.end()) {
5521         const SCEV *Old = It->second;
5522 
5523         // SCEVUnknown for a PHI either means that it has an unrecognized
5524         // structure, or it's a PHI that's in the progress of being computed
5525         // by createNodeForPHI.  In the former case, additional loop trip
5526         // count information isn't going to change anything. In the later
5527         // case, createNodeForPHI will perform the necessary updates on its
5528         // own when it gets to that point.
5529         if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) {
5530           eraseValueFromMap(It->first);
5531           forgetMemoizedResults(Old);
5532         }
5533         if (PHINode *PN = dyn_cast<PHINode>(I))
5534           ConstantEvolutionLoopExitValue.erase(PN);
5535       }
5536 
5537       PushDefUseChildren(I, Worklist);
5538     }
5539   }
5540 
5541   // Re-lookup the insert position, since the call to
5542   // computeBackedgeTakenCount above could result in a
5543   // recusive call to getBackedgeTakenInfo (on a different
5544   // loop), which would invalidate the iterator computed
5545   // earlier.
5546   return BackedgeTakenCounts.find(L)->second = std::move(Result);
5547 }
5548 
5549 void ScalarEvolution::forgetLoop(const Loop *L) {
5550   // Drop any stored trip count value.
5551   auto RemoveLoopFromBackedgeMap =
5552       [L](DenseMap<const Loop *, BackedgeTakenInfo> &Map) {
5553         auto BTCPos = Map.find(L);
5554         if (BTCPos != Map.end()) {
5555           BTCPos->second.clear();
5556           Map.erase(BTCPos);
5557         }
5558       };
5559 
5560   RemoveLoopFromBackedgeMap(BackedgeTakenCounts);
5561   RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts);
5562 
5563   // Drop information about expressions based on loop-header PHIs.
5564   SmallVector<Instruction *, 16> Worklist;
5565   PushLoopPHIs(L, Worklist);
5566 
5567   SmallPtrSet<Instruction *, 8> Visited;
5568   while (!Worklist.empty()) {
5569     Instruction *I = Worklist.pop_back_val();
5570     if (!Visited.insert(I).second)
5571       continue;
5572 
5573     ValueExprMapType::iterator It =
5574       ValueExprMap.find_as(static_cast<Value *>(I));
5575     if (It != ValueExprMap.end()) {
5576       eraseValueFromMap(It->first);
5577       forgetMemoizedResults(It->second);
5578       if (PHINode *PN = dyn_cast<PHINode>(I))
5579         ConstantEvolutionLoopExitValue.erase(PN);
5580     }
5581 
5582     PushDefUseChildren(I, Worklist);
5583   }
5584 
5585   // Forget all contained loops too, to avoid dangling entries in the
5586   // ValuesAtScopes map.
5587   for (Loop *I : *L)
5588     forgetLoop(I);
5589 
5590   LoopPropertiesCache.erase(L);
5591 }
5592 
5593 void ScalarEvolution::forgetValue(Value *V) {
5594   Instruction *I = dyn_cast<Instruction>(V);
5595   if (!I) return;
5596 
5597   // Drop information about expressions based on loop-header PHIs.
5598   SmallVector<Instruction *, 16> Worklist;
5599   Worklist.push_back(I);
5600 
5601   SmallPtrSet<Instruction *, 8> Visited;
5602   while (!Worklist.empty()) {
5603     I = Worklist.pop_back_val();
5604     if (!Visited.insert(I).second)
5605       continue;
5606 
5607     ValueExprMapType::iterator It =
5608       ValueExprMap.find_as(static_cast<Value *>(I));
5609     if (It != ValueExprMap.end()) {
5610       eraseValueFromMap(It->first);
5611       forgetMemoizedResults(It->second);
5612       if (PHINode *PN = dyn_cast<PHINode>(I))
5613         ConstantEvolutionLoopExitValue.erase(PN);
5614     }
5615 
5616     PushDefUseChildren(I, Worklist);
5617   }
5618 }
5619 
5620 /// Get the exact loop backedge taken count considering all loop exits. A
5621 /// computable result can only be returned for loops with a single exit.
5622 /// Returning the minimum taken count among all exits is incorrect because one
5623 /// of the loop's exit limit's may have been skipped. howFarToZero assumes that
5624 /// the limit of each loop test is never skipped. This is a valid assumption as
5625 /// long as the loop exits via that test. For precise results, it is the
5626 /// caller's responsibility to specify the relevant loop exit using
5627 /// getExact(ExitingBlock, SE).
5628 const SCEV *
5629 ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE,
5630                                              SCEVUnionPredicate *Preds) const {
5631   // If any exits were not computable, the loop is not computable.
5632   if (!isComplete() || ExitNotTaken.empty())
5633     return SE->getCouldNotCompute();
5634 
5635   const SCEV *BECount = nullptr;
5636   for (auto &ENT : ExitNotTaken) {
5637     assert(ENT.ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV");
5638 
5639     if (!BECount)
5640       BECount = ENT.ExactNotTaken;
5641     else if (BECount != ENT.ExactNotTaken)
5642       return SE->getCouldNotCompute();
5643     if (Preds && !ENT.hasAlwaysTruePredicate())
5644       Preds->add(ENT.Predicate.get());
5645 
5646     assert((Preds || ENT.hasAlwaysTruePredicate()) &&
5647            "Predicate should be always true!");
5648   }
5649 
5650   assert(BECount && "Invalid not taken count for loop exit");
5651   return BECount;
5652 }
5653 
5654 /// Get the exact not taken count for this loop exit.
5655 const SCEV *
5656 ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock,
5657                                              ScalarEvolution *SE) const {
5658   for (auto &ENT : ExitNotTaken)
5659     if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
5660       return ENT.ExactNotTaken;
5661 
5662   return SE->getCouldNotCompute();
5663 }
5664 
5665 /// getMax - Get the max backedge taken count for the loop.
5666 const SCEV *
5667 ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const {
5668   auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
5669     return !ENT.hasAlwaysTruePredicate();
5670   };
5671 
5672   if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getMax())
5673     return SE->getCouldNotCompute();
5674 
5675   return getMax();
5676 }
5677 
5678 bool ScalarEvolution::BackedgeTakenInfo::isMaxOrZero(ScalarEvolution *SE) const {
5679   auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
5680     return !ENT.hasAlwaysTruePredicate();
5681   };
5682   return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
5683 }
5684 
5685 bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
5686                                                     ScalarEvolution *SE) const {
5687   if (getMax() && getMax() != SE->getCouldNotCompute() &&
5688       SE->hasOperand(getMax(), S))
5689     return true;
5690 
5691   for (auto &ENT : ExitNotTaken)
5692     if (ENT.ExactNotTaken != SE->getCouldNotCompute() &&
5693         SE->hasOperand(ENT.ExactNotTaken, S))
5694       return true;
5695 
5696   return false;
5697 }
5698 
5699 /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
5700 /// computable exit into a persistent ExitNotTakenInfo array.
5701 ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
5702     SmallVectorImpl<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo>
5703         &&ExitCounts,
5704     bool Complete, const SCEV *MaxCount, bool MaxOrZero)
5705     : MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) {
5706   typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo;
5707   ExitNotTaken.reserve(ExitCounts.size());
5708   std::transform(
5709       ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken),
5710       [&](const EdgeExitInfo &EEI) {
5711         BasicBlock *ExitBB = EEI.first;
5712         const ExitLimit &EL = EEI.second;
5713         if (EL.Predicates.empty())
5714           return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, nullptr);
5715 
5716         std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate);
5717         for (auto *Pred : EL.Predicates)
5718           Predicate->add(Pred);
5719 
5720         return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate));
5721       });
5722 }
5723 
5724 /// Invalidate this result and free the ExitNotTakenInfo array.
5725 void ScalarEvolution::BackedgeTakenInfo::clear() {
5726   ExitNotTaken.clear();
5727 }
5728 
5729 /// Compute the number of times the backedge of the specified loop will execute.
5730 ScalarEvolution::BackedgeTakenInfo
5731 ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
5732                                            bool AllowPredicates) {
5733   SmallVector<BasicBlock *, 8> ExitingBlocks;
5734   L->getExitingBlocks(ExitingBlocks);
5735 
5736   typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo;
5737 
5738   SmallVector<EdgeExitInfo, 4> ExitCounts;
5739   bool CouldComputeBECount = true;
5740   BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
5741   const SCEV *MustExitMaxBECount = nullptr;
5742   const SCEV *MayExitMaxBECount = nullptr;
5743   bool MustExitMaxOrZero = false;
5744 
5745   // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
5746   // and compute maxBECount.
5747   // Do a union of all the predicates here.
5748   for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
5749     BasicBlock *ExitBB = ExitingBlocks[i];
5750     ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
5751 
5752     assert((AllowPredicates || EL.Predicates.empty()) &&
5753            "Predicated exit limit when predicates are not allowed!");
5754 
5755     // 1. For each exit that can be computed, add an entry to ExitCounts.
5756     // CouldComputeBECount is true only if all exits can be computed.
5757     if (EL.ExactNotTaken == getCouldNotCompute())
5758       // We couldn't compute an exact value for this exit, so
5759       // we won't be able to compute an exact value for the loop.
5760       CouldComputeBECount = false;
5761     else
5762       ExitCounts.emplace_back(ExitBB, EL);
5763 
5764     // 2. Derive the loop's MaxBECount from each exit's max number of
5765     // non-exiting iterations. Partition the loop exits into two kinds:
5766     // LoopMustExits and LoopMayExits.
5767     //
5768     // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
5769     // is a LoopMayExit.  If any computable LoopMustExit is found, then
5770     // MaxBECount is the minimum EL.MaxNotTaken of computable
5771     // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
5772     // EL.MaxNotTaken, where CouldNotCompute is considered greater than any
5773     // computable EL.MaxNotTaken.
5774     if (EL.MaxNotTaken != getCouldNotCompute() && Latch &&
5775         DT.dominates(ExitBB, Latch)) {
5776       if (!MustExitMaxBECount) {
5777         MustExitMaxBECount = EL.MaxNotTaken;
5778         MustExitMaxOrZero = EL.MaxOrZero;
5779       } else {
5780         MustExitMaxBECount =
5781             getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken);
5782       }
5783     } else if (MayExitMaxBECount != getCouldNotCompute()) {
5784       if (!MayExitMaxBECount || EL.MaxNotTaken == getCouldNotCompute())
5785         MayExitMaxBECount = EL.MaxNotTaken;
5786       else {
5787         MayExitMaxBECount =
5788             getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.MaxNotTaken);
5789       }
5790     }
5791   }
5792   const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
5793     (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
5794   // The loop backedge will be taken the maximum or zero times if there's
5795   // a single exit that must be taken the maximum or zero times.
5796   bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
5797   return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
5798                            MaxBECount, MaxOrZero);
5799 }
5800 
5801 ScalarEvolution::ExitLimit
5802 ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
5803                                   bool AllowPredicates) {
5804 
5805   // Okay, we've chosen an exiting block.  See what condition causes us to exit
5806   // at this block and remember the exit block and whether all other targets
5807   // lead to the loop header.
5808   bool MustExecuteLoopHeader = true;
5809   BasicBlock *Exit = nullptr;
5810   for (auto *SBB : successors(ExitingBlock))
5811     if (!L->contains(SBB)) {
5812       if (Exit) // Multiple exit successors.
5813         return getCouldNotCompute();
5814       Exit = SBB;
5815     } else if (SBB != L->getHeader()) {
5816       MustExecuteLoopHeader = false;
5817     }
5818 
5819   // At this point, we know we have a conditional branch that determines whether
5820   // the loop is exited.  However, we don't know if the branch is executed each
5821   // time through the loop.  If not, then the execution count of the branch will
5822   // not be equal to the trip count of the loop.
5823   //
5824   // Currently we check for this by checking to see if the Exit branch goes to
5825   // the loop header.  If so, we know it will always execute the same number of
5826   // times as the loop.  We also handle the case where the exit block *is* the
5827   // loop header.  This is common for un-rotated loops.
5828   //
5829   // If both of those tests fail, walk up the unique predecessor chain to the
5830   // header, stopping if there is an edge that doesn't exit the loop. If the
5831   // header is reached, the execution count of the branch will be equal to the
5832   // trip count of the loop.
5833   //
5834   //  More extensive analysis could be done to handle more cases here.
5835   //
5836   if (!MustExecuteLoopHeader && ExitingBlock != L->getHeader()) {
5837     // The simple checks failed, try climbing the unique predecessor chain
5838     // up to the header.
5839     bool Ok = false;
5840     for (BasicBlock *BB = ExitingBlock; BB; ) {
5841       BasicBlock *Pred = BB->getUniquePredecessor();
5842       if (!Pred)
5843         return getCouldNotCompute();
5844       TerminatorInst *PredTerm = Pred->getTerminator();
5845       for (const BasicBlock *PredSucc : PredTerm->successors()) {
5846         if (PredSucc == BB)
5847           continue;
5848         // If the predecessor has a successor that isn't BB and isn't
5849         // outside the loop, assume the worst.
5850         if (L->contains(PredSucc))
5851           return getCouldNotCompute();
5852       }
5853       if (Pred == L->getHeader()) {
5854         Ok = true;
5855         break;
5856       }
5857       BB = Pred;
5858     }
5859     if (!Ok)
5860       return getCouldNotCompute();
5861   }
5862 
5863   bool IsOnlyExit = (L->getExitingBlock() != nullptr);
5864   TerminatorInst *Term = ExitingBlock->getTerminator();
5865   if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
5866     assert(BI->isConditional() && "If unconditional, it can't be in loop!");
5867     // Proceed to the next level to examine the exit condition expression.
5868     return computeExitLimitFromCond(
5869         L, BI->getCondition(), BI->getSuccessor(0), BI->getSuccessor(1),
5870         /*ControlsExit=*/IsOnlyExit, AllowPredicates);
5871   }
5872 
5873   if (SwitchInst *SI = dyn_cast<SwitchInst>(Term))
5874     return computeExitLimitFromSingleExitSwitch(L, SI, Exit,
5875                                                 /*ControlsExit=*/IsOnlyExit);
5876 
5877   return getCouldNotCompute();
5878 }
5879 
5880 ScalarEvolution::ExitLimit
5881 ScalarEvolution::computeExitLimitFromCond(const Loop *L,
5882                                           Value *ExitCond,
5883                                           BasicBlock *TBB,
5884                                           BasicBlock *FBB,
5885                                           bool ControlsExit,
5886                                           bool AllowPredicates) {
5887   // Check if the controlling expression for this loop is an And or Or.
5888   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
5889     if (BO->getOpcode() == Instruction::And) {
5890       // Recurse on the operands of the and.
5891       bool EitherMayExit = L->contains(TBB);
5892       ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB,
5893                                                ControlsExit && !EitherMayExit,
5894                                                AllowPredicates);
5895       ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB,
5896                                                ControlsExit && !EitherMayExit,
5897                                                AllowPredicates);
5898       const SCEV *BECount = getCouldNotCompute();
5899       const SCEV *MaxBECount = getCouldNotCompute();
5900       if (EitherMayExit) {
5901         // Both conditions must be true for the loop to continue executing.
5902         // Choose the less conservative count.
5903         if (EL0.ExactNotTaken == getCouldNotCompute() ||
5904             EL1.ExactNotTaken == getCouldNotCompute())
5905           BECount = getCouldNotCompute();
5906         else
5907           BECount =
5908               getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken);
5909         if (EL0.MaxNotTaken == getCouldNotCompute())
5910           MaxBECount = EL1.MaxNotTaken;
5911         else if (EL1.MaxNotTaken == getCouldNotCompute())
5912           MaxBECount = EL0.MaxNotTaken;
5913         else
5914           MaxBECount =
5915               getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
5916       } else {
5917         // Both conditions must be true at the same time for the loop to exit.
5918         // For now, be conservative.
5919         assert(L->contains(FBB) && "Loop block has no successor in loop!");
5920         if (EL0.MaxNotTaken == EL1.MaxNotTaken)
5921           MaxBECount = EL0.MaxNotTaken;
5922         if (EL0.ExactNotTaken == EL1.ExactNotTaken)
5923           BECount = EL0.ExactNotTaken;
5924       }
5925 
5926       // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
5927       // to be more aggressive when computing BECount than when computing
5928       // MaxBECount.  In these cases it is possible for EL0.ExactNotTaken and
5929       // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken
5930       // to not.
5931       if (isa<SCEVCouldNotCompute>(MaxBECount) &&
5932           !isa<SCEVCouldNotCompute>(BECount))
5933         MaxBECount = BECount;
5934 
5935       return ExitLimit(BECount, MaxBECount, false,
5936                        {&EL0.Predicates, &EL1.Predicates});
5937     }
5938     if (BO->getOpcode() == Instruction::Or) {
5939       // Recurse on the operands of the or.
5940       bool EitherMayExit = L->contains(FBB);
5941       ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB,
5942                                                ControlsExit && !EitherMayExit,
5943                                                AllowPredicates);
5944       ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB,
5945                                                ControlsExit && !EitherMayExit,
5946                                                AllowPredicates);
5947       const SCEV *BECount = getCouldNotCompute();
5948       const SCEV *MaxBECount = getCouldNotCompute();
5949       if (EitherMayExit) {
5950         // Both conditions must be false for the loop to continue executing.
5951         // Choose the less conservative count.
5952         if (EL0.ExactNotTaken == getCouldNotCompute() ||
5953             EL1.ExactNotTaken == getCouldNotCompute())
5954           BECount = getCouldNotCompute();
5955         else
5956           BECount =
5957               getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken);
5958         if (EL0.MaxNotTaken == getCouldNotCompute())
5959           MaxBECount = EL1.MaxNotTaken;
5960         else if (EL1.MaxNotTaken == getCouldNotCompute())
5961           MaxBECount = EL0.MaxNotTaken;
5962         else
5963           MaxBECount =
5964               getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
5965       } else {
5966         // Both conditions must be false at the same time for the loop to exit.
5967         // For now, be conservative.
5968         assert(L->contains(TBB) && "Loop block has no successor in loop!");
5969         if (EL0.MaxNotTaken == EL1.MaxNotTaken)
5970           MaxBECount = EL0.MaxNotTaken;
5971         if (EL0.ExactNotTaken == EL1.ExactNotTaken)
5972           BECount = EL0.ExactNotTaken;
5973       }
5974 
5975       return ExitLimit(BECount, MaxBECount, false,
5976                        {&EL0.Predicates, &EL1.Predicates});
5977     }
5978   }
5979 
5980   // With an icmp, it may be feasible to compute an exact backedge-taken count.
5981   // Proceed to the next level to examine the icmp.
5982   if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
5983     ExitLimit EL =
5984         computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit);
5985     if (EL.hasFullInfo() || !AllowPredicates)
5986       return EL;
5987 
5988     // Try again, but use SCEV predicates this time.
5989     return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit,
5990                                     /*AllowPredicates=*/true);
5991   }
5992 
5993   // Check for a constant condition. These are normally stripped out by
5994   // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
5995   // preserve the CFG and is temporarily leaving constant conditions
5996   // in place.
5997   if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
5998     if (L->contains(FBB) == !CI->getZExtValue())
5999       // The backedge is always taken.
6000       return getCouldNotCompute();
6001     else
6002       // The backedge is never taken.
6003       return getZero(CI->getType());
6004   }
6005 
6006   // If it's not an integer or pointer comparison then compute it the hard way.
6007   return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
6008 }
6009 
6010 ScalarEvolution::ExitLimit
6011 ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
6012                                           ICmpInst *ExitCond,
6013                                           BasicBlock *TBB,
6014                                           BasicBlock *FBB,
6015                                           bool ControlsExit,
6016                                           bool AllowPredicates) {
6017 
6018   // If the condition was exit on true, convert the condition to exit on false
6019   ICmpInst::Predicate Cond;
6020   if (!L->contains(FBB))
6021     Cond = ExitCond->getPredicate();
6022   else
6023     Cond = ExitCond->getInversePredicate();
6024 
6025   // Handle common loops like: for (X = "string"; *X; ++X)
6026   if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
6027     if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
6028       ExitLimit ItCnt =
6029         computeLoadConstantCompareExitLimit(LI, RHS, L, Cond);
6030       if (ItCnt.hasAnyInfo())
6031         return ItCnt;
6032     }
6033 
6034   const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
6035   const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
6036 
6037   // Try to evaluate any dependencies out of the loop.
6038   LHS = getSCEVAtScope(LHS, L);
6039   RHS = getSCEVAtScope(RHS, L);
6040 
6041   // At this point, we would like to compute how many iterations of the
6042   // loop the predicate will return true for these inputs.
6043   if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
6044     // If there is a loop-invariant, force it into the RHS.
6045     std::swap(LHS, RHS);
6046     Cond = ICmpInst::getSwappedPredicate(Cond);
6047   }
6048 
6049   // Simplify the operands before analyzing them.
6050   (void)SimplifyICmpOperands(Cond, LHS, RHS);
6051 
6052   // If we have a comparison of a chrec against a constant, try to use value
6053   // ranges to answer this query.
6054   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
6055     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
6056       if (AddRec->getLoop() == L) {
6057         // Form the constant range.
6058         ConstantRange CompRange =
6059             ConstantRange::makeExactICmpRegion(Cond, RHSC->getAPInt());
6060 
6061         const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
6062         if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
6063       }
6064 
6065   switch (Cond) {
6066   case ICmpInst::ICMP_NE: {                     // while (X != Y)
6067     // Convert to: while (X-Y != 0)
6068     ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit,
6069                                 AllowPredicates);
6070     if (EL.hasAnyInfo()) return EL;
6071     break;
6072   }
6073   case ICmpInst::ICMP_EQ: {                     // while (X == Y)
6074     // Convert to: while (X-Y == 0)
6075     ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
6076     if (EL.hasAnyInfo()) return EL;
6077     break;
6078   }
6079   case ICmpInst::ICMP_SLT:
6080   case ICmpInst::ICMP_ULT: {                    // while (X < Y)
6081     bool IsSigned = Cond == ICmpInst::ICMP_SLT;
6082     ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit,
6083                                     AllowPredicates);
6084     if (EL.hasAnyInfo()) return EL;
6085     break;
6086   }
6087   case ICmpInst::ICMP_SGT:
6088   case ICmpInst::ICMP_UGT: {                    // while (X > Y)
6089     bool IsSigned = Cond == ICmpInst::ICMP_SGT;
6090     ExitLimit EL =
6091         howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit,
6092                             AllowPredicates);
6093     if (EL.hasAnyInfo()) return EL;
6094     break;
6095   }
6096   default:
6097     break;
6098   }
6099 
6100   auto *ExhaustiveCount =
6101       computeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
6102 
6103   if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
6104     return ExhaustiveCount;
6105 
6106   return computeShiftCompareExitLimit(ExitCond->getOperand(0),
6107                                       ExitCond->getOperand(1), L, Cond);
6108 }
6109 
6110 ScalarEvolution::ExitLimit
6111 ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
6112                                                       SwitchInst *Switch,
6113                                                       BasicBlock *ExitingBlock,
6114                                                       bool ControlsExit) {
6115   assert(!L->contains(ExitingBlock) && "Not an exiting block!");
6116 
6117   // Give up if the exit is the default dest of a switch.
6118   if (Switch->getDefaultDest() == ExitingBlock)
6119     return getCouldNotCompute();
6120 
6121   assert(L->contains(Switch->getDefaultDest()) &&
6122          "Default case must not exit the loop!");
6123   const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
6124   const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
6125 
6126   // while (X != Y) --> while (X-Y != 0)
6127   ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit);
6128   if (EL.hasAnyInfo())
6129     return EL;
6130 
6131   return getCouldNotCompute();
6132 }
6133 
6134 static ConstantInt *
6135 EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
6136                                 ScalarEvolution &SE) {
6137   const SCEV *InVal = SE.getConstant(C);
6138   const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
6139   assert(isa<SCEVConstant>(Val) &&
6140          "Evaluation of SCEV at constant didn't fold correctly?");
6141   return cast<SCEVConstant>(Val)->getValue();
6142 }
6143 
6144 /// Given an exit condition of 'icmp op load X, cst', try to see if we can
6145 /// compute the backedge execution count.
6146 ScalarEvolution::ExitLimit
6147 ScalarEvolution::computeLoadConstantCompareExitLimit(
6148   LoadInst *LI,
6149   Constant *RHS,
6150   const Loop *L,
6151   ICmpInst::Predicate predicate) {
6152 
6153   if (LI->isVolatile()) return getCouldNotCompute();
6154 
6155   // Check to see if the loaded pointer is a getelementptr of a global.
6156   // TODO: Use SCEV instead of manually grubbing with GEPs.
6157   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
6158   if (!GEP) return getCouldNotCompute();
6159 
6160   // Make sure that it is really a constant global we are gepping, with an
6161   // initializer, and make sure the first IDX is really 0.
6162   GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
6163   if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() ||
6164       GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
6165       !cast<Constant>(GEP->getOperand(1))->isNullValue())
6166     return getCouldNotCompute();
6167 
6168   // Okay, we allow one non-constant index into the GEP instruction.
6169   Value *VarIdx = nullptr;
6170   std::vector<Constant*> Indexes;
6171   unsigned VarIdxNum = 0;
6172   for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
6173     if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
6174       Indexes.push_back(CI);
6175     } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
6176       if (VarIdx) return getCouldNotCompute();  // Multiple non-constant idx's.
6177       VarIdx = GEP->getOperand(i);
6178       VarIdxNum = i-2;
6179       Indexes.push_back(nullptr);
6180     }
6181 
6182   // Loop-invariant loads may be a byproduct of loop optimization. Skip them.
6183   if (!VarIdx)
6184     return getCouldNotCompute();
6185 
6186   // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
6187   // Check to see if X is a loop variant variable value now.
6188   const SCEV *Idx = getSCEV(VarIdx);
6189   Idx = getSCEVAtScope(Idx, L);
6190 
6191   // We can only recognize very limited forms of loop index expressions, in
6192   // particular, only affine AddRec's like {C1,+,C2}.
6193   const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
6194   if (!IdxExpr || !IdxExpr->isAffine() || isLoopInvariant(IdxExpr, L) ||
6195       !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
6196       !isa<SCEVConstant>(IdxExpr->getOperand(1)))
6197     return getCouldNotCompute();
6198 
6199   unsigned MaxSteps = MaxBruteForceIterations;
6200   for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
6201     ConstantInt *ItCst = ConstantInt::get(
6202                            cast<IntegerType>(IdxExpr->getType()), IterationNum);
6203     ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
6204 
6205     // Form the GEP offset.
6206     Indexes[VarIdxNum] = Val;
6207 
6208     Constant *Result = ConstantFoldLoadThroughGEPIndices(GV->getInitializer(),
6209                                                          Indexes);
6210     if (!Result) break;  // Cannot compute!
6211 
6212     // Evaluate the condition for this iteration.
6213     Result = ConstantExpr::getICmp(predicate, Result, RHS);
6214     if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
6215     if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
6216       ++NumArrayLenItCounts;
6217       return getConstant(ItCst);   // Found terminating iteration!
6218     }
6219   }
6220   return getCouldNotCompute();
6221 }
6222 
6223 ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
6224     Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
6225   ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
6226   if (!RHS)
6227     return getCouldNotCompute();
6228 
6229   const BasicBlock *Latch = L->getLoopLatch();
6230   if (!Latch)
6231     return getCouldNotCompute();
6232 
6233   const BasicBlock *Predecessor = L->getLoopPredecessor();
6234   if (!Predecessor)
6235     return getCouldNotCompute();
6236 
6237   // Return true if V is of the form "LHS `shift_op` <positive constant>".
6238   // Return LHS in OutLHS and shift_opt in OutOpCode.
6239   auto MatchPositiveShift =
6240       [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
6241 
6242     using namespace PatternMatch;
6243 
6244     ConstantInt *ShiftAmt;
6245     if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
6246       OutOpCode = Instruction::LShr;
6247     else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
6248       OutOpCode = Instruction::AShr;
6249     else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
6250       OutOpCode = Instruction::Shl;
6251     else
6252       return false;
6253 
6254     return ShiftAmt->getValue().isStrictlyPositive();
6255   };
6256 
6257   // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
6258   //
6259   // loop:
6260   //   %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
6261   //   %iv.shifted = lshr i32 %iv, <positive constant>
6262   //
6263   // Return true on a succesful match.  Return the corresponding PHI node (%iv
6264   // above) in PNOut and the opcode of the shift operation in OpCodeOut.
6265   auto MatchShiftRecurrence =
6266       [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
6267     Optional<Instruction::BinaryOps> PostShiftOpCode;
6268 
6269     {
6270       Instruction::BinaryOps OpC;
6271       Value *V;
6272 
6273       // If we encounter a shift instruction, "peel off" the shift operation,
6274       // and remember that we did so.  Later when we inspect %iv's backedge
6275       // value, we will make sure that the backedge value uses the same
6276       // operation.
6277       //
6278       // Note: the peeled shift operation does not have to be the same
6279       // instruction as the one feeding into the PHI's backedge value.  We only
6280       // really care about it being the same *kind* of shift instruction --
6281       // that's all that is required for our later inferences to hold.
6282       if (MatchPositiveShift(LHS, V, OpC)) {
6283         PostShiftOpCode = OpC;
6284         LHS = V;
6285       }
6286     }
6287 
6288     PNOut = dyn_cast<PHINode>(LHS);
6289     if (!PNOut || PNOut->getParent() != L->getHeader())
6290       return false;
6291 
6292     Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
6293     Value *OpLHS;
6294 
6295     return
6296         // The backedge value for the PHI node must be a shift by a positive
6297         // amount
6298         MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
6299 
6300         // of the PHI node itself
6301         OpLHS == PNOut &&
6302 
6303         // and the kind of shift should be match the kind of shift we peeled
6304         // off, if any.
6305         (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut);
6306   };
6307 
6308   PHINode *PN;
6309   Instruction::BinaryOps OpCode;
6310   if (!MatchShiftRecurrence(LHS, PN, OpCode))
6311     return getCouldNotCompute();
6312 
6313   const DataLayout &DL = getDataLayout();
6314 
6315   // The key rationale for this optimization is that for some kinds of shift
6316   // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
6317   // within a finite number of iterations.  If the condition guarding the
6318   // backedge (in the sense that the backedge is taken if the condition is true)
6319   // is false for the value the shift recurrence stabilizes to, then we know
6320   // that the backedge is taken only a finite number of times.
6321 
6322   ConstantInt *StableValue = nullptr;
6323   switch (OpCode) {
6324   default:
6325     llvm_unreachable("Impossible case!");
6326 
6327   case Instruction::AShr: {
6328     // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
6329     // bitwidth(K) iterations.
6330     Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
6331     bool KnownZero, KnownOne;
6332     ComputeSignBit(FirstValue, KnownZero, KnownOne, DL, 0, nullptr,
6333                    Predecessor->getTerminator(), &DT);
6334     auto *Ty = cast<IntegerType>(RHS->getType());
6335     if (KnownZero)
6336       StableValue = ConstantInt::get(Ty, 0);
6337     else if (KnownOne)
6338       StableValue = ConstantInt::get(Ty, -1, true);
6339     else
6340       return getCouldNotCompute();
6341 
6342     break;
6343   }
6344   case Instruction::LShr:
6345   case Instruction::Shl:
6346     // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
6347     // stabilize to 0 in at most bitwidth(K) iterations.
6348     StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
6349     break;
6350   }
6351 
6352   auto *Result =
6353       ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
6354   assert(Result->getType()->isIntegerTy(1) &&
6355          "Otherwise cannot be an operand to a branch instruction");
6356 
6357   if (Result->isZeroValue()) {
6358     unsigned BitWidth = getTypeSizeInBits(RHS->getType());
6359     const SCEV *UpperBound =
6360         getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
6361     return ExitLimit(getCouldNotCompute(), UpperBound, false);
6362   }
6363 
6364   return getCouldNotCompute();
6365 }
6366 
6367 /// Return true if we can constant fold an instruction of the specified type,
6368 /// assuming that all operands were constants.
6369 static bool CanConstantFold(const Instruction *I) {
6370   if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
6371       isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
6372       isa<LoadInst>(I))
6373     return true;
6374 
6375   if (const CallInst *CI = dyn_cast<CallInst>(I))
6376     if (const Function *F = CI->getCalledFunction())
6377       return canConstantFoldCallTo(F);
6378   return false;
6379 }
6380 
6381 /// Determine whether this instruction can constant evolve within this loop
6382 /// assuming its operands can all constant evolve.
6383 static bool canConstantEvolve(Instruction *I, const Loop *L) {
6384   // An instruction outside of the loop can't be derived from a loop PHI.
6385   if (!L->contains(I)) return false;
6386 
6387   if (isa<PHINode>(I)) {
6388     // We don't currently keep track of the control flow needed to evaluate
6389     // PHIs, so we cannot handle PHIs inside of loops.
6390     return L->getHeader() == I->getParent();
6391   }
6392 
6393   // If we won't be able to constant fold this expression even if the operands
6394   // are constants, bail early.
6395   return CanConstantFold(I);
6396 }
6397 
6398 /// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
6399 /// recursing through each instruction operand until reaching a loop header phi.
6400 static PHINode *
6401 getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L,
6402                                DenseMap<Instruction *, PHINode *> &PHIMap) {
6403 
6404   // Otherwise, we can evaluate this instruction if all of its operands are
6405   // constant or derived from a PHI node themselves.
6406   PHINode *PHI = nullptr;
6407   for (Value *Op : UseInst->operands()) {
6408     if (isa<Constant>(Op)) continue;
6409 
6410     Instruction *OpInst = dyn_cast<Instruction>(Op);
6411     if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
6412 
6413     PHINode *P = dyn_cast<PHINode>(OpInst);
6414     if (!P)
6415       // If this operand is already visited, reuse the prior result.
6416       // We may have P != PHI if this is the deepest point at which the
6417       // inconsistent paths meet.
6418       P = PHIMap.lookup(OpInst);
6419     if (!P) {
6420       // Recurse and memoize the results, whether a phi is found or not.
6421       // This recursive call invalidates pointers into PHIMap.
6422       P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap);
6423       PHIMap[OpInst] = P;
6424     }
6425     if (!P)
6426       return nullptr;  // Not evolving from PHI
6427     if (PHI && PHI != P)
6428       return nullptr;  // Evolving from multiple different PHIs.
6429     PHI = P;
6430   }
6431   // This is a expression evolving from a constant PHI!
6432   return PHI;
6433 }
6434 
6435 /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
6436 /// in the loop that V is derived from.  We allow arbitrary operations along the
6437 /// way, but the operands of an operation must either be constants or a value
6438 /// derived from a constant PHI.  If this expression does not fit with these
6439 /// constraints, return null.
6440 static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
6441   Instruction *I = dyn_cast<Instruction>(V);
6442   if (!I || !canConstantEvolve(I, L)) return nullptr;
6443 
6444   if (PHINode *PN = dyn_cast<PHINode>(I))
6445     return PN;
6446 
6447   // Record non-constant instructions contained by the loop.
6448   DenseMap<Instruction *, PHINode *> PHIMap;
6449   return getConstantEvolvingPHIOperands(I, L, PHIMap);
6450 }
6451 
6452 /// EvaluateExpression - Given an expression that passes the
6453 /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
6454 /// in the loop has the value PHIVal.  If we can't fold this expression for some
6455 /// reason, return null.
6456 static Constant *EvaluateExpression(Value *V, const Loop *L,
6457                                     DenseMap<Instruction *, Constant *> &Vals,
6458                                     const DataLayout &DL,
6459                                     const TargetLibraryInfo *TLI) {
6460   // Convenient constant check, but redundant for recursive calls.
6461   if (Constant *C = dyn_cast<Constant>(V)) return C;
6462   Instruction *I = dyn_cast<Instruction>(V);
6463   if (!I) return nullptr;
6464 
6465   if (Constant *C = Vals.lookup(I)) return C;
6466 
6467   // An instruction inside the loop depends on a value outside the loop that we
6468   // weren't given a mapping for, or a value such as a call inside the loop.
6469   if (!canConstantEvolve(I, L)) return nullptr;
6470 
6471   // An unmapped PHI can be due to a branch or another loop inside this loop,
6472   // or due to this not being the initial iteration through a loop where we
6473   // couldn't compute the evolution of this particular PHI last time.
6474   if (isa<PHINode>(I)) return nullptr;
6475 
6476   std::vector<Constant*> Operands(I->getNumOperands());
6477 
6478   for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
6479     Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
6480     if (!Operand) {
6481       Operands[i] = dyn_cast<Constant>(I->getOperand(i));
6482       if (!Operands[i]) return nullptr;
6483       continue;
6484     }
6485     Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
6486     Vals[Operand] = C;
6487     if (!C) return nullptr;
6488     Operands[i] = C;
6489   }
6490 
6491   if (CmpInst *CI = dyn_cast<CmpInst>(I))
6492     return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
6493                                            Operands[1], DL, TLI);
6494   if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
6495     if (!LI->isVolatile())
6496       return ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL);
6497   }
6498   return ConstantFoldInstOperands(I, Operands, DL, TLI);
6499 }
6500 
6501 
6502 // If every incoming value to PN except the one for BB is a specific Constant,
6503 // return that, else return nullptr.
6504 static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) {
6505   Constant *IncomingVal = nullptr;
6506 
6507   for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
6508     if (PN->getIncomingBlock(i) == BB)
6509       continue;
6510 
6511     auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
6512     if (!CurrentVal)
6513       return nullptr;
6514 
6515     if (IncomingVal != CurrentVal) {
6516       if (IncomingVal)
6517         return nullptr;
6518       IncomingVal = CurrentVal;
6519     }
6520   }
6521 
6522   return IncomingVal;
6523 }
6524 
6525 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
6526 /// in the header of its containing loop, we know the loop executes a
6527 /// constant number of times, and the PHI node is just a recurrence
6528 /// involving constants, fold it.
6529 Constant *
6530 ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
6531                                                    const APInt &BEs,
6532                                                    const Loop *L) {
6533   auto I = ConstantEvolutionLoopExitValue.find(PN);
6534   if (I != ConstantEvolutionLoopExitValue.end())
6535     return I->second;
6536 
6537   if (BEs.ugt(MaxBruteForceIterations))
6538     return ConstantEvolutionLoopExitValue[PN] = nullptr;  // Not going to evaluate it.
6539 
6540   Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
6541 
6542   DenseMap<Instruction *, Constant *> CurrentIterVals;
6543   BasicBlock *Header = L->getHeader();
6544   assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
6545 
6546   BasicBlock *Latch = L->getLoopLatch();
6547   if (!Latch)
6548     return nullptr;
6549 
6550   for (auto &I : *Header) {
6551     PHINode *PHI = dyn_cast<PHINode>(&I);
6552     if (!PHI) break;
6553     auto *StartCST = getOtherIncomingValue(PHI, Latch);
6554     if (!StartCST) continue;
6555     CurrentIterVals[PHI] = StartCST;
6556   }
6557   if (!CurrentIterVals.count(PN))
6558     return RetVal = nullptr;
6559 
6560   Value *BEValue = PN->getIncomingValueForBlock(Latch);
6561 
6562   // Execute the loop symbolically to determine the exit value.
6563   if (BEs.getActiveBits() >= 32)
6564     return RetVal = nullptr; // More than 2^32-1 iterations?? Not doing it!
6565 
6566   unsigned NumIterations = BEs.getZExtValue(); // must be in range
6567   unsigned IterationNum = 0;
6568   const DataLayout &DL = getDataLayout();
6569   for (; ; ++IterationNum) {
6570     if (IterationNum == NumIterations)
6571       return RetVal = CurrentIterVals[PN];  // Got exit value!
6572 
6573     // Compute the value of the PHIs for the next iteration.
6574     // EvaluateExpression adds non-phi values to the CurrentIterVals map.
6575     DenseMap<Instruction *, Constant *> NextIterVals;
6576     Constant *NextPHI =
6577         EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
6578     if (!NextPHI)
6579       return nullptr;        // Couldn't evaluate!
6580     NextIterVals[PN] = NextPHI;
6581 
6582     bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
6583 
6584     // Also evaluate the other PHI nodes.  However, we don't get to stop if we
6585     // cease to be able to evaluate one of them or if they stop evolving,
6586     // because that doesn't necessarily prevent us from computing PN.
6587     SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute;
6588     for (const auto &I : CurrentIterVals) {
6589       PHINode *PHI = dyn_cast<PHINode>(I.first);
6590       if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
6591       PHIsToCompute.emplace_back(PHI, I.second);
6592     }
6593     // We use two distinct loops because EvaluateExpression may invalidate any
6594     // iterators into CurrentIterVals.
6595     for (const auto &I : PHIsToCompute) {
6596       PHINode *PHI = I.first;
6597       Constant *&NextPHI = NextIterVals[PHI];
6598       if (!NextPHI) {   // Not already computed.
6599         Value *BEValue = PHI->getIncomingValueForBlock(Latch);
6600         NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
6601       }
6602       if (NextPHI != I.second)
6603         StoppedEvolving = false;
6604     }
6605 
6606     // If all entries in CurrentIterVals == NextIterVals then we can stop
6607     // iterating, the loop can't continue to change.
6608     if (StoppedEvolving)
6609       return RetVal = CurrentIterVals[PN];
6610 
6611     CurrentIterVals.swap(NextIterVals);
6612   }
6613 }
6614 
6615 const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
6616                                                           Value *Cond,
6617                                                           bool ExitWhen) {
6618   PHINode *PN = getConstantEvolvingPHI(Cond, L);
6619   if (!PN) return getCouldNotCompute();
6620 
6621   // If the loop is canonicalized, the PHI will have exactly two entries.
6622   // That's the only form we support here.
6623   if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
6624 
6625   DenseMap<Instruction *, Constant *> CurrentIterVals;
6626   BasicBlock *Header = L->getHeader();
6627   assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
6628 
6629   BasicBlock *Latch = L->getLoopLatch();
6630   assert(Latch && "Should follow from NumIncomingValues == 2!");
6631 
6632   for (auto &I : *Header) {
6633     PHINode *PHI = dyn_cast<PHINode>(&I);
6634     if (!PHI)
6635       break;
6636     auto *StartCST = getOtherIncomingValue(PHI, Latch);
6637     if (!StartCST) continue;
6638     CurrentIterVals[PHI] = StartCST;
6639   }
6640   if (!CurrentIterVals.count(PN))
6641     return getCouldNotCompute();
6642 
6643   // Okay, we find a PHI node that defines the trip count of this loop.  Execute
6644   // the loop symbolically to determine when the condition gets a value of
6645   // "ExitWhen".
6646   unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
6647   const DataLayout &DL = getDataLayout();
6648   for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
6649     auto *CondVal = dyn_cast_or_null<ConstantInt>(
6650         EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
6651 
6652     // Couldn't symbolically evaluate.
6653     if (!CondVal) return getCouldNotCompute();
6654 
6655     if (CondVal->getValue() == uint64_t(ExitWhen)) {
6656       ++NumBruteForceTripCountsComputed;
6657       return getConstant(Type::getInt32Ty(getContext()), IterationNum);
6658     }
6659 
6660     // Update all the PHI nodes for the next iteration.
6661     DenseMap<Instruction *, Constant *> NextIterVals;
6662 
6663     // Create a list of which PHIs we need to compute. We want to do this before
6664     // calling EvaluateExpression on them because that may invalidate iterators
6665     // into CurrentIterVals.
6666     SmallVector<PHINode *, 8> PHIsToCompute;
6667     for (const auto &I : CurrentIterVals) {
6668       PHINode *PHI = dyn_cast<PHINode>(I.first);
6669       if (!PHI || PHI->getParent() != Header) continue;
6670       PHIsToCompute.push_back(PHI);
6671     }
6672     for (PHINode *PHI : PHIsToCompute) {
6673       Constant *&NextPHI = NextIterVals[PHI];
6674       if (NextPHI) continue;    // Already computed!
6675 
6676       Value *BEValue = PHI->getIncomingValueForBlock(Latch);
6677       NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
6678     }
6679     CurrentIterVals.swap(NextIterVals);
6680   }
6681 
6682   // Too many iterations were needed to evaluate.
6683   return getCouldNotCompute();
6684 }
6685 
6686 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
6687   SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
6688       ValuesAtScopes[V];
6689   // Check to see if we've folded this expression at this loop before.
6690   for (auto &LS : Values)
6691     if (LS.first == L)
6692       return LS.second ? LS.second : V;
6693 
6694   Values.emplace_back(L, nullptr);
6695 
6696   // Otherwise compute it.
6697   const SCEV *C = computeSCEVAtScope(V, L);
6698   for (auto &LS : reverse(ValuesAtScopes[V]))
6699     if (LS.first == L) {
6700       LS.second = C;
6701       break;
6702     }
6703   return C;
6704 }
6705 
6706 /// This builds up a Constant using the ConstantExpr interface.  That way, we
6707 /// will return Constants for objects which aren't represented by a
6708 /// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
6709 /// Returns NULL if the SCEV isn't representable as a Constant.
6710 static Constant *BuildConstantFromSCEV(const SCEV *V) {
6711   switch (static_cast<SCEVTypes>(V->getSCEVType())) {
6712     case scCouldNotCompute:
6713     case scAddRecExpr:
6714       break;
6715     case scConstant:
6716       return cast<SCEVConstant>(V)->getValue();
6717     case scUnknown:
6718       return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
6719     case scSignExtend: {
6720       const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
6721       if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
6722         return ConstantExpr::getSExt(CastOp, SS->getType());
6723       break;
6724     }
6725     case scZeroExtend: {
6726       const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
6727       if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
6728         return ConstantExpr::getZExt(CastOp, SZ->getType());
6729       break;
6730     }
6731     case scTruncate: {
6732       const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
6733       if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
6734         return ConstantExpr::getTrunc(CastOp, ST->getType());
6735       break;
6736     }
6737     case scAddExpr: {
6738       const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
6739       if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) {
6740         if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
6741           unsigned AS = PTy->getAddressSpace();
6742           Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
6743           C = ConstantExpr::getBitCast(C, DestPtrTy);
6744         }
6745         for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) {
6746           Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i));
6747           if (!C2) return nullptr;
6748 
6749           // First pointer!
6750           if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) {
6751             unsigned AS = C2->getType()->getPointerAddressSpace();
6752             std::swap(C, C2);
6753             Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
6754             // The offsets have been converted to bytes.  We can add bytes to an
6755             // i8* by GEP with the byte count in the first index.
6756             C = ConstantExpr::getBitCast(C, DestPtrTy);
6757           }
6758 
6759           // Don't bother trying to sum two pointers. We probably can't
6760           // statically compute a load that results from it anyway.
6761           if (C2->getType()->isPointerTy())
6762             return nullptr;
6763 
6764           if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
6765             if (PTy->getElementType()->isStructTy())
6766               C2 = ConstantExpr::getIntegerCast(
6767                   C2, Type::getInt32Ty(C->getContext()), true);
6768             C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2);
6769           } else
6770             C = ConstantExpr::getAdd(C, C2);
6771         }
6772         return C;
6773       }
6774       break;
6775     }
6776     case scMulExpr: {
6777       const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
6778       if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) {
6779         // Don't bother with pointers at all.
6780         if (C->getType()->isPointerTy()) return nullptr;
6781         for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) {
6782           Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i));
6783           if (!C2 || C2->getType()->isPointerTy()) return nullptr;
6784           C = ConstantExpr::getMul(C, C2);
6785         }
6786         return C;
6787       }
6788       break;
6789     }
6790     case scUDivExpr: {
6791       const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
6792       if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS()))
6793         if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS()))
6794           if (LHS->getType() == RHS->getType())
6795             return ConstantExpr::getUDiv(LHS, RHS);
6796       break;
6797     }
6798     case scSMaxExpr:
6799     case scUMaxExpr:
6800       break; // TODO: smax, umax.
6801   }
6802   return nullptr;
6803 }
6804 
6805 const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
6806   if (isa<SCEVConstant>(V)) return V;
6807 
6808   // If this instruction is evolved from a constant-evolving PHI, compute the
6809   // exit value from the loop without using SCEVs.
6810   if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
6811     if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
6812       const Loop *LI = this->LI[I->getParent()];
6813       if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
6814         if (PHINode *PN = dyn_cast<PHINode>(I))
6815           if (PN->getParent() == LI->getHeader()) {
6816             // Okay, there is no closed form solution for the PHI node.  Check
6817             // to see if the loop that contains it has a known backedge-taken
6818             // count.  If so, we may be able to force computation of the exit
6819             // value.
6820             const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI);
6821             if (const SCEVConstant *BTCC =
6822                   dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
6823               // Okay, we know how many times the containing loop executes.  If
6824               // this is a constant evolving PHI node, get the final value at
6825               // the specified iteration number.
6826               Constant *RV =
6827                   getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), LI);
6828               if (RV) return getSCEV(RV);
6829             }
6830           }
6831 
6832       // Okay, this is an expression that we cannot symbolically evaluate
6833       // into a SCEV.  Check to see if it's possible to symbolically evaluate
6834       // the arguments into constants, and if so, try to constant propagate the
6835       // result.  This is particularly useful for computing loop exit values.
6836       if (CanConstantFold(I)) {
6837         SmallVector<Constant *, 4> Operands;
6838         bool MadeImprovement = false;
6839         for (Value *Op : I->operands()) {
6840           if (Constant *C = dyn_cast<Constant>(Op)) {
6841             Operands.push_back(C);
6842             continue;
6843           }
6844 
6845           // If any of the operands is non-constant and if they are
6846           // non-integer and non-pointer, don't even try to analyze them
6847           // with scev techniques.
6848           if (!isSCEVable(Op->getType()))
6849             return V;
6850 
6851           const SCEV *OrigV = getSCEV(Op);
6852           const SCEV *OpV = getSCEVAtScope(OrigV, L);
6853           MadeImprovement |= OrigV != OpV;
6854 
6855           Constant *C = BuildConstantFromSCEV(OpV);
6856           if (!C) return V;
6857           if (C->getType() != Op->getType())
6858             C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
6859                                                               Op->getType(),
6860                                                               false),
6861                                       C, Op->getType());
6862           Operands.push_back(C);
6863         }
6864 
6865         // Check to see if getSCEVAtScope actually made an improvement.
6866         if (MadeImprovement) {
6867           Constant *C = nullptr;
6868           const DataLayout &DL = getDataLayout();
6869           if (const CmpInst *CI = dyn_cast<CmpInst>(I))
6870             C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
6871                                                 Operands[1], DL, &TLI);
6872           else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) {
6873             if (!LI->isVolatile())
6874               C = ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL);
6875           } else
6876             C = ConstantFoldInstOperands(I, Operands, DL, &TLI);
6877           if (!C) return V;
6878           return getSCEV(C);
6879         }
6880       }
6881     }
6882 
6883     // This is some other type of SCEVUnknown, just return it.
6884     return V;
6885   }
6886 
6887   if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
6888     // Avoid performing the look-up in the common case where the specified
6889     // expression has no loop-variant portions.
6890     for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
6891       const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
6892       if (OpAtScope != Comm->getOperand(i)) {
6893         // Okay, at least one of these operands is loop variant but might be
6894         // foldable.  Build a new instance of the folded commutative expression.
6895         SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
6896                                             Comm->op_begin()+i);
6897         NewOps.push_back(OpAtScope);
6898 
6899         for (++i; i != e; ++i) {
6900           OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
6901           NewOps.push_back(OpAtScope);
6902         }
6903         if (isa<SCEVAddExpr>(Comm))
6904           return getAddExpr(NewOps);
6905         if (isa<SCEVMulExpr>(Comm))
6906           return getMulExpr(NewOps);
6907         if (isa<SCEVSMaxExpr>(Comm))
6908           return getSMaxExpr(NewOps);
6909         if (isa<SCEVUMaxExpr>(Comm))
6910           return getUMaxExpr(NewOps);
6911         llvm_unreachable("Unknown commutative SCEV type!");
6912       }
6913     }
6914     // If we got here, all operands are loop invariant.
6915     return Comm;
6916   }
6917 
6918   if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
6919     const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
6920     const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
6921     if (LHS == Div->getLHS() && RHS == Div->getRHS())
6922       return Div;   // must be loop invariant
6923     return getUDivExpr(LHS, RHS);
6924   }
6925 
6926   // If this is a loop recurrence for a loop that does not contain L, then we
6927   // are dealing with the final value computed by the loop.
6928   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
6929     // First, attempt to evaluate each operand.
6930     // Avoid performing the look-up in the common case where the specified
6931     // expression has no loop-variant portions.
6932     for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
6933       const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
6934       if (OpAtScope == AddRec->getOperand(i))
6935         continue;
6936 
6937       // Okay, at least one of these operands is loop variant but might be
6938       // foldable.  Build a new instance of the folded commutative expression.
6939       SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(),
6940                                           AddRec->op_begin()+i);
6941       NewOps.push_back(OpAtScope);
6942       for (++i; i != e; ++i)
6943         NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
6944 
6945       const SCEV *FoldedRec =
6946         getAddRecExpr(NewOps, AddRec->getLoop(),
6947                       AddRec->getNoWrapFlags(SCEV::FlagNW));
6948       AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
6949       // The addrec may be folded to a nonrecurrence, for example, if the
6950       // induction variable is multiplied by zero after constant folding. Go
6951       // ahead and return the folded value.
6952       if (!AddRec)
6953         return FoldedRec;
6954       break;
6955     }
6956 
6957     // If the scope is outside the addrec's loop, evaluate it by using the
6958     // loop exit value of the addrec.
6959     if (!AddRec->getLoop()->contains(L)) {
6960       // To evaluate this recurrence, we need to know how many times the AddRec
6961       // loop iterates.  Compute this now.
6962       const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
6963       if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
6964 
6965       // Then, evaluate the AddRec.
6966       return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
6967     }
6968 
6969     return AddRec;
6970   }
6971 
6972   if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) {
6973     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
6974     if (Op == Cast->getOperand())
6975       return Cast;  // must be loop invariant
6976     return getZeroExtendExpr(Op, Cast->getType());
6977   }
6978 
6979   if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) {
6980     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
6981     if (Op == Cast->getOperand())
6982       return Cast;  // must be loop invariant
6983     return getSignExtendExpr(Op, Cast->getType());
6984   }
6985 
6986   if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) {
6987     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
6988     if (Op == Cast->getOperand())
6989       return Cast;  // must be loop invariant
6990     return getTruncateExpr(Op, Cast->getType());
6991   }
6992 
6993   llvm_unreachable("Unknown SCEV type!");
6994 }
6995 
6996 const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
6997   return getSCEVAtScope(getSCEV(V), L);
6998 }
6999 
7000 /// Finds the minimum unsigned root of the following equation:
7001 ///
7002 ///     A * X = B (mod N)
7003 ///
7004 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of
7005 /// A and B isn't important.
7006 ///
7007 /// If the equation does not have a solution, SCEVCouldNotCompute is returned.
7008 static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
7009                                                ScalarEvolution &SE) {
7010   uint32_t BW = A.getBitWidth();
7011   assert(BW == B.getBitWidth() && "Bit widths must be the same.");
7012   assert(A != 0 && "A must be non-zero.");
7013 
7014   // 1. D = gcd(A, N)
7015   //
7016   // The gcd of A and N may have only one prime factor: 2. The number of
7017   // trailing zeros in A is its multiplicity
7018   uint32_t Mult2 = A.countTrailingZeros();
7019   // D = 2^Mult2
7020 
7021   // 2. Check if B is divisible by D.
7022   //
7023   // B is divisible by D if and only if the multiplicity of prime factor 2 for B
7024   // is not less than multiplicity of this prime factor for D.
7025   if (B.countTrailingZeros() < Mult2)
7026     return SE.getCouldNotCompute();
7027 
7028   // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
7029   // modulo (N / D).
7030   //
7031   // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
7032   // bit width during computations.
7033   APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
7034   APInt Mod(BW + 1, 0);
7035   Mod.setBit(BW - Mult2);  // Mod = N / D
7036   APInt I = AD.multiplicativeInverse(Mod);
7037 
7038   // 4. Compute the minimum unsigned root of the equation:
7039   // I * (B / D) mod (N / D)
7040   APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
7041 
7042   // The result is guaranteed to be less than 2^BW so we may truncate it to BW
7043   // bits.
7044   return SE.getConstant(Result.trunc(BW));
7045 }
7046 
7047 /// Find the roots of the quadratic equation for the given quadratic chrec
7048 /// {L,+,M,+,N}.  This returns either the two roots (which might be the same) or
7049 /// two SCEVCouldNotCompute objects.
7050 ///
7051 static Optional<std::pair<const SCEVConstant *,const SCEVConstant *>>
7052 SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
7053   assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
7054   const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
7055   const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
7056   const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
7057 
7058   // We currently can only solve this if the coefficients are constants.
7059   if (!LC || !MC || !NC)
7060     return None;
7061 
7062   uint32_t BitWidth = LC->getAPInt().getBitWidth();
7063   const APInt &L = LC->getAPInt();
7064   const APInt &M = MC->getAPInt();
7065   const APInt &N = NC->getAPInt();
7066   APInt Two(BitWidth, 2);
7067   APInt Four(BitWidth, 4);
7068 
7069   {
7070     using namespace APIntOps;
7071     const APInt& C = L;
7072     // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
7073     // The B coefficient is M-N/2
7074     APInt B(M);
7075     B -= sdiv(N,Two);
7076 
7077     // The A coefficient is N/2
7078     APInt A(N.sdiv(Two));
7079 
7080     // Compute the B^2-4ac term.
7081     APInt SqrtTerm(B);
7082     SqrtTerm *= B;
7083     SqrtTerm -= Four * (A * C);
7084 
7085     if (SqrtTerm.isNegative()) {
7086       // The loop is provably infinite.
7087       return None;
7088     }
7089 
7090     // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
7091     // integer value or else APInt::sqrt() will assert.
7092     APInt SqrtVal(SqrtTerm.sqrt());
7093 
7094     // Compute the two solutions for the quadratic formula.
7095     // The divisions must be performed as signed divisions.
7096     APInt NegB(-B);
7097     APInt TwoA(A << 1);
7098     if (TwoA.isMinValue())
7099       return None;
7100 
7101     LLVMContext &Context = SE.getContext();
7102 
7103     ConstantInt *Solution1 =
7104       ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA));
7105     ConstantInt *Solution2 =
7106       ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA));
7107 
7108     return std::make_pair(cast<SCEVConstant>(SE.getConstant(Solution1)),
7109                           cast<SCEVConstant>(SE.getConstant(Solution2)));
7110   } // end APIntOps namespace
7111 }
7112 
7113 ScalarEvolution::ExitLimit
7114 ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
7115                               bool AllowPredicates) {
7116 
7117   // This is only used for loops with a "x != y" exit test. The exit condition
7118   // is now expressed as a single expression, V = x-y. So the exit test is
7119   // effectively V != 0.  We know and take advantage of the fact that this
7120   // expression only being used in a comparison by zero context.
7121 
7122   SmallPtrSet<const SCEVPredicate *, 4> Predicates;
7123   // If the value is a constant
7124   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
7125     // If the value is already zero, the branch will execute zero times.
7126     if (C->getValue()->isZero()) return C;
7127     return getCouldNotCompute();  // Otherwise it will loop infinitely.
7128   }
7129 
7130   const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
7131   if (!AddRec && AllowPredicates)
7132     // Try to make this an AddRec using runtime tests, in the first X
7133     // iterations of this loop, where X is the SCEV expression found by the
7134     // algorithm below.
7135     AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
7136 
7137   if (!AddRec || AddRec->getLoop() != L)
7138     return getCouldNotCompute();
7139 
7140   // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
7141   // the quadratic equation to solve it.
7142   if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
7143     if (auto Roots = SolveQuadraticEquation(AddRec, *this)) {
7144       const SCEVConstant *R1 = Roots->first;
7145       const SCEVConstant *R2 = Roots->second;
7146       // Pick the smallest positive root value.
7147       if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp(
7148               CmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) {
7149         if (!CB->getZExtValue())
7150           std::swap(R1, R2); // R1 is the minimum root now.
7151 
7152         // We can only use this value if the chrec ends up with an exact zero
7153         // value at this index.  When solving for "X*X != 5", for example, we
7154         // should not accept a root of 2.
7155         const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
7156         if (Val->isZero())
7157           // We found a quadratic root!
7158           return ExitLimit(R1, R1, false, Predicates);
7159       }
7160     }
7161     return getCouldNotCompute();
7162   }
7163 
7164   // Otherwise we can only handle this if it is affine.
7165   if (!AddRec->isAffine())
7166     return getCouldNotCompute();
7167 
7168   // If this is an affine expression, the execution count of this branch is
7169   // the minimum unsigned root of the following equation:
7170   //
7171   //     Start + Step*N = 0 (mod 2^BW)
7172   //
7173   // equivalent to:
7174   //
7175   //             Step*N = -Start (mod 2^BW)
7176   //
7177   // where BW is the common bit width of Start and Step.
7178 
7179   // Get the initial value for the loop.
7180   const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
7181   const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
7182 
7183   // For now we handle only constant steps.
7184   //
7185   // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
7186   // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
7187   // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
7188   // We have not yet seen any such cases.
7189   const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
7190   if (!StepC || StepC->getValue()->equalsInt(0))
7191     return getCouldNotCompute();
7192 
7193   // For positive steps (counting up until unsigned overflow):
7194   //   N = -Start/Step (as unsigned)
7195   // For negative steps (counting down to zero):
7196   //   N = Start/-Step
7197   // First compute the unsigned distance from zero in the direction of Step.
7198   bool CountDown = StepC->getAPInt().isNegative();
7199   const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
7200 
7201   // Handle unitary steps, which cannot wraparound.
7202   // 1*N = -Start; -1*N = Start (mod 2^BW), so:
7203   //   N = Distance (as unsigned)
7204   if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue()) {
7205     ConstantRange CR = getUnsignedRange(Start);
7206     const SCEV *MaxBECount;
7207     if (!CountDown && CR.getUnsignedMin().isMinValue())
7208       // When counting up, the worst starting value is 1, not 0.
7209       MaxBECount = CR.getUnsignedMax().isMinValue()
7210         ? getConstant(APInt::getMinValue(CR.getBitWidth()))
7211         : getConstant(APInt::getMaxValue(CR.getBitWidth()));
7212     else
7213       MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
7214                                          : -CR.getUnsignedMin());
7215     return ExitLimit(Distance, MaxBECount, false, Predicates);
7216   }
7217 
7218   // As a special case, handle the instance where Step is a positive power of
7219   // two. In this case, determining whether Step divides Distance evenly can be
7220   // done by counting and comparing the number of trailing zeros of Step and
7221   // Distance.
7222   if (!CountDown) {
7223     const APInt &StepV = StepC->getAPInt();
7224     // StepV.isPowerOf2() returns true if StepV is an positive power of two.  It
7225     // also returns true if StepV is maximally negative (eg, INT_MIN), but that
7226     // case is not handled as this code is guarded by !CountDown.
7227     if (StepV.isPowerOf2() &&
7228         GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) {
7229       // Here we've constrained the equation to be of the form
7230       //
7231       //   2^(N + k) * Distance' = (StepV == 2^N) * X (mod 2^W)  ... (0)
7232       //
7233       // where we're operating on a W bit wide integer domain and k is
7234       // non-negative.  The smallest unsigned solution for X is the trip count.
7235       //
7236       // (0) is equivalent to:
7237       //
7238       //      2^(N + k) * Distance' - 2^N * X = L * 2^W
7239       // <=>  2^N(2^k * Distance' - X) = L * 2^(W - N) * 2^N
7240       // <=>  2^k * Distance' - X = L * 2^(W - N)
7241       // <=>  2^k * Distance'     = L * 2^(W - N) + X    ... (1)
7242       //
7243       // The smallest X satisfying (1) is unsigned remainder of dividing the LHS
7244       // by 2^(W - N).
7245       //
7246       // <=>  X = 2^k * Distance' URem 2^(W - N)   ... (2)
7247       //
7248       // E.g. say we're solving
7249       //
7250       //   2 * Val = 2 * X  (in i8)   ... (3)
7251       //
7252       // then from (2), we get X = Val URem i8 128 (k = 0 in this case).
7253       //
7254       // Note: It is tempting to solve (3) by setting X = Val, but Val is not
7255       // necessarily the smallest unsigned value of X that satisfies (3).
7256       // E.g. if Val is i8 -127 then the smallest value of X that satisfies (3)
7257       // is i8 1, not i8 -127
7258 
7259       const auto *ModuloResult = getUDivExactExpr(Distance, Step);
7260 
7261       // Since SCEV does not have a URem node, we construct one using a truncate
7262       // and a zero extend.
7263 
7264       unsigned NarrowWidth = StepV.getBitWidth() - StepV.countTrailingZeros();
7265       auto *NarrowTy = IntegerType::get(getContext(), NarrowWidth);
7266       auto *WideTy = Distance->getType();
7267 
7268       const SCEV *Limit =
7269           getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy);
7270       return ExitLimit(Limit, Limit, false, Predicates);
7271     }
7272   }
7273 
7274   // If the condition controls loop exit (the loop exits only if the expression
7275   // is true) and the addition is no-wrap we can use unsigned divide to
7276   // compute the backedge count.  In this case, the step may not divide the
7277   // distance, but we don't care because if the condition is "missed" the loop
7278   // will have undefined behavior due to wrapping.
7279   if (ControlsExit && AddRec->hasNoSelfWrap() &&
7280       loopHasNoAbnormalExits(AddRec->getLoop())) {
7281     const SCEV *Exact =
7282         getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
7283     return ExitLimit(Exact, Exact, false, Predicates);
7284   }
7285 
7286   // Then, try to solve the above equation provided that Start is constant.
7287   if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) {
7288     const SCEV *E = SolveLinEquationWithOverflow(
7289         StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this);
7290     return ExitLimit(E, E, false, Predicates);
7291   }
7292   return getCouldNotCompute();
7293 }
7294 
7295 ScalarEvolution::ExitLimit
7296 ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
7297   // Loops that look like: while (X == 0) are very strange indeed.  We don't
7298   // handle them yet except for the trivial case.  This could be expanded in the
7299   // future as needed.
7300 
7301   // If the value is a constant, check to see if it is known to be non-zero
7302   // already.  If so, the backedge will execute zero times.
7303   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
7304     if (!C->getValue()->isNullValue())
7305       return getZero(C->getType());
7306     return getCouldNotCompute();  // Otherwise it will loop infinitely.
7307   }
7308 
7309   // We could implement others, but I really doubt anyone writes loops like
7310   // this, and if they did, they would already be constant folded.
7311   return getCouldNotCompute();
7312 }
7313 
7314 std::pair<BasicBlock *, BasicBlock *>
7315 ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
7316   // If the block has a unique predecessor, then there is no path from the
7317   // predecessor to the block that does not go through the direct edge
7318   // from the predecessor to the block.
7319   if (BasicBlock *Pred = BB->getSinglePredecessor())
7320     return {Pred, BB};
7321 
7322   // A loop's header is defined to be a block that dominates the loop.
7323   // If the header has a unique predecessor outside the loop, it must be
7324   // a block that has exactly one successor that can reach the loop.
7325   if (Loop *L = LI.getLoopFor(BB))
7326     return {L->getLoopPredecessor(), L->getHeader()};
7327 
7328   return {nullptr, nullptr};
7329 }
7330 
7331 /// SCEV structural equivalence is usually sufficient for testing whether two
7332 /// expressions are equal, however for the purposes of looking for a condition
7333 /// guarding a loop, it can be useful to be a little more general, since a
7334 /// front-end may have replicated the controlling expression.
7335 ///
7336 static bool HasSameValue(const SCEV *A, const SCEV *B) {
7337   // Quick check to see if they are the same SCEV.
7338   if (A == B) return true;
7339 
7340   auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
7341     // Not all instructions that are "identical" compute the same value.  For
7342     // instance, two distinct alloca instructions allocating the same type are
7343     // identical and do not read memory; but compute distinct values.
7344     return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
7345   };
7346 
7347   // Otherwise, if they're both SCEVUnknown, it's possible that they hold
7348   // two different instructions with the same value. Check for this case.
7349   if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
7350     if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
7351       if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
7352         if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
7353           if (ComputesEqualValues(AI, BI))
7354             return true;
7355 
7356   // Otherwise assume they may have a different value.
7357   return false;
7358 }
7359 
7360 bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
7361                                            const SCEV *&LHS, const SCEV *&RHS,
7362                                            unsigned Depth) {
7363   bool Changed = false;
7364 
7365   // If we hit the max recursion limit bail out.
7366   if (Depth >= 3)
7367     return false;
7368 
7369   // Canonicalize a constant to the right side.
7370   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
7371     // Check for both operands constant.
7372     if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
7373       if (ConstantExpr::getICmp(Pred,
7374                                 LHSC->getValue(),
7375                                 RHSC->getValue())->isNullValue())
7376         goto trivially_false;
7377       else
7378         goto trivially_true;
7379     }
7380     // Otherwise swap the operands to put the constant on the right.
7381     std::swap(LHS, RHS);
7382     Pred = ICmpInst::getSwappedPredicate(Pred);
7383     Changed = true;
7384   }
7385 
7386   // If we're comparing an addrec with a value which is loop-invariant in the
7387   // addrec's loop, put the addrec on the left. Also make a dominance check,
7388   // as both operands could be addrecs loop-invariant in each other's loop.
7389   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
7390     const Loop *L = AR->getLoop();
7391     if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
7392       std::swap(LHS, RHS);
7393       Pred = ICmpInst::getSwappedPredicate(Pred);
7394       Changed = true;
7395     }
7396   }
7397 
7398   // If there's a constant operand, canonicalize comparisons with boundary
7399   // cases, and canonicalize *-or-equal comparisons to regular comparisons.
7400   if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
7401     const APInt &RA = RC->getAPInt();
7402 
7403     bool SimplifiedByConstantRange = false;
7404 
7405     if (!ICmpInst::isEquality(Pred)) {
7406       ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA);
7407       if (ExactCR.isFullSet())
7408         goto trivially_true;
7409       else if (ExactCR.isEmptySet())
7410         goto trivially_false;
7411 
7412       APInt NewRHS;
7413       CmpInst::Predicate NewPred;
7414       if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
7415           ICmpInst::isEquality(NewPred)) {
7416         // We were able to convert an inequality to an equality.
7417         Pred = NewPred;
7418         RHS = getConstant(NewRHS);
7419         Changed = SimplifiedByConstantRange = true;
7420       }
7421     }
7422 
7423     if (!SimplifiedByConstantRange) {
7424       switch (Pred) {
7425       default:
7426         break;
7427       case ICmpInst::ICMP_EQ:
7428       case ICmpInst::ICMP_NE:
7429         // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
7430         if (!RA)
7431           if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS))
7432             if (const SCEVMulExpr *ME =
7433                     dyn_cast<SCEVMulExpr>(AE->getOperand(0)))
7434               if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 &&
7435                   ME->getOperand(0)->isAllOnesValue()) {
7436                 RHS = AE->getOperand(1);
7437                 LHS = ME->getOperand(1);
7438                 Changed = true;
7439               }
7440         break;
7441 
7442 
7443         // The "Should have been caught earlier!" messages refer to the fact
7444         // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
7445         // should have fired on the corresponding cases, and canonicalized the
7446         // check to trivially_true or trivially_false.
7447 
7448       case ICmpInst::ICMP_UGE:
7449         assert(!RA.isMinValue() && "Should have been caught earlier!");
7450         Pred = ICmpInst::ICMP_UGT;
7451         RHS = getConstant(RA - 1);
7452         Changed = true;
7453         break;
7454       case ICmpInst::ICMP_ULE:
7455         assert(!RA.isMaxValue() && "Should have been caught earlier!");
7456         Pred = ICmpInst::ICMP_ULT;
7457         RHS = getConstant(RA + 1);
7458         Changed = true;
7459         break;
7460       case ICmpInst::ICMP_SGE:
7461         assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
7462         Pred = ICmpInst::ICMP_SGT;
7463         RHS = getConstant(RA - 1);
7464         Changed = true;
7465         break;
7466       case ICmpInst::ICMP_SLE:
7467         assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
7468         Pred = ICmpInst::ICMP_SLT;
7469         RHS = getConstant(RA + 1);
7470         Changed = true;
7471         break;
7472       }
7473     }
7474   }
7475 
7476   // Check for obvious equality.
7477   if (HasSameValue(LHS, RHS)) {
7478     if (ICmpInst::isTrueWhenEqual(Pred))
7479       goto trivially_true;
7480     if (ICmpInst::isFalseWhenEqual(Pred))
7481       goto trivially_false;
7482   }
7483 
7484   // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
7485   // adding or subtracting 1 from one of the operands.
7486   switch (Pred) {
7487   case ICmpInst::ICMP_SLE:
7488     if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) {
7489       RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
7490                        SCEV::FlagNSW);
7491       Pred = ICmpInst::ICMP_SLT;
7492       Changed = true;
7493     } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) {
7494       LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
7495                        SCEV::FlagNSW);
7496       Pred = ICmpInst::ICMP_SLT;
7497       Changed = true;
7498     }
7499     break;
7500   case ICmpInst::ICMP_SGE:
7501     if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) {
7502       RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
7503                        SCEV::FlagNSW);
7504       Pred = ICmpInst::ICMP_SGT;
7505       Changed = true;
7506     } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) {
7507       LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
7508                        SCEV::FlagNSW);
7509       Pred = ICmpInst::ICMP_SGT;
7510       Changed = true;
7511     }
7512     break;
7513   case ICmpInst::ICMP_ULE:
7514     if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) {
7515       RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
7516                        SCEV::FlagNUW);
7517       Pred = ICmpInst::ICMP_ULT;
7518       Changed = true;
7519     } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
7520       LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
7521       Pred = ICmpInst::ICMP_ULT;
7522       Changed = true;
7523     }
7524     break;
7525   case ICmpInst::ICMP_UGE:
7526     if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
7527       RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
7528       Pred = ICmpInst::ICMP_UGT;
7529       Changed = true;
7530     } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
7531       LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
7532                        SCEV::FlagNUW);
7533       Pred = ICmpInst::ICMP_UGT;
7534       Changed = true;
7535     }
7536     break;
7537   default:
7538     break;
7539   }
7540 
7541   // TODO: More simplifications are possible here.
7542 
7543   // Recursively simplify until we either hit a recursion limit or nothing
7544   // changes.
7545   if (Changed)
7546     return SimplifyICmpOperands(Pred, LHS, RHS, Depth+1);
7547 
7548   return Changed;
7549 
7550 trivially_true:
7551   // Return 0 == 0.
7552   LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
7553   Pred = ICmpInst::ICMP_EQ;
7554   return true;
7555 
7556 trivially_false:
7557   // Return 0 != 0.
7558   LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
7559   Pred = ICmpInst::ICMP_NE;
7560   return true;
7561 }
7562 
7563 bool ScalarEvolution::isKnownNegative(const SCEV *S) {
7564   return getSignedRange(S).getSignedMax().isNegative();
7565 }
7566 
7567 bool ScalarEvolution::isKnownPositive(const SCEV *S) {
7568   return getSignedRange(S).getSignedMin().isStrictlyPositive();
7569 }
7570 
7571 bool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
7572   return !getSignedRange(S).getSignedMin().isNegative();
7573 }
7574 
7575 bool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
7576   return !getSignedRange(S).getSignedMax().isStrictlyPositive();
7577 }
7578 
7579 bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
7580   return isKnownNegative(S) || isKnownPositive(S);
7581 }
7582 
7583 bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
7584                                        const SCEV *LHS, const SCEV *RHS) {
7585   // Canonicalize the inputs first.
7586   (void)SimplifyICmpOperands(Pred, LHS, RHS);
7587 
7588   // If LHS or RHS is an addrec, check to see if the condition is true in
7589   // every iteration of the loop.
7590   // If LHS and RHS are both addrec, both conditions must be true in
7591   // every iteration of the loop.
7592   const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
7593   const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
7594   bool LeftGuarded = false;
7595   bool RightGuarded = false;
7596   if (LAR) {
7597     const Loop *L = LAR->getLoop();
7598     if (isLoopEntryGuardedByCond(L, Pred, LAR->getStart(), RHS) &&
7599         isLoopBackedgeGuardedByCond(L, Pred, LAR->getPostIncExpr(*this), RHS)) {
7600       if (!RAR) return true;
7601       LeftGuarded = true;
7602     }
7603   }
7604   if (RAR) {
7605     const Loop *L = RAR->getLoop();
7606     if (isLoopEntryGuardedByCond(L, Pred, LHS, RAR->getStart()) &&
7607         isLoopBackedgeGuardedByCond(L, Pred, LHS, RAR->getPostIncExpr(*this))) {
7608       if (!LAR) return true;
7609       RightGuarded = true;
7610     }
7611   }
7612   if (LeftGuarded && RightGuarded)
7613     return true;
7614 
7615   if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
7616     return true;
7617 
7618   // Otherwise see what can be done with known constant ranges.
7619   return isKnownPredicateViaConstantRanges(Pred, LHS, RHS);
7620 }
7621 
7622 bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS,
7623                                            ICmpInst::Predicate Pred,
7624                                            bool &Increasing) {
7625   bool Result = isMonotonicPredicateImpl(LHS, Pred, Increasing);
7626 
7627 #ifndef NDEBUG
7628   // Verify an invariant: inverting the predicate should turn a monotonically
7629   // increasing change to a monotonically decreasing one, and vice versa.
7630   bool IncreasingSwapped;
7631   bool ResultSwapped = isMonotonicPredicateImpl(
7632       LHS, ICmpInst::getSwappedPredicate(Pred), IncreasingSwapped);
7633 
7634   assert(Result == ResultSwapped && "should be able to analyze both!");
7635   if (ResultSwapped)
7636     assert(Increasing == !IncreasingSwapped &&
7637            "monotonicity should flip as we flip the predicate");
7638 #endif
7639 
7640   return Result;
7641 }
7642 
7643 bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS,
7644                                                ICmpInst::Predicate Pred,
7645                                                bool &Increasing) {
7646 
7647   // A zero step value for LHS means the induction variable is essentially a
7648   // loop invariant value. We don't really depend on the predicate actually
7649   // flipping from false to true (for increasing predicates, and the other way
7650   // around for decreasing predicates), all we care about is that *if* the
7651   // predicate changes then it only changes from false to true.
7652   //
7653   // A zero step value in itself is not very useful, but there may be places
7654   // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
7655   // as general as possible.
7656 
7657   switch (Pred) {
7658   default:
7659     return false; // Conservative answer
7660 
7661   case ICmpInst::ICMP_UGT:
7662   case ICmpInst::ICMP_UGE:
7663   case ICmpInst::ICMP_ULT:
7664   case ICmpInst::ICMP_ULE:
7665     if (!LHS->hasNoUnsignedWrap())
7666       return false;
7667 
7668     Increasing = Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE;
7669     return true;
7670 
7671   case ICmpInst::ICMP_SGT:
7672   case ICmpInst::ICMP_SGE:
7673   case ICmpInst::ICMP_SLT:
7674   case ICmpInst::ICMP_SLE: {
7675     if (!LHS->hasNoSignedWrap())
7676       return false;
7677 
7678     const SCEV *Step = LHS->getStepRecurrence(*this);
7679 
7680     if (isKnownNonNegative(Step)) {
7681       Increasing = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE;
7682       return true;
7683     }
7684 
7685     if (isKnownNonPositive(Step)) {
7686       Increasing = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE;
7687       return true;
7688     }
7689 
7690     return false;
7691   }
7692 
7693   }
7694 
7695   llvm_unreachable("switch has default clause!");
7696 }
7697 
7698 bool ScalarEvolution::isLoopInvariantPredicate(
7699     ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
7700     ICmpInst::Predicate &InvariantPred, const SCEV *&InvariantLHS,
7701     const SCEV *&InvariantRHS) {
7702 
7703   // If there is a loop-invariant, force it into the RHS, otherwise bail out.
7704   if (!isLoopInvariant(RHS, L)) {
7705     if (!isLoopInvariant(LHS, L))
7706       return false;
7707 
7708     std::swap(LHS, RHS);
7709     Pred = ICmpInst::getSwappedPredicate(Pred);
7710   }
7711 
7712   const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
7713   if (!ArLHS || ArLHS->getLoop() != L)
7714     return false;
7715 
7716   bool Increasing;
7717   if (!isMonotonicPredicate(ArLHS, Pred, Increasing))
7718     return false;
7719 
7720   // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
7721   // true as the loop iterates, and the backedge is control dependent on
7722   // "ArLHS `Pred` RHS" == true then we can reason as follows:
7723   //
7724   //   * if the predicate was false in the first iteration then the predicate
7725   //     is never evaluated again, since the loop exits without taking the
7726   //     backedge.
7727   //   * if the predicate was true in the first iteration then it will
7728   //     continue to be true for all future iterations since it is
7729   //     monotonically increasing.
7730   //
7731   // For both the above possibilities, we can replace the loop varying
7732   // predicate with its value on the first iteration of the loop (which is
7733   // loop invariant).
7734   //
7735   // A similar reasoning applies for a monotonically decreasing predicate, by
7736   // replacing true with false and false with true in the above two bullets.
7737 
7738   auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
7739 
7740   if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
7741     return false;
7742 
7743   InvariantPred = Pred;
7744   InvariantLHS = ArLHS->getStart();
7745   InvariantRHS = RHS;
7746   return true;
7747 }
7748 
7749 bool ScalarEvolution::isKnownPredicateViaConstantRanges(
7750     ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
7751   if (HasSameValue(LHS, RHS))
7752     return ICmpInst::isTrueWhenEqual(Pred);
7753 
7754   // This code is split out from isKnownPredicate because it is called from
7755   // within isLoopEntryGuardedByCond.
7756 
7757   auto CheckRanges =
7758       [&](const ConstantRange &RangeLHS, const ConstantRange &RangeRHS) {
7759     return ConstantRange::makeSatisfyingICmpRegion(Pred, RangeRHS)
7760         .contains(RangeLHS);
7761   };
7762 
7763   // The check at the top of the function catches the case where the values are
7764   // known to be equal.
7765   if (Pred == CmpInst::ICMP_EQ)
7766     return false;
7767 
7768   if (Pred == CmpInst::ICMP_NE)
7769     return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)) ||
7770            CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)) ||
7771            isKnownNonZero(getMinusSCEV(LHS, RHS));
7772 
7773   if (CmpInst::isSigned(Pred))
7774     return CheckRanges(getSignedRange(LHS), getSignedRange(RHS));
7775 
7776   return CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS));
7777 }
7778 
7779 bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
7780                                                     const SCEV *LHS,
7781                                                     const SCEV *RHS) {
7782 
7783   // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer.
7784   // Return Y via OutY.
7785   auto MatchBinaryAddToConst =
7786       [this](const SCEV *Result, const SCEV *X, APInt &OutY,
7787              SCEV::NoWrapFlags ExpectedFlags) {
7788     const SCEV *NonConstOp, *ConstOp;
7789     SCEV::NoWrapFlags FlagsPresent;
7790 
7791     if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) ||
7792         !isa<SCEVConstant>(ConstOp) || NonConstOp != X)
7793       return false;
7794 
7795     OutY = cast<SCEVConstant>(ConstOp)->getAPInt();
7796     return (FlagsPresent & ExpectedFlags) == ExpectedFlags;
7797   };
7798 
7799   APInt C;
7800 
7801   switch (Pred) {
7802   default:
7803     break;
7804 
7805   case ICmpInst::ICMP_SGE:
7806     std::swap(LHS, RHS);
7807   case ICmpInst::ICMP_SLE:
7808     // X s<= (X + C)<nsw> if C >= 0
7809     if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative())
7810       return true;
7811 
7812     // (X + C)<nsw> s<= X if C <= 0
7813     if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) &&
7814         !C.isStrictlyPositive())
7815       return true;
7816     break;
7817 
7818   case ICmpInst::ICMP_SGT:
7819     std::swap(LHS, RHS);
7820   case ICmpInst::ICMP_SLT:
7821     // X s< (X + C)<nsw> if C > 0
7822     if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) &&
7823         C.isStrictlyPositive())
7824       return true;
7825 
7826     // (X + C)<nsw> s< X if C < 0
7827     if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
7828       return true;
7829     break;
7830   }
7831 
7832   return false;
7833 }
7834 
7835 bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
7836                                                    const SCEV *LHS,
7837                                                    const SCEV *RHS) {
7838   if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
7839     return false;
7840 
7841   // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
7842   // the stack can result in exponential time complexity.
7843   SaveAndRestore<bool> Restore(ProvingSplitPredicate, true);
7844 
7845   // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
7846   //
7847   // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
7848   // isKnownPredicate.  isKnownPredicate is more powerful, but also more
7849   // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
7850   // interesting cases seen in practice.  We can consider "upgrading" L >= 0 to
7851   // use isKnownPredicate later if needed.
7852   return isKnownNonNegative(RHS) &&
7853          isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
7854          isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
7855 }
7856 
7857 bool ScalarEvolution::isImpliedViaGuard(BasicBlock *BB,
7858                                         ICmpInst::Predicate Pred,
7859                                         const SCEV *LHS, const SCEV *RHS) {
7860   // No need to even try if we know the module has no guards.
7861   if (!HasGuards)
7862     return false;
7863 
7864   return any_of(*BB, [&](Instruction &I) {
7865     using namespace llvm::PatternMatch;
7866 
7867     Value *Condition;
7868     return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
7869                          m_Value(Condition))) &&
7870            isImpliedCond(Pred, LHS, RHS, Condition, false);
7871   });
7872 }
7873 
7874 /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
7875 /// protected by a conditional between LHS and RHS.  This is used to
7876 /// to eliminate casts.
7877 bool
7878 ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
7879                                              ICmpInst::Predicate Pred,
7880                                              const SCEV *LHS, const SCEV *RHS) {
7881   // Interpret a null as meaning no loop, where there is obviously no guard
7882   // (interprocedural conditions notwithstanding).
7883   if (!L) return true;
7884 
7885   if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS))
7886     return true;
7887 
7888   BasicBlock *Latch = L->getLoopLatch();
7889   if (!Latch)
7890     return false;
7891 
7892   BranchInst *LoopContinuePredicate =
7893     dyn_cast<BranchInst>(Latch->getTerminator());
7894   if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
7895       isImpliedCond(Pred, LHS, RHS,
7896                     LoopContinuePredicate->getCondition(),
7897                     LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
7898     return true;
7899 
7900   // We don't want more than one activation of the following loops on the stack
7901   // -- that can lead to O(n!) time complexity.
7902   if (WalkingBEDominatingConds)
7903     return false;
7904 
7905   SaveAndRestore<bool> ClearOnExit(WalkingBEDominatingConds, true);
7906 
7907   // See if we can exploit a trip count to prove the predicate.
7908   const auto &BETakenInfo = getBackedgeTakenInfo(L);
7909   const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
7910   if (LatchBECount != getCouldNotCompute()) {
7911     // We know that Latch branches back to the loop header exactly
7912     // LatchBECount times.  This means the backdege condition at Latch is
7913     // equivalent to  "{0,+,1} u< LatchBECount".
7914     Type *Ty = LatchBECount->getType();
7915     auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
7916     const SCEV *LoopCounter =
7917       getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
7918     if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
7919                       LatchBECount))
7920       return true;
7921   }
7922 
7923   // Check conditions due to any @llvm.assume intrinsics.
7924   for (auto &AssumeVH : AC.assumptions()) {
7925     if (!AssumeVH)
7926       continue;
7927     auto *CI = cast<CallInst>(AssumeVH);
7928     if (!DT.dominates(CI, Latch->getTerminator()))
7929       continue;
7930 
7931     if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
7932       return true;
7933   }
7934 
7935   // If the loop is not reachable from the entry block, we risk running into an
7936   // infinite loop as we walk up into the dom tree.  These loops do not matter
7937   // anyway, so we just return a conservative answer when we see them.
7938   if (!DT.isReachableFromEntry(L->getHeader()))
7939     return false;
7940 
7941   if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
7942     return true;
7943 
7944   for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
7945        DTN != HeaderDTN; DTN = DTN->getIDom()) {
7946 
7947     assert(DTN && "should reach the loop header before reaching the root!");
7948 
7949     BasicBlock *BB = DTN->getBlock();
7950     if (isImpliedViaGuard(BB, Pred, LHS, RHS))
7951       return true;
7952 
7953     BasicBlock *PBB = BB->getSinglePredecessor();
7954     if (!PBB)
7955       continue;
7956 
7957     BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
7958     if (!ContinuePredicate || !ContinuePredicate->isConditional())
7959       continue;
7960 
7961     Value *Condition = ContinuePredicate->getCondition();
7962 
7963     // If we have an edge `E` within the loop body that dominates the only
7964     // latch, the condition guarding `E` also guards the backedge.  This
7965     // reasoning works only for loops with a single latch.
7966 
7967     BasicBlockEdge DominatingEdge(PBB, BB);
7968     if (DominatingEdge.isSingleEdge()) {
7969       // We're constructively (and conservatively) enumerating edges within the
7970       // loop body that dominate the latch.  The dominator tree better agree
7971       // with us on this:
7972       assert(DT.dominates(DominatingEdge, Latch) && "should be!");
7973 
7974       if (isImpliedCond(Pred, LHS, RHS, Condition,
7975                         BB != ContinuePredicate->getSuccessor(0)))
7976         return true;
7977     }
7978   }
7979 
7980   return false;
7981 }
7982 
7983 bool
7984 ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
7985                                           ICmpInst::Predicate Pred,
7986                                           const SCEV *LHS, const SCEV *RHS) {
7987   // Interpret a null as meaning no loop, where there is obviously no guard
7988   // (interprocedural conditions notwithstanding).
7989   if (!L) return false;
7990 
7991   if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS))
7992     return true;
7993 
7994   // Starting at the loop predecessor, climb up the predecessor chain, as long
7995   // as there are predecessors that can be found that have unique successors
7996   // leading to the original header.
7997   for (std::pair<BasicBlock *, BasicBlock *>
7998          Pair(L->getLoopPredecessor(), L->getHeader());
7999        Pair.first;
8000        Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
8001 
8002     if (isImpliedViaGuard(Pair.first, Pred, LHS, RHS))
8003       return true;
8004 
8005     BranchInst *LoopEntryPredicate =
8006       dyn_cast<BranchInst>(Pair.first->getTerminator());
8007     if (!LoopEntryPredicate ||
8008         LoopEntryPredicate->isUnconditional())
8009       continue;
8010 
8011     if (isImpliedCond(Pred, LHS, RHS,
8012                       LoopEntryPredicate->getCondition(),
8013                       LoopEntryPredicate->getSuccessor(0) != Pair.second))
8014       return true;
8015   }
8016 
8017   // Check conditions due to any @llvm.assume intrinsics.
8018   for (auto &AssumeVH : AC.assumptions()) {
8019     if (!AssumeVH)
8020       continue;
8021     auto *CI = cast<CallInst>(AssumeVH);
8022     if (!DT.dominates(CI, L->getHeader()))
8023       continue;
8024 
8025     if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
8026       return true;
8027   }
8028 
8029   return false;
8030 }
8031 
8032 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
8033                                     const SCEV *LHS, const SCEV *RHS,
8034                                     Value *FoundCondValue,
8035                                     bool Inverse) {
8036   if (!PendingLoopPredicates.insert(FoundCondValue).second)
8037     return false;
8038 
8039   auto ClearOnExit =
8040       make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
8041 
8042   // Recursively handle And and Or conditions.
8043   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
8044     if (BO->getOpcode() == Instruction::And) {
8045       if (!Inverse)
8046         return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
8047                isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
8048     } else if (BO->getOpcode() == Instruction::Or) {
8049       if (Inverse)
8050         return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
8051                isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
8052     }
8053   }
8054 
8055   ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
8056   if (!ICI) return false;
8057 
8058   // Now that we found a conditional branch that dominates the loop or controls
8059   // the loop latch. Check to see if it is the comparison we are looking for.
8060   ICmpInst::Predicate FoundPred;
8061   if (Inverse)
8062     FoundPred = ICI->getInversePredicate();
8063   else
8064     FoundPred = ICI->getPredicate();
8065 
8066   const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
8067   const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
8068 
8069   return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS);
8070 }
8071 
8072 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
8073                                     const SCEV *RHS,
8074                                     ICmpInst::Predicate FoundPred,
8075                                     const SCEV *FoundLHS,
8076                                     const SCEV *FoundRHS) {
8077   // Balance the types.
8078   if (getTypeSizeInBits(LHS->getType()) <
8079       getTypeSizeInBits(FoundLHS->getType())) {
8080     if (CmpInst::isSigned(Pred)) {
8081       LHS = getSignExtendExpr(LHS, FoundLHS->getType());
8082       RHS = getSignExtendExpr(RHS, FoundLHS->getType());
8083     } else {
8084       LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
8085       RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
8086     }
8087   } else if (getTypeSizeInBits(LHS->getType()) >
8088       getTypeSizeInBits(FoundLHS->getType())) {
8089     if (CmpInst::isSigned(FoundPred)) {
8090       FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
8091       FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
8092     } else {
8093       FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
8094       FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
8095     }
8096   }
8097 
8098   // Canonicalize the query to match the way instcombine will have
8099   // canonicalized the comparison.
8100   if (SimplifyICmpOperands(Pred, LHS, RHS))
8101     if (LHS == RHS)
8102       return CmpInst::isTrueWhenEqual(Pred);
8103   if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
8104     if (FoundLHS == FoundRHS)
8105       return CmpInst::isFalseWhenEqual(FoundPred);
8106 
8107   // Check to see if we can make the LHS or RHS match.
8108   if (LHS == FoundRHS || RHS == FoundLHS) {
8109     if (isa<SCEVConstant>(RHS)) {
8110       std::swap(FoundLHS, FoundRHS);
8111       FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
8112     } else {
8113       std::swap(LHS, RHS);
8114       Pred = ICmpInst::getSwappedPredicate(Pred);
8115     }
8116   }
8117 
8118   // Check whether the found predicate is the same as the desired predicate.
8119   if (FoundPred == Pred)
8120     return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
8121 
8122   // Check whether swapping the found predicate makes it the same as the
8123   // desired predicate.
8124   if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
8125     if (isa<SCEVConstant>(RHS))
8126       return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS);
8127     else
8128       return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred),
8129                                    RHS, LHS, FoundLHS, FoundRHS);
8130   }
8131 
8132   // Unsigned comparison is the same as signed comparison when both the operands
8133   // are non-negative.
8134   if (CmpInst::isUnsigned(FoundPred) &&
8135       CmpInst::getSignedPredicate(FoundPred) == Pred &&
8136       isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS))
8137     return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
8138 
8139   // Check if we can make progress by sharpening ranges.
8140   if (FoundPred == ICmpInst::ICMP_NE &&
8141       (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
8142 
8143     const SCEVConstant *C = nullptr;
8144     const SCEV *V = nullptr;
8145 
8146     if (isa<SCEVConstant>(FoundLHS)) {
8147       C = cast<SCEVConstant>(FoundLHS);
8148       V = FoundRHS;
8149     } else {
8150       C = cast<SCEVConstant>(FoundRHS);
8151       V = FoundLHS;
8152     }
8153 
8154     // The guarding predicate tells us that C != V. If the known range
8155     // of V is [C, t), we can sharpen the range to [C + 1, t).  The
8156     // range we consider has to correspond to same signedness as the
8157     // predicate we're interested in folding.
8158 
8159     APInt Min = ICmpInst::isSigned(Pred) ?
8160         getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin();
8161 
8162     if (Min == C->getAPInt()) {
8163       // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
8164       // This is true even if (Min + 1) wraps around -- in case of
8165       // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
8166 
8167       APInt SharperMin = Min + 1;
8168 
8169       switch (Pred) {
8170         case ICmpInst::ICMP_SGE:
8171         case ICmpInst::ICMP_UGE:
8172           // We know V `Pred` SharperMin.  If this implies LHS `Pred`
8173           // RHS, we're done.
8174           if (isImpliedCondOperands(Pred, LHS, RHS, V,
8175                                     getConstant(SharperMin)))
8176             return true;
8177 
8178         case ICmpInst::ICMP_SGT:
8179         case ICmpInst::ICMP_UGT:
8180           // We know from the range information that (V `Pred` Min ||
8181           // V == Min).  We know from the guarding condition that !(V
8182           // == Min).  This gives us
8183           //
8184           //       V `Pred` Min || V == Min && !(V == Min)
8185           //   =>  V `Pred` Min
8186           //
8187           // If V `Pred` Min implies LHS `Pred` RHS, we're done.
8188 
8189           if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min)))
8190             return true;
8191 
8192         default:
8193           // No change
8194           break;
8195       }
8196     }
8197   }
8198 
8199   // Check whether the actual condition is beyond sufficient.
8200   if (FoundPred == ICmpInst::ICMP_EQ)
8201     if (ICmpInst::isTrueWhenEqual(Pred))
8202       if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS))
8203         return true;
8204   if (Pred == ICmpInst::ICMP_NE)
8205     if (!ICmpInst::isTrueWhenEqual(FoundPred))
8206       if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS))
8207         return true;
8208 
8209   // Otherwise assume the worst.
8210   return false;
8211 }
8212 
8213 bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
8214                                      const SCEV *&L, const SCEV *&R,
8215                                      SCEV::NoWrapFlags &Flags) {
8216   const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
8217   if (!AE || AE->getNumOperands() != 2)
8218     return false;
8219 
8220   L = AE->getOperand(0);
8221   R = AE->getOperand(1);
8222   Flags = AE->getNoWrapFlags();
8223   return true;
8224 }
8225 
8226 Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
8227                                                            const SCEV *Less) {
8228   // We avoid subtracting expressions here because this function is usually
8229   // fairly deep in the call stack (i.e. is called many times).
8230 
8231   if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
8232     const auto *LAR = cast<SCEVAddRecExpr>(Less);
8233     const auto *MAR = cast<SCEVAddRecExpr>(More);
8234 
8235     if (LAR->getLoop() != MAR->getLoop())
8236       return None;
8237 
8238     // We look at affine expressions only; not for correctness but to keep
8239     // getStepRecurrence cheap.
8240     if (!LAR->isAffine() || !MAR->isAffine())
8241       return None;
8242 
8243     if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
8244       return None;
8245 
8246     Less = LAR->getStart();
8247     More = MAR->getStart();
8248 
8249     // fall through
8250   }
8251 
8252   if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
8253     const auto &M = cast<SCEVConstant>(More)->getAPInt();
8254     const auto &L = cast<SCEVConstant>(Less)->getAPInt();
8255     return M - L;
8256   }
8257 
8258   const SCEV *L, *R;
8259   SCEV::NoWrapFlags Flags;
8260   if (splitBinaryAdd(Less, L, R, Flags))
8261     if (const auto *LC = dyn_cast<SCEVConstant>(L))
8262       if (R == More)
8263         return -(LC->getAPInt());
8264 
8265   if (splitBinaryAdd(More, L, R, Flags))
8266     if (const auto *LC = dyn_cast<SCEVConstant>(L))
8267       if (R == Less)
8268         return LC->getAPInt();
8269 
8270   return None;
8271 }
8272 
8273 bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
8274     ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
8275     const SCEV *FoundLHS, const SCEV *FoundRHS) {
8276   if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
8277     return false;
8278 
8279   const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
8280   if (!AddRecLHS)
8281     return false;
8282 
8283   const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
8284   if (!AddRecFoundLHS)
8285     return false;
8286 
8287   // We'd like to let SCEV reason about control dependencies, so we constrain
8288   // both the inequalities to be about add recurrences on the same loop.  This
8289   // way we can use isLoopEntryGuardedByCond later.
8290 
8291   const Loop *L = AddRecFoundLHS->getLoop();
8292   if (L != AddRecLHS->getLoop())
8293     return false;
8294 
8295   //  FoundLHS u< FoundRHS u< -C =>  (FoundLHS + C) u< (FoundRHS + C) ... (1)
8296   //
8297   //  FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
8298   //                                                                  ... (2)
8299   //
8300   // Informal proof for (2), assuming (1) [*]:
8301   //
8302   // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
8303   //
8304   // Then
8305   //
8306   //       FoundLHS s< FoundRHS s< INT_MIN - C
8307   // <=>  (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C   [ using (3) ]
8308   // <=>  (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
8309   // <=>  (FoundLHS + INT_MIN + C + INT_MIN) s<
8310   //                        (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
8311   // <=>  FoundLHS + C s< FoundRHS + C
8312   //
8313   // [*]: (1) can be proved by ruling out overflow.
8314   //
8315   // [**]: This can be proved by analyzing all the four possibilities:
8316   //    (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
8317   //    (A s>= 0, B s>= 0).
8318   //
8319   // Note:
8320   // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
8321   // will not sign underflow.  For instance, say FoundLHS = (i8 -128), FoundRHS
8322   // = (i8 -127) and C = (i8 -100).  Then INT_MIN - C = (i8 -28), and FoundRHS
8323   // s< (INT_MIN - C).  Lack of sign overflow / underflow in "FoundRHS + C" is
8324   // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
8325   // C)".
8326 
8327   Optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
8328   Optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
8329   if (!LDiff || !RDiff || *LDiff != *RDiff)
8330     return false;
8331 
8332   if (LDiff->isMinValue())
8333     return true;
8334 
8335   APInt FoundRHSLimit;
8336 
8337   if (Pred == CmpInst::ICMP_ULT) {
8338     FoundRHSLimit = -(*RDiff);
8339   } else {
8340     assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
8341     FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
8342   }
8343 
8344   // Try to prove (1) or (2), as needed.
8345   return isLoopEntryGuardedByCond(L, Pred, FoundRHS,
8346                                   getConstant(FoundRHSLimit));
8347 }
8348 
8349 bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
8350                                             const SCEV *LHS, const SCEV *RHS,
8351                                             const SCEV *FoundLHS,
8352                                             const SCEV *FoundRHS) {
8353   if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
8354     return true;
8355 
8356   if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
8357     return true;
8358 
8359   return isImpliedCondOperandsHelper(Pred, LHS, RHS,
8360                                      FoundLHS, FoundRHS) ||
8361          // ~x < ~y --> x > y
8362          isImpliedCondOperandsHelper(Pred, LHS, RHS,
8363                                      getNotSCEV(FoundRHS),
8364                                      getNotSCEV(FoundLHS));
8365 }
8366 
8367 
8368 /// If Expr computes ~A, return A else return nullptr
8369 static const SCEV *MatchNotExpr(const SCEV *Expr) {
8370   const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
8371   if (!Add || Add->getNumOperands() != 2 ||
8372       !Add->getOperand(0)->isAllOnesValue())
8373     return nullptr;
8374 
8375   const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
8376   if (!AddRHS || AddRHS->getNumOperands() != 2 ||
8377       !AddRHS->getOperand(0)->isAllOnesValue())
8378     return nullptr;
8379 
8380   return AddRHS->getOperand(1);
8381 }
8382 
8383 
8384 /// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values?
8385 template<typename MaxExprType>
8386 static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr,
8387                               const SCEV *Candidate) {
8388   const MaxExprType *MaxExpr = dyn_cast<MaxExprType>(MaybeMaxExpr);
8389   if (!MaxExpr) return false;
8390 
8391   return find(MaxExpr->operands(), Candidate) != MaxExpr->op_end();
8392 }
8393 
8394 
8395 /// Is MaybeMinExpr an SMin or UMin of Candidate and some other values?
8396 template<typename MaxExprType>
8397 static bool IsMinConsistingOf(ScalarEvolution &SE,
8398                               const SCEV *MaybeMinExpr,
8399                               const SCEV *Candidate) {
8400   const SCEV *MaybeMaxExpr = MatchNotExpr(MaybeMinExpr);
8401   if (!MaybeMaxExpr)
8402     return false;
8403 
8404   return IsMaxConsistingOf<MaxExprType>(MaybeMaxExpr, SE.getNotSCEV(Candidate));
8405 }
8406 
8407 static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE,
8408                                            ICmpInst::Predicate Pred,
8409                                            const SCEV *LHS, const SCEV *RHS) {
8410 
8411   // If both sides are affine addrecs for the same loop, with equal
8412   // steps, and we know the recurrences don't wrap, then we only
8413   // need to check the predicate on the starting values.
8414 
8415   if (!ICmpInst::isRelational(Pred))
8416     return false;
8417 
8418   const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
8419   if (!LAR)
8420     return false;
8421   const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
8422   if (!RAR)
8423     return false;
8424   if (LAR->getLoop() != RAR->getLoop())
8425     return false;
8426   if (!LAR->isAffine() || !RAR->isAffine())
8427     return false;
8428 
8429   if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
8430     return false;
8431 
8432   SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ?
8433                          SCEV::FlagNSW : SCEV::FlagNUW;
8434   if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
8435     return false;
8436 
8437   return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
8438 }
8439 
8440 /// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
8441 /// expression?
8442 static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
8443                                         ICmpInst::Predicate Pred,
8444                                         const SCEV *LHS, const SCEV *RHS) {
8445   switch (Pred) {
8446   default:
8447     return false;
8448 
8449   case ICmpInst::ICMP_SGE:
8450     std::swap(LHS, RHS);
8451     LLVM_FALLTHROUGH;
8452   case ICmpInst::ICMP_SLE:
8453     return
8454       // min(A, ...) <= A
8455       IsMinConsistingOf<SCEVSMaxExpr>(SE, LHS, RHS) ||
8456       // A <= max(A, ...)
8457       IsMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
8458 
8459   case ICmpInst::ICMP_UGE:
8460     std::swap(LHS, RHS);
8461     LLVM_FALLTHROUGH;
8462   case ICmpInst::ICMP_ULE:
8463     return
8464       // min(A, ...) <= A
8465       IsMinConsistingOf<SCEVUMaxExpr>(SE, LHS, RHS) ||
8466       // A <= max(A, ...)
8467       IsMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
8468   }
8469 
8470   llvm_unreachable("covered switch fell through?!");
8471 }
8472 
8473 bool
8474 ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
8475                                              const SCEV *LHS, const SCEV *RHS,
8476                                              const SCEV *FoundLHS,
8477                                              const SCEV *FoundRHS) {
8478   auto IsKnownPredicateFull =
8479       [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
8480     return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
8481            IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
8482            IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
8483            isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
8484   };
8485 
8486   switch (Pred) {
8487   default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
8488   case ICmpInst::ICMP_EQ:
8489   case ICmpInst::ICMP_NE:
8490     if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
8491       return true;
8492     break;
8493   case ICmpInst::ICMP_SLT:
8494   case ICmpInst::ICMP_SLE:
8495     if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
8496         IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS))
8497       return true;
8498     break;
8499   case ICmpInst::ICMP_SGT:
8500   case ICmpInst::ICMP_SGE:
8501     if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
8502         IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS))
8503       return true;
8504     break;
8505   case ICmpInst::ICMP_ULT:
8506   case ICmpInst::ICMP_ULE:
8507     if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
8508         IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS))
8509       return true;
8510     break;
8511   case ICmpInst::ICMP_UGT:
8512   case ICmpInst::ICMP_UGE:
8513     if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
8514         IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS))
8515       return true;
8516     break;
8517   }
8518 
8519   return false;
8520 }
8521 
8522 bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
8523                                                      const SCEV *LHS,
8524                                                      const SCEV *RHS,
8525                                                      const SCEV *FoundLHS,
8526                                                      const SCEV *FoundRHS) {
8527   if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
8528     // The restriction on `FoundRHS` be lifted easily -- it exists only to
8529     // reduce the compile time impact of this optimization.
8530     return false;
8531 
8532   Optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
8533   if (!Addend)
8534     return false;
8535 
8536   APInt ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
8537 
8538   // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
8539   // antecedent "`FoundLHS` `Pred` `FoundRHS`".
8540   ConstantRange FoundLHSRange =
8541       ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS);
8542 
8543   // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
8544   ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
8545 
8546   // We can also compute the range of values for `LHS` that satisfy the
8547   // consequent, "`LHS` `Pred` `RHS`":
8548   APInt ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
8549   ConstantRange SatisfyingLHSRange =
8550       ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS);
8551 
8552   // The antecedent implies the consequent if every value of `LHS` that
8553   // satisfies the antecedent also satisfies the consequent.
8554   return SatisfyingLHSRange.contains(LHSRange);
8555 }
8556 
8557 bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
8558                                          bool IsSigned, bool NoWrap) {
8559   assert(isKnownPositive(Stride) && "Positive stride expected!");
8560 
8561   if (NoWrap) return false;
8562 
8563   unsigned BitWidth = getTypeSizeInBits(RHS->getType());
8564   const SCEV *One = getOne(Stride->getType());
8565 
8566   if (IsSigned) {
8567     APInt MaxRHS = getSignedRange(RHS).getSignedMax();
8568     APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
8569     APInt MaxStrideMinusOne = getSignedRange(getMinusSCEV(Stride, One))
8570                                 .getSignedMax();
8571 
8572     // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
8573     return (MaxValue - MaxStrideMinusOne).slt(MaxRHS);
8574   }
8575 
8576   APInt MaxRHS = getUnsignedRange(RHS).getUnsignedMax();
8577   APInt MaxValue = APInt::getMaxValue(BitWidth);
8578   APInt MaxStrideMinusOne = getUnsignedRange(getMinusSCEV(Stride, One))
8579                               .getUnsignedMax();
8580 
8581   // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
8582   return (MaxValue - MaxStrideMinusOne).ult(MaxRHS);
8583 }
8584 
8585 bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
8586                                          bool IsSigned, bool NoWrap) {
8587   if (NoWrap) return false;
8588 
8589   unsigned BitWidth = getTypeSizeInBits(RHS->getType());
8590   const SCEV *One = getOne(Stride->getType());
8591 
8592   if (IsSigned) {
8593     APInt MinRHS = getSignedRange(RHS).getSignedMin();
8594     APInt MinValue = APInt::getSignedMinValue(BitWidth);
8595     APInt MaxStrideMinusOne = getSignedRange(getMinusSCEV(Stride, One))
8596                                .getSignedMax();
8597 
8598     // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
8599     return (MinValue + MaxStrideMinusOne).sgt(MinRHS);
8600   }
8601 
8602   APInt MinRHS = getUnsignedRange(RHS).getUnsignedMin();
8603   APInt MinValue = APInt::getMinValue(BitWidth);
8604   APInt MaxStrideMinusOne = getUnsignedRange(getMinusSCEV(Stride, One))
8605                             .getUnsignedMax();
8606 
8607   // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
8608   return (MinValue + MaxStrideMinusOne).ugt(MinRHS);
8609 }
8610 
8611 const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step,
8612                                             bool Equality) {
8613   const SCEV *One = getOne(Step->getType());
8614   Delta = Equality ? getAddExpr(Delta, Step)
8615                    : getAddExpr(Delta, getMinusSCEV(Step, One));
8616   return getUDivExpr(Delta, Step);
8617 }
8618 
8619 ScalarEvolution::ExitLimit
8620 ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
8621                                   const Loop *L, bool IsSigned,
8622                                   bool ControlsExit, bool AllowPredicates) {
8623   SmallPtrSet<const SCEVPredicate *, 4> Predicates;
8624   // We handle only IV < Invariant
8625   if (!isLoopInvariant(RHS, L))
8626     return getCouldNotCompute();
8627 
8628   const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
8629   bool PredicatedIV = false;
8630 
8631   if (!IV && AllowPredicates) {
8632     // Try to make this an AddRec using runtime tests, in the first X
8633     // iterations of this loop, where X is the SCEV expression found by the
8634     // algorithm below.
8635     IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
8636     PredicatedIV = true;
8637   }
8638 
8639   // Avoid weird loops
8640   if (!IV || IV->getLoop() != L || !IV->isAffine())
8641     return getCouldNotCompute();
8642 
8643   bool NoWrap = ControlsExit &&
8644                 IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW);
8645 
8646   const SCEV *Stride = IV->getStepRecurrence(*this);
8647 
8648   bool PositiveStride = isKnownPositive(Stride);
8649 
8650   // Avoid negative or zero stride values.
8651   if (!PositiveStride) {
8652     // We can compute the correct backedge taken count for loops with unknown
8653     // strides if we can prove that the loop is not an infinite loop with side
8654     // effects. Here's the loop structure we are trying to handle -
8655     //
8656     // i = start
8657     // do {
8658     //   A[i] = i;
8659     //   i += s;
8660     // } while (i < end);
8661     //
8662     // The backedge taken count for such loops is evaluated as -
8663     // (max(end, start + stride) - start - 1) /u stride
8664     //
8665     // The additional preconditions that we need to check to prove correctness
8666     // of the above formula is as follows -
8667     //
8668     // a) IV is either nuw or nsw depending upon signedness (indicated by the
8669     //    NoWrap flag).
8670     // b) loop is single exit with no side effects.
8671     //
8672     //
8673     // Precondition a) implies that if the stride is negative, this is a single
8674     // trip loop. The backedge taken count formula reduces to zero in this case.
8675     //
8676     // Precondition b) implies that the unknown stride cannot be zero otherwise
8677     // we have UB.
8678     //
8679     // The positive stride case is the same as isKnownPositive(Stride) returning
8680     // true (original behavior of the function).
8681     //
8682     // We want to make sure that the stride is truly unknown as there are edge
8683     // cases where ScalarEvolution propagates no wrap flags to the
8684     // post-increment/decrement IV even though the increment/decrement operation
8685     // itself is wrapping. The computed backedge taken count may be wrong in
8686     // such cases. This is prevented by checking that the stride is not known to
8687     // be either positive or non-positive. For example, no wrap flags are
8688     // propagated to the post-increment IV of this loop with a trip count of 2 -
8689     //
8690     // unsigned char i;
8691     // for(i=127; i<128; i+=129)
8692     //   A[i] = i;
8693     //
8694     if (PredicatedIV || !NoWrap || isKnownNonPositive(Stride) ||
8695         !loopHasNoSideEffects(L))
8696       return getCouldNotCompute();
8697 
8698   } else if (!Stride->isOne() &&
8699              doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap))
8700     // Avoid proven overflow cases: this will ensure that the backedge taken
8701     // count will not generate any unsigned overflow. Relaxed no-overflow
8702     // conditions exploit NoWrapFlags, allowing to optimize in presence of
8703     // undefined behaviors like the case of C language.
8704     return getCouldNotCompute();
8705 
8706   ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT
8707                                       : ICmpInst::ICMP_ULT;
8708   const SCEV *Start = IV->getStart();
8709   const SCEV *End = RHS;
8710   // If the backedge is taken at least once, then it will be taken
8711   // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start
8712   // is the LHS value of the less-than comparison the first time it is evaluated
8713   // and End is the RHS.
8714   const SCEV *BECountIfBackedgeTaken =
8715     computeBECount(getMinusSCEV(End, Start), Stride, false);
8716   // If the loop entry is guarded by the result of the backedge test of the
8717   // first loop iteration, then we know the backedge will be taken at least
8718   // once and so the backedge taken count is as above. If not then we use the
8719   // expression (max(End,Start)-Start)/Stride to describe the backedge count,
8720   // as if the backedge is taken at least once max(End,Start) is End and so the
8721   // result is as above, and if not max(End,Start) is Start so we get a backedge
8722   // count of zero.
8723   const SCEV *BECount;
8724   if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS))
8725     BECount = BECountIfBackedgeTaken;
8726   else {
8727     End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
8728     BECount = computeBECount(getMinusSCEV(End, Start), Stride, false);
8729   }
8730 
8731   const SCEV *MaxBECount;
8732   bool MaxOrZero = false;
8733   if (isa<SCEVConstant>(BECount))
8734     MaxBECount = BECount;
8735   else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
8736     // If we know exactly how many times the backedge will be taken if it's
8737     // taken at least once, then the backedge count will either be that or
8738     // zero.
8739     MaxBECount = BECountIfBackedgeTaken;
8740     MaxOrZero = true;
8741   } else {
8742     // Calculate the maximum backedge count based on the range of values
8743     // permitted by Start, End, and Stride.
8744     APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin()
8745                               : getUnsignedRange(Start).getUnsignedMin();
8746 
8747     unsigned BitWidth = getTypeSizeInBits(LHS->getType());
8748 
8749     APInt StrideForMaxBECount;
8750 
8751     if (PositiveStride)
8752       StrideForMaxBECount =
8753         IsSigned ? getSignedRange(Stride).getSignedMin()
8754                  : getUnsignedRange(Stride).getUnsignedMin();
8755     else
8756       // Using a stride of 1 is safe when computing max backedge taken count for
8757       // a loop with unknown stride.
8758       StrideForMaxBECount = APInt(BitWidth, 1, IsSigned);
8759 
8760     APInt Limit =
8761       IsSigned ? APInt::getSignedMaxValue(BitWidth) - (StrideForMaxBECount - 1)
8762                : APInt::getMaxValue(BitWidth) - (StrideForMaxBECount - 1);
8763 
8764     // Although End can be a MAX expression we estimate MaxEnd considering only
8765     // the case End = RHS. This is safe because in the other case (End - Start)
8766     // is zero, leading to a zero maximum backedge taken count.
8767     APInt MaxEnd =
8768       IsSigned ? APIntOps::smin(getSignedRange(RHS).getSignedMax(), Limit)
8769                : APIntOps::umin(getUnsignedRange(RHS).getUnsignedMax(), Limit);
8770 
8771     MaxBECount = computeBECount(getConstant(MaxEnd - MinStart),
8772                                 getConstant(StrideForMaxBECount), false);
8773   }
8774 
8775   if (isa<SCEVCouldNotCompute>(MaxBECount))
8776     MaxBECount = BECount;
8777 
8778   return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates);
8779 }
8780 
8781 ScalarEvolution::ExitLimit
8782 ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
8783                                      const Loop *L, bool IsSigned,
8784                                      bool ControlsExit, bool AllowPredicates) {
8785   SmallPtrSet<const SCEVPredicate *, 4> Predicates;
8786   // We handle only IV > Invariant
8787   if (!isLoopInvariant(RHS, L))
8788     return getCouldNotCompute();
8789 
8790   const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
8791   if (!IV && AllowPredicates)
8792     // Try to make this an AddRec using runtime tests, in the first X
8793     // iterations of this loop, where X is the SCEV expression found by the
8794     // algorithm below.
8795     IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
8796 
8797   // Avoid weird loops
8798   if (!IV || IV->getLoop() != L || !IV->isAffine())
8799     return getCouldNotCompute();
8800 
8801   bool NoWrap = ControlsExit &&
8802                 IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW);
8803 
8804   const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
8805 
8806   // Avoid negative or zero stride values
8807   if (!isKnownPositive(Stride))
8808     return getCouldNotCompute();
8809 
8810   // Avoid proven overflow cases: this will ensure that the backedge taken count
8811   // will not generate any unsigned overflow. Relaxed no-overflow conditions
8812   // exploit NoWrapFlags, allowing to optimize in presence of undefined
8813   // behaviors like the case of C language.
8814   if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap))
8815     return getCouldNotCompute();
8816 
8817   ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT
8818                                       : ICmpInst::ICMP_UGT;
8819 
8820   const SCEV *Start = IV->getStart();
8821   const SCEV *End = RHS;
8822   if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS))
8823     End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
8824 
8825   const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false);
8826 
8827   APInt MaxStart = IsSigned ? getSignedRange(Start).getSignedMax()
8828                             : getUnsignedRange(Start).getUnsignedMax();
8829 
8830   APInt MinStride = IsSigned ? getSignedRange(Stride).getSignedMin()
8831                              : getUnsignedRange(Stride).getUnsignedMin();
8832 
8833   unsigned BitWidth = getTypeSizeInBits(LHS->getType());
8834   APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
8835                          : APInt::getMinValue(BitWidth) + (MinStride - 1);
8836 
8837   // Although End can be a MIN expression we estimate MinEnd considering only
8838   // the case End = RHS. This is safe because in the other case (Start - End)
8839   // is zero, leading to a zero maximum backedge taken count.
8840   APInt MinEnd =
8841     IsSigned ? APIntOps::smax(getSignedRange(RHS).getSignedMin(), Limit)
8842              : APIntOps::umax(getUnsignedRange(RHS).getUnsignedMin(), Limit);
8843 
8844 
8845   const SCEV *MaxBECount = getCouldNotCompute();
8846   if (isa<SCEVConstant>(BECount))
8847     MaxBECount = BECount;
8848   else
8849     MaxBECount = computeBECount(getConstant(MaxStart - MinEnd),
8850                                 getConstant(MinStride), false);
8851 
8852   if (isa<SCEVCouldNotCompute>(MaxBECount))
8853     MaxBECount = BECount;
8854 
8855   return ExitLimit(BECount, MaxBECount, false, Predicates);
8856 }
8857 
8858 const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
8859                                                     ScalarEvolution &SE) const {
8860   if (Range.isFullSet())  // Infinite loop.
8861     return SE.getCouldNotCompute();
8862 
8863   // If the start is a non-zero constant, shift the range to simplify things.
8864   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
8865     if (!SC->getValue()->isZero()) {
8866       SmallVector<const SCEV *, 4> Operands(op_begin(), op_end());
8867       Operands[0] = SE.getZero(SC->getType());
8868       const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
8869                                              getNoWrapFlags(FlagNW));
8870       if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
8871         return ShiftedAddRec->getNumIterationsInRange(
8872             Range.subtract(SC->getAPInt()), SE);
8873       // This is strange and shouldn't happen.
8874       return SE.getCouldNotCompute();
8875     }
8876 
8877   // The only time we can solve this is when we have all constant indices.
8878   // Otherwise, we cannot determine the overflow conditions.
8879   if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
8880     return SE.getCouldNotCompute();
8881 
8882   // Okay at this point we know that all elements of the chrec are constants and
8883   // that the start element is zero.
8884 
8885   // First check to see if the range contains zero.  If not, the first
8886   // iteration exits.
8887   unsigned BitWidth = SE.getTypeSizeInBits(getType());
8888   if (!Range.contains(APInt(BitWidth, 0)))
8889     return SE.getZero(getType());
8890 
8891   if (isAffine()) {
8892     // If this is an affine expression then we have this situation:
8893     //   Solve {0,+,A} in Range  ===  Ax in Range
8894 
8895     // We know that zero is in the range.  If A is positive then we know that
8896     // the upper value of the range must be the first possible exit value.
8897     // If A is negative then the lower of the range is the last possible loop
8898     // value.  Also note that we already checked for a full range.
8899     APInt One(BitWidth,1);
8900     APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
8901     APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
8902 
8903     // The exit value should be (End+A)/A.
8904     APInt ExitVal = (End + A).udiv(A);
8905     ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
8906 
8907     // Evaluate at the exit value.  If we really did fall out of the valid
8908     // range, then we computed our trip count, otherwise wrap around or other
8909     // things must have happened.
8910     ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
8911     if (Range.contains(Val->getValue()))
8912       return SE.getCouldNotCompute();  // Something strange happened
8913 
8914     // Ensure that the previous value is in the range.  This is a sanity check.
8915     assert(Range.contains(
8916            EvaluateConstantChrecAtConstant(this,
8917            ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) &&
8918            "Linear scev computation is off in a bad way!");
8919     return SE.getConstant(ExitValue);
8920   } else if (isQuadratic()) {
8921     // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
8922     // quadratic equation to solve it.  To do this, we must frame our problem in
8923     // terms of figuring out when zero is crossed, instead of when
8924     // Range.getUpper() is crossed.
8925     SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end());
8926     NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
8927     const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), FlagAnyWrap);
8928 
8929     // Next, solve the constructed addrec
8930     if (auto Roots =
8931             SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE)) {
8932       const SCEVConstant *R1 = Roots->first;
8933       const SCEVConstant *R2 = Roots->second;
8934       // Pick the smallest positive root value.
8935       if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp(
8936               ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) {
8937         if (!CB->getZExtValue())
8938           std::swap(R1, R2); // R1 is the minimum root now.
8939 
8940         // Make sure the root is not off by one.  The returned iteration should
8941         // not be in the range, but the previous one should be.  When solving
8942         // for "X*X < 5", for example, we should not return a root of 2.
8943         ConstantInt *R1Val =
8944             EvaluateConstantChrecAtConstant(this, R1->getValue(), SE);
8945         if (Range.contains(R1Val->getValue())) {
8946           // The next iteration must be out of the range...
8947           ConstantInt *NextVal =
8948               ConstantInt::get(SE.getContext(), R1->getAPInt() + 1);
8949 
8950           R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
8951           if (!Range.contains(R1Val->getValue()))
8952             return SE.getConstant(NextVal);
8953           return SE.getCouldNotCompute(); // Something strange happened
8954         }
8955 
8956         // If R1 was not in the range, then it is a good return value.  Make
8957         // sure that R1-1 WAS in the range though, just in case.
8958         ConstantInt *NextVal =
8959             ConstantInt::get(SE.getContext(), R1->getAPInt() - 1);
8960         R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
8961         if (Range.contains(R1Val->getValue()))
8962           return R1;
8963         return SE.getCouldNotCompute(); // Something strange happened
8964       }
8965     }
8966   }
8967 
8968   return SE.getCouldNotCompute();
8969 }
8970 
8971 namespace {
8972 struct FindUndefs {
8973   bool Found;
8974   FindUndefs() : Found(false) {}
8975 
8976   bool follow(const SCEV *S) {
8977     if (const SCEVUnknown *C = dyn_cast<SCEVUnknown>(S)) {
8978       if (isa<UndefValue>(C->getValue()))
8979         Found = true;
8980     } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
8981       if (isa<UndefValue>(C->getValue()))
8982         Found = true;
8983     }
8984 
8985     // Keep looking if we haven't found it yet.
8986     return !Found;
8987   }
8988   bool isDone() const {
8989     // Stop recursion if we have found an undef.
8990     return Found;
8991   }
8992 };
8993 }
8994 
8995 // Return true when S contains at least an undef value.
8996 static inline bool
8997 containsUndefs(const SCEV *S) {
8998   FindUndefs F;
8999   SCEVTraversal<FindUndefs> ST(F);
9000   ST.visitAll(S);
9001 
9002   return F.Found;
9003 }
9004 
9005 namespace {
9006 // Collect all steps of SCEV expressions.
9007 struct SCEVCollectStrides {
9008   ScalarEvolution &SE;
9009   SmallVectorImpl<const SCEV *> &Strides;
9010 
9011   SCEVCollectStrides(ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &S)
9012       : SE(SE), Strides(S) {}
9013 
9014   bool follow(const SCEV *S) {
9015     if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S))
9016       Strides.push_back(AR->getStepRecurrence(SE));
9017     return true;
9018   }
9019   bool isDone() const { return false; }
9020 };
9021 
9022 // Collect all SCEVUnknown and SCEVMulExpr expressions.
9023 struct SCEVCollectTerms {
9024   SmallVectorImpl<const SCEV *> &Terms;
9025 
9026   SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T)
9027       : Terms(T) {}
9028 
9029   bool follow(const SCEV *S) {
9030     if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S) ||
9031         isa<SCEVSignExtendExpr>(S)) {
9032       if (!containsUndefs(S))
9033         Terms.push_back(S);
9034 
9035       // Stop recursion: once we collected a term, do not walk its operands.
9036       return false;
9037     }
9038 
9039     // Keep looking.
9040     return true;
9041   }
9042   bool isDone() const { return false; }
9043 };
9044 
9045 // Check if a SCEV contains an AddRecExpr.
9046 struct SCEVHasAddRec {
9047   bool &ContainsAddRec;
9048 
9049   SCEVHasAddRec(bool &ContainsAddRec) : ContainsAddRec(ContainsAddRec) {
9050    ContainsAddRec = false;
9051   }
9052 
9053   bool follow(const SCEV *S) {
9054     if (isa<SCEVAddRecExpr>(S)) {
9055       ContainsAddRec = true;
9056 
9057       // Stop recursion: once we collected a term, do not walk its operands.
9058       return false;
9059     }
9060 
9061     // Keep looking.
9062     return true;
9063   }
9064   bool isDone() const { return false; }
9065 };
9066 
9067 // Find factors that are multiplied with an expression that (possibly as a
9068 // subexpression) contains an AddRecExpr. In the expression:
9069 //
9070 //  8 * (100 +  %p * %q * (%a + {0, +, 1}_loop))
9071 //
9072 // "%p * %q" are factors multiplied by the expression "(%a + {0, +, 1}_loop)"
9073 // that contains the AddRec {0, +, 1}_loop. %p * %q are likely to be array size
9074 // parameters as they form a product with an induction variable.
9075 //
9076 // This collector expects all array size parameters to be in the same MulExpr.
9077 // It might be necessary to later add support for collecting parameters that are
9078 // spread over different nested MulExpr.
9079 struct SCEVCollectAddRecMultiplies {
9080   SmallVectorImpl<const SCEV *> &Terms;
9081   ScalarEvolution &SE;
9082 
9083   SCEVCollectAddRecMultiplies(SmallVectorImpl<const SCEV *> &T, ScalarEvolution &SE)
9084       : Terms(T), SE(SE) {}
9085 
9086   bool follow(const SCEV *S) {
9087     if (auto *Mul = dyn_cast<SCEVMulExpr>(S)) {
9088       bool HasAddRec = false;
9089       SmallVector<const SCEV *, 0> Operands;
9090       for (auto Op : Mul->operands()) {
9091         if (isa<SCEVUnknown>(Op)) {
9092           Operands.push_back(Op);
9093         } else {
9094           bool ContainsAddRec;
9095           SCEVHasAddRec ContiansAddRec(ContainsAddRec);
9096           visitAll(Op, ContiansAddRec);
9097           HasAddRec |= ContainsAddRec;
9098         }
9099       }
9100       if (Operands.size() == 0)
9101         return true;
9102 
9103       if (!HasAddRec)
9104         return false;
9105 
9106       Terms.push_back(SE.getMulExpr(Operands));
9107       // Stop recursion: once we collected a term, do not walk its operands.
9108       return false;
9109     }
9110 
9111     // Keep looking.
9112     return true;
9113   }
9114   bool isDone() const { return false; }
9115 };
9116 }
9117 
9118 /// Find parametric terms in this SCEVAddRecExpr. We first for parameters in
9119 /// two places:
9120 ///   1) The strides of AddRec expressions.
9121 ///   2) Unknowns that are multiplied with AddRec expressions.
9122 void ScalarEvolution::collectParametricTerms(const SCEV *Expr,
9123     SmallVectorImpl<const SCEV *> &Terms) {
9124   SmallVector<const SCEV *, 4> Strides;
9125   SCEVCollectStrides StrideCollector(*this, Strides);
9126   visitAll(Expr, StrideCollector);
9127 
9128   DEBUG({
9129       dbgs() << "Strides:\n";
9130       for (const SCEV *S : Strides)
9131         dbgs() << *S << "\n";
9132     });
9133 
9134   for (const SCEV *S : Strides) {
9135     SCEVCollectTerms TermCollector(Terms);
9136     visitAll(S, TermCollector);
9137   }
9138 
9139   DEBUG({
9140       dbgs() << "Terms:\n";
9141       for (const SCEV *T : Terms)
9142         dbgs() << *T << "\n";
9143     });
9144 
9145   SCEVCollectAddRecMultiplies MulCollector(Terms, *this);
9146   visitAll(Expr, MulCollector);
9147 }
9148 
9149 static bool findArrayDimensionsRec(ScalarEvolution &SE,
9150                                    SmallVectorImpl<const SCEV *> &Terms,
9151                                    SmallVectorImpl<const SCEV *> &Sizes) {
9152   int Last = Terms.size() - 1;
9153   const SCEV *Step = Terms[Last];
9154 
9155   // End of recursion.
9156   if (Last == 0) {
9157     if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Step)) {
9158       SmallVector<const SCEV *, 2> Qs;
9159       for (const SCEV *Op : M->operands())
9160         if (!isa<SCEVConstant>(Op))
9161           Qs.push_back(Op);
9162 
9163       Step = SE.getMulExpr(Qs);
9164     }
9165 
9166     Sizes.push_back(Step);
9167     return true;
9168   }
9169 
9170   for (const SCEV *&Term : Terms) {
9171     // Normalize the terms before the next call to findArrayDimensionsRec.
9172     const SCEV *Q, *R;
9173     SCEVDivision::divide(SE, Term, Step, &Q, &R);
9174 
9175     // Bail out when GCD does not evenly divide one of the terms.
9176     if (!R->isZero())
9177       return false;
9178 
9179     Term = Q;
9180   }
9181 
9182   // Remove all SCEVConstants.
9183   Terms.erase(
9184       remove_if(Terms, [](const SCEV *E) { return isa<SCEVConstant>(E); }),
9185       Terms.end());
9186 
9187   if (Terms.size() > 0)
9188     if (!findArrayDimensionsRec(SE, Terms, Sizes))
9189       return false;
9190 
9191   Sizes.push_back(Step);
9192   return true;
9193 }
9194 
9195 // Returns true when S contains at least a SCEVUnknown parameter.
9196 static inline bool
9197 containsParameters(const SCEV *S) {
9198   struct FindParameter {
9199     bool FoundParameter;
9200     FindParameter() : FoundParameter(false) {}
9201 
9202     bool follow(const SCEV *S) {
9203       if (isa<SCEVUnknown>(S)) {
9204         FoundParameter = true;
9205         // Stop recursion: we found a parameter.
9206         return false;
9207       }
9208       // Keep looking.
9209       return true;
9210     }
9211     bool isDone() const {
9212       // Stop recursion if we have found a parameter.
9213       return FoundParameter;
9214     }
9215   };
9216 
9217   FindParameter F;
9218   SCEVTraversal<FindParameter> ST(F);
9219   ST.visitAll(S);
9220 
9221   return F.FoundParameter;
9222 }
9223 
9224 // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter.
9225 static inline bool
9226 containsParameters(SmallVectorImpl<const SCEV *> &Terms) {
9227   for (const SCEV *T : Terms)
9228     if (containsParameters(T))
9229       return true;
9230   return false;
9231 }
9232 
9233 // Return the number of product terms in S.
9234 static inline int numberOfTerms(const SCEV *S) {
9235   if (const SCEVMulExpr *Expr = dyn_cast<SCEVMulExpr>(S))
9236     return Expr->getNumOperands();
9237   return 1;
9238 }
9239 
9240 static const SCEV *removeConstantFactors(ScalarEvolution &SE, const SCEV *T) {
9241   if (isa<SCEVConstant>(T))
9242     return nullptr;
9243 
9244   if (isa<SCEVUnknown>(T))
9245     return T;
9246 
9247   if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(T)) {
9248     SmallVector<const SCEV *, 2> Factors;
9249     for (const SCEV *Op : M->operands())
9250       if (!isa<SCEVConstant>(Op))
9251         Factors.push_back(Op);
9252 
9253     return SE.getMulExpr(Factors);
9254   }
9255 
9256   return T;
9257 }
9258 
9259 /// Return the size of an element read or written by Inst.
9260 const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) {
9261   Type *Ty;
9262   if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
9263     Ty = Store->getValueOperand()->getType();
9264   else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
9265     Ty = Load->getType();
9266   else
9267     return nullptr;
9268 
9269   Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty));
9270   return getSizeOfExpr(ETy, Ty);
9271 }
9272 
9273 void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,
9274                                           SmallVectorImpl<const SCEV *> &Sizes,
9275                                           const SCEV *ElementSize) const {
9276   if (Terms.size() < 1 || !ElementSize)
9277     return;
9278 
9279   // Early return when Terms do not contain parameters: we do not delinearize
9280   // non parametric SCEVs.
9281   if (!containsParameters(Terms))
9282     return;
9283 
9284   DEBUG({
9285       dbgs() << "Terms:\n";
9286       for (const SCEV *T : Terms)
9287         dbgs() << *T << "\n";
9288     });
9289 
9290   // Remove duplicates.
9291   std::sort(Terms.begin(), Terms.end());
9292   Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end());
9293 
9294   // Put larger terms first.
9295   std::sort(Terms.begin(), Terms.end(), [](const SCEV *LHS, const SCEV *RHS) {
9296     return numberOfTerms(LHS) > numberOfTerms(RHS);
9297   });
9298 
9299   ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
9300 
9301   // Try to divide all terms by the element size. If term is not divisible by
9302   // element size, proceed with the original term.
9303   for (const SCEV *&Term : Terms) {
9304     const SCEV *Q, *R;
9305     SCEVDivision::divide(SE, Term, ElementSize, &Q, &R);
9306     if (!Q->isZero())
9307       Term = Q;
9308   }
9309 
9310   SmallVector<const SCEV *, 4> NewTerms;
9311 
9312   // Remove constant factors.
9313   for (const SCEV *T : Terms)
9314     if (const SCEV *NewT = removeConstantFactors(SE, T))
9315       NewTerms.push_back(NewT);
9316 
9317   DEBUG({
9318       dbgs() << "Terms after sorting:\n";
9319       for (const SCEV *T : NewTerms)
9320         dbgs() << *T << "\n";
9321     });
9322 
9323   if (NewTerms.empty() ||
9324       !findArrayDimensionsRec(SE, NewTerms, Sizes)) {
9325     Sizes.clear();
9326     return;
9327   }
9328 
9329   // The last element to be pushed into Sizes is the size of an element.
9330   Sizes.push_back(ElementSize);
9331 
9332   DEBUG({
9333       dbgs() << "Sizes:\n";
9334       for (const SCEV *S : Sizes)
9335         dbgs() << *S << "\n";
9336     });
9337 }
9338 
9339 void ScalarEvolution::computeAccessFunctions(
9340     const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts,
9341     SmallVectorImpl<const SCEV *> &Sizes) {
9342 
9343   // Early exit in case this SCEV is not an affine multivariate function.
9344   if (Sizes.empty())
9345     return;
9346 
9347   if (auto *AR = dyn_cast<SCEVAddRecExpr>(Expr))
9348     if (!AR->isAffine())
9349       return;
9350 
9351   const SCEV *Res = Expr;
9352   int Last = Sizes.size() - 1;
9353   for (int i = Last; i >= 0; i--) {
9354     const SCEV *Q, *R;
9355     SCEVDivision::divide(*this, Res, Sizes[i], &Q, &R);
9356 
9357     DEBUG({
9358         dbgs() << "Res: " << *Res << "\n";
9359         dbgs() << "Sizes[i]: " << *Sizes[i] << "\n";
9360         dbgs() << "Res divided by Sizes[i]:\n";
9361         dbgs() << "Quotient: " << *Q << "\n";
9362         dbgs() << "Remainder: " << *R << "\n";
9363       });
9364 
9365     Res = Q;
9366 
9367     // Do not record the last subscript corresponding to the size of elements in
9368     // the array.
9369     if (i == Last) {
9370 
9371       // Bail out if the remainder is too complex.
9372       if (isa<SCEVAddRecExpr>(R)) {
9373         Subscripts.clear();
9374         Sizes.clear();
9375         return;
9376       }
9377 
9378       continue;
9379     }
9380 
9381     // Record the access function for the current subscript.
9382     Subscripts.push_back(R);
9383   }
9384 
9385   // Also push in last position the remainder of the last division: it will be
9386   // the access function of the innermost dimension.
9387   Subscripts.push_back(Res);
9388 
9389   std::reverse(Subscripts.begin(), Subscripts.end());
9390 
9391   DEBUG({
9392       dbgs() << "Subscripts:\n";
9393       for (const SCEV *S : Subscripts)
9394         dbgs() << *S << "\n";
9395     });
9396 }
9397 
9398 /// Splits the SCEV into two vectors of SCEVs representing the subscripts and
9399 /// sizes of an array access. Returns the remainder of the delinearization that
9400 /// is the offset start of the array.  The SCEV->delinearize algorithm computes
9401 /// the multiples of SCEV coefficients: that is a pattern matching of sub
9402 /// expressions in the stride and base of a SCEV corresponding to the
9403 /// computation of a GCD (greatest common divisor) of base and stride.  When
9404 /// SCEV->delinearize fails, it returns the SCEV unchanged.
9405 ///
9406 /// For example: when analyzing the memory access A[i][j][k] in this loop nest
9407 ///
9408 ///  void foo(long n, long m, long o, double A[n][m][o]) {
9409 ///
9410 ///    for (long i = 0; i < n; i++)
9411 ///      for (long j = 0; j < m; j++)
9412 ///        for (long k = 0; k < o; k++)
9413 ///          A[i][j][k] = 1.0;
9414 ///  }
9415 ///
9416 /// the delinearization input is the following AddRec SCEV:
9417 ///
9418 ///  AddRec: {{{%A,+,(8 * %m * %o)}<%for.i>,+,(8 * %o)}<%for.j>,+,8}<%for.k>
9419 ///
9420 /// From this SCEV, we are able to say that the base offset of the access is %A
9421 /// because it appears as an offset that does not divide any of the strides in
9422 /// the loops:
9423 ///
9424 ///  CHECK: Base offset: %A
9425 ///
9426 /// and then SCEV->delinearize determines the size of some of the dimensions of
9427 /// the array as these are the multiples by which the strides are happening:
9428 ///
9429 ///  CHECK: ArrayDecl[UnknownSize][%m][%o] with elements of sizeof(double) bytes.
9430 ///
9431 /// Note that the outermost dimension remains of UnknownSize because there are
9432 /// no strides that would help identifying the size of the last dimension: when
9433 /// the array has been statically allocated, one could compute the size of that
9434 /// dimension by dividing the overall size of the array by the size of the known
9435 /// dimensions: %m * %o * 8.
9436 ///
9437 /// Finally delinearize provides the access functions for the array reference
9438 /// that does correspond to A[i][j][k] of the above C testcase:
9439 ///
9440 ///  CHECK: ArrayRef[{0,+,1}<%for.i>][{0,+,1}<%for.j>][{0,+,1}<%for.k>]
9441 ///
9442 /// The testcases are checking the output of a function pass:
9443 /// DelinearizationPass that walks through all loads and stores of a function
9444 /// asking for the SCEV of the memory access with respect to all enclosing
9445 /// loops, calling SCEV->delinearize on that and printing the results.
9446 
9447 void ScalarEvolution::delinearize(const SCEV *Expr,
9448                                  SmallVectorImpl<const SCEV *> &Subscripts,
9449                                  SmallVectorImpl<const SCEV *> &Sizes,
9450                                  const SCEV *ElementSize) {
9451   // First step: collect parametric terms.
9452   SmallVector<const SCEV *, 4> Terms;
9453   collectParametricTerms(Expr, Terms);
9454 
9455   if (Terms.empty())
9456     return;
9457 
9458   // Second step: find subscript sizes.
9459   findArrayDimensions(Terms, Sizes, ElementSize);
9460 
9461   if (Sizes.empty())
9462     return;
9463 
9464   // Third step: compute the access functions for each subscript.
9465   computeAccessFunctions(Expr, Subscripts, Sizes);
9466 
9467   if (Subscripts.empty())
9468     return;
9469 
9470   DEBUG({
9471       dbgs() << "succeeded to delinearize " << *Expr << "\n";
9472       dbgs() << "ArrayDecl[UnknownSize]";
9473       for (const SCEV *S : Sizes)
9474         dbgs() << "[" << *S << "]";
9475 
9476       dbgs() << "\nArrayRef";
9477       for (const SCEV *S : Subscripts)
9478         dbgs() << "[" << *S << "]";
9479       dbgs() << "\n";
9480     });
9481 }
9482 
9483 //===----------------------------------------------------------------------===//
9484 //                   SCEVCallbackVH Class Implementation
9485 //===----------------------------------------------------------------------===//
9486 
9487 void ScalarEvolution::SCEVCallbackVH::deleted() {
9488   assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
9489   if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
9490     SE->ConstantEvolutionLoopExitValue.erase(PN);
9491   SE->eraseValueFromMap(getValPtr());
9492   // this now dangles!
9493 }
9494 
9495 void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
9496   assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
9497 
9498   // Forget all the expressions associated with users of the old value,
9499   // so that future queries will recompute the expressions using the new
9500   // value.
9501   Value *Old = getValPtr();
9502   SmallVector<User *, 16> Worklist(Old->user_begin(), Old->user_end());
9503   SmallPtrSet<User *, 8> Visited;
9504   while (!Worklist.empty()) {
9505     User *U = Worklist.pop_back_val();
9506     // Deleting the Old value will cause this to dangle. Postpone
9507     // that until everything else is done.
9508     if (U == Old)
9509       continue;
9510     if (!Visited.insert(U).second)
9511       continue;
9512     if (PHINode *PN = dyn_cast<PHINode>(U))
9513       SE->ConstantEvolutionLoopExitValue.erase(PN);
9514     SE->eraseValueFromMap(U);
9515     Worklist.insert(Worklist.end(), U->user_begin(), U->user_end());
9516   }
9517   // Delete the Old value.
9518   if (PHINode *PN = dyn_cast<PHINode>(Old))
9519     SE->ConstantEvolutionLoopExitValue.erase(PN);
9520   SE->eraseValueFromMap(Old);
9521   // this now dangles!
9522 }
9523 
9524 ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
9525   : CallbackVH(V), SE(se) {}
9526 
9527 //===----------------------------------------------------------------------===//
9528 //                   ScalarEvolution Class Implementation
9529 //===----------------------------------------------------------------------===//
9530 
9531 ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI,
9532                                  AssumptionCache &AC, DominatorTree &DT,
9533                                  LoopInfo &LI)
9534     : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI),
9535       CouldNotCompute(new SCEVCouldNotCompute()),
9536       WalkingBEDominatingConds(false), ProvingSplitPredicate(false),
9537       ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64),
9538       FirstUnknown(nullptr) {
9539 
9540   // To use guards for proving predicates, we need to scan every instruction in
9541   // relevant basic blocks, and not just terminators.  Doing this is a waste of
9542   // time if the IR does not actually contain any calls to
9543   // @llvm.experimental.guard, so do a quick check and remember this beforehand.
9544   //
9545   // This pessimizes the case where a pass that preserves ScalarEvolution wants
9546   // to _add_ guards to the module when there weren't any before, and wants
9547   // ScalarEvolution to optimize based on those guards.  For now we prefer to be
9548   // efficient in lieu of being smart in that rather obscure case.
9549 
9550   auto *GuardDecl = F.getParent()->getFunction(
9551       Intrinsic::getName(Intrinsic::experimental_guard));
9552   HasGuards = GuardDecl && !GuardDecl->use_empty();
9553 }
9554 
9555 ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
9556     : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),
9557       LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
9558       ValueExprMap(std::move(Arg.ValueExprMap)),
9559       PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
9560       WalkingBEDominatingConds(false), ProvingSplitPredicate(false),
9561       BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
9562       PredicatedBackedgeTakenCounts(
9563           std::move(Arg.PredicatedBackedgeTakenCounts)),
9564       ConstantEvolutionLoopExitValue(
9565           std::move(Arg.ConstantEvolutionLoopExitValue)),
9566       ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
9567       LoopDispositions(std::move(Arg.LoopDispositions)),
9568       LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
9569       BlockDispositions(std::move(Arg.BlockDispositions)),
9570       UnsignedRanges(std::move(Arg.UnsignedRanges)),
9571       SignedRanges(std::move(Arg.SignedRanges)),
9572       UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
9573       UniquePreds(std::move(Arg.UniquePreds)),
9574       SCEVAllocator(std::move(Arg.SCEVAllocator)),
9575       FirstUnknown(Arg.FirstUnknown) {
9576   Arg.FirstUnknown = nullptr;
9577 }
9578 
9579 ScalarEvolution::~ScalarEvolution() {
9580   // Iterate through all the SCEVUnknown instances and call their
9581   // destructors, so that they release their references to their values.
9582   for (SCEVUnknown *U = FirstUnknown; U;) {
9583     SCEVUnknown *Tmp = U;
9584     U = U->Next;
9585     Tmp->~SCEVUnknown();
9586   }
9587   FirstUnknown = nullptr;
9588 
9589   ExprValueMap.clear();
9590   ValueExprMap.clear();
9591   HasRecMap.clear();
9592 
9593   // Free any extra memory created for ExitNotTakenInfo in the unlikely event
9594   // that a loop had multiple computable exits.
9595   for (auto &BTCI : BackedgeTakenCounts)
9596     BTCI.second.clear();
9597   for (auto &BTCI : PredicatedBackedgeTakenCounts)
9598     BTCI.second.clear();
9599 
9600   assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
9601   assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
9602   assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
9603 }
9604 
9605 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
9606   return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
9607 }
9608 
9609 static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
9610                           const Loop *L) {
9611   // Print all inner loops first
9612   for (Loop *I : *L)
9613     PrintLoopInfo(OS, SE, I);
9614 
9615   OS << "Loop ";
9616   L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
9617   OS << ": ";
9618 
9619   SmallVector<BasicBlock *, 8> ExitBlocks;
9620   L->getExitBlocks(ExitBlocks);
9621   if (ExitBlocks.size() != 1)
9622     OS << "<multiple exits> ";
9623 
9624   if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
9625     OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L);
9626   } else {
9627     OS << "Unpredictable backedge-taken count. ";
9628   }
9629 
9630   OS << "\n"
9631         "Loop ";
9632   L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
9633   OS << ": ";
9634 
9635   if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {
9636     OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L);
9637     if (SE->isBackedgeTakenCountMaxOrZero(L))
9638       OS << ", actual taken count either this or zero.";
9639   } else {
9640     OS << "Unpredictable max backedge-taken count. ";
9641   }
9642 
9643   OS << "\n"
9644         "Loop ";
9645   L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
9646   OS << ": ";
9647 
9648   SCEVUnionPredicate Pred;
9649   auto PBT = SE->getPredicatedBackedgeTakenCount(L, Pred);
9650   if (!isa<SCEVCouldNotCompute>(PBT)) {
9651     OS << "Predicated backedge-taken count is " << *PBT << "\n";
9652     OS << " Predicates:\n";
9653     Pred.print(OS, 4);
9654   } else {
9655     OS << "Unpredictable predicated backedge-taken count. ";
9656   }
9657   OS << "\n";
9658 }
9659 
9660 static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) {
9661   switch (LD) {
9662   case ScalarEvolution::LoopVariant:
9663     return "Variant";
9664   case ScalarEvolution::LoopInvariant:
9665     return "Invariant";
9666   case ScalarEvolution::LoopComputable:
9667     return "Computable";
9668   }
9669   llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!");
9670 }
9671 
9672 void ScalarEvolution::print(raw_ostream &OS) const {
9673   // ScalarEvolution's implementation of the print method is to print
9674   // out SCEV values of all instructions that are interesting. Doing
9675   // this potentially causes it to create new SCEV objects though,
9676   // which technically conflicts with the const qualifier. This isn't
9677   // observable from outside the class though, so casting away the
9678   // const isn't dangerous.
9679   ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
9680 
9681   OS << "Classifying expressions for: ";
9682   F.printAsOperand(OS, /*PrintType=*/false);
9683   OS << "\n";
9684   for (Instruction &I : instructions(F))
9685     if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
9686       OS << I << '\n';
9687       OS << "  -->  ";
9688       const SCEV *SV = SE.getSCEV(&I);
9689       SV->print(OS);
9690       if (!isa<SCEVCouldNotCompute>(SV)) {
9691         OS << " U: ";
9692         SE.getUnsignedRange(SV).print(OS);
9693         OS << " S: ";
9694         SE.getSignedRange(SV).print(OS);
9695       }
9696 
9697       const Loop *L = LI.getLoopFor(I.getParent());
9698 
9699       const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
9700       if (AtUse != SV) {
9701         OS << "  -->  ";
9702         AtUse->print(OS);
9703         if (!isa<SCEVCouldNotCompute>(AtUse)) {
9704           OS << " U: ";
9705           SE.getUnsignedRange(AtUse).print(OS);
9706           OS << " S: ";
9707           SE.getSignedRange(AtUse).print(OS);
9708         }
9709       }
9710 
9711       if (L) {
9712         OS << "\t\t" "Exits: ";
9713         const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
9714         if (!SE.isLoopInvariant(ExitValue, L)) {
9715           OS << "<<Unknown>>";
9716         } else {
9717           OS << *ExitValue;
9718         }
9719 
9720         bool First = true;
9721         for (auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
9722           if (First) {
9723             OS << "\t\t" "LoopDispositions: { ";
9724             First = false;
9725           } else {
9726             OS << ", ";
9727           }
9728 
9729           Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
9730           OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter));
9731         }
9732 
9733         for (auto *InnerL : depth_first(L)) {
9734           if (InnerL == L)
9735             continue;
9736           if (First) {
9737             OS << "\t\t" "LoopDispositions: { ";
9738             First = false;
9739           } else {
9740             OS << ", ";
9741           }
9742 
9743           InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
9744           OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL));
9745         }
9746 
9747         OS << " }";
9748       }
9749 
9750       OS << "\n";
9751     }
9752 
9753   OS << "Determining loop execution counts for: ";
9754   F.printAsOperand(OS, /*PrintType=*/false);
9755   OS << "\n";
9756   for (Loop *I : LI)
9757     PrintLoopInfo(OS, &SE, I);
9758 }
9759 
9760 ScalarEvolution::LoopDisposition
9761 ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
9762   auto &Values = LoopDispositions[S];
9763   for (auto &V : Values) {
9764     if (V.getPointer() == L)
9765       return V.getInt();
9766   }
9767   Values.emplace_back(L, LoopVariant);
9768   LoopDisposition D = computeLoopDisposition(S, L);
9769   auto &Values2 = LoopDispositions[S];
9770   for (auto &V : make_range(Values2.rbegin(), Values2.rend())) {
9771     if (V.getPointer() == L) {
9772       V.setInt(D);
9773       break;
9774     }
9775   }
9776   return D;
9777 }
9778 
9779 ScalarEvolution::LoopDisposition
9780 ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
9781   switch (static_cast<SCEVTypes>(S->getSCEVType())) {
9782   case scConstant:
9783     return LoopInvariant;
9784   case scTruncate:
9785   case scZeroExtend:
9786   case scSignExtend:
9787     return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L);
9788   case scAddRecExpr: {
9789     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
9790 
9791     // If L is the addrec's loop, it's computable.
9792     if (AR->getLoop() == L)
9793       return LoopComputable;
9794 
9795     // Add recurrences are never invariant in the function-body (null loop).
9796     if (!L)
9797       return LoopVariant;
9798 
9799     // This recurrence is variant w.r.t. L if L contains AR's loop.
9800     if (L->contains(AR->getLoop()))
9801       return LoopVariant;
9802 
9803     // This recurrence is invariant w.r.t. L if AR's loop contains L.
9804     if (AR->getLoop()->contains(L))
9805       return LoopInvariant;
9806 
9807     // This recurrence is variant w.r.t. L if any of its operands
9808     // are variant.
9809     for (auto *Op : AR->operands())
9810       if (!isLoopInvariant(Op, L))
9811         return LoopVariant;
9812 
9813     // Otherwise it's loop-invariant.
9814     return LoopInvariant;
9815   }
9816   case scAddExpr:
9817   case scMulExpr:
9818   case scUMaxExpr:
9819   case scSMaxExpr: {
9820     bool HasVarying = false;
9821     for (auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
9822       LoopDisposition D = getLoopDisposition(Op, L);
9823       if (D == LoopVariant)
9824         return LoopVariant;
9825       if (D == LoopComputable)
9826         HasVarying = true;
9827     }
9828     return HasVarying ? LoopComputable : LoopInvariant;
9829   }
9830   case scUDivExpr: {
9831     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
9832     LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L);
9833     if (LD == LoopVariant)
9834       return LoopVariant;
9835     LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L);
9836     if (RD == LoopVariant)
9837       return LoopVariant;
9838     return (LD == LoopInvariant && RD == LoopInvariant) ?
9839            LoopInvariant : LoopComputable;
9840   }
9841   case scUnknown:
9842     // All non-instruction values are loop invariant.  All instructions are loop
9843     // invariant if they are not contained in the specified loop.
9844     // Instructions are never considered invariant in the function body
9845     // (null loop) because they are defined within the "loop".
9846     if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
9847       return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
9848     return LoopInvariant;
9849   case scCouldNotCompute:
9850     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9851   }
9852   llvm_unreachable("Unknown SCEV kind!");
9853 }
9854 
9855 bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) {
9856   return getLoopDisposition(S, L) == LoopInvariant;
9857 }
9858 
9859 bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) {
9860   return getLoopDisposition(S, L) == LoopComputable;
9861 }
9862 
9863 ScalarEvolution::BlockDisposition
9864 ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
9865   auto &Values = BlockDispositions[S];
9866   for (auto &V : Values) {
9867     if (V.getPointer() == BB)
9868       return V.getInt();
9869   }
9870   Values.emplace_back(BB, DoesNotDominateBlock);
9871   BlockDisposition D = computeBlockDisposition(S, BB);
9872   auto &Values2 = BlockDispositions[S];
9873   for (auto &V : make_range(Values2.rbegin(), Values2.rend())) {
9874     if (V.getPointer() == BB) {
9875       V.setInt(D);
9876       break;
9877     }
9878   }
9879   return D;
9880 }
9881 
9882 ScalarEvolution::BlockDisposition
9883 ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
9884   switch (static_cast<SCEVTypes>(S->getSCEVType())) {
9885   case scConstant:
9886     return ProperlyDominatesBlock;
9887   case scTruncate:
9888   case scZeroExtend:
9889   case scSignExtend:
9890     return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB);
9891   case scAddRecExpr: {
9892     // This uses a "dominates" query instead of "properly dominates" query
9893     // to test for proper dominance too, because the instruction which
9894     // produces the addrec's value is a PHI, and a PHI effectively properly
9895     // dominates its entire containing block.
9896     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
9897     if (!DT.dominates(AR->getLoop()->getHeader(), BB))
9898       return DoesNotDominateBlock;
9899 
9900     // Fall through into SCEVNAryExpr handling.
9901     LLVM_FALLTHROUGH;
9902   }
9903   case scAddExpr:
9904   case scMulExpr:
9905   case scUMaxExpr:
9906   case scSMaxExpr: {
9907     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
9908     bool Proper = true;
9909     for (const SCEV *NAryOp : NAry->operands()) {
9910       BlockDisposition D = getBlockDisposition(NAryOp, BB);
9911       if (D == DoesNotDominateBlock)
9912         return DoesNotDominateBlock;
9913       if (D == DominatesBlock)
9914         Proper = false;
9915     }
9916     return Proper ? ProperlyDominatesBlock : DominatesBlock;
9917   }
9918   case scUDivExpr: {
9919     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
9920     const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
9921     BlockDisposition LD = getBlockDisposition(LHS, BB);
9922     if (LD == DoesNotDominateBlock)
9923       return DoesNotDominateBlock;
9924     BlockDisposition RD = getBlockDisposition(RHS, BB);
9925     if (RD == DoesNotDominateBlock)
9926       return DoesNotDominateBlock;
9927     return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ?
9928       ProperlyDominatesBlock : DominatesBlock;
9929   }
9930   case scUnknown:
9931     if (Instruction *I =
9932           dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
9933       if (I->getParent() == BB)
9934         return DominatesBlock;
9935       if (DT.properlyDominates(I->getParent(), BB))
9936         return ProperlyDominatesBlock;
9937       return DoesNotDominateBlock;
9938     }
9939     return ProperlyDominatesBlock;
9940   case scCouldNotCompute:
9941     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9942   }
9943   llvm_unreachable("Unknown SCEV kind!");
9944 }
9945 
9946 bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
9947   return getBlockDisposition(S, BB) >= DominatesBlock;
9948 }
9949 
9950 bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
9951   return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
9952 }
9953 
9954 bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
9955   // Search for a SCEV expression node within an expression tree.
9956   // Implements SCEVTraversal::Visitor.
9957   struct SCEVSearch {
9958     const SCEV *Node;
9959     bool IsFound;
9960 
9961     SCEVSearch(const SCEV *N): Node(N), IsFound(false) {}
9962 
9963     bool follow(const SCEV *S) {
9964       IsFound |= (S == Node);
9965       return !IsFound;
9966     }
9967     bool isDone() const { return IsFound; }
9968   };
9969 
9970   SCEVSearch Search(Op);
9971   visitAll(S, Search);
9972   return Search.IsFound;
9973 }
9974 
9975 void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
9976   ValuesAtScopes.erase(S);
9977   LoopDispositions.erase(S);
9978   BlockDispositions.erase(S);
9979   UnsignedRanges.erase(S);
9980   SignedRanges.erase(S);
9981   ExprValueMap.erase(S);
9982   HasRecMap.erase(S);
9983 
9984   auto RemoveSCEVFromBackedgeMap =
9985       [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) {
9986         for (auto I = Map.begin(), E = Map.end(); I != E;) {
9987           BackedgeTakenInfo &BEInfo = I->second;
9988           if (BEInfo.hasOperand(S, this)) {
9989             BEInfo.clear();
9990             Map.erase(I++);
9991           } else
9992             ++I;
9993         }
9994       };
9995 
9996   RemoveSCEVFromBackedgeMap(BackedgeTakenCounts);
9997   RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts);
9998 }
9999 
10000 typedef DenseMap<const Loop *, std::string> VerifyMap;
10001 
10002 /// replaceSubString - Replaces all occurrences of From in Str with To.
10003 static void replaceSubString(std::string &Str, StringRef From, StringRef To) {
10004   size_t Pos = 0;
10005   while ((Pos = Str.find(From, Pos)) != std::string::npos) {
10006     Str.replace(Pos, From.size(), To.data(), To.size());
10007     Pos += To.size();
10008   }
10009 }
10010 
10011 /// getLoopBackedgeTakenCounts - Helper method for verifyAnalysis.
10012 static void
10013 getLoopBackedgeTakenCounts(Loop *L, VerifyMap &Map, ScalarEvolution &SE) {
10014   std::string &S = Map[L];
10015   if (S.empty()) {
10016     raw_string_ostream OS(S);
10017     SE.getBackedgeTakenCount(L)->print(OS);
10018 
10019     // false and 0 are semantically equivalent. This can happen in dead loops.
10020     replaceSubString(OS.str(), "false", "0");
10021     // Remove wrap flags, their use in SCEV is highly fragile.
10022     // FIXME: Remove this when SCEV gets smarter about them.
10023     replaceSubString(OS.str(), "<nw>", "");
10024     replaceSubString(OS.str(), "<nsw>", "");
10025     replaceSubString(OS.str(), "<nuw>", "");
10026   }
10027 
10028   for (auto *R : reverse(*L))
10029     getLoopBackedgeTakenCounts(R, Map, SE); // recurse.
10030 }
10031 
10032 void ScalarEvolution::verify() const {
10033   ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
10034 
10035   // Gather stringified backedge taken counts for all loops using SCEV's caches.
10036   // FIXME: It would be much better to store actual values instead of strings,
10037   //        but SCEV pointers will change if we drop the caches.
10038   VerifyMap BackedgeDumpsOld, BackedgeDumpsNew;
10039   for (LoopInfo::reverse_iterator I = LI.rbegin(), E = LI.rend(); I != E; ++I)
10040     getLoopBackedgeTakenCounts(*I, BackedgeDumpsOld, SE);
10041 
10042   // Gather stringified backedge taken counts for all loops using a fresh
10043   // ScalarEvolution object.
10044   ScalarEvolution SE2(F, TLI, AC, DT, LI);
10045   for (LoopInfo::reverse_iterator I = LI.rbegin(), E = LI.rend(); I != E; ++I)
10046     getLoopBackedgeTakenCounts(*I, BackedgeDumpsNew, SE2);
10047 
10048   // Now compare whether they're the same with and without caches. This allows
10049   // verifying that no pass changed the cache.
10050   assert(BackedgeDumpsOld.size() == BackedgeDumpsNew.size() &&
10051          "New loops suddenly appeared!");
10052 
10053   for (VerifyMap::iterator OldI = BackedgeDumpsOld.begin(),
10054                            OldE = BackedgeDumpsOld.end(),
10055                            NewI = BackedgeDumpsNew.begin();
10056        OldI != OldE; ++OldI, ++NewI) {
10057     assert(OldI->first == NewI->first && "Loop order changed!");
10058 
10059     // Compare the stringified SCEVs. We don't care if undef backedgetaken count
10060     // changes.
10061     // FIXME: We currently ignore SCEV changes from/to CouldNotCompute. This
10062     // means that a pass is buggy or SCEV has to learn a new pattern but is
10063     // usually not harmful.
10064     if (OldI->second != NewI->second &&
10065         OldI->second.find("undef") == std::string::npos &&
10066         NewI->second.find("undef") == std::string::npos &&
10067         OldI->second != "***COULDNOTCOMPUTE***" &&
10068         NewI->second != "***COULDNOTCOMPUTE***") {
10069       dbgs() << "SCEVValidator: SCEV for loop '"
10070              << OldI->first->getHeader()->getName()
10071              << "' changed from '" << OldI->second
10072              << "' to '" << NewI->second << "'!\n";
10073       std::abort();
10074     }
10075   }
10076 
10077   // TODO: Verify more things.
10078 }
10079 
10080 char ScalarEvolutionAnalysis::PassID;
10081 
10082 ScalarEvolution ScalarEvolutionAnalysis::run(Function &F,
10083                                              FunctionAnalysisManager &AM) {
10084   return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F),
10085                          AM.getResult<AssumptionAnalysis>(F),
10086                          AM.getResult<DominatorTreeAnalysis>(F),
10087                          AM.getResult<LoopAnalysis>(F));
10088 }
10089 
10090 PreservedAnalyses
10091 ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
10092   AM.getResult<ScalarEvolutionAnalysis>(F).print(OS);
10093   return PreservedAnalyses::all();
10094 }
10095 
10096 INITIALIZE_PASS_BEGIN(ScalarEvolutionWrapperPass, "scalar-evolution",
10097                       "Scalar Evolution Analysis", false, true)
10098 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
10099 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
10100 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
10101 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
10102 INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution",
10103                     "Scalar Evolution Analysis", false, true)
10104 char ScalarEvolutionWrapperPass::ID = 0;
10105 
10106 ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) {
10107   initializeScalarEvolutionWrapperPassPass(*PassRegistry::getPassRegistry());
10108 }
10109 
10110 bool ScalarEvolutionWrapperPass::runOnFunction(Function &F) {
10111   SE.reset(new ScalarEvolution(
10112       F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(),
10113       getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
10114       getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
10115       getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
10116   return false;
10117 }
10118 
10119 void ScalarEvolutionWrapperPass::releaseMemory() { SE.reset(); }
10120 
10121 void ScalarEvolutionWrapperPass::print(raw_ostream &OS, const Module *) const {
10122   SE->print(OS);
10123 }
10124 
10125 void ScalarEvolutionWrapperPass::verifyAnalysis() const {
10126   if (!VerifySCEV)
10127     return;
10128 
10129   SE->verify();
10130 }
10131 
10132 void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
10133   AU.setPreservesAll();
10134   AU.addRequiredTransitive<AssumptionCacheTracker>();
10135   AU.addRequiredTransitive<LoopInfoWrapperPass>();
10136   AU.addRequiredTransitive<DominatorTreeWrapperPass>();
10137   AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
10138 }
10139 
10140 const SCEVPredicate *
10141 ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
10142                                    const SCEVConstant *RHS) {
10143   FoldingSetNodeID ID;
10144   // Unique this node based on the arguments
10145   ID.AddInteger(SCEVPredicate::P_Equal);
10146   ID.AddPointer(LHS);
10147   ID.AddPointer(RHS);
10148   void *IP = nullptr;
10149   if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
10150     return S;
10151   SCEVEqualPredicate *Eq = new (SCEVAllocator)
10152       SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS);
10153   UniquePreds.InsertNode(Eq, IP);
10154   return Eq;
10155 }
10156 
10157 const SCEVPredicate *ScalarEvolution::getWrapPredicate(
10158     const SCEVAddRecExpr *AR,
10159     SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
10160   FoldingSetNodeID ID;
10161   // Unique this node based on the arguments
10162   ID.AddInteger(SCEVPredicate::P_Wrap);
10163   ID.AddPointer(AR);
10164   ID.AddInteger(AddedFlags);
10165   void *IP = nullptr;
10166   if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
10167     return S;
10168   auto *OF = new (SCEVAllocator)
10169       SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
10170   UniquePreds.InsertNode(OF, IP);
10171   return OF;
10172 }
10173 
10174 namespace {
10175 
10176 class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
10177 public:
10178   /// Rewrites \p S in the context of a loop L and the SCEV predication
10179   /// infrastructure.
10180   ///
10181   /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
10182   /// equivalences present in \p Pred.
10183   ///
10184   /// If \p NewPreds is non-null, rewrite is free to add further predicates to
10185   /// \p NewPreds such that the result will be an AddRecExpr.
10186   static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
10187                              SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
10188                              SCEVUnionPredicate *Pred) {
10189     SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
10190     return Rewriter.visit(S);
10191   }
10192 
10193   SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
10194                         SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
10195                         SCEVUnionPredicate *Pred)
10196       : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
10197 
10198   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
10199     if (Pred) {
10200       auto ExprPreds = Pred->getPredicatesForExpr(Expr);
10201       for (auto *Pred : ExprPreds)
10202         if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
10203           if (IPred->getLHS() == Expr)
10204             return IPred->getRHS();
10205     }
10206 
10207     return Expr;
10208   }
10209 
10210   const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
10211     const SCEV *Operand = visit(Expr->getOperand());
10212     const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
10213     if (AR && AR->getLoop() == L && AR->isAffine()) {
10214       // This couldn't be folded because the operand didn't have the nuw
10215       // flag. Add the nusw flag as an assumption that we could make.
10216       const SCEV *Step = AR->getStepRecurrence(SE);
10217       Type *Ty = Expr->getType();
10218       if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
10219         return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
10220                                 SE.getSignExtendExpr(Step, Ty), L,
10221                                 AR->getNoWrapFlags());
10222     }
10223     return SE.getZeroExtendExpr(Operand, Expr->getType());
10224   }
10225 
10226   const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
10227     const SCEV *Operand = visit(Expr->getOperand());
10228     const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
10229     if (AR && AR->getLoop() == L && AR->isAffine()) {
10230       // This couldn't be folded because the operand didn't have the nsw
10231       // flag. Add the nssw flag as an assumption that we could make.
10232       const SCEV *Step = AR->getStepRecurrence(SE);
10233       Type *Ty = Expr->getType();
10234       if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
10235         return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
10236                                 SE.getSignExtendExpr(Step, Ty), L,
10237                                 AR->getNoWrapFlags());
10238     }
10239     return SE.getSignExtendExpr(Operand, Expr->getType());
10240   }
10241 
10242 private:
10243   bool addOverflowAssumption(const SCEVAddRecExpr *AR,
10244                              SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
10245     auto *A = SE.getWrapPredicate(AR, AddedFlags);
10246     if (!NewPreds) {
10247       // Check if we've already made this assumption.
10248       return Pred && Pred->implies(A);
10249     }
10250     NewPreds->insert(A);
10251     return true;
10252   }
10253 
10254   SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
10255   SCEVUnionPredicate *Pred;
10256   const Loop *L;
10257 };
10258 } // end anonymous namespace
10259 
10260 const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
10261                                                    SCEVUnionPredicate &Preds) {
10262   return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
10263 }
10264 
10265 const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
10266     const SCEV *S, const Loop *L,
10267     SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
10268 
10269   SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
10270   S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
10271   auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
10272 
10273   if (!AddRec)
10274     return nullptr;
10275 
10276   // Since the transformation was successful, we can now transfer the SCEV
10277   // predicates.
10278   for (auto *P : TransformPreds)
10279     Preds.insert(P);
10280 
10281   return AddRec;
10282 }
10283 
10284 /// SCEV predicates
10285 SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
10286                              SCEVPredicateKind Kind)
10287     : FastID(ID), Kind(Kind) {}
10288 
10289 SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
10290                                        const SCEVUnknown *LHS,
10291                                        const SCEVConstant *RHS)
10292     : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {}
10293 
10294 bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
10295   const auto *Op = dyn_cast<SCEVEqualPredicate>(N);
10296 
10297   if (!Op)
10298     return false;
10299 
10300   return Op->LHS == LHS && Op->RHS == RHS;
10301 }
10302 
10303 bool SCEVEqualPredicate::isAlwaysTrue() const { return false; }
10304 
10305 const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; }
10306 
10307 void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
10308   OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
10309 }
10310 
10311 SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
10312                                      const SCEVAddRecExpr *AR,
10313                                      IncrementWrapFlags Flags)
10314     : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
10315 
10316 const SCEV *SCEVWrapPredicate::getExpr() const { return AR; }
10317 
10318 bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
10319   const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
10320 
10321   return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
10322 }
10323 
10324 bool SCEVWrapPredicate::isAlwaysTrue() const {
10325   SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
10326   IncrementWrapFlags IFlags = Flags;
10327 
10328   if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
10329     IFlags = clearFlags(IFlags, IncrementNSSW);
10330 
10331   return IFlags == IncrementAnyWrap;
10332 }
10333 
10334 void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
10335   OS.indent(Depth) << *getExpr() << " Added Flags: ";
10336   if (SCEVWrapPredicate::IncrementNUSW & getFlags())
10337     OS << "<nusw>";
10338   if (SCEVWrapPredicate::IncrementNSSW & getFlags())
10339     OS << "<nssw>";
10340   OS << "\n";
10341 }
10342 
10343 SCEVWrapPredicate::IncrementWrapFlags
10344 SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
10345                                    ScalarEvolution &SE) {
10346   IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
10347   SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
10348 
10349   // We can safely transfer the NSW flag as NSSW.
10350   if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
10351     ImpliedFlags = IncrementNSSW;
10352 
10353   if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
10354     // If the increment is positive, the SCEV NUW flag will also imply the
10355     // WrapPredicate NUSW flag.
10356     if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
10357       if (Step->getValue()->getValue().isNonNegative())
10358         ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
10359   }
10360 
10361   return ImpliedFlags;
10362 }
10363 
10364 /// Union predicates don't get cached so create a dummy set ID for it.
10365 SCEVUnionPredicate::SCEVUnionPredicate()
10366     : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
10367 
10368 bool SCEVUnionPredicate::isAlwaysTrue() const {
10369   return all_of(Preds,
10370                 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
10371 }
10372 
10373 ArrayRef<const SCEVPredicate *>
10374 SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
10375   auto I = SCEVToPreds.find(Expr);
10376   if (I == SCEVToPreds.end())
10377     return ArrayRef<const SCEVPredicate *>();
10378   return I->second;
10379 }
10380 
10381 bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
10382   if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
10383     return all_of(Set->Preds,
10384                   [this](const SCEVPredicate *I) { return this->implies(I); });
10385 
10386   auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
10387   if (ScevPredsIt == SCEVToPreds.end())
10388     return false;
10389   auto &SCEVPreds = ScevPredsIt->second;
10390 
10391   return any_of(SCEVPreds,
10392                 [N](const SCEVPredicate *I) { return I->implies(N); });
10393 }
10394 
10395 const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }
10396 
10397 void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
10398   for (auto Pred : Preds)
10399     Pred->print(OS, Depth);
10400 }
10401 
10402 void SCEVUnionPredicate::add(const SCEVPredicate *N) {
10403   if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
10404     for (auto Pred : Set->Preds)
10405       add(Pred);
10406     return;
10407   }
10408 
10409   if (implies(N))
10410     return;
10411 
10412   const SCEV *Key = N->getExpr();
10413   assert(Key && "Only SCEVUnionPredicate doesn't have an "
10414                 " associated expression!");
10415 
10416   SCEVToPreds[Key].push_back(N);
10417   Preds.push_back(N);
10418 }
10419 
10420 PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
10421                                                      Loop &L)
10422     : SE(SE), L(L), Generation(0), BackedgeCount(nullptr) {}
10423 
10424 const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
10425   const SCEV *Expr = SE.getSCEV(V);
10426   RewriteEntry &Entry = RewriteMap[Expr];
10427 
10428   // If we already have an entry and the version matches, return it.
10429   if (Entry.second && Generation == Entry.first)
10430     return Entry.second;
10431 
10432   // We found an entry but it's stale. Rewrite the stale entry
10433   // acording to the current predicate.
10434   if (Entry.second)
10435     Expr = Entry.second;
10436 
10437   const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds);
10438   Entry = {Generation, NewSCEV};
10439 
10440   return NewSCEV;
10441 }
10442 
10443 const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() {
10444   if (!BackedgeCount) {
10445     SCEVUnionPredicate BackedgePred;
10446     BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, BackedgePred);
10447     addPredicate(BackedgePred);
10448   }
10449   return BackedgeCount;
10450 }
10451 
10452 void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
10453   if (Preds.implies(&Pred))
10454     return;
10455   Preds.add(&Pred);
10456   updateGeneration();
10457 }
10458 
10459 const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const {
10460   return Preds;
10461 }
10462 
10463 void PredicatedScalarEvolution::updateGeneration() {
10464   // If the generation number wrapped recompute everything.
10465   if (++Generation == 0) {
10466     for (auto &II : RewriteMap) {
10467       const SCEV *Rewritten = II.second.second;
10468       II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)};
10469     }
10470   }
10471 }
10472 
10473 void PredicatedScalarEvolution::setNoOverflow(
10474     Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
10475   const SCEV *Expr = getSCEV(V);
10476   const auto *AR = cast<SCEVAddRecExpr>(Expr);
10477 
10478   auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
10479 
10480   // Clear the statically implied flags.
10481   Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
10482   addPredicate(*SE.getWrapPredicate(AR, Flags));
10483 
10484   auto II = FlagsMap.insert({V, Flags});
10485   if (!II.second)
10486     II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
10487 }
10488 
10489 bool PredicatedScalarEvolution::hasNoOverflow(
10490     Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
10491   const SCEV *Expr = getSCEV(V);
10492   const auto *AR = cast<SCEVAddRecExpr>(Expr);
10493 
10494   Flags = SCEVWrapPredicate::clearFlags(
10495       Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
10496 
10497   auto II = FlagsMap.find(V);
10498 
10499   if (II != FlagsMap.end())
10500     Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
10501 
10502   return Flags == SCEVWrapPredicate::IncrementAnyWrap;
10503 }
10504 
10505 const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
10506   const SCEV *Expr = this->getSCEV(V);
10507   SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
10508   auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
10509 
10510   if (!New)
10511     return nullptr;
10512 
10513   for (auto *P : NewPreds)
10514     Preds.add(P);
10515 
10516   updateGeneration();
10517   RewriteMap[SE.getSCEV(V)] = {Generation, New};
10518   return New;
10519 }
10520 
10521 PredicatedScalarEvolution::PredicatedScalarEvolution(
10522     const PredicatedScalarEvolution &Init)
10523     : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds),
10524       Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
10525   for (const auto &I : Init.FlagsMap)
10526     FlagsMap.insert(I);
10527 }
10528 
10529 void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
10530   // For each block.
10531   for (auto *BB : L.getBlocks())
10532     for (auto &I : *BB) {
10533       if (!SE.isSCEVable(I.getType()))
10534         continue;
10535 
10536       auto *Expr = SE.getSCEV(&I);
10537       auto II = RewriteMap.find(Expr);
10538 
10539       if (II == RewriteMap.end())
10540         continue;
10541 
10542       // Don't print things that are not interesting.
10543       if (II->second.second == Expr)
10544         continue;
10545 
10546       OS.indent(Depth) << "[PSE]" << I << ":\n";
10547       OS.indent(Depth + 2) << *Expr << "\n";
10548       OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
10549     }
10550 }
10551