xref: /llvm-project/llvm/lib/Analysis/ScalarEvolution.cpp (revision b4916918e5219ac25a5b6472c5638450f867d975)
1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the scalar evolution analysis
10 // engine, which is used primarily to analyze expressions involving induction
11 // variables in loops.
12 //
13 // There are several aspects to this library.  First is the representation of
14 // scalar expressions, which are represented as subclasses of the SCEV class.
15 // These classes are used to represent certain types of subexpressions that we
16 // can handle. We only create one SCEV of a particular shape, so
17 // pointer-comparisons for equality are legal.
18 //
19 // One important aspect of the SCEV objects is that they are never cyclic, even
20 // if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
21 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
22 // recurrence) then we represent it directly as a recurrence node, otherwise we
23 // represent it as a SCEVUnknown node.
24 //
25 // In addition to being able to represent expressions of various types, we also
26 // have folders that are used to build the *canonical* representation for a
27 // particular expression.  These folders are capable of using a variety of
28 // rewrite rules to simplify the expressions.
29 //
30 // Once the folders are defined, we can implement the more interesting
31 // higher-level code, such as the code that recognizes PHI nodes of various
32 // types, computes the execution count of a loop, etc.
33 //
34 // TODO: We should use these routines and value representations to implement
35 // dependence analysis!
36 //
37 //===----------------------------------------------------------------------===//
38 //
39 // There are several good references for the techniques used in this analysis.
40 //
41 //  Chains of recurrences -- a method to expedite the evaluation
42 //  of closed-form functions
43 //  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44 //
45 //  On computational properties of chains of recurrences
46 //  Eugene V. Zima
47 //
48 //  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49 //  Robert A. van Engelen
50 //
51 //  Efficient Symbolic Analysis for Optimizing Compilers
52 //  Robert A. van Engelen
53 //
54 //  Using the chains of recurrences algebra for data dependence testing and
55 //  induction variable substitution
56 //  MS Thesis, Johnie Birch
57 //
58 //===----------------------------------------------------------------------===//
59 
60 #include "llvm/Analysis/ScalarEvolution.h"
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/ADT/ArrayRef.h"
63 #include "llvm/ADT/DenseMap.h"
64 #include "llvm/ADT/DepthFirstIterator.h"
65 #include "llvm/ADT/EquivalenceClasses.h"
66 #include "llvm/ADT/FoldingSet.h"
67 #include "llvm/ADT/None.h"
68 #include "llvm/ADT/Optional.h"
69 #include "llvm/ADT/STLExtras.h"
70 #include "llvm/ADT/ScopeExit.h"
71 #include "llvm/ADT/Sequence.h"
72 #include "llvm/ADT/SetVector.h"
73 #include "llvm/ADT/SmallPtrSet.h"
74 #include "llvm/ADT/SmallSet.h"
75 #include "llvm/ADT/SmallVector.h"
76 #include "llvm/ADT/Statistic.h"
77 #include "llvm/ADT/StringRef.h"
78 #include "llvm/Analysis/AssumptionCache.h"
79 #include "llvm/Analysis/ConstantFolding.h"
80 #include "llvm/Analysis/InstructionSimplify.h"
81 #include "llvm/Analysis/LoopInfo.h"
82 #include "llvm/Analysis/ScalarEvolutionDivision.h"
83 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
84 #include "llvm/Analysis/TargetLibraryInfo.h"
85 #include "llvm/Analysis/ValueTracking.h"
86 #include "llvm/Config/llvm-config.h"
87 #include "llvm/IR/Argument.h"
88 #include "llvm/IR/BasicBlock.h"
89 #include "llvm/IR/CFG.h"
90 #include "llvm/IR/Constant.h"
91 #include "llvm/IR/ConstantRange.h"
92 #include "llvm/IR/Constants.h"
93 #include "llvm/IR/DataLayout.h"
94 #include "llvm/IR/DerivedTypes.h"
95 #include "llvm/IR/Dominators.h"
96 #include "llvm/IR/Function.h"
97 #include "llvm/IR/GlobalAlias.h"
98 #include "llvm/IR/GlobalValue.h"
99 #include "llvm/IR/GlobalVariable.h"
100 #include "llvm/IR/InstIterator.h"
101 #include "llvm/IR/InstrTypes.h"
102 #include "llvm/IR/Instruction.h"
103 #include "llvm/IR/Instructions.h"
104 #include "llvm/IR/IntrinsicInst.h"
105 #include "llvm/IR/Intrinsics.h"
106 #include "llvm/IR/LLVMContext.h"
107 #include "llvm/IR/Metadata.h"
108 #include "llvm/IR/Operator.h"
109 #include "llvm/IR/PatternMatch.h"
110 #include "llvm/IR/Type.h"
111 #include "llvm/IR/Use.h"
112 #include "llvm/IR/User.h"
113 #include "llvm/IR/Value.h"
114 #include "llvm/IR/Verifier.h"
115 #include "llvm/InitializePasses.h"
116 #include "llvm/Pass.h"
117 #include "llvm/Support/Casting.h"
118 #include "llvm/Support/CommandLine.h"
119 #include "llvm/Support/Compiler.h"
120 #include "llvm/Support/Debug.h"
121 #include "llvm/Support/ErrorHandling.h"
122 #include "llvm/Support/KnownBits.h"
123 #include "llvm/Support/SaveAndRestore.h"
124 #include "llvm/Support/raw_ostream.h"
125 #include <algorithm>
126 #include <cassert>
127 #include <climits>
128 #include <cstddef>
129 #include <cstdint>
130 #include <cstdlib>
131 #include <map>
132 #include <memory>
133 #include <tuple>
134 #include <utility>
135 #include <vector>
136 
137 using namespace llvm;
138 
139 #define DEBUG_TYPE "scalar-evolution"
140 
141 STATISTIC(NumArrayLenItCounts,
142           "Number of trip counts computed with array length");
143 STATISTIC(NumTripCountsComputed,
144           "Number of loops with predictable loop counts");
145 STATISTIC(NumTripCountsNotComputed,
146           "Number of loops without predictable loop counts");
147 STATISTIC(NumBruteForceTripCountsComputed,
148           "Number of loops with trip counts computed by force");
149 
150 static cl::opt<unsigned>
151 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
152                         cl::ZeroOrMore,
153                         cl::desc("Maximum number of iterations SCEV will "
154                                  "symbolically execute a constant "
155                                  "derived loop"),
156                         cl::init(100));
157 
158 // FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean.
159 static cl::opt<bool> VerifySCEV(
160     "verify-scev", cl::Hidden,
161     cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
162 static cl::opt<bool> VerifySCEVStrict(
163     "verify-scev-strict", cl::Hidden,
164     cl::desc("Enable stricter verification with -verify-scev is passed"));
165 static cl::opt<bool>
166     VerifySCEVMap("verify-scev-maps", cl::Hidden,
167                   cl::desc("Verify no dangling value in ScalarEvolution's "
168                            "ExprValueMap (slow)"));
169 
170 static cl::opt<bool> VerifyIR(
171     "scev-verify-ir", cl::Hidden,
172     cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
173     cl::init(false));
174 
175 static cl::opt<unsigned> MulOpsInlineThreshold(
176     "scev-mulops-inline-threshold", cl::Hidden,
177     cl::desc("Threshold for inlining multiplication operands into a SCEV"),
178     cl::init(32));
179 
180 static cl::opt<unsigned> AddOpsInlineThreshold(
181     "scev-addops-inline-threshold", cl::Hidden,
182     cl::desc("Threshold for inlining addition operands into a SCEV"),
183     cl::init(500));
184 
185 static cl::opt<unsigned> MaxSCEVCompareDepth(
186     "scalar-evolution-max-scev-compare-depth", cl::Hidden,
187     cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
188     cl::init(32));
189 
190 static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth(
191     "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
192     cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
193     cl::init(2));
194 
195 static cl::opt<unsigned> MaxValueCompareDepth(
196     "scalar-evolution-max-value-compare-depth", cl::Hidden,
197     cl::desc("Maximum depth of recursive value complexity comparisons"),
198     cl::init(2));
199 
200 static cl::opt<unsigned>
201     MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
202                   cl::desc("Maximum depth of recursive arithmetics"),
203                   cl::init(32));
204 
205 static cl::opt<unsigned> MaxConstantEvolvingDepth(
206     "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
207     cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
208 
209 static cl::opt<unsigned>
210     MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
211                  cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
212                  cl::init(8));
213 
214 static cl::opt<unsigned>
215     MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
216                   cl::desc("Max coefficients in AddRec during evolving"),
217                   cl::init(8));
218 
219 static cl::opt<unsigned>
220     HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
221                   cl::desc("Size of the expression which is considered huge"),
222                   cl::init(4096));
223 
224 static cl::opt<bool>
225 ClassifyExpressions("scalar-evolution-classify-expressions",
226     cl::Hidden, cl::init(true),
227     cl::desc("When printing analysis, include information on every instruction"));
228 
229 static cl::opt<bool> UseExpensiveRangeSharpening(
230     "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
231     cl::init(false),
232     cl::desc("Use more powerful methods of sharpening expression ranges. May "
233              "be costly in terms of compile time"));
234 
235 //===----------------------------------------------------------------------===//
236 //                           SCEV class definitions
237 //===----------------------------------------------------------------------===//
238 
239 //===----------------------------------------------------------------------===//
240 // Implementation of the SCEV class.
241 //
242 
243 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
244 LLVM_DUMP_METHOD void SCEV::dump() const {
245   print(dbgs());
246   dbgs() << '\n';
247 }
248 #endif
249 
250 void SCEV::print(raw_ostream &OS) const {
251   switch (getSCEVType()) {
252   case scConstant:
253     cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
254     return;
255   case scPtrToInt: {
256     const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
257     const SCEV *Op = PtrToInt->getOperand();
258     OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
259        << *PtrToInt->getType() << ")";
260     return;
261   }
262   case scTruncate: {
263     const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
264     const SCEV *Op = Trunc->getOperand();
265     OS << "(trunc " << *Op->getType() << " " << *Op << " to "
266        << *Trunc->getType() << ")";
267     return;
268   }
269   case scZeroExtend: {
270     const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
271     const SCEV *Op = ZExt->getOperand();
272     OS << "(zext " << *Op->getType() << " " << *Op << " to "
273        << *ZExt->getType() << ")";
274     return;
275   }
276   case scSignExtend: {
277     const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
278     const SCEV *Op = SExt->getOperand();
279     OS << "(sext " << *Op->getType() << " " << *Op << " to "
280        << *SExt->getType() << ")";
281     return;
282   }
283   case scAddRecExpr: {
284     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
285     OS << "{" << *AR->getOperand(0);
286     for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
287       OS << ",+," << *AR->getOperand(i);
288     OS << "}<";
289     if (AR->hasNoUnsignedWrap())
290       OS << "nuw><";
291     if (AR->hasNoSignedWrap())
292       OS << "nsw><";
293     if (AR->hasNoSelfWrap() &&
294         !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
295       OS << "nw><";
296     AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
297     OS << ">";
298     return;
299   }
300   case scAddExpr:
301   case scMulExpr:
302   case scUMaxExpr:
303   case scSMaxExpr:
304   case scUMinExpr:
305   case scSMinExpr: {
306     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
307     const char *OpStr = nullptr;
308     switch (NAry->getSCEVType()) {
309     case scAddExpr: OpStr = " + "; break;
310     case scMulExpr: OpStr = " * "; break;
311     case scUMaxExpr: OpStr = " umax "; break;
312     case scSMaxExpr: OpStr = " smax "; break;
313     case scUMinExpr:
314       OpStr = " umin ";
315       break;
316     case scSMinExpr:
317       OpStr = " smin ";
318       break;
319     default:
320       llvm_unreachable("There are no other nary expression types.");
321     }
322     OS << "(";
323     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
324          I != E; ++I) {
325       OS << **I;
326       if (std::next(I) != E)
327         OS << OpStr;
328     }
329     OS << ")";
330     switch (NAry->getSCEVType()) {
331     case scAddExpr:
332     case scMulExpr:
333       if (NAry->hasNoUnsignedWrap())
334         OS << "<nuw>";
335       if (NAry->hasNoSignedWrap())
336         OS << "<nsw>";
337       break;
338     default:
339       // Nothing to print for other nary expressions.
340       break;
341     }
342     return;
343   }
344   case scUDivExpr: {
345     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
346     OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
347     return;
348   }
349   case scUnknown: {
350     const SCEVUnknown *U = cast<SCEVUnknown>(this);
351     Type *AllocTy;
352     if (U->isSizeOf(AllocTy)) {
353       OS << "sizeof(" << *AllocTy << ")";
354       return;
355     }
356     if (U->isAlignOf(AllocTy)) {
357       OS << "alignof(" << *AllocTy << ")";
358       return;
359     }
360 
361     Type *CTy;
362     Constant *FieldNo;
363     if (U->isOffsetOf(CTy, FieldNo)) {
364       OS << "offsetof(" << *CTy << ", ";
365       FieldNo->printAsOperand(OS, false);
366       OS << ")";
367       return;
368     }
369 
370     // Otherwise just print it normally.
371     U->getValue()->printAsOperand(OS, false);
372     return;
373   }
374   case scCouldNotCompute:
375     OS << "***COULDNOTCOMPUTE***";
376     return;
377   }
378   llvm_unreachable("Unknown SCEV kind!");
379 }
380 
381 Type *SCEV::getType() const {
382   switch (getSCEVType()) {
383   case scConstant:
384     return cast<SCEVConstant>(this)->getType();
385   case scPtrToInt:
386   case scTruncate:
387   case scZeroExtend:
388   case scSignExtend:
389     return cast<SCEVCastExpr>(this)->getType();
390   case scAddRecExpr:
391   case scMulExpr:
392   case scUMaxExpr:
393   case scSMaxExpr:
394   case scUMinExpr:
395   case scSMinExpr:
396     return cast<SCEVNAryExpr>(this)->getType();
397   case scAddExpr:
398     return cast<SCEVAddExpr>(this)->getType();
399   case scUDivExpr:
400     return cast<SCEVUDivExpr>(this)->getType();
401   case scUnknown:
402     return cast<SCEVUnknown>(this)->getType();
403   case scCouldNotCompute:
404     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
405   }
406   llvm_unreachable("Unknown SCEV kind!");
407 }
408 
409 bool SCEV::isZero() const {
410   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
411     return SC->getValue()->isZero();
412   return false;
413 }
414 
415 bool SCEV::isOne() const {
416   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
417     return SC->getValue()->isOne();
418   return false;
419 }
420 
421 bool SCEV::isAllOnesValue() const {
422   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
423     return SC->getValue()->isMinusOne();
424   return false;
425 }
426 
427 bool SCEV::isNonConstantNegative() const {
428   const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
429   if (!Mul) return false;
430 
431   // If there is a constant factor, it will be first.
432   const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
433   if (!SC) return false;
434 
435   // Return true if the value is negative, this matches things like (-42 * V).
436   return SC->getAPInt().isNegative();
437 }
438 
439 SCEVCouldNotCompute::SCEVCouldNotCompute() :
440   SCEV(FoldingSetNodeIDRef(), scCouldNotCompute, 0) {}
441 
442 bool SCEVCouldNotCompute::classof(const SCEV *S) {
443   return S->getSCEVType() == scCouldNotCompute;
444 }
445 
446 const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
447   FoldingSetNodeID ID;
448   ID.AddInteger(scConstant);
449   ID.AddPointer(V);
450   void *IP = nullptr;
451   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
452   SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
453   UniqueSCEVs.InsertNode(S, IP);
454   return S;
455 }
456 
457 const SCEV *ScalarEvolution::getConstant(const APInt &Val) {
458   return getConstant(ConstantInt::get(getContext(), Val));
459 }
460 
461 const SCEV *
462 ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
463   IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
464   return getConstant(ConstantInt::get(ITy, V, isSigned));
465 }
466 
467 SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
468                            const SCEV *op, Type *ty)
469     : SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) {
470   Operands[0] = op;
471 }
472 
473 SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
474                                    Type *ITy)
475     : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
476   assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
477          "Must be a non-bit-width-changing pointer-to-integer cast!");
478 }
479 
480 SCEVIntegralCastExpr::SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID,
481                                            SCEVTypes SCEVTy, const SCEV *op,
482                                            Type *ty)
483     : SCEVCastExpr(ID, SCEVTy, op, ty) {}
484 
485 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
486                                    Type *ty)
487     : SCEVIntegralCastExpr(ID, scTruncate, op, ty) {
488   assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
489          "Cannot truncate non-integer value!");
490 }
491 
492 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
493                                        const SCEV *op, Type *ty)
494     : SCEVIntegralCastExpr(ID, scZeroExtend, op, ty) {
495   assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
496          "Cannot zero extend non-integer value!");
497 }
498 
499 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
500                                        const SCEV *op, Type *ty)
501     : SCEVIntegralCastExpr(ID, scSignExtend, op, ty) {
502   assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
503          "Cannot sign extend non-integer value!");
504 }
505 
506 void SCEVUnknown::deleted() {
507   // Clear this SCEVUnknown from various maps.
508   SE->forgetMemoizedResults(this);
509 
510   // Remove this SCEVUnknown from the uniquing map.
511   SE->UniqueSCEVs.RemoveNode(this);
512 
513   // Release the value.
514   setValPtr(nullptr);
515 }
516 
517 void SCEVUnknown::allUsesReplacedWith(Value *New) {
518   // Remove this SCEVUnknown from the uniquing map.
519   SE->UniqueSCEVs.RemoveNode(this);
520 
521   // Update this SCEVUnknown to point to the new value. This is needed
522   // because there may still be outstanding SCEVs which still point to
523   // this SCEVUnknown.
524   setValPtr(New);
525 }
526 
527 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
528   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
529     if (VCE->getOpcode() == Instruction::PtrToInt)
530       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
531         if (CE->getOpcode() == Instruction::GetElementPtr &&
532             CE->getOperand(0)->isNullValue() &&
533             CE->getNumOperands() == 2)
534           if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
535             if (CI->isOne()) {
536               AllocTy = cast<PointerType>(CE->getOperand(0)->getType())
537                                  ->getElementType();
538               return true;
539             }
540 
541   return false;
542 }
543 
544 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
545   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
546     if (VCE->getOpcode() == Instruction::PtrToInt)
547       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
548         if (CE->getOpcode() == Instruction::GetElementPtr &&
549             CE->getOperand(0)->isNullValue()) {
550           Type *Ty =
551             cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
552           if (StructType *STy = dyn_cast<StructType>(Ty))
553             if (!STy->isPacked() &&
554                 CE->getNumOperands() == 3 &&
555                 CE->getOperand(1)->isNullValue()) {
556               if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
557                 if (CI->isOne() &&
558                     STy->getNumElements() == 2 &&
559                     STy->getElementType(0)->isIntegerTy(1)) {
560                   AllocTy = STy->getElementType(1);
561                   return true;
562                 }
563             }
564         }
565 
566   return false;
567 }
568 
569 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
570   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
571     if (VCE->getOpcode() == Instruction::PtrToInt)
572       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
573         if (CE->getOpcode() == Instruction::GetElementPtr &&
574             CE->getNumOperands() == 3 &&
575             CE->getOperand(0)->isNullValue() &&
576             CE->getOperand(1)->isNullValue()) {
577           Type *Ty =
578             cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
579           // Ignore vector types here so that ScalarEvolutionExpander doesn't
580           // emit getelementptrs that index into vectors.
581           if (Ty->isStructTy() || Ty->isArrayTy()) {
582             CTy = Ty;
583             FieldNo = CE->getOperand(2);
584             return true;
585           }
586         }
587 
588   return false;
589 }
590 
591 //===----------------------------------------------------------------------===//
592 //                               SCEV Utilities
593 //===----------------------------------------------------------------------===//
594 
595 /// Compare the two values \p LV and \p RV in terms of their "complexity" where
596 /// "complexity" is a partial (and somewhat ad-hoc) relation used to order
597 /// operands in SCEV expressions.  \p EqCache is a set of pairs of values that
598 /// have been previously deemed to be "equally complex" by this routine.  It is
599 /// intended to avoid exponential time complexity in cases like:
600 ///
601 ///   %a = f(%x, %y)
602 ///   %b = f(%a, %a)
603 ///   %c = f(%b, %b)
604 ///
605 ///   %d = f(%x, %y)
606 ///   %e = f(%d, %d)
607 ///   %f = f(%e, %e)
608 ///
609 ///   CompareValueComplexity(%f, %c)
610 ///
611 /// Since we do not continue running this routine on expression trees once we
612 /// have seen unequal values, there is no need to track them in the cache.
613 static int
614 CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
615                        const LoopInfo *const LI, Value *LV, Value *RV,
616                        unsigned Depth) {
617   if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
618     return 0;
619 
620   // Order pointer values after integer values. This helps SCEVExpander form
621   // GEPs.
622   bool LIsPointer = LV->getType()->isPointerTy(),
623        RIsPointer = RV->getType()->isPointerTy();
624   if (LIsPointer != RIsPointer)
625     return (int)LIsPointer - (int)RIsPointer;
626 
627   // Compare getValueID values.
628   unsigned LID = LV->getValueID(), RID = RV->getValueID();
629   if (LID != RID)
630     return (int)LID - (int)RID;
631 
632   // Sort arguments by their position.
633   if (const auto *LA = dyn_cast<Argument>(LV)) {
634     const auto *RA = cast<Argument>(RV);
635     unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
636     return (int)LArgNo - (int)RArgNo;
637   }
638 
639   if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
640     const auto *RGV = cast<GlobalValue>(RV);
641 
642     const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
643       auto LT = GV->getLinkage();
644       return !(GlobalValue::isPrivateLinkage(LT) ||
645                GlobalValue::isInternalLinkage(LT));
646     };
647 
648     // Use the names to distinguish the two values, but only if the
649     // names are semantically important.
650     if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
651       return LGV->getName().compare(RGV->getName());
652   }
653 
654   // For instructions, compare their loop depth, and their operand count.  This
655   // is pretty loose.
656   if (const auto *LInst = dyn_cast<Instruction>(LV)) {
657     const auto *RInst = cast<Instruction>(RV);
658 
659     // Compare loop depths.
660     const BasicBlock *LParent = LInst->getParent(),
661                      *RParent = RInst->getParent();
662     if (LParent != RParent) {
663       unsigned LDepth = LI->getLoopDepth(LParent),
664                RDepth = LI->getLoopDepth(RParent);
665       if (LDepth != RDepth)
666         return (int)LDepth - (int)RDepth;
667     }
668 
669     // Compare the number of operands.
670     unsigned LNumOps = LInst->getNumOperands(),
671              RNumOps = RInst->getNumOperands();
672     if (LNumOps != RNumOps)
673       return (int)LNumOps - (int)RNumOps;
674 
675     for (unsigned Idx : seq(0u, LNumOps)) {
676       int Result =
677           CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
678                                  RInst->getOperand(Idx), Depth + 1);
679       if (Result != 0)
680         return Result;
681     }
682   }
683 
684   EqCacheValue.unionSets(LV, RV);
685   return 0;
686 }
687 
688 // Return negative, zero, or positive, if LHS is less than, equal to, or greater
689 // than RHS, respectively. A three-way result allows recursive comparisons to be
690 // more efficient.
691 static int CompareSCEVComplexity(
692     EquivalenceClasses<const SCEV *> &EqCacheSCEV,
693     EquivalenceClasses<const Value *> &EqCacheValue,
694     const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS,
695     DominatorTree &DT, unsigned Depth = 0) {
696   // Fast-path: SCEVs are uniqued so we can do a quick equality check.
697   if (LHS == RHS)
698     return 0;
699 
700   // Primarily, sort the SCEVs by their getSCEVType().
701   SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
702   if (LType != RType)
703     return (int)LType - (int)RType;
704 
705   if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.isEquivalent(LHS, RHS))
706     return 0;
707   // Aside from the getSCEVType() ordering, the particular ordering
708   // isn't very important except that it's beneficial to be consistent,
709   // so that (a + b) and (b + a) don't end up as different expressions.
710   switch (LType) {
711   case scUnknown: {
712     const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
713     const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
714 
715     int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
716                                    RU->getValue(), Depth + 1);
717     if (X == 0)
718       EqCacheSCEV.unionSets(LHS, RHS);
719     return X;
720   }
721 
722   case scConstant: {
723     const SCEVConstant *LC = cast<SCEVConstant>(LHS);
724     const SCEVConstant *RC = cast<SCEVConstant>(RHS);
725 
726     // Compare constant values.
727     const APInt &LA = LC->getAPInt();
728     const APInt &RA = RC->getAPInt();
729     unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
730     if (LBitWidth != RBitWidth)
731       return (int)LBitWidth - (int)RBitWidth;
732     return LA.ult(RA) ? -1 : 1;
733   }
734 
735   case scAddRecExpr: {
736     const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
737     const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
738 
739     // There is always a dominance between two recs that are used by one SCEV,
740     // so we can safely sort recs by loop header dominance. We require such
741     // order in getAddExpr.
742     const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
743     if (LLoop != RLoop) {
744       const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
745       assert(LHead != RHead && "Two loops share the same header?");
746       if (DT.dominates(LHead, RHead))
747         return 1;
748       else
749         assert(DT.dominates(RHead, LHead) &&
750                "No dominance between recurrences used by one SCEV?");
751       return -1;
752     }
753 
754     // Addrec complexity grows with operand count.
755     unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
756     if (LNumOps != RNumOps)
757       return (int)LNumOps - (int)RNumOps;
758 
759     // Lexicographically compare.
760     for (unsigned i = 0; i != LNumOps; ++i) {
761       int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
762                                     LA->getOperand(i), RA->getOperand(i), DT,
763                                     Depth + 1);
764       if (X != 0)
765         return X;
766     }
767     EqCacheSCEV.unionSets(LHS, RHS);
768     return 0;
769   }
770 
771   case scAddExpr:
772   case scMulExpr:
773   case scSMaxExpr:
774   case scUMaxExpr:
775   case scSMinExpr:
776   case scUMinExpr: {
777     const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
778     const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
779 
780     // Lexicographically compare n-ary expressions.
781     unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
782     if (LNumOps != RNumOps)
783       return (int)LNumOps - (int)RNumOps;
784 
785     for (unsigned i = 0; i != LNumOps; ++i) {
786       int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
787                                     LC->getOperand(i), RC->getOperand(i), DT,
788                                     Depth + 1);
789       if (X != 0)
790         return X;
791     }
792     EqCacheSCEV.unionSets(LHS, RHS);
793     return 0;
794   }
795 
796   case scUDivExpr: {
797     const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
798     const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
799 
800     // Lexicographically compare udiv expressions.
801     int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
802                                   RC->getLHS(), DT, Depth + 1);
803     if (X != 0)
804       return X;
805     X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
806                               RC->getRHS(), DT, Depth + 1);
807     if (X == 0)
808       EqCacheSCEV.unionSets(LHS, RHS);
809     return X;
810   }
811 
812   case scPtrToInt:
813   case scTruncate:
814   case scZeroExtend:
815   case scSignExtend: {
816     const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
817     const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
818 
819     // Compare cast expressions by operand.
820     int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
821                                   LC->getOperand(), RC->getOperand(), DT,
822                                   Depth + 1);
823     if (X == 0)
824       EqCacheSCEV.unionSets(LHS, RHS);
825     return X;
826   }
827 
828   case scCouldNotCompute:
829     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
830   }
831   llvm_unreachable("Unknown SCEV kind!");
832 }
833 
834 /// Given a list of SCEV objects, order them by their complexity, and group
835 /// objects of the same complexity together by value.  When this routine is
836 /// finished, we know that any duplicates in the vector are consecutive and that
837 /// complexity is monotonically increasing.
838 ///
839 /// Note that we go take special precautions to ensure that we get deterministic
840 /// results from this routine.  In other words, we don't want the results of
841 /// this to depend on where the addresses of various SCEV objects happened to
842 /// land in memory.
843 static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
844                               LoopInfo *LI, DominatorTree &DT) {
845   if (Ops.size() < 2) return;  // Noop
846 
847   EquivalenceClasses<const SCEV *> EqCacheSCEV;
848   EquivalenceClasses<const Value *> EqCacheValue;
849   if (Ops.size() == 2) {
850     // This is the common case, which also happens to be trivially simple.
851     // Special case it.
852     const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
853     if (CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, RHS, LHS, DT) < 0)
854       std::swap(LHS, RHS);
855     return;
856   }
857 
858   // Do the rough sort by complexity.
859   llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
860     return CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT) <
861            0;
862   });
863 
864   // Now that we are sorted by complexity, group elements of the same
865   // complexity.  Note that this is, at worst, N^2, but the vector is likely to
866   // be extremely short in practice.  Note that we take this approach because we
867   // do not want to depend on the addresses of the objects we are grouping.
868   for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
869     const SCEV *S = Ops[i];
870     unsigned Complexity = S->getSCEVType();
871 
872     // If there are any objects of the same complexity and same value as this
873     // one, group them.
874     for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
875       if (Ops[j] == S) { // Found a duplicate.
876         // Move it to immediately after i'th element.
877         std::swap(Ops[i+1], Ops[j]);
878         ++i;   // no need to rescan it.
879         if (i == e-2) return;  // Done!
880       }
881     }
882   }
883 }
884 
885 /// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
886 /// least HugeExprThreshold nodes).
887 static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) {
888   return any_of(Ops, [](const SCEV *S) {
889     return S->getExpressionSize() >= HugeExprThreshold;
890   });
891 }
892 
893 //===----------------------------------------------------------------------===//
894 //                      Simple SCEV method implementations
895 //===----------------------------------------------------------------------===//
896 
897 /// Compute BC(It, K).  The result has width W.  Assume, K > 0.
898 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
899                                        ScalarEvolution &SE,
900                                        Type *ResultTy) {
901   // Handle the simplest case efficiently.
902   if (K == 1)
903     return SE.getTruncateOrZeroExtend(It, ResultTy);
904 
905   // We are using the following formula for BC(It, K):
906   //
907   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
908   //
909   // Suppose, W is the bitwidth of the return value.  We must be prepared for
910   // overflow.  Hence, we must assure that the result of our computation is
911   // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
912   // safe in modular arithmetic.
913   //
914   // However, this code doesn't use exactly that formula; the formula it uses
915   // is something like the following, where T is the number of factors of 2 in
916   // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
917   // exponentiation:
918   //
919   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
920   //
921   // This formula is trivially equivalent to the previous formula.  However,
922   // this formula can be implemented much more efficiently.  The trick is that
923   // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
924   // arithmetic.  To do exact division in modular arithmetic, all we have
925   // to do is multiply by the inverse.  Therefore, this step can be done at
926   // width W.
927   //
928   // The next issue is how to safely do the division by 2^T.  The way this
929   // is done is by doing the multiplication step at a width of at least W + T
930   // bits.  This way, the bottom W+T bits of the product are accurate. Then,
931   // when we perform the division by 2^T (which is equivalent to a right shift
932   // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
933   // truncated out after the division by 2^T.
934   //
935   // In comparison to just directly using the first formula, this technique
936   // is much more efficient; using the first formula requires W * K bits,
937   // but this formula less than W + K bits. Also, the first formula requires
938   // a division step, whereas this formula only requires multiplies and shifts.
939   //
940   // It doesn't matter whether the subtraction step is done in the calculation
941   // width or the input iteration count's width; if the subtraction overflows,
942   // the result must be zero anyway.  We prefer here to do it in the width of
943   // the induction variable because it helps a lot for certain cases; CodeGen
944   // isn't smart enough to ignore the overflow, which leads to much less
945   // efficient code if the width of the subtraction is wider than the native
946   // register width.
947   //
948   // (It's possible to not widen at all by pulling out factors of 2 before
949   // the multiplication; for example, K=2 can be calculated as
950   // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
951   // extra arithmetic, so it's not an obvious win, and it gets
952   // much more complicated for K > 3.)
953 
954   // Protection from insane SCEVs; this bound is conservative,
955   // but it probably doesn't matter.
956   if (K > 1000)
957     return SE.getCouldNotCompute();
958 
959   unsigned W = SE.getTypeSizeInBits(ResultTy);
960 
961   // Calculate K! / 2^T and T; we divide out the factors of two before
962   // multiplying for calculating K! / 2^T to avoid overflow.
963   // Other overflow doesn't matter because we only care about the bottom
964   // W bits of the result.
965   APInt OddFactorial(W, 1);
966   unsigned T = 1;
967   for (unsigned i = 3; i <= K; ++i) {
968     APInt Mult(W, i);
969     unsigned TwoFactors = Mult.countTrailingZeros();
970     T += TwoFactors;
971     Mult.lshrInPlace(TwoFactors);
972     OddFactorial *= Mult;
973   }
974 
975   // We need at least W + T bits for the multiplication step
976   unsigned CalculationBits = W + T;
977 
978   // Calculate 2^T, at width T+W.
979   APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
980 
981   // Calculate the multiplicative inverse of K! / 2^T;
982   // this multiplication factor will perform the exact division by
983   // K! / 2^T.
984   APInt Mod = APInt::getSignedMinValue(W+1);
985   APInt MultiplyFactor = OddFactorial.zext(W+1);
986   MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
987   MultiplyFactor = MultiplyFactor.trunc(W);
988 
989   // Calculate the product, at width T+W
990   IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
991                                                       CalculationBits);
992   const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
993   for (unsigned i = 1; i != K; ++i) {
994     const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
995     Dividend = SE.getMulExpr(Dividend,
996                              SE.getTruncateOrZeroExtend(S, CalculationTy));
997   }
998 
999   // Divide by 2^T
1000   const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1001 
1002   // Truncate the result, and divide by K! / 2^T.
1003 
1004   return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1005                        SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1006 }
1007 
1008 /// Return the value of this chain of recurrences at the specified iteration
1009 /// number.  We can evaluate this recurrence by multiplying each element in the
1010 /// chain by the binomial coefficient corresponding to it.  In other words, we
1011 /// can evaluate {A,+,B,+,C,+,D} as:
1012 ///
1013 ///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1014 ///
1015 /// where BC(It, k) stands for binomial coefficient.
1016 const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
1017                                                 ScalarEvolution &SE) const {
1018   const SCEV *Result = getStart();
1019   for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
1020     // The computation is correct in the face of overflow provided that the
1021     // multiplication is performed _after_ the evaluation of the binomial
1022     // coefficient.
1023     const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType());
1024     if (isa<SCEVCouldNotCompute>(Coeff))
1025       return Coeff;
1026 
1027     Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
1028   }
1029   return Result;
1030 }
1031 
1032 //===----------------------------------------------------------------------===//
1033 //                    SCEV Expression folder implementations
1034 //===----------------------------------------------------------------------===//
1035 
1036 const SCEV *ScalarEvolution::getPtrToIntExpr(const SCEV *Op, Type *Ty,
1037                                              unsigned Depth) {
1038   assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1039 
1040   // We could be called with an integer-typed operands during SCEV rewrites.
1041   // Since the operand is an integer already, just perform zext/trunc/self cast.
1042   if (!Op->getType()->isPointerTy())
1043     return getTruncateOrZeroExtend(Op, Ty);
1044 
1045   FoldingSetNodeID ID;
1046   ID.AddInteger(scPtrToInt);
1047   ID.AddPointer(Op);
1048   void *IP = nullptr;
1049   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1050     return getTruncateOrZeroExtend(S, Ty, Depth);
1051 
1052   assert((isa<SCEVNAryExpr>(Op) || isa<SCEVUnknown>(Op)) &&
1053          "We can only gen an nary expression, or an unknown here.");
1054 
1055   Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1056 
1057   // If the input operand is not an unknown (and thus is an nary expression),
1058   // sink the cast to operands, so that the operation is performed on integers,
1059   // and we eventually end up with just an ptrtoint(unknown).
1060   if (const SCEVNAryExpr *NaryExpr = dyn_cast<SCEVNAryExpr>(Op)) {
1061     SmallVector<const SCEV *, 2> NewOps;
1062     NewOps.reserve(NaryExpr->getNumOperands());
1063     for (const SCEV *Op : NaryExpr->operands())
1064       NewOps.push_back(Op->getType()->isPointerTy()
1065                            ? getPtrToIntExpr(Op, IntPtrTy, Depth + 1)
1066                            : Op);
1067     const SCEV *NewNaryExpr = nullptr;
1068     switch (SCEVTypes SCEVType = NaryExpr->getSCEVType()) {
1069     case scAddExpr:
1070       NewNaryExpr = getAddExpr(NewOps, NaryExpr->getNoWrapFlags(), Depth + 1);
1071       break;
1072     case scAddRecExpr:
1073       NewNaryExpr =
1074           getAddRecExpr(NewOps, cast<SCEVAddRecExpr>(NaryExpr)->getLoop(),
1075                         NaryExpr->getNoWrapFlags());
1076       break;
1077     case scUMaxExpr:
1078     case scSMaxExpr:
1079     case scUMinExpr:
1080     case scSMinExpr:
1081       NewNaryExpr = getMinMaxExpr(SCEVType, NewOps);
1082       break;
1083 
1084     case scMulExpr:
1085       NewNaryExpr = getMulExpr(NewOps, NaryExpr->getNoWrapFlags(), Depth + 1);
1086       break;
1087     case scUDivExpr:
1088       NewNaryExpr = getUDivExpr(NewOps[0], NewOps[1]);
1089       break;
1090     case scConstant:
1091     case scTruncate:
1092     case scZeroExtend:
1093     case scSignExtend:
1094     case scPtrToInt:
1095     case scUnknown:
1096     case scCouldNotCompute:
1097       llvm_unreachable("We can't get these types here.");
1098     }
1099     return getTruncateOrZeroExtend(NewNaryExpr, Ty, Depth);
1100   }
1101 
1102   // The cast wasn't folded; create an explicit cast node. We can reuse
1103   // the existing insert position since if we get here, we won't have
1104   // made any changes which would invalidate it.
1105   assert(getDataLayout().getTypeSizeInBits(getEffectiveSCEVType(
1106              Op->getType())) == getDataLayout().getTypeSizeInBits(IntPtrTy) &&
1107          "We can only model ptrtoint if SCEV's effective (integer) type is "
1108          "sufficiently wide to represent all possible pointer values.");
1109   SCEV *S = new (SCEVAllocator)
1110       SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1111   UniqueSCEVs.InsertNode(S, IP);
1112   addToLoopUseLists(S);
1113   return getTruncateOrZeroExtend(S, Ty, Depth);
1114 }
1115 
1116 const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
1117                                              unsigned Depth) {
1118   assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1119          "This is not a truncating conversion!");
1120   assert(isSCEVable(Ty) &&
1121          "This is not a conversion to a SCEVable type!");
1122   Ty = getEffectiveSCEVType(Ty);
1123 
1124   FoldingSetNodeID ID;
1125   ID.AddInteger(scTruncate);
1126   ID.AddPointer(Op);
1127   ID.AddPointer(Ty);
1128   void *IP = nullptr;
1129   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1130 
1131   // Fold if the operand is constant.
1132   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1133     return getConstant(
1134       cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1135 
1136   // trunc(trunc(x)) --> trunc(x)
1137   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1138     return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1139 
1140   // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1141   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1142     return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1143 
1144   // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1145   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1146     return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1147 
1148   if (Depth > MaxCastDepth) {
1149     SCEV *S =
1150         new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1151     UniqueSCEVs.InsertNode(S, IP);
1152     addToLoopUseLists(S);
1153     return S;
1154   }
1155 
1156   // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1157   // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1158   // if after transforming we have at most one truncate, not counting truncates
1159   // that replace other casts.
1160   if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1161     auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1162     SmallVector<const SCEV *, 4> Operands;
1163     unsigned numTruncs = 0;
1164     for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1165          ++i) {
1166       const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1167       if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1168           isa<SCEVTruncateExpr>(S))
1169         numTruncs++;
1170       Operands.push_back(S);
1171     }
1172     if (numTruncs < 2) {
1173       if (isa<SCEVAddExpr>(Op))
1174         return getAddExpr(Operands);
1175       else if (isa<SCEVMulExpr>(Op))
1176         return getMulExpr(Operands);
1177       else
1178         llvm_unreachable("Unexpected SCEV type for Op.");
1179     }
1180     // Although we checked in the beginning that ID is not in the cache, it is
1181     // possible that during recursion and different modification ID was inserted
1182     // into the cache. So if we find it, just return it.
1183     if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1184       return S;
1185   }
1186 
1187   // If the input value is a chrec scev, truncate the chrec's operands.
1188   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1189     SmallVector<const SCEV *, 4> Operands;
1190     for (const SCEV *Op : AddRec->operands())
1191       Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1192     return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1193   }
1194 
1195   // The cast wasn't folded; create an explicit cast node. We can reuse
1196   // the existing insert position since if we get here, we won't have
1197   // made any changes which would invalidate it.
1198   SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1199                                                  Op, Ty);
1200   UniqueSCEVs.InsertNode(S, IP);
1201   addToLoopUseLists(S);
1202   return S;
1203 }
1204 
1205 // Get the limit of a recurrence such that incrementing by Step cannot cause
1206 // signed overflow as long as the value of the recurrence within the
1207 // loop does not exceed this limit before incrementing.
1208 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1209                                                  ICmpInst::Predicate *Pred,
1210                                                  ScalarEvolution *SE) {
1211   unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1212   if (SE->isKnownPositive(Step)) {
1213     *Pred = ICmpInst::ICMP_SLT;
1214     return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
1215                            SE->getSignedRangeMax(Step));
1216   }
1217   if (SE->isKnownNegative(Step)) {
1218     *Pred = ICmpInst::ICMP_SGT;
1219     return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
1220                            SE->getSignedRangeMin(Step));
1221   }
1222   return nullptr;
1223 }
1224 
1225 // Get the limit of a recurrence such that incrementing by Step cannot cause
1226 // unsigned overflow as long as the value of the recurrence within the loop does
1227 // not exceed this limit before incrementing.
1228 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
1229                                                    ICmpInst::Predicate *Pred,
1230                                                    ScalarEvolution *SE) {
1231   unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1232   *Pred = ICmpInst::ICMP_ULT;
1233 
1234   return SE->getConstant(APInt::getMinValue(BitWidth) -
1235                          SE->getUnsignedRangeMax(Step));
1236 }
1237 
1238 namespace {
1239 
1240 struct ExtendOpTraitsBase {
1241   typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1242                                                           unsigned);
1243 };
1244 
1245 // Used to make code generic over signed and unsigned overflow.
1246 template <typename ExtendOp> struct ExtendOpTraits {
1247   // Members present:
1248   //
1249   // static const SCEV::NoWrapFlags WrapType;
1250   //
1251   // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1252   //
1253   // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1254   //                                           ICmpInst::Predicate *Pred,
1255   //                                           ScalarEvolution *SE);
1256 };
1257 
1258 template <>
1259 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1260   static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1261 
1262   static const GetExtendExprTy GetExtendExpr;
1263 
1264   static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1265                                              ICmpInst::Predicate *Pred,
1266                                              ScalarEvolution *SE) {
1267     return getSignedOverflowLimitForStep(Step, Pred, SE);
1268   }
1269 };
1270 
1271 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1272     SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr;
1273 
1274 template <>
1275 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1276   static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1277 
1278   static const GetExtendExprTy GetExtendExpr;
1279 
1280   static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1281                                              ICmpInst::Predicate *Pred,
1282                                              ScalarEvolution *SE) {
1283     return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1284   }
1285 };
1286 
1287 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1288     SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
1289 
1290 } // end anonymous namespace
1291 
1292 // The recurrence AR has been shown to have no signed/unsigned wrap or something
1293 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1294 // easily prove NSW/NUW for its preincrement or postincrement sibling. This
1295 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1296 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1297 // expression "Step + sext/zext(PreIncAR)" is congruent with
1298 // "sext/zext(PostIncAR)"
1299 template <typename ExtendOpTy>
1300 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1301                                         ScalarEvolution *SE, unsigned Depth) {
1302   auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1303   auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1304 
1305   const Loop *L = AR->getLoop();
1306   const SCEV *Start = AR->getStart();
1307   const SCEV *Step = AR->getStepRecurrence(*SE);
1308 
1309   // Check for a simple looking step prior to loop entry.
1310   const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1311   if (!SA)
1312     return nullptr;
1313 
1314   // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1315   // subtraction is expensive. For this purpose, perform a quick and dirty
1316   // difference, by checking for Step in the operand list.
1317   SmallVector<const SCEV *, 4> DiffOps;
1318   for (const SCEV *Op : SA->operands())
1319     if (Op != Step)
1320       DiffOps.push_back(Op);
1321 
1322   if (DiffOps.size() == SA->getNumOperands())
1323     return nullptr;
1324 
1325   // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1326   // `Step`:
1327 
1328   // 1. NSW/NUW flags on the step increment.
1329   auto PreStartFlags =
1330     ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW);
1331   const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1332   const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1333       SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1334 
1335   // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1336   // "S+X does not sign/unsign-overflow".
1337   //
1338 
1339   const SCEV *BECount = SE->getBackedgeTakenCount(L);
1340   if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1341       !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1342     return PreStart;
1343 
1344   // 2. Direct overflow check on the step operation's expression.
1345   unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1346   Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1347   const SCEV *OperandExtendedStart =
1348       SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1349                      (SE->*GetExtendExpr)(Step, WideTy, Depth));
1350   if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1351     if (PreAR && AR->getNoWrapFlags(WrapType)) {
1352       // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1353       // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1354       // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`.  Cache this fact.
1355       const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(WrapType);
1356     }
1357     return PreStart;
1358   }
1359 
1360   // 3. Loop precondition.
1361   ICmpInst::Predicate Pred;
1362   const SCEV *OverflowLimit =
1363       ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1364 
1365   if (OverflowLimit &&
1366       SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1367     return PreStart;
1368 
1369   return nullptr;
1370 }
1371 
1372 // Get the normalized zero or sign extended expression for this AddRec's Start.
1373 template <typename ExtendOpTy>
1374 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1375                                         ScalarEvolution *SE,
1376                                         unsigned Depth) {
1377   auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1378 
1379   const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1380   if (!PreStart)
1381     return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1382 
1383   return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1384                                              Depth),
1385                         (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1386 }
1387 
1388 // Try to prove away overflow by looking at "nearby" add recurrences.  A
1389 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1390 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1391 //
1392 // Formally:
1393 //
1394 //     {S,+,X} == {S-T,+,X} + T
1395 //  => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1396 //
1397 // If ({S-T,+,X} + T) does not overflow  ... (1)
1398 //
1399 //  RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1400 //
1401 // If {S-T,+,X} does not overflow  ... (2)
1402 //
1403 //  RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1404 //      == {Ext(S-T)+Ext(T),+,Ext(X)}
1405 //
1406 // If (S-T)+T does not overflow  ... (3)
1407 //
1408 //  RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1409 //      == {Ext(S),+,Ext(X)} == LHS
1410 //
1411 // Thus, if (1), (2) and (3) are true for some T, then
1412 //   Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1413 //
1414 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1415 // does not overflow" restricted to the 0th iteration.  Therefore we only need
1416 // to check for (1) and (2).
1417 //
1418 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1419 // is `Delta` (defined below).
1420 template <typename ExtendOpTy>
1421 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1422                                                 const SCEV *Step,
1423                                                 const Loop *L) {
1424   auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1425 
1426   // We restrict `Start` to a constant to prevent SCEV from spending too much
1427   // time here.  It is correct (but more expensive) to continue with a
1428   // non-constant `Start` and do a general SCEV subtraction to compute
1429   // `PreStart` below.
1430   const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1431   if (!StartC)
1432     return false;
1433 
1434   APInt StartAI = StartC->getAPInt();
1435 
1436   for (unsigned Delta : {-2, -1, 1, 2}) {
1437     const SCEV *PreStart = getConstant(StartAI - Delta);
1438 
1439     FoldingSetNodeID ID;
1440     ID.AddInteger(scAddRecExpr);
1441     ID.AddPointer(PreStart);
1442     ID.AddPointer(Step);
1443     ID.AddPointer(L);
1444     void *IP = nullptr;
1445     const auto *PreAR =
1446       static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1447 
1448     // Give up if we don't already have the add recurrence we need because
1449     // actually constructing an add recurrence is relatively expensive.
1450     if (PreAR && PreAR->getNoWrapFlags(WrapType)) {  // proves (2)
1451       const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1452       ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
1453       const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1454           DeltaS, &Pred, this);
1455       if (Limit && isKnownPredicate(Pred, PreAR, Limit))  // proves (1)
1456         return true;
1457     }
1458   }
1459 
1460   return false;
1461 }
1462 
1463 // Finds an integer D for an expression (C + x + y + ...) such that the top
1464 // level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1465 // unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1466 // maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1467 // the (C + x + y + ...) expression is \p WholeAddExpr.
1468 static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
1469                                             const SCEVConstant *ConstantTerm,
1470                                             const SCEVAddExpr *WholeAddExpr) {
1471   const APInt &C = ConstantTerm->getAPInt();
1472   const unsigned BitWidth = C.getBitWidth();
1473   // Find number of trailing zeros of (x + y + ...) w/o the C first:
1474   uint32_t TZ = BitWidth;
1475   for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1476     TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
1477   if (TZ) {
1478     // Set D to be as many least significant bits of C as possible while still
1479     // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1480     return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1481   }
1482   return APInt(BitWidth, 0);
1483 }
1484 
1485 // Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1486 // level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1487 // number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1488 // ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1489 static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
1490                                             const APInt &ConstantStart,
1491                                             const SCEV *Step) {
1492   const unsigned BitWidth = ConstantStart.getBitWidth();
1493   const uint32_t TZ = SE.GetMinTrailingZeros(Step);
1494   if (TZ)
1495     return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1496                          : ConstantStart;
1497   return APInt(BitWidth, 0);
1498 }
1499 
1500 const SCEV *
1501 ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
1502   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1503          "This is not an extending conversion!");
1504   assert(isSCEVable(Ty) &&
1505          "This is not a conversion to a SCEVable type!");
1506   Ty = getEffectiveSCEVType(Ty);
1507 
1508   // Fold if the operand is constant.
1509   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1510     return getConstant(
1511       cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
1512 
1513   // zext(zext(x)) --> zext(x)
1514   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1515     return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1516 
1517   // Before doing any expensive analysis, check to see if we've already
1518   // computed a SCEV for this Op and Ty.
1519   FoldingSetNodeID ID;
1520   ID.AddInteger(scZeroExtend);
1521   ID.AddPointer(Op);
1522   ID.AddPointer(Ty);
1523   void *IP = nullptr;
1524   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1525   if (Depth > MaxCastDepth) {
1526     SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1527                                                      Op, Ty);
1528     UniqueSCEVs.InsertNode(S, IP);
1529     addToLoopUseLists(S);
1530     return S;
1531   }
1532 
1533   // zext(trunc(x)) --> zext(x) or x or trunc(x)
1534   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1535     // It's possible the bits taken off by the truncate were all zero bits. If
1536     // so, we should be able to simplify this further.
1537     const SCEV *X = ST->getOperand();
1538     ConstantRange CR = getUnsignedRange(X);
1539     unsigned TruncBits = getTypeSizeInBits(ST->getType());
1540     unsigned NewBits = getTypeSizeInBits(Ty);
1541     if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1542             CR.zextOrTrunc(NewBits)))
1543       return getTruncateOrZeroExtend(X, Ty, Depth);
1544   }
1545 
1546   // If the input value is a chrec scev, and we can prove that the value
1547   // did not overflow the old, smaller, value, we can zero extend all of the
1548   // operands (often constants).  This allows analysis of something like
1549   // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1550   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1551     if (AR->isAffine()) {
1552       const SCEV *Start = AR->getStart();
1553       const SCEV *Step = AR->getStepRecurrence(*this);
1554       unsigned BitWidth = getTypeSizeInBits(AR->getType());
1555       const Loop *L = AR->getLoop();
1556 
1557       if (!AR->hasNoUnsignedWrap()) {
1558         auto NewFlags = proveNoWrapViaConstantRanges(AR);
1559         const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags);
1560       }
1561 
1562       // If we have special knowledge that this addrec won't overflow,
1563       // we don't need to do any further analysis.
1564       if (AR->hasNoUnsignedWrap())
1565         return getAddRecExpr(
1566             getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
1567             getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1568 
1569       // Check whether the backedge-taken count is SCEVCouldNotCompute.
1570       // Note that this serves two purposes: It filters out loops that are
1571       // simply not analyzable, and it covers the case where this code is
1572       // being called from within backedge-taken count analysis, such that
1573       // attempting to ask for the backedge-taken count would likely result
1574       // in infinite recursion. In the later case, the analysis code will
1575       // cope with a conservative value, and it will take care to purge
1576       // that value once it has finished.
1577       const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1578       if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1579         // Manually compute the final value for AR, checking for
1580         // overflow.
1581 
1582         // Check whether the backedge-taken count can be losslessly casted to
1583         // the addrec's type. The count is always unsigned.
1584         const SCEV *CastedMaxBECount =
1585             getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1586         const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1587             CastedMaxBECount, MaxBECount->getType(), Depth);
1588         if (MaxBECount == RecastedMaxBECount) {
1589           Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1590           // Check whether Start+Step*MaxBECount has no unsigned overflow.
1591           const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1592                                         SCEV::FlagAnyWrap, Depth + 1);
1593           const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1594                                                           SCEV::FlagAnyWrap,
1595                                                           Depth + 1),
1596                                                WideTy, Depth + 1);
1597           const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1598           const SCEV *WideMaxBECount =
1599             getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1600           const SCEV *OperandExtendedAdd =
1601             getAddExpr(WideStart,
1602                        getMulExpr(WideMaxBECount,
1603                                   getZeroExtendExpr(Step, WideTy, Depth + 1),
1604                                   SCEV::FlagAnyWrap, Depth + 1),
1605                        SCEV::FlagAnyWrap, Depth + 1);
1606           if (ZAdd == OperandExtendedAdd) {
1607             // Cache knowledge of AR NUW, which is propagated to this AddRec.
1608             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1609             // Return the expression with the addrec on the outside.
1610             return getAddRecExpr(
1611                 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1612                                                          Depth + 1),
1613                 getZeroExtendExpr(Step, Ty, Depth + 1), L,
1614                 AR->getNoWrapFlags());
1615           }
1616           // Similar to above, only this time treat the step value as signed.
1617           // This covers loops that count down.
1618           OperandExtendedAdd =
1619             getAddExpr(WideStart,
1620                        getMulExpr(WideMaxBECount,
1621                                   getSignExtendExpr(Step, WideTy, Depth + 1),
1622                                   SCEV::FlagAnyWrap, Depth + 1),
1623                        SCEV::FlagAnyWrap, Depth + 1);
1624           if (ZAdd == OperandExtendedAdd) {
1625             // Cache knowledge of AR NW, which is propagated to this AddRec.
1626             // Negative step causes unsigned wrap, but it still can't self-wrap.
1627             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1628             // Return the expression with the addrec on the outside.
1629             return getAddRecExpr(
1630                 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1631                                                          Depth + 1),
1632                 getSignExtendExpr(Step, Ty, Depth + 1), L,
1633                 AR->getNoWrapFlags());
1634           }
1635         }
1636       }
1637 
1638       // Normally, in the cases we can prove no-overflow via a
1639       // backedge guarding condition, we can also compute a backedge
1640       // taken count for the loop.  The exceptions are assumptions and
1641       // guards present in the loop -- SCEV is not great at exploiting
1642       // these to compute max backedge taken counts, but can still use
1643       // these to prove lack of overflow.  Use this fact to avoid
1644       // doing extra work that may not pay off.
1645       if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1646           !AC.assumptions().empty()) {
1647         // If the backedge is guarded by a comparison with the pre-inc
1648         // value the addrec is safe. Also, if the entry is guarded by
1649         // a comparison with the start value and the backedge is
1650         // guarded by a comparison with the post-inc value, the addrec
1651         // is safe.
1652         if (isKnownPositive(Step)) {
1653           const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
1654                                       getUnsignedRangeMax(Step));
1655           if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
1656               isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) {
1657             // Cache knowledge of AR NUW, which is propagated to this
1658             // AddRec.
1659             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1660             // Return the expression with the addrec on the outside.
1661             return getAddRecExpr(
1662                 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1663                                                          Depth + 1),
1664                 getZeroExtendExpr(Step, Ty, Depth + 1), L,
1665                 AR->getNoWrapFlags());
1666           }
1667         } else if (isKnownNegative(Step)) {
1668           const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
1669                                       getSignedRangeMin(Step));
1670           if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
1671               isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) {
1672             // Cache knowledge of AR NW, which is propagated to this
1673             // AddRec.  Negative step causes unsigned wrap, but it
1674             // still can't self-wrap.
1675             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1676             // Return the expression with the addrec on the outside.
1677             return getAddRecExpr(
1678                 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1679                                                          Depth + 1),
1680                 getSignExtendExpr(Step, Ty, Depth + 1), L,
1681                 AR->getNoWrapFlags());
1682           }
1683         }
1684       }
1685 
1686       // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1687       // if D + (C - D + Step * n) could be proven to not unsigned wrap
1688       // where D maximizes the number of trailing zeros of (C - D + Step * n)
1689       if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1690         const APInt &C = SC->getAPInt();
1691         const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1692         if (D != 0) {
1693           const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1694           const SCEV *SResidual =
1695               getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1696           const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1697           return getAddExpr(SZExtD, SZExtR,
1698                             (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1699                             Depth + 1);
1700         }
1701       }
1702 
1703       if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1704         const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1705         return getAddRecExpr(
1706             getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
1707             getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1708       }
1709     }
1710 
1711   // zext(A % B) --> zext(A) % zext(B)
1712   {
1713     const SCEV *LHS;
1714     const SCEV *RHS;
1715     if (matchURem(Op, LHS, RHS))
1716       return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1717                          getZeroExtendExpr(RHS, Ty, Depth + 1));
1718   }
1719 
1720   // zext(A / B) --> zext(A) / zext(B).
1721   if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1722     return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1723                        getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1724 
1725   if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1726     // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1727     if (SA->hasNoUnsignedWrap()) {
1728       // If the addition does not unsign overflow then we can, by definition,
1729       // commute the zero extension with the addition operation.
1730       SmallVector<const SCEV *, 4> Ops;
1731       for (const auto *Op : SA->operands())
1732         Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1733       return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1734     }
1735 
1736     // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1737     // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1738     // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1739     //
1740     // Often address arithmetics contain expressions like
1741     // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1742     // This transformation is useful while proving that such expressions are
1743     // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1744     if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1745       const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1746       if (D != 0) {
1747         const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1748         const SCEV *SResidual =
1749             getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
1750         const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1751         return getAddExpr(SZExtD, SZExtR,
1752                           (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1753                           Depth + 1);
1754       }
1755     }
1756   }
1757 
1758   if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1759     // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1760     if (SM->hasNoUnsignedWrap()) {
1761       // If the multiply does not unsign overflow then we can, by definition,
1762       // commute the zero extension with the multiply operation.
1763       SmallVector<const SCEV *, 4> Ops;
1764       for (const auto *Op : SM->operands())
1765         Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1766       return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1767     }
1768 
1769     // zext(2^K * (trunc X to iN)) to iM ->
1770     // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1771     //
1772     // Proof:
1773     //
1774     //     zext(2^K * (trunc X to iN)) to iM
1775     //   = zext((trunc X to iN) << K) to iM
1776     //   = zext((trunc X to i{N-K}) << K)<nuw> to iM
1777     //     (because shl removes the top K bits)
1778     //   = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1779     //   = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1780     //
1781     if (SM->getNumOperands() == 2)
1782       if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1783         if (MulLHS->getAPInt().isPowerOf2())
1784           if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1785             int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1786                                MulLHS->getAPInt().logBase2();
1787             Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1788             return getMulExpr(
1789                 getZeroExtendExpr(MulLHS, Ty),
1790                 getZeroExtendExpr(
1791                     getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1792                 SCEV::FlagNUW, Depth + 1);
1793           }
1794   }
1795 
1796   // The cast wasn't folded; create an explicit cast node.
1797   // Recompute the insert position, as it may have been invalidated.
1798   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1799   SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1800                                                    Op, Ty);
1801   UniqueSCEVs.InsertNode(S, IP);
1802   addToLoopUseLists(S);
1803   return S;
1804 }
1805 
1806 const SCEV *
1807 ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
1808   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1809          "This is not an extending conversion!");
1810   assert(isSCEVable(Ty) &&
1811          "This is not a conversion to a SCEVable type!");
1812   Ty = getEffectiveSCEVType(Ty);
1813 
1814   // Fold if the operand is constant.
1815   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1816     return getConstant(
1817       cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
1818 
1819   // sext(sext(x)) --> sext(x)
1820   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1821     return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1822 
1823   // sext(zext(x)) --> zext(x)
1824   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1825     return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1826 
1827   // Before doing any expensive analysis, check to see if we've already
1828   // computed a SCEV for this Op and Ty.
1829   FoldingSetNodeID ID;
1830   ID.AddInteger(scSignExtend);
1831   ID.AddPointer(Op);
1832   ID.AddPointer(Ty);
1833   void *IP = nullptr;
1834   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1835   // Limit recursion depth.
1836   if (Depth > MaxCastDepth) {
1837     SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1838                                                      Op, Ty);
1839     UniqueSCEVs.InsertNode(S, IP);
1840     addToLoopUseLists(S);
1841     return S;
1842   }
1843 
1844   // sext(trunc(x)) --> sext(x) or x or trunc(x)
1845   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1846     // It's possible the bits taken off by the truncate were all sign bits. If
1847     // so, we should be able to simplify this further.
1848     const SCEV *X = ST->getOperand();
1849     ConstantRange CR = getSignedRange(X);
1850     unsigned TruncBits = getTypeSizeInBits(ST->getType());
1851     unsigned NewBits = getTypeSizeInBits(Ty);
1852     if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1853             CR.sextOrTrunc(NewBits)))
1854       return getTruncateOrSignExtend(X, Ty, Depth);
1855   }
1856 
1857   if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1858     // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1859     if (SA->hasNoSignedWrap()) {
1860       // If the addition does not sign overflow then we can, by definition,
1861       // commute the sign extension with the addition operation.
1862       SmallVector<const SCEV *, 4> Ops;
1863       for (const auto *Op : SA->operands())
1864         Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1865       return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1866     }
1867 
1868     // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1869     // if D + (C - D + x + y + ...) could be proven to not signed wrap
1870     // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1871     //
1872     // For instance, this will bring two seemingly different expressions:
1873     //     1 + sext(5 + 20 * %x + 24 * %y)  and
1874     //         sext(6 + 20 * %x + 24 * %y)
1875     // to the same form:
1876     //     2 + sext(4 + 20 * %x + 24 * %y)
1877     if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1878       const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1879       if (D != 0) {
1880         const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1881         const SCEV *SResidual =
1882             getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
1883         const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1884         return getAddExpr(SSExtD, SSExtR,
1885                           (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1886                           Depth + 1);
1887       }
1888     }
1889   }
1890   // If the input value is a chrec scev, and we can prove that the value
1891   // did not overflow the old, smaller, value, we can sign extend all of the
1892   // operands (often constants).  This allows analysis of something like
1893   // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
1894   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1895     if (AR->isAffine()) {
1896       const SCEV *Start = AR->getStart();
1897       const SCEV *Step = AR->getStepRecurrence(*this);
1898       unsigned BitWidth = getTypeSizeInBits(AR->getType());
1899       const Loop *L = AR->getLoop();
1900 
1901       if (!AR->hasNoSignedWrap()) {
1902         auto NewFlags = proveNoWrapViaConstantRanges(AR);
1903         const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags);
1904       }
1905 
1906       // If we have special knowledge that this addrec won't overflow,
1907       // we don't need to do any further analysis.
1908       if (AR->hasNoSignedWrap())
1909         return getAddRecExpr(
1910             getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
1911             getSignExtendExpr(Step, Ty, Depth + 1), L, SCEV::FlagNSW);
1912 
1913       // Check whether the backedge-taken count is SCEVCouldNotCompute.
1914       // Note that this serves two purposes: It filters out loops that are
1915       // simply not analyzable, and it covers the case where this code is
1916       // being called from within backedge-taken count analysis, such that
1917       // attempting to ask for the backedge-taken count would likely result
1918       // in infinite recursion. In the later case, the analysis code will
1919       // cope with a conservative value, and it will take care to purge
1920       // that value once it has finished.
1921       const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1922       if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1923         // Manually compute the final value for AR, checking for
1924         // overflow.
1925 
1926         // Check whether the backedge-taken count can be losslessly casted to
1927         // the addrec's type. The count is always unsigned.
1928         const SCEV *CastedMaxBECount =
1929             getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1930         const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1931             CastedMaxBECount, MaxBECount->getType(), Depth);
1932         if (MaxBECount == RecastedMaxBECount) {
1933           Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1934           // Check whether Start+Step*MaxBECount has no signed overflow.
1935           const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
1936                                         SCEV::FlagAnyWrap, Depth + 1);
1937           const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
1938                                                           SCEV::FlagAnyWrap,
1939                                                           Depth + 1),
1940                                                WideTy, Depth + 1);
1941           const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
1942           const SCEV *WideMaxBECount =
1943             getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1944           const SCEV *OperandExtendedAdd =
1945             getAddExpr(WideStart,
1946                        getMulExpr(WideMaxBECount,
1947                                   getSignExtendExpr(Step, WideTy, Depth + 1),
1948                                   SCEV::FlagAnyWrap, Depth + 1),
1949                        SCEV::FlagAnyWrap, Depth + 1);
1950           if (SAdd == OperandExtendedAdd) {
1951             // Cache knowledge of AR NSW, which is propagated to this AddRec.
1952             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1953             // Return the expression with the addrec on the outside.
1954             return getAddRecExpr(
1955                 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
1956                                                          Depth + 1),
1957                 getSignExtendExpr(Step, Ty, Depth + 1), L,
1958                 AR->getNoWrapFlags());
1959           }
1960           // Similar to above, only this time treat the step value as unsigned.
1961           // This covers loops that count up with an unsigned step.
1962           OperandExtendedAdd =
1963             getAddExpr(WideStart,
1964                        getMulExpr(WideMaxBECount,
1965                                   getZeroExtendExpr(Step, WideTy, Depth + 1),
1966                                   SCEV::FlagAnyWrap, Depth + 1),
1967                        SCEV::FlagAnyWrap, Depth + 1);
1968           if (SAdd == OperandExtendedAdd) {
1969             // If AR wraps around then
1970             //
1971             //    abs(Step) * MaxBECount > unsigned-max(AR->getType())
1972             // => SAdd != OperandExtendedAdd
1973             //
1974             // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
1975             // (SAdd == OperandExtendedAdd => AR is NW)
1976 
1977             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1978 
1979             // Return the expression with the addrec on the outside.
1980             return getAddRecExpr(
1981                 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
1982                                                          Depth + 1),
1983                 getZeroExtendExpr(Step, Ty, Depth + 1), L,
1984                 AR->getNoWrapFlags());
1985           }
1986         }
1987       }
1988 
1989       // Normally, in the cases we can prove no-overflow via a
1990       // backedge guarding condition, we can also compute a backedge
1991       // taken count for the loop.  The exceptions are assumptions and
1992       // guards present in the loop -- SCEV is not great at exploiting
1993       // these to compute max backedge taken counts, but can still use
1994       // these to prove lack of overflow.  Use this fact to avoid
1995       // doing extra work that may not pay off.
1996 
1997       if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1998           !AC.assumptions().empty()) {
1999         // If the backedge is guarded by a comparison with the pre-inc
2000         // value the addrec is safe. Also, if the entry is guarded by
2001         // a comparison with the start value and the backedge is
2002         // guarded by a comparison with the post-inc value, the addrec
2003         // is safe.
2004         ICmpInst::Predicate Pred;
2005         const SCEV *OverflowLimit =
2006             getSignedOverflowLimitForStep(Step, &Pred, this);
2007         if (OverflowLimit &&
2008             (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
2009              isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
2010           // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec.
2011           const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
2012           return getAddRecExpr(
2013               getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
2014               getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
2015         }
2016       }
2017 
2018       // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2019       // if D + (C - D + Step * n) could be proven to not signed wrap
2020       // where D maximizes the number of trailing zeros of (C - D + Step * n)
2021       if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2022         const APInt &C = SC->getAPInt();
2023         const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2024         if (D != 0) {
2025           const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2026           const SCEV *SResidual =
2027               getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2028           const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2029           return getAddExpr(SSExtD, SSExtR,
2030                             (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
2031                             Depth + 1);
2032         }
2033       }
2034 
2035       if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2036         const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
2037         return getAddRecExpr(
2038             getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
2039             getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
2040       }
2041     }
2042 
2043   // If the input value is provably positive and we could not simplify
2044   // away the sext build a zext instead.
2045   if (isKnownNonNegative(Op))
2046     return getZeroExtendExpr(Op, Ty, Depth + 1);
2047 
2048   // The cast wasn't folded; create an explicit cast node.
2049   // Recompute the insert position, as it may have been invalidated.
2050   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2051   SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2052                                                    Op, Ty);
2053   UniqueSCEVs.InsertNode(S, IP);
2054   addToLoopUseLists(S);
2055   return S;
2056 }
2057 
2058 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
2059 /// unspecified bits out to the given type.
2060 const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
2061                                               Type *Ty) {
2062   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2063          "This is not an extending conversion!");
2064   assert(isSCEVable(Ty) &&
2065          "This is not a conversion to a SCEVable type!");
2066   Ty = getEffectiveSCEVType(Ty);
2067 
2068   // Sign-extend negative constants.
2069   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2070     if (SC->getAPInt().isNegative())
2071       return getSignExtendExpr(Op, Ty);
2072 
2073   // Peel off a truncate cast.
2074   if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2075     const SCEV *NewOp = T->getOperand();
2076     if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2077       return getAnyExtendExpr(NewOp, Ty);
2078     return getTruncateOrNoop(NewOp, Ty);
2079   }
2080 
2081   // Next try a zext cast. If the cast is folded, use it.
2082   const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2083   if (!isa<SCEVZeroExtendExpr>(ZExt))
2084     return ZExt;
2085 
2086   // Next try a sext cast. If the cast is folded, use it.
2087   const SCEV *SExt = getSignExtendExpr(Op, Ty);
2088   if (!isa<SCEVSignExtendExpr>(SExt))
2089     return SExt;
2090 
2091   // Force the cast to be folded into the operands of an addrec.
2092   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2093     SmallVector<const SCEV *, 4> Ops;
2094     for (const SCEV *Op : AR->operands())
2095       Ops.push_back(getAnyExtendExpr(Op, Ty));
2096     return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2097   }
2098 
2099   // If the expression is obviously signed, use the sext cast value.
2100   if (isa<SCEVSMaxExpr>(Op))
2101     return SExt;
2102 
2103   // Absent any other information, use the zext cast value.
2104   return ZExt;
2105 }
2106 
2107 /// Process the given Ops list, which is a list of operands to be added under
2108 /// the given scale, update the given map. This is a helper function for
2109 /// getAddRecExpr. As an example of what it does, given a sequence of operands
2110 /// that would form an add expression like this:
2111 ///
2112 ///    m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2113 ///
2114 /// where A and B are constants, update the map with these values:
2115 ///
2116 ///    (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2117 ///
2118 /// and add 13 + A*B*29 to AccumulatedConstant.
2119 /// This will allow getAddRecExpr to produce this:
2120 ///
2121 ///    13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2122 ///
2123 /// This form often exposes folding opportunities that are hidden in
2124 /// the original operand list.
2125 ///
2126 /// Return true iff it appears that any interesting folding opportunities
2127 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
2128 /// the common case where no interesting opportunities are present, and
2129 /// is also used as a check to avoid infinite recursion.
2130 static bool
2131 CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
2132                              SmallVectorImpl<const SCEV *> &NewOps,
2133                              APInt &AccumulatedConstant,
2134                              const SCEV *const *Ops, size_t NumOperands,
2135                              const APInt &Scale,
2136                              ScalarEvolution &SE) {
2137   bool Interesting = false;
2138 
2139   // Iterate over the add operands. They are sorted, with constants first.
2140   unsigned i = 0;
2141   while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2142     ++i;
2143     // Pull a buried constant out to the outside.
2144     if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2145       Interesting = true;
2146     AccumulatedConstant += Scale * C->getAPInt();
2147   }
2148 
2149   // Next comes everything else. We're especially interested in multiplies
2150   // here, but they're in the middle, so just visit the rest with one loop.
2151   for (; i != NumOperands; ++i) {
2152     const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2153     if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2154       APInt NewScale =
2155           Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2156       if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2157         // A multiplication of a constant with another add; recurse.
2158         const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2159         Interesting |=
2160           CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2161                                        Add->op_begin(), Add->getNumOperands(),
2162                                        NewScale, SE);
2163       } else {
2164         // A multiplication of a constant with some other value. Update
2165         // the map.
2166         SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
2167         const SCEV *Key = SE.getMulExpr(MulOps);
2168         auto Pair = M.insert({Key, NewScale});
2169         if (Pair.second) {
2170           NewOps.push_back(Pair.first->first);
2171         } else {
2172           Pair.first->second += NewScale;
2173           // The map already had an entry for this value, which may indicate
2174           // a folding opportunity.
2175           Interesting = true;
2176         }
2177       }
2178     } else {
2179       // An ordinary operand. Update the map.
2180       std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2181           M.insert({Ops[i], Scale});
2182       if (Pair.second) {
2183         NewOps.push_back(Pair.first->first);
2184       } else {
2185         Pair.first->second += Scale;
2186         // The map already had an entry for this value, which may indicate
2187         // a folding opportunity.
2188         Interesting = true;
2189       }
2190     }
2191   }
2192 
2193   return Interesting;
2194 }
2195 
2196 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2197 // `OldFlags' as can't-wrap behavior.  Infer a more aggressive set of
2198 // can't-overflow flags for the operation if possible.
2199 static SCEV::NoWrapFlags
2200 StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
2201                       const ArrayRef<const SCEV *> Ops,
2202                       SCEV::NoWrapFlags Flags) {
2203   using namespace std::placeholders;
2204 
2205   using OBO = OverflowingBinaryOperator;
2206 
2207   bool CanAnalyze =
2208       Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
2209   (void)CanAnalyze;
2210   assert(CanAnalyze && "don't call from other places!");
2211 
2212   int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2213   SCEV::NoWrapFlags SignOrUnsignWrap =
2214       ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2215 
2216   // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2217   auto IsKnownNonNegative = [&](const SCEV *S) {
2218     return SE->isKnownNonNegative(S);
2219   };
2220 
2221   if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2222     Flags =
2223         ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2224 
2225   SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2226 
2227   if (SignOrUnsignWrap != SignOrUnsignMask &&
2228       (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2229       isa<SCEVConstant>(Ops[0])) {
2230 
2231     auto Opcode = [&] {
2232       switch (Type) {
2233       case scAddExpr:
2234         return Instruction::Add;
2235       case scMulExpr:
2236         return Instruction::Mul;
2237       default:
2238         llvm_unreachable("Unexpected SCEV op.");
2239       }
2240     }();
2241 
2242     const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2243 
2244     // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2245     if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2246       auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
2247           Opcode, C, OBO::NoSignedWrap);
2248       if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2249         Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2250     }
2251 
2252     // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2253     if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2254       auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
2255           Opcode, C, OBO::NoUnsignedWrap);
2256       if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2257         Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2258     }
2259   }
2260 
2261   return Flags;
2262 }
2263 
2264 bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) {
2265   return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2266 }
2267 
2268 /// Get a canonical add expression, or something simpler if possible.
2269 const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
2270                                         SCEV::NoWrapFlags Flags,
2271                                         unsigned Depth) {
2272   assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2273          "only nuw or nsw allowed");
2274   assert(!Ops.empty() && "Cannot get empty add!");
2275   if (Ops.size() == 1) return Ops[0];
2276 #ifndef NDEBUG
2277   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2278   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2279     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2280            "SCEVAddExpr operand types don't match!");
2281 #endif
2282 
2283   // Sort by complexity, this groups all similar expression types together.
2284   GroupByComplexity(Ops, &LI, DT);
2285 
2286   // If there are any constants, fold them together.
2287   unsigned Idx = 0;
2288   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2289     ++Idx;
2290     assert(Idx < Ops.size());
2291     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2292       // We found two constants, fold them together!
2293       Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2294       if (Ops.size() == 2) return Ops[0];
2295       Ops.erase(Ops.begin()+1);  // Erase the folded element
2296       LHSC = cast<SCEVConstant>(Ops[0]);
2297     }
2298 
2299     // If we are left with a constant zero being added, strip it off.
2300     if (LHSC->getValue()->isZero()) {
2301       Ops.erase(Ops.begin());
2302       --Idx;
2303     }
2304 
2305     if (Ops.size() == 1) return Ops[0];
2306   }
2307 
2308   Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags);
2309 
2310   // Limit recursion calls depth.
2311   if (Depth > MaxArithDepth || hasHugeExpression(Ops))
2312     return getOrCreateAddExpr(Ops, Flags);
2313 
2314   if (SCEV *S = std::get<0>(findExistingSCEVInCache(scAddExpr, Ops))) {
2315     static_cast<SCEVAddExpr *>(S)->setNoWrapFlags(Flags);
2316     return S;
2317   }
2318 
2319   // Okay, check to see if the same value occurs in the operand list more than
2320   // once.  If so, merge them together into an multiply expression.  Since we
2321   // sorted the list, these values are required to be adjacent.
2322   Type *Ty = Ops[0]->getType();
2323   bool FoundMatch = false;
2324   for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2325     if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
2326       // Scan ahead to count how many equal operands there are.
2327       unsigned Count = 2;
2328       while (i+Count != e && Ops[i+Count] == Ops[i])
2329         ++Count;
2330       // Merge the values into a multiply.
2331       const SCEV *Scale = getConstant(Ty, Count);
2332       const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2333       if (Ops.size() == Count)
2334         return Mul;
2335       Ops[i] = Mul;
2336       Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2337       --i; e -= Count - 1;
2338       FoundMatch = true;
2339     }
2340   if (FoundMatch)
2341     return getAddExpr(Ops, Flags, Depth + 1);
2342 
2343   // Check for truncates. If all the operands are truncated from the same
2344   // type, see if factoring out the truncate would permit the result to be
2345   // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2346   // if the contents of the resulting outer trunc fold to something simple.
2347   auto FindTruncSrcType = [&]() -> Type * {
2348     // We're ultimately looking to fold an addrec of truncs and muls of only
2349     // constants and truncs, so if we find any other types of SCEV
2350     // as operands of the addrec then we bail and return nullptr here.
2351     // Otherwise, we return the type of the operand of a trunc that we find.
2352     if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2353       return T->getOperand()->getType();
2354     if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2355       const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2356       if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2357         return T->getOperand()->getType();
2358     }
2359     return nullptr;
2360   };
2361   if (auto *SrcType = FindTruncSrcType()) {
2362     SmallVector<const SCEV *, 8> LargeOps;
2363     bool Ok = true;
2364     // Check all the operands to see if they can be represented in the
2365     // source type of the truncate.
2366     for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2367       if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2368         if (T->getOperand()->getType() != SrcType) {
2369           Ok = false;
2370           break;
2371         }
2372         LargeOps.push_back(T->getOperand());
2373       } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2374         LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2375       } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2376         SmallVector<const SCEV *, 8> LargeMulOps;
2377         for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2378           if (const SCEVTruncateExpr *T =
2379                 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2380             if (T->getOperand()->getType() != SrcType) {
2381               Ok = false;
2382               break;
2383             }
2384             LargeMulOps.push_back(T->getOperand());
2385           } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2386             LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2387           } else {
2388             Ok = false;
2389             break;
2390           }
2391         }
2392         if (Ok)
2393           LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2394       } else {
2395         Ok = false;
2396         break;
2397       }
2398     }
2399     if (Ok) {
2400       // Evaluate the expression in the larger type.
2401       const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2402       // If it folds to something simple, use it. Otherwise, don't.
2403       if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2404         return getTruncateExpr(Fold, Ty);
2405     }
2406   }
2407 
2408   // Skip past any other cast SCEVs.
2409   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2410     ++Idx;
2411 
2412   // If there are add operands they would be next.
2413   if (Idx < Ops.size()) {
2414     bool DeletedAdd = false;
2415     while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2416       if (Ops.size() > AddOpsInlineThreshold ||
2417           Add->getNumOperands() > AddOpsInlineThreshold)
2418         break;
2419       // If we have an add, expand the add operands onto the end of the operands
2420       // list.
2421       Ops.erase(Ops.begin()+Idx);
2422       Ops.append(Add->op_begin(), Add->op_end());
2423       DeletedAdd = true;
2424     }
2425 
2426     // If we deleted at least one add, we added operands to the end of the list,
2427     // and they are not necessarily sorted.  Recurse to resort and resimplify
2428     // any operands we just acquired.
2429     if (DeletedAdd)
2430       return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2431   }
2432 
2433   // Skip over the add expression until we get to a multiply.
2434   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2435     ++Idx;
2436 
2437   // Check to see if there are any folding opportunities present with
2438   // operands multiplied by constant values.
2439   if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2440     uint64_t BitWidth = getTypeSizeInBits(Ty);
2441     DenseMap<const SCEV *, APInt> M;
2442     SmallVector<const SCEV *, 8> NewOps;
2443     APInt AccumulatedConstant(BitWidth, 0);
2444     if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2445                                      Ops.data(), Ops.size(),
2446                                      APInt(BitWidth, 1), *this)) {
2447       struct APIntCompare {
2448         bool operator()(const APInt &LHS, const APInt &RHS) const {
2449           return LHS.ult(RHS);
2450         }
2451       };
2452 
2453       // Some interesting folding opportunity is present, so its worthwhile to
2454       // re-generate the operands list. Group the operands by constant scale,
2455       // to avoid multiplying by the same constant scale multiple times.
2456       std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2457       for (const SCEV *NewOp : NewOps)
2458         MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2459       // Re-generate the operands list.
2460       Ops.clear();
2461       if (AccumulatedConstant != 0)
2462         Ops.push_back(getConstant(AccumulatedConstant));
2463       for (auto &MulOp : MulOpLists)
2464         if (MulOp.first != 0)
2465           Ops.push_back(getMulExpr(
2466               getConstant(MulOp.first),
2467               getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2468               SCEV::FlagAnyWrap, Depth + 1));
2469       if (Ops.empty())
2470         return getZero(Ty);
2471       if (Ops.size() == 1)
2472         return Ops[0];
2473       return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2474     }
2475   }
2476 
2477   // If we are adding something to a multiply expression, make sure the
2478   // something is not already an operand of the multiply.  If so, merge it into
2479   // the multiply.
2480   for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2481     const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2482     for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2483       const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2484       if (isa<SCEVConstant>(MulOpSCEV))
2485         continue;
2486       for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2487         if (MulOpSCEV == Ops[AddOp]) {
2488           // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
2489           const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2490           if (Mul->getNumOperands() != 2) {
2491             // If the multiply has more than two operands, we must get the
2492             // Y*Z term.
2493             SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2494                                                 Mul->op_begin()+MulOp);
2495             MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2496             InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2497           }
2498           SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2499           const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2500           const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2501                                             SCEV::FlagAnyWrap, Depth + 1);
2502           if (Ops.size() == 2) return OuterMul;
2503           if (AddOp < Idx) {
2504             Ops.erase(Ops.begin()+AddOp);
2505             Ops.erase(Ops.begin()+Idx-1);
2506           } else {
2507             Ops.erase(Ops.begin()+Idx);
2508             Ops.erase(Ops.begin()+AddOp-1);
2509           }
2510           Ops.push_back(OuterMul);
2511           return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2512         }
2513 
2514       // Check this multiply against other multiplies being added together.
2515       for (unsigned OtherMulIdx = Idx+1;
2516            OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2517            ++OtherMulIdx) {
2518         const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2519         // If MulOp occurs in OtherMul, we can fold the two multiplies
2520         // together.
2521         for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2522              OMulOp != e; ++OMulOp)
2523           if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2524             // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2525             const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2526             if (Mul->getNumOperands() != 2) {
2527               SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2528                                                   Mul->op_begin()+MulOp);
2529               MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2530               InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2531             }
2532             const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2533             if (OtherMul->getNumOperands() != 2) {
2534               SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
2535                                                   OtherMul->op_begin()+OMulOp);
2536               MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
2537               InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2538             }
2539             SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2540             const SCEV *InnerMulSum =
2541                 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2542             const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2543                                               SCEV::FlagAnyWrap, Depth + 1);
2544             if (Ops.size() == 2) return OuterMul;
2545             Ops.erase(Ops.begin()+Idx);
2546             Ops.erase(Ops.begin()+OtherMulIdx-1);
2547             Ops.push_back(OuterMul);
2548             return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2549           }
2550       }
2551     }
2552   }
2553 
2554   // If there are any add recurrences in the operands list, see if any other
2555   // added values are loop invariant.  If so, we can fold them into the
2556   // recurrence.
2557   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2558     ++Idx;
2559 
2560   // Scan over all recurrences, trying to fold loop invariants into them.
2561   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2562     // Scan all of the other operands to this add and add them to the vector if
2563     // they are loop invariant w.r.t. the recurrence.
2564     SmallVector<const SCEV *, 8> LIOps;
2565     const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2566     const Loop *AddRecLoop = AddRec->getLoop();
2567     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2568       if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2569         LIOps.push_back(Ops[i]);
2570         Ops.erase(Ops.begin()+i);
2571         --i; --e;
2572       }
2573 
2574     // If we found some loop invariants, fold them into the recurrence.
2575     if (!LIOps.empty()) {
2576       //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
2577       LIOps.push_back(AddRec->getStart());
2578 
2579       SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
2580                                              AddRec->op_end());
2581       // This follows from the fact that the no-wrap flags on the outer add
2582       // expression are applicable on the 0th iteration, when the add recurrence
2583       // will be equal to its start value.
2584       AddRecOps[0] = getAddExpr(LIOps, Flags, Depth + 1);
2585 
2586       // Build the new addrec. Propagate the NUW and NSW flags if both the
2587       // outer add and the inner addrec are guaranteed to have no overflow.
2588       // Always propagate NW.
2589       Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2590       const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2591 
2592       // If all of the other operands were loop invariant, we are done.
2593       if (Ops.size() == 1) return NewRec;
2594 
2595       // Otherwise, add the folded AddRec by the non-invariant parts.
2596       for (unsigned i = 0;; ++i)
2597         if (Ops[i] == AddRec) {
2598           Ops[i] = NewRec;
2599           break;
2600         }
2601       return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2602     }
2603 
2604     // Okay, if there weren't any loop invariants to be folded, check to see if
2605     // there are multiple AddRec's with the same loop induction variable being
2606     // added together.  If so, we can fold them.
2607     for (unsigned OtherIdx = Idx+1;
2608          OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2609          ++OtherIdx) {
2610       // We expect the AddRecExpr's to be sorted in reverse dominance order,
2611       // so that the 1st found AddRecExpr is dominated by all others.
2612       assert(DT.dominates(
2613            cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2614            AddRec->getLoop()->getHeader()) &&
2615         "AddRecExprs are not sorted in reverse dominance order?");
2616       if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2617         // Other + {A,+,B}<L> + {C,+,D}<L>  -->  Other + {A+C,+,B+D}<L>
2618         SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
2619                                                AddRec->op_end());
2620         for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2621              ++OtherIdx) {
2622           const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2623           if (OtherAddRec->getLoop() == AddRecLoop) {
2624             for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2625                  i != e; ++i) {
2626               if (i >= AddRecOps.size()) {
2627                 AddRecOps.append(OtherAddRec->op_begin()+i,
2628                                  OtherAddRec->op_end());
2629                 break;
2630               }
2631               SmallVector<const SCEV *, 2> TwoOps = {
2632                   AddRecOps[i], OtherAddRec->getOperand(i)};
2633               AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2634             }
2635             Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2636           }
2637         }
2638         // Step size has changed, so we cannot guarantee no self-wraparound.
2639         Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2640         return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2641       }
2642     }
2643 
2644     // Otherwise couldn't fold anything into this recurrence.  Move onto the
2645     // next one.
2646   }
2647 
2648   // Okay, it looks like we really DO need an add expr.  Check to see if we
2649   // already have one, otherwise create a new one.
2650   return getOrCreateAddExpr(Ops, Flags);
2651 }
2652 
2653 const SCEV *
2654 ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2655                                     SCEV::NoWrapFlags Flags) {
2656   FoldingSetNodeID ID;
2657   ID.AddInteger(scAddExpr);
2658   for (const SCEV *Op : Ops)
2659     ID.AddPointer(Op);
2660   void *IP = nullptr;
2661   SCEVAddExpr *S =
2662       static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2663   if (!S) {
2664     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2665     std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2666     S = new (SCEVAllocator)
2667         SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2668     UniqueSCEVs.InsertNode(S, IP);
2669     addToLoopUseLists(S);
2670   }
2671   S->setNoWrapFlags(Flags);
2672   return S;
2673 }
2674 
2675 const SCEV *
2676 ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
2677                                        const Loop *L, SCEV::NoWrapFlags Flags) {
2678   FoldingSetNodeID ID;
2679   ID.AddInteger(scAddRecExpr);
2680   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2681     ID.AddPointer(Ops[i]);
2682   ID.AddPointer(L);
2683   void *IP = nullptr;
2684   SCEVAddRecExpr *S =
2685       static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2686   if (!S) {
2687     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2688     std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2689     S = new (SCEVAllocator)
2690         SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
2691     UniqueSCEVs.InsertNode(S, IP);
2692     addToLoopUseLists(S);
2693   }
2694   S->setNoWrapFlags(Flags);
2695   return S;
2696 }
2697 
2698 const SCEV *
2699 ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
2700                                     SCEV::NoWrapFlags Flags) {
2701   FoldingSetNodeID ID;
2702   ID.AddInteger(scMulExpr);
2703   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2704     ID.AddPointer(Ops[i]);
2705   void *IP = nullptr;
2706   SCEVMulExpr *S =
2707     static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2708   if (!S) {
2709     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2710     std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2711     S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2712                                         O, Ops.size());
2713     UniqueSCEVs.InsertNode(S, IP);
2714     addToLoopUseLists(S);
2715   }
2716   S->setNoWrapFlags(Flags);
2717   return S;
2718 }
2719 
2720 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
2721   uint64_t k = i*j;
2722   if (j > 1 && k / j != i) Overflow = true;
2723   return k;
2724 }
2725 
2726 /// Compute the result of "n choose k", the binomial coefficient.  If an
2727 /// intermediate computation overflows, Overflow will be set and the return will
2728 /// be garbage. Overflow is not cleared on absence of overflow.
2729 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
2730   // We use the multiplicative formula:
2731   //     n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
2732   // At each iteration, we take the n-th term of the numeral and divide by the
2733   // (k-n)th term of the denominator.  This division will always produce an
2734   // integral result, and helps reduce the chance of overflow in the
2735   // intermediate computations. However, we can still overflow even when the
2736   // final result would fit.
2737 
2738   if (n == 0 || n == k) return 1;
2739   if (k > n) return 0;
2740 
2741   if (k > n/2)
2742     k = n-k;
2743 
2744   uint64_t r = 1;
2745   for (uint64_t i = 1; i <= k; ++i) {
2746     r = umul_ov(r, n-(i-1), Overflow);
2747     r /= i;
2748   }
2749   return r;
2750 }
2751 
2752 /// Determine if any of the operands in this SCEV are a constant or if
2753 /// any of the add or multiply expressions in this SCEV contain a constant.
2754 static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
2755   struct FindConstantInAddMulChain {
2756     bool FoundConstant = false;
2757 
2758     bool follow(const SCEV *S) {
2759       FoundConstant |= isa<SCEVConstant>(S);
2760       return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
2761     }
2762 
2763     bool isDone() const {
2764       return FoundConstant;
2765     }
2766   };
2767 
2768   FindConstantInAddMulChain F;
2769   SCEVTraversal<FindConstantInAddMulChain> ST(F);
2770   ST.visitAll(StartExpr);
2771   return F.FoundConstant;
2772 }
2773 
2774 /// Get a canonical multiply expression, or something simpler if possible.
2775 const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
2776                                         SCEV::NoWrapFlags Flags,
2777                                         unsigned Depth) {
2778   assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) &&
2779          "only nuw or nsw allowed");
2780   assert(!Ops.empty() && "Cannot get empty mul!");
2781   if (Ops.size() == 1) return Ops[0];
2782 #ifndef NDEBUG
2783   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2784   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2785     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2786            "SCEVMulExpr operand types don't match!");
2787 #endif
2788 
2789   // Sort by complexity, this groups all similar expression types together.
2790   GroupByComplexity(Ops, &LI, DT);
2791 
2792   // If there are any constants, fold them together.
2793   unsigned Idx = 0;
2794   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2795     ++Idx;
2796     assert(Idx < Ops.size());
2797     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2798       // We found two constants, fold them together!
2799       Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
2800       if (Ops.size() == 2) return Ops[0];
2801       Ops.erase(Ops.begin()+1);  // Erase the folded element
2802       LHSC = cast<SCEVConstant>(Ops[0]);
2803     }
2804 
2805     // If we have a multiply of zero, it will always be zero.
2806     if (LHSC->getValue()->isZero())
2807       return LHSC;
2808 
2809     // If we are left with a constant one being multiplied, strip it off.
2810     if (LHSC->getValue()->isOne()) {
2811       Ops.erase(Ops.begin());
2812       --Idx;
2813     }
2814 
2815     if (Ops.size() == 1)
2816       return Ops[0];
2817   }
2818 
2819   Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags);
2820 
2821   // Limit recursion calls depth.
2822   if (Depth > MaxArithDepth || hasHugeExpression(Ops))
2823     return getOrCreateMulExpr(Ops, Flags);
2824 
2825   if (SCEV *S = std::get<0>(findExistingSCEVInCache(scMulExpr, Ops))) {
2826     static_cast<SCEVMulExpr *>(S)->setNoWrapFlags(Flags);
2827     return S;
2828   }
2829 
2830   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2831     if (Ops.size() == 2) {
2832       // C1*(C2+V) -> C1*C2 + C1*V
2833       if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
2834         // If any of Add's ops are Adds or Muls with a constant, apply this
2835         // transformation as well.
2836         //
2837         // TODO: There are some cases where this transformation is not
2838         // profitable; for example, Add = (C0 + X) * Y + Z.  Maybe the scope of
2839         // this transformation should be narrowed down.
2840         if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add))
2841           return getAddExpr(getMulExpr(LHSC, Add->getOperand(0),
2842                                        SCEV::FlagAnyWrap, Depth + 1),
2843                             getMulExpr(LHSC, Add->getOperand(1),
2844                                        SCEV::FlagAnyWrap, Depth + 1),
2845                             SCEV::FlagAnyWrap, Depth + 1);
2846 
2847       if (Ops[0]->isAllOnesValue()) {
2848         // If we have a mul by -1 of an add, try distributing the -1 among the
2849         // add operands.
2850         if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
2851           SmallVector<const SCEV *, 4> NewOps;
2852           bool AnyFolded = false;
2853           for (const SCEV *AddOp : Add->operands()) {
2854             const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
2855                                          Depth + 1);
2856             if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
2857             NewOps.push_back(Mul);
2858           }
2859           if (AnyFolded)
2860             return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
2861         } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
2862           // Negation preserves a recurrence's no self-wrap property.
2863           SmallVector<const SCEV *, 4> Operands;
2864           for (const SCEV *AddRecOp : AddRec->operands())
2865             Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
2866                                           Depth + 1));
2867 
2868           return getAddRecExpr(Operands, AddRec->getLoop(),
2869                                AddRec->getNoWrapFlags(SCEV::FlagNW));
2870         }
2871       }
2872     }
2873   }
2874 
2875   // Skip over the add expression until we get to a multiply.
2876   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2877     ++Idx;
2878 
2879   // If there are mul operands inline them all into this expression.
2880   if (Idx < Ops.size()) {
2881     bool DeletedMul = false;
2882     while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2883       if (Ops.size() > MulOpsInlineThreshold)
2884         break;
2885       // If we have an mul, expand the mul operands onto the end of the
2886       // operands list.
2887       Ops.erase(Ops.begin()+Idx);
2888       Ops.append(Mul->op_begin(), Mul->op_end());
2889       DeletedMul = true;
2890     }
2891 
2892     // If we deleted at least one mul, we added operands to the end of the
2893     // list, and they are not necessarily sorted.  Recurse to resort and
2894     // resimplify any operands we just acquired.
2895     if (DeletedMul)
2896       return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2897   }
2898 
2899   // If there are any add recurrences in the operands list, see if any other
2900   // added values are loop invariant.  If so, we can fold them into the
2901   // recurrence.
2902   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2903     ++Idx;
2904 
2905   // Scan over all recurrences, trying to fold loop invariants into them.
2906   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2907     // Scan all of the other operands to this mul and add them to the vector
2908     // if they are loop invariant w.r.t. the recurrence.
2909     SmallVector<const SCEV *, 8> LIOps;
2910     const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2911     const Loop *AddRecLoop = AddRec->getLoop();
2912     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2913       if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2914         LIOps.push_back(Ops[i]);
2915         Ops.erase(Ops.begin()+i);
2916         --i; --e;
2917       }
2918 
2919     // If we found some loop invariants, fold them into the recurrence.
2920     if (!LIOps.empty()) {
2921       //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
2922       SmallVector<const SCEV *, 4> NewOps;
2923       NewOps.reserve(AddRec->getNumOperands());
2924       const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
2925       for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
2926         NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
2927                                     SCEV::FlagAnyWrap, Depth + 1));
2928 
2929       // Build the new addrec. Propagate the NUW and NSW flags if both the
2930       // outer mul and the inner addrec are guaranteed to have no overflow.
2931       //
2932       // No self-wrap cannot be guaranteed after changing the step size, but
2933       // will be inferred if either NUW or NSW is true.
2934       Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW));
2935       const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags);
2936 
2937       // If all of the other operands were loop invariant, we are done.
2938       if (Ops.size() == 1) return NewRec;
2939 
2940       // Otherwise, multiply the folded AddRec by the non-invariant parts.
2941       for (unsigned i = 0;; ++i)
2942         if (Ops[i] == AddRec) {
2943           Ops[i] = NewRec;
2944           break;
2945         }
2946       return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2947     }
2948 
2949     // Okay, if there weren't any loop invariants to be folded, check to see
2950     // if there are multiple AddRec's with the same loop induction variable
2951     // being multiplied together.  If so, we can fold them.
2952 
2953     // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
2954     // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
2955     //       choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
2956     //   ]]],+,...up to x=2n}.
2957     // Note that the arguments to choose() are always integers with values
2958     // known at compile time, never SCEV objects.
2959     //
2960     // The implementation avoids pointless extra computations when the two
2961     // addrec's are of different length (mathematically, it's equivalent to
2962     // an infinite stream of zeros on the right).
2963     bool OpsModified = false;
2964     for (unsigned OtherIdx = Idx+1;
2965          OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2966          ++OtherIdx) {
2967       const SCEVAddRecExpr *OtherAddRec =
2968         dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2969       if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
2970         continue;
2971 
2972       // Limit max number of arguments to avoid creation of unreasonably big
2973       // SCEVAddRecs with very complex operands.
2974       if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
2975           MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
2976         continue;
2977 
2978       bool Overflow = false;
2979       Type *Ty = AddRec->getType();
2980       bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
2981       SmallVector<const SCEV*, 7> AddRecOps;
2982       for (int x = 0, xe = AddRec->getNumOperands() +
2983              OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
2984         SmallVector <const SCEV *, 7> SumOps;
2985         for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
2986           uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
2987           for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
2988                  ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
2989                z < ze && !Overflow; ++z) {
2990             uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
2991             uint64_t Coeff;
2992             if (LargerThan64Bits)
2993               Coeff = umul_ov(Coeff1, Coeff2, Overflow);
2994             else
2995               Coeff = Coeff1*Coeff2;
2996             const SCEV *CoeffTerm = getConstant(Ty, Coeff);
2997             const SCEV *Term1 = AddRec->getOperand(y-z);
2998             const SCEV *Term2 = OtherAddRec->getOperand(z);
2999             SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3000                                         SCEV::FlagAnyWrap, Depth + 1));
3001           }
3002         }
3003         if (SumOps.empty())
3004           SumOps.push_back(getZero(Ty));
3005         AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3006       }
3007       if (!Overflow) {
3008         const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRecLoop,
3009                                               SCEV::FlagAnyWrap);
3010         if (Ops.size() == 2) return NewAddRec;
3011         Ops[Idx] = NewAddRec;
3012         Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3013         OpsModified = true;
3014         AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3015         if (!AddRec)
3016           break;
3017       }
3018     }
3019     if (OpsModified)
3020       return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3021 
3022     // Otherwise couldn't fold anything into this recurrence.  Move onto the
3023     // next one.
3024   }
3025 
3026   // Okay, it looks like we really DO need an mul expr.  Check to see if we
3027   // already have one, otherwise create a new one.
3028   return getOrCreateMulExpr(Ops, Flags);
3029 }
3030 
3031 /// Represents an unsigned remainder expression based on unsigned division.
3032 const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS,
3033                                          const SCEV *RHS) {
3034   assert(getEffectiveSCEVType(LHS->getType()) ==
3035          getEffectiveSCEVType(RHS->getType()) &&
3036          "SCEVURemExpr operand types don't match!");
3037 
3038   // Short-circuit easy cases
3039   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3040     // If constant is one, the result is trivial
3041     if (RHSC->getValue()->isOne())
3042       return getZero(LHS->getType()); // X urem 1 --> 0
3043 
3044     // If constant is a power of two, fold into a zext(trunc(LHS)).
3045     if (RHSC->getAPInt().isPowerOf2()) {
3046       Type *FullTy = LHS->getType();
3047       Type *TruncTy =
3048           IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3049       return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3050     }
3051   }
3052 
3053   // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3054   const SCEV *UDiv = getUDivExpr(LHS, RHS);
3055   const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3056   return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3057 }
3058 
3059 /// Get a canonical unsigned division expression, or something simpler if
3060 /// possible.
3061 const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
3062                                          const SCEV *RHS) {
3063   assert(getEffectiveSCEVType(LHS->getType()) ==
3064          getEffectiveSCEVType(RHS->getType()) &&
3065          "SCEVUDivExpr operand types don't match!");
3066 
3067   FoldingSetNodeID ID;
3068   ID.AddInteger(scUDivExpr);
3069   ID.AddPointer(LHS);
3070   ID.AddPointer(RHS);
3071   void *IP = nullptr;
3072   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3073     return S;
3074 
3075   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3076     if (RHSC->getValue()->isOne())
3077       return LHS;                               // X udiv 1 --> x
3078     // If the denominator is zero, the result of the udiv is undefined. Don't
3079     // try to analyze it, because the resolution chosen here may differ from
3080     // the resolution chosen in other parts of the compiler.
3081     if (!RHSC->getValue()->isZero()) {
3082       // Determine if the division can be folded into the operands of
3083       // its operands.
3084       // TODO: Generalize this to non-constants by using known-bits information.
3085       Type *Ty = LHS->getType();
3086       unsigned LZ = RHSC->getAPInt().countLeadingZeros();
3087       unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3088       // For non-power-of-two values, effectively round the value up to the
3089       // nearest power of two.
3090       if (!RHSC->getAPInt().isPowerOf2())
3091         ++MaxShiftAmt;
3092       IntegerType *ExtTy =
3093         IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3094       if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3095         if (const SCEVConstant *Step =
3096             dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3097           // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3098           const APInt &StepInt = Step->getAPInt();
3099           const APInt &DivInt = RHSC->getAPInt();
3100           if (!StepInt.urem(DivInt) &&
3101               getZeroExtendExpr(AR, ExtTy) ==
3102               getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3103                             getZeroExtendExpr(Step, ExtTy),
3104                             AR->getLoop(), SCEV::FlagAnyWrap)) {
3105             SmallVector<const SCEV *, 4> Operands;
3106             for (const SCEV *Op : AR->operands())
3107               Operands.push_back(getUDivExpr(Op, RHS));
3108             return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3109           }
3110           /// Get a canonical UDivExpr for a recurrence.
3111           /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3112           // We can currently only fold X%N if X is constant.
3113           const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3114           if (StartC && !DivInt.urem(StepInt) &&
3115               getZeroExtendExpr(AR, ExtTy) ==
3116               getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3117                             getZeroExtendExpr(Step, ExtTy),
3118                             AR->getLoop(), SCEV::FlagAnyWrap)) {
3119             const APInt &StartInt = StartC->getAPInt();
3120             const APInt &StartRem = StartInt.urem(StepInt);
3121             if (StartRem != 0) {
3122               const SCEV *NewLHS =
3123                   getAddRecExpr(getConstant(StartInt - StartRem), Step,
3124                                 AR->getLoop(), SCEV::FlagNW);
3125               if (LHS != NewLHS) {
3126                 LHS = NewLHS;
3127 
3128                 // Reset the ID to include the new LHS, and check if it is
3129                 // already cached.
3130                 ID.clear();
3131                 ID.AddInteger(scUDivExpr);
3132                 ID.AddPointer(LHS);
3133                 ID.AddPointer(RHS);
3134                 IP = nullptr;
3135                 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3136                   return S;
3137               }
3138             }
3139           }
3140         }
3141       // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3142       if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3143         SmallVector<const SCEV *, 4> Operands;
3144         for (const SCEV *Op : M->operands())
3145           Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3146         if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3147           // Find an operand that's safely divisible.
3148           for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3149             const SCEV *Op = M->getOperand(i);
3150             const SCEV *Div = getUDivExpr(Op, RHSC);
3151             if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3152               Operands = SmallVector<const SCEV *, 4>(M->op_begin(),
3153                                                       M->op_end());
3154               Operands[i] = Div;
3155               return getMulExpr(Operands);
3156             }
3157           }
3158       }
3159 
3160       // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3161       if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3162         if (auto *DivisorConstant =
3163                 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3164           bool Overflow = false;
3165           APInt NewRHS =
3166               DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3167           if (Overflow) {
3168             return getConstant(RHSC->getType(), 0, false);
3169           }
3170           return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3171         }
3172       }
3173 
3174       // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3175       if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3176         SmallVector<const SCEV *, 4> Operands;
3177         for (const SCEV *Op : A->operands())
3178           Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3179         if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3180           Operands.clear();
3181           for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3182             const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3183             if (isa<SCEVUDivExpr>(Op) ||
3184                 getMulExpr(Op, RHS) != A->getOperand(i))
3185               break;
3186             Operands.push_back(Op);
3187           }
3188           if (Operands.size() == A->getNumOperands())
3189             return getAddExpr(Operands);
3190         }
3191       }
3192 
3193       // Fold if both operands are constant.
3194       if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
3195         Constant *LHSCV = LHSC->getValue();
3196         Constant *RHSCV = RHSC->getValue();
3197         return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
3198                                                                    RHSCV)));
3199       }
3200     }
3201   }
3202 
3203   // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3204   // changes). Make sure we get a new one.
3205   IP = nullptr;
3206   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3207   SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3208                                              LHS, RHS);
3209   UniqueSCEVs.InsertNode(S, IP);
3210   addToLoopUseLists(S);
3211   return S;
3212 }
3213 
3214 static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3215   APInt A = C1->getAPInt().abs();
3216   APInt B = C2->getAPInt().abs();
3217   uint32_t ABW = A.getBitWidth();
3218   uint32_t BBW = B.getBitWidth();
3219 
3220   if (ABW > BBW)
3221     B = B.zext(ABW);
3222   else if (ABW < BBW)
3223     A = A.zext(BBW);
3224 
3225   return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3226 }
3227 
3228 /// Get a canonical unsigned division expression, or something simpler if
3229 /// possible. There is no representation for an exact udiv in SCEV IR, but we
3230 /// can attempt to remove factors from the LHS and RHS.  We can't do this when
3231 /// it's not exact because the udiv may be clearing bits.
3232 const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS,
3233                                               const SCEV *RHS) {
3234   // TODO: we could try to find factors in all sorts of things, but for now we
3235   // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3236   // end of this file for inspiration.
3237 
3238   const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3239   if (!Mul || !Mul->hasNoUnsignedWrap())
3240     return getUDivExpr(LHS, RHS);
3241 
3242   if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3243     // If the mulexpr multiplies by a constant, then that constant must be the
3244     // first element of the mulexpr.
3245     if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3246       if (LHSCst == RHSCst) {
3247         SmallVector<const SCEV *, 2> Operands;
3248         Operands.append(Mul->op_begin() + 1, Mul->op_end());
3249         return getMulExpr(Operands);
3250       }
3251 
3252       // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3253       // that there's a factor provided by one of the other terms. We need to
3254       // check.
3255       APInt Factor = gcd(LHSCst, RHSCst);
3256       if (!Factor.isIntN(1)) {
3257         LHSCst =
3258             cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3259         RHSCst =
3260             cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3261         SmallVector<const SCEV *, 2> Operands;
3262         Operands.push_back(LHSCst);
3263         Operands.append(Mul->op_begin() + 1, Mul->op_end());
3264         LHS = getMulExpr(Operands);
3265         RHS = RHSCst;
3266         Mul = dyn_cast<SCEVMulExpr>(LHS);
3267         if (!Mul)
3268           return getUDivExactExpr(LHS, RHS);
3269       }
3270     }
3271   }
3272 
3273   for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3274     if (Mul->getOperand(i) == RHS) {
3275       SmallVector<const SCEV *, 2> Operands;
3276       Operands.append(Mul->op_begin(), Mul->op_begin() + i);
3277       Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
3278       return getMulExpr(Operands);
3279     }
3280   }
3281 
3282   return getUDivExpr(LHS, RHS);
3283 }
3284 
3285 /// Get an add recurrence expression for the specified loop.  Simplify the
3286 /// expression as much as possible.
3287 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3288                                            const Loop *L,
3289                                            SCEV::NoWrapFlags Flags) {
3290   SmallVector<const SCEV *, 4> Operands;
3291   Operands.push_back(Start);
3292   if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3293     if (StepChrec->getLoop() == L) {
3294       Operands.append(StepChrec->op_begin(), StepChrec->op_end());
3295       return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3296     }
3297 
3298   Operands.push_back(Step);
3299   return getAddRecExpr(Operands, L, Flags);
3300 }
3301 
3302 /// Get an add recurrence expression for the specified loop.  Simplify the
3303 /// expression as much as possible.
3304 const SCEV *
3305 ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
3306                                const Loop *L, SCEV::NoWrapFlags Flags) {
3307   if (Operands.size() == 1) return Operands[0];
3308 #ifndef NDEBUG
3309   Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3310   for (unsigned i = 1, e = Operands.size(); i != e; ++i)
3311     assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
3312            "SCEVAddRecExpr operand types don't match!");
3313   for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3314     assert(isLoopInvariant(Operands[i], L) &&
3315            "SCEVAddRecExpr operand is not loop-invariant!");
3316 #endif
3317 
3318   if (Operands.back()->isZero()) {
3319     Operands.pop_back();
3320     return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0}  -->  X
3321   }
3322 
3323   // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3324   // use that information to infer NUW and NSW flags. However, computing a
3325   // BE count requires calling getAddRecExpr, so we may not yet have a
3326   // meaningful BE count at this point (and if we don't, we'd be stuck
3327   // with a SCEVCouldNotCompute as the cached BE count).
3328 
3329   Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3330 
3331   // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3332   if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3333     const Loop *NestedLoop = NestedAR->getLoop();
3334     if (L->contains(NestedLoop)
3335             ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3336             : (!NestedLoop->contains(L) &&
3337                DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3338       SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(),
3339                                                   NestedAR->op_end());
3340       Operands[0] = NestedAR->getStart();
3341       // AddRecs require their operands be loop-invariant with respect to their
3342       // loops. Don't perform this transformation if it would break this
3343       // requirement.
3344       bool AllInvariant = all_of(
3345           Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3346 
3347       if (AllInvariant) {
3348         // Create a recurrence for the outer loop with the same step size.
3349         //
3350         // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3351         // inner recurrence has the same property.
3352         SCEV::NoWrapFlags OuterFlags =
3353           maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3354 
3355         NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3356         AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3357           return isLoopInvariant(Op, NestedLoop);
3358         });
3359 
3360         if (AllInvariant) {
3361           // Ok, both add recurrences are valid after the transformation.
3362           //
3363           // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3364           // the outer recurrence has the same property.
3365           SCEV::NoWrapFlags InnerFlags =
3366             maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3367           return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3368         }
3369       }
3370       // Reset Operands to its original state.
3371       Operands[0] = NestedAR;
3372     }
3373   }
3374 
3375   // Okay, it looks like we really DO need an addrec expr.  Check to see if we
3376   // already have one, otherwise create a new one.
3377   return getOrCreateAddRecExpr(Operands, L, Flags);
3378 }
3379 
3380 const SCEV *
3381 ScalarEvolution::getGEPExpr(GEPOperator *GEP,
3382                             const SmallVectorImpl<const SCEV *> &IndexExprs) {
3383   const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3384   // getSCEV(Base)->getType() has the same address space as Base->getType()
3385   // because SCEV::getType() preserves the address space.
3386   Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3387   // FIXME(PR23527): Don't blindly transfer the inbounds flag from the GEP
3388   // instruction to its SCEV, because the Instruction may be guarded by control
3389   // flow and the no-overflow bits may not be valid for the expression in any
3390   // context. This can be fixed similarly to how these flags are handled for
3391   // adds.
3392   SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW
3393                                              : SCEV::FlagAnyWrap;
3394 
3395   const SCEV *TotalOffset = getZero(IntIdxTy);
3396   Type *CurTy = GEP->getType();
3397   bool FirstIter = true;
3398   for (const SCEV *IndexExpr : IndexExprs) {
3399     // Compute the (potentially symbolic) offset in bytes for this index.
3400     if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3401       // For a struct, add the member offset.
3402       ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3403       unsigned FieldNo = Index->getZExtValue();
3404       const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3405 
3406       // Add the field offset to the running total offset.
3407       TotalOffset = getAddExpr(TotalOffset, FieldOffset);
3408 
3409       // Update CurTy to the type of the field at Index.
3410       CurTy = STy->getTypeAtIndex(Index);
3411     } else {
3412       // Update CurTy to its element type.
3413       if (FirstIter) {
3414         assert(isa<PointerType>(CurTy) &&
3415                "The first index of a GEP indexes a pointer");
3416         CurTy = GEP->getSourceElementType();
3417         FirstIter = false;
3418       } else {
3419         CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0);
3420       }
3421       // For an array, add the element offset, explicitly scaled.
3422       const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3423       // Getelementptr indices are signed.
3424       IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3425 
3426       // Multiply the index by the element size to compute the element offset.
3427       const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, Wrap);
3428 
3429       // Add the element offset to the running total offset.
3430       TotalOffset = getAddExpr(TotalOffset, LocalOffset);
3431     }
3432   }
3433 
3434   // Add the total offset from all the GEP indices to the base.
3435   return getAddExpr(BaseExpr, TotalOffset, Wrap);
3436 }
3437 
3438 std::tuple<SCEV *, FoldingSetNodeID, void *>
3439 ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3440                                          ArrayRef<const SCEV *> Ops) {
3441   FoldingSetNodeID ID;
3442   void *IP = nullptr;
3443   ID.AddInteger(SCEVType);
3444   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3445     ID.AddPointer(Ops[i]);
3446   return std::tuple<SCEV *, FoldingSetNodeID, void *>(
3447       UniqueSCEVs.FindNodeOrInsertPos(ID, IP), std::move(ID), IP);
3448 }
3449 
3450 const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3451   SCEV::NoWrapFlags Flags = IsNSW ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3452   return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3453 }
3454 
3455 const SCEV *ScalarEvolution::getSignumExpr(const SCEV *Op) {
3456   Type *Ty = Op->getType();
3457   return getSMinExpr(getSMaxExpr(Op, getMinusOne(Ty)), getOne(Ty));
3458 }
3459 
3460 const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
3461                                            SmallVectorImpl<const SCEV *> &Ops) {
3462   assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3463   if (Ops.size() == 1) return Ops[0];
3464 #ifndef NDEBUG
3465   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3466   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3467     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3468            "Operand types don't match!");
3469 #endif
3470 
3471   bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3472   bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3473 
3474   // Sort by complexity, this groups all similar expression types together.
3475   GroupByComplexity(Ops, &LI, DT);
3476 
3477   // Check if we have created the same expression before.
3478   if (const SCEV *S = std::get<0>(findExistingSCEVInCache(Kind, Ops))) {
3479     return S;
3480   }
3481 
3482   // If there are any constants, fold them together.
3483   unsigned Idx = 0;
3484   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3485     ++Idx;
3486     assert(Idx < Ops.size());
3487     auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3488       if (Kind == scSMaxExpr)
3489         return APIntOps::smax(LHS, RHS);
3490       else if (Kind == scSMinExpr)
3491         return APIntOps::smin(LHS, RHS);
3492       else if (Kind == scUMaxExpr)
3493         return APIntOps::umax(LHS, RHS);
3494       else if (Kind == scUMinExpr)
3495         return APIntOps::umin(LHS, RHS);
3496       llvm_unreachable("Unknown SCEV min/max opcode");
3497     };
3498 
3499     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3500       // We found two constants, fold them together!
3501       ConstantInt *Fold = ConstantInt::get(
3502           getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3503       Ops[0] = getConstant(Fold);
3504       Ops.erase(Ops.begin()+1);  // Erase the folded element
3505       if (Ops.size() == 1) return Ops[0];
3506       LHSC = cast<SCEVConstant>(Ops[0]);
3507     }
3508 
3509     bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3510     bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3511 
3512     if (IsMax ? IsMinV : IsMaxV) {
3513       // If we are left with a constant minimum(/maximum)-int, strip it off.
3514       Ops.erase(Ops.begin());
3515       --Idx;
3516     } else if (IsMax ? IsMaxV : IsMinV) {
3517       // If we have a max(/min) with a constant maximum(/minimum)-int,
3518       // it will always be the extremum.
3519       return LHSC;
3520     }
3521 
3522     if (Ops.size() == 1) return Ops[0];
3523   }
3524 
3525   // Find the first operation of the same kind
3526   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3527     ++Idx;
3528 
3529   // Check to see if one of the operands is of the same kind. If so, expand its
3530   // operands onto our operand list, and recurse to simplify.
3531   if (Idx < Ops.size()) {
3532     bool DeletedAny = false;
3533     while (Ops[Idx]->getSCEVType() == Kind) {
3534       const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3535       Ops.erase(Ops.begin()+Idx);
3536       Ops.append(SMME->op_begin(), SMME->op_end());
3537       DeletedAny = true;
3538     }
3539 
3540     if (DeletedAny)
3541       return getMinMaxExpr(Kind, Ops);
3542   }
3543 
3544   // Okay, check to see if the same value occurs in the operand list twice.  If
3545   // so, delete one.  Since we sorted the list, these values are required to
3546   // be adjacent.
3547   llvm::CmpInst::Predicate GEPred =
3548       IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
3549   llvm::CmpInst::Predicate LEPred =
3550       IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
3551   llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3552   llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3553   for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3554     if (Ops[i] == Ops[i + 1] ||
3555         isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3556       //  X op Y op Y  -->  X op Y
3557       //  X op Y       -->  X, if we know X, Y are ordered appropriately
3558       Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3559       --i;
3560       --e;
3561     } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3562                                                Ops[i + 1])) {
3563       //  X op Y       -->  Y, if we know X, Y are ordered appropriately
3564       Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3565       --i;
3566       --e;
3567     }
3568   }
3569 
3570   if (Ops.size() == 1) return Ops[0];
3571 
3572   assert(!Ops.empty() && "Reduced smax down to nothing!");
3573 
3574   // Okay, it looks like we really DO need an expr.  Check to see if we
3575   // already have one, otherwise create a new one.
3576   const SCEV *ExistingSCEV;
3577   FoldingSetNodeID ID;
3578   void *IP;
3579   std::tie(ExistingSCEV, ID, IP) = findExistingSCEVInCache(Kind, Ops);
3580   if (ExistingSCEV)
3581     return ExistingSCEV;
3582   const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3583   std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3584   SCEV *S = new (SCEVAllocator)
3585       SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3586 
3587   UniqueSCEVs.InsertNode(S, IP);
3588   addToLoopUseLists(S);
3589   return S;
3590 }
3591 
3592 const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
3593   SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
3594   return getSMaxExpr(Ops);
3595 }
3596 
3597 const SCEV *ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
3598   return getMinMaxExpr(scSMaxExpr, Ops);
3599 }
3600 
3601 const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
3602   SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
3603   return getUMaxExpr(Ops);
3604 }
3605 
3606 const SCEV *ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
3607   return getMinMaxExpr(scUMaxExpr, Ops);
3608 }
3609 
3610 const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
3611                                          const SCEV *RHS) {
3612   SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
3613   return getSMinExpr(Ops);
3614 }
3615 
3616 const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl<const SCEV *> &Ops) {
3617   return getMinMaxExpr(scSMinExpr, Ops);
3618 }
3619 
3620 const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
3621                                          const SCEV *RHS) {
3622   SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
3623   return getUMinExpr(Ops);
3624 }
3625 
3626 const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops) {
3627   return getMinMaxExpr(scUMinExpr, Ops);
3628 }
3629 
3630 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
3631   if (isa<ScalableVectorType>(AllocTy)) {
3632     Constant *NullPtr = Constant::getNullValue(AllocTy->getPointerTo());
3633     Constant *One = ConstantInt::get(IntTy, 1);
3634     Constant *GEP = ConstantExpr::getGetElementPtr(AllocTy, NullPtr, One);
3635     // Note that the expression we created is the final expression, we don't
3636     // want to simplify it any further Also, if we call a normal getSCEV(),
3637     // we'll end up in an endless recursion. So just create an SCEVUnknown.
3638     return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy));
3639   }
3640   // We can bypass creating a target-independent
3641   // constant expression and then folding it back into a ConstantInt.
3642   // This is just a compile-time optimization.
3643   return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
3644 }
3645 
3646 const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
3647                                              StructType *STy,
3648                                              unsigned FieldNo) {
3649   // We can bypass creating a target-independent
3650   // constant expression and then folding it back into a ConstantInt.
3651   // This is just a compile-time optimization.
3652   return getConstant(
3653       IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
3654 }
3655 
3656 const SCEV *ScalarEvolution::getUnknown(Value *V) {
3657   // Don't attempt to do anything other than create a SCEVUnknown object
3658   // here.  createSCEV only calls getUnknown after checking for all other
3659   // interesting possibilities, and any other code that calls getUnknown
3660   // is doing so in order to hide a value from SCEV canonicalization.
3661 
3662   FoldingSetNodeID ID;
3663   ID.AddInteger(scUnknown);
3664   ID.AddPointer(V);
3665   void *IP = nullptr;
3666   if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
3667     assert(cast<SCEVUnknown>(S)->getValue() == V &&
3668            "Stale SCEVUnknown in uniquing map!");
3669     return S;
3670   }
3671   SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
3672                                             FirstUnknown);
3673   FirstUnknown = cast<SCEVUnknown>(S);
3674   UniqueSCEVs.InsertNode(S, IP);
3675   return S;
3676 }
3677 
3678 //===----------------------------------------------------------------------===//
3679 //            Basic SCEV Analysis and PHI Idiom Recognition Code
3680 //
3681 
3682 /// Test if values of the given type are analyzable within the SCEV
3683 /// framework. This primarily includes integer types, and it can optionally
3684 /// include pointer types if the ScalarEvolution class has access to
3685 /// target-specific information.
3686 bool ScalarEvolution::isSCEVable(Type *Ty) const {
3687   // Integers and pointers are always SCEVable.
3688   return Ty->isIntOrPtrTy();
3689 }
3690 
3691 /// Return the size in bits of the specified type, for which isSCEVable must
3692 /// return true.
3693 uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
3694   assert(isSCEVable(Ty) && "Type is not SCEVable!");
3695   if (Ty->isPointerTy())
3696     return getDataLayout().getIndexTypeSizeInBits(Ty);
3697   return getDataLayout().getTypeSizeInBits(Ty);
3698 }
3699 
3700 /// Return a type with the same bitwidth as the given type and which represents
3701 /// how SCEV will treat the given type, for which isSCEVable must return
3702 /// true. For pointer types, this is the pointer index sized integer type.
3703 Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
3704   assert(isSCEVable(Ty) && "Type is not SCEVable!");
3705 
3706   if (Ty->isIntegerTy())
3707     return Ty;
3708 
3709   // The only other support type is pointer.
3710   assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
3711   return getDataLayout().getIndexType(Ty);
3712 }
3713 
3714 Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const {
3715   return  getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
3716 }
3717 
3718 const SCEV *ScalarEvolution::getCouldNotCompute() {
3719   return CouldNotCompute.get();
3720 }
3721 
3722 bool ScalarEvolution::checkValidity(const SCEV *S) const {
3723   bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
3724     auto *SU = dyn_cast<SCEVUnknown>(S);
3725     return SU && SU->getValue() == nullptr;
3726   });
3727 
3728   return !ContainsNulls;
3729 }
3730 
3731 bool ScalarEvolution::containsAddRecurrence(const SCEV *S) {
3732   HasRecMapType::iterator I = HasRecMap.find(S);
3733   if (I != HasRecMap.end())
3734     return I->second;
3735 
3736   bool FoundAddRec =
3737       SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
3738   HasRecMap.insert({S, FoundAddRec});
3739   return FoundAddRec;
3740 }
3741 
3742 /// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}.
3743 /// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an
3744 /// offset I, then return {S', I}, else return {\p S, nullptr}.
3745 static std::pair<const SCEV *, ConstantInt *> splitAddExpr(const SCEV *S) {
3746   const auto *Add = dyn_cast<SCEVAddExpr>(S);
3747   if (!Add)
3748     return {S, nullptr};
3749 
3750   if (Add->getNumOperands() != 2)
3751     return {S, nullptr};
3752 
3753   auto *ConstOp = dyn_cast<SCEVConstant>(Add->getOperand(0));
3754   if (!ConstOp)
3755     return {S, nullptr};
3756 
3757   return {Add->getOperand(1), ConstOp->getValue()};
3758 }
3759 
3760 /// Return the ValueOffsetPair set for \p S. \p S can be represented
3761 /// by the value and offset from any ValueOffsetPair in the set.
3762 SetVector<ScalarEvolution::ValueOffsetPair> *
3763 ScalarEvolution::getSCEVValues(const SCEV *S) {
3764   ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
3765   if (SI == ExprValueMap.end())
3766     return nullptr;
3767 #ifndef NDEBUG
3768   if (VerifySCEVMap) {
3769     // Check there is no dangling Value in the set returned.
3770     for (const auto &VE : SI->second)
3771       assert(ValueExprMap.count(VE.first));
3772   }
3773 #endif
3774   return &SI->second;
3775 }
3776 
3777 /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
3778 /// cannot be used separately. eraseValueFromMap should be used to remove
3779 /// V from ValueExprMap and ExprValueMap at the same time.
3780 void ScalarEvolution::eraseValueFromMap(Value *V) {
3781   ValueExprMapType::iterator I = ValueExprMap.find_as(V);
3782   if (I != ValueExprMap.end()) {
3783     const SCEV *S = I->second;
3784     // Remove {V, 0} from the set of ExprValueMap[S]
3785     if (SetVector<ValueOffsetPair> *SV = getSCEVValues(S))
3786       SV->remove({V, nullptr});
3787 
3788     // Remove {V, Offset} from the set of ExprValueMap[Stripped]
3789     const SCEV *Stripped;
3790     ConstantInt *Offset;
3791     std::tie(Stripped, Offset) = splitAddExpr(S);
3792     if (Offset != nullptr) {
3793       if (SetVector<ValueOffsetPair> *SV = getSCEVValues(Stripped))
3794         SV->remove({V, Offset});
3795     }
3796     ValueExprMap.erase(V);
3797   }
3798 }
3799 
3800 /// Check whether value has nuw/nsw/exact set but SCEV does not.
3801 /// TODO: In reality it is better to check the poison recursively
3802 /// but this is better than nothing.
3803 static bool SCEVLostPoisonFlags(const SCEV *S, const Value *V) {
3804   if (auto *I = dyn_cast<Instruction>(V)) {
3805     if (isa<OverflowingBinaryOperator>(I)) {
3806       if (auto *NS = dyn_cast<SCEVNAryExpr>(S)) {
3807         if (I->hasNoSignedWrap() && !NS->hasNoSignedWrap())
3808           return true;
3809         if (I->hasNoUnsignedWrap() && !NS->hasNoUnsignedWrap())
3810           return true;
3811       }
3812     } else if (isa<PossiblyExactOperator>(I) && I->isExact())
3813       return true;
3814   }
3815   return false;
3816 }
3817 
3818 /// Return an existing SCEV if it exists, otherwise analyze the expression and
3819 /// create a new one.
3820 const SCEV *ScalarEvolution::getSCEV(Value *V) {
3821   assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
3822 
3823   const SCEV *S = getExistingSCEV(V);
3824   if (S == nullptr) {
3825     S = createSCEV(V);
3826     // During PHI resolution, it is possible to create two SCEVs for the same
3827     // V, so it is needed to double check whether V->S is inserted into
3828     // ValueExprMap before insert S->{V, 0} into ExprValueMap.
3829     std::pair<ValueExprMapType::iterator, bool> Pair =
3830         ValueExprMap.insert({SCEVCallbackVH(V, this), S});
3831     if (Pair.second && !SCEVLostPoisonFlags(S, V)) {
3832       ExprValueMap[S].insert({V, nullptr});
3833 
3834       // If S == Stripped + Offset, add Stripped -> {V, Offset} into
3835       // ExprValueMap.
3836       const SCEV *Stripped = S;
3837       ConstantInt *Offset = nullptr;
3838       std::tie(Stripped, Offset) = splitAddExpr(S);
3839       // If stripped is SCEVUnknown, don't bother to save
3840       // Stripped -> {V, offset}. It doesn't simplify and sometimes even
3841       // increase the complexity of the expansion code.
3842       // If V is GetElementPtrInst, don't save Stripped -> {V, offset}
3843       // because it may generate add/sub instead of GEP in SCEV expansion.
3844       if (Offset != nullptr && !isa<SCEVUnknown>(Stripped) &&
3845           !isa<GetElementPtrInst>(V))
3846         ExprValueMap[Stripped].insert({V, Offset});
3847     }
3848   }
3849   return S;
3850 }
3851 
3852 const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
3853   assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
3854 
3855   ValueExprMapType::iterator I = ValueExprMap.find_as(V);
3856   if (I != ValueExprMap.end()) {
3857     const SCEV *S = I->second;
3858     if (checkValidity(S))
3859       return S;
3860     eraseValueFromMap(V);
3861     forgetMemoizedResults(S);
3862   }
3863   return nullptr;
3864 }
3865 
3866 /// Return a SCEV corresponding to -V = -1*V
3867 const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
3868                                              SCEV::NoWrapFlags Flags) {
3869   if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
3870     return getConstant(
3871                cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
3872 
3873   Type *Ty = V->getType();
3874   Ty = getEffectiveSCEVType(Ty);
3875   return getMulExpr(V, getMinusOne(Ty), Flags);
3876 }
3877 
3878 /// If Expr computes ~A, return A else return nullptr
3879 static const SCEV *MatchNotExpr(const SCEV *Expr) {
3880   const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
3881   if (!Add || Add->getNumOperands() != 2 ||
3882       !Add->getOperand(0)->isAllOnesValue())
3883     return nullptr;
3884 
3885   const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
3886   if (!AddRHS || AddRHS->getNumOperands() != 2 ||
3887       !AddRHS->getOperand(0)->isAllOnesValue())
3888     return nullptr;
3889 
3890   return AddRHS->getOperand(1);
3891 }
3892 
3893 /// Return a SCEV corresponding to ~V = -1-V
3894 const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
3895   if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
3896     return getConstant(
3897                 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
3898 
3899   // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
3900   if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
3901     auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
3902       SmallVector<const SCEV *, 2> MatchedOperands;
3903       for (const SCEV *Operand : MME->operands()) {
3904         const SCEV *Matched = MatchNotExpr(Operand);
3905         if (!Matched)
3906           return (const SCEV *)nullptr;
3907         MatchedOperands.push_back(Matched);
3908       }
3909       return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
3910                            MatchedOperands);
3911     };
3912     if (const SCEV *Replaced = MatchMinMaxNegation(MME))
3913       return Replaced;
3914   }
3915 
3916   Type *Ty = V->getType();
3917   Ty = getEffectiveSCEVType(Ty);
3918   return getMinusSCEV(getMinusOne(Ty), V);
3919 }
3920 
3921 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
3922                                           SCEV::NoWrapFlags Flags,
3923                                           unsigned Depth) {
3924   // Fast path: X - X --> 0.
3925   if (LHS == RHS)
3926     return getZero(LHS->getType());
3927 
3928   // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
3929   // makes it so that we cannot make much use of NUW.
3930   auto AddFlags = SCEV::FlagAnyWrap;
3931   const bool RHSIsNotMinSigned =
3932       !getSignedRangeMin(RHS).isMinSignedValue();
3933   if (maskFlags(Flags, SCEV::FlagNSW) == SCEV::FlagNSW) {
3934     // Let M be the minimum representable signed value. Then (-1)*RHS
3935     // signed-wraps if and only if RHS is M. That can happen even for
3936     // a NSW subtraction because e.g. (-1)*M signed-wraps even though
3937     // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
3938     // (-1)*RHS, we need to prove that RHS != M.
3939     //
3940     // If LHS is non-negative and we know that LHS - RHS does not
3941     // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
3942     // either by proving that RHS > M or that LHS >= 0.
3943     if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
3944       AddFlags = SCEV::FlagNSW;
3945     }
3946   }
3947 
3948   // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
3949   // RHS is NSW and LHS >= 0.
3950   //
3951   // The difficulty here is that the NSW flag may have been proven
3952   // relative to a loop that is to be found in a recurrence in LHS and
3953   // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
3954   // larger scope than intended.
3955   auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3956 
3957   return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
3958 }
3959 
3960 const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty,
3961                                                      unsigned Depth) {
3962   Type *SrcTy = V->getType();
3963   assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
3964          "Cannot truncate or zero extend with non-integer arguments!");
3965   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3966     return V;  // No conversion
3967   if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
3968     return getTruncateExpr(V, Ty, Depth);
3969   return getZeroExtendExpr(V, Ty, Depth);
3970 }
3971 
3972 const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty,
3973                                                      unsigned Depth) {
3974   Type *SrcTy = V->getType();
3975   assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
3976          "Cannot truncate or zero extend with non-integer arguments!");
3977   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3978     return V;  // No conversion
3979   if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
3980     return getTruncateExpr(V, Ty, Depth);
3981   return getSignExtendExpr(V, Ty, Depth);
3982 }
3983 
3984 const SCEV *
3985 ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
3986   Type *SrcTy = V->getType();
3987   assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
3988          "Cannot noop or zero extend with non-integer arguments!");
3989   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
3990          "getNoopOrZeroExtend cannot truncate!");
3991   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
3992     return V;  // No conversion
3993   return getZeroExtendExpr(V, Ty);
3994 }
3995 
3996 const SCEV *
3997 ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
3998   Type *SrcTy = V->getType();
3999   assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4000          "Cannot noop or sign extend with non-integer arguments!");
4001   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4002          "getNoopOrSignExtend cannot truncate!");
4003   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4004     return V;  // No conversion
4005   return getSignExtendExpr(V, Ty);
4006 }
4007 
4008 const SCEV *
4009 ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
4010   Type *SrcTy = V->getType();
4011   assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4012          "Cannot noop or any extend with non-integer arguments!");
4013   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4014          "getNoopOrAnyExtend cannot truncate!");
4015   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4016     return V;  // No conversion
4017   return getAnyExtendExpr(V, Ty);
4018 }
4019 
4020 const SCEV *
4021 ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
4022   Type *SrcTy = V->getType();
4023   assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4024          "Cannot truncate or noop with non-integer arguments!");
4025   assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
4026          "getTruncateOrNoop cannot extend!");
4027   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4028     return V;  // No conversion
4029   return getTruncateExpr(V, Ty);
4030 }
4031 
4032 const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
4033                                                         const SCEV *RHS) {
4034   const SCEV *PromotedLHS = LHS;
4035   const SCEV *PromotedRHS = RHS;
4036 
4037   if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4038     PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4039   else
4040     PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4041 
4042   return getUMaxExpr(PromotedLHS, PromotedRHS);
4043 }
4044 
4045 const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
4046                                                         const SCEV *RHS) {
4047   SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4048   return getUMinFromMismatchedTypes(Ops);
4049 }
4050 
4051 const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(
4052     SmallVectorImpl<const SCEV *> &Ops) {
4053   assert(!Ops.empty() && "At least one operand must be!");
4054   // Trivial case.
4055   if (Ops.size() == 1)
4056     return Ops[0];
4057 
4058   // Find the max type first.
4059   Type *MaxType = nullptr;
4060   for (auto *S : Ops)
4061     if (MaxType)
4062       MaxType = getWiderType(MaxType, S->getType());
4063     else
4064       MaxType = S->getType();
4065   assert(MaxType && "Failed to find maximum type!");
4066 
4067   // Extend all ops to max type.
4068   SmallVector<const SCEV *, 2> PromotedOps;
4069   for (auto *S : Ops)
4070     PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4071 
4072   // Generate umin.
4073   return getUMinExpr(PromotedOps);
4074 }
4075 
4076 const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
4077   // A pointer operand may evaluate to a nonpointer expression, such as null.
4078   if (!V->getType()->isPointerTy())
4079     return V;
4080 
4081   while (true) {
4082     if (const SCEVIntegralCastExpr *Cast = dyn_cast<SCEVIntegralCastExpr>(V)) {
4083       V = Cast->getOperand();
4084     } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
4085       const SCEV *PtrOp = nullptr;
4086       for (const SCEV *NAryOp : NAry->operands()) {
4087         if (NAryOp->getType()->isPointerTy()) {
4088           // Cannot find the base of an expression with multiple pointer ops.
4089           if (PtrOp)
4090             return V;
4091           PtrOp = NAryOp;
4092         }
4093       }
4094       if (!PtrOp) // All operands were non-pointer.
4095         return V;
4096       V = PtrOp;
4097     } else // Not something we can look further into.
4098       return V;
4099   }
4100 }
4101 
4102 /// Push users of the given Instruction onto the given Worklist.
4103 static void
4104 PushDefUseChildren(Instruction *I,
4105                    SmallVectorImpl<Instruction *> &Worklist) {
4106   // Push the def-use children onto the Worklist stack.
4107   for (User *U : I->users())
4108     Worklist.push_back(cast<Instruction>(U));
4109 }
4110 
4111 void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) {
4112   SmallVector<Instruction *, 16> Worklist;
4113   PushDefUseChildren(PN, Worklist);
4114 
4115   SmallPtrSet<Instruction *, 8> Visited;
4116   Visited.insert(PN);
4117   while (!Worklist.empty()) {
4118     Instruction *I = Worklist.pop_back_val();
4119     if (!Visited.insert(I).second)
4120       continue;
4121 
4122     auto It = ValueExprMap.find_as(static_cast<Value *>(I));
4123     if (It != ValueExprMap.end()) {
4124       const SCEV *Old = It->second;
4125 
4126       // Short-circuit the def-use traversal if the symbolic name
4127       // ceases to appear in expressions.
4128       if (Old != SymName && !hasOperand(Old, SymName))
4129         continue;
4130 
4131       // SCEVUnknown for a PHI either means that it has an unrecognized
4132       // structure, it's a PHI that's in the progress of being computed
4133       // by createNodeForPHI, or it's a single-value PHI. In the first case,
4134       // additional loop trip count information isn't going to change anything.
4135       // In the second case, createNodeForPHI will perform the necessary
4136       // updates on its own when it gets to that point. In the third, we do
4137       // want to forget the SCEVUnknown.
4138       if (!isa<PHINode>(I) ||
4139           !isa<SCEVUnknown>(Old) ||
4140           (I != PN && Old == SymName)) {
4141         eraseValueFromMap(It->first);
4142         forgetMemoizedResults(Old);
4143       }
4144     }
4145 
4146     PushDefUseChildren(I, Worklist);
4147   }
4148 }
4149 
4150 namespace {
4151 
4152 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4153 /// expression in case its Loop is L. If it is not L then
4154 /// if IgnoreOtherLoops is true then use AddRec itself
4155 /// otherwise rewrite cannot be done.
4156 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4157 class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4158 public:
4159   static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4160                              bool IgnoreOtherLoops = true) {
4161     SCEVInitRewriter Rewriter(L, SE);
4162     const SCEV *Result = Rewriter.visit(S);
4163     if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4164       return SE.getCouldNotCompute();
4165     return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4166                ? SE.getCouldNotCompute()
4167                : Result;
4168   }
4169 
4170   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4171     if (!SE.isLoopInvariant(Expr, L))
4172       SeenLoopVariantSCEVUnknown = true;
4173     return Expr;
4174   }
4175 
4176   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4177     // Only re-write AddRecExprs for this loop.
4178     if (Expr->getLoop() == L)
4179       return Expr->getStart();
4180     SeenOtherLoops = true;
4181     return Expr;
4182   }
4183 
4184   bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4185 
4186   bool hasSeenOtherLoops() { return SeenOtherLoops; }
4187 
4188 private:
4189   explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4190       : SCEVRewriteVisitor(SE), L(L) {}
4191 
4192   const Loop *L;
4193   bool SeenLoopVariantSCEVUnknown = false;
4194   bool SeenOtherLoops = false;
4195 };
4196 
4197 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4198 /// increment expression in case its Loop is L. If it is not L then
4199 /// use AddRec itself.
4200 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4201 class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4202 public:
4203   static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4204     SCEVPostIncRewriter Rewriter(L, SE);
4205     const SCEV *Result = Rewriter.visit(S);
4206     return Rewriter.hasSeenLoopVariantSCEVUnknown()
4207         ? SE.getCouldNotCompute()
4208         : Result;
4209   }
4210 
4211   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4212     if (!SE.isLoopInvariant(Expr, L))
4213       SeenLoopVariantSCEVUnknown = true;
4214     return Expr;
4215   }
4216 
4217   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4218     // Only re-write AddRecExprs for this loop.
4219     if (Expr->getLoop() == L)
4220       return Expr->getPostIncExpr(SE);
4221     SeenOtherLoops = true;
4222     return Expr;
4223   }
4224 
4225   bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4226 
4227   bool hasSeenOtherLoops() { return SeenOtherLoops; }
4228 
4229 private:
4230   explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4231       : SCEVRewriteVisitor(SE), L(L) {}
4232 
4233   const Loop *L;
4234   bool SeenLoopVariantSCEVUnknown = false;
4235   bool SeenOtherLoops = false;
4236 };
4237 
4238 /// This class evaluates the compare condition by matching it against the
4239 /// condition of loop latch. If there is a match we assume a true value
4240 /// for the condition while building SCEV nodes.
4241 class SCEVBackedgeConditionFolder
4242     : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4243 public:
4244   static const SCEV *rewrite(const SCEV *S, const Loop *L,
4245                              ScalarEvolution &SE) {
4246     bool IsPosBECond = false;
4247     Value *BECond = nullptr;
4248     if (BasicBlock *Latch = L->getLoopLatch()) {
4249       BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4250       if (BI && BI->isConditional()) {
4251         assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4252                "Both outgoing branches should not target same header!");
4253         BECond = BI->getCondition();
4254         IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4255       } else {
4256         return S;
4257       }
4258     }
4259     SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4260     return Rewriter.visit(S);
4261   }
4262 
4263   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4264     const SCEV *Result = Expr;
4265     bool InvariantF = SE.isLoopInvariant(Expr, L);
4266 
4267     if (!InvariantF) {
4268       Instruction *I = cast<Instruction>(Expr->getValue());
4269       switch (I->getOpcode()) {
4270       case Instruction::Select: {
4271         SelectInst *SI = cast<SelectInst>(I);
4272         Optional<const SCEV *> Res =
4273             compareWithBackedgeCondition(SI->getCondition());
4274         if (Res.hasValue()) {
4275           bool IsOne = cast<SCEVConstant>(Res.getValue())->getValue()->isOne();
4276           Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
4277         }
4278         break;
4279       }
4280       default: {
4281         Optional<const SCEV *> Res = compareWithBackedgeCondition(I);
4282         if (Res.hasValue())
4283           Result = Res.getValue();
4284         break;
4285       }
4286       }
4287     }
4288     return Result;
4289   }
4290 
4291 private:
4292   explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
4293                                        bool IsPosBECond, ScalarEvolution &SE)
4294       : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
4295         IsPositiveBECond(IsPosBECond) {}
4296 
4297   Optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
4298 
4299   const Loop *L;
4300   /// Loop back condition.
4301   Value *BackedgeCond = nullptr;
4302   /// Set to true if loop back is on positive branch condition.
4303   bool IsPositiveBECond;
4304 };
4305 
4306 Optional<const SCEV *>
4307 SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
4308 
4309   // If value matches the backedge condition for loop latch,
4310   // then return a constant evolution node based on loopback
4311   // branch taken.
4312   if (BackedgeCond == IC)
4313     return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
4314                             : SE.getZero(Type::getInt1Ty(SE.getContext()));
4315   return None;
4316 }
4317 
4318 class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
4319 public:
4320   static const SCEV *rewrite(const SCEV *S, const Loop *L,
4321                              ScalarEvolution &SE) {
4322     SCEVShiftRewriter Rewriter(L, SE);
4323     const SCEV *Result = Rewriter.visit(S);
4324     return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
4325   }
4326 
4327   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4328     // Only allow AddRecExprs for this loop.
4329     if (!SE.isLoopInvariant(Expr, L))
4330       Valid = false;
4331     return Expr;
4332   }
4333 
4334   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4335     if (Expr->getLoop() == L && Expr->isAffine())
4336       return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
4337     Valid = false;
4338     return Expr;
4339   }
4340 
4341   bool isValid() { return Valid; }
4342 
4343 private:
4344   explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
4345       : SCEVRewriteVisitor(SE), L(L) {}
4346 
4347   const Loop *L;
4348   bool Valid = true;
4349 };
4350 
4351 } // end anonymous namespace
4352 
4353 SCEV::NoWrapFlags
4354 ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
4355   if (!AR->isAffine())
4356     return SCEV::FlagAnyWrap;
4357 
4358   using OBO = OverflowingBinaryOperator;
4359 
4360   SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap;
4361 
4362   if (!AR->hasNoSignedWrap()) {
4363     ConstantRange AddRecRange = getSignedRange(AR);
4364     ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
4365 
4366     auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
4367         Instruction::Add, IncRange, OBO::NoSignedWrap);
4368     if (NSWRegion.contains(AddRecRange))
4369       Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW);
4370   }
4371 
4372   if (!AR->hasNoUnsignedWrap()) {
4373     ConstantRange AddRecRange = getUnsignedRange(AR);
4374     ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
4375 
4376     auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
4377         Instruction::Add, IncRange, OBO::NoUnsignedWrap);
4378     if (NUWRegion.contains(AddRecRange))
4379       Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW);
4380   }
4381 
4382   return Result;
4383 }
4384 
4385 namespace {
4386 
4387 /// Represents an abstract binary operation.  This may exist as a
4388 /// normal instruction or constant expression, or may have been
4389 /// derived from an expression tree.
4390 struct BinaryOp {
4391   unsigned Opcode;
4392   Value *LHS;
4393   Value *RHS;
4394   bool IsNSW = false;
4395   bool IsNUW = false;
4396   bool IsExact = false;
4397 
4398   /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
4399   /// constant expression.
4400   Operator *Op = nullptr;
4401 
4402   explicit BinaryOp(Operator *Op)
4403       : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
4404         Op(Op) {
4405     if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
4406       IsNSW = OBO->hasNoSignedWrap();
4407       IsNUW = OBO->hasNoUnsignedWrap();
4408     }
4409     if (auto *PEO = dyn_cast<PossiblyExactOperator>(Op))
4410       IsExact = PEO->isExact();
4411   }
4412 
4413   explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
4414                     bool IsNUW = false, bool IsExact = false)
4415       : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW),
4416         IsExact(IsExact) {}
4417 };
4418 
4419 } // end anonymous namespace
4420 
4421 /// Try to map \p V into a BinaryOp, and return \c None on failure.
4422 static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
4423   auto *Op = dyn_cast<Operator>(V);
4424   if (!Op)
4425     return None;
4426 
4427   // Implementation detail: all the cleverness here should happen without
4428   // creating new SCEV expressions -- our caller knowns tricks to avoid creating
4429   // SCEV expressions when possible, and we should not break that.
4430 
4431   switch (Op->getOpcode()) {
4432   case Instruction::Add:
4433   case Instruction::Sub:
4434   case Instruction::Mul:
4435   case Instruction::UDiv:
4436   case Instruction::URem:
4437   case Instruction::And:
4438   case Instruction::Or:
4439   case Instruction::AShr:
4440   case Instruction::Shl:
4441     return BinaryOp(Op);
4442 
4443   case Instruction::Xor:
4444     if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
4445       // If the RHS of the xor is a signmask, then this is just an add.
4446       // Instcombine turns add of signmask into xor as a strength reduction step.
4447       if (RHSC->getValue().isSignMask())
4448         return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
4449     return BinaryOp(Op);
4450 
4451   case Instruction::LShr:
4452     // Turn logical shift right of a constant into a unsigned divide.
4453     if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
4454       uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
4455 
4456       // If the shift count is not less than the bitwidth, the result of
4457       // the shift is undefined. Don't try to analyze it, because the
4458       // resolution chosen here may differ from the resolution chosen in
4459       // other parts of the compiler.
4460       if (SA->getValue().ult(BitWidth)) {
4461         Constant *X =
4462             ConstantInt::get(SA->getContext(),
4463                              APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
4464         return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
4465       }
4466     }
4467     return BinaryOp(Op);
4468 
4469   case Instruction::ExtractValue: {
4470     auto *EVI = cast<ExtractValueInst>(Op);
4471     if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
4472       break;
4473 
4474     auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
4475     if (!WO)
4476       break;
4477 
4478     Instruction::BinaryOps BinOp = WO->getBinaryOp();
4479     bool Signed = WO->isSigned();
4480     // TODO: Should add nuw/nsw flags for mul as well.
4481     if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
4482       return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
4483 
4484     // Now that we know that all uses of the arithmetic-result component of
4485     // CI are guarded by the overflow check, we can go ahead and pretend
4486     // that the arithmetic is non-overflowing.
4487     return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
4488                     /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
4489   }
4490 
4491   default:
4492     break;
4493   }
4494 
4495   // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
4496   // semantics as a Sub, return a binary sub expression.
4497   if (auto *II = dyn_cast<IntrinsicInst>(V))
4498     if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
4499       return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
4500 
4501   return None;
4502 }
4503 
4504 /// Helper function to createAddRecFromPHIWithCasts. We have a phi
4505 /// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
4506 /// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
4507 /// way. This function checks if \p Op, an operand of this SCEVAddExpr,
4508 /// follows one of the following patterns:
4509 /// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
4510 /// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
4511 /// If the SCEV expression of \p Op conforms with one of the expected patterns
4512 /// we return the type of the truncation operation, and indicate whether the
4513 /// truncated type should be treated as signed/unsigned by setting
4514 /// \p Signed to true/false, respectively.
4515 static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
4516                                bool &Signed, ScalarEvolution &SE) {
4517   // The case where Op == SymbolicPHI (that is, with no type conversions on
4518   // the way) is handled by the regular add recurrence creating logic and
4519   // would have already been triggered in createAddRecForPHI. Reaching it here
4520   // means that createAddRecFromPHI had failed for this PHI before (e.g.,
4521   // because one of the other operands of the SCEVAddExpr updating this PHI is
4522   // not invariant).
4523   //
4524   // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
4525   // this case predicates that allow us to prove that Op == SymbolicPHI will
4526   // be added.
4527   if (Op == SymbolicPHI)
4528     return nullptr;
4529 
4530   unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
4531   unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
4532   if (SourceBits != NewBits)
4533     return nullptr;
4534 
4535   const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
4536   const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
4537   if (!SExt && !ZExt)
4538     return nullptr;
4539   const SCEVTruncateExpr *Trunc =
4540       SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
4541            : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
4542   if (!Trunc)
4543     return nullptr;
4544   const SCEV *X = Trunc->getOperand();
4545   if (X != SymbolicPHI)
4546     return nullptr;
4547   Signed = SExt != nullptr;
4548   return Trunc->getType();
4549 }
4550 
4551 static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
4552   if (!PN->getType()->isIntegerTy())
4553     return nullptr;
4554   const Loop *L = LI.getLoopFor(PN->getParent());
4555   if (!L || L->getHeader() != PN->getParent())
4556     return nullptr;
4557   return L;
4558 }
4559 
4560 // Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
4561 // computation that updates the phi follows the following pattern:
4562 //   (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
4563 // which correspond to a phi->trunc->sext/zext->add->phi update chain.
4564 // If so, try to see if it can be rewritten as an AddRecExpr under some
4565 // Predicates. If successful, return them as a pair. Also cache the results
4566 // of the analysis.
4567 //
4568 // Example usage scenario:
4569 //    Say the Rewriter is called for the following SCEV:
4570 //         8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
4571 //    where:
4572 //         %X = phi i64 (%Start, %BEValue)
4573 //    It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
4574 //    and call this function with %SymbolicPHI = %X.
4575 //
4576 //    The analysis will find that the value coming around the backedge has
4577 //    the following SCEV:
4578 //         BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
4579 //    Upon concluding that this matches the desired pattern, the function
4580 //    will return the pair {NewAddRec, SmallPredsVec} where:
4581 //         NewAddRec = {%Start,+,%Step}
4582 //         SmallPredsVec = {P1, P2, P3} as follows:
4583 //           P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
4584 //           P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
4585 //           P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
4586 //    The returned pair means that SymbolicPHI can be rewritten into NewAddRec
4587 //    under the predicates {P1,P2,P3}.
4588 //    This predicated rewrite will be cached in PredicatedSCEVRewrites:
4589 //         PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
4590 //
4591 // TODO's:
4592 //
4593 // 1) Extend the Induction descriptor to also support inductions that involve
4594 //    casts: When needed (namely, when we are called in the context of the
4595 //    vectorizer induction analysis), a Set of cast instructions will be
4596 //    populated by this method, and provided back to isInductionPHI. This is
4597 //    needed to allow the vectorizer to properly record them to be ignored by
4598 //    the cost model and to avoid vectorizing them (otherwise these casts,
4599 //    which are redundant under the runtime overflow checks, will be
4600 //    vectorized, which can be costly).
4601 //
4602 // 2) Support additional induction/PHISCEV patterns: We also want to support
4603 //    inductions where the sext-trunc / zext-trunc operations (partly) occur
4604 //    after the induction update operation (the induction increment):
4605 //
4606 //      (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
4607 //    which correspond to a phi->add->trunc->sext/zext->phi update chain.
4608 //
4609 //      (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
4610 //    which correspond to a phi->trunc->add->sext/zext->phi update chain.
4611 //
4612 // 3) Outline common code with createAddRecFromPHI to avoid duplication.
4613 Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
4614 ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
4615   SmallVector<const SCEVPredicate *, 3> Predicates;
4616 
4617   // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
4618   // return an AddRec expression under some predicate.
4619 
4620   auto *PN = cast<PHINode>(SymbolicPHI->getValue());
4621   const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
4622   assert(L && "Expecting an integer loop header phi");
4623 
4624   // The loop may have multiple entrances or multiple exits; we can analyze
4625   // this phi as an addrec if it has a unique entry value and a unique
4626   // backedge value.
4627   Value *BEValueV = nullptr, *StartValueV = nullptr;
4628   for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
4629     Value *V = PN->getIncomingValue(i);
4630     if (L->contains(PN->getIncomingBlock(i))) {
4631       if (!BEValueV) {
4632         BEValueV = V;
4633       } else if (BEValueV != V) {
4634         BEValueV = nullptr;
4635         break;
4636       }
4637     } else if (!StartValueV) {
4638       StartValueV = V;
4639     } else if (StartValueV != V) {
4640       StartValueV = nullptr;
4641       break;
4642     }
4643   }
4644   if (!BEValueV || !StartValueV)
4645     return None;
4646 
4647   const SCEV *BEValue = getSCEV(BEValueV);
4648 
4649   // If the value coming around the backedge is an add with the symbolic
4650   // value we just inserted, possibly with casts that we can ignore under
4651   // an appropriate runtime guard, then we found a simple induction variable!
4652   const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
4653   if (!Add)
4654     return None;
4655 
4656   // If there is a single occurrence of the symbolic value, possibly
4657   // casted, replace it with a recurrence.
4658   unsigned FoundIndex = Add->getNumOperands();
4659   Type *TruncTy = nullptr;
4660   bool Signed;
4661   for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
4662     if ((TruncTy =
4663              isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
4664       if (FoundIndex == e) {
4665         FoundIndex = i;
4666         break;
4667       }
4668 
4669   if (FoundIndex == Add->getNumOperands())
4670     return None;
4671 
4672   // Create an add with everything but the specified operand.
4673   SmallVector<const SCEV *, 8> Ops;
4674   for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
4675     if (i != FoundIndex)
4676       Ops.push_back(Add->getOperand(i));
4677   const SCEV *Accum = getAddExpr(Ops);
4678 
4679   // The runtime checks will not be valid if the step amount is
4680   // varying inside the loop.
4681   if (!isLoopInvariant(Accum, L))
4682     return None;
4683 
4684   // *** Part2: Create the predicates
4685 
4686   // Analysis was successful: we have a phi-with-cast pattern for which we
4687   // can return an AddRec expression under the following predicates:
4688   //
4689   // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
4690   //     fits within the truncated type (does not overflow) for i = 0 to n-1.
4691   // P2: An Equal predicate that guarantees that
4692   //     Start = (Ext ix (Trunc iy (Start) to ix) to iy)
4693   // P3: An Equal predicate that guarantees that
4694   //     Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
4695   //
4696   // As we next prove, the above predicates guarantee that:
4697   //     Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
4698   //
4699   //
4700   // More formally, we want to prove that:
4701   //     Expr(i+1) = Start + (i+1) * Accum
4702   //               = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
4703   //
4704   // Given that:
4705   // 1) Expr(0) = Start
4706   // 2) Expr(1) = Start + Accum
4707   //            = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
4708   // 3) Induction hypothesis (step i):
4709   //    Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
4710   //
4711   // Proof:
4712   //  Expr(i+1) =
4713   //   = Start + (i+1)*Accum
4714   //   = (Start + i*Accum) + Accum
4715   //   = Expr(i) + Accum
4716   //   = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
4717   //                                                             :: from step i
4718   //
4719   //   = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
4720   //
4721   //   = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
4722   //     + (Ext ix (Trunc iy (Accum) to ix) to iy)
4723   //     + Accum                                                     :: from P3
4724   //
4725   //   = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
4726   //     + Accum                            :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
4727   //
4728   //   = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
4729   //   = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
4730   //
4731   // By induction, the same applies to all iterations 1<=i<n:
4732   //
4733 
4734   // Create a truncated addrec for which we will add a no overflow check (P1).
4735   const SCEV *StartVal = getSCEV(StartValueV);
4736   const SCEV *PHISCEV =
4737       getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
4738                     getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
4739 
4740   // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
4741   // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
4742   // will be constant.
4743   //
4744   //  If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
4745   // add P1.
4746   if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
4747     SCEVWrapPredicate::IncrementWrapFlags AddedFlags =
4748         Signed ? SCEVWrapPredicate::IncrementNSSW
4749                : SCEVWrapPredicate::IncrementNUSW;
4750     const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
4751     Predicates.push_back(AddRecPred);
4752   }
4753 
4754   // Create the Equal Predicates P2,P3:
4755 
4756   // It is possible that the predicates P2 and/or P3 are computable at
4757   // compile time due to StartVal and/or Accum being constants.
4758   // If either one is, then we can check that now and escape if either P2
4759   // or P3 is false.
4760 
4761   // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
4762   // for each of StartVal and Accum
4763   auto getExtendedExpr = [&](const SCEV *Expr,
4764                              bool CreateSignExtend) -> const SCEV * {
4765     assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
4766     const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
4767     const SCEV *ExtendedExpr =
4768         CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
4769                          : getZeroExtendExpr(TruncatedExpr, Expr->getType());
4770     return ExtendedExpr;
4771   };
4772 
4773   // Given:
4774   //  ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
4775   //               = getExtendedExpr(Expr)
4776   // Determine whether the predicate P: Expr == ExtendedExpr
4777   // is known to be false at compile time
4778   auto PredIsKnownFalse = [&](const SCEV *Expr,
4779                               const SCEV *ExtendedExpr) -> bool {
4780     return Expr != ExtendedExpr &&
4781            isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
4782   };
4783 
4784   const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
4785   if (PredIsKnownFalse(StartVal, StartExtended)) {
4786     LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
4787     return None;
4788   }
4789 
4790   // The Step is always Signed (because the overflow checks are either
4791   // NSSW or NUSW)
4792   const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
4793   if (PredIsKnownFalse(Accum, AccumExtended)) {
4794     LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
4795     return None;
4796   }
4797 
4798   auto AppendPredicate = [&](const SCEV *Expr,
4799                              const SCEV *ExtendedExpr) -> void {
4800     if (Expr != ExtendedExpr &&
4801         !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
4802       const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
4803       LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
4804       Predicates.push_back(Pred);
4805     }
4806   };
4807 
4808   AppendPredicate(StartVal, StartExtended);
4809   AppendPredicate(Accum, AccumExtended);
4810 
4811   // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
4812   // which the casts had been folded away. The caller can rewrite SymbolicPHI
4813   // into NewAR if it will also add the runtime overflow checks specified in
4814   // Predicates.
4815   auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
4816 
4817   std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
4818       std::make_pair(NewAR, Predicates);
4819   // Remember the result of the analysis for this SCEV at this locayyytion.
4820   PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
4821   return PredRewrite;
4822 }
4823 
4824 Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
4825 ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {
4826   auto *PN = cast<PHINode>(SymbolicPHI->getValue());
4827   const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
4828   if (!L)
4829     return None;
4830 
4831   // Check to see if we already analyzed this PHI.
4832   auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
4833   if (I != PredicatedSCEVRewrites.end()) {
4834     std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
4835         I->second;
4836     // Analysis was done before and failed to create an AddRec:
4837     if (Rewrite.first == SymbolicPHI)
4838       return None;
4839     // Analysis was done before and succeeded to create an AddRec under
4840     // a predicate:
4841     assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
4842     assert(!(Rewrite.second).empty() && "Expected to find Predicates");
4843     return Rewrite;
4844   }
4845 
4846   Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
4847     Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
4848 
4849   // Record in the cache that the analysis failed
4850   if (!Rewrite) {
4851     SmallVector<const SCEVPredicate *, 3> Predicates;
4852     PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
4853     return None;
4854   }
4855 
4856   return Rewrite;
4857 }
4858 
4859 // FIXME: This utility is currently required because the Rewriter currently
4860 // does not rewrite this expression:
4861 // {0, +, (sext ix (trunc iy to ix) to iy)}
4862 // into {0, +, %step},
4863 // even when the following Equal predicate exists:
4864 // "%step == (sext ix (trunc iy to ix) to iy)".
4865 bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
4866     const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
4867   if (AR1 == AR2)
4868     return true;
4869 
4870   auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
4871     if (Expr1 != Expr2 && !Preds.implies(SE.getEqualPredicate(Expr1, Expr2)) &&
4872         !Preds.implies(SE.getEqualPredicate(Expr2, Expr1)))
4873       return false;
4874     return true;
4875   };
4876 
4877   if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
4878       !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
4879     return false;
4880   return true;
4881 }
4882 
4883 /// A helper function for createAddRecFromPHI to handle simple cases.
4884 ///
4885 /// This function tries to find an AddRec expression for the simplest (yet most
4886 /// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
4887 /// If it fails, createAddRecFromPHI will use a more general, but slow,
4888 /// technique for finding the AddRec expression.
4889 const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
4890                                                       Value *BEValueV,
4891                                                       Value *StartValueV) {
4892   const Loop *L = LI.getLoopFor(PN->getParent());
4893   assert(L && L->getHeader() == PN->getParent());
4894   assert(BEValueV && StartValueV);
4895 
4896   auto BO = MatchBinaryOp(BEValueV, DT);
4897   if (!BO)
4898     return nullptr;
4899 
4900   if (BO->Opcode != Instruction::Add)
4901     return nullptr;
4902 
4903   const SCEV *Accum = nullptr;
4904   if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
4905     Accum = getSCEV(BO->RHS);
4906   else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
4907     Accum = getSCEV(BO->LHS);
4908 
4909   if (!Accum)
4910     return nullptr;
4911 
4912   SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
4913   if (BO->IsNUW)
4914     Flags = setFlags(Flags, SCEV::FlagNUW);
4915   if (BO->IsNSW)
4916     Flags = setFlags(Flags, SCEV::FlagNSW);
4917 
4918   const SCEV *StartVal = getSCEV(StartValueV);
4919   const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
4920 
4921   ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
4922 
4923   // We can add Flags to the post-inc expression only if we
4924   // know that it is *undefined behavior* for BEValueV to
4925   // overflow.
4926   if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
4927     if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
4928       (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
4929 
4930   return PHISCEV;
4931 }
4932 
4933 const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
4934   const Loop *L = LI.getLoopFor(PN->getParent());
4935   if (!L || L->getHeader() != PN->getParent())
4936     return nullptr;
4937 
4938   // The loop may have multiple entrances or multiple exits; we can analyze
4939   // this phi as an addrec if it has a unique entry value and a unique
4940   // backedge value.
4941   Value *BEValueV = nullptr, *StartValueV = nullptr;
4942   for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
4943     Value *V = PN->getIncomingValue(i);
4944     if (L->contains(PN->getIncomingBlock(i))) {
4945       if (!BEValueV) {
4946         BEValueV = V;
4947       } else if (BEValueV != V) {
4948         BEValueV = nullptr;
4949         break;
4950       }
4951     } else if (!StartValueV) {
4952       StartValueV = V;
4953     } else if (StartValueV != V) {
4954       StartValueV = nullptr;
4955       break;
4956     }
4957   }
4958   if (!BEValueV || !StartValueV)
4959     return nullptr;
4960 
4961   assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
4962          "PHI node already processed?");
4963 
4964   // First, try to find AddRec expression without creating a fictituos symbolic
4965   // value for PN.
4966   if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
4967     return S;
4968 
4969   // Handle PHI node value symbolically.
4970   const SCEV *SymbolicName = getUnknown(PN);
4971   ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName});
4972 
4973   // Using this symbolic name for the PHI, analyze the value coming around
4974   // the back-edge.
4975   const SCEV *BEValue = getSCEV(BEValueV);
4976 
4977   // NOTE: If BEValue is loop invariant, we know that the PHI node just
4978   // has a special value for the first iteration of the loop.
4979 
4980   // If the value coming around the backedge is an add with the symbolic
4981   // value we just inserted, then we found a simple induction variable!
4982   if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
4983     // If there is a single occurrence of the symbolic value, replace it
4984     // with a recurrence.
4985     unsigned FoundIndex = Add->getNumOperands();
4986     for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
4987       if (Add->getOperand(i) == SymbolicName)
4988         if (FoundIndex == e) {
4989           FoundIndex = i;
4990           break;
4991         }
4992 
4993     if (FoundIndex != Add->getNumOperands()) {
4994       // Create an add with everything but the specified operand.
4995       SmallVector<const SCEV *, 8> Ops;
4996       for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
4997         if (i != FoundIndex)
4998           Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
4999                                                              L, *this));
5000       const SCEV *Accum = getAddExpr(Ops);
5001 
5002       // This is not a valid addrec if the step amount is varying each
5003       // loop iteration, but is not itself an addrec in this loop.
5004       if (isLoopInvariant(Accum, L) ||
5005           (isa<SCEVAddRecExpr>(Accum) &&
5006            cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5007         SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
5008 
5009         if (auto BO = MatchBinaryOp(BEValueV, DT)) {
5010           if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5011             if (BO->IsNUW)
5012               Flags = setFlags(Flags, SCEV::FlagNUW);
5013             if (BO->IsNSW)
5014               Flags = setFlags(Flags, SCEV::FlagNSW);
5015           }
5016         } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5017           // If the increment is an inbounds GEP, then we know the address
5018           // space cannot be wrapped around. We cannot make any guarantee
5019           // about signed or unsigned overflow because pointers are
5020           // unsigned but we may have a negative index from the base
5021           // pointer. We can guarantee that no unsigned wrap occurs if the
5022           // indices form a positive value.
5023           if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
5024             Flags = setFlags(Flags, SCEV::FlagNW);
5025 
5026             const SCEV *Ptr = getSCEV(GEP->getPointerOperand());
5027             if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))
5028               Flags = setFlags(Flags, SCEV::FlagNUW);
5029           }
5030 
5031           // We cannot transfer nuw and nsw flags from subtraction
5032           // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5033           // for instance.
5034         }
5035 
5036         const SCEV *StartVal = getSCEV(StartValueV);
5037         const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5038 
5039         // Okay, for the entire analysis of this edge we assumed the PHI
5040         // to be symbolic.  We now need to go back and purge all of the
5041         // entries for the scalars that use the symbolic expression.
5042         forgetSymbolicName(PN, SymbolicName);
5043         ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
5044 
5045         // We can add Flags to the post-inc expression only if we
5046         // know that it is *undefined behavior* for BEValueV to
5047         // overflow.
5048         if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5049           if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5050             (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5051 
5052         return PHISCEV;
5053       }
5054     }
5055   } else {
5056     // Otherwise, this could be a loop like this:
5057     //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
5058     // In this case, j = {1,+,1}  and BEValue is j.
5059     // Because the other in-value of i (0) fits the evolution of BEValue
5060     // i really is an addrec evolution.
5061     //
5062     // We can generalize this saying that i is the shifted value of BEValue
5063     // by one iteration:
5064     //   PHI(f(0), f({1,+,1})) --> f({0,+,1})
5065     const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5066     const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5067     if (Shifted != getCouldNotCompute() &&
5068         Start != getCouldNotCompute()) {
5069       const SCEV *StartVal = getSCEV(StartValueV);
5070       if (Start == StartVal) {
5071         // Okay, for the entire analysis of this edge we assumed the PHI
5072         // to be symbolic.  We now need to go back and purge all of the
5073         // entries for the scalars that use the symbolic expression.
5074         forgetSymbolicName(PN, SymbolicName);
5075         ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted;
5076         return Shifted;
5077       }
5078     }
5079   }
5080 
5081   // Remove the temporary PHI node SCEV that has been inserted while intending
5082   // to create an AddRecExpr for this PHI node. We can not keep this temporary
5083   // as it will prevent later (possibly simpler) SCEV expressions to be added
5084   // to the ValueExprMap.
5085   eraseValueFromMap(PN);
5086 
5087   return nullptr;
5088 }
5089 
5090 // Checks if the SCEV S is available at BB.  S is considered available at BB
5091 // if S can be materialized at BB without introducing a fault.
5092 static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
5093                                BasicBlock *BB) {
5094   struct CheckAvailable {
5095     bool TraversalDone = false;
5096     bool Available = true;
5097 
5098     const Loop *L = nullptr;  // The loop BB is in (can be nullptr)
5099     BasicBlock *BB = nullptr;
5100     DominatorTree &DT;
5101 
5102     CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT)
5103       : L(L), BB(BB), DT(DT) {}
5104 
5105     bool setUnavailable() {
5106       TraversalDone = true;
5107       Available = false;
5108       return false;
5109     }
5110 
5111     bool follow(const SCEV *S) {
5112       switch (S->getSCEVType()) {
5113       case scConstant:
5114       case scPtrToInt:
5115       case scTruncate:
5116       case scZeroExtend:
5117       case scSignExtend:
5118       case scAddExpr:
5119       case scMulExpr:
5120       case scUMaxExpr:
5121       case scSMaxExpr:
5122       case scUMinExpr:
5123       case scSMinExpr:
5124         // These expressions are available if their operand(s) is/are.
5125         return true;
5126 
5127       case scAddRecExpr: {
5128         // We allow add recurrences that are on the loop BB is in, or some
5129         // outer loop.  This guarantees availability because the value of the
5130         // add recurrence at BB is simply the "current" value of the induction
5131         // variable.  We can relax this in the future; for instance an add
5132         // recurrence on a sibling dominating loop is also available at BB.
5133         const auto *ARLoop = cast<SCEVAddRecExpr>(S)->getLoop();
5134         if (L && (ARLoop == L || ARLoop->contains(L)))
5135           return true;
5136 
5137         return setUnavailable();
5138       }
5139 
5140       case scUnknown: {
5141         // For SCEVUnknown, we check for simple dominance.
5142         const auto *SU = cast<SCEVUnknown>(S);
5143         Value *V = SU->getValue();
5144 
5145         if (isa<Argument>(V))
5146           return false;
5147 
5148         if (isa<Instruction>(V) && DT.dominates(cast<Instruction>(V), BB))
5149           return false;
5150 
5151         return setUnavailable();
5152       }
5153 
5154       case scUDivExpr:
5155       case scCouldNotCompute:
5156         // We do not try to smart about these at all.
5157         return setUnavailable();
5158       }
5159       llvm_unreachable("Unknown SCEV kind!");
5160     }
5161 
5162     bool isDone() { return TraversalDone; }
5163   };
5164 
5165   CheckAvailable CA(L, BB, DT);
5166   SCEVTraversal<CheckAvailable> ST(CA);
5167 
5168   ST.visitAll(S);
5169   return CA.Available;
5170 }
5171 
5172 // Try to match a control flow sequence that branches out at BI and merges back
5173 // at Merge into a "C ? LHS : RHS" select pattern.  Return true on a successful
5174 // match.
5175 static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge,
5176                           Value *&C, Value *&LHS, Value *&RHS) {
5177   C = BI->getCondition();
5178 
5179   BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5180   BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5181 
5182   if (!LeftEdge.isSingleEdge())
5183     return false;
5184 
5185   assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5186 
5187   Use &LeftUse = Merge->getOperandUse(0);
5188   Use &RightUse = Merge->getOperandUse(1);
5189 
5190   if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5191     LHS = LeftUse;
5192     RHS = RightUse;
5193     return true;
5194   }
5195 
5196   if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5197     LHS = RightUse;
5198     RHS = LeftUse;
5199     return true;
5200   }
5201 
5202   return false;
5203 }
5204 
5205 const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5206   auto IsReachable =
5207       [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
5208   if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
5209     const Loop *L = LI.getLoopFor(PN->getParent());
5210 
5211     // We don't want to break LCSSA, even in a SCEV expression tree.
5212     for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
5213       if (LI.getLoopFor(PN->getIncomingBlock(i)) != L)
5214         return nullptr;
5215 
5216     // Try to match
5217     //
5218     //  br %cond, label %left, label %right
5219     // left:
5220     //  br label %merge
5221     // right:
5222     //  br label %merge
5223     // merge:
5224     //  V = phi [ %x, %left ], [ %y, %right ]
5225     //
5226     // as "select %cond, %x, %y"
5227 
5228     BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
5229     assert(IDom && "At least the entry block should dominate PN");
5230 
5231     auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
5232     Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
5233 
5234     if (BI && BI->isConditional() &&
5235         BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
5236         IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) &&
5237         IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent()))
5238       return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
5239   }
5240 
5241   return nullptr;
5242 }
5243 
5244 const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
5245   if (const SCEV *S = createAddRecFromPHI(PN))
5246     return S;
5247 
5248   if (const SCEV *S = createNodeFromSelectLikePHI(PN))
5249     return S;
5250 
5251   // If the PHI has a single incoming value, follow that value, unless the
5252   // PHI's incoming blocks are in a different loop, in which case doing so
5253   // risks breaking LCSSA form. Instcombine would normally zap these, but
5254   // it doesn't have DominatorTree information, so it may miss cases.
5255   if (Value *V = SimplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC}))
5256     if (LI.replacementPreservesLCSSAForm(PN, V))
5257       return getSCEV(V);
5258 
5259   // If it's not a loop phi, we can't handle it yet.
5260   return getUnknown(PN);
5261 }
5262 
5263 const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,
5264                                                       Value *Cond,
5265                                                       Value *TrueVal,
5266                                                       Value *FalseVal) {
5267   // Handle "constant" branch or select. This can occur for instance when a
5268   // loop pass transforms an inner loop and moves on to process the outer loop.
5269   if (auto *CI = dyn_cast<ConstantInt>(Cond))
5270     return getSCEV(CI->isOne() ? TrueVal : FalseVal);
5271 
5272   // Try to match some simple smax or umax patterns.
5273   auto *ICI = dyn_cast<ICmpInst>(Cond);
5274   if (!ICI)
5275     return getUnknown(I);
5276 
5277   Value *LHS = ICI->getOperand(0);
5278   Value *RHS = ICI->getOperand(1);
5279 
5280   switch (ICI->getPredicate()) {
5281   case ICmpInst::ICMP_SLT:
5282   case ICmpInst::ICMP_SLE:
5283     std::swap(LHS, RHS);
5284     LLVM_FALLTHROUGH;
5285   case ICmpInst::ICMP_SGT:
5286   case ICmpInst::ICMP_SGE:
5287     // a >s b ? a+x : b+x  ->  smax(a, b)+x
5288     // a >s b ? b+x : a+x  ->  smin(a, b)+x
5289     if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
5290       const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType());
5291       const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType());
5292       const SCEV *LA = getSCEV(TrueVal);
5293       const SCEV *RA = getSCEV(FalseVal);
5294       const SCEV *LDiff = getMinusSCEV(LA, LS);
5295       const SCEV *RDiff = getMinusSCEV(RA, RS);
5296       if (LDiff == RDiff)
5297         return getAddExpr(getSMaxExpr(LS, RS), LDiff);
5298       LDiff = getMinusSCEV(LA, RS);
5299       RDiff = getMinusSCEV(RA, LS);
5300       if (LDiff == RDiff)
5301         return getAddExpr(getSMinExpr(LS, RS), LDiff);
5302     }
5303     break;
5304   case ICmpInst::ICMP_ULT:
5305   case ICmpInst::ICMP_ULE:
5306     std::swap(LHS, RHS);
5307     LLVM_FALLTHROUGH;
5308   case ICmpInst::ICMP_UGT:
5309   case ICmpInst::ICMP_UGE:
5310     // a >u b ? a+x : b+x  ->  umax(a, b)+x
5311     // a >u b ? b+x : a+x  ->  umin(a, b)+x
5312     if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
5313       const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
5314       const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType());
5315       const SCEV *LA = getSCEV(TrueVal);
5316       const SCEV *RA = getSCEV(FalseVal);
5317       const SCEV *LDiff = getMinusSCEV(LA, LS);
5318       const SCEV *RDiff = getMinusSCEV(RA, RS);
5319       if (LDiff == RDiff)
5320         return getAddExpr(getUMaxExpr(LS, RS), LDiff);
5321       LDiff = getMinusSCEV(LA, RS);
5322       RDiff = getMinusSCEV(RA, LS);
5323       if (LDiff == RDiff)
5324         return getAddExpr(getUMinExpr(LS, RS), LDiff);
5325     }
5326     break;
5327   case ICmpInst::ICMP_NE:
5328     // n != 0 ? n+x : 1+x  ->  umax(n, 1)+x
5329     if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
5330         isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
5331       const SCEV *One = getOne(I->getType());
5332       const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
5333       const SCEV *LA = getSCEV(TrueVal);
5334       const SCEV *RA = getSCEV(FalseVal);
5335       const SCEV *LDiff = getMinusSCEV(LA, LS);
5336       const SCEV *RDiff = getMinusSCEV(RA, One);
5337       if (LDiff == RDiff)
5338         return getAddExpr(getUMaxExpr(One, LS), LDiff);
5339     }
5340     break;
5341   case ICmpInst::ICMP_EQ:
5342     // n == 0 ? 1+x : n+x  ->  umax(n, 1)+x
5343     if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
5344         isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
5345       const SCEV *One = getOne(I->getType());
5346       const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
5347       const SCEV *LA = getSCEV(TrueVal);
5348       const SCEV *RA = getSCEV(FalseVal);
5349       const SCEV *LDiff = getMinusSCEV(LA, One);
5350       const SCEV *RDiff = getMinusSCEV(RA, LS);
5351       if (LDiff == RDiff)
5352         return getAddExpr(getUMaxExpr(One, LS), LDiff);
5353     }
5354     break;
5355   default:
5356     break;
5357   }
5358 
5359   return getUnknown(I);
5360 }
5361 
5362 /// Expand GEP instructions into add and multiply operations. This allows them
5363 /// to be analyzed by regular SCEV code.
5364 const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
5365   // Don't attempt to analyze GEPs over unsized objects.
5366   if (!GEP->getSourceElementType()->isSized())
5367     return getUnknown(GEP);
5368 
5369   SmallVector<const SCEV *, 4> IndexExprs;
5370   for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index)
5371     IndexExprs.push_back(getSCEV(*Index));
5372   return getGEPExpr(GEP, IndexExprs);
5373 }
5374 
5375 uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
5376   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
5377     return C->getAPInt().countTrailingZeros();
5378 
5379   if (const SCEVPtrToIntExpr *I = dyn_cast<SCEVPtrToIntExpr>(S))
5380     return GetMinTrailingZeros(I->getOperand());
5381 
5382   if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
5383     return std::min(GetMinTrailingZeros(T->getOperand()),
5384                     (uint32_t)getTypeSizeInBits(T->getType()));
5385 
5386   if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
5387     uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
5388     return OpRes == getTypeSizeInBits(E->getOperand()->getType())
5389                ? getTypeSizeInBits(E->getType())
5390                : OpRes;
5391   }
5392 
5393   if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
5394     uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
5395     return OpRes == getTypeSizeInBits(E->getOperand()->getType())
5396                ? getTypeSizeInBits(E->getType())
5397                : OpRes;
5398   }
5399 
5400   if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
5401     // The result is the min of all operands results.
5402     uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
5403     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
5404       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
5405     return MinOpRes;
5406   }
5407 
5408   if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
5409     // The result is the sum of all operands results.
5410     uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
5411     uint32_t BitWidth = getTypeSizeInBits(M->getType());
5412     for (unsigned i = 1, e = M->getNumOperands();
5413          SumOpRes != BitWidth && i != e; ++i)
5414       SumOpRes =
5415           std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth);
5416     return SumOpRes;
5417   }
5418 
5419   if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
5420     // The result is the min of all operands results.
5421     uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
5422     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
5423       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
5424     return MinOpRes;
5425   }
5426 
5427   if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
5428     // The result is the min of all operands results.
5429     uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
5430     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
5431       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
5432     return MinOpRes;
5433   }
5434 
5435   if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
5436     // The result is the min of all operands results.
5437     uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
5438     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
5439       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
5440     return MinOpRes;
5441   }
5442 
5443   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
5444     // For a SCEVUnknown, ask ValueTracking.
5445     KnownBits Known = computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT);
5446     return Known.countMinTrailingZeros();
5447   }
5448 
5449   // SCEVUDivExpr
5450   return 0;
5451 }
5452 
5453 uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
5454   auto I = MinTrailingZerosCache.find(S);
5455   if (I != MinTrailingZerosCache.end())
5456     return I->second;
5457 
5458   uint32_t Result = GetMinTrailingZerosImpl(S);
5459   auto InsertPair = MinTrailingZerosCache.insert({S, Result});
5460   assert(InsertPair.second && "Should insert a new key");
5461   return InsertPair.first->second;
5462 }
5463 
5464 /// Helper method to assign a range to V from metadata present in the IR.
5465 static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
5466   if (Instruction *I = dyn_cast<Instruction>(V))
5467     if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
5468       return getConstantRangeFromMetadata(*MD);
5469 
5470   return None;
5471 }
5472 
5473 /// Determine the range for a particular SCEV.  If SignHint is
5474 /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
5475 /// with a "cleaner" unsigned (resp. signed) representation.
5476 const ConstantRange &
5477 ScalarEvolution::getRangeRef(const SCEV *S,
5478                              ScalarEvolution::RangeSignHint SignHint) {
5479   DenseMap<const SCEV *, ConstantRange> &Cache =
5480       SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
5481                                                        : SignedRanges;
5482   ConstantRange::PreferredRangeType RangeType =
5483       SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED
5484           ? ConstantRange::Unsigned : ConstantRange::Signed;
5485 
5486   // See if we've computed this range already.
5487   DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S);
5488   if (I != Cache.end())
5489     return I->second;
5490 
5491   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
5492     return setRange(C, SignHint, ConstantRange(C->getAPInt()));
5493 
5494   unsigned BitWidth = getTypeSizeInBits(S->getType());
5495   ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
5496   using OBO = OverflowingBinaryOperator;
5497 
5498   // If the value has known zeros, the maximum value will have those known zeros
5499   // as well.
5500   uint32_t TZ = GetMinTrailingZeros(S);
5501   if (TZ != 0) {
5502     if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED)
5503       ConservativeResult =
5504           ConstantRange(APInt::getMinValue(BitWidth),
5505                         APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
5506     else
5507       ConservativeResult = ConstantRange(
5508           APInt::getSignedMinValue(BitWidth),
5509           APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
5510   }
5511 
5512   if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
5513     ConstantRange X = getRangeRef(Add->getOperand(0), SignHint);
5514     unsigned WrapType = OBO::AnyWrap;
5515     if (Add->hasNoSignedWrap())
5516       WrapType |= OBO::NoSignedWrap;
5517     if (Add->hasNoUnsignedWrap())
5518       WrapType |= OBO::NoUnsignedWrap;
5519     for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
5520       X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint),
5521                           WrapType, RangeType);
5522     return setRange(Add, SignHint,
5523                     ConservativeResult.intersectWith(X, RangeType));
5524   }
5525 
5526   if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
5527     ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint);
5528     for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
5529       X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint));
5530     return setRange(Mul, SignHint,
5531                     ConservativeResult.intersectWith(X, RangeType));
5532   }
5533 
5534   if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
5535     ConstantRange X = getRangeRef(SMax->getOperand(0), SignHint);
5536     for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
5537       X = X.smax(getRangeRef(SMax->getOperand(i), SignHint));
5538     return setRange(SMax, SignHint,
5539                     ConservativeResult.intersectWith(X, RangeType));
5540   }
5541 
5542   if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
5543     ConstantRange X = getRangeRef(UMax->getOperand(0), SignHint);
5544     for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
5545       X = X.umax(getRangeRef(UMax->getOperand(i), SignHint));
5546     return setRange(UMax, SignHint,
5547                     ConservativeResult.intersectWith(X, RangeType));
5548   }
5549 
5550   if (const SCEVSMinExpr *SMin = dyn_cast<SCEVSMinExpr>(S)) {
5551     ConstantRange X = getRangeRef(SMin->getOperand(0), SignHint);
5552     for (unsigned i = 1, e = SMin->getNumOperands(); i != e; ++i)
5553       X = X.smin(getRangeRef(SMin->getOperand(i), SignHint));
5554     return setRange(SMin, SignHint,
5555                     ConservativeResult.intersectWith(X, RangeType));
5556   }
5557 
5558   if (const SCEVUMinExpr *UMin = dyn_cast<SCEVUMinExpr>(S)) {
5559     ConstantRange X = getRangeRef(UMin->getOperand(0), SignHint);
5560     for (unsigned i = 1, e = UMin->getNumOperands(); i != e; ++i)
5561       X = X.umin(getRangeRef(UMin->getOperand(i), SignHint));
5562     return setRange(UMin, SignHint,
5563                     ConservativeResult.intersectWith(X, RangeType));
5564   }
5565 
5566   if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
5567     ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint);
5568     ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint);
5569     return setRange(UDiv, SignHint,
5570                     ConservativeResult.intersectWith(X.udiv(Y), RangeType));
5571   }
5572 
5573   if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
5574     ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint);
5575     return setRange(ZExt, SignHint,
5576                     ConservativeResult.intersectWith(X.zeroExtend(BitWidth),
5577                                                      RangeType));
5578   }
5579 
5580   if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
5581     ConstantRange X = getRangeRef(SExt->getOperand(), SignHint);
5582     return setRange(SExt, SignHint,
5583                     ConservativeResult.intersectWith(X.signExtend(BitWidth),
5584                                                      RangeType));
5585   }
5586 
5587   if (const SCEVPtrToIntExpr *PtrToInt = dyn_cast<SCEVPtrToIntExpr>(S)) {
5588     ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint);
5589     return setRange(PtrToInt, SignHint, X);
5590   }
5591 
5592   if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
5593     ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint);
5594     return setRange(Trunc, SignHint,
5595                     ConservativeResult.intersectWith(X.truncate(BitWidth),
5596                                                      RangeType));
5597   }
5598 
5599   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
5600     // If there's no unsigned wrap, the value will never be less than its
5601     // initial value.
5602     if (AddRec->hasNoUnsignedWrap()) {
5603       APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
5604       if (!UnsignedMinValue.isNullValue())
5605         ConservativeResult = ConservativeResult.intersectWith(
5606             ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
5607     }
5608 
5609     // If there's no signed wrap, and all the operands except initial value have
5610     // the same sign or zero, the value won't ever be:
5611     // 1: smaller than initial value if operands are non negative,
5612     // 2: bigger than initial value if operands are non positive.
5613     // For both cases, value can not cross signed min/max boundary.
5614     if (AddRec->hasNoSignedWrap()) {
5615       bool AllNonNeg = true;
5616       bool AllNonPos = true;
5617       for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
5618         if (!isKnownNonNegative(AddRec->getOperand(i)))
5619           AllNonNeg = false;
5620         if (!isKnownNonPositive(AddRec->getOperand(i)))
5621           AllNonPos = false;
5622       }
5623       if (AllNonNeg)
5624         ConservativeResult = ConservativeResult.intersectWith(
5625             ConstantRange::getNonEmpty(getSignedRangeMin(AddRec->getStart()),
5626                                        APInt::getSignedMinValue(BitWidth)),
5627             RangeType);
5628       else if (AllNonPos)
5629         ConservativeResult = ConservativeResult.intersectWith(
5630             ConstantRange::getNonEmpty(
5631                 APInt::getSignedMinValue(BitWidth),
5632                 getSignedRangeMax(AddRec->getStart()) + 1),
5633             RangeType);
5634     }
5635 
5636     // TODO: non-affine addrec
5637     if (AddRec->isAffine()) {
5638       const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(AddRec->getLoop());
5639       if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
5640           getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
5641         auto RangeFromAffine = getRangeForAffineAR(
5642             AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
5643             BitWidth);
5644         ConservativeResult =
5645             ConservativeResult.intersectWith(RangeFromAffine, RangeType);
5646 
5647         auto RangeFromFactoring = getRangeViaFactoring(
5648             AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
5649             BitWidth);
5650         ConservativeResult =
5651             ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
5652       }
5653 
5654       // Now try symbolic BE count and more powerful methods.
5655       if (UseExpensiveRangeSharpening) {
5656         const SCEV *SymbolicMaxBECount =
5657             getSymbolicMaxBackedgeTakenCount(AddRec->getLoop());
5658         if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
5659             getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
5660             AddRec->hasNoSelfWrap()) {
5661           auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
5662               AddRec, SymbolicMaxBECount, BitWidth, SignHint);
5663           ConservativeResult =
5664               ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
5665         }
5666       }
5667     }
5668 
5669     return setRange(AddRec, SignHint, std::move(ConservativeResult));
5670   }
5671 
5672   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
5673     // Check if the IR explicitly contains !range metadata.
5674     Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
5675     if (MDRange.hasValue())
5676       ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue(),
5677                                                             RangeType);
5678 
5679     // Split here to avoid paying the compile-time cost of calling both
5680     // computeKnownBits and ComputeNumSignBits.  This restriction can be lifted
5681     // if needed.
5682     const DataLayout &DL = getDataLayout();
5683     if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
5684       // For a SCEVUnknown, ask ValueTracking.
5685       KnownBits Known = computeKnownBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
5686       if (Known.getBitWidth() != BitWidth)
5687         Known = Known.zextOrTrunc(BitWidth);
5688       // If Known does not result in full-set, intersect with it.
5689       if (Known.getMinValue() != Known.getMaxValue() + 1)
5690         ConservativeResult = ConservativeResult.intersectWith(
5691             ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
5692             RangeType);
5693     } else {
5694       assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED &&
5695              "generalize as needed!");
5696       unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
5697       // If the pointer size is larger than the index size type, this can cause
5698       // NS to be larger than BitWidth. So compensate for this.
5699       if (U->getType()->isPointerTy()) {
5700         unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
5701         int ptrIdxDiff = ptrSize - BitWidth;
5702         if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
5703           NS -= ptrIdxDiff;
5704       }
5705 
5706       if (NS > 1)
5707         ConservativeResult = ConservativeResult.intersectWith(
5708             ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
5709                           APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
5710             RangeType);
5711     }
5712 
5713     // A range of Phi is a subset of union of all ranges of its input.
5714     if (const PHINode *Phi = dyn_cast<PHINode>(U->getValue())) {
5715       // Make sure that we do not run over cycled Phis.
5716       if (PendingPhiRanges.insert(Phi).second) {
5717         ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
5718         for (auto &Op : Phi->operands()) {
5719           auto OpRange = getRangeRef(getSCEV(Op), SignHint);
5720           RangeFromOps = RangeFromOps.unionWith(OpRange);
5721           // No point to continue if we already have a full set.
5722           if (RangeFromOps.isFullSet())
5723             break;
5724         }
5725         ConservativeResult =
5726             ConservativeResult.intersectWith(RangeFromOps, RangeType);
5727         bool Erased = PendingPhiRanges.erase(Phi);
5728         assert(Erased && "Failed to erase Phi properly?");
5729         (void) Erased;
5730       }
5731     }
5732 
5733     return setRange(U, SignHint, std::move(ConservativeResult));
5734   }
5735 
5736   return setRange(S, SignHint, std::move(ConservativeResult));
5737 }
5738 
5739 // Given a StartRange, Step and MaxBECount for an expression compute a range of
5740 // values that the expression can take. Initially, the expression has a value
5741 // from StartRange and then is changed by Step up to MaxBECount times. Signed
5742 // argument defines if we treat Step as signed or unsigned.
5743 static ConstantRange getRangeForAffineARHelper(APInt Step,
5744                                                const ConstantRange &StartRange,
5745                                                const APInt &MaxBECount,
5746                                                unsigned BitWidth, bool Signed) {
5747   // If either Step or MaxBECount is 0, then the expression won't change, and we
5748   // just need to return the initial range.
5749   if (Step == 0 || MaxBECount == 0)
5750     return StartRange;
5751 
5752   // If we don't know anything about the initial value (i.e. StartRange is
5753   // FullRange), then we don't know anything about the final range either.
5754   // Return FullRange.
5755   if (StartRange.isFullSet())
5756     return ConstantRange::getFull(BitWidth);
5757 
5758   // If Step is signed and negative, then we use its absolute value, but we also
5759   // note that we're moving in the opposite direction.
5760   bool Descending = Signed && Step.isNegative();
5761 
5762   if (Signed)
5763     // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
5764     // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
5765     // This equations hold true due to the well-defined wrap-around behavior of
5766     // APInt.
5767     Step = Step.abs();
5768 
5769   // Check if Offset is more than full span of BitWidth. If it is, the
5770   // expression is guaranteed to overflow.
5771   if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
5772     return ConstantRange::getFull(BitWidth);
5773 
5774   // Offset is by how much the expression can change. Checks above guarantee no
5775   // overflow here.
5776   APInt Offset = Step * MaxBECount;
5777 
5778   // Minimum value of the final range will match the minimal value of StartRange
5779   // if the expression is increasing and will be decreased by Offset otherwise.
5780   // Maximum value of the final range will match the maximal value of StartRange
5781   // if the expression is decreasing and will be increased by Offset otherwise.
5782   APInt StartLower = StartRange.getLower();
5783   APInt StartUpper = StartRange.getUpper() - 1;
5784   APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
5785                                    : (StartUpper + std::move(Offset));
5786 
5787   // It's possible that the new minimum/maximum value will fall into the initial
5788   // range (due to wrap around). This means that the expression can take any
5789   // value in this bitwidth, and we have to return full range.
5790   if (StartRange.contains(MovedBoundary))
5791     return ConstantRange::getFull(BitWidth);
5792 
5793   APInt NewLower =
5794       Descending ? std::move(MovedBoundary) : std::move(StartLower);
5795   APInt NewUpper =
5796       Descending ? std::move(StartUpper) : std::move(MovedBoundary);
5797   NewUpper += 1;
5798 
5799   // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
5800   return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
5801 }
5802 
5803 ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
5804                                                    const SCEV *Step,
5805                                                    const SCEV *MaxBECount,
5806                                                    unsigned BitWidth) {
5807   assert(!isa<SCEVCouldNotCompute>(MaxBECount) &&
5808          getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
5809          "Precondition!");
5810 
5811   MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType());
5812   APInt MaxBECountValue = getUnsignedRangeMax(MaxBECount);
5813 
5814   // First, consider step signed.
5815   ConstantRange StartSRange = getSignedRange(Start);
5816   ConstantRange StepSRange = getSignedRange(Step);
5817 
5818   // If Step can be both positive and negative, we need to find ranges for the
5819   // maximum absolute step values in both directions and union them.
5820   ConstantRange SR =
5821       getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange,
5822                                 MaxBECountValue, BitWidth, /* Signed = */ true);
5823   SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(),
5824                                               StartSRange, MaxBECountValue,
5825                                               BitWidth, /* Signed = */ true));
5826 
5827   // Next, consider step unsigned.
5828   ConstantRange UR = getRangeForAffineARHelper(
5829       getUnsignedRangeMax(Step), getUnsignedRange(Start),
5830       MaxBECountValue, BitWidth, /* Signed = */ false);
5831 
5832   // Finally, intersect signed and unsigned ranges.
5833   return SR.intersectWith(UR, ConstantRange::Smallest);
5834 }
5835 
5836 ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
5837     const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
5838     ScalarEvolution::RangeSignHint SignHint) {
5839   assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
5840   assert(AddRec->hasNoSelfWrap() &&
5841          "This only works for non-self-wrapping AddRecs!");
5842   const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
5843   const SCEV *Step = AddRec->getStepRecurrence(*this);
5844   // Only deal with constant step to save compile time.
5845   if (!isa<SCEVConstant>(Step))
5846     return ConstantRange::getFull(BitWidth);
5847   // Let's make sure that we can prove that we do not self-wrap during
5848   // MaxBECount iterations. We need this because MaxBECount is a maximum
5849   // iteration count estimate, and we might infer nw from some exit for which we
5850   // do not know max exit count (or any other side reasoning).
5851   // TODO: Turn into assert at some point.
5852   MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
5853   const SCEV *RangeWidth = getMinusOne(AddRec->getType());
5854   const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
5855   const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
5856   if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
5857                                          MaxItersWithoutWrap))
5858     return ConstantRange::getFull(BitWidth);
5859 
5860   ICmpInst::Predicate LEPred =
5861       IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
5862   ICmpInst::Predicate GEPred =
5863       IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
5864   const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
5865 
5866   // We know that there is no self-wrap. Let's take Start and End values and
5867   // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
5868   // the iteration. They either lie inside the range [Min(Start, End),
5869   // Max(Start, End)] or outside it:
5870   //
5871   // Case 1:   RangeMin    ...    Start V1 ... VN End ...           RangeMax;
5872   // Case 2:   RangeMin Vk ... V1 Start    ...    End Vn ... Vk + 1 RangeMax;
5873   //
5874   // No self wrap flag guarantees that the intermediate values cannot be BOTH
5875   // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
5876   // knowledge, let's try to prove that we are dealing with Case 1. It is so if
5877   // Start <= End and step is positive, or Start >= End and step is negative.
5878   const SCEV *Start = AddRec->getStart();
5879   ConstantRange StartRange = getRangeRef(Start, SignHint);
5880   ConstantRange EndRange = getRangeRef(End, SignHint);
5881   ConstantRange RangeBetween = StartRange.unionWith(EndRange);
5882   // If they already cover full iteration space, we will know nothing useful
5883   // even if we prove what we want to prove.
5884   if (RangeBetween.isFullSet())
5885     return RangeBetween;
5886   // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
5887   bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
5888                                : RangeBetween.isWrappedSet();
5889   if (IsWrappedSet)
5890     return ConstantRange::getFull(BitWidth);
5891 
5892   if (isKnownPositive(Step) &&
5893       isKnownPredicateViaConstantRanges(LEPred, Start, End))
5894     return RangeBetween;
5895   else if (isKnownNegative(Step) &&
5896            isKnownPredicateViaConstantRanges(GEPred, Start, End))
5897     return RangeBetween;
5898   return ConstantRange::getFull(BitWidth);
5899 }
5900 
5901 ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
5902                                                     const SCEV *Step,
5903                                                     const SCEV *MaxBECount,
5904                                                     unsigned BitWidth) {
5905   //    RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
5906   // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
5907 
5908   struct SelectPattern {
5909     Value *Condition = nullptr;
5910     APInt TrueValue;
5911     APInt FalseValue;
5912 
5913     explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
5914                            const SCEV *S) {
5915       Optional<unsigned> CastOp;
5916       APInt Offset(BitWidth, 0);
5917 
5918       assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
5919              "Should be!");
5920 
5921       // Peel off a constant offset:
5922       if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
5923         // In the future we could consider being smarter here and handle
5924         // {Start+Step,+,Step} too.
5925         if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
5926           return;
5927 
5928         Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
5929         S = SA->getOperand(1);
5930       }
5931 
5932       // Peel off a cast operation
5933       if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
5934         CastOp = SCast->getSCEVType();
5935         S = SCast->getOperand();
5936       }
5937 
5938       using namespace llvm::PatternMatch;
5939 
5940       auto *SU = dyn_cast<SCEVUnknown>(S);
5941       const APInt *TrueVal, *FalseVal;
5942       if (!SU ||
5943           !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
5944                                           m_APInt(FalseVal)))) {
5945         Condition = nullptr;
5946         return;
5947       }
5948 
5949       TrueValue = *TrueVal;
5950       FalseValue = *FalseVal;
5951 
5952       // Re-apply the cast we peeled off earlier
5953       if (CastOp.hasValue())
5954         switch (*CastOp) {
5955         default:
5956           llvm_unreachable("Unknown SCEV cast type!");
5957 
5958         case scTruncate:
5959           TrueValue = TrueValue.trunc(BitWidth);
5960           FalseValue = FalseValue.trunc(BitWidth);
5961           break;
5962         case scZeroExtend:
5963           TrueValue = TrueValue.zext(BitWidth);
5964           FalseValue = FalseValue.zext(BitWidth);
5965           break;
5966         case scSignExtend:
5967           TrueValue = TrueValue.sext(BitWidth);
5968           FalseValue = FalseValue.sext(BitWidth);
5969           break;
5970         }
5971 
5972       // Re-apply the constant offset we peeled off earlier
5973       TrueValue += Offset;
5974       FalseValue += Offset;
5975     }
5976 
5977     bool isRecognized() { return Condition != nullptr; }
5978   };
5979 
5980   SelectPattern StartPattern(*this, BitWidth, Start);
5981   if (!StartPattern.isRecognized())
5982     return ConstantRange::getFull(BitWidth);
5983 
5984   SelectPattern StepPattern(*this, BitWidth, Step);
5985   if (!StepPattern.isRecognized())
5986     return ConstantRange::getFull(BitWidth);
5987 
5988   if (StartPattern.Condition != StepPattern.Condition) {
5989     // We don't handle this case today; but we could, by considering four
5990     // possibilities below instead of two. I'm not sure if there are cases where
5991     // that will help over what getRange already does, though.
5992     return ConstantRange::getFull(BitWidth);
5993   }
5994 
5995   // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
5996   // construct arbitrary general SCEV expressions here.  This function is called
5997   // from deep in the call stack, and calling getSCEV (on a sext instruction,
5998   // say) can end up caching a suboptimal value.
5999 
6000   // FIXME: without the explicit `this` receiver below, MSVC errors out with
6001   // C2352 and C2512 (otherwise it isn't needed).
6002 
6003   const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
6004   const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
6005   const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
6006   const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
6007 
6008   ConstantRange TrueRange =
6009       this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth);
6010   ConstantRange FalseRange =
6011       this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth);
6012 
6013   return TrueRange.unionWith(FalseRange);
6014 }
6015 
6016 SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
6017   if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
6018   const BinaryOperator *BinOp = cast<BinaryOperator>(V);
6019 
6020   // Return early if there are no flags to propagate to the SCEV.
6021   SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
6022   if (BinOp->hasNoUnsignedWrap())
6023     Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
6024   if (BinOp->hasNoSignedWrap())
6025     Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
6026   if (Flags == SCEV::FlagAnyWrap)
6027     return SCEV::FlagAnyWrap;
6028 
6029   return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
6030 }
6031 
6032 bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
6033   // Here we check that I is in the header of the innermost loop containing I,
6034   // since we only deal with instructions in the loop header. The actual loop we
6035   // need to check later will come from an add recurrence, but getting that
6036   // requires computing the SCEV of the operands, which can be expensive. This
6037   // check we can do cheaply to rule out some cases early.
6038   Loop *InnermostContainingLoop = LI.getLoopFor(I->getParent());
6039   if (InnermostContainingLoop == nullptr ||
6040       InnermostContainingLoop->getHeader() != I->getParent())
6041     return false;
6042 
6043   // Only proceed if we can prove that I does not yield poison.
6044   if (!programUndefinedIfPoison(I))
6045     return false;
6046 
6047   // At this point we know that if I is executed, then it does not wrap
6048   // according to at least one of NSW or NUW. If I is not executed, then we do
6049   // not know if the calculation that I represents would wrap. Multiple
6050   // instructions can map to the same SCEV. If we apply NSW or NUW from I to
6051   // the SCEV, we must guarantee no wrapping for that SCEV also when it is
6052   // derived from other instructions that map to the same SCEV. We cannot make
6053   // that guarantee for cases where I is not executed. So we need to find the
6054   // loop that I is considered in relation to and prove that I is executed for
6055   // every iteration of that loop. That implies that the value that I
6056   // calculates does not wrap anywhere in the loop, so then we can apply the
6057   // flags to the SCEV.
6058   //
6059   // We check isLoopInvariant to disambiguate in case we are adding recurrences
6060   // from different loops, so that we know which loop to prove that I is
6061   // executed in.
6062   for (unsigned OpIndex = 0; OpIndex < I->getNumOperands(); ++OpIndex) {
6063     // I could be an extractvalue from a call to an overflow intrinsic.
6064     // TODO: We can do better here in some cases.
6065     if (!isSCEVable(I->getOperand(OpIndex)->getType()))
6066       return false;
6067     const SCEV *Op = getSCEV(I->getOperand(OpIndex));
6068     if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
6069       bool AllOtherOpsLoopInvariant = true;
6070       for (unsigned OtherOpIndex = 0; OtherOpIndex < I->getNumOperands();
6071            ++OtherOpIndex) {
6072         if (OtherOpIndex != OpIndex) {
6073           const SCEV *OtherOp = getSCEV(I->getOperand(OtherOpIndex));
6074           if (!isLoopInvariant(OtherOp, AddRec->getLoop())) {
6075             AllOtherOpsLoopInvariant = false;
6076             break;
6077           }
6078         }
6079       }
6080       if (AllOtherOpsLoopInvariant &&
6081           isGuaranteedToExecuteForEveryIteration(I, AddRec->getLoop()))
6082         return true;
6083     }
6084   }
6085   return false;
6086 }
6087 
6088 bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
6089   // If we know that \c I can never be poison period, then that's enough.
6090   if (isSCEVExprNeverPoison(I))
6091     return true;
6092 
6093   // For an add recurrence specifically, we assume that infinite loops without
6094   // side effects are undefined behavior, and then reason as follows:
6095   //
6096   // If the add recurrence is poison in any iteration, it is poison on all
6097   // future iterations (since incrementing poison yields poison). If the result
6098   // of the add recurrence is fed into the loop latch condition and the loop
6099   // does not contain any throws or exiting blocks other than the latch, we now
6100   // have the ability to "choose" whether the backedge is taken or not (by
6101   // choosing a sufficiently evil value for the poison feeding into the branch)
6102   // for every iteration including and after the one in which \p I first became
6103   // poison.  There are two possibilities (let's call the iteration in which \p
6104   // I first became poison as K):
6105   //
6106   //  1. In the set of iterations including and after K, the loop body executes
6107   //     no side effects.  In this case executing the backege an infinte number
6108   //     of times will yield undefined behavior.
6109   //
6110   //  2. In the set of iterations including and after K, the loop body executes
6111   //     at least one side effect.  In this case, that specific instance of side
6112   //     effect is control dependent on poison, which also yields undefined
6113   //     behavior.
6114 
6115   auto *ExitingBB = L->getExitingBlock();
6116   auto *LatchBB = L->getLoopLatch();
6117   if (!ExitingBB || !LatchBB || ExitingBB != LatchBB)
6118     return false;
6119 
6120   SmallPtrSet<const Instruction *, 16> Pushed;
6121   SmallVector<const Instruction *, 8> PoisonStack;
6122 
6123   // We start by assuming \c I, the post-inc add recurrence, is poison.  Only
6124   // things that are known to be poison under that assumption go on the
6125   // PoisonStack.
6126   Pushed.insert(I);
6127   PoisonStack.push_back(I);
6128 
6129   bool LatchControlDependentOnPoison = false;
6130   while (!PoisonStack.empty() && !LatchControlDependentOnPoison) {
6131     const Instruction *Poison = PoisonStack.pop_back_val();
6132 
6133     for (auto *PoisonUser : Poison->users()) {
6134       if (propagatesPoison(cast<Operator>(PoisonUser))) {
6135         if (Pushed.insert(cast<Instruction>(PoisonUser)).second)
6136           PoisonStack.push_back(cast<Instruction>(PoisonUser));
6137       } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) {
6138         assert(BI->isConditional() && "Only possibility!");
6139         if (BI->getParent() == LatchBB) {
6140           LatchControlDependentOnPoison = true;
6141           break;
6142         }
6143       }
6144     }
6145   }
6146 
6147   return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L);
6148 }
6149 
6150 ScalarEvolution::LoopProperties
6151 ScalarEvolution::getLoopProperties(const Loop *L) {
6152   using LoopProperties = ScalarEvolution::LoopProperties;
6153 
6154   auto Itr = LoopPropertiesCache.find(L);
6155   if (Itr == LoopPropertiesCache.end()) {
6156     auto HasSideEffects = [](Instruction *I) {
6157       if (auto *SI = dyn_cast<StoreInst>(I))
6158         return !SI->isSimple();
6159 
6160       return I->mayHaveSideEffects();
6161     };
6162 
6163     LoopProperties LP = {/* HasNoAbnormalExits */ true,
6164                          /*HasNoSideEffects*/ true};
6165 
6166     for (auto *BB : L->getBlocks())
6167       for (auto &I : *BB) {
6168         if (!isGuaranteedToTransferExecutionToSuccessor(&I))
6169           LP.HasNoAbnormalExits = false;
6170         if (HasSideEffects(&I))
6171           LP.HasNoSideEffects = false;
6172         if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
6173           break; // We're already as pessimistic as we can get.
6174       }
6175 
6176     auto InsertPair = LoopPropertiesCache.insert({L, LP});
6177     assert(InsertPair.second && "We just checked!");
6178     Itr = InsertPair.first;
6179   }
6180 
6181   return Itr->second;
6182 }
6183 
6184 const SCEV *ScalarEvolution::createSCEV(Value *V) {
6185   if (!isSCEVable(V->getType()))
6186     return getUnknown(V);
6187 
6188   if (Instruction *I = dyn_cast<Instruction>(V)) {
6189     // Don't attempt to analyze instructions in blocks that aren't
6190     // reachable. Such instructions don't matter, and they aren't required
6191     // to obey basic rules for definitions dominating uses which this
6192     // analysis depends on.
6193     if (!DT.isReachableFromEntry(I->getParent()))
6194       return getUnknown(UndefValue::get(V->getType()));
6195   } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
6196     return getConstant(CI);
6197   else if (isa<ConstantPointerNull>(V))
6198     // FIXME: we shouldn't special-case null pointer constant.
6199     return getZero(V->getType());
6200   else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
6201     return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee());
6202   else if (!isa<ConstantExpr>(V))
6203     return getUnknown(V);
6204 
6205   Operator *U = cast<Operator>(V);
6206   if (auto BO = MatchBinaryOp(U, DT)) {
6207     switch (BO->Opcode) {
6208     case Instruction::Add: {
6209       // The simple thing to do would be to just call getSCEV on both operands
6210       // and call getAddExpr with the result. However if we're looking at a
6211       // bunch of things all added together, this can be quite inefficient,
6212       // because it leads to N-1 getAddExpr calls for N ultimate operands.
6213       // Instead, gather up all the operands and make a single getAddExpr call.
6214       // LLVM IR canonical form means we need only traverse the left operands.
6215       SmallVector<const SCEV *, 4> AddOps;
6216       do {
6217         if (BO->Op) {
6218           if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
6219             AddOps.push_back(OpSCEV);
6220             break;
6221           }
6222 
6223           // If a NUW or NSW flag can be applied to the SCEV for this
6224           // addition, then compute the SCEV for this addition by itself
6225           // with a separate call to getAddExpr. We need to do that
6226           // instead of pushing the operands of the addition onto AddOps,
6227           // since the flags are only known to apply to this particular
6228           // addition - they may not apply to other additions that can be
6229           // formed with operands from AddOps.
6230           const SCEV *RHS = getSCEV(BO->RHS);
6231           SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
6232           if (Flags != SCEV::FlagAnyWrap) {
6233             const SCEV *LHS = getSCEV(BO->LHS);
6234             if (BO->Opcode == Instruction::Sub)
6235               AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
6236             else
6237               AddOps.push_back(getAddExpr(LHS, RHS, Flags));
6238             break;
6239           }
6240         }
6241 
6242         if (BO->Opcode == Instruction::Sub)
6243           AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
6244         else
6245           AddOps.push_back(getSCEV(BO->RHS));
6246 
6247         auto NewBO = MatchBinaryOp(BO->LHS, DT);
6248         if (!NewBO || (NewBO->Opcode != Instruction::Add &&
6249                        NewBO->Opcode != Instruction::Sub)) {
6250           AddOps.push_back(getSCEV(BO->LHS));
6251           break;
6252         }
6253         BO = NewBO;
6254       } while (true);
6255 
6256       return getAddExpr(AddOps);
6257     }
6258 
6259     case Instruction::Mul: {
6260       SmallVector<const SCEV *, 4> MulOps;
6261       do {
6262         if (BO->Op) {
6263           if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
6264             MulOps.push_back(OpSCEV);
6265             break;
6266           }
6267 
6268           SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
6269           if (Flags != SCEV::FlagAnyWrap) {
6270             MulOps.push_back(
6271                 getMulExpr(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags));
6272             break;
6273           }
6274         }
6275 
6276         MulOps.push_back(getSCEV(BO->RHS));
6277         auto NewBO = MatchBinaryOp(BO->LHS, DT);
6278         if (!NewBO || NewBO->Opcode != Instruction::Mul) {
6279           MulOps.push_back(getSCEV(BO->LHS));
6280           break;
6281         }
6282         BO = NewBO;
6283       } while (true);
6284 
6285       return getMulExpr(MulOps);
6286     }
6287     case Instruction::UDiv:
6288       return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS));
6289     case Instruction::URem:
6290       return getURemExpr(getSCEV(BO->LHS), getSCEV(BO->RHS));
6291     case Instruction::Sub: {
6292       SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
6293       if (BO->Op)
6294         Flags = getNoWrapFlagsFromUB(BO->Op);
6295       return getMinusSCEV(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags);
6296     }
6297     case Instruction::And:
6298       // For an expression like x&255 that merely masks off the high bits,
6299       // use zext(trunc(x)) as the SCEV expression.
6300       if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
6301         if (CI->isZero())
6302           return getSCEV(BO->RHS);
6303         if (CI->isMinusOne())
6304           return getSCEV(BO->LHS);
6305         const APInt &A = CI->getValue();
6306 
6307         // Instcombine's ShrinkDemandedConstant may strip bits out of
6308         // constants, obscuring what would otherwise be a low-bits mask.
6309         // Use computeKnownBits to compute what ShrinkDemandedConstant
6310         // knew about to reconstruct a low-bits mask value.
6311         unsigned LZ = A.countLeadingZeros();
6312         unsigned TZ = A.countTrailingZeros();
6313         unsigned BitWidth = A.getBitWidth();
6314         KnownBits Known(BitWidth);
6315         computeKnownBits(BO->LHS, Known, getDataLayout(),
6316                          0, &AC, nullptr, &DT);
6317 
6318         APInt EffectiveMask =
6319             APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
6320         if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
6321           const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
6322           const SCEV *LHS = getSCEV(BO->LHS);
6323           const SCEV *ShiftedLHS = nullptr;
6324           if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
6325             if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
6326               // For an expression like (x * 8) & 8, simplify the multiply.
6327               unsigned MulZeros = OpC->getAPInt().countTrailingZeros();
6328               unsigned GCD = std::min(MulZeros, TZ);
6329               APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
6330               SmallVector<const SCEV*, 4> MulOps;
6331               MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
6332               MulOps.append(LHSMul->op_begin() + 1, LHSMul->op_end());
6333               auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
6334               ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
6335             }
6336           }
6337           if (!ShiftedLHS)
6338             ShiftedLHS = getUDivExpr(LHS, MulCount);
6339           return getMulExpr(
6340               getZeroExtendExpr(
6341                   getTruncateExpr(ShiftedLHS,
6342                       IntegerType::get(getContext(), BitWidth - LZ - TZ)),
6343                   BO->LHS->getType()),
6344               MulCount);
6345         }
6346       }
6347       break;
6348 
6349     case Instruction::Or:
6350       // If the RHS of the Or is a constant, we may have something like:
6351       // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
6352       // optimizations will transparently handle this case.
6353       //
6354       // In order for this transformation to be safe, the LHS must be of the
6355       // form X*(2^n) and the Or constant must be less than 2^n.
6356       if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
6357         const SCEV *LHS = getSCEV(BO->LHS);
6358         const APInt &CIVal = CI->getValue();
6359         if (GetMinTrailingZeros(LHS) >=
6360             (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
6361           // Build a plain add SCEV.
6362           return getAddExpr(LHS, getSCEV(CI),
6363                             (SCEV::NoWrapFlags)(SCEV::FlagNUW | SCEV::FlagNSW));
6364         }
6365       }
6366       break;
6367 
6368     case Instruction::Xor:
6369       if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
6370         // If the RHS of xor is -1, then this is a not operation.
6371         if (CI->isMinusOne())
6372           return getNotSCEV(getSCEV(BO->LHS));
6373 
6374         // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
6375         // This is a variant of the check for xor with -1, and it handles
6376         // the case where instcombine has trimmed non-demanded bits out
6377         // of an xor with -1.
6378         if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
6379           if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
6380             if (LBO->getOpcode() == Instruction::And &&
6381                 LCI->getValue() == CI->getValue())
6382               if (const SCEVZeroExtendExpr *Z =
6383                       dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
6384                 Type *UTy = BO->LHS->getType();
6385                 const SCEV *Z0 = Z->getOperand();
6386                 Type *Z0Ty = Z0->getType();
6387                 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
6388 
6389                 // If C is a low-bits mask, the zero extend is serving to
6390                 // mask off the high bits. Complement the operand and
6391                 // re-apply the zext.
6392                 if (CI->getValue().isMask(Z0TySize))
6393                   return getZeroExtendExpr(getNotSCEV(Z0), UTy);
6394 
6395                 // If C is a single bit, it may be in the sign-bit position
6396                 // before the zero-extend. In this case, represent the xor
6397                 // using an add, which is equivalent, and re-apply the zext.
6398                 APInt Trunc = CI->getValue().trunc(Z0TySize);
6399                 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
6400                     Trunc.isSignMask())
6401                   return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
6402                                            UTy);
6403               }
6404       }
6405       break;
6406 
6407     case Instruction::Shl:
6408       // Turn shift left of a constant amount into a multiply.
6409       if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
6410         uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
6411 
6412         // If the shift count is not less than the bitwidth, the result of
6413         // the shift is undefined. Don't try to analyze it, because the
6414         // resolution chosen here may differ from the resolution chosen in
6415         // other parts of the compiler.
6416         if (SA->getValue().uge(BitWidth))
6417           break;
6418 
6419         // We can safely preserve the nuw flag in all cases. It's also safe to
6420         // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
6421         // requires special handling. It can be preserved as long as we're not
6422         // left shifting by bitwidth - 1.
6423         auto Flags = SCEV::FlagAnyWrap;
6424         if (BO->Op) {
6425           auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
6426           if ((MulFlags & SCEV::FlagNSW) &&
6427               ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
6428             Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNSW);
6429           if (MulFlags & SCEV::FlagNUW)
6430             Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNUW);
6431         }
6432 
6433         Constant *X = ConstantInt::get(
6434             getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
6435         return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags);
6436       }
6437       break;
6438 
6439     case Instruction::AShr: {
6440       // AShr X, C, where C is a constant.
6441       ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
6442       if (!CI)
6443         break;
6444 
6445       Type *OuterTy = BO->LHS->getType();
6446       uint64_t BitWidth = getTypeSizeInBits(OuterTy);
6447       // If the shift count is not less than the bitwidth, the result of
6448       // the shift is undefined. Don't try to analyze it, because the
6449       // resolution chosen here may differ from the resolution chosen in
6450       // other parts of the compiler.
6451       if (CI->getValue().uge(BitWidth))
6452         break;
6453 
6454       if (CI->isZero())
6455         return getSCEV(BO->LHS); // shift by zero --> noop
6456 
6457       uint64_t AShrAmt = CI->getZExtValue();
6458       Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
6459 
6460       Operator *L = dyn_cast<Operator>(BO->LHS);
6461       if (L && L->getOpcode() == Instruction::Shl) {
6462         // X = Shl A, n
6463         // Y = AShr X, m
6464         // Both n and m are constant.
6465 
6466         const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
6467         if (L->getOperand(1) == BO->RHS)
6468           // For a two-shift sext-inreg, i.e. n = m,
6469           // use sext(trunc(x)) as the SCEV expression.
6470           return getSignExtendExpr(
6471               getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy);
6472 
6473         ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
6474         if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) {
6475           uint64_t ShlAmt = ShlAmtCI->getZExtValue();
6476           if (ShlAmt > AShrAmt) {
6477             // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
6478             // expression. We already checked that ShlAmt < BitWidth, so
6479             // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
6480             // ShlAmt - AShrAmt < Amt.
6481             APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
6482                                             ShlAmt - AShrAmt);
6483             return getSignExtendExpr(
6484                 getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy),
6485                 getConstant(Mul)), OuterTy);
6486           }
6487         }
6488       }
6489       if (BO->IsExact) {
6490         // Given exact arithmetic in-bounds right-shift by a constant,
6491         // we can lower it into:  (abs(x) EXACT/u (1<<C)) * signum(x)
6492         const SCEV *X = getSCEV(BO->LHS);
6493         const SCEV *AbsX = getAbsExpr(X, /*IsNSW=*/false);
6494         APInt Mult = APInt::getOneBitSet(BitWidth, AShrAmt);
6495         const SCEV *Div = getUDivExactExpr(AbsX, getConstant(Mult));
6496         return getMulExpr(Div, getSignumExpr(X), SCEV::FlagNSW);
6497       }
6498       break;
6499     }
6500     }
6501   }
6502 
6503   switch (U->getOpcode()) {
6504   case Instruction::Trunc:
6505     return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
6506 
6507   case Instruction::ZExt:
6508     return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
6509 
6510   case Instruction::SExt:
6511     if (auto BO = MatchBinaryOp(U->getOperand(0), DT)) {
6512       // The NSW flag of a subtract does not always survive the conversion to
6513       // A + (-1)*B.  By pushing sign extension onto its operands we are much
6514       // more likely to preserve NSW and allow later AddRec optimisations.
6515       //
6516       // NOTE: This is effectively duplicating this logic from getSignExtend:
6517       //   sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
6518       // but by that point the NSW information has potentially been lost.
6519       if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
6520         Type *Ty = U->getType();
6521         auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
6522         auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
6523         return getMinusSCEV(V1, V2, SCEV::FlagNSW);
6524       }
6525     }
6526     return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
6527 
6528   case Instruction::BitCast:
6529     // BitCasts are no-op casts so we just eliminate the cast.
6530     if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
6531       return getSCEV(U->getOperand(0));
6532     break;
6533 
6534   case Instruction::PtrToInt: {
6535     // Pointer to integer cast is straight-forward, so do model it.
6536     Value *Ptr = U->getOperand(0);
6537     const SCEV *Op = getSCEV(Ptr);
6538     Type *DstIntTy = U->getType();
6539     // SCEV doesn't have constant pointer expression type, but it supports
6540     // nullptr constant (and only that one), which is modelled in SCEV as a
6541     // zero integer constant. So just skip the ptrtoint cast for constants.
6542     if (isa<SCEVConstant>(Op))
6543       return getTruncateOrZeroExtend(Op, DstIntTy);
6544     Type *PtrTy = Ptr->getType();
6545     Type *IntPtrTy = getDataLayout().getIntPtrType(PtrTy);
6546     // But only if effective SCEV (integer) type is wide enough to represent
6547     // all possible pointer values.
6548     if (getDataLayout().getTypeSizeInBits(getEffectiveSCEVType(PtrTy)) !=
6549         getDataLayout().getTypeSizeInBits(IntPtrTy))
6550       return getUnknown(V);
6551     return getPtrToIntExpr(Op, DstIntTy);
6552   }
6553   case Instruction::IntToPtr:
6554     // Just don't deal with inttoptr casts.
6555     return getUnknown(V);
6556 
6557   case Instruction::SDiv:
6558     // If both operands are non-negative, this is just an udiv.
6559     if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
6560         isKnownNonNegative(getSCEV(U->getOperand(1))))
6561       return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
6562     break;
6563 
6564   case Instruction::SRem:
6565     // If both operands are non-negative, this is just an urem.
6566     if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
6567         isKnownNonNegative(getSCEV(U->getOperand(1))))
6568       return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
6569     break;
6570 
6571   case Instruction::GetElementPtr:
6572     return createNodeForGEP(cast<GEPOperator>(U));
6573 
6574   case Instruction::PHI:
6575     return createNodeForPHI(cast<PHINode>(U));
6576 
6577   case Instruction::Select:
6578     // U can also be a select constant expr, which let fall through.  Since
6579     // createNodeForSelect only works for a condition that is an `ICmpInst`, and
6580     // constant expressions cannot have instructions as operands, we'd have
6581     // returned getUnknown for a select constant expressions anyway.
6582     if (isa<Instruction>(U))
6583       return createNodeForSelectOrPHI(cast<Instruction>(U), U->getOperand(0),
6584                                       U->getOperand(1), U->getOperand(2));
6585     break;
6586 
6587   case Instruction::Call:
6588   case Instruction::Invoke:
6589     if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
6590       return getSCEV(RV);
6591 
6592     if (auto *II = dyn_cast<IntrinsicInst>(U)) {
6593       switch (II->getIntrinsicID()) {
6594       case Intrinsic::abs:
6595         return getAbsExpr(
6596             getSCEV(II->getArgOperand(0)),
6597             /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
6598       case Intrinsic::umax:
6599         return getUMaxExpr(getSCEV(II->getArgOperand(0)),
6600                            getSCEV(II->getArgOperand(1)));
6601       case Intrinsic::umin:
6602         return getUMinExpr(getSCEV(II->getArgOperand(0)),
6603                            getSCEV(II->getArgOperand(1)));
6604       case Intrinsic::smax:
6605         return getSMaxExpr(getSCEV(II->getArgOperand(0)),
6606                            getSCEV(II->getArgOperand(1)));
6607       case Intrinsic::smin:
6608         return getSMinExpr(getSCEV(II->getArgOperand(0)),
6609                            getSCEV(II->getArgOperand(1)));
6610       case Intrinsic::usub_sat: {
6611         const SCEV *X = getSCEV(II->getArgOperand(0));
6612         const SCEV *Y = getSCEV(II->getArgOperand(1));
6613         const SCEV *ClampedY = getUMinExpr(X, Y);
6614         return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
6615       }
6616       case Intrinsic::uadd_sat: {
6617         const SCEV *X = getSCEV(II->getArgOperand(0));
6618         const SCEV *Y = getSCEV(II->getArgOperand(1));
6619         const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
6620         return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
6621       }
6622       default:
6623         break;
6624       }
6625     }
6626     break;
6627   }
6628 
6629   return getUnknown(V);
6630 }
6631 
6632 //===----------------------------------------------------------------------===//
6633 //                   Iteration Count Computation Code
6634 //
6635 
6636 static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
6637   if (!ExitCount)
6638     return 0;
6639 
6640   ConstantInt *ExitConst = ExitCount->getValue();
6641 
6642   // Guard against huge trip counts.
6643   if (ExitConst->getValue().getActiveBits() > 32)
6644     return 0;
6645 
6646   // In case of integer overflow, this returns 0, which is correct.
6647   return ((unsigned)ExitConst->getZExtValue()) + 1;
6648 }
6649 
6650 unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) {
6651   if (BasicBlock *ExitingBB = L->getExitingBlock())
6652     return getSmallConstantTripCount(L, ExitingBB);
6653 
6654   // No trip count information for multiple exits.
6655   return 0;
6656 }
6657 
6658 unsigned
6659 ScalarEvolution::getSmallConstantTripCount(const Loop *L,
6660                                            const BasicBlock *ExitingBlock) {
6661   assert(ExitingBlock && "Must pass a non-null exiting block!");
6662   assert(L->isLoopExiting(ExitingBlock) &&
6663          "Exiting block must actually branch out of the loop!");
6664   const SCEVConstant *ExitCount =
6665       dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
6666   return getConstantTripCount(ExitCount);
6667 }
6668 
6669 unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) {
6670   const auto *MaxExitCount =
6671       dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
6672   return getConstantTripCount(MaxExitCount);
6673 }
6674 
6675 unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
6676   if (BasicBlock *ExitingBB = L->getExitingBlock())
6677     return getSmallConstantTripMultiple(L, ExitingBB);
6678 
6679   // No trip multiple information for multiple exits.
6680   return 0;
6681 }
6682 
6683 /// Returns the largest constant divisor of the trip count of this loop as a
6684 /// normal unsigned value, if possible. This means that the actual trip count is
6685 /// always a multiple of the returned value (don't forget the trip count could
6686 /// very well be zero as well!).
6687 ///
6688 /// Returns 1 if the trip count is unknown or not guaranteed to be the
6689 /// multiple of a constant (which is also the case if the trip count is simply
6690 /// constant, use getSmallConstantTripCount for that case), Will also return 1
6691 /// if the trip count is very large (>= 2^32).
6692 ///
6693 /// As explained in the comments for getSmallConstantTripCount, this assumes
6694 /// that control exits the loop via ExitingBlock.
6695 unsigned
6696 ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
6697                                               const BasicBlock *ExitingBlock) {
6698   assert(ExitingBlock && "Must pass a non-null exiting block!");
6699   assert(L->isLoopExiting(ExitingBlock) &&
6700          "Exiting block must actually branch out of the loop!");
6701   const SCEV *ExitCount = getExitCount(L, ExitingBlock);
6702   if (ExitCount == getCouldNotCompute())
6703     return 1;
6704 
6705   // Get the trip count from the BE count by adding 1.
6706   const SCEV *TCExpr = getAddExpr(ExitCount, getOne(ExitCount->getType()));
6707 
6708   const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr);
6709   if (!TC)
6710     // Attempt to factor more general cases. Returns the greatest power of
6711     // two divisor. If overflow happens, the trip count expression is still
6712     // divisible by the greatest power of 2 divisor returned.
6713     return 1U << std::min((uint32_t)31, GetMinTrailingZeros(TCExpr));
6714 
6715   ConstantInt *Result = TC->getValue();
6716 
6717   // Guard against huge trip counts (this requires checking
6718   // for zero to handle the case where the trip count == -1 and the
6719   // addition wraps).
6720   if (!Result || Result->getValue().getActiveBits() > 32 ||
6721       Result->getValue().getActiveBits() == 0)
6722     return 1;
6723 
6724   return (unsigned)Result->getZExtValue();
6725 }
6726 
6727 const SCEV *ScalarEvolution::getExitCount(const Loop *L,
6728                                           const BasicBlock *ExitingBlock,
6729                                           ExitCountKind Kind) {
6730   switch (Kind) {
6731   case Exact:
6732   case SymbolicMaximum:
6733     return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
6734   case ConstantMaximum:
6735     return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
6736   };
6737   llvm_unreachable("Invalid ExitCountKind!");
6738 }
6739 
6740 const SCEV *
6741 ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
6742                                                  SCEVUnionPredicate &Preds) {
6743   return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
6744 }
6745 
6746 const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L,
6747                                                    ExitCountKind Kind) {
6748   switch (Kind) {
6749   case Exact:
6750     return getBackedgeTakenInfo(L).getExact(L, this);
6751   case ConstantMaximum:
6752     return getBackedgeTakenInfo(L).getConstantMax(this);
6753   case SymbolicMaximum:
6754     return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
6755   };
6756   llvm_unreachable("Invalid ExitCountKind!");
6757 }
6758 
6759 bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
6760   return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
6761 }
6762 
6763 /// Push PHI nodes in the header of the given loop onto the given Worklist.
6764 static void
6765 PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
6766   BasicBlock *Header = L->getHeader();
6767 
6768   // Push all Loop-header PHIs onto the Worklist stack.
6769   for (PHINode &PN : Header->phis())
6770     Worklist.push_back(&PN);
6771 }
6772 
6773 const ScalarEvolution::BackedgeTakenInfo &
6774 ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
6775   auto &BTI = getBackedgeTakenInfo(L);
6776   if (BTI.hasFullInfo())
6777     return BTI;
6778 
6779   auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
6780 
6781   if (!Pair.second)
6782     return Pair.first->second;
6783 
6784   BackedgeTakenInfo Result =
6785       computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
6786 
6787   return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
6788 }
6789 
6790 ScalarEvolution::BackedgeTakenInfo &
6791 ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
6792   // Initially insert an invalid entry for this loop. If the insertion
6793   // succeeds, proceed to actually compute a backedge-taken count and
6794   // update the value. The temporary CouldNotCompute value tells SCEV
6795   // code elsewhere that it shouldn't attempt to request a new
6796   // backedge-taken count, which could result in infinite recursion.
6797   std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
6798       BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
6799   if (!Pair.second)
6800     return Pair.first->second;
6801 
6802   // computeBackedgeTakenCount may allocate memory for its result. Inserting it
6803   // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
6804   // must be cleared in this scope.
6805   BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
6806 
6807   // In product build, there are no usage of statistic.
6808   (void)NumTripCountsComputed;
6809   (void)NumTripCountsNotComputed;
6810 #if LLVM_ENABLE_STATS || !defined(NDEBUG)
6811   const SCEV *BEExact = Result.getExact(L, this);
6812   if (BEExact != getCouldNotCompute()) {
6813     assert(isLoopInvariant(BEExact, L) &&
6814            isLoopInvariant(Result.getConstantMax(this), L) &&
6815            "Computed backedge-taken count isn't loop invariant for loop!");
6816     ++NumTripCountsComputed;
6817   } else if (Result.getConstantMax(this) == getCouldNotCompute() &&
6818              isa<PHINode>(L->getHeader()->begin())) {
6819     // Only count loops that have phi nodes as not being computable.
6820     ++NumTripCountsNotComputed;
6821   }
6822 #endif // LLVM_ENABLE_STATS || !defined(NDEBUG)
6823 
6824   // Now that we know more about the trip count for this loop, forget any
6825   // existing SCEV values for PHI nodes in this loop since they are only
6826   // conservative estimates made without the benefit of trip count
6827   // information. This is similar to the code in forgetLoop, except that
6828   // it handles SCEVUnknown PHI nodes specially.
6829   if (Result.hasAnyInfo()) {
6830     SmallVector<Instruction *, 16> Worklist;
6831     PushLoopPHIs(L, Worklist);
6832 
6833     SmallPtrSet<Instruction *, 8> Discovered;
6834     while (!Worklist.empty()) {
6835       Instruction *I = Worklist.pop_back_val();
6836 
6837       ValueExprMapType::iterator It =
6838         ValueExprMap.find_as(static_cast<Value *>(I));
6839       if (It != ValueExprMap.end()) {
6840         const SCEV *Old = It->second;
6841 
6842         // SCEVUnknown for a PHI either means that it has an unrecognized
6843         // structure, or it's a PHI that's in the progress of being computed
6844         // by createNodeForPHI.  In the former case, additional loop trip
6845         // count information isn't going to change anything. In the later
6846         // case, createNodeForPHI will perform the necessary updates on its
6847         // own when it gets to that point.
6848         if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) {
6849           eraseValueFromMap(It->first);
6850           forgetMemoizedResults(Old);
6851         }
6852         if (PHINode *PN = dyn_cast<PHINode>(I))
6853           ConstantEvolutionLoopExitValue.erase(PN);
6854       }
6855 
6856       // Since we don't need to invalidate anything for correctness and we're
6857       // only invalidating to make SCEV's results more precise, we get to stop
6858       // early to avoid invalidating too much.  This is especially important in
6859       // cases like:
6860       //
6861       //   %v = f(pn0, pn1) // pn0 and pn1 used through some other phi node
6862       // loop0:
6863       //   %pn0 = phi
6864       //   ...
6865       // loop1:
6866       //   %pn1 = phi
6867       //   ...
6868       //
6869       // where both loop0 and loop1's backedge taken count uses the SCEV
6870       // expression for %v.  If we don't have the early stop below then in cases
6871       // like the above, getBackedgeTakenInfo(loop1) will clear out the trip
6872       // count for loop0 and getBackedgeTakenInfo(loop0) will clear out the trip
6873       // count for loop1, effectively nullifying SCEV's trip count cache.
6874       for (auto *U : I->users())
6875         if (auto *I = dyn_cast<Instruction>(U)) {
6876           auto *LoopForUser = LI.getLoopFor(I->getParent());
6877           if (LoopForUser && L->contains(LoopForUser) &&
6878               Discovered.insert(I).second)
6879             Worklist.push_back(I);
6880         }
6881     }
6882   }
6883 
6884   // Re-lookup the insert position, since the call to
6885   // computeBackedgeTakenCount above could result in a
6886   // recusive call to getBackedgeTakenInfo (on a different
6887   // loop), which would invalidate the iterator computed
6888   // earlier.
6889   return BackedgeTakenCounts.find(L)->second = std::move(Result);
6890 }
6891 
6892 void ScalarEvolution::forgetAllLoops() {
6893   // This method is intended to forget all info about loops. It should
6894   // invalidate caches as if the following happened:
6895   // - The trip counts of all loops have changed arbitrarily
6896   // - Every llvm::Value has been updated in place to produce a different
6897   // result.
6898   BackedgeTakenCounts.clear();
6899   PredicatedBackedgeTakenCounts.clear();
6900   LoopPropertiesCache.clear();
6901   ConstantEvolutionLoopExitValue.clear();
6902   ValueExprMap.clear();
6903   ValuesAtScopes.clear();
6904   LoopDispositions.clear();
6905   BlockDispositions.clear();
6906   UnsignedRanges.clear();
6907   SignedRanges.clear();
6908   ExprValueMap.clear();
6909   HasRecMap.clear();
6910   MinTrailingZerosCache.clear();
6911   PredicatedSCEVRewrites.clear();
6912 }
6913 
6914 void ScalarEvolution::forgetLoop(const Loop *L) {
6915   // Drop any stored trip count value.
6916   auto RemoveLoopFromBackedgeMap =
6917       [](DenseMap<const Loop *, BackedgeTakenInfo> &Map, const Loop *L) {
6918         auto BTCPos = Map.find(L);
6919         if (BTCPos != Map.end()) {
6920           BTCPos->second.clear();
6921           Map.erase(BTCPos);
6922         }
6923       };
6924 
6925   SmallVector<const Loop *, 16> LoopWorklist(1, L);
6926   SmallVector<Instruction *, 32> Worklist;
6927   SmallPtrSet<Instruction *, 16> Visited;
6928 
6929   // Iterate over all the loops and sub-loops to drop SCEV information.
6930   while (!LoopWorklist.empty()) {
6931     auto *CurrL = LoopWorklist.pop_back_val();
6932 
6933     RemoveLoopFromBackedgeMap(BackedgeTakenCounts, CurrL);
6934     RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts, CurrL);
6935 
6936     // Drop information about predicated SCEV rewrites for this loop.
6937     for (auto I = PredicatedSCEVRewrites.begin();
6938          I != PredicatedSCEVRewrites.end();) {
6939       std::pair<const SCEV *, const Loop *> Entry = I->first;
6940       if (Entry.second == CurrL)
6941         PredicatedSCEVRewrites.erase(I++);
6942       else
6943         ++I;
6944     }
6945 
6946     auto LoopUsersItr = LoopUsers.find(CurrL);
6947     if (LoopUsersItr != LoopUsers.end()) {
6948       for (auto *S : LoopUsersItr->second)
6949         forgetMemoizedResults(S);
6950       LoopUsers.erase(LoopUsersItr);
6951     }
6952 
6953     // Drop information about expressions based on loop-header PHIs.
6954     PushLoopPHIs(CurrL, Worklist);
6955 
6956     while (!Worklist.empty()) {
6957       Instruction *I = Worklist.pop_back_val();
6958       if (!Visited.insert(I).second)
6959         continue;
6960 
6961       ValueExprMapType::iterator It =
6962           ValueExprMap.find_as(static_cast<Value *>(I));
6963       if (It != ValueExprMap.end()) {
6964         eraseValueFromMap(It->first);
6965         forgetMemoizedResults(It->second);
6966         if (PHINode *PN = dyn_cast<PHINode>(I))
6967           ConstantEvolutionLoopExitValue.erase(PN);
6968       }
6969 
6970       PushDefUseChildren(I, Worklist);
6971     }
6972 
6973     LoopPropertiesCache.erase(CurrL);
6974     // Forget all contained loops too, to avoid dangling entries in the
6975     // ValuesAtScopes map.
6976     LoopWorklist.append(CurrL->begin(), CurrL->end());
6977   }
6978 }
6979 
6980 void ScalarEvolution::forgetTopmostLoop(const Loop *L) {
6981   while (Loop *Parent = L->getParentLoop())
6982     L = Parent;
6983   forgetLoop(L);
6984 }
6985 
6986 void ScalarEvolution::forgetValue(Value *V) {
6987   Instruction *I = dyn_cast<Instruction>(V);
6988   if (!I) return;
6989 
6990   // Drop information about expressions based on loop-header PHIs.
6991   SmallVector<Instruction *, 16> Worklist;
6992   Worklist.push_back(I);
6993 
6994   SmallPtrSet<Instruction *, 8> Visited;
6995   while (!Worklist.empty()) {
6996     I = Worklist.pop_back_val();
6997     if (!Visited.insert(I).second)
6998       continue;
6999 
7000     ValueExprMapType::iterator It =
7001       ValueExprMap.find_as(static_cast<Value *>(I));
7002     if (It != ValueExprMap.end()) {
7003       eraseValueFromMap(It->first);
7004       forgetMemoizedResults(It->second);
7005       if (PHINode *PN = dyn_cast<PHINode>(I))
7006         ConstantEvolutionLoopExitValue.erase(PN);
7007     }
7008 
7009     PushDefUseChildren(I, Worklist);
7010   }
7011 }
7012 
7013 void ScalarEvolution::forgetLoopDispositions(const Loop *L) {
7014   LoopDispositions.clear();
7015 }
7016 
7017 /// Get the exact loop backedge taken count considering all loop exits. A
7018 /// computable result can only be returned for loops with all exiting blocks
7019 /// dominating the latch. howFarToZero assumes that the limit of each loop test
7020 /// is never skipped. This is a valid assumption as long as the loop exits via
7021 /// that test. For precise results, it is the caller's responsibility to specify
7022 /// the relevant loop exiting block using getExact(ExitingBlock, SE).
7023 const SCEV *
7024 ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
7025                                              SCEVUnionPredicate *Preds) const {
7026   // If any exits were not computable, the loop is not computable.
7027   if (!isComplete() || ExitNotTaken.empty())
7028     return SE->getCouldNotCompute();
7029 
7030   const BasicBlock *Latch = L->getLoopLatch();
7031   // All exiting blocks we have collected must dominate the only backedge.
7032   if (!Latch)
7033     return SE->getCouldNotCompute();
7034 
7035   // All exiting blocks we have gathered dominate loop's latch, so exact trip
7036   // count is simply a minimum out of all these calculated exit counts.
7037   SmallVector<const SCEV *, 2> Ops;
7038   for (auto &ENT : ExitNotTaken) {
7039     const SCEV *BECount = ENT.ExactNotTaken;
7040     assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
7041     assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
7042            "We should only have known counts for exiting blocks that dominate "
7043            "latch!");
7044 
7045     Ops.push_back(BECount);
7046 
7047     if (Preds && !ENT.hasAlwaysTruePredicate())
7048       Preds->add(ENT.Predicate.get());
7049 
7050     assert((Preds || ENT.hasAlwaysTruePredicate()) &&
7051            "Predicate should be always true!");
7052   }
7053 
7054   return SE->getUMinFromMismatchedTypes(Ops);
7055 }
7056 
7057 /// Get the exact not taken count for this loop exit.
7058 const SCEV *
7059 ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
7060                                              ScalarEvolution *SE) const {
7061   for (auto &ENT : ExitNotTaken)
7062     if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
7063       return ENT.ExactNotTaken;
7064 
7065   return SE->getCouldNotCompute();
7066 }
7067 
7068 const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
7069     const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
7070   for (auto &ENT : ExitNotTaken)
7071     if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
7072       return ENT.MaxNotTaken;
7073 
7074   return SE->getCouldNotCompute();
7075 }
7076 
7077 /// getConstantMax - Get the constant max backedge taken count for the loop.
7078 const SCEV *
7079 ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
7080   auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
7081     return !ENT.hasAlwaysTruePredicate();
7082   };
7083 
7084   if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getConstantMax())
7085     return SE->getCouldNotCompute();
7086 
7087   assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
7088           isa<SCEVConstant>(getConstantMax())) &&
7089          "No point in having a non-constant max backedge taken count!");
7090   return getConstantMax();
7091 }
7092 
7093 const SCEV *
7094 ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
7095                                                    ScalarEvolution *SE) {
7096   if (!SymbolicMax)
7097     SymbolicMax = SE->computeSymbolicMaxBackedgeTakenCount(L);
7098   return SymbolicMax;
7099 }
7100 
7101 bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
7102     ScalarEvolution *SE) const {
7103   auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
7104     return !ENT.hasAlwaysTruePredicate();
7105   };
7106   return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
7107 }
7108 
7109 bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
7110                                                     ScalarEvolution *SE) const {
7111   if (getConstantMax() && getConstantMax() != SE->getCouldNotCompute() &&
7112       SE->hasOperand(getConstantMax(), S))
7113     return true;
7114 
7115   for (auto &ENT : ExitNotTaken)
7116     if (ENT.ExactNotTaken != SE->getCouldNotCompute() &&
7117         SE->hasOperand(ENT.ExactNotTaken, S))
7118       return true;
7119 
7120   return false;
7121 }
7122 
7123 ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E)
7124     : ExactNotTaken(E), MaxNotTaken(E) {
7125   assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||
7126           isa<SCEVConstant>(MaxNotTaken)) &&
7127          "No point in having a non-constant max backedge taken count!");
7128 }
7129 
7130 ScalarEvolution::ExitLimit::ExitLimit(
7131     const SCEV *E, const SCEV *M, bool MaxOrZero,
7132     ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
7133     : ExactNotTaken(E), MaxNotTaken(M), MaxOrZero(MaxOrZero) {
7134   assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
7135           !isa<SCEVCouldNotCompute>(MaxNotTaken)) &&
7136          "Exact is not allowed to be less precise than Max");
7137   assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||
7138           isa<SCEVConstant>(MaxNotTaken)) &&
7139          "No point in having a non-constant max backedge taken count!");
7140   for (auto *PredSet : PredSetList)
7141     for (auto *P : *PredSet)
7142       addPredicate(P);
7143 }
7144 
7145 ScalarEvolution::ExitLimit::ExitLimit(
7146     const SCEV *E, const SCEV *M, bool MaxOrZero,
7147     const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
7148     : ExitLimit(E, M, MaxOrZero, {&PredSet}) {
7149   assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||
7150           isa<SCEVConstant>(MaxNotTaken)) &&
7151          "No point in having a non-constant max backedge taken count!");
7152 }
7153 
7154 ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M,
7155                                       bool MaxOrZero)
7156     : ExitLimit(E, M, MaxOrZero, None) {
7157   assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||
7158           isa<SCEVConstant>(MaxNotTaken)) &&
7159          "No point in having a non-constant max backedge taken count!");
7160 }
7161 
7162 /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
7163 /// computable exit into a persistent ExitNotTakenInfo array.
7164 ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
7165     ArrayRef<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> ExitCounts,
7166     bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
7167     : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
7168   using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
7169 
7170   ExitNotTaken.reserve(ExitCounts.size());
7171   std::transform(
7172       ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken),
7173       [&](const EdgeExitInfo &EEI) {
7174         BasicBlock *ExitBB = EEI.first;
7175         const ExitLimit &EL = EEI.second;
7176         if (EL.Predicates.empty())
7177           return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, EL.MaxNotTaken,
7178                                   nullptr);
7179 
7180         std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate);
7181         for (auto *Pred : EL.Predicates)
7182           Predicate->add(Pred);
7183 
7184         return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, EL.MaxNotTaken,
7185                                 std::move(Predicate));
7186       });
7187   assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
7188           isa<SCEVConstant>(ConstantMax)) &&
7189          "No point in having a non-constant max backedge taken count!");
7190 }
7191 
7192 /// Invalidate this result and free the ExitNotTakenInfo array.
7193 void ScalarEvolution::BackedgeTakenInfo::clear() {
7194   ExitNotTaken.clear();
7195 }
7196 
7197 /// Compute the number of times the backedge of the specified loop will execute.
7198 ScalarEvolution::BackedgeTakenInfo
7199 ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
7200                                            bool AllowPredicates) {
7201   SmallVector<BasicBlock *, 8> ExitingBlocks;
7202   L->getExitingBlocks(ExitingBlocks);
7203 
7204   using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
7205 
7206   SmallVector<EdgeExitInfo, 4> ExitCounts;
7207   bool CouldComputeBECount = true;
7208   BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
7209   const SCEV *MustExitMaxBECount = nullptr;
7210   const SCEV *MayExitMaxBECount = nullptr;
7211   bool MustExitMaxOrZero = false;
7212 
7213   // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
7214   // and compute maxBECount.
7215   // Do a union of all the predicates here.
7216   for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
7217     BasicBlock *ExitBB = ExitingBlocks[i];
7218 
7219     // We canonicalize untaken exits to br (constant), ignore them so that
7220     // proving an exit untaken doesn't negatively impact our ability to reason
7221     // about the loop as whole.
7222     if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
7223       if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
7224         bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
7225         if ((ExitIfTrue && CI->isZero()) || (!ExitIfTrue && CI->isOne()))
7226           continue;
7227       }
7228 
7229     ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
7230 
7231     assert((AllowPredicates || EL.Predicates.empty()) &&
7232            "Predicated exit limit when predicates are not allowed!");
7233 
7234     // 1. For each exit that can be computed, add an entry to ExitCounts.
7235     // CouldComputeBECount is true only if all exits can be computed.
7236     if (EL.ExactNotTaken == getCouldNotCompute())
7237       // We couldn't compute an exact value for this exit, so
7238       // we won't be able to compute an exact value for the loop.
7239       CouldComputeBECount = false;
7240     else
7241       ExitCounts.emplace_back(ExitBB, EL);
7242 
7243     // 2. Derive the loop's MaxBECount from each exit's max number of
7244     // non-exiting iterations. Partition the loop exits into two kinds:
7245     // LoopMustExits and LoopMayExits.
7246     //
7247     // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
7248     // is a LoopMayExit.  If any computable LoopMustExit is found, then
7249     // MaxBECount is the minimum EL.MaxNotTaken of computable
7250     // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
7251     // EL.MaxNotTaken, where CouldNotCompute is considered greater than any
7252     // computable EL.MaxNotTaken.
7253     if (EL.MaxNotTaken != getCouldNotCompute() && Latch &&
7254         DT.dominates(ExitBB, Latch)) {
7255       if (!MustExitMaxBECount) {
7256         MustExitMaxBECount = EL.MaxNotTaken;
7257         MustExitMaxOrZero = EL.MaxOrZero;
7258       } else {
7259         MustExitMaxBECount =
7260             getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken);
7261       }
7262     } else if (MayExitMaxBECount != getCouldNotCompute()) {
7263       if (!MayExitMaxBECount || EL.MaxNotTaken == getCouldNotCompute())
7264         MayExitMaxBECount = EL.MaxNotTaken;
7265       else {
7266         MayExitMaxBECount =
7267             getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.MaxNotTaken);
7268       }
7269     }
7270   }
7271   const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
7272     (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
7273   // The loop backedge will be taken the maximum or zero times if there's
7274   // a single exit that must be taken the maximum or zero times.
7275   bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
7276   return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
7277                            MaxBECount, MaxOrZero);
7278 }
7279 
7280 ScalarEvolution::ExitLimit
7281 ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
7282                                       bool AllowPredicates) {
7283   assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
7284   // If our exiting block does not dominate the latch, then its connection with
7285   // loop's exit limit may be far from trivial.
7286   const BasicBlock *Latch = L->getLoopLatch();
7287   if (!Latch || !DT.dominates(ExitingBlock, Latch))
7288     return getCouldNotCompute();
7289 
7290   bool IsOnlyExit = (L->getExitingBlock() != nullptr);
7291   Instruction *Term = ExitingBlock->getTerminator();
7292   if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
7293     assert(BI->isConditional() && "If unconditional, it can't be in loop!");
7294     bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
7295     assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
7296            "It should have one successor in loop and one exit block!");
7297     // Proceed to the next level to examine the exit condition expression.
7298     return computeExitLimitFromCond(
7299         L, BI->getCondition(), ExitIfTrue,
7300         /*ControlsExit=*/IsOnlyExit, AllowPredicates);
7301   }
7302 
7303   if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
7304     // For switch, make sure that there is a single exit from the loop.
7305     BasicBlock *Exit = nullptr;
7306     for (auto *SBB : successors(ExitingBlock))
7307       if (!L->contains(SBB)) {
7308         if (Exit) // Multiple exit successors.
7309           return getCouldNotCompute();
7310         Exit = SBB;
7311       }
7312     assert(Exit && "Exiting block must have at least one exit");
7313     return computeExitLimitFromSingleExitSwitch(L, SI, Exit,
7314                                                 /*ControlsExit=*/IsOnlyExit);
7315   }
7316 
7317   return getCouldNotCompute();
7318 }
7319 
7320 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
7321     const Loop *L, Value *ExitCond, bool ExitIfTrue,
7322     bool ControlsExit, bool AllowPredicates) {
7323   ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
7324   return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
7325                                         ControlsExit, AllowPredicates);
7326 }
7327 
7328 Optional<ScalarEvolution::ExitLimit>
7329 ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
7330                                       bool ExitIfTrue, bool ControlsExit,
7331                                       bool AllowPredicates) {
7332   (void)this->L;
7333   (void)this->ExitIfTrue;
7334   (void)this->AllowPredicates;
7335 
7336   assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
7337          this->AllowPredicates == AllowPredicates &&
7338          "Variance in assumed invariant key components!");
7339   auto Itr = TripCountMap.find({ExitCond, ControlsExit});
7340   if (Itr == TripCountMap.end())
7341     return None;
7342   return Itr->second;
7343 }
7344 
7345 void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
7346                                              bool ExitIfTrue,
7347                                              bool ControlsExit,
7348                                              bool AllowPredicates,
7349                                              const ExitLimit &EL) {
7350   assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
7351          this->AllowPredicates == AllowPredicates &&
7352          "Variance in assumed invariant key components!");
7353 
7354   auto InsertResult = TripCountMap.insert({{ExitCond, ControlsExit}, EL});
7355   assert(InsertResult.second && "Expected successful insertion!");
7356   (void)InsertResult;
7357   (void)ExitIfTrue;
7358 }
7359 
7360 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
7361     ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
7362     bool ControlsExit, bool AllowPredicates) {
7363 
7364   if (auto MaybeEL =
7365           Cache.find(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates))
7366     return *MaybeEL;
7367 
7368   ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, ExitIfTrue,
7369                                               ControlsExit, AllowPredicates);
7370   Cache.insert(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates, EL);
7371   return EL;
7372 }
7373 
7374 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
7375     ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
7376     bool ControlsExit, bool AllowPredicates) {
7377   // Check if the controlling expression for this loop is an And or Or.
7378   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
7379     if (BO->getOpcode() == Instruction::And) {
7380       // Recurse on the operands of the and.
7381       bool EitherMayExit = !ExitIfTrue;
7382       ExitLimit EL0 = computeExitLimitFromCondCached(
7383           Cache, L, BO->getOperand(0), ExitIfTrue,
7384           ControlsExit && !EitherMayExit, AllowPredicates);
7385       ExitLimit EL1 = computeExitLimitFromCondCached(
7386           Cache, L, BO->getOperand(1), ExitIfTrue,
7387           ControlsExit && !EitherMayExit, AllowPredicates);
7388       // Be robust against unsimplified IR for the form "and i1 X, true"
7389       if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1)))
7390         return CI->isOne() ? EL0 : EL1;
7391       if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(0)))
7392         return CI->isOne() ? EL1 : EL0;
7393       const SCEV *BECount = getCouldNotCompute();
7394       const SCEV *MaxBECount = getCouldNotCompute();
7395       if (EitherMayExit) {
7396         // Both conditions must be true for the loop to continue executing.
7397         // Choose the less conservative count.
7398         if (EL0.ExactNotTaken == getCouldNotCompute() ||
7399             EL1.ExactNotTaken == getCouldNotCompute())
7400           BECount = getCouldNotCompute();
7401         else
7402           BECount =
7403               getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken);
7404         if (EL0.MaxNotTaken == getCouldNotCompute())
7405           MaxBECount = EL1.MaxNotTaken;
7406         else if (EL1.MaxNotTaken == getCouldNotCompute())
7407           MaxBECount = EL0.MaxNotTaken;
7408         else
7409           MaxBECount =
7410               getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
7411       } else {
7412         // Both conditions must be true at the same time for the loop to exit.
7413         // For now, be conservative.
7414         if (EL0.MaxNotTaken == EL1.MaxNotTaken)
7415           MaxBECount = EL0.MaxNotTaken;
7416         if (EL0.ExactNotTaken == EL1.ExactNotTaken)
7417           BECount = EL0.ExactNotTaken;
7418       }
7419 
7420       // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
7421       // to be more aggressive when computing BECount than when computing
7422       // MaxBECount.  In these cases it is possible for EL0.ExactNotTaken and
7423       // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken
7424       // to not.
7425       if (isa<SCEVCouldNotCompute>(MaxBECount) &&
7426           !isa<SCEVCouldNotCompute>(BECount))
7427         MaxBECount = getConstant(getUnsignedRangeMax(BECount));
7428 
7429       return ExitLimit(BECount, MaxBECount, false,
7430                        {&EL0.Predicates, &EL1.Predicates});
7431     }
7432     if (BO->getOpcode() == Instruction::Or) {
7433       // Recurse on the operands of the or.
7434       bool EitherMayExit = ExitIfTrue;
7435       ExitLimit EL0 = computeExitLimitFromCondCached(
7436           Cache, L, BO->getOperand(0), ExitIfTrue,
7437           ControlsExit && !EitherMayExit, AllowPredicates);
7438       ExitLimit EL1 = computeExitLimitFromCondCached(
7439           Cache, L, BO->getOperand(1), ExitIfTrue,
7440           ControlsExit && !EitherMayExit, AllowPredicates);
7441       // Be robust against unsimplified IR for the form "or i1 X, true"
7442       if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1)))
7443         return CI->isZero() ? EL0 : EL1;
7444       if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(0)))
7445         return CI->isZero() ? EL1 : EL0;
7446       const SCEV *BECount = getCouldNotCompute();
7447       const SCEV *MaxBECount = getCouldNotCompute();
7448       if (EitherMayExit) {
7449         // Both conditions must be false for the loop to continue executing.
7450         // Choose the less conservative count.
7451         if (EL0.ExactNotTaken == getCouldNotCompute() ||
7452             EL1.ExactNotTaken == getCouldNotCompute())
7453           BECount = getCouldNotCompute();
7454         else
7455           BECount =
7456               getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken);
7457         if (EL0.MaxNotTaken == getCouldNotCompute())
7458           MaxBECount = EL1.MaxNotTaken;
7459         else if (EL1.MaxNotTaken == getCouldNotCompute())
7460           MaxBECount = EL0.MaxNotTaken;
7461         else
7462           MaxBECount =
7463               getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
7464       } else {
7465         // Both conditions must be false at the same time for the loop to exit.
7466         // For now, be conservative.
7467         if (EL0.MaxNotTaken == EL1.MaxNotTaken)
7468           MaxBECount = EL0.MaxNotTaken;
7469         if (EL0.ExactNotTaken == EL1.ExactNotTaken)
7470           BECount = EL0.ExactNotTaken;
7471       }
7472       // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
7473       // to be more aggressive when computing BECount than when computing
7474       // MaxBECount.  In these cases it is possible for EL0.ExactNotTaken and
7475       // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken
7476       // to not.
7477       if (isa<SCEVCouldNotCompute>(MaxBECount) &&
7478           !isa<SCEVCouldNotCompute>(BECount))
7479         MaxBECount = getConstant(getUnsignedRangeMax(BECount));
7480 
7481       return ExitLimit(BECount, MaxBECount, false,
7482                        {&EL0.Predicates, &EL1.Predicates});
7483     }
7484   }
7485 
7486   // With an icmp, it may be feasible to compute an exact backedge-taken count.
7487   // Proceed to the next level to examine the icmp.
7488   if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
7489     ExitLimit EL =
7490         computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit);
7491     if (EL.hasFullInfo() || !AllowPredicates)
7492       return EL;
7493 
7494     // Try again, but use SCEV predicates this time.
7495     return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit,
7496                                     /*AllowPredicates=*/true);
7497   }
7498 
7499   // Check for a constant condition. These are normally stripped out by
7500   // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
7501   // preserve the CFG and is temporarily leaving constant conditions
7502   // in place.
7503   if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
7504     if (ExitIfTrue == !CI->getZExtValue())
7505       // The backedge is always taken.
7506       return getCouldNotCompute();
7507     else
7508       // The backedge is never taken.
7509       return getZero(CI->getType());
7510   }
7511 
7512   // If it's not an integer or pointer comparison then compute it the hard way.
7513   return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
7514 }
7515 
7516 ScalarEvolution::ExitLimit
7517 ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
7518                                           ICmpInst *ExitCond,
7519                                           bool ExitIfTrue,
7520                                           bool ControlsExit,
7521                                           bool AllowPredicates) {
7522   // If the condition was exit on true, convert the condition to exit on false
7523   ICmpInst::Predicate Pred;
7524   if (!ExitIfTrue)
7525     Pred = ExitCond->getPredicate();
7526   else
7527     Pred = ExitCond->getInversePredicate();
7528   const ICmpInst::Predicate OriginalPred = Pred;
7529 
7530   // Handle common loops like: for (X = "string"; *X; ++X)
7531   if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
7532     if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
7533       ExitLimit ItCnt =
7534         computeLoadConstantCompareExitLimit(LI, RHS, L, Pred);
7535       if (ItCnt.hasAnyInfo())
7536         return ItCnt;
7537     }
7538 
7539   const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
7540   const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
7541 
7542   // Try to evaluate any dependencies out of the loop.
7543   LHS = getSCEVAtScope(LHS, L);
7544   RHS = getSCEVAtScope(RHS, L);
7545 
7546   // At this point, we would like to compute how many iterations of the
7547   // loop the predicate will return true for these inputs.
7548   if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
7549     // If there is a loop-invariant, force it into the RHS.
7550     std::swap(LHS, RHS);
7551     Pred = ICmpInst::getSwappedPredicate(Pred);
7552   }
7553 
7554   // Simplify the operands before analyzing them.
7555   (void)SimplifyICmpOperands(Pred, LHS, RHS);
7556 
7557   // If we have a comparison of a chrec against a constant, try to use value
7558   // ranges to answer this query.
7559   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
7560     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
7561       if (AddRec->getLoop() == L) {
7562         // Form the constant range.
7563         ConstantRange CompRange =
7564             ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
7565 
7566         const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
7567         if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
7568       }
7569 
7570   switch (Pred) {
7571   case ICmpInst::ICMP_NE: {                     // while (X != Y)
7572     // Convert to: while (X-Y != 0)
7573     ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit,
7574                                 AllowPredicates);
7575     if (EL.hasAnyInfo()) return EL;
7576     break;
7577   }
7578   case ICmpInst::ICMP_EQ: {                     // while (X == Y)
7579     // Convert to: while (X-Y == 0)
7580     ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
7581     if (EL.hasAnyInfo()) return EL;
7582     break;
7583   }
7584   case ICmpInst::ICMP_SLT:
7585   case ICmpInst::ICMP_ULT: {                    // while (X < Y)
7586     bool IsSigned = Pred == ICmpInst::ICMP_SLT;
7587     ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit,
7588                                     AllowPredicates);
7589     if (EL.hasAnyInfo()) return EL;
7590     break;
7591   }
7592   case ICmpInst::ICMP_SGT:
7593   case ICmpInst::ICMP_UGT: {                    // while (X > Y)
7594     bool IsSigned = Pred == ICmpInst::ICMP_SGT;
7595     ExitLimit EL =
7596         howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit,
7597                             AllowPredicates);
7598     if (EL.hasAnyInfo()) return EL;
7599     break;
7600   }
7601   default:
7602     break;
7603   }
7604 
7605   auto *ExhaustiveCount =
7606       computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
7607 
7608   if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
7609     return ExhaustiveCount;
7610 
7611   return computeShiftCompareExitLimit(ExitCond->getOperand(0),
7612                                       ExitCond->getOperand(1), L, OriginalPred);
7613 }
7614 
7615 ScalarEvolution::ExitLimit
7616 ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
7617                                                       SwitchInst *Switch,
7618                                                       BasicBlock *ExitingBlock,
7619                                                       bool ControlsExit) {
7620   assert(!L->contains(ExitingBlock) && "Not an exiting block!");
7621 
7622   // Give up if the exit is the default dest of a switch.
7623   if (Switch->getDefaultDest() == ExitingBlock)
7624     return getCouldNotCompute();
7625 
7626   assert(L->contains(Switch->getDefaultDest()) &&
7627          "Default case must not exit the loop!");
7628   const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
7629   const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
7630 
7631   // while (X != Y) --> while (X-Y != 0)
7632   ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit);
7633   if (EL.hasAnyInfo())
7634     return EL;
7635 
7636   return getCouldNotCompute();
7637 }
7638 
7639 static ConstantInt *
7640 EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
7641                                 ScalarEvolution &SE) {
7642   const SCEV *InVal = SE.getConstant(C);
7643   const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
7644   assert(isa<SCEVConstant>(Val) &&
7645          "Evaluation of SCEV at constant didn't fold correctly?");
7646   return cast<SCEVConstant>(Val)->getValue();
7647 }
7648 
7649 /// Given an exit condition of 'icmp op load X, cst', try to see if we can
7650 /// compute the backedge execution count.
7651 ScalarEvolution::ExitLimit
7652 ScalarEvolution::computeLoadConstantCompareExitLimit(
7653   LoadInst *LI,
7654   Constant *RHS,
7655   const Loop *L,
7656   ICmpInst::Predicate predicate) {
7657   if (LI->isVolatile()) return getCouldNotCompute();
7658 
7659   // Check to see if the loaded pointer is a getelementptr of a global.
7660   // TODO: Use SCEV instead of manually grubbing with GEPs.
7661   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
7662   if (!GEP) return getCouldNotCompute();
7663 
7664   // Make sure that it is really a constant global we are gepping, with an
7665   // initializer, and make sure the first IDX is really 0.
7666   GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
7667   if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() ||
7668       GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
7669       !cast<Constant>(GEP->getOperand(1))->isNullValue())
7670     return getCouldNotCompute();
7671 
7672   // Okay, we allow one non-constant index into the GEP instruction.
7673   Value *VarIdx = nullptr;
7674   std::vector<Constant*> Indexes;
7675   unsigned VarIdxNum = 0;
7676   for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
7677     if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
7678       Indexes.push_back(CI);
7679     } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
7680       if (VarIdx) return getCouldNotCompute();  // Multiple non-constant idx's.
7681       VarIdx = GEP->getOperand(i);
7682       VarIdxNum = i-2;
7683       Indexes.push_back(nullptr);
7684     }
7685 
7686   // Loop-invariant loads may be a byproduct of loop optimization. Skip them.
7687   if (!VarIdx)
7688     return getCouldNotCompute();
7689 
7690   // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
7691   // Check to see if X is a loop variant variable value now.
7692   const SCEV *Idx = getSCEV(VarIdx);
7693   Idx = getSCEVAtScope(Idx, L);
7694 
7695   // We can only recognize very limited forms of loop index expressions, in
7696   // particular, only affine AddRec's like {C1,+,C2}.
7697   const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
7698   if (!IdxExpr || !IdxExpr->isAffine() || isLoopInvariant(IdxExpr, L) ||
7699       !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
7700       !isa<SCEVConstant>(IdxExpr->getOperand(1)))
7701     return getCouldNotCompute();
7702 
7703   unsigned MaxSteps = MaxBruteForceIterations;
7704   for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
7705     ConstantInt *ItCst = ConstantInt::get(
7706                            cast<IntegerType>(IdxExpr->getType()), IterationNum);
7707     ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
7708 
7709     // Form the GEP offset.
7710     Indexes[VarIdxNum] = Val;
7711 
7712     Constant *Result = ConstantFoldLoadThroughGEPIndices(GV->getInitializer(),
7713                                                          Indexes);
7714     if (!Result) break;  // Cannot compute!
7715 
7716     // Evaluate the condition for this iteration.
7717     Result = ConstantExpr::getICmp(predicate, Result, RHS);
7718     if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
7719     if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
7720       ++NumArrayLenItCounts;
7721       return getConstant(ItCst);   // Found terminating iteration!
7722     }
7723   }
7724   return getCouldNotCompute();
7725 }
7726 
7727 ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
7728     Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
7729   ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
7730   if (!RHS)
7731     return getCouldNotCompute();
7732 
7733   const BasicBlock *Latch = L->getLoopLatch();
7734   if (!Latch)
7735     return getCouldNotCompute();
7736 
7737   const BasicBlock *Predecessor = L->getLoopPredecessor();
7738   if (!Predecessor)
7739     return getCouldNotCompute();
7740 
7741   // Return true if V is of the form "LHS `shift_op` <positive constant>".
7742   // Return LHS in OutLHS and shift_opt in OutOpCode.
7743   auto MatchPositiveShift =
7744       [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
7745 
7746     using namespace PatternMatch;
7747 
7748     ConstantInt *ShiftAmt;
7749     if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
7750       OutOpCode = Instruction::LShr;
7751     else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
7752       OutOpCode = Instruction::AShr;
7753     else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
7754       OutOpCode = Instruction::Shl;
7755     else
7756       return false;
7757 
7758     return ShiftAmt->getValue().isStrictlyPositive();
7759   };
7760 
7761   // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
7762   //
7763   // loop:
7764   //   %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
7765   //   %iv.shifted = lshr i32 %iv, <positive constant>
7766   //
7767   // Return true on a successful match.  Return the corresponding PHI node (%iv
7768   // above) in PNOut and the opcode of the shift operation in OpCodeOut.
7769   auto MatchShiftRecurrence =
7770       [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
7771     Optional<Instruction::BinaryOps> PostShiftOpCode;
7772 
7773     {
7774       Instruction::BinaryOps OpC;
7775       Value *V;
7776 
7777       // If we encounter a shift instruction, "peel off" the shift operation,
7778       // and remember that we did so.  Later when we inspect %iv's backedge
7779       // value, we will make sure that the backedge value uses the same
7780       // operation.
7781       //
7782       // Note: the peeled shift operation does not have to be the same
7783       // instruction as the one feeding into the PHI's backedge value.  We only
7784       // really care about it being the same *kind* of shift instruction --
7785       // that's all that is required for our later inferences to hold.
7786       if (MatchPositiveShift(LHS, V, OpC)) {
7787         PostShiftOpCode = OpC;
7788         LHS = V;
7789       }
7790     }
7791 
7792     PNOut = dyn_cast<PHINode>(LHS);
7793     if (!PNOut || PNOut->getParent() != L->getHeader())
7794       return false;
7795 
7796     Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
7797     Value *OpLHS;
7798 
7799     return
7800         // The backedge value for the PHI node must be a shift by a positive
7801         // amount
7802         MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
7803 
7804         // of the PHI node itself
7805         OpLHS == PNOut &&
7806 
7807         // and the kind of shift should be match the kind of shift we peeled
7808         // off, if any.
7809         (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut);
7810   };
7811 
7812   PHINode *PN;
7813   Instruction::BinaryOps OpCode;
7814   if (!MatchShiftRecurrence(LHS, PN, OpCode))
7815     return getCouldNotCompute();
7816 
7817   const DataLayout &DL = getDataLayout();
7818 
7819   // The key rationale for this optimization is that for some kinds of shift
7820   // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
7821   // within a finite number of iterations.  If the condition guarding the
7822   // backedge (in the sense that the backedge is taken if the condition is true)
7823   // is false for the value the shift recurrence stabilizes to, then we know
7824   // that the backedge is taken only a finite number of times.
7825 
7826   ConstantInt *StableValue = nullptr;
7827   switch (OpCode) {
7828   default:
7829     llvm_unreachable("Impossible case!");
7830 
7831   case Instruction::AShr: {
7832     // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
7833     // bitwidth(K) iterations.
7834     Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
7835     KnownBits Known = computeKnownBits(FirstValue, DL, 0, nullptr,
7836                                        Predecessor->getTerminator(), &DT);
7837     auto *Ty = cast<IntegerType>(RHS->getType());
7838     if (Known.isNonNegative())
7839       StableValue = ConstantInt::get(Ty, 0);
7840     else if (Known.isNegative())
7841       StableValue = ConstantInt::get(Ty, -1, true);
7842     else
7843       return getCouldNotCompute();
7844 
7845     break;
7846   }
7847   case Instruction::LShr:
7848   case Instruction::Shl:
7849     // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
7850     // stabilize to 0 in at most bitwidth(K) iterations.
7851     StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
7852     break;
7853   }
7854 
7855   auto *Result =
7856       ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
7857   assert(Result->getType()->isIntegerTy(1) &&
7858          "Otherwise cannot be an operand to a branch instruction");
7859 
7860   if (Result->isZeroValue()) {
7861     unsigned BitWidth = getTypeSizeInBits(RHS->getType());
7862     const SCEV *UpperBound =
7863         getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
7864     return ExitLimit(getCouldNotCompute(), UpperBound, false);
7865   }
7866 
7867   return getCouldNotCompute();
7868 }
7869 
7870 /// Return true if we can constant fold an instruction of the specified type,
7871 /// assuming that all operands were constants.
7872 static bool CanConstantFold(const Instruction *I) {
7873   if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
7874       isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
7875       isa<LoadInst>(I) || isa<ExtractValueInst>(I))
7876     return true;
7877 
7878   if (const CallInst *CI = dyn_cast<CallInst>(I))
7879     if (const Function *F = CI->getCalledFunction())
7880       return canConstantFoldCallTo(CI, F);
7881   return false;
7882 }
7883 
7884 /// Determine whether this instruction can constant evolve within this loop
7885 /// assuming its operands can all constant evolve.
7886 static bool canConstantEvolve(Instruction *I, const Loop *L) {
7887   // An instruction outside of the loop can't be derived from a loop PHI.
7888   if (!L->contains(I)) return false;
7889 
7890   if (isa<PHINode>(I)) {
7891     // We don't currently keep track of the control flow needed to evaluate
7892     // PHIs, so we cannot handle PHIs inside of loops.
7893     return L->getHeader() == I->getParent();
7894   }
7895 
7896   // If we won't be able to constant fold this expression even if the operands
7897   // are constants, bail early.
7898   return CanConstantFold(I);
7899 }
7900 
7901 /// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
7902 /// recursing through each instruction operand until reaching a loop header phi.
7903 static PHINode *
7904 getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L,
7905                                DenseMap<Instruction *, PHINode *> &PHIMap,
7906                                unsigned Depth) {
7907   if (Depth > MaxConstantEvolvingDepth)
7908     return nullptr;
7909 
7910   // Otherwise, we can evaluate this instruction if all of its operands are
7911   // constant or derived from a PHI node themselves.
7912   PHINode *PHI = nullptr;
7913   for (Value *Op : UseInst->operands()) {
7914     if (isa<Constant>(Op)) continue;
7915 
7916     Instruction *OpInst = dyn_cast<Instruction>(Op);
7917     if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
7918 
7919     PHINode *P = dyn_cast<PHINode>(OpInst);
7920     if (!P)
7921       // If this operand is already visited, reuse the prior result.
7922       // We may have P != PHI if this is the deepest point at which the
7923       // inconsistent paths meet.
7924       P = PHIMap.lookup(OpInst);
7925     if (!P) {
7926       // Recurse and memoize the results, whether a phi is found or not.
7927       // This recursive call invalidates pointers into PHIMap.
7928       P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
7929       PHIMap[OpInst] = P;
7930     }
7931     if (!P)
7932       return nullptr;  // Not evolving from PHI
7933     if (PHI && PHI != P)
7934       return nullptr;  // Evolving from multiple different PHIs.
7935     PHI = P;
7936   }
7937   // This is a expression evolving from a constant PHI!
7938   return PHI;
7939 }
7940 
7941 /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
7942 /// in the loop that V is derived from.  We allow arbitrary operations along the
7943 /// way, but the operands of an operation must either be constants or a value
7944 /// derived from a constant PHI.  If this expression does not fit with these
7945 /// constraints, return null.
7946 static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
7947   Instruction *I = dyn_cast<Instruction>(V);
7948   if (!I || !canConstantEvolve(I, L)) return nullptr;
7949 
7950   if (PHINode *PN = dyn_cast<PHINode>(I))
7951     return PN;
7952 
7953   // Record non-constant instructions contained by the loop.
7954   DenseMap<Instruction *, PHINode *> PHIMap;
7955   return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
7956 }
7957 
7958 /// EvaluateExpression - Given an expression that passes the
7959 /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
7960 /// in the loop has the value PHIVal.  If we can't fold this expression for some
7961 /// reason, return null.
7962 static Constant *EvaluateExpression(Value *V, const Loop *L,
7963                                     DenseMap<Instruction *, Constant *> &Vals,
7964                                     const DataLayout &DL,
7965                                     const TargetLibraryInfo *TLI) {
7966   // Convenient constant check, but redundant for recursive calls.
7967   if (Constant *C = dyn_cast<Constant>(V)) return C;
7968   Instruction *I = dyn_cast<Instruction>(V);
7969   if (!I) return nullptr;
7970 
7971   if (Constant *C = Vals.lookup(I)) return C;
7972 
7973   // An instruction inside the loop depends on a value outside the loop that we
7974   // weren't given a mapping for, or a value such as a call inside the loop.
7975   if (!canConstantEvolve(I, L)) return nullptr;
7976 
7977   // An unmapped PHI can be due to a branch or another loop inside this loop,
7978   // or due to this not being the initial iteration through a loop where we
7979   // couldn't compute the evolution of this particular PHI last time.
7980   if (isa<PHINode>(I)) return nullptr;
7981 
7982   std::vector<Constant*> Operands(I->getNumOperands());
7983 
7984   for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
7985     Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
7986     if (!Operand) {
7987       Operands[i] = dyn_cast<Constant>(I->getOperand(i));
7988       if (!Operands[i]) return nullptr;
7989       continue;
7990     }
7991     Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
7992     Vals[Operand] = C;
7993     if (!C) return nullptr;
7994     Operands[i] = C;
7995   }
7996 
7997   if (CmpInst *CI = dyn_cast<CmpInst>(I))
7998     return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
7999                                            Operands[1], DL, TLI);
8000   if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
8001     if (!LI->isVolatile())
8002       return ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL);
8003   }
8004   return ConstantFoldInstOperands(I, Operands, DL, TLI);
8005 }
8006 
8007 
8008 // If every incoming value to PN except the one for BB is a specific Constant,
8009 // return that, else return nullptr.
8010 static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) {
8011   Constant *IncomingVal = nullptr;
8012 
8013   for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
8014     if (PN->getIncomingBlock(i) == BB)
8015       continue;
8016 
8017     auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
8018     if (!CurrentVal)
8019       return nullptr;
8020 
8021     if (IncomingVal != CurrentVal) {
8022       if (IncomingVal)
8023         return nullptr;
8024       IncomingVal = CurrentVal;
8025     }
8026   }
8027 
8028   return IncomingVal;
8029 }
8030 
8031 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
8032 /// in the header of its containing loop, we know the loop executes a
8033 /// constant number of times, and the PHI node is just a recurrence
8034 /// involving constants, fold it.
8035 Constant *
8036 ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
8037                                                    const APInt &BEs,
8038                                                    const Loop *L) {
8039   auto I = ConstantEvolutionLoopExitValue.find(PN);
8040   if (I != ConstantEvolutionLoopExitValue.end())
8041     return I->second;
8042 
8043   if (BEs.ugt(MaxBruteForceIterations))
8044     return ConstantEvolutionLoopExitValue[PN] = nullptr;  // Not going to evaluate it.
8045 
8046   Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
8047 
8048   DenseMap<Instruction *, Constant *> CurrentIterVals;
8049   BasicBlock *Header = L->getHeader();
8050   assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
8051 
8052   BasicBlock *Latch = L->getLoopLatch();
8053   if (!Latch)
8054     return nullptr;
8055 
8056   for (PHINode &PHI : Header->phis()) {
8057     if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
8058       CurrentIterVals[&PHI] = StartCST;
8059   }
8060   if (!CurrentIterVals.count(PN))
8061     return RetVal = nullptr;
8062 
8063   Value *BEValue = PN->getIncomingValueForBlock(Latch);
8064 
8065   // Execute the loop symbolically to determine the exit value.
8066   assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
8067          "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
8068 
8069   unsigned NumIterations = BEs.getZExtValue(); // must be in range
8070   unsigned IterationNum = 0;
8071   const DataLayout &DL = getDataLayout();
8072   for (; ; ++IterationNum) {
8073     if (IterationNum == NumIterations)
8074       return RetVal = CurrentIterVals[PN];  // Got exit value!
8075 
8076     // Compute the value of the PHIs for the next iteration.
8077     // EvaluateExpression adds non-phi values to the CurrentIterVals map.
8078     DenseMap<Instruction *, Constant *> NextIterVals;
8079     Constant *NextPHI =
8080         EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
8081     if (!NextPHI)
8082       return nullptr;        // Couldn't evaluate!
8083     NextIterVals[PN] = NextPHI;
8084 
8085     bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
8086 
8087     // Also evaluate the other PHI nodes.  However, we don't get to stop if we
8088     // cease to be able to evaluate one of them or if they stop evolving,
8089     // because that doesn't necessarily prevent us from computing PN.
8090     SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute;
8091     for (const auto &I : CurrentIterVals) {
8092       PHINode *PHI = dyn_cast<PHINode>(I.first);
8093       if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
8094       PHIsToCompute.emplace_back(PHI, I.second);
8095     }
8096     // We use two distinct loops because EvaluateExpression may invalidate any
8097     // iterators into CurrentIterVals.
8098     for (const auto &I : PHIsToCompute) {
8099       PHINode *PHI = I.first;
8100       Constant *&NextPHI = NextIterVals[PHI];
8101       if (!NextPHI) {   // Not already computed.
8102         Value *BEValue = PHI->getIncomingValueForBlock(Latch);
8103         NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
8104       }
8105       if (NextPHI != I.second)
8106         StoppedEvolving = false;
8107     }
8108 
8109     // If all entries in CurrentIterVals == NextIterVals then we can stop
8110     // iterating, the loop can't continue to change.
8111     if (StoppedEvolving)
8112       return RetVal = CurrentIterVals[PN];
8113 
8114     CurrentIterVals.swap(NextIterVals);
8115   }
8116 }
8117 
8118 const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
8119                                                           Value *Cond,
8120                                                           bool ExitWhen) {
8121   PHINode *PN = getConstantEvolvingPHI(Cond, L);
8122   if (!PN) return getCouldNotCompute();
8123 
8124   // If the loop is canonicalized, the PHI will have exactly two entries.
8125   // That's the only form we support here.
8126   if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
8127 
8128   DenseMap<Instruction *, Constant *> CurrentIterVals;
8129   BasicBlock *Header = L->getHeader();
8130   assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
8131 
8132   BasicBlock *Latch = L->getLoopLatch();
8133   assert(Latch && "Should follow from NumIncomingValues == 2!");
8134 
8135   for (PHINode &PHI : Header->phis()) {
8136     if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
8137       CurrentIterVals[&PHI] = StartCST;
8138   }
8139   if (!CurrentIterVals.count(PN))
8140     return getCouldNotCompute();
8141 
8142   // Okay, we find a PHI node that defines the trip count of this loop.  Execute
8143   // the loop symbolically to determine when the condition gets a value of
8144   // "ExitWhen".
8145   unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
8146   const DataLayout &DL = getDataLayout();
8147   for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
8148     auto *CondVal = dyn_cast_or_null<ConstantInt>(
8149         EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
8150 
8151     // Couldn't symbolically evaluate.
8152     if (!CondVal) return getCouldNotCompute();
8153 
8154     if (CondVal->getValue() == uint64_t(ExitWhen)) {
8155       ++NumBruteForceTripCountsComputed;
8156       return getConstant(Type::getInt32Ty(getContext()), IterationNum);
8157     }
8158 
8159     // Update all the PHI nodes for the next iteration.
8160     DenseMap<Instruction *, Constant *> NextIterVals;
8161 
8162     // Create a list of which PHIs we need to compute. We want to do this before
8163     // calling EvaluateExpression on them because that may invalidate iterators
8164     // into CurrentIterVals.
8165     SmallVector<PHINode *, 8> PHIsToCompute;
8166     for (const auto &I : CurrentIterVals) {
8167       PHINode *PHI = dyn_cast<PHINode>(I.first);
8168       if (!PHI || PHI->getParent() != Header) continue;
8169       PHIsToCompute.push_back(PHI);
8170     }
8171     for (PHINode *PHI : PHIsToCompute) {
8172       Constant *&NextPHI = NextIterVals[PHI];
8173       if (NextPHI) continue;    // Already computed!
8174 
8175       Value *BEValue = PHI->getIncomingValueForBlock(Latch);
8176       NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
8177     }
8178     CurrentIterVals.swap(NextIterVals);
8179   }
8180 
8181   // Too many iterations were needed to evaluate.
8182   return getCouldNotCompute();
8183 }
8184 
8185 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
8186   SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
8187       ValuesAtScopes[V];
8188   // Check to see if we've folded this expression at this loop before.
8189   for (auto &LS : Values)
8190     if (LS.first == L)
8191       return LS.second ? LS.second : V;
8192 
8193   Values.emplace_back(L, nullptr);
8194 
8195   // Otherwise compute it.
8196   const SCEV *C = computeSCEVAtScope(V, L);
8197   for (auto &LS : reverse(ValuesAtScopes[V]))
8198     if (LS.first == L) {
8199       LS.second = C;
8200       break;
8201     }
8202   return C;
8203 }
8204 
8205 /// This builds up a Constant using the ConstantExpr interface.  That way, we
8206 /// will return Constants for objects which aren't represented by a
8207 /// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
8208 /// Returns NULL if the SCEV isn't representable as a Constant.
8209 static Constant *BuildConstantFromSCEV(const SCEV *V) {
8210   switch (V->getSCEVType()) {
8211   case scCouldNotCompute:
8212   case scAddRecExpr:
8213     return nullptr;
8214   case scConstant:
8215     return cast<SCEVConstant>(V)->getValue();
8216   case scUnknown:
8217     return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
8218   case scSignExtend: {
8219     const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
8220     if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
8221       return ConstantExpr::getSExt(CastOp, SS->getType());
8222     return nullptr;
8223   }
8224   case scZeroExtend: {
8225     const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
8226     if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
8227       return ConstantExpr::getZExt(CastOp, SZ->getType());
8228     return nullptr;
8229   }
8230   case scPtrToInt: {
8231     const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
8232     if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
8233       return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
8234 
8235     return nullptr;
8236   }
8237   case scTruncate: {
8238     const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
8239     if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
8240       return ConstantExpr::getTrunc(CastOp, ST->getType());
8241     return nullptr;
8242   }
8243   case scAddExpr: {
8244     const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
8245     if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) {
8246       if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
8247         unsigned AS = PTy->getAddressSpace();
8248         Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
8249         C = ConstantExpr::getBitCast(C, DestPtrTy);
8250       }
8251       for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) {
8252         Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i));
8253         if (!C2)
8254           return nullptr;
8255 
8256         // First pointer!
8257         if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) {
8258           unsigned AS = C2->getType()->getPointerAddressSpace();
8259           std::swap(C, C2);
8260           Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
8261           // The offsets have been converted to bytes.  We can add bytes to an
8262           // i8* by GEP with the byte count in the first index.
8263           C = ConstantExpr::getBitCast(C, DestPtrTy);
8264         }
8265 
8266         // Don't bother trying to sum two pointers. We probably can't
8267         // statically compute a load that results from it anyway.
8268         if (C2->getType()->isPointerTy())
8269           return nullptr;
8270 
8271         if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
8272           if (PTy->getElementType()->isStructTy())
8273             C2 = ConstantExpr::getIntegerCast(
8274                 C2, Type::getInt32Ty(C->getContext()), true);
8275           C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2);
8276         } else
8277           C = ConstantExpr::getAdd(C, C2);
8278       }
8279       return C;
8280     }
8281     return nullptr;
8282   }
8283   case scMulExpr: {
8284     const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
8285     if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) {
8286       // Don't bother with pointers at all.
8287       if (C->getType()->isPointerTy())
8288         return nullptr;
8289       for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) {
8290         Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i));
8291         if (!C2 || C2->getType()->isPointerTy())
8292           return nullptr;
8293         C = ConstantExpr::getMul(C, C2);
8294       }
8295       return C;
8296     }
8297     return nullptr;
8298   }
8299   case scUDivExpr: {
8300     const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
8301     if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS()))
8302       if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS()))
8303         if (LHS->getType() == RHS->getType())
8304           return ConstantExpr::getUDiv(LHS, RHS);
8305     return nullptr;
8306   }
8307   case scSMaxExpr:
8308   case scUMaxExpr:
8309   case scSMinExpr:
8310   case scUMinExpr:
8311     return nullptr; // TODO: smax, umax, smin, umax.
8312   }
8313   llvm_unreachable("Unknown SCEV kind!");
8314 }
8315 
8316 const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
8317   if (isa<SCEVConstant>(V)) return V;
8318 
8319   // If this instruction is evolved from a constant-evolving PHI, compute the
8320   // exit value from the loop without using SCEVs.
8321   if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
8322     if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
8323       if (PHINode *PN = dyn_cast<PHINode>(I)) {
8324         const Loop *CurrLoop = this->LI[I->getParent()];
8325         // Looking for loop exit value.
8326         if (CurrLoop && CurrLoop->getParentLoop() == L &&
8327             PN->getParent() == CurrLoop->getHeader()) {
8328           // Okay, there is no closed form solution for the PHI node.  Check
8329           // to see if the loop that contains it has a known backedge-taken
8330           // count.  If so, we may be able to force computation of the exit
8331           // value.
8332           const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
8333           // This trivial case can show up in some degenerate cases where
8334           // the incoming IR has not yet been fully simplified.
8335           if (BackedgeTakenCount->isZero()) {
8336             Value *InitValue = nullptr;
8337             bool MultipleInitValues = false;
8338             for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
8339               if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
8340                 if (!InitValue)
8341                   InitValue = PN->getIncomingValue(i);
8342                 else if (InitValue != PN->getIncomingValue(i)) {
8343                   MultipleInitValues = true;
8344                   break;
8345                 }
8346               }
8347             }
8348             if (!MultipleInitValues && InitValue)
8349               return getSCEV(InitValue);
8350           }
8351           // Do we have a loop invariant value flowing around the backedge
8352           // for a loop which must execute the backedge?
8353           if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
8354               isKnownPositive(BackedgeTakenCount) &&
8355               PN->getNumIncomingValues() == 2) {
8356 
8357             unsigned InLoopPred =
8358                 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
8359             Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
8360             if (CurrLoop->isLoopInvariant(BackedgeVal))
8361               return getSCEV(BackedgeVal);
8362           }
8363           if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
8364             // Okay, we know how many times the containing loop executes.  If
8365             // this is a constant evolving PHI node, get the final value at
8366             // the specified iteration number.
8367             Constant *RV = getConstantEvolutionLoopExitValue(
8368                 PN, BTCC->getAPInt(), CurrLoop);
8369             if (RV) return getSCEV(RV);
8370           }
8371         }
8372 
8373         // If there is a single-input Phi, evaluate it at our scope. If we can
8374         // prove that this replacement does not break LCSSA form, use new value.
8375         if (PN->getNumOperands() == 1) {
8376           const SCEV *Input = getSCEV(PN->getOperand(0));
8377           const SCEV *InputAtScope = getSCEVAtScope(Input, L);
8378           // TODO: We can generalize it using LI.replacementPreservesLCSSAForm,
8379           // for the simplest case just support constants.
8380           if (isa<SCEVConstant>(InputAtScope)) return InputAtScope;
8381         }
8382       }
8383 
8384       // Okay, this is an expression that we cannot symbolically evaluate
8385       // into a SCEV.  Check to see if it's possible to symbolically evaluate
8386       // the arguments into constants, and if so, try to constant propagate the
8387       // result.  This is particularly useful for computing loop exit values.
8388       if (CanConstantFold(I)) {
8389         SmallVector<Constant *, 4> Operands;
8390         bool MadeImprovement = false;
8391         for (Value *Op : I->operands()) {
8392           if (Constant *C = dyn_cast<Constant>(Op)) {
8393             Operands.push_back(C);
8394             continue;
8395           }
8396 
8397           // If any of the operands is non-constant and if they are
8398           // non-integer and non-pointer, don't even try to analyze them
8399           // with scev techniques.
8400           if (!isSCEVable(Op->getType()))
8401             return V;
8402 
8403           const SCEV *OrigV = getSCEV(Op);
8404           const SCEV *OpV = getSCEVAtScope(OrigV, L);
8405           MadeImprovement |= OrigV != OpV;
8406 
8407           Constant *C = BuildConstantFromSCEV(OpV);
8408           if (!C) return V;
8409           if (C->getType() != Op->getType())
8410             C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
8411                                                               Op->getType(),
8412                                                               false),
8413                                       C, Op->getType());
8414           Operands.push_back(C);
8415         }
8416 
8417         // Check to see if getSCEVAtScope actually made an improvement.
8418         if (MadeImprovement) {
8419           Constant *C = nullptr;
8420           const DataLayout &DL = getDataLayout();
8421           if (const CmpInst *CI = dyn_cast<CmpInst>(I))
8422             C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
8423                                                 Operands[1], DL, &TLI);
8424           else if (const LoadInst *Load = dyn_cast<LoadInst>(I)) {
8425             if (!Load->isVolatile())
8426               C = ConstantFoldLoadFromConstPtr(Operands[0], Load->getType(),
8427                                                DL);
8428           } else
8429             C = ConstantFoldInstOperands(I, Operands, DL, &TLI);
8430           if (!C) return V;
8431           return getSCEV(C);
8432         }
8433       }
8434     }
8435 
8436     // This is some other type of SCEVUnknown, just return it.
8437     return V;
8438   }
8439 
8440   if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
8441     // Avoid performing the look-up in the common case where the specified
8442     // expression has no loop-variant portions.
8443     for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
8444       const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
8445       if (OpAtScope != Comm->getOperand(i)) {
8446         // Okay, at least one of these operands is loop variant but might be
8447         // foldable.  Build a new instance of the folded commutative expression.
8448         SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
8449                                             Comm->op_begin()+i);
8450         NewOps.push_back(OpAtScope);
8451 
8452         for (++i; i != e; ++i) {
8453           OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
8454           NewOps.push_back(OpAtScope);
8455         }
8456         if (isa<SCEVAddExpr>(Comm))
8457           return getAddExpr(NewOps, Comm->getNoWrapFlags());
8458         if (isa<SCEVMulExpr>(Comm))
8459           return getMulExpr(NewOps, Comm->getNoWrapFlags());
8460         if (isa<SCEVMinMaxExpr>(Comm))
8461           return getMinMaxExpr(Comm->getSCEVType(), NewOps);
8462         llvm_unreachable("Unknown commutative SCEV type!");
8463       }
8464     }
8465     // If we got here, all operands are loop invariant.
8466     return Comm;
8467   }
8468 
8469   if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
8470     const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
8471     const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
8472     if (LHS == Div->getLHS() && RHS == Div->getRHS())
8473       return Div;   // must be loop invariant
8474     return getUDivExpr(LHS, RHS);
8475   }
8476 
8477   // If this is a loop recurrence for a loop that does not contain L, then we
8478   // are dealing with the final value computed by the loop.
8479   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
8480     // First, attempt to evaluate each operand.
8481     // Avoid performing the look-up in the common case where the specified
8482     // expression has no loop-variant portions.
8483     for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
8484       const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
8485       if (OpAtScope == AddRec->getOperand(i))
8486         continue;
8487 
8488       // Okay, at least one of these operands is loop variant but might be
8489       // foldable.  Build a new instance of the folded commutative expression.
8490       SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(),
8491                                           AddRec->op_begin()+i);
8492       NewOps.push_back(OpAtScope);
8493       for (++i; i != e; ++i)
8494         NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
8495 
8496       const SCEV *FoldedRec =
8497         getAddRecExpr(NewOps, AddRec->getLoop(),
8498                       AddRec->getNoWrapFlags(SCEV::FlagNW));
8499       AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
8500       // The addrec may be folded to a nonrecurrence, for example, if the
8501       // induction variable is multiplied by zero after constant folding. Go
8502       // ahead and return the folded value.
8503       if (!AddRec)
8504         return FoldedRec;
8505       break;
8506     }
8507 
8508     // If the scope is outside the addrec's loop, evaluate it by using the
8509     // loop exit value of the addrec.
8510     if (!AddRec->getLoop()->contains(L)) {
8511       // To evaluate this recurrence, we need to know how many times the AddRec
8512       // loop iterates.  Compute this now.
8513       const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
8514       if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
8515 
8516       // Then, evaluate the AddRec.
8517       return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
8518     }
8519 
8520     return AddRec;
8521   }
8522 
8523   if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) {
8524     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
8525     if (Op == Cast->getOperand())
8526       return Cast;  // must be loop invariant
8527     return getZeroExtendExpr(Op, Cast->getType());
8528   }
8529 
8530   if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) {
8531     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
8532     if (Op == Cast->getOperand())
8533       return Cast;  // must be loop invariant
8534     return getSignExtendExpr(Op, Cast->getType());
8535   }
8536 
8537   if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) {
8538     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
8539     if (Op == Cast->getOperand())
8540       return Cast;  // must be loop invariant
8541     return getTruncateExpr(Op, Cast->getType());
8542   }
8543 
8544   if (const SCEVPtrToIntExpr *Cast = dyn_cast<SCEVPtrToIntExpr>(V)) {
8545     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
8546     if (Op == Cast->getOperand())
8547       return Cast; // must be loop invariant
8548     return getPtrToIntExpr(Op, Cast->getType());
8549   }
8550 
8551   llvm_unreachable("Unknown SCEV type!");
8552 }
8553 
8554 const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
8555   return getSCEVAtScope(getSCEV(V), L);
8556 }
8557 
8558 const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
8559   if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
8560     return stripInjectiveFunctions(ZExt->getOperand());
8561   if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
8562     return stripInjectiveFunctions(SExt->getOperand());
8563   return S;
8564 }
8565 
8566 /// Finds the minimum unsigned root of the following equation:
8567 ///
8568 ///     A * X = B (mod N)
8569 ///
8570 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of
8571 /// A and B isn't important.
8572 ///
8573 /// If the equation does not have a solution, SCEVCouldNotCompute is returned.
8574 static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
8575                                                ScalarEvolution &SE) {
8576   uint32_t BW = A.getBitWidth();
8577   assert(BW == SE.getTypeSizeInBits(B->getType()));
8578   assert(A != 0 && "A must be non-zero.");
8579 
8580   // 1. D = gcd(A, N)
8581   //
8582   // The gcd of A and N may have only one prime factor: 2. The number of
8583   // trailing zeros in A is its multiplicity
8584   uint32_t Mult2 = A.countTrailingZeros();
8585   // D = 2^Mult2
8586 
8587   // 2. Check if B is divisible by D.
8588   //
8589   // B is divisible by D if and only if the multiplicity of prime factor 2 for B
8590   // is not less than multiplicity of this prime factor for D.
8591   if (SE.GetMinTrailingZeros(B) < Mult2)
8592     return SE.getCouldNotCompute();
8593 
8594   // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
8595   // modulo (N / D).
8596   //
8597   // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
8598   // (N / D) in general. The inverse itself always fits into BW bits, though,
8599   // so we immediately truncate it.
8600   APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
8601   APInt Mod(BW + 1, 0);
8602   Mod.setBit(BW - Mult2);  // Mod = N / D
8603   APInt I = AD.multiplicativeInverse(Mod).trunc(BW);
8604 
8605   // 4. Compute the minimum unsigned root of the equation:
8606   // I * (B / D) mod (N / D)
8607   // To simplify the computation, we factor out the divide by D:
8608   // (I * B mod N) / D
8609   const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
8610   return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
8611 }
8612 
8613 /// For a given quadratic addrec, generate coefficients of the corresponding
8614 /// quadratic equation, multiplied by a common value to ensure that they are
8615 /// integers.
8616 /// The returned value is a tuple { A, B, C, M, BitWidth }, where
8617 /// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
8618 /// were multiplied by, and BitWidth is the bit width of the original addrec
8619 /// coefficients.
8620 /// This function returns None if the addrec coefficients are not compile-
8621 /// time constants.
8622 static Optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
8623 GetQuadraticEquation(const SCEVAddRecExpr *AddRec) {
8624   assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
8625   const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
8626   const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
8627   const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
8628   LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
8629                     << *AddRec << '\n');
8630 
8631   // We currently can only solve this if the coefficients are constants.
8632   if (!LC || !MC || !NC) {
8633     LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
8634     return None;
8635   }
8636 
8637   APInt L = LC->getAPInt();
8638   APInt M = MC->getAPInt();
8639   APInt N = NC->getAPInt();
8640   assert(!N.isNullValue() && "This is not a quadratic addrec");
8641 
8642   unsigned BitWidth = LC->getAPInt().getBitWidth();
8643   unsigned NewWidth = BitWidth + 1;
8644   LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
8645                     << BitWidth << '\n');
8646   // The sign-extension (as opposed to a zero-extension) here matches the
8647   // extension used in SolveQuadraticEquationWrap (with the same motivation).
8648   N = N.sext(NewWidth);
8649   M = M.sext(NewWidth);
8650   L = L.sext(NewWidth);
8651 
8652   // The increments are M, M+N, M+2N, ..., so the accumulated values are
8653   //   L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
8654   //   L+M, L+2M+N, L+3M+3N, ...
8655   // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
8656   //
8657   // The equation Acc = 0 is then
8658   //   L + nM + n(n-1)/2 N = 0,  or  2L + 2M n + n(n-1) N = 0.
8659   // In a quadratic form it becomes:
8660   //   N n^2 + (2M-N) n + 2L = 0.
8661 
8662   APInt A = N;
8663   APInt B = 2 * M - A;
8664   APInt C = 2 * L;
8665   APInt T = APInt(NewWidth, 2);
8666   LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
8667                     << "x + " << C << ", coeff bw: " << NewWidth
8668                     << ", multiplied by " << T << '\n');
8669   return std::make_tuple(A, B, C, T, BitWidth);
8670 }
8671 
8672 /// Helper function to compare optional APInts:
8673 /// (a) if X and Y both exist, return min(X, Y),
8674 /// (b) if neither X nor Y exist, return None,
8675 /// (c) if exactly one of X and Y exists, return that value.
8676 static Optional<APInt> MinOptional(Optional<APInt> X, Optional<APInt> Y) {
8677   if (X.hasValue() && Y.hasValue()) {
8678     unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
8679     APInt XW = X->sextOrSelf(W);
8680     APInt YW = Y->sextOrSelf(W);
8681     return XW.slt(YW) ? *X : *Y;
8682   }
8683   if (!X.hasValue() && !Y.hasValue())
8684     return None;
8685   return X.hasValue() ? *X : *Y;
8686 }
8687 
8688 /// Helper function to truncate an optional APInt to a given BitWidth.
8689 /// When solving addrec-related equations, it is preferable to return a value
8690 /// that has the same bit width as the original addrec's coefficients. If the
8691 /// solution fits in the original bit width, truncate it (except for i1).
8692 /// Returning a value of a different bit width may inhibit some optimizations.
8693 ///
8694 /// In general, a solution to a quadratic equation generated from an addrec
8695 /// may require BW+1 bits, where BW is the bit width of the addrec's
8696 /// coefficients. The reason is that the coefficients of the quadratic
8697 /// equation are BW+1 bits wide (to avoid truncation when converting from
8698 /// the addrec to the equation).
8699 static Optional<APInt> TruncIfPossible(Optional<APInt> X, unsigned BitWidth) {
8700   if (!X.hasValue())
8701     return None;
8702   unsigned W = X->getBitWidth();
8703   if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
8704     return X->trunc(BitWidth);
8705   return X;
8706 }
8707 
8708 /// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
8709 /// iterations. The values L, M, N are assumed to be signed, and they
8710 /// should all have the same bit widths.
8711 /// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
8712 /// where BW is the bit width of the addrec's coefficients.
8713 /// If the calculated value is a BW-bit integer (for BW > 1), it will be
8714 /// returned as such, otherwise the bit width of the returned value may
8715 /// be greater than BW.
8716 ///
8717 /// This function returns None if
8718 /// (a) the addrec coefficients are not constant, or
8719 /// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
8720 ///     like x^2 = 5, no integer solutions exist, in other cases an integer
8721 ///     solution may exist, but SolveQuadraticEquationWrap may fail to find it.
8722 static Optional<APInt>
8723 SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
8724   APInt A, B, C, M;
8725   unsigned BitWidth;
8726   auto T = GetQuadraticEquation(AddRec);
8727   if (!T.hasValue())
8728     return None;
8729 
8730   std::tie(A, B, C, M, BitWidth) = *T;
8731   LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
8732   Optional<APInt> X = APIntOps::SolveQuadraticEquationWrap(A, B, C, BitWidth+1);
8733   if (!X.hasValue())
8734     return None;
8735 
8736   ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
8737   ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
8738   if (!V->isZero())
8739     return None;
8740 
8741   return TruncIfPossible(X, BitWidth);
8742 }
8743 
8744 /// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
8745 /// iterations. The values M, N are assumed to be signed, and they
8746 /// should all have the same bit widths.
8747 /// Find the least n such that c(n) does not belong to the given range,
8748 /// while c(n-1) does.
8749 ///
8750 /// This function returns None if
8751 /// (a) the addrec coefficients are not constant, or
8752 /// (b) SolveQuadraticEquationWrap was unable to find a solution for the
8753 ///     bounds of the range.
8754 static Optional<APInt>
8755 SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec,
8756                           const ConstantRange &Range, ScalarEvolution &SE) {
8757   assert(AddRec->getOperand(0)->isZero() &&
8758          "Starting value of addrec should be 0");
8759   LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
8760                     << Range << ", addrec " << *AddRec << '\n');
8761   // This case is handled in getNumIterationsInRange. Here we can assume that
8762   // we start in the range.
8763   assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
8764          "Addrec's initial value should be in range");
8765 
8766   APInt A, B, C, M;
8767   unsigned BitWidth;
8768   auto T = GetQuadraticEquation(AddRec);
8769   if (!T.hasValue())
8770     return None;
8771 
8772   // Be careful about the return value: there can be two reasons for not
8773   // returning an actual number. First, if no solutions to the equations
8774   // were found, and second, if the solutions don't leave the given range.
8775   // The first case means that the actual solution is "unknown", the second
8776   // means that it's known, but not valid. If the solution is unknown, we
8777   // cannot make any conclusions.
8778   // Return a pair: the optional solution and a flag indicating if the
8779   // solution was found.
8780   auto SolveForBoundary = [&](APInt Bound) -> std::pair<Optional<APInt>,bool> {
8781     // Solve for signed overflow and unsigned overflow, pick the lower
8782     // solution.
8783     LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
8784                       << Bound << " (before multiplying by " << M << ")\n");
8785     Bound *= M; // The quadratic equation multiplier.
8786 
8787     Optional<APInt> SO = None;
8788     if (BitWidth > 1) {
8789       LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
8790                            "signed overflow\n");
8791       SO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, BitWidth);
8792     }
8793     LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
8794                          "unsigned overflow\n");
8795     Optional<APInt> UO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound,
8796                                                               BitWidth+1);
8797 
8798     auto LeavesRange = [&] (const APInt &X) {
8799       ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
8800       ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
8801       if (Range.contains(V0->getValue()))
8802         return false;
8803       // X should be at least 1, so X-1 is non-negative.
8804       ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
8805       ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
8806       if (Range.contains(V1->getValue()))
8807         return true;
8808       return false;
8809     };
8810 
8811     // If SolveQuadraticEquationWrap returns None, it means that there can
8812     // be a solution, but the function failed to find it. We cannot treat it
8813     // as "no solution".
8814     if (!SO.hasValue() || !UO.hasValue())
8815       return { None, false };
8816 
8817     // Check the smaller value first to see if it leaves the range.
8818     // At this point, both SO and UO must have values.
8819     Optional<APInt> Min = MinOptional(SO, UO);
8820     if (LeavesRange(*Min))
8821       return { Min, true };
8822     Optional<APInt> Max = Min == SO ? UO : SO;
8823     if (LeavesRange(*Max))
8824       return { Max, true };
8825 
8826     // Solutions were found, but were eliminated, hence the "true".
8827     return { None, true };
8828   };
8829 
8830   std::tie(A, B, C, M, BitWidth) = *T;
8831   // Lower bound is inclusive, subtract 1 to represent the exiting value.
8832   APInt Lower = Range.getLower().sextOrSelf(A.getBitWidth()) - 1;
8833   APInt Upper = Range.getUpper().sextOrSelf(A.getBitWidth());
8834   auto SL = SolveForBoundary(Lower);
8835   auto SU = SolveForBoundary(Upper);
8836   // If any of the solutions was unknown, no meaninigful conclusions can
8837   // be made.
8838   if (!SL.second || !SU.second)
8839     return None;
8840 
8841   // Claim: The correct solution is not some value between Min and Max.
8842   //
8843   // Justification: Assuming that Min and Max are different values, one of
8844   // them is when the first signed overflow happens, the other is when the
8845   // first unsigned overflow happens. Crossing the range boundary is only
8846   // possible via an overflow (treating 0 as a special case of it, modeling
8847   // an overflow as crossing k*2^W for some k).
8848   //
8849   // The interesting case here is when Min was eliminated as an invalid
8850   // solution, but Max was not. The argument is that if there was another
8851   // overflow between Min and Max, it would also have been eliminated if
8852   // it was considered.
8853   //
8854   // For a given boundary, it is possible to have two overflows of the same
8855   // type (signed/unsigned) without having the other type in between: this
8856   // can happen when the vertex of the parabola is between the iterations
8857   // corresponding to the overflows. This is only possible when the two
8858   // overflows cross k*2^W for the same k. In such case, if the second one
8859   // left the range (and was the first one to do so), the first overflow
8860   // would have to enter the range, which would mean that either we had left
8861   // the range before or that we started outside of it. Both of these cases
8862   // are contradictions.
8863   //
8864   // Claim: In the case where SolveForBoundary returns None, the correct
8865   // solution is not some value between the Max for this boundary and the
8866   // Min of the other boundary.
8867   //
8868   // Justification: Assume that we had such Max_A and Min_B corresponding
8869   // to range boundaries A and B and such that Max_A < Min_B. If there was
8870   // a solution between Max_A and Min_B, it would have to be caused by an
8871   // overflow corresponding to either A or B. It cannot correspond to B,
8872   // since Min_B is the first occurrence of such an overflow. If it
8873   // corresponded to A, it would have to be either a signed or an unsigned
8874   // overflow that is larger than both eliminated overflows for A. But
8875   // between the eliminated overflows and this overflow, the values would
8876   // cover the entire value space, thus crossing the other boundary, which
8877   // is a contradiction.
8878 
8879   return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
8880 }
8881 
8882 ScalarEvolution::ExitLimit
8883 ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
8884                               bool AllowPredicates) {
8885 
8886   // This is only used for loops with a "x != y" exit test. The exit condition
8887   // is now expressed as a single expression, V = x-y. So the exit test is
8888   // effectively V != 0.  We know and take advantage of the fact that this
8889   // expression only being used in a comparison by zero context.
8890 
8891   SmallPtrSet<const SCEVPredicate *, 4> Predicates;
8892   // If the value is a constant
8893   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
8894     // If the value is already zero, the branch will execute zero times.
8895     if (C->getValue()->isZero()) return C;
8896     return getCouldNotCompute();  // Otherwise it will loop infinitely.
8897   }
8898 
8899   const SCEVAddRecExpr *AddRec =
8900       dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
8901 
8902   if (!AddRec && AllowPredicates)
8903     // Try to make this an AddRec using runtime tests, in the first X
8904     // iterations of this loop, where X is the SCEV expression found by the
8905     // algorithm below.
8906     AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
8907 
8908   if (!AddRec || AddRec->getLoop() != L)
8909     return getCouldNotCompute();
8910 
8911   // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
8912   // the quadratic equation to solve it.
8913   if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
8914     // We can only use this value if the chrec ends up with an exact zero
8915     // value at this index.  When solving for "X*X != 5", for example, we
8916     // should not accept a root of 2.
8917     if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
8918       const auto *R = cast<SCEVConstant>(getConstant(S.getValue()));
8919       return ExitLimit(R, R, false, Predicates);
8920     }
8921     return getCouldNotCompute();
8922   }
8923 
8924   // Otherwise we can only handle this if it is affine.
8925   if (!AddRec->isAffine())
8926     return getCouldNotCompute();
8927 
8928   // If this is an affine expression, the execution count of this branch is
8929   // the minimum unsigned root of the following equation:
8930   //
8931   //     Start + Step*N = 0 (mod 2^BW)
8932   //
8933   // equivalent to:
8934   //
8935   //             Step*N = -Start (mod 2^BW)
8936   //
8937   // where BW is the common bit width of Start and Step.
8938 
8939   // Get the initial value for the loop.
8940   const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
8941   const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
8942 
8943   // For now we handle only constant steps.
8944   //
8945   // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
8946   // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
8947   // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
8948   // We have not yet seen any such cases.
8949   const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
8950   if (!StepC || StepC->getValue()->isZero())
8951     return getCouldNotCompute();
8952 
8953   // For positive steps (counting up until unsigned overflow):
8954   //   N = -Start/Step (as unsigned)
8955   // For negative steps (counting down to zero):
8956   //   N = Start/-Step
8957   // First compute the unsigned distance from zero in the direction of Step.
8958   bool CountDown = StepC->getAPInt().isNegative();
8959   const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
8960 
8961   // Handle unitary steps, which cannot wraparound.
8962   // 1*N = -Start; -1*N = Start (mod 2^BW), so:
8963   //   N = Distance (as unsigned)
8964   if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) {
8965     APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
8966     APInt MaxBECountBase = getUnsignedRangeMax(Distance);
8967     if (MaxBECountBase.ult(MaxBECount))
8968       MaxBECount = MaxBECountBase;
8969 
8970     // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
8971     // we end up with a loop whose backedge-taken count is n - 1.  Detect this
8972     // case, and see if we can improve the bound.
8973     //
8974     // Explicitly handling this here is necessary because getUnsignedRange
8975     // isn't context-sensitive; it doesn't know that we only care about the
8976     // range inside the loop.
8977     const SCEV *Zero = getZero(Distance->getType());
8978     const SCEV *One = getOne(Distance->getType());
8979     const SCEV *DistancePlusOne = getAddExpr(Distance, One);
8980     if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
8981       // If Distance + 1 doesn't overflow, we can compute the maximum distance
8982       // as "unsigned_max(Distance + 1) - 1".
8983       ConstantRange CR = getUnsignedRange(DistancePlusOne);
8984       MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
8985     }
8986     return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates);
8987   }
8988 
8989   // If the condition controls loop exit (the loop exits only if the expression
8990   // is true) and the addition is no-wrap we can use unsigned divide to
8991   // compute the backedge count.  In this case, the step may not divide the
8992   // distance, but we don't care because if the condition is "missed" the loop
8993   // will have undefined behavior due to wrapping.
8994   if (ControlsExit && AddRec->hasNoSelfWrap() &&
8995       loopHasNoAbnormalExits(AddRec->getLoop())) {
8996     const SCEV *Exact =
8997         getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
8998     const SCEV *Max =
8999         Exact == getCouldNotCompute()
9000             ? Exact
9001             : getConstant(getUnsignedRangeMax(Exact));
9002     return ExitLimit(Exact, Max, false, Predicates);
9003   }
9004 
9005   // Solve the general equation.
9006   const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(),
9007                                                getNegativeSCEV(Start), *this);
9008   const SCEV *M = E == getCouldNotCompute()
9009                       ? E
9010                       : getConstant(getUnsignedRangeMax(E));
9011   return ExitLimit(E, M, false, Predicates);
9012 }
9013 
9014 ScalarEvolution::ExitLimit
9015 ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
9016   // Loops that look like: while (X == 0) are very strange indeed.  We don't
9017   // handle them yet except for the trivial case.  This could be expanded in the
9018   // future as needed.
9019 
9020   // If the value is a constant, check to see if it is known to be non-zero
9021   // already.  If so, the backedge will execute zero times.
9022   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
9023     if (!C->getValue()->isZero())
9024       return getZero(C->getType());
9025     return getCouldNotCompute();  // Otherwise it will loop infinitely.
9026   }
9027 
9028   // We could implement others, but I really doubt anyone writes loops like
9029   // this, and if they did, they would already be constant folded.
9030   return getCouldNotCompute();
9031 }
9032 
9033 std::pair<const BasicBlock *, const BasicBlock *>
9034 ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
9035     const {
9036   // If the block has a unique predecessor, then there is no path from the
9037   // predecessor to the block that does not go through the direct edge
9038   // from the predecessor to the block.
9039   if (const BasicBlock *Pred = BB->getSinglePredecessor())
9040     return {Pred, BB};
9041 
9042   // A loop's header is defined to be a block that dominates the loop.
9043   // If the header has a unique predecessor outside the loop, it must be
9044   // a block that has exactly one successor that can reach the loop.
9045   if (const Loop *L = LI.getLoopFor(BB))
9046     return {L->getLoopPredecessor(), L->getHeader()};
9047 
9048   return {nullptr, nullptr};
9049 }
9050 
9051 /// SCEV structural equivalence is usually sufficient for testing whether two
9052 /// expressions are equal, however for the purposes of looking for a condition
9053 /// guarding a loop, it can be useful to be a little more general, since a
9054 /// front-end may have replicated the controlling expression.
9055 static bool HasSameValue(const SCEV *A, const SCEV *B) {
9056   // Quick check to see if they are the same SCEV.
9057   if (A == B) return true;
9058 
9059   auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
9060     // Not all instructions that are "identical" compute the same value.  For
9061     // instance, two distinct alloca instructions allocating the same type are
9062     // identical and do not read memory; but compute distinct values.
9063     return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
9064   };
9065 
9066   // Otherwise, if they're both SCEVUnknown, it's possible that they hold
9067   // two different instructions with the same value. Check for this case.
9068   if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
9069     if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
9070       if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
9071         if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
9072           if (ComputesEqualValues(AI, BI))
9073             return true;
9074 
9075   // Otherwise assume they may have a different value.
9076   return false;
9077 }
9078 
9079 bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
9080                                            const SCEV *&LHS, const SCEV *&RHS,
9081                                            unsigned Depth) {
9082   bool Changed = false;
9083   // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
9084   // '0 != 0'.
9085   auto TrivialCase = [&](bool TriviallyTrue) {
9086     LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
9087     Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
9088     return true;
9089   };
9090   // If we hit the max recursion limit bail out.
9091   if (Depth >= 3)
9092     return false;
9093 
9094   // Canonicalize a constant to the right side.
9095   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
9096     // Check for both operands constant.
9097     if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
9098       if (ConstantExpr::getICmp(Pred,
9099                                 LHSC->getValue(),
9100                                 RHSC->getValue())->isNullValue())
9101         return TrivialCase(false);
9102       else
9103         return TrivialCase(true);
9104     }
9105     // Otherwise swap the operands to put the constant on the right.
9106     std::swap(LHS, RHS);
9107     Pred = ICmpInst::getSwappedPredicate(Pred);
9108     Changed = true;
9109   }
9110 
9111   // If we're comparing an addrec with a value which is loop-invariant in the
9112   // addrec's loop, put the addrec on the left. Also make a dominance check,
9113   // as both operands could be addrecs loop-invariant in each other's loop.
9114   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
9115     const Loop *L = AR->getLoop();
9116     if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
9117       std::swap(LHS, RHS);
9118       Pred = ICmpInst::getSwappedPredicate(Pred);
9119       Changed = true;
9120     }
9121   }
9122 
9123   // If there's a constant operand, canonicalize comparisons with boundary
9124   // cases, and canonicalize *-or-equal comparisons to regular comparisons.
9125   if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
9126     const APInt &RA = RC->getAPInt();
9127 
9128     bool SimplifiedByConstantRange = false;
9129 
9130     if (!ICmpInst::isEquality(Pred)) {
9131       ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA);
9132       if (ExactCR.isFullSet())
9133         return TrivialCase(true);
9134       else if (ExactCR.isEmptySet())
9135         return TrivialCase(false);
9136 
9137       APInt NewRHS;
9138       CmpInst::Predicate NewPred;
9139       if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
9140           ICmpInst::isEquality(NewPred)) {
9141         // We were able to convert an inequality to an equality.
9142         Pred = NewPred;
9143         RHS = getConstant(NewRHS);
9144         Changed = SimplifiedByConstantRange = true;
9145       }
9146     }
9147 
9148     if (!SimplifiedByConstantRange) {
9149       switch (Pred) {
9150       default:
9151         break;
9152       case ICmpInst::ICMP_EQ:
9153       case ICmpInst::ICMP_NE:
9154         // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
9155         if (!RA)
9156           if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS))
9157             if (const SCEVMulExpr *ME =
9158                     dyn_cast<SCEVMulExpr>(AE->getOperand(0)))
9159               if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 &&
9160                   ME->getOperand(0)->isAllOnesValue()) {
9161                 RHS = AE->getOperand(1);
9162                 LHS = ME->getOperand(1);
9163                 Changed = true;
9164               }
9165         break;
9166 
9167 
9168         // The "Should have been caught earlier!" messages refer to the fact
9169         // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
9170         // should have fired on the corresponding cases, and canonicalized the
9171         // check to trivial case.
9172 
9173       case ICmpInst::ICMP_UGE:
9174         assert(!RA.isMinValue() && "Should have been caught earlier!");
9175         Pred = ICmpInst::ICMP_UGT;
9176         RHS = getConstant(RA - 1);
9177         Changed = true;
9178         break;
9179       case ICmpInst::ICMP_ULE:
9180         assert(!RA.isMaxValue() && "Should have been caught earlier!");
9181         Pred = ICmpInst::ICMP_ULT;
9182         RHS = getConstant(RA + 1);
9183         Changed = true;
9184         break;
9185       case ICmpInst::ICMP_SGE:
9186         assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
9187         Pred = ICmpInst::ICMP_SGT;
9188         RHS = getConstant(RA - 1);
9189         Changed = true;
9190         break;
9191       case ICmpInst::ICMP_SLE:
9192         assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
9193         Pred = ICmpInst::ICMP_SLT;
9194         RHS = getConstant(RA + 1);
9195         Changed = true;
9196         break;
9197       }
9198     }
9199   }
9200 
9201   // Check for obvious equality.
9202   if (HasSameValue(LHS, RHS)) {
9203     if (ICmpInst::isTrueWhenEqual(Pred))
9204       return TrivialCase(true);
9205     if (ICmpInst::isFalseWhenEqual(Pred))
9206       return TrivialCase(false);
9207   }
9208 
9209   // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
9210   // adding or subtracting 1 from one of the operands.
9211   switch (Pred) {
9212   case ICmpInst::ICMP_SLE:
9213     if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
9214       RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
9215                        SCEV::FlagNSW);
9216       Pred = ICmpInst::ICMP_SLT;
9217       Changed = true;
9218     } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
9219       LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
9220                        SCEV::FlagNSW);
9221       Pred = ICmpInst::ICMP_SLT;
9222       Changed = true;
9223     }
9224     break;
9225   case ICmpInst::ICMP_SGE:
9226     if (!getSignedRangeMin(RHS).isMinSignedValue()) {
9227       RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
9228                        SCEV::FlagNSW);
9229       Pred = ICmpInst::ICMP_SGT;
9230       Changed = true;
9231     } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
9232       LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
9233                        SCEV::FlagNSW);
9234       Pred = ICmpInst::ICMP_SGT;
9235       Changed = true;
9236     }
9237     break;
9238   case ICmpInst::ICMP_ULE:
9239     if (!getUnsignedRangeMax(RHS).isMaxValue()) {
9240       RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
9241                        SCEV::FlagNUW);
9242       Pred = ICmpInst::ICMP_ULT;
9243       Changed = true;
9244     } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
9245       LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
9246       Pred = ICmpInst::ICMP_ULT;
9247       Changed = true;
9248     }
9249     break;
9250   case ICmpInst::ICMP_UGE:
9251     if (!getUnsignedRangeMin(RHS).isMinValue()) {
9252       RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
9253       Pred = ICmpInst::ICMP_UGT;
9254       Changed = true;
9255     } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
9256       LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
9257                        SCEV::FlagNUW);
9258       Pred = ICmpInst::ICMP_UGT;
9259       Changed = true;
9260     }
9261     break;
9262   default:
9263     break;
9264   }
9265 
9266   // TODO: More simplifications are possible here.
9267 
9268   // Recursively simplify until we either hit a recursion limit or nothing
9269   // changes.
9270   if (Changed)
9271     return SimplifyICmpOperands(Pred, LHS, RHS, Depth+1);
9272 
9273   return Changed;
9274 }
9275 
9276 bool ScalarEvolution::isKnownNegative(const SCEV *S) {
9277   return getSignedRangeMax(S).isNegative();
9278 }
9279 
9280 bool ScalarEvolution::isKnownPositive(const SCEV *S) {
9281   return getSignedRangeMin(S).isStrictlyPositive();
9282 }
9283 
9284 bool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
9285   return !getSignedRangeMin(S).isNegative();
9286 }
9287 
9288 bool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
9289   return !getSignedRangeMax(S).isStrictlyPositive();
9290 }
9291 
9292 bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
9293   return isKnownNegative(S) || isKnownPositive(S);
9294 }
9295 
9296 std::pair<const SCEV *, const SCEV *>
9297 ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
9298   // Compute SCEV on entry of loop L.
9299   const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
9300   if (Start == getCouldNotCompute())
9301     return { Start, Start };
9302   // Compute post increment SCEV for loop L.
9303   const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
9304   assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
9305   return { Start, PostInc };
9306 }
9307 
9308 bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred,
9309                                           const SCEV *LHS, const SCEV *RHS) {
9310   // First collect all loops.
9311   SmallPtrSet<const Loop *, 8> LoopsUsed;
9312   getUsedLoops(LHS, LoopsUsed);
9313   getUsedLoops(RHS, LoopsUsed);
9314 
9315   if (LoopsUsed.empty())
9316     return false;
9317 
9318   // Domination relationship must be a linear order on collected loops.
9319 #ifndef NDEBUG
9320   for (auto *L1 : LoopsUsed)
9321     for (auto *L2 : LoopsUsed)
9322       assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
9323               DT.dominates(L2->getHeader(), L1->getHeader())) &&
9324              "Domination relationship is not a linear order");
9325 #endif
9326 
9327   const Loop *MDL =
9328       *std::max_element(LoopsUsed.begin(), LoopsUsed.end(),
9329                         [&](const Loop *L1, const Loop *L2) {
9330          return DT.properlyDominates(L1->getHeader(), L2->getHeader());
9331        });
9332 
9333   // Get init and post increment value for LHS.
9334   auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
9335   // if LHS contains unknown non-invariant SCEV then bail out.
9336   if (SplitLHS.first == getCouldNotCompute())
9337     return false;
9338   assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
9339   // Get init and post increment value for RHS.
9340   auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
9341   // if RHS contains unknown non-invariant SCEV then bail out.
9342   if (SplitRHS.first == getCouldNotCompute())
9343     return false;
9344   assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
9345   // It is possible that init SCEV contains an invariant load but it does
9346   // not dominate MDL and is not available at MDL loop entry, so we should
9347   // check it here.
9348   if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
9349       !isAvailableAtLoopEntry(SplitRHS.first, MDL))
9350     return false;
9351 
9352   // It seems backedge guard check is faster than entry one so in some cases
9353   // it can speed up whole estimation by short circuit
9354   return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
9355                                      SplitRHS.second) &&
9356          isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
9357 }
9358 
9359 bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
9360                                        const SCEV *LHS, const SCEV *RHS) {
9361   // Canonicalize the inputs first.
9362   (void)SimplifyICmpOperands(Pred, LHS, RHS);
9363 
9364   if (isKnownViaInduction(Pred, LHS, RHS))
9365     return true;
9366 
9367   if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
9368     return true;
9369 
9370   // Otherwise see what can be done with some simple reasoning.
9371   return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
9372 }
9373 
9374 bool ScalarEvolution::isKnownPredicateAt(ICmpInst::Predicate Pred,
9375                                          const SCEV *LHS, const SCEV *RHS,
9376                                          const Instruction *Context) {
9377   // TODO: Analyze guards and assumes from Context's block.
9378   return isKnownPredicate(Pred, LHS, RHS) ||
9379          isBasicBlockEntryGuardedByCond(Context->getParent(), Pred, LHS, RHS);
9380 }
9381 
9382 bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred,
9383                                               const SCEVAddRecExpr *LHS,
9384                                               const SCEV *RHS) {
9385   const Loop *L = LHS->getLoop();
9386   return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
9387          isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
9388 }
9389 
9390 Optional<ScalarEvolution::MonotonicPredicateType>
9391 ScalarEvolution::getMonotonicPredicateType(const SCEVAddRecExpr *LHS,
9392                                            ICmpInst::Predicate Pred) {
9393   auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
9394 
9395 #ifndef NDEBUG
9396   // Verify an invariant: inverting the predicate should turn a monotonically
9397   // increasing change to a monotonically decreasing one, and vice versa.
9398   if (Result) {
9399     auto ResultSwapped =
9400         getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
9401 
9402     assert(ResultSwapped.hasValue() && "should be able to analyze both!");
9403     assert(ResultSwapped.getValue() != Result.getValue() &&
9404            "monotonicity should flip as we flip the predicate");
9405   }
9406 #endif
9407 
9408   return Result;
9409 }
9410 
9411 Optional<ScalarEvolution::MonotonicPredicateType>
9412 ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
9413                                                ICmpInst::Predicate Pred) {
9414   // A zero step value for LHS means the induction variable is essentially a
9415   // loop invariant value. We don't really depend on the predicate actually
9416   // flipping from false to true (for increasing predicates, and the other way
9417   // around for decreasing predicates), all we care about is that *if* the
9418   // predicate changes then it only changes from false to true.
9419   //
9420   // A zero step value in itself is not very useful, but there may be places
9421   // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
9422   // as general as possible.
9423 
9424   // Only handle LE/LT/GE/GT predicates.
9425   if (!ICmpInst::isRelational(Pred))
9426     return None;
9427 
9428   bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
9429   assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
9430          "Should be greater or less!");
9431 
9432   // Check that AR does not wrap.
9433   if (ICmpInst::isUnsigned(Pred)) {
9434     if (!LHS->hasNoUnsignedWrap())
9435       return None;
9436     return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
9437   } else {
9438     assert(ICmpInst::isSigned(Pred) &&
9439            "Relational predicate is either signed or unsigned!");
9440     if (!LHS->hasNoSignedWrap())
9441       return None;
9442 
9443     const SCEV *Step = LHS->getStepRecurrence(*this);
9444 
9445     if (isKnownNonNegative(Step))
9446       return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
9447 
9448     if (isKnownNonPositive(Step))
9449       return !IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
9450 
9451     return None;
9452   }
9453 }
9454 
9455 bool ScalarEvolution::isLoopInvariantPredicate(
9456     ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
9457     ICmpInst::Predicate &InvariantPred, const SCEV *&InvariantLHS,
9458     const SCEV *&InvariantRHS) {
9459 
9460   // If there is a loop-invariant, force it into the RHS, otherwise bail out.
9461   if (!isLoopInvariant(RHS, L)) {
9462     if (!isLoopInvariant(LHS, L))
9463       return false;
9464 
9465     std::swap(LHS, RHS);
9466     Pred = ICmpInst::getSwappedPredicate(Pred);
9467   }
9468 
9469   const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
9470   if (!ArLHS || ArLHS->getLoop() != L)
9471     return false;
9472 
9473   auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
9474   if (!MonotonicType)
9475     return false;
9476   // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
9477   // true as the loop iterates, and the backedge is control dependent on
9478   // "ArLHS `Pred` RHS" == true then we can reason as follows:
9479   //
9480   //   * if the predicate was false in the first iteration then the predicate
9481   //     is never evaluated again, since the loop exits without taking the
9482   //     backedge.
9483   //   * if the predicate was true in the first iteration then it will
9484   //     continue to be true for all future iterations since it is
9485   //     monotonically increasing.
9486   //
9487   // For both the above possibilities, we can replace the loop varying
9488   // predicate with its value on the first iteration of the loop (which is
9489   // loop invariant).
9490   //
9491   // A similar reasoning applies for a monotonically decreasing predicate, by
9492   // replacing true with false and false with true in the above two bullets.
9493   bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing;
9494   auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
9495 
9496   if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
9497     return false;
9498 
9499   InvariantPred = Pred;
9500   InvariantLHS = ArLHS->getStart();
9501   InvariantRHS = RHS;
9502   return true;
9503 }
9504 
9505 bool ScalarEvolution::isLoopInvariantExitCondDuringFirstIterations(
9506     ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
9507     const Instruction *Context, const SCEV *MaxIter,
9508     ICmpInst::Predicate &InvariantPred, const SCEV *&InvariantLHS,
9509     const SCEV *&InvariantRHS) {
9510   // Try to prove the following set of facts:
9511   // - The predicate is monotonic.
9512   // - If the check does not fail on the 1st iteration:
9513   //   - No overflow will happen during first MaxIter iterations;
9514   //   - It will not fail on the MaxIter'th iteration.
9515   // If the check does fail on the 1st iteration, we leave the loop and no
9516   // other checks matter.
9517 
9518   // If there is a loop-invariant, force it into the RHS, otherwise bail out.
9519   if (!isLoopInvariant(RHS, L)) {
9520     if (!isLoopInvariant(LHS, L))
9521       return false;
9522 
9523     std::swap(LHS, RHS);
9524     Pred = ICmpInst::getSwappedPredicate(Pred);
9525   }
9526 
9527   auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9528   // TODO: Lift affinity limitation in the future.
9529   if (!AR || AR->getLoop() != L || !AR->isAffine())
9530     return false;
9531 
9532   // The predicate must be relational (i.e. <, <=, >=, >).
9533   if (!ICmpInst::isRelational(Pred))
9534     return false;
9535 
9536   // TODO: Support steps other than +/- 1.
9537   const SCEV *Step = AR->getOperand(1);
9538   auto *One = getOne(Step->getType());
9539   auto *MinusOne = getNegativeSCEV(One);
9540   if (Step != One && Step != MinusOne)
9541     return false;
9542 
9543   // Type mismatch here means that MaxIter is potentially larger than max
9544   // unsigned value in start type, which mean we cannot prove no wrap for the
9545   // indvar.
9546   if (AR->getType() != MaxIter->getType())
9547     return false;
9548 
9549   // Value of IV on suggested last iteration.
9550   const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
9551   // Does it still meet the requirement?
9552   if (!isKnownPredicateAt(Pred, Last, RHS, Context))
9553     return false;
9554   // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
9555   // not exceed max unsigned value of this type), this effectively proves
9556   // that there is no wrap during the iteration. To prove that there is no
9557   // signed/unsigned wrap, we need to check that
9558   // Start <= Last for step = 1 or Start >= Last for step = -1.
9559   ICmpInst::Predicate NoOverflowPred =
9560       CmpInst::isSigned(Pred) ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
9561   if (Step == MinusOne)
9562     NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred);
9563   const SCEV *Start = AR->getStart();
9564   if (!isKnownPredicateAt(NoOverflowPred, Start, Last, Context))
9565     return false;
9566 
9567   // Everything is fine.
9568   InvariantPred = Pred;
9569   InvariantLHS = Start;
9570   InvariantRHS = RHS;
9571   return true;
9572 }
9573 
9574 bool ScalarEvolution::isKnownPredicateViaConstantRanges(
9575     ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
9576   if (HasSameValue(LHS, RHS))
9577     return ICmpInst::isTrueWhenEqual(Pred);
9578 
9579   // This code is split out from isKnownPredicate because it is called from
9580   // within isLoopEntryGuardedByCond.
9581 
9582   auto CheckRanges =
9583       [&](const ConstantRange &RangeLHS, const ConstantRange &RangeRHS) {
9584     return ConstantRange::makeSatisfyingICmpRegion(Pred, RangeRHS)
9585         .contains(RangeLHS);
9586   };
9587 
9588   // The check at the top of the function catches the case where the values are
9589   // known to be equal.
9590   if (Pred == CmpInst::ICMP_EQ)
9591     return false;
9592 
9593   if (Pred == CmpInst::ICMP_NE)
9594     return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)) ||
9595            CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)) ||
9596            isKnownNonZero(getMinusSCEV(LHS, RHS));
9597 
9598   if (CmpInst::isSigned(Pred))
9599     return CheckRanges(getSignedRange(LHS), getSignedRange(RHS));
9600 
9601   return CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS));
9602 }
9603 
9604 bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
9605                                                     const SCEV *LHS,
9606                                                     const SCEV *RHS) {
9607   // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer.
9608   // Return Y via OutY.
9609   auto MatchBinaryAddToConst =
9610       [this](const SCEV *Result, const SCEV *X, APInt &OutY,
9611              SCEV::NoWrapFlags ExpectedFlags) {
9612     const SCEV *NonConstOp, *ConstOp;
9613     SCEV::NoWrapFlags FlagsPresent;
9614 
9615     if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) ||
9616         !isa<SCEVConstant>(ConstOp) || NonConstOp != X)
9617       return false;
9618 
9619     OutY = cast<SCEVConstant>(ConstOp)->getAPInt();
9620     return (FlagsPresent & ExpectedFlags) == ExpectedFlags;
9621   };
9622 
9623   APInt C;
9624 
9625   switch (Pred) {
9626   default:
9627     break;
9628 
9629   case ICmpInst::ICMP_SGE:
9630     std::swap(LHS, RHS);
9631     LLVM_FALLTHROUGH;
9632   case ICmpInst::ICMP_SLE:
9633     // X s<= (X + C)<nsw> if C >= 0
9634     if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative())
9635       return true;
9636 
9637     // (X + C)<nsw> s<= X if C <= 0
9638     if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) &&
9639         !C.isStrictlyPositive())
9640       return true;
9641     break;
9642 
9643   case ICmpInst::ICMP_SGT:
9644     std::swap(LHS, RHS);
9645     LLVM_FALLTHROUGH;
9646   case ICmpInst::ICMP_SLT:
9647     // X s< (X + C)<nsw> if C > 0
9648     if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) &&
9649         C.isStrictlyPositive())
9650       return true;
9651 
9652     // (X + C)<nsw> s< X if C < 0
9653     if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
9654       return true;
9655     break;
9656 
9657   case ICmpInst::ICMP_UGE:
9658     std::swap(LHS, RHS);
9659     LLVM_FALLTHROUGH;
9660   case ICmpInst::ICMP_ULE:
9661     // X u<= (X + C)<nuw> for any C
9662     if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW))
9663       return true;
9664     break;
9665 
9666   case ICmpInst::ICMP_UGT:
9667     std::swap(LHS, RHS);
9668     LLVM_FALLTHROUGH;
9669   case ICmpInst::ICMP_ULT:
9670     // X u< (X + C)<nuw> if C != 0
9671     if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW) && !C.isNullValue())
9672       return true;
9673     break;
9674   }
9675 
9676   return false;
9677 }
9678 
9679 bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
9680                                                    const SCEV *LHS,
9681                                                    const SCEV *RHS) {
9682   if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
9683     return false;
9684 
9685   // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
9686   // the stack can result in exponential time complexity.
9687   SaveAndRestore<bool> Restore(ProvingSplitPredicate, true);
9688 
9689   // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
9690   //
9691   // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
9692   // isKnownPredicate.  isKnownPredicate is more powerful, but also more
9693   // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
9694   // interesting cases seen in practice.  We can consider "upgrading" L >= 0 to
9695   // use isKnownPredicate later if needed.
9696   return isKnownNonNegative(RHS) &&
9697          isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
9698          isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
9699 }
9700 
9701 bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
9702                                         ICmpInst::Predicate Pred,
9703                                         const SCEV *LHS, const SCEV *RHS) {
9704   // No need to even try if we know the module has no guards.
9705   if (!HasGuards)
9706     return false;
9707 
9708   return any_of(*BB, [&](const Instruction &I) {
9709     using namespace llvm::PatternMatch;
9710 
9711     Value *Condition;
9712     return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
9713                          m_Value(Condition))) &&
9714            isImpliedCond(Pred, LHS, RHS, Condition, false);
9715   });
9716 }
9717 
9718 /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
9719 /// protected by a conditional between LHS and RHS.  This is used to
9720 /// to eliminate casts.
9721 bool
9722 ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
9723                                              ICmpInst::Predicate Pred,
9724                                              const SCEV *LHS, const SCEV *RHS) {
9725   // Interpret a null as meaning no loop, where there is obviously no guard
9726   // (interprocedural conditions notwithstanding).
9727   if (!L) return true;
9728 
9729   if (VerifyIR)
9730     assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
9731            "This cannot be done on broken IR!");
9732 
9733 
9734   if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
9735     return true;
9736 
9737   BasicBlock *Latch = L->getLoopLatch();
9738   if (!Latch)
9739     return false;
9740 
9741   BranchInst *LoopContinuePredicate =
9742     dyn_cast<BranchInst>(Latch->getTerminator());
9743   if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
9744       isImpliedCond(Pred, LHS, RHS,
9745                     LoopContinuePredicate->getCondition(),
9746                     LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
9747     return true;
9748 
9749   // We don't want more than one activation of the following loops on the stack
9750   // -- that can lead to O(n!) time complexity.
9751   if (WalkingBEDominatingConds)
9752     return false;
9753 
9754   SaveAndRestore<bool> ClearOnExit(WalkingBEDominatingConds, true);
9755 
9756   // See if we can exploit a trip count to prove the predicate.
9757   const auto &BETakenInfo = getBackedgeTakenInfo(L);
9758   const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
9759   if (LatchBECount != getCouldNotCompute()) {
9760     // We know that Latch branches back to the loop header exactly
9761     // LatchBECount times.  This means the backdege condition at Latch is
9762     // equivalent to  "{0,+,1} u< LatchBECount".
9763     Type *Ty = LatchBECount->getType();
9764     auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
9765     const SCEV *LoopCounter =
9766       getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
9767     if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
9768                       LatchBECount))
9769       return true;
9770   }
9771 
9772   // Check conditions due to any @llvm.assume intrinsics.
9773   for (auto &AssumeVH : AC.assumptions()) {
9774     if (!AssumeVH)
9775       continue;
9776     auto *CI = cast<CallInst>(AssumeVH);
9777     if (!DT.dominates(CI, Latch->getTerminator()))
9778       continue;
9779 
9780     if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
9781       return true;
9782   }
9783 
9784   // If the loop is not reachable from the entry block, we risk running into an
9785   // infinite loop as we walk up into the dom tree.  These loops do not matter
9786   // anyway, so we just return a conservative answer when we see them.
9787   if (!DT.isReachableFromEntry(L->getHeader()))
9788     return false;
9789 
9790   if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
9791     return true;
9792 
9793   for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
9794        DTN != HeaderDTN; DTN = DTN->getIDom()) {
9795     assert(DTN && "should reach the loop header before reaching the root!");
9796 
9797     BasicBlock *BB = DTN->getBlock();
9798     if (isImpliedViaGuard(BB, Pred, LHS, RHS))
9799       return true;
9800 
9801     BasicBlock *PBB = BB->getSinglePredecessor();
9802     if (!PBB)
9803       continue;
9804 
9805     BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
9806     if (!ContinuePredicate || !ContinuePredicate->isConditional())
9807       continue;
9808 
9809     Value *Condition = ContinuePredicate->getCondition();
9810 
9811     // If we have an edge `E` within the loop body that dominates the only
9812     // latch, the condition guarding `E` also guards the backedge.  This
9813     // reasoning works only for loops with a single latch.
9814 
9815     BasicBlockEdge DominatingEdge(PBB, BB);
9816     if (DominatingEdge.isSingleEdge()) {
9817       // We're constructively (and conservatively) enumerating edges within the
9818       // loop body that dominate the latch.  The dominator tree better agree
9819       // with us on this:
9820       assert(DT.dominates(DominatingEdge, Latch) && "should be!");
9821 
9822       if (isImpliedCond(Pred, LHS, RHS, Condition,
9823                         BB != ContinuePredicate->getSuccessor(0)))
9824         return true;
9825     }
9826   }
9827 
9828   return false;
9829 }
9830 
9831 bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB,
9832                                                      ICmpInst::Predicate Pred,
9833                                                      const SCEV *LHS,
9834                                                      const SCEV *RHS) {
9835   if (VerifyIR)
9836     assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
9837            "This cannot be done on broken IR!");
9838 
9839   if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
9840     return true;
9841 
9842   // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
9843   // the facts (a >= b && a != b) separately. A typical situation is when the
9844   // non-strict comparison is known from ranges and non-equality is known from
9845   // dominating predicates. If we are proving strict comparison, we always try
9846   // to prove non-equality and non-strict comparison separately.
9847   auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
9848   const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
9849   bool ProvedNonStrictComparison = false;
9850   bool ProvedNonEquality = false;
9851 
9852   if (ProvingStrictComparison) {
9853     ProvedNonStrictComparison =
9854         isKnownViaNonRecursiveReasoning(NonStrictPredicate, LHS, RHS);
9855     ProvedNonEquality =
9856         isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, LHS, RHS);
9857     if (ProvedNonStrictComparison && ProvedNonEquality)
9858       return true;
9859   }
9860 
9861   // Try to prove (Pred, LHS, RHS) using isImpliedViaGuard.
9862   auto ProveViaGuard = [&](const BasicBlock *Block) {
9863     if (isImpliedViaGuard(Block, Pred, LHS, RHS))
9864       return true;
9865     if (ProvingStrictComparison) {
9866       if (!ProvedNonStrictComparison)
9867         ProvedNonStrictComparison =
9868             isImpliedViaGuard(Block, NonStrictPredicate, LHS, RHS);
9869       if (!ProvedNonEquality)
9870         ProvedNonEquality =
9871             isImpliedViaGuard(Block, ICmpInst::ICMP_NE, LHS, RHS);
9872       if (ProvedNonStrictComparison && ProvedNonEquality)
9873         return true;
9874     }
9875     return false;
9876   };
9877 
9878   // Try to prove (Pred, LHS, RHS) using isImpliedCond.
9879   auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
9880     const Instruction *Context = &BB->front();
9881     if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, Context))
9882       return true;
9883     if (ProvingStrictComparison) {
9884       if (!ProvedNonStrictComparison)
9885         ProvedNonStrictComparison = isImpliedCond(NonStrictPredicate, LHS, RHS,
9886                                                   Condition, Inverse, Context);
9887       if (!ProvedNonEquality)
9888         ProvedNonEquality = isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS,
9889                                           Condition, Inverse, Context);
9890       if (ProvedNonStrictComparison && ProvedNonEquality)
9891         return true;
9892     }
9893     return false;
9894   };
9895 
9896   // Starting at the block's predecessor, climb up the predecessor chain, as long
9897   // as there are predecessors that can be found that have unique successors
9898   // leading to the original block.
9899   const Loop *ContainingLoop = LI.getLoopFor(BB);
9900   const BasicBlock *PredBB;
9901   if (ContainingLoop && ContainingLoop->getHeader() == BB)
9902     PredBB = ContainingLoop->getLoopPredecessor();
9903   else
9904     PredBB = BB->getSinglePredecessor();
9905   for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
9906        Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
9907     if (ProveViaGuard(Pair.first))
9908       return true;
9909 
9910     const BranchInst *LoopEntryPredicate =
9911         dyn_cast<BranchInst>(Pair.first->getTerminator());
9912     if (!LoopEntryPredicate ||
9913         LoopEntryPredicate->isUnconditional())
9914       continue;
9915 
9916     if (ProveViaCond(LoopEntryPredicate->getCondition(),
9917                      LoopEntryPredicate->getSuccessor(0) != Pair.second))
9918       return true;
9919   }
9920 
9921   // Check conditions due to any @llvm.assume intrinsics.
9922   for (auto &AssumeVH : AC.assumptions()) {
9923     if (!AssumeVH)
9924       continue;
9925     auto *CI = cast<CallInst>(AssumeVH);
9926     if (!DT.dominates(CI, BB))
9927       continue;
9928 
9929     if (ProveViaCond(CI->getArgOperand(0), false))
9930       return true;
9931   }
9932 
9933   return false;
9934 }
9935 
9936 bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
9937                                                ICmpInst::Predicate Pred,
9938                                                const SCEV *LHS,
9939                                                const SCEV *RHS) {
9940   // Interpret a null as meaning no loop, where there is obviously no guard
9941   // (interprocedural conditions notwithstanding).
9942   if (!L)
9943     return false;
9944 
9945   // Both LHS and RHS must be available at loop entry.
9946   assert(isAvailableAtLoopEntry(LHS, L) &&
9947          "LHS is not available at Loop Entry");
9948   assert(isAvailableAtLoopEntry(RHS, L) &&
9949          "RHS is not available at Loop Entry");
9950   return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
9951 }
9952 
9953 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
9954                                     const SCEV *RHS,
9955                                     const Value *FoundCondValue, bool Inverse,
9956                                     const Instruction *Context) {
9957   if (!PendingLoopPredicates.insert(FoundCondValue).second)
9958     return false;
9959 
9960   auto ClearOnExit =
9961       make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
9962 
9963   // Recursively handle And and Or conditions.
9964   if (const BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
9965     if (BO->getOpcode() == Instruction::And) {
9966       if (!Inverse)
9967         return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse,
9968                              Context) ||
9969                isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse,
9970                              Context);
9971     } else if (BO->getOpcode() == Instruction::Or) {
9972       if (Inverse)
9973         return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse,
9974                              Context) ||
9975                isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse,
9976                              Context);
9977     }
9978   }
9979 
9980   const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
9981   if (!ICI) return false;
9982 
9983   // Now that we found a conditional branch that dominates the loop or controls
9984   // the loop latch. Check to see if it is the comparison we are looking for.
9985   ICmpInst::Predicate FoundPred;
9986   if (Inverse)
9987     FoundPred = ICI->getInversePredicate();
9988   else
9989     FoundPred = ICI->getPredicate();
9990 
9991   const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
9992   const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
9993 
9994   return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, Context);
9995 }
9996 
9997 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
9998                                     const SCEV *RHS,
9999                                     ICmpInst::Predicate FoundPred,
10000                                     const SCEV *FoundLHS, const SCEV *FoundRHS,
10001                                     const Instruction *Context) {
10002   // Balance the types.
10003   if (getTypeSizeInBits(LHS->getType()) <
10004       getTypeSizeInBits(FoundLHS->getType())) {
10005     // For unsigned and equality predicates, try to prove that both found
10006     // operands fit into narrow unsigned range. If so, try to prove facts in
10007     // narrow types.
10008     if (!CmpInst::isSigned(FoundPred)) {
10009       auto *NarrowType = LHS->getType();
10010       auto *WideType = FoundLHS->getType();
10011       auto BitWidth = getTypeSizeInBits(NarrowType);
10012       const SCEV *MaxValue = getZeroExtendExpr(
10013           getConstant(APInt::getMaxValue(BitWidth)), WideType);
10014       if (isKnownPredicate(ICmpInst::ICMP_ULE, FoundLHS, MaxValue) &&
10015           isKnownPredicate(ICmpInst::ICMP_ULE, FoundRHS, MaxValue)) {
10016         const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
10017         const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
10018         if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
10019                                        TruncFoundRHS, Context))
10020           return true;
10021       }
10022     }
10023 
10024     if (CmpInst::isSigned(Pred)) {
10025       LHS = getSignExtendExpr(LHS, FoundLHS->getType());
10026       RHS = getSignExtendExpr(RHS, FoundLHS->getType());
10027     } else {
10028       LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
10029       RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
10030     }
10031   } else if (getTypeSizeInBits(LHS->getType()) >
10032       getTypeSizeInBits(FoundLHS->getType())) {
10033     if (CmpInst::isSigned(FoundPred)) {
10034       FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
10035       FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
10036     } else {
10037       FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
10038       FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
10039     }
10040   }
10041   return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
10042                                     FoundRHS, Context);
10043 }
10044 
10045 bool ScalarEvolution::isImpliedCondBalancedTypes(
10046     ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
10047     ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS,
10048     const Instruction *Context) {
10049   assert(getTypeSizeInBits(LHS->getType()) ==
10050              getTypeSizeInBits(FoundLHS->getType()) &&
10051          "Types should be balanced!");
10052   // Canonicalize the query to match the way instcombine will have
10053   // canonicalized the comparison.
10054   if (SimplifyICmpOperands(Pred, LHS, RHS))
10055     if (LHS == RHS)
10056       return CmpInst::isTrueWhenEqual(Pred);
10057   if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
10058     if (FoundLHS == FoundRHS)
10059       return CmpInst::isFalseWhenEqual(FoundPred);
10060 
10061   // Check to see if we can make the LHS or RHS match.
10062   if (LHS == FoundRHS || RHS == FoundLHS) {
10063     if (isa<SCEVConstant>(RHS)) {
10064       std::swap(FoundLHS, FoundRHS);
10065       FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
10066     } else {
10067       std::swap(LHS, RHS);
10068       Pred = ICmpInst::getSwappedPredicate(Pred);
10069     }
10070   }
10071 
10072   // Check whether the found predicate is the same as the desired predicate.
10073   if (FoundPred == Pred)
10074     return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context);
10075 
10076   // Check whether swapping the found predicate makes it the same as the
10077   // desired predicate.
10078   if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
10079     if (isa<SCEVConstant>(RHS))
10080       return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, Context);
10081     else
10082       return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), RHS,
10083                                    LHS, FoundLHS, FoundRHS, Context);
10084   }
10085 
10086   // Unsigned comparison is the same as signed comparison when both the operands
10087   // are non-negative.
10088   if (CmpInst::isUnsigned(FoundPred) &&
10089       CmpInst::getSignedPredicate(FoundPred) == Pred &&
10090       isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS))
10091     return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context);
10092 
10093   // Check if we can make progress by sharpening ranges.
10094   if (FoundPred == ICmpInst::ICMP_NE &&
10095       (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
10096 
10097     const SCEVConstant *C = nullptr;
10098     const SCEV *V = nullptr;
10099 
10100     if (isa<SCEVConstant>(FoundLHS)) {
10101       C = cast<SCEVConstant>(FoundLHS);
10102       V = FoundRHS;
10103     } else {
10104       C = cast<SCEVConstant>(FoundRHS);
10105       V = FoundLHS;
10106     }
10107 
10108     // The guarding predicate tells us that C != V. If the known range
10109     // of V is [C, t), we can sharpen the range to [C + 1, t).  The
10110     // range we consider has to correspond to same signedness as the
10111     // predicate we're interested in folding.
10112 
10113     APInt Min = ICmpInst::isSigned(Pred) ?
10114         getSignedRangeMin(V) : getUnsignedRangeMin(V);
10115 
10116     if (Min == C->getAPInt()) {
10117       // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
10118       // This is true even if (Min + 1) wraps around -- in case of
10119       // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
10120 
10121       APInt SharperMin = Min + 1;
10122 
10123       switch (Pred) {
10124         case ICmpInst::ICMP_SGE:
10125         case ICmpInst::ICMP_UGE:
10126           // We know V `Pred` SharperMin.  If this implies LHS `Pred`
10127           // RHS, we're done.
10128           if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
10129                                     Context))
10130             return true;
10131           LLVM_FALLTHROUGH;
10132 
10133         case ICmpInst::ICMP_SGT:
10134         case ICmpInst::ICMP_UGT:
10135           // We know from the range information that (V `Pred` Min ||
10136           // V == Min).  We know from the guarding condition that !(V
10137           // == Min).  This gives us
10138           //
10139           //       V `Pred` Min || V == Min && !(V == Min)
10140           //   =>  V `Pred` Min
10141           //
10142           // If V `Pred` Min implies LHS `Pred` RHS, we're done.
10143 
10144           if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min),
10145                                     Context))
10146             return true;
10147           break;
10148 
10149         // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
10150         case ICmpInst::ICMP_SLE:
10151         case ICmpInst::ICMP_ULE:
10152           if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
10153                                     LHS, V, getConstant(SharperMin), Context))
10154             return true;
10155           LLVM_FALLTHROUGH;
10156 
10157         case ICmpInst::ICMP_SLT:
10158         case ICmpInst::ICMP_ULT:
10159           if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
10160                                     LHS, V, getConstant(Min), Context))
10161             return true;
10162           break;
10163 
10164         default:
10165           // No change
10166           break;
10167       }
10168     }
10169   }
10170 
10171   // Check whether the actual condition is beyond sufficient.
10172   if (FoundPred == ICmpInst::ICMP_EQ)
10173     if (ICmpInst::isTrueWhenEqual(Pred))
10174       if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context))
10175         return true;
10176   if (Pred == ICmpInst::ICMP_NE)
10177     if (!ICmpInst::isTrueWhenEqual(FoundPred))
10178       if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS,
10179                                 Context))
10180         return true;
10181 
10182   // Otherwise assume the worst.
10183   return false;
10184 }
10185 
10186 bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
10187                                      const SCEV *&L, const SCEV *&R,
10188                                      SCEV::NoWrapFlags &Flags) {
10189   const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
10190   if (!AE || AE->getNumOperands() != 2)
10191     return false;
10192 
10193   L = AE->getOperand(0);
10194   R = AE->getOperand(1);
10195   Flags = AE->getNoWrapFlags();
10196   return true;
10197 }
10198 
10199 Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
10200                                                            const SCEV *Less) {
10201   // We avoid subtracting expressions here because this function is usually
10202   // fairly deep in the call stack (i.e. is called many times).
10203 
10204   // X - X = 0.
10205   if (More == Less)
10206     return APInt(getTypeSizeInBits(More->getType()), 0);
10207 
10208   if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
10209     const auto *LAR = cast<SCEVAddRecExpr>(Less);
10210     const auto *MAR = cast<SCEVAddRecExpr>(More);
10211 
10212     if (LAR->getLoop() != MAR->getLoop())
10213       return None;
10214 
10215     // We look at affine expressions only; not for correctness but to keep
10216     // getStepRecurrence cheap.
10217     if (!LAR->isAffine() || !MAR->isAffine())
10218       return None;
10219 
10220     if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
10221       return None;
10222 
10223     Less = LAR->getStart();
10224     More = MAR->getStart();
10225 
10226     // fall through
10227   }
10228 
10229   if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
10230     const auto &M = cast<SCEVConstant>(More)->getAPInt();
10231     const auto &L = cast<SCEVConstant>(Less)->getAPInt();
10232     return M - L;
10233   }
10234 
10235   SCEV::NoWrapFlags Flags;
10236   const SCEV *LLess = nullptr, *RLess = nullptr;
10237   const SCEV *LMore = nullptr, *RMore = nullptr;
10238   const SCEVConstant *C1 = nullptr, *C2 = nullptr;
10239   // Compare (X + C1) vs X.
10240   if (splitBinaryAdd(Less, LLess, RLess, Flags))
10241     if ((C1 = dyn_cast<SCEVConstant>(LLess)))
10242       if (RLess == More)
10243         return -(C1->getAPInt());
10244 
10245   // Compare X vs (X + C2).
10246   if (splitBinaryAdd(More, LMore, RMore, Flags))
10247     if ((C2 = dyn_cast<SCEVConstant>(LMore)))
10248       if (RMore == Less)
10249         return C2->getAPInt();
10250 
10251   // Compare (X + C1) vs (X + C2).
10252   if (C1 && C2 && RLess == RMore)
10253     return C2->getAPInt() - C1->getAPInt();
10254 
10255   return None;
10256 }
10257 
10258 bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
10259     ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
10260     const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *Context) {
10261   // Try to recognize the following pattern:
10262   //
10263   //   FoundRHS = ...
10264   // ...
10265   // loop:
10266   //   FoundLHS = {Start,+,W}
10267   // context_bb: // Basic block from the same loop
10268   //   known(Pred, FoundLHS, FoundRHS)
10269   //
10270   // If some predicate is known in the context of a loop, it is also known on
10271   // each iteration of this loop, including the first iteration. Therefore, in
10272   // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
10273   // prove the original pred using this fact.
10274   if (!Context)
10275     return false;
10276   const BasicBlock *ContextBB = Context->getParent();
10277   // Make sure AR varies in the context block.
10278   if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
10279     const Loop *L = AR->getLoop();
10280     // Make sure that context belongs to the loop and executes on 1st iteration
10281     // (if it ever executes at all).
10282     if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
10283       return false;
10284     if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
10285       return false;
10286     return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
10287   }
10288 
10289   if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
10290     const Loop *L = AR->getLoop();
10291     // Make sure that context belongs to the loop and executes on 1st iteration
10292     // (if it ever executes at all).
10293     if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
10294       return false;
10295     if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
10296       return false;
10297     return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
10298   }
10299 
10300   return false;
10301 }
10302 
10303 bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
10304     ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
10305     const SCEV *FoundLHS, const SCEV *FoundRHS) {
10306   if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
10307     return false;
10308 
10309   const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
10310   if (!AddRecLHS)
10311     return false;
10312 
10313   const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
10314   if (!AddRecFoundLHS)
10315     return false;
10316 
10317   // We'd like to let SCEV reason about control dependencies, so we constrain
10318   // both the inequalities to be about add recurrences on the same loop.  This
10319   // way we can use isLoopEntryGuardedByCond later.
10320 
10321   const Loop *L = AddRecFoundLHS->getLoop();
10322   if (L != AddRecLHS->getLoop())
10323     return false;
10324 
10325   //  FoundLHS u< FoundRHS u< -C =>  (FoundLHS + C) u< (FoundRHS + C) ... (1)
10326   //
10327   //  FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
10328   //                                                                  ... (2)
10329   //
10330   // Informal proof for (2), assuming (1) [*]:
10331   //
10332   // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
10333   //
10334   // Then
10335   //
10336   //       FoundLHS s< FoundRHS s< INT_MIN - C
10337   // <=>  (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C   [ using (3) ]
10338   // <=>  (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
10339   // <=>  (FoundLHS + INT_MIN + C + INT_MIN) s<
10340   //                        (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
10341   // <=>  FoundLHS + C s< FoundRHS + C
10342   //
10343   // [*]: (1) can be proved by ruling out overflow.
10344   //
10345   // [**]: This can be proved by analyzing all the four possibilities:
10346   //    (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
10347   //    (A s>= 0, B s>= 0).
10348   //
10349   // Note:
10350   // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
10351   // will not sign underflow.  For instance, say FoundLHS = (i8 -128), FoundRHS
10352   // = (i8 -127) and C = (i8 -100).  Then INT_MIN - C = (i8 -28), and FoundRHS
10353   // s< (INT_MIN - C).  Lack of sign overflow / underflow in "FoundRHS + C" is
10354   // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
10355   // C)".
10356 
10357   Optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
10358   Optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
10359   if (!LDiff || !RDiff || *LDiff != *RDiff)
10360     return false;
10361 
10362   if (LDiff->isMinValue())
10363     return true;
10364 
10365   APInt FoundRHSLimit;
10366 
10367   if (Pred == CmpInst::ICMP_ULT) {
10368     FoundRHSLimit = -(*RDiff);
10369   } else {
10370     assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
10371     FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
10372   }
10373 
10374   // Try to prove (1) or (2), as needed.
10375   return isAvailableAtLoopEntry(FoundRHS, L) &&
10376          isLoopEntryGuardedByCond(L, Pred, FoundRHS,
10377                                   getConstant(FoundRHSLimit));
10378 }
10379 
10380 bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
10381                                         const SCEV *LHS, const SCEV *RHS,
10382                                         const SCEV *FoundLHS,
10383                                         const SCEV *FoundRHS, unsigned Depth) {
10384   const PHINode *LPhi = nullptr, *RPhi = nullptr;
10385 
10386   auto ClearOnExit = make_scope_exit([&]() {
10387     if (LPhi) {
10388       bool Erased = PendingMerges.erase(LPhi);
10389       assert(Erased && "Failed to erase LPhi!");
10390       (void)Erased;
10391     }
10392     if (RPhi) {
10393       bool Erased = PendingMerges.erase(RPhi);
10394       assert(Erased && "Failed to erase RPhi!");
10395       (void)Erased;
10396     }
10397   });
10398 
10399   // Find respective Phis and check that they are not being pending.
10400   if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
10401     if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
10402       if (!PendingMerges.insert(Phi).second)
10403         return false;
10404       LPhi = Phi;
10405     }
10406   if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
10407     if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
10408       // If we detect a loop of Phi nodes being processed by this method, for
10409       // example:
10410       //
10411       //   %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
10412       //   %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
10413       //
10414       // we don't want to deal with a case that complex, so return conservative
10415       // answer false.
10416       if (!PendingMerges.insert(Phi).second)
10417         return false;
10418       RPhi = Phi;
10419     }
10420 
10421   // If none of LHS, RHS is a Phi, nothing to do here.
10422   if (!LPhi && !RPhi)
10423     return false;
10424 
10425   // If there is a SCEVUnknown Phi we are interested in, make it left.
10426   if (!LPhi) {
10427     std::swap(LHS, RHS);
10428     std::swap(FoundLHS, FoundRHS);
10429     std::swap(LPhi, RPhi);
10430     Pred = ICmpInst::getSwappedPredicate(Pred);
10431   }
10432 
10433   assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
10434   const BasicBlock *LBB = LPhi->getParent();
10435   const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
10436 
10437   auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
10438     return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
10439            isImpliedCondOperandsViaRanges(Pred, S1, S2, FoundLHS, FoundRHS) ||
10440            isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
10441   };
10442 
10443   if (RPhi && RPhi->getParent() == LBB) {
10444     // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
10445     // If we compare two Phis from the same block, and for each entry block
10446     // the predicate is true for incoming values from this block, then the
10447     // predicate is also true for the Phis.
10448     for (const BasicBlock *IncBB : predecessors(LBB)) {
10449       const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
10450       const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
10451       if (!ProvedEasily(L, R))
10452         return false;
10453     }
10454   } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
10455     // Case two: RHS is also a Phi from the same basic block, and it is an
10456     // AddRec. It means that there is a loop which has both AddRec and Unknown
10457     // PHIs, for it we can compare incoming values of AddRec from above the loop
10458     // and latch with their respective incoming values of LPhi.
10459     // TODO: Generalize to handle loops with many inputs in a header.
10460     if (LPhi->getNumIncomingValues() != 2) return false;
10461 
10462     auto *RLoop = RAR->getLoop();
10463     auto *Predecessor = RLoop->getLoopPredecessor();
10464     assert(Predecessor && "Loop with AddRec with no predecessor?");
10465     const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
10466     if (!ProvedEasily(L1, RAR->getStart()))
10467       return false;
10468     auto *Latch = RLoop->getLoopLatch();
10469     assert(Latch && "Loop with AddRec with no latch?");
10470     const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
10471     if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
10472       return false;
10473   } else {
10474     // In all other cases go over inputs of LHS and compare each of them to RHS,
10475     // the predicate is true for (LHS, RHS) if it is true for all such pairs.
10476     // At this point RHS is either a non-Phi, or it is a Phi from some block
10477     // different from LBB.
10478     for (const BasicBlock *IncBB : predecessors(LBB)) {
10479       // Check that RHS is available in this block.
10480       if (!dominates(RHS, IncBB))
10481         return false;
10482       const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
10483       if (!ProvedEasily(L, RHS))
10484         return false;
10485     }
10486   }
10487   return true;
10488 }
10489 
10490 bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
10491                                             const SCEV *LHS, const SCEV *RHS,
10492                                             const SCEV *FoundLHS,
10493                                             const SCEV *FoundRHS,
10494                                             const Instruction *Context) {
10495   if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
10496     return true;
10497 
10498   if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
10499     return true;
10500 
10501   if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
10502                                           Context))
10503     return true;
10504 
10505   return isImpliedCondOperandsHelper(Pred, LHS, RHS,
10506                                      FoundLHS, FoundRHS) ||
10507          // ~x < ~y --> x > y
10508          isImpliedCondOperandsHelper(Pred, LHS, RHS,
10509                                      getNotSCEV(FoundRHS),
10510                                      getNotSCEV(FoundLHS));
10511 }
10512 
10513 /// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
10514 template <typename MinMaxExprType>
10515 static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
10516                                  const SCEV *Candidate) {
10517   const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
10518   if (!MinMaxExpr)
10519     return false;
10520 
10521   return find(MinMaxExpr->operands(), Candidate) != MinMaxExpr->op_end();
10522 }
10523 
10524 static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE,
10525                                            ICmpInst::Predicate Pred,
10526                                            const SCEV *LHS, const SCEV *RHS) {
10527   // If both sides are affine addrecs for the same loop, with equal
10528   // steps, and we know the recurrences don't wrap, then we only
10529   // need to check the predicate on the starting values.
10530 
10531   if (!ICmpInst::isRelational(Pred))
10532     return false;
10533 
10534   const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
10535   if (!LAR)
10536     return false;
10537   const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
10538   if (!RAR)
10539     return false;
10540   if (LAR->getLoop() != RAR->getLoop())
10541     return false;
10542   if (!LAR->isAffine() || !RAR->isAffine())
10543     return false;
10544 
10545   if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
10546     return false;
10547 
10548   SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ?
10549                          SCEV::FlagNSW : SCEV::FlagNUW;
10550   if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
10551     return false;
10552 
10553   return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
10554 }
10555 
10556 /// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
10557 /// expression?
10558 static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
10559                                         ICmpInst::Predicate Pred,
10560                                         const SCEV *LHS, const SCEV *RHS) {
10561   switch (Pred) {
10562   default:
10563     return false;
10564 
10565   case ICmpInst::ICMP_SGE:
10566     std::swap(LHS, RHS);
10567     LLVM_FALLTHROUGH;
10568   case ICmpInst::ICMP_SLE:
10569     return
10570         // min(A, ...) <= A
10571         IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
10572         // A <= max(A, ...)
10573         IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
10574 
10575   case ICmpInst::ICMP_UGE:
10576     std::swap(LHS, RHS);
10577     LLVM_FALLTHROUGH;
10578   case ICmpInst::ICMP_ULE:
10579     return
10580         // min(A, ...) <= A
10581         IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
10582         // A <= max(A, ...)
10583         IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
10584   }
10585 
10586   llvm_unreachable("covered switch fell through?!");
10587 }
10588 
10589 bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
10590                                              const SCEV *LHS, const SCEV *RHS,
10591                                              const SCEV *FoundLHS,
10592                                              const SCEV *FoundRHS,
10593                                              unsigned Depth) {
10594   assert(getTypeSizeInBits(LHS->getType()) ==
10595              getTypeSizeInBits(RHS->getType()) &&
10596          "LHS and RHS have different sizes?");
10597   assert(getTypeSizeInBits(FoundLHS->getType()) ==
10598              getTypeSizeInBits(FoundRHS->getType()) &&
10599          "FoundLHS and FoundRHS have different sizes?");
10600   // We want to avoid hurting the compile time with analysis of too big trees.
10601   if (Depth > MaxSCEVOperationsImplicationDepth)
10602     return false;
10603 
10604   // We only want to work with GT comparison so far.
10605   if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
10606     Pred = CmpInst::getSwappedPredicate(Pred);
10607     std::swap(LHS, RHS);
10608     std::swap(FoundLHS, FoundRHS);
10609   }
10610 
10611   // For unsigned, try to reduce it to corresponding signed comparison.
10612   if (Pred == ICmpInst::ICMP_UGT)
10613     // We can replace unsigned predicate with its signed counterpart if all
10614     // involved values are non-negative.
10615     // TODO: We could have better support for unsigned.
10616     if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
10617       // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
10618       // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
10619       // use this fact to prove that LHS and RHS are non-negative.
10620       const SCEV *MinusOne = getMinusOne(LHS->getType());
10621       if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
10622                                 FoundRHS) &&
10623           isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
10624                                 FoundRHS))
10625         Pred = ICmpInst::ICMP_SGT;
10626     }
10627 
10628   if (Pred != ICmpInst::ICMP_SGT)
10629     return false;
10630 
10631   auto GetOpFromSExt = [&](const SCEV *S) {
10632     if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
10633       return Ext->getOperand();
10634     // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
10635     // the constant in some cases.
10636     return S;
10637   };
10638 
10639   // Acquire values from extensions.
10640   auto *OrigLHS = LHS;
10641   auto *OrigFoundLHS = FoundLHS;
10642   LHS = GetOpFromSExt(LHS);
10643   FoundLHS = GetOpFromSExt(FoundLHS);
10644 
10645   // Is the SGT predicate can be proved trivially or using the found context.
10646   auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
10647     return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
10648            isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
10649                                   FoundRHS, Depth + 1);
10650   };
10651 
10652   if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
10653     // We want to avoid creation of any new non-constant SCEV. Since we are
10654     // going to compare the operands to RHS, we should be certain that we don't
10655     // need any size extensions for this. So let's decline all cases when the
10656     // sizes of types of LHS and RHS do not match.
10657     // TODO: Maybe try to get RHS from sext to catch more cases?
10658     if (getTypeSizeInBits(LHS->getType()) != getTypeSizeInBits(RHS->getType()))
10659       return false;
10660 
10661     // Should not overflow.
10662     if (!LHSAddExpr->hasNoSignedWrap())
10663       return false;
10664 
10665     auto *LL = LHSAddExpr->getOperand(0);
10666     auto *LR = LHSAddExpr->getOperand(1);
10667     auto *MinusOne = getMinusOne(RHS->getType());
10668 
10669     // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
10670     auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
10671       return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
10672     };
10673     // Try to prove the following rule:
10674     // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
10675     // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
10676     if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
10677       return true;
10678   } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
10679     Value *LL, *LR;
10680     // FIXME: Once we have SDiv implemented, we can get rid of this matching.
10681 
10682     using namespace llvm::PatternMatch;
10683 
10684     if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
10685       // Rules for division.
10686       // We are going to perform some comparisons with Denominator and its
10687       // derivative expressions. In general case, creating a SCEV for it may
10688       // lead to a complex analysis of the entire graph, and in particular it
10689       // can request trip count recalculation for the same loop. This would
10690       // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
10691       // this, we only want to create SCEVs that are constants in this section.
10692       // So we bail if Denominator is not a constant.
10693       if (!isa<ConstantInt>(LR))
10694         return false;
10695 
10696       auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
10697 
10698       // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
10699       // then a SCEV for the numerator already exists and matches with FoundLHS.
10700       auto *Numerator = getExistingSCEV(LL);
10701       if (!Numerator || Numerator->getType() != FoundLHS->getType())
10702         return false;
10703 
10704       // Make sure that the numerator matches with FoundLHS and the denominator
10705       // is positive.
10706       if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
10707         return false;
10708 
10709       auto *DTy = Denominator->getType();
10710       auto *FRHSTy = FoundRHS->getType();
10711       if (DTy->isPointerTy() != FRHSTy->isPointerTy())
10712         // One of types is a pointer and another one is not. We cannot extend
10713         // them properly to a wider type, so let us just reject this case.
10714         // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
10715         // to avoid this check.
10716         return false;
10717 
10718       // Given that:
10719       // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
10720       auto *WTy = getWiderType(DTy, FRHSTy);
10721       auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
10722       auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
10723 
10724       // Try to prove the following rule:
10725       // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
10726       // For example, given that FoundLHS > 2. It means that FoundLHS is at
10727       // least 3. If we divide it by Denominator < 4, we will have at least 1.
10728       auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
10729       if (isKnownNonPositive(RHS) &&
10730           IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
10731         return true;
10732 
10733       // Try to prove the following rule:
10734       // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
10735       // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
10736       // If we divide it by Denominator > 2, then:
10737       // 1. If FoundLHS is negative, then the result is 0.
10738       // 2. If FoundLHS is non-negative, then the result is non-negative.
10739       // Anyways, the result is non-negative.
10740       auto *MinusOne = getMinusOne(WTy);
10741       auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
10742       if (isKnownNegative(RHS) &&
10743           IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
10744         return true;
10745     }
10746   }
10747 
10748   // If our expression contained SCEVUnknown Phis, and we split it down and now
10749   // need to prove something for them, try to prove the predicate for every
10750   // possible incoming values of those Phis.
10751   if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
10752     return true;
10753 
10754   return false;
10755 }
10756 
10757 static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred,
10758                                         const SCEV *LHS, const SCEV *RHS) {
10759   // zext x u<= sext x, sext x s<= zext x
10760   switch (Pred) {
10761   case ICmpInst::ICMP_SGE:
10762     std::swap(LHS, RHS);
10763     LLVM_FALLTHROUGH;
10764   case ICmpInst::ICMP_SLE: {
10765     // If operand >=s 0 then ZExt == SExt.  If operand <s 0 then SExt <s ZExt.
10766     const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
10767     const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
10768     if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
10769       return true;
10770     break;
10771   }
10772   case ICmpInst::ICMP_UGE:
10773     std::swap(LHS, RHS);
10774     LLVM_FALLTHROUGH;
10775   case ICmpInst::ICMP_ULE: {
10776     // If operand >=s 0 then ZExt == SExt.  If operand <s 0 then ZExt <u SExt.
10777     const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
10778     const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
10779     if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
10780       return true;
10781     break;
10782   }
10783   default:
10784     break;
10785   };
10786   return false;
10787 }
10788 
10789 bool
10790 ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
10791                                            const SCEV *LHS, const SCEV *RHS) {
10792   return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
10793          isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
10794          IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
10795          IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
10796          isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
10797 }
10798 
10799 bool
10800 ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
10801                                              const SCEV *LHS, const SCEV *RHS,
10802                                              const SCEV *FoundLHS,
10803                                              const SCEV *FoundRHS) {
10804   switch (Pred) {
10805   default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
10806   case ICmpInst::ICMP_EQ:
10807   case ICmpInst::ICMP_NE:
10808     if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
10809       return true;
10810     break;
10811   case ICmpInst::ICMP_SLT:
10812   case ICmpInst::ICMP_SLE:
10813     if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
10814         isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
10815       return true;
10816     break;
10817   case ICmpInst::ICMP_SGT:
10818   case ICmpInst::ICMP_SGE:
10819     if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
10820         isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
10821       return true;
10822     break;
10823   case ICmpInst::ICMP_ULT:
10824   case ICmpInst::ICMP_ULE:
10825     if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
10826         isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
10827       return true;
10828     break;
10829   case ICmpInst::ICMP_UGT:
10830   case ICmpInst::ICMP_UGE:
10831     if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
10832         isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
10833       return true;
10834     break;
10835   }
10836 
10837   // Maybe it can be proved via operations?
10838   if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
10839     return true;
10840 
10841   return false;
10842 }
10843 
10844 bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
10845                                                      const SCEV *LHS,
10846                                                      const SCEV *RHS,
10847                                                      const SCEV *FoundLHS,
10848                                                      const SCEV *FoundRHS) {
10849   if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
10850     // The restriction on `FoundRHS` be lifted easily -- it exists only to
10851     // reduce the compile time impact of this optimization.
10852     return false;
10853 
10854   Optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
10855   if (!Addend)
10856     return false;
10857 
10858   const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
10859 
10860   // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
10861   // antecedent "`FoundLHS` `Pred` `FoundRHS`".
10862   ConstantRange FoundLHSRange =
10863       ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS);
10864 
10865   // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
10866   ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
10867 
10868   // We can also compute the range of values for `LHS` that satisfy the
10869   // consequent, "`LHS` `Pred` `RHS`":
10870   const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
10871   ConstantRange SatisfyingLHSRange =
10872       ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS);
10873 
10874   // The antecedent implies the consequent if every value of `LHS` that
10875   // satisfies the antecedent also satisfies the consequent.
10876   return SatisfyingLHSRange.contains(LHSRange);
10877 }
10878 
10879 bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
10880                                          bool IsSigned, bool NoWrap) {
10881   assert(isKnownPositive(Stride) && "Positive stride expected!");
10882 
10883   if (NoWrap) return false;
10884 
10885   unsigned BitWidth = getTypeSizeInBits(RHS->getType());
10886   const SCEV *One = getOne(Stride->getType());
10887 
10888   if (IsSigned) {
10889     APInt MaxRHS = getSignedRangeMax(RHS);
10890     APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
10891     APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
10892 
10893     // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
10894     return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
10895   }
10896 
10897   APInt MaxRHS = getUnsignedRangeMax(RHS);
10898   APInt MaxValue = APInt::getMaxValue(BitWidth);
10899   APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
10900 
10901   // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
10902   return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
10903 }
10904 
10905 bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
10906                                          bool IsSigned, bool NoWrap) {
10907   if (NoWrap) return false;
10908 
10909   unsigned BitWidth = getTypeSizeInBits(RHS->getType());
10910   const SCEV *One = getOne(Stride->getType());
10911 
10912   if (IsSigned) {
10913     APInt MinRHS = getSignedRangeMin(RHS);
10914     APInt MinValue = APInt::getSignedMinValue(BitWidth);
10915     APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
10916 
10917     // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
10918     return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
10919   }
10920 
10921   APInt MinRHS = getUnsignedRangeMin(RHS);
10922   APInt MinValue = APInt::getMinValue(BitWidth);
10923   APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
10924 
10925   // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
10926   return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
10927 }
10928 
10929 const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step,
10930                                             bool Equality) {
10931   const SCEV *One = getOne(Step->getType());
10932   Delta = Equality ? getAddExpr(Delta, Step)
10933                    : getAddExpr(Delta, getMinusSCEV(Step, One));
10934   return getUDivExpr(Delta, Step);
10935 }
10936 
10937 const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
10938                                                     const SCEV *Stride,
10939                                                     const SCEV *End,
10940                                                     unsigned BitWidth,
10941                                                     bool IsSigned) {
10942 
10943   assert(!isKnownNonPositive(Stride) &&
10944          "Stride is expected strictly positive!");
10945   // Calculate the maximum backedge count based on the range of values
10946   // permitted by Start, End, and Stride.
10947   const SCEV *MaxBECount;
10948   APInt MinStart =
10949       IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
10950 
10951   APInt StrideForMaxBECount =
10952       IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
10953 
10954   // We already know that the stride is positive, so we paper over conservatism
10955   // in our range computation by forcing StrideForMaxBECount to be at least one.
10956   // In theory this is unnecessary, but we expect MaxBECount to be a
10957   // SCEVConstant, and (udiv <constant> 0) is not constant folded by SCEV (there
10958   // is nothing to constant fold it to).
10959   APInt One(BitWidth, 1, IsSigned);
10960   StrideForMaxBECount = APIntOps::smax(One, StrideForMaxBECount);
10961 
10962   APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
10963                             : APInt::getMaxValue(BitWidth);
10964   APInt Limit = MaxValue - (StrideForMaxBECount - 1);
10965 
10966   // Although End can be a MAX expression we estimate MaxEnd considering only
10967   // the case End = RHS of the loop termination condition. This is safe because
10968   // in the other case (End - Start) is zero, leading to a zero maximum backedge
10969   // taken count.
10970   APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
10971                           : APIntOps::umin(getUnsignedRangeMax(End), Limit);
10972 
10973   MaxBECount = computeBECount(getConstant(MaxEnd - MinStart) /* Delta */,
10974                               getConstant(StrideForMaxBECount) /* Step */,
10975                               false /* Equality */);
10976 
10977   return MaxBECount;
10978 }
10979 
10980 ScalarEvolution::ExitLimit
10981 ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
10982                                   const Loop *L, bool IsSigned,
10983                                   bool ControlsExit, bool AllowPredicates) {
10984   SmallPtrSet<const SCEVPredicate *, 4> Predicates;
10985 
10986   const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
10987   bool PredicatedIV = false;
10988 
10989   if (!IV && AllowPredicates) {
10990     // Try to make this an AddRec using runtime tests, in the first X
10991     // iterations of this loop, where X is the SCEV expression found by the
10992     // algorithm below.
10993     IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
10994     PredicatedIV = true;
10995   }
10996 
10997   // Avoid weird loops
10998   if (!IV || IV->getLoop() != L || !IV->isAffine())
10999     return getCouldNotCompute();
11000 
11001   bool NoWrap = ControlsExit &&
11002                 IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW);
11003 
11004   const SCEV *Stride = IV->getStepRecurrence(*this);
11005 
11006   bool PositiveStride = isKnownPositive(Stride);
11007 
11008   // Avoid negative or zero stride values.
11009   if (!PositiveStride) {
11010     // We can compute the correct backedge taken count for loops with unknown
11011     // strides if we can prove that the loop is not an infinite loop with side
11012     // effects. Here's the loop structure we are trying to handle -
11013     //
11014     // i = start
11015     // do {
11016     //   A[i] = i;
11017     //   i += s;
11018     // } while (i < end);
11019     //
11020     // The backedge taken count for such loops is evaluated as -
11021     // (max(end, start + stride) - start - 1) /u stride
11022     //
11023     // The additional preconditions that we need to check to prove correctness
11024     // of the above formula is as follows -
11025     //
11026     // a) IV is either nuw or nsw depending upon signedness (indicated by the
11027     //    NoWrap flag).
11028     // b) loop is single exit with no side effects.
11029     //
11030     //
11031     // Precondition a) implies that if the stride is negative, this is a single
11032     // trip loop. The backedge taken count formula reduces to zero in this case.
11033     //
11034     // Precondition b) implies that the unknown stride cannot be zero otherwise
11035     // we have UB.
11036     //
11037     // The positive stride case is the same as isKnownPositive(Stride) returning
11038     // true (original behavior of the function).
11039     //
11040     // We want to make sure that the stride is truly unknown as there are edge
11041     // cases where ScalarEvolution propagates no wrap flags to the
11042     // post-increment/decrement IV even though the increment/decrement operation
11043     // itself is wrapping. The computed backedge taken count may be wrong in
11044     // such cases. This is prevented by checking that the stride is not known to
11045     // be either positive or non-positive. For example, no wrap flags are
11046     // propagated to the post-increment IV of this loop with a trip count of 2 -
11047     //
11048     // unsigned char i;
11049     // for(i=127; i<128; i+=129)
11050     //   A[i] = i;
11051     //
11052     if (PredicatedIV || !NoWrap || isKnownNonPositive(Stride) ||
11053         !loopHasNoSideEffects(L))
11054       return getCouldNotCompute();
11055   } else if (!Stride->isOne() &&
11056              doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap))
11057     // Avoid proven overflow cases: this will ensure that the backedge taken
11058     // count will not generate any unsigned overflow. Relaxed no-overflow
11059     // conditions exploit NoWrapFlags, allowing to optimize in presence of
11060     // undefined behaviors like the case of C language.
11061     return getCouldNotCompute();
11062 
11063   ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT
11064                                       : ICmpInst::ICMP_ULT;
11065   const SCEV *Start = IV->getStart();
11066   const SCEV *End = RHS;
11067   // When the RHS is not invariant, we do not know the end bound of the loop and
11068   // cannot calculate the ExactBECount needed by ExitLimit. However, we can
11069   // calculate the MaxBECount, given the start, stride and max value for the end
11070   // bound of the loop (RHS), and the fact that IV does not overflow (which is
11071   // checked above).
11072   if (!isLoopInvariant(RHS, L)) {
11073     const SCEV *MaxBECount = computeMaxBECountForLT(
11074         Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
11075     return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
11076                      false /*MaxOrZero*/, Predicates);
11077   }
11078   // If the backedge is taken at least once, then it will be taken
11079   // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start
11080   // is the LHS value of the less-than comparison the first time it is evaluated
11081   // and End is the RHS.
11082   const SCEV *BECountIfBackedgeTaken =
11083     computeBECount(getMinusSCEV(End, Start), Stride, false);
11084   // If the loop entry is guarded by the result of the backedge test of the
11085   // first loop iteration, then we know the backedge will be taken at least
11086   // once and so the backedge taken count is as above. If not then we use the
11087   // expression (max(End,Start)-Start)/Stride to describe the backedge count,
11088   // as if the backedge is taken at least once max(End,Start) is End and so the
11089   // result is as above, and if not max(End,Start) is Start so we get a backedge
11090   // count of zero.
11091   const SCEV *BECount;
11092   if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS))
11093     BECount = BECountIfBackedgeTaken;
11094   else {
11095     // If we know that RHS >= Start in the context of loop, then we know that
11096     // max(RHS, Start) = RHS at this point.
11097     if (isLoopEntryGuardedByCond(
11098             L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, RHS, Start))
11099       End = RHS;
11100     else
11101       End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
11102     BECount = computeBECount(getMinusSCEV(End, Start), Stride, false);
11103   }
11104 
11105   const SCEV *MaxBECount;
11106   bool MaxOrZero = false;
11107   if (isa<SCEVConstant>(BECount))
11108     MaxBECount = BECount;
11109   else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
11110     // If we know exactly how many times the backedge will be taken if it's
11111     // taken at least once, then the backedge count will either be that or
11112     // zero.
11113     MaxBECount = BECountIfBackedgeTaken;
11114     MaxOrZero = true;
11115   } else {
11116     MaxBECount = computeMaxBECountForLT(
11117         Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
11118   }
11119 
11120   if (isa<SCEVCouldNotCompute>(MaxBECount) &&
11121       !isa<SCEVCouldNotCompute>(BECount))
11122     MaxBECount = getConstant(getUnsignedRangeMax(BECount));
11123 
11124   return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates);
11125 }
11126 
11127 ScalarEvolution::ExitLimit
11128 ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
11129                                      const Loop *L, bool IsSigned,
11130                                      bool ControlsExit, bool AllowPredicates) {
11131   SmallPtrSet<const SCEVPredicate *, 4> Predicates;
11132   // We handle only IV > Invariant
11133   if (!isLoopInvariant(RHS, L))
11134     return getCouldNotCompute();
11135 
11136   const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
11137   if (!IV && AllowPredicates)
11138     // Try to make this an AddRec using runtime tests, in the first X
11139     // iterations of this loop, where X is the SCEV expression found by the
11140     // algorithm below.
11141     IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
11142 
11143   // Avoid weird loops
11144   if (!IV || IV->getLoop() != L || !IV->isAffine())
11145     return getCouldNotCompute();
11146 
11147   bool NoWrap = ControlsExit &&
11148                 IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW);
11149 
11150   const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
11151 
11152   // Avoid negative or zero stride values
11153   if (!isKnownPositive(Stride))
11154     return getCouldNotCompute();
11155 
11156   // Avoid proven overflow cases: this will ensure that the backedge taken count
11157   // will not generate any unsigned overflow. Relaxed no-overflow conditions
11158   // exploit NoWrapFlags, allowing to optimize in presence of undefined
11159   // behaviors like the case of C language.
11160   if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap))
11161     return getCouldNotCompute();
11162 
11163   ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT
11164                                       : ICmpInst::ICMP_UGT;
11165 
11166   const SCEV *Start = IV->getStart();
11167   const SCEV *End = RHS;
11168   if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
11169     // If we know that Start >= RHS in the context of loop, then we know that
11170     // min(RHS, Start) = RHS at this point.
11171     if (isLoopEntryGuardedByCond(
11172             L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
11173       End = RHS;
11174     else
11175       End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
11176   }
11177 
11178   const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false);
11179 
11180   APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
11181                             : getUnsignedRangeMax(Start);
11182 
11183   APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
11184                              : getUnsignedRangeMin(Stride);
11185 
11186   unsigned BitWidth = getTypeSizeInBits(LHS->getType());
11187   APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
11188                          : APInt::getMinValue(BitWidth) + (MinStride - 1);
11189 
11190   // Although End can be a MIN expression we estimate MinEnd considering only
11191   // the case End = RHS. This is safe because in the other case (Start - End)
11192   // is zero, leading to a zero maximum backedge taken count.
11193   APInt MinEnd =
11194     IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
11195              : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
11196 
11197   const SCEV *MaxBECount = isa<SCEVConstant>(BECount)
11198                                ? BECount
11199                                : computeBECount(getConstant(MaxStart - MinEnd),
11200                                                 getConstant(MinStride), false);
11201 
11202   if (isa<SCEVCouldNotCompute>(MaxBECount))
11203     MaxBECount = BECount;
11204 
11205   return ExitLimit(BECount, MaxBECount, false, Predicates);
11206 }
11207 
11208 const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
11209                                                     ScalarEvolution &SE) const {
11210   if (Range.isFullSet())  // Infinite loop.
11211     return SE.getCouldNotCompute();
11212 
11213   // If the start is a non-zero constant, shift the range to simplify things.
11214   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
11215     if (!SC->getValue()->isZero()) {
11216       SmallVector<const SCEV *, 4> Operands(op_begin(), op_end());
11217       Operands[0] = SE.getZero(SC->getType());
11218       const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
11219                                              getNoWrapFlags(FlagNW));
11220       if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
11221         return ShiftedAddRec->getNumIterationsInRange(
11222             Range.subtract(SC->getAPInt()), SE);
11223       // This is strange and shouldn't happen.
11224       return SE.getCouldNotCompute();
11225     }
11226 
11227   // The only time we can solve this is when we have all constant indices.
11228   // Otherwise, we cannot determine the overflow conditions.
11229   if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
11230     return SE.getCouldNotCompute();
11231 
11232   // Okay at this point we know that all elements of the chrec are constants and
11233   // that the start element is zero.
11234 
11235   // First check to see if the range contains zero.  If not, the first
11236   // iteration exits.
11237   unsigned BitWidth = SE.getTypeSizeInBits(getType());
11238   if (!Range.contains(APInt(BitWidth, 0)))
11239     return SE.getZero(getType());
11240 
11241   if (isAffine()) {
11242     // If this is an affine expression then we have this situation:
11243     //   Solve {0,+,A} in Range  ===  Ax in Range
11244 
11245     // We know that zero is in the range.  If A is positive then we know that
11246     // the upper value of the range must be the first possible exit value.
11247     // If A is negative then the lower of the range is the last possible loop
11248     // value.  Also note that we already checked for a full range.
11249     APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
11250     APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
11251 
11252     // The exit value should be (End+A)/A.
11253     APInt ExitVal = (End + A).udiv(A);
11254     ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
11255 
11256     // Evaluate at the exit value.  If we really did fall out of the valid
11257     // range, then we computed our trip count, otherwise wrap around or other
11258     // things must have happened.
11259     ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
11260     if (Range.contains(Val->getValue()))
11261       return SE.getCouldNotCompute();  // Something strange happened
11262 
11263     // Ensure that the previous value is in the range.  This is a sanity check.
11264     assert(Range.contains(
11265            EvaluateConstantChrecAtConstant(this,
11266            ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
11267            "Linear scev computation is off in a bad way!");
11268     return SE.getConstant(ExitValue);
11269   }
11270 
11271   if (isQuadratic()) {
11272     if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
11273       return SE.getConstant(S.getValue());
11274   }
11275 
11276   return SE.getCouldNotCompute();
11277 }
11278 
11279 const SCEVAddRecExpr *
11280 SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const {
11281   assert(getNumOperands() > 1 && "AddRec with zero step?");
11282   // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
11283   // but in this case we cannot guarantee that the value returned will be an
11284   // AddRec because SCEV does not have a fixed point where it stops
11285   // simplification: it is legal to return ({rec1} + {rec2}). For example, it
11286   // may happen if we reach arithmetic depth limit while simplifying. So we
11287   // construct the returned value explicitly.
11288   SmallVector<const SCEV *, 3> Ops;
11289   // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
11290   // (this + Step) is {A+B,+,B+C,+...,+,N}.
11291   for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
11292     Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
11293   // We know that the last operand is not a constant zero (otherwise it would
11294   // have been popped out earlier). This guarantees us that if the result has
11295   // the same last operand, then it will also not be popped out, meaning that
11296   // the returned value will be an AddRec.
11297   const SCEV *Last = getOperand(getNumOperands() - 1);
11298   assert(!Last->isZero() && "Recurrency with zero step?");
11299   Ops.push_back(Last);
11300   return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
11301                                                SCEV::FlagAnyWrap));
11302 }
11303 
11304 // Return true when S contains at least an undef value.
11305 static inline bool containsUndefs(const SCEV *S) {
11306   return SCEVExprContains(S, [](const SCEV *S) {
11307     if (const auto *SU = dyn_cast<SCEVUnknown>(S))
11308       return isa<UndefValue>(SU->getValue());
11309     return false;
11310   });
11311 }
11312 
11313 namespace {
11314 
11315 // Collect all steps of SCEV expressions.
11316 struct SCEVCollectStrides {
11317   ScalarEvolution &SE;
11318   SmallVectorImpl<const SCEV *> &Strides;
11319 
11320   SCEVCollectStrides(ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &S)
11321       : SE(SE), Strides(S) {}
11322 
11323   bool follow(const SCEV *S) {
11324     if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S))
11325       Strides.push_back(AR->getStepRecurrence(SE));
11326     return true;
11327   }
11328 
11329   bool isDone() const { return false; }
11330 };
11331 
11332 // Collect all SCEVUnknown and SCEVMulExpr expressions.
11333 struct SCEVCollectTerms {
11334   SmallVectorImpl<const SCEV *> &Terms;
11335 
11336   SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T) : Terms(T) {}
11337 
11338   bool follow(const SCEV *S) {
11339     if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S) ||
11340         isa<SCEVSignExtendExpr>(S)) {
11341       if (!containsUndefs(S))
11342         Terms.push_back(S);
11343 
11344       // Stop recursion: once we collected a term, do not walk its operands.
11345       return false;
11346     }
11347 
11348     // Keep looking.
11349     return true;
11350   }
11351 
11352   bool isDone() const { return false; }
11353 };
11354 
11355 // Check if a SCEV contains an AddRecExpr.
11356 struct SCEVHasAddRec {
11357   bool &ContainsAddRec;
11358 
11359   SCEVHasAddRec(bool &ContainsAddRec) : ContainsAddRec(ContainsAddRec) {
11360     ContainsAddRec = false;
11361   }
11362 
11363   bool follow(const SCEV *S) {
11364     if (isa<SCEVAddRecExpr>(S)) {
11365       ContainsAddRec = true;
11366 
11367       // Stop recursion: once we collected a term, do not walk its operands.
11368       return false;
11369     }
11370 
11371     // Keep looking.
11372     return true;
11373   }
11374 
11375   bool isDone() const { return false; }
11376 };
11377 
11378 // Find factors that are multiplied with an expression that (possibly as a
11379 // subexpression) contains an AddRecExpr. In the expression:
11380 //
11381 //  8 * (100 +  %p * %q * (%a + {0, +, 1}_loop))
11382 //
11383 // "%p * %q" are factors multiplied by the expression "(%a + {0, +, 1}_loop)"
11384 // that contains the AddRec {0, +, 1}_loop. %p * %q are likely to be array size
11385 // parameters as they form a product with an induction variable.
11386 //
11387 // This collector expects all array size parameters to be in the same MulExpr.
11388 // It might be necessary to later add support for collecting parameters that are
11389 // spread over different nested MulExpr.
11390 struct SCEVCollectAddRecMultiplies {
11391   SmallVectorImpl<const SCEV *> &Terms;
11392   ScalarEvolution &SE;
11393 
11394   SCEVCollectAddRecMultiplies(SmallVectorImpl<const SCEV *> &T, ScalarEvolution &SE)
11395       : Terms(T), SE(SE) {}
11396 
11397   bool follow(const SCEV *S) {
11398     if (auto *Mul = dyn_cast<SCEVMulExpr>(S)) {
11399       bool HasAddRec = false;
11400       SmallVector<const SCEV *, 0> Operands;
11401       for (auto Op : Mul->operands()) {
11402         const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(Op);
11403         if (Unknown && !isa<CallInst>(Unknown->getValue())) {
11404           Operands.push_back(Op);
11405         } else if (Unknown) {
11406           HasAddRec = true;
11407         } else {
11408           bool ContainsAddRec = false;
11409           SCEVHasAddRec ContiansAddRec(ContainsAddRec);
11410           visitAll(Op, ContiansAddRec);
11411           HasAddRec |= ContainsAddRec;
11412         }
11413       }
11414       if (Operands.size() == 0)
11415         return true;
11416 
11417       if (!HasAddRec)
11418         return false;
11419 
11420       Terms.push_back(SE.getMulExpr(Operands));
11421       // Stop recursion: once we collected a term, do not walk its operands.
11422       return false;
11423     }
11424 
11425     // Keep looking.
11426     return true;
11427   }
11428 
11429   bool isDone() const { return false; }
11430 };
11431 
11432 } // end anonymous namespace
11433 
11434 /// Find parametric terms in this SCEVAddRecExpr. We first for parameters in
11435 /// two places:
11436 ///   1) The strides of AddRec expressions.
11437 ///   2) Unknowns that are multiplied with AddRec expressions.
11438 void ScalarEvolution::collectParametricTerms(const SCEV *Expr,
11439     SmallVectorImpl<const SCEV *> &Terms) {
11440   SmallVector<const SCEV *, 4> Strides;
11441   SCEVCollectStrides StrideCollector(*this, Strides);
11442   visitAll(Expr, StrideCollector);
11443 
11444   LLVM_DEBUG({
11445     dbgs() << "Strides:\n";
11446     for (const SCEV *S : Strides)
11447       dbgs() << *S << "\n";
11448   });
11449 
11450   for (const SCEV *S : Strides) {
11451     SCEVCollectTerms TermCollector(Terms);
11452     visitAll(S, TermCollector);
11453   }
11454 
11455   LLVM_DEBUG({
11456     dbgs() << "Terms:\n";
11457     for (const SCEV *T : Terms)
11458       dbgs() << *T << "\n";
11459   });
11460 
11461   SCEVCollectAddRecMultiplies MulCollector(Terms, *this);
11462   visitAll(Expr, MulCollector);
11463 }
11464 
11465 static bool findArrayDimensionsRec(ScalarEvolution &SE,
11466                                    SmallVectorImpl<const SCEV *> &Terms,
11467                                    SmallVectorImpl<const SCEV *> &Sizes) {
11468   int Last = Terms.size() - 1;
11469   const SCEV *Step = Terms[Last];
11470 
11471   // End of recursion.
11472   if (Last == 0) {
11473     if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Step)) {
11474       SmallVector<const SCEV *, 2> Qs;
11475       for (const SCEV *Op : M->operands())
11476         if (!isa<SCEVConstant>(Op))
11477           Qs.push_back(Op);
11478 
11479       Step = SE.getMulExpr(Qs);
11480     }
11481 
11482     Sizes.push_back(Step);
11483     return true;
11484   }
11485 
11486   for (const SCEV *&Term : Terms) {
11487     // Normalize the terms before the next call to findArrayDimensionsRec.
11488     const SCEV *Q, *R;
11489     SCEVDivision::divide(SE, Term, Step, &Q, &R);
11490 
11491     // Bail out when GCD does not evenly divide one of the terms.
11492     if (!R->isZero())
11493       return false;
11494 
11495     Term = Q;
11496   }
11497 
11498   // Remove all SCEVConstants.
11499   Terms.erase(
11500       remove_if(Terms, [](const SCEV *E) { return isa<SCEVConstant>(E); }),
11501       Terms.end());
11502 
11503   if (Terms.size() > 0)
11504     if (!findArrayDimensionsRec(SE, Terms, Sizes))
11505       return false;
11506 
11507   Sizes.push_back(Step);
11508   return true;
11509 }
11510 
11511 // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter.
11512 static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) {
11513   for (const SCEV *T : Terms)
11514     if (SCEVExprContains(T, [](const SCEV *S) { return isa<SCEVUnknown>(S); }))
11515       return true;
11516 
11517   return false;
11518 }
11519 
11520 // Return the number of product terms in S.
11521 static inline int numberOfTerms(const SCEV *S) {
11522   if (const SCEVMulExpr *Expr = dyn_cast<SCEVMulExpr>(S))
11523     return Expr->getNumOperands();
11524   return 1;
11525 }
11526 
11527 static const SCEV *removeConstantFactors(ScalarEvolution &SE, const SCEV *T) {
11528   if (isa<SCEVConstant>(T))
11529     return nullptr;
11530 
11531   if (isa<SCEVUnknown>(T))
11532     return T;
11533 
11534   if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(T)) {
11535     SmallVector<const SCEV *, 2> Factors;
11536     for (const SCEV *Op : M->operands())
11537       if (!isa<SCEVConstant>(Op))
11538         Factors.push_back(Op);
11539 
11540     return SE.getMulExpr(Factors);
11541   }
11542 
11543   return T;
11544 }
11545 
11546 /// Return the size of an element read or written by Inst.
11547 const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) {
11548   Type *Ty;
11549   if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
11550     Ty = Store->getValueOperand()->getType();
11551   else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
11552     Ty = Load->getType();
11553   else
11554     return nullptr;
11555 
11556   Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty));
11557   return getSizeOfExpr(ETy, Ty);
11558 }
11559 
11560 void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,
11561                                           SmallVectorImpl<const SCEV *> &Sizes,
11562                                           const SCEV *ElementSize) {
11563   if (Terms.size() < 1 || !ElementSize)
11564     return;
11565 
11566   // Early return when Terms do not contain parameters: we do not delinearize
11567   // non parametric SCEVs.
11568   if (!containsParameters(Terms))
11569     return;
11570 
11571   LLVM_DEBUG({
11572     dbgs() << "Terms:\n";
11573     for (const SCEV *T : Terms)
11574       dbgs() << *T << "\n";
11575   });
11576 
11577   // Remove duplicates.
11578   array_pod_sort(Terms.begin(), Terms.end());
11579   Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end());
11580 
11581   // Put larger terms first.
11582   llvm::sort(Terms, [](const SCEV *LHS, const SCEV *RHS) {
11583     return numberOfTerms(LHS) > numberOfTerms(RHS);
11584   });
11585 
11586   // Try to divide all terms by the element size. If term is not divisible by
11587   // element size, proceed with the original term.
11588   for (const SCEV *&Term : Terms) {
11589     const SCEV *Q, *R;
11590     SCEVDivision::divide(*this, Term, ElementSize, &Q, &R);
11591     if (!Q->isZero())
11592       Term = Q;
11593   }
11594 
11595   SmallVector<const SCEV *, 4> NewTerms;
11596 
11597   // Remove constant factors.
11598   for (const SCEV *T : Terms)
11599     if (const SCEV *NewT = removeConstantFactors(*this, T))
11600       NewTerms.push_back(NewT);
11601 
11602   LLVM_DEBUG({
11603     dbgs() << "Terms after sorting:\n";
11604     for (const SCEV *T : NewTerms)
11605       dbgs() << *T << "\n";
11606   });
11607 
11608   if (NewTerms.empty() || !findArrayDimensionsRec(*this, NewTerms, Sizes)) {
11609     Sizes.clear();
11610     return;
11611   }
11612 
11613   // The last element to be pushed into Sizes is the size of an element.
11614   Sizes.push_back(ElementSize);
11615 
11616   LLVM_DEBUG({
11617     dbgs() << "Sizes:\n";
11618     for (const SCEV *S : Sizes)
11619       dbgs() << *S << "\n";
11620   });
11621 }
11622 
11623 void ScalarEvolution::computeAccessFunctions(
11624     const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts,
11625     SmallVectorImpl<const SCEV *> &Sizes) {
11626   // Early exit in case this SCEV is not an affine multivariate function.
11627   if (Sizes.empty())
11628     return;
11629 
11630   if (auto *AR = dyn_cast<SCEVAddRecExpr>(Expr))
11631     if (!AR->isAffine())
11632       return;
11633 
11634   const SCEV *Res = Expr;
11635   int Last = Sizes.size() - 1;
11636   for (int i = Last; i >= 0; i--) {
11637     const SCEV *Q, *R;
11638     SCEVDivision::divide(*this, Res, Sizes[i], &Q, &R);
11639 
11640     LLVM_DEBUG({
11641       dbgs() << "Res: " << *Res << "\n";
11642       dbgs() << "Sizes[i]: " << *Sizes[i] << "\n";
11643       dbgs() << "Res divided by Sizes[i]:\n";
11644       dbgs() << "Quotient: " << *Q << "\n";
11645       dbgs() << "Remainder: " << *R << "\n";
11646     });
11647 
11648     Res = Q;
11649 
11650     // Do not record the last subscript corresponding to the size of elements in
11651     // the array.
11652     if (i == Last) {
11653 
11654       // Bail out if the remainder is too complex.
11655       if (isa<SCEVAddRecExpr>(R)) {
11656         Subscripts.clear();
11657         Sizes.clear();
11658         return;
11659       }
11660 
11661       continue;
11662     }
11663 
11664     // Record the access function for the current subscript.
11665     Subscripts.push_back(R);
11666   }
11667 
11668   // Also push in last position the remainder of the last division: it will be
11669   // the access function of the innermost dimension.
11670   Subscripts.push_back(Res);
11671 
11672   std::reverse(Subscripts.begin(), Subscripts.end());
11673 
11674   LLVM_DEBUG({
11675     dbgs() << "Subscripts:\n";
11676     for (const SCEV *S : Subscripts)
11677       dbgs() << *S << "\n";
11678   });
11679 }
11680 
11681 /// Splits the SCEV into two vectors of SCEVs representing the subscripts and
11682 /// sizes of an array access. Returns the remainder of the delinearization that
11683 /// is the offset start of the array.  The SCEV->delinearize algorithm computes
11684 /// the multiples of SCEV coefficients: that is a pattern matching of sub
11685 /// expressions in the stride and base of a SCEV corresponding to the
11686 /// computation of a GCD (greatest common divisor) of base and stride.  When
11687 /// SCEV->delinearize fails, it returns the SCEV unchanged.
11688 ///
11689 /// For example: when analyzing the memory access A[i][j][k] in this loop nest
11690 ///
11691 ///  void foo(long n, long m, long o, double A[n][m][o]) {
11692 ///
11693 ///    for (long i = 0; i < n; i++)
11694 ///      for (long j = 0; j < m; j++)
11695 ///        for (long k = 0; k < o; k++)
11696 ///          A[i][j][k] = 1.0;
11697 ///  }
11698 ///
11699 /// the delinearization input is the following AddRec SCEV:
11700 ///
11701 ///  AddRec: {{{%A,+,(8 * %m * %o)}<%for.i>,+,(8 * %o)}<%for.j>,+,8}<%for.k>
11702 ///
11703 /// From this SCEV, we are able to say that the base offset of the access is %A
11704 /// because it appears as an offset that does not divide any of the strides in
11705 /// the loops:
11706 ///
11707 ///  CHECK: Base offset: %A
11708 ///
11709 /// and then SCEV->delinearize determines the size of some of the dimensions of
11710 /// the array as these are the multiples by which the strides are happening:
11711 ///
11712 ///  CHECK: ArrayDecl[UnknownSize][%m][%o] with elements of sizeof(double) bytes.
11713 ///
11714 /// Note that the outermost dimension remains of UnknownSize because there are
11715 /// no strides that would help identifying the size of the last dimension: when
11716 /// the array has been statically allocated, one could compute the size of that
11717 /// dimension by dividing the overall size of the array by the size of the known
11718 /// dimensions: %m * %o * 8.
11719 ///
11720 /// Finally delinearize provides the access functions for the array reference
11721 /// that does correspond to A[i][j][k] of the above C testcase:
11722 ///
11723 ///  CHECK: ArrayRef[{0,+,1}<%for.i>][{0,+,1}<%for.j>][{0,+,1}<%for.k>]
11724 ///
11725 /// The testcases are checking the output of a function pass:
11726 /// DelinearizationPass that walks through all loads and stores of a function
11727 /// asking for the SCEV of the memory access with respect to all enclosing
11728 /// loops, calling SCEV->delinearize on that and printing the results.
11729 void ScalarEvolution::delinearize(const SCEV *Expr,
11730                                  SmallVectorImpl<const SCEV *> &Subscripts,
11731                                  SmallVectorImpl<const SCEV *> &Sizes,
11732                                  const SCEV *ElementSize) {
11733   // First step: collect parametric terms.
11734   SmallVector<const SCEV *, 4> Terms;
11735   collectParametricTerms(Expr, Terms);
11736 
11737   if (Terms.empty())
11738     return;
11739 
11740   // Second step: find subscript sizes.
11741   findArrayDimensions(Terms, Sizes, ElementSize);
11742 
11743   if (Sizes.empty())
11744     return;
11745 
11746   // Third step: compute the access functions for each subscript.
11747   computeAccessFunctions(Expr, Subscripts, Sizes);
11748 
11749   if (Subscripts.empty())
11750     return;
11751 
11752   LLVM_DEBUG({
11753     dbgs() << "succeeded to delinearize " << *Expr << "\n";
11754     dbgs() << "ArrayDecl[UnknownSize]";
11755     for (const SCEV *S : Sizes)
11756       dbgs() << "[" << *S << "]";
11757 
11758     dbgs() << "\nArrayRef";
11759     for (const SCEV *S : Subscripts)
11760       dbgs() << "[" << *S << "]";
11761     dbgs() << "\n";
11762   });
11763 }
11764 
11765 bool ScalarEvolution::getIndexExpressionsFromGEP(
11766     const GetElementPtrInst *GEP, SmallVectorImpl<const SCEV *> &Subscripts,
11767     SmallVectorImpl<int> &Sizes) {
11768   assert(Subscripts.empty() && Sizes.empty() &&
11769          "Expected output lists to be empty on entry to this function.");
11770   assert(GEP && "getIndexExpressionsFromGEP called with a null GEP");
11771   Type *Ty = GEP->getPointerOperandType();
11772   bool DroppedFirstDim = false;
11773   for (unsigned i = 1; i < GEP->getNumOperands(); i++) {
11774     const SCEV *Expr = getSCEV(GEP->getOperand(i));
11775     if (i == 1) {
11776       if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
11777         Ty = PtrTy->getElementType();
11778       } else if (auto *ArrayTy = dyn_cast<ArrayType>(Ty)) {
11779         Ty = ArrayTy->getElementType();
11780       } else {
11781         Subscripts.clear();
11782         Sizes.clear();
11783         return false;
11784       }
11785       if (auto *Const = dyn_cast<SCEVConstant>(Expr))
11786         if (Const->getValue()->isZero()) {
11787           DroppedFirstDim = true;
11788           continue;
11789         }
11790       Subscripts.push_back(Expr);
11791       continue;
11792     }
11793 
11794     auto *ArrayTy = dyn_cast<ArrayType>(Ty);
11795     if (!ArrayTy) {
11796       Subscripts.clear();
11797       Sizes.clear();
11798       return false;
11799     }
11800 
11801     Subscripts.push_back(Expr);
11802     if (!(DroppedFirstDim && i == 2))
11803       Sizes.push_back(ArrayTy->getNumElements());
11804 
11805     Ty = ArrayTy->getElementType();
11806   }
11807   return !Subscripts.empty();
11808 }
11809 
11810 //===----------------------------------------------------------------------===//
11811 //                   SCEVCallbackVH Class Implementation
11812 //===----------------------------------------------------------------------===//
11813 
11814 void ScalarEvolution::SCEVCallbackVH::deleted() {
11815   assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
11816   if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
11817     SE->ConstantEvolutionLoopExitValue.erase(PN);
11818   SE->eraseValueFromMap(getValPtr());
11819   // this now dangles!
11820 }
11821 
11822 void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
11823   assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
11824 
11825   // Forget all the expressions associated with users of the old value,
11826   // so that future queries will recompute the expressions using the new
11827   // value.
11828   Value *Old = getValPtr();
11829   SmallVector<User *, 16> Worklist(Old->user_begin(), Old->user_end());
11830   SmallPtrSet<User *, 8> Visited;
11831   while (!Worklist.empty()) {
11832     User *U = Worklist.pop_back_val();
11833     // Deleting the Old value will cause this to dangle. Postpone
11834     // that until everything else is done.
11835     if (U == Old)
11836       continue;
11837     if (!Visited.insert(U).second)
11838       continue;
11839     if (PHINode *PN = dyn_cast<PHINode>(U))
11840       SE->ConstantEvolutionLoopExitValue.erase(PN);
11841     SE->eraseValueFromMap(U);
11842     Worklist.insert(Worklist.end(), U->user_begin(), U->user_end());
11843   }
11844   // Delete the Old value.
11845   if (PHINode *PN = dyn_cast<PHINode>(Old))
11846     SE->ConstantEvolutionLoopExitValue.erase(PN);
11847   SE->eraseValueFromMap(Old);
11848   // this now dangles!
11849 }
11850 
11851 ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
11852   : CallbackVH(V), SE(se) {}
11853 
11854 //===----------------------------------------------------------------------===//
11855 //                   ScalarEvolution Class Implementation
11856 //===----------------------------------------------------------------------===//
11857 
11858 ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI,
11859                                  AssumptionCache &AC, DominatorTree &DT,
11860                                  LoopInfo &LI)
11861     : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI),
11862       CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
11863       LoopDispositions(64), BlockDispositions(64) {
11864   // To use guards for proving predicates, we need to scan every instruction in
11865   // relevant basic blocks, and not just terminators.  Doing this is a waste of
11866   // time if the IR does not actually contain any calls to
11867   // @llvm.experimental.guard, so do a quick check and remember this beforehand.
11868   //
11869   // This pessimizes the case where a pass that preserves ScalarEvolution wants
11870   // to _add_ guards to the module when there weren't any before, and wants
11871   // ScalarEvolution to optimize based on those guards.  For now we prefer to be
11872   // efficient in lieu of being smart in that rather obscure case.
11873 
11874   auto *GuardDecl = F.getParent()->getFunction(
11875       Intrinsic::getName(Intrinsic::experimental_guard));
11876   HasGuards = GuardDecl && !GuardDecl->use_empty();
11877 }
11878 
11879 ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
11880     : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),
11881       LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
11882       ValueExprMap(std::move(Arg.ValueExprMap)),
11883       PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
11884       PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
11885       PendingMerges(std::move(Arg.PendingMerges)),
11886       MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)),
11887       BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
11888       PredicatedBackedgeTakenCounts(
11889           std::move(Arg.PredicatedBackedgeTakenCounts)),
11890       ConstantEvolutionLoopExitValue(
11891           std::move(Arg.ConstantEvolutionLoopExitValue)),
11892       ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
11893       LoopDispositions(std::move(Arg.LoopDispositions)),
11894       LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
11895       BlockDispositions(std::move(Arg.BlockDispositions)),
11896       UnsignedRanges(std::move(Arg.UnsignedRanges)),
11897       SignedRanges(std::move(Arg.SignedRanges)),
11898       UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
11899       UniquePreds(std::move(Arg.UniquePreds)),
11900       SCEVAllocator(std::move(Arg.SCEVAllocator)),
11901       LoopUsers(std::move(Arg.LoopUsers)),
11902       PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
11903       FirstUnknown(Arg.FirstUnknown) {
11904   Arg.FirstUnknown = nullptr;
11905 }
11906 
11907 ScalarEvolution::~ScalarEvolution() {
11908   // Iterate through all the SCEVUnknown instances and call their
11909   // destructors, so that they release their references to their values.
11910   for (SCEVUnknown *U = FirstUnknown; U;) {
11911     SCEVUnknown *Tmp = U;
11912     U = U->Next;
11913     Tmp->~SCEVUnknown();
11914   }
11915   FirstUnknown = nullptr;
11916 
11917   ExprValueMap.clear();
11918   ValueExprMap.clear();
11919   HasRecMap.clear();
11920 
11921   // Free any extra memory created for ExitNotTakenInfo in the unlikely event
11922   // that a loop had multiple computable exits.
11923   for (auto &BTCI : BackedgeTakenCounts)
11924     BTCI.second.clear();
11925   for (auto &BTCI : PredicatedBackedgeTakenCounts)
11926     BTCI.second.clear();
11927 
11928   assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
11929   assert(PendingPhiRanges.empty() && "getRangeRef garbage");
11930   assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
11931   assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
11932   assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
11933 }
11934 
11935 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
11936   return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
11937 }
11938 
11939 static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
11940                           const Loop *L) {
11941   // Print all inner loops first
11942   for (Loop *I : *L)
11943     PrintLoopInfo(OS, SE, I);
11944 
11945   OS << "Loop ";
11946   L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
11947   OS << ": ";
11948 
11949   SmallVector<BasicBlock *, 8> ExitingBlocks;
11950   L->getExitingBlocks(ExitingBlocks);
11951   if (ExitingBlocks.size() != 1)
11952     OS << "<multiple exits> ";
11953 
11954   if (SE->hasLoopInvariantBackedgeTakenCount(L))
11955     OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L) << "\n";
11956   else
11957     OS << "Unpredictable backedge-taken count.\n";
11958 
11959   if (ExitingBlocks.size() > 1)
11960     for (BasicBlock *ExitingBlock : ExitingBlocks) {
11961       OS << "  exit count for " << ExitingBlock->getName() << ": "
11962          << *SE->getExitCount(L, ExitingBlock) << "\n";
11963     }
11964 
11965   OS << "Loop ";
11966   L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
11967   OS << ": ";
11968 
11969   if (!isa<SCEVCouldNotCompute>(SE->getConstantMaxBackedgeTakenCount(L))) {
11970     OS << "max backedge-taken count is " << *SE->getConstantMaxBackedgeTakenCount(L);
11971     if (SE->isBackedgeTakenCountMaxOrZero(L))
11972       OS << ", actual taken count either this or zero.";
11973   } else {
11974     OS << "Unpredictable max backedge-taken count. ";
11975   }
11976 
11977   OS << "\n"
11978         "Loop ";
11979   L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
11980   OS << ": ";
11981 
11982   SCEVUnionPredicate Pred;
11983   auto PBT = SE->getPredicatedBackedgeTakenCount(L, Pred);
11984   if (!isa<SCEVCouldNotCompute>(PBT)) {
11985     OS << "Predicated backedge-taken count is " << *PBT << "\n";
11986     OS << " Predicates:\n";
11987     Pred.print(OS, 4);
11988   } else {
11989     OS << "Unpredictable predicated backedge-taken count. ";
11990   }
11991   OS << "\n";
11992 
11993   if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
11994     OS << "Loop ";
11995     L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
11996     OS << ": ";
11997     OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
11998   }
11999 }
12000 
12001 static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) {
12002   switch (LD) {
12003   case ScalarEvolution::LoopVariant:
12004     return "Variant";
12005   case ScalarEvolution::LoopInvariant:
12006     return "Invariant";
12007   case ScalarEvolution::LoopComputable:
12008     return "Computable";
12009   }
12010   llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!");
12011 }
12012 
12013 void ScalarEvolution::print(raw_ostream &OS) const {
12014   // ScalarEvolution's implementation of the print method is to print
12015   // out SCEV values of all instructions that are interesting. Doing
12016   // this potentially causes it to create new SCEV objects though,
12017   // which technically conflicts with the const qualifier. This isn't
12018   // observable from outside the class though, so casting away the
12019   // const isn't dangerous.
12020   ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
12021 
12022   if (ClassifyExpressions) {
12023     OS << "Classifying expressions for: ";
12024     F.printAsOperand(OS, /*PrintType=*/false);
12025     OS << "\n";
12026     for (Instruction &I : instructions(F))
12027       if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
12028         OS << I << '\n';
12029         OS << "  -->  ";
12030         const SCEV *SV = SE.getSCEV(&I);
12031         SV->print(OS);
12032         if (!isa<SCEVCouldNotCompute>(SV)) {
12033           OS << " U: ";
12034           SE.getUnsignedRange(SV).print(OS);
12035           OS << " S: ";
12036           SE.getSignedRange(SV).print(OS);
12037         }
12038 
12039         const Loop *L = LI.getLoopFor(I.getParent());
12040 
12041         const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
12042         if (AtUse != SV) {
12043           OS << "  -->  ";
12044           AtUse->print(OS);
12045           if (!isa<SCEVCouldNotCompute>(AtUse)) {
12046             OS << " U: ";
12047             SE.getUnsignedRange(AtUse).print(OS);
12048             OS << " S: ";
12049             SE.getSignedRange(AtUse).print(OS);
12050           }
12051         }
12052 
12053         if (L) {
12054           OS << "\t\t" "Exits: ";
12055           const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
12056           if (!SE.isLoopInvariant(ExitValue, L)) {
12057             OS << "<<Unknown>>";
12058           } else {
12059             OS << *ExitValue;
12060           }
12061 
12062           bool First = true;
12063           for (auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
12064             if (First) {
12065               OS << "\t\t" "LoopDispositions: { ";
12066               First = false;
12067             } else {
12068               OS << ", ";
12069             }
12070 
12071             Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
12072             OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter));
12073           }
12074 
12075           for (auto *InnerL : depth_first(L)) {
12076             if (InnerL == L)
12077               continue;
12078             if (First) {
12079               OS << "\t\t" "LoopDispositions: { ";
12080               First = false;
12081             } else {
12082               OS << ", ";
12083             }
12084 
12085             InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
12086             OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL));
12087           }
12088 
12089           OS << " }";
12090         }
12091 
12092         OS << "\n";
12093       }
12094   }
12095 
12096   OS << "Determining loop execution counts for: ";
12097   F.printAsOperand(OS, /*PrintType=*/false);
12098   OS << "\n";
12099   for (Loop *I : LI)
12100     PrintLoopInfo(OS, &SE, I);
12101 }
12102 
12103 ScalarEvolution::LoopDisposition
12104 ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
12105   auto &Values = LoopDispositions[S];
12106   for (auto &V : Values) {
12107     if (V.getPointer() == L)
12108       return V.getInt();
12109   }
12110   Values.emplace_back(L, LoopVariant);
12111   LoopDisposition D = computeLoopDisposition(S, L);
12112   auto &Values2 = LoopDispositions[S];
12113   for (auto &V : make_range(Values2.rbegin(), Values2.rend())) {
12114     if (V.getPointer() == L) {
12115       V.setInt(D);
12116       break;
12117     }
12118   }
12119   return D;
12120 }
12121 
12122 ScalarEvolution::LoopDisposition
12123 ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
12124   switch (S->getSCEVType()) {
12125   case scConstant:
12126     return LoopInvariant;
12127   case scPtrToInt:
12128   case scTruncate:
12129   case scZeroExtend:
12130   case scSignExtend:
12131     return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L);
12132   case scAddRecExpr: {
12133     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
12134 
12135     // If L is the addrec's loop, it's computable.
12136     if (AR->getLoop() == L)
12137       return LoopComputable;
12138 
12139     // Add recurrences are never invariant in the function-body (null loop).
12140     if (!L)
12141       return LoopVariant;
12142 
12143     // Everything that is not defined at loop entry is variant.
12144     if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
12145       return LoopVariant;
12146     assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
12147            " dominate the contained loop's header?");
12148 
12149     // This recurrence is invariant w.r.t. L if AR's loop contains L.
12150     if (AR->getLoop()->contains(L))
12151       return LoopInvariant;
12152 
12153     // This recurrence is variant w.r.t. L if any of its operands
12154     // are variant.
12155     for (auto *Op : AR->operands())
12156       if (!isLoopInvariant(Op, L))
12157         return LoopVariant;
12158 
12159     // Otherwise it's loop-invariant.
12160     return LoopInvariant;
12161   }
12162   case scAddExpr:
12163   case scMulExpr:
12164   case scUMaxExpr:
12165   case scSMaxExpr:
12166   case scUMinExpr:
12167   case scSMinExpr: {
12168     bool HasVarying = false;
12169     for (auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
12170       LoopDisposition D = getLoopDisposition(Op, L);
12171       if (D == LoopVariant)
12172         return LoopVariant;
12173       if (D == LoopComputable)
12174         HasVarying = true;
12175     }
12176     return HasVarying ? LoopComputable : LoopInvariant;
12177   }
12178   case scUDivExpr: {
12179     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
12180     LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L);
12181     if (LD == LoopVariant)
12182       return LoopVariant;
12183     LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L);
12184     if (RD == LoopVariant)
12185       return LoopVariant;
12186     return (LD == LoopInvariant && RD == LoopInvariant) ?
12187            LoopInvariant : LoopComputable;
12188   }
12189   case scUnknown:
12190     // All non-instruction values are loop invariant.  All instructions are loop
12191     // invariant if they are not contained in the specified loop.
12192     // Instructions are never considered invariant in the function body
12193     // (null loop) because they are defined within the "loop".
12194     if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
12195       return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
12196     return LoopInvariant;
12197   case scCouldNotCompute:
12198     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
12199   }
12200   llvm_unreachable("Unknown SCEV kind!");
12201 }
12202 
12203 bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) {
12204   return getLoopDisposition(S, L) == LoopInvariant;
12205 }
12206 
12207 bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) {
12208   return getLoopDisposition(S, L) == LoopComputable;
12209 }
12210 
12211 ScalarEvolution::BlockDisposition
12212 ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
12213   auto &Values = BlockDispositions[S];
12214   for (auto &V : Values) {
12215     if (V.getPointer() == BB)
12216       return V.getInt();
12217   }
12218   Values.emplace_back(BB, DoesNotDominateBlock);
12219   BlockDisposition D = computeBlockDisposition(S, BB);
12220   auto &Values2 = BlockDispositions[S];
12221   for (auto &V : make_range(Values2.rbegin(), Values2.rend())) {
12222     if (V.getPointer() == BB) {
12223       V.setInt(D);
12224       break;
12225     }
12226   }
12227   return D;
12228 }
12229 
12230 ScalarEvolution::BlockDisposition
12231 ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
12232   switch (S->getSCEVType()) {
12233   case scConstant:
12234     return ProperlyDominatesBlock;
12235   case scPtrToInt:
12236   case scTruncate:
12237   case scZeroExtend:
12238   case scSignExtend:
12239     return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB);
12240   case scAddRecExpr: {
12241     // This uses a "dominates" query instead of "properly dominates" query
12242     // to test for proper dominance too, because the instruction which
12243     // produces the addrec's value is a PHI, and a PHI effectively properly
12244     // dominates its entire containing block.
12245     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
12246     if (!DT.dominates(AR->getLoop()->getHeader(), BB))
12247       return DoesNotDominateBlock;
12248 
12249     // Fall through into SCEVNAryExpr handling.
12250     LLVM_FALLTHROUGH;
12251   }
12252   case scAddExpr:
12253   case scMulExpr:
12254   case scUMaxExpr:
12255   case scSMaxExpr:
12256   case scUMinExpr:
12257   case scSMinExpr: {
12258     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
12259     bool Proper = true;
12260     for (const SCEV *NAryOp : NAry->operands()) {
12261       BlockDisposition D = getBlockDisposition(NAryOp, BB);
12262       if (D == DoesNotDominateBlock)
12263         return DoesNotDominateBlock;
12264       if (D == DominatesBlock)
12265         Proper = false;
12266     }
12267     return Proper ? ProperlyDominatesBlock : DominatesBlock;
12268   }
12269   case scUDivExpr: {
12270     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
12271     const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
12272     BlockDisposition LD = getBlockDisposition(LHS, BB);
12273     if (LD == DoesNotDominateBlock)
12274       return DoesNotDominateBlock;
12275     BlockDisposition RD = getBlockDisposition(RHS, BB);
12276     if (RD == DoesNotDominateBlock)
12277       return DoesNotDominateBlock;
12278     return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ?
12279       ProperlyDominatesBlock : DominatesBlock;
12280   }
12281   case scUnknown:
12282     if (Instruction *I =
12283           dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
12284       if (I->getParent() == BB)
12285         return DominatesBlock;
12286       if (DT.properlyDominates(I->getParent(), BB))
12287         return ProperlyDominatesBlock;
12288       return DoesNotDominateBlock;
12289     }
12290     return ProperlyDominatesBlock;
12291   case scCouldNotCompute:
12292     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
12293   }
12294   llvm_unreachable("Unknown SCEV kind!");
12295 }
12296 
12297 bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
12298   return getBlockDisposition(S, BB) >= DominatesBlock;
12299 }
12300 
12301 bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
12302   return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
12303 }
12304 
12305 bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
12306   return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
12307 }
12308 
12309 bool ScalarEvolution::ExitLimit::hasOperand(const SCEV *S) const {
12310   auto IsS = [&](const SCEV *X) { return S == X; };
12311   auto ContainsS = [&](const SCEV *X) {
12312     return !isa<SCEVCouldNotCompute>(X) && SCEVExprContains(X, IsS);
12313   };
12314   return ContainsS(ExactNotTaken) || ContainsS(MaxNotTaken);
12315 }
12316 
12317 void
12318 ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
12319   ValuesAtScopes.erase(S);
12320   LoopDispositions.erase(S);
12321   BlockDispositions.erase(S);
12322   UnsignedRanges.erase(S);
12323   SignedRanges.erase(S);
12324   ExprValueMap.erase(S);
12325   HasRecMap.erase(S);
12326   MinTrailingZerosCache.erase(S);
12327 
12328   for (auto I = PredicatedSCEVRewrites.begin();
12329        I != PredicatedSCEVRewrites.end();) {
12330     std::pair<const SCEV *, const Loop *> Entry = I->first;
12331     if (Entry.first == S)
12332       PredicatedSCEVRewrites.erase(I++);
12333     else
12334       ++I;
12335   }
12336 
12337   auto RemoveSCEVFromBackedgeMap =
12338       [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) {
12339         for (auto I = Map.begin(), E = Map.end(); I != E;) {
12340           BackedgeTakenInfo &BEInfo = I->second;
12341           if (BEInfo.hasOperand(S, this)) {
12342             BEInfo.clear();
12343             Map.erase(I++);
12344           } else
12345             ++I;
12346         }
12347       };
12348 
12349   RemoveSCEVFromBackedgeMap(BackedgeTakenCounts);
12350   RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts);
12351 }
12352 
12353 void
12354 ScalarEvolution::getUsedLoops(const SCEV *S,
12355                               SmallPtrSetImpl<const Loop *> &LoopsUsed) {
12356   struct FindUsedLoops {
12357     FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
12358         : LoopsUsed(LoopsUsed) {}
12359     SmallPtrSetImpl<const Loop *> &LoopsUsed;
12360     bool follow(const SCEV *S) {
12361       if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
12362         LoopsUsed.insert(AR->getLoop());
12363       return true;
12364     }
12365 
12366     bool isDone() const { return false; }
12367   };
12368 
12369   FindUsedLoops F(LoopsUsed);
12370   SCEVTraversal<FindUsedLoops>(F).visitAll(S);
12371 }
12372 
12373 void ScalarEvolution::addToLoopUseLists(const SCEV *S) {
12374   SmallPtrSet<const Loop *, 8> LoopsUsed;
12375   getUsedLoops(S, LoopsUsed);
12376   for (auto *L : LoopsUsed)
12377     LoopUsers[L].push_back(S);
12378 }
12379 
12380 void ScalarEvolution::verify() const {
12381   ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
12382   ScalarEvolution SE2(F, TLI, AC, DT, LI);
12383 
12384   SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
12385 
12386   // Map's SCEV expressions from one ScalarEvolution "universe" to another.
12387   struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
12388     SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
12389 
12390     const SCEV *visitConstant(const SCEVConstant *Constant) {
12391       return SE.getConstant(Constant->getAPInt());
12392     }
12393 
12394     const SCEV *visitUnknown(const SCEVUnknown *Expr) {
12395       return SE.getUnknown(Expr->getValue());
12396     }
12397 
12398     const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
12399       return SE.getCouldNotCompute();
12400     }
12401   };
12402 
12403   SCEVMapper SCM(SE2);
12404 
12405   while (!LoopStack.empty()) {
12406     auto *L = LoopStack.pop_back_val();
12407     LoopStack.insert(LoopStack.end(), L->begin(), L->end());
12408 
12409     auto *CurBECount = SCM.visit(
12410         const_cast<ScalarEvolution *>(this)->getBackedgeTakenCount(L));
12411     auto *NewBECount = SE2.getBackedgeTakenCount(L);
12412 
12413     if (CurBECount == SE2.getCouldNotCompute() ||
12414         NewBECount == SE2.getCouldNotCompute()) {
12415       // NB! This situation is legal, but is very suspicious -- whatever pass
12416       // change the loop to make a trip count go from could not compute to
12417       // computable or vice-versa *should have* invalidated SCEV.  However, we
12418       // choose not to assert here (for now) since we don't want false
12419       // positives.
12420       continue;
12421     }
12422 
12423     if (containsUndefs(CurBECount) || containsUndefs(NewBECount)) {
12424       // SCEV treats "undef" as an unknown but consistent value (i.e. it does
12425       // not propagate undef aggressively).  This means we can (and do) fail
12426       // verification in cases where a transform makes the trip count of a loop
12427       // go from "undef" to "undef+1" (say).  The transform is fine, since in
12428       // both cases the loop iterates "undef" times, but SCEV thinks we
12429       // increased the trip count of the loop by 1 incorrectly.
12430       continue;
12431     }
12432 
12433     if (SE.getTypeSizeInBits(CurBECount->getType()) >
12434         SE.getTypeSizeInBits(NewBECount->getType()))
12435       NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
12436     else if (SE.getTypeSizeInBits(CurBECount->getType()) <
12437              SE.getTypeSizeInBits(NewBECount->getType()))
12438       CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
12439 
12440     const SCEV *Delta = SE2.getMinusSCEV(CurBECount, NewBECount);
12441 
12442     // Unless VerifySCEVStrict is set, we only compare constant deltas.
12443     if ((VerifySCEVStrict || isa<SCEVConstant>(Delta)) && !Delta->isZero()) {
12444       dbgs() << "Trip Count for " << *L << " Changed!\n";
12445       dbgs() << "Old: " << *CurBECount << "\n";
12446       dbgs() << "New: " << *NewBECount << "\n";
12447       dbgs() << "Delta: " << *Delta << "\n";
12448       std::abort();
12449     }
12450   }
12451 
12452   // Collect all valid loops currently in LoopInfo.
12453   SmallPtrSet<Loop *, 32> ValidLoops;
12454   SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
12455   while (!Worklist.empty()) {
12456     Loop *L = Worklist.pop_back_val();
12457     if (ValidLoops.contains(L))
12458       continue;
12459     ValidLoops.insert(L);
12460     Worklist.append(L->begin(), L->end());
12461   }
12462   // Check for SCEV expressions referencing invalid/deleted loops.
12463   for (auto &KV : ValueExprMap) {
12464     auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second);
12465     if (!AR)
12466       continue;
12467     assert(ValidLoops.contains(AR->getLoop()) &&
12468            "AddRec references invalid loop");
12469   }
12470 }
12471 
12472 bool ScalarEvolution::invalidate(
12473     Function &F, const PreservedAnalyses &PA,
12474     FunctionAnalysisManager::Invalidator &Inv) {
12475   // Invalidate the ScalarEvolution object whenever it isn't preserved or one
12476   // of its dependencies is invalidated.
12477   auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
12478   return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
12479          Inv.invalidate<AssumptionAnalysis>(F, PA) ||
12480          Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
12481          Inv.invalidate<LoopAnalysis>(F, PA);
12482 }
12483 
12484 AnalysisKey ScalarEvolutionAnalysis::Key;
12485 
12486 ScalarEvolution ScalarEvolutionAnalysis::run(Function &F,
12487                                              FunctionAnalysisManager &AM) {
12488   return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F),
12489                          AM.getResult<AssumptionAnalysis>(F),
12490                          AM.getResult<DominatorTreeAnalysis>(F),
12491                          AM.getResult<LoopAnalysis>(F));
12492 }
12493 
12494 PreservedAnalyses
12495 ScalarEvolutionVerifierPass::run(Function &F, FunctionAnalysisManager &AM) {
12496   AM.getResult<ScalarEvolutionAnalysis>(F).verify();
12497   return PreservedAnalyses::all();
12498 }
12499 
12500 PreservedAnalyses
12501 ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
12502   // For compatibility with opt's -analyze feature under legacy pass manager
12503   // which was not ported to NPM. This keeps tests using
12504   // update_analyze_test_checks.py working.
12505   OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
12506      << F.getName() << "':\n";
12507   AM.getResult<ScalarEvolutionAnalysis>(F).print(OS);
12508   return PreservedAnalyses::all();
12509 }
12510 
12511 INITIALIZE_PASS_BEGIN(ScalarEvolutionWrapperPass, "scalar-evolution",
12512                       "Scalar Evolution Analysis", false, true)
12513 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
12514 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
12515 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
12516 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
12517 INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution",
12518                     "Scalar Evolution Analysis", false, true)
12519 
12520 char ScalarEvolutionWrapperPass::ID = 0;
12521 
12522 ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) {
12523   initializeScalarEvolutionWrapperPassPass(*PassRegistry::getPassRegistry());
12524 }
12525 
12526 bool ScalarEvolutionWrapperPass::runOnFunction(Function &F) {
12527   SE.reset(new ScalarEvolution(
12528       F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
12529       getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
12530       getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
12531       getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
12532   return false;
12533 }
12534 
12535 void ScalarEvolutionWrapperPass::releaseMemory() { SE.reset(); }
12536 
12537 void ScalarEvolutionWrapperPass::print(raw_ostream &OS, const Module *) const {
12538   SE->print(OS);
12539 }
12540 
12541 void ScalarEvolutionWrapperPass::verifyAnalysis() const {
12542   if (!VerifySCEV)
12543     return;
12544 
12545   SE->verify();
12546 }
12547 
12548 void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
12549   AU.setPreservesAll();
12550   AU.addRequiredTransitive<AssumptionCacheTracker>();
12551   AU.addRequiredTransitive<LoopInfoWrapperPass>();
12552   AU.addRequiredTransitive<DominatorTreeWrapperPass>();
12553   AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
12554 }
12555 
12556 const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS,
12557                                                         const SCEV *RHS) {
12558   FoldingSetNodeID ID;
12559   assert(LHS->getType() == RHS->getType() &&
12560          "Type mismatch between LHS and RHS");
12561   // Unique this node based on the arguments
12562   ID.AddInteger(SCEVPredicate::P_Equal);
12563   ID.AddPointer(LHS);
12564   ID.AddPointer(RHS);
12565   void *IP = nullptr;
12566   if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
12567     return S;
12568   SCEVEqualPredicate *Eq = new (SCEVAllocator)
12569       SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS);
12570   UniquePreds.InsertNode(Eq, IP);
12571   return Eq;
12572 }
12573 
12574 const SCEVPredicate *ScalarEvolution::getWrapPredicate(
12575     const SCEVAddRecExpr *AR,
12576     SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
12577   FoldingSetNodeID ID;
12578   // Unique this node based on the arguments
12579   ID.AddInteger(SCEVPredicate::P_Wrap);
12580   ID.AddPointer(AR);
12581   ID.AddInteger(AddedFlags);
12582   void *IP = nullptr;
12583   if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
12584     return S;
12585   auto *OF = new (SCEVAllocator)
12586       SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
12587   UniquePreds.InsertNode(OF, IP);
12588   return OF;
12589 }
12590 
12591 namespace {
12592 
12593 class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
12594 public:
12595 
12596   /// Rewrites \p S in the context of a loop L and the SCEV predication
12597   /// infrastructure.
12598   ///
12599   /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
12600   /// equivalences present in \p Pred.
12601   ///
12602   /// If \p NewPreds is non-null, rewrite is free to add further predicates to
12603   /// \p NewPreds such that the result will be an AddRecExpr.
12604   static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
12605                              SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
12606                              SCEVUnionPredicate *Pred) {
12607     SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
12608     return Rewriter.visit(S);
12609   }
12610 
12611   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
12612     if (Pred) {
12613       auto ExprPreds = Pred->getPredicatesForExpr(Expr);
12614       for (auto *Pred : ExprPreds)
12615         if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
12616           if (IPred->getLHS() == Expr)
12617             return IPred->getRHS();
12618     }
12619     return convertToAddRecWithPreds(Expr);
12620   }
12621 
12622   const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
12623     const SCEV *Operand = visit(Expr->getOperand());
12624     const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
12625     if (AR && AR->getLoop() == L && AR->isAffine()) {
12626       // This couldn't be folded because the operand didn't have the nuw
12627       // flag. Add the nusw flag as an assumption that we could make.
12628       const SCEV *Step = AR->getStepRecurrence(SE);
12629       Type *Ty = Expr->getType();
12630       if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
12631         return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
12632                                 SE.getSignExtendExpr(Step, Ty), L,
12633                                 AR->getNoWrapFlags());
12634     }
12635     return SE.getZeroExtendExpr(Operand, Expr->getType());
12636   }
12637 
12638   const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
12639     const SCEV *Operand = visit(Expr->getOperand());
12640     const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
12641     if (AR && AR->getLoop() == L && AR->isAffine()) {
12642       // This couldn't be folded because the operand didn't have the nsw
12643       // flag. Add the nssw flag as an assumption that we could make.
12644       const SCEV *Step = AR->getStepRecurrence(SE);
12645       Type *Ty = Expr->getType();
12646       if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
12647         return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
12648                                 SE.getSignExtendExpr(Step, Ty), L,
12649                                 AR->getNoWrapFlags());
12650     }
12651     return SE.getSignExtendExpr(Operand, Expr->getType());
12652   }
12653 
12654 private:
12655   explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
12656                         SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
12657                         SCEVUnionPredicate *Pred)
12658       : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
12659 
12660   bool addOverflowAssumption(const SCEVPredicate *P) {
12661     if (!NewPreds) {
12662       // Check if we've already made this assumption.
12663       return Pred && Pred->implies(P);
12664     }
12665     NewPreds->insert(P);
12666     return true;
12667   }
12668 
12669   bool addOverflowAssumption(const SCEVAddRecExpr *AR,
12670                              SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
12671     auto *A = SE.getWrapPredicate(AR, AddedFlags);
12672     return addOverflowAssumption(A);
12673   }
12674 
12675   // If \p Expr represents a PHINode, we try to see if it can be represented
12676   // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
12677   // to add this predicate as a runtime overflow check, we return the AddRec.
12678   // If \p Expr does not meet these conditions (is not a PHI node, or we
12679   // couldn't create an AddRec for it, or couldn't add the predicate), we just
12680   // return \p Expr.
12681   const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
12682     if (!isa<PHINode>(Expr->getValue()))
12683       return Expr;
12684     Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
12685     PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
12686     if (!PredicatedRewrite)
12687       return Expr;
12688     for (auto *P : PredicatedRewrite->second){
12689       // Wrap predicates from outer loops are not supported.
12690       if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
12691         auto *AR = cast<const SCEVAddRecExpr>(WP->getExpr());
12692         if (L != AR->getLoop())
12693           return Expr;
12694       }
12695       if (!addOverflowAssumption(P))
12696         return Expr;
12697     }
12698     return PredicatedRewrite->first;
12699   }
12700 
12701   SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
12702   SCEVUnionPredicate *Pred;
12703   const Loop *L;
12704 };
12705 
12706 } // end anonymous namespace
12707 
12708 const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
12709                                                    SCEVUnionPredicate &Preds) {
12710   return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
12711 }
12712 
12713 const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
12714     const SCEV *S, const Loop *L,
12715     SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
12716   SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
12717   S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
12718   auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
12719 
12720   if (!AddRec)
12721     return nullptr;
12722 
12723   // Since the transformation was successful, we can now transfer the SCEV
12724   // predicates.
12725   for (auto *P : TransformPreds)
12726     Preds.insert(P);
12727 
12728   return AddRec;
12729 }
12730 
12731 /// SCEV predicates
12732 SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
12733                              SCEVPredicateKind Kind)
12734     : FastID(ID), Kind(Kind) {}
12735 
12736 SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
12737                                        const SCEV *LHS, const SCEV *RHS)
12738     : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {
12739   assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
12740   assert(LHS != RHS && "LHS and RHS are the same SCEV");
12741 }
12742 
12743 bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
12744   const auto *Op = dyn_cast<SCEVEqualPredicate>(N);
12745 
12746   if (!Op)
12747     return false;
12748 
12749   return Op->LHS == LHS && Op->RHS == RHS;
12750 }
12751 
12752 bool SCEVEqualPredicate::isAlwaysTrue() const { return false; }
12753 
12754 const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; }
12755 
12756 void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
12757   OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
12758 }
12759 
12760 SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
12761                                      const SCEVAddRecExpr *AR,
12762                                      IncrementWrapFlags Flags)
12763     : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
12764 
12765 const SCEV *SCEVWrapPredicate::getExpr() const { return AR; }
12766 
12767 bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
12768   const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
12769 
12770   return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
12771 }
12772 
12773 bool SCEVWrapPredicate::isAlwaysTrue() const {
12774   SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
12775   IncrementWrapFlags IFlags = Flags;
12776 
12777   if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
12778     IFlags = clearFlags(IFlags, IncrementNSSW);
12779 
12780   return IFlags == IncrementAnyWrap;
12781 }
12782 
12783 void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
12784   OS.indent(Depth) << *getExpr() << " Added Flags: ";
12785   if (SCEVWrapPredicate::IncrementNUSW & getFlags())
12786     OS << "<nusw>";
12787   if (SCEVWrapPredicate::IncrementNSSW & getFlags())
12788     OS << "<nssw>";
12789   OS << "\n";
12790 }
12791 
12792 SCEVWrapPredicate::IncrementWrapFlags
12793 SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
12794                                    ScalarEvolution &SE) {
12795   IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
12796   SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
12797 
12798   // We can safely transfer the NSW flag as NSSW.
12799   if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
12800     ImpliedFlags = IncrementNSSW;
12801 
12802   if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
12803     // If the increment is positive, the SCEV NUW flag will also imply the
12804     // WrapPredicate NUSW flag.
12805     if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
12806       if (Step->getValue()->getValue().isNonNegative())
12807         ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
12808   }
12809 
12810   return ImpliedFlags;
12811 }
12812 
12813 /// Union predicates don't get cached so create a dummy set ID for it.
12814 SCEVUnionPredicate::SCEVUnionPredicate()
12815     : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
12816 
12817 bool SCEVUnionPredicate::isAlwaysTrue() const {
12818   return all_of(Preds,
12819                 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
12820 }
12821 
12822 ArrayRef<const SCEVPredicate *>
12823 SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
12824   auto I = SCEVToPreds.find(Expr);
12825   if (I == SCEVToPreds.end())
12826     return ArrayRef<const SCEVPredicate *>();
12827   return I->second;
12828 }
12829 
12830 bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
12831   if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
12832     return all_of(Set->Preds,
12833                   [this](const SCEVPredicate *I) { return this->implies(I); });
12834 
12835   auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
12836   if (ScevPredsIt == SCEVToPreds.end())
12837     return false;
12838   auto &SCEVPreds = ScevPredsIt->second;
12839 
12840   return any_of(SCEVPreds,
12841                 [N](const SCEVPredicate *I) { return I->implies(N); });
12842 }
12843 
12844 const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }
12845 
12846 void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
12847   for (auto Pred : Preds)
12848     Pred->print(OS, Depth);
12849 }
12850 
12851 void SCEVUnionPredicate::add(const SCEVPredicate *N) {
12852   if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
12853     for (auto Pred : Set->Preds)
12854       add(Pred);
12855     return;
12856   }
12857 
12858   if (implies(N))
12859     return;
12860 
12861   const SCEV *Key = N->getExpr();
12862   assert(Key && "Only SCEVUnionPredicate doesn't have an "
12863                 " associated expression!");
12864 
12865   SCEVToPreds[Key].push_back(N);
12866   Preds.push_back(N);
12867 }
12868 
12869 PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
12870                                                      Loop &L)
12871     : SE(SE), L(L) {}
12872 
12873 const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
12874   const SCEV *Expr = SE.getSCEV(V);
12875   RewriteEntry &Entry = RewriteMap[Expr];
12876 
12877   // If we already have an entry and the version matches, return it.
12878   if (Entry.second && Generation == Entry.first)
12879     return Entry.second;
12880 
12881   // We found an entry but it's stale. Rewrite the stale entry
12882   // according to the current predicate.
12883   if (Entry.second)
12884     Expr = Entry.second;
12885 
12886   const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds);
12887   Entry = {Generation, NewSCEV};
12888 
12889   return NewSCEV;
12890 }
12891 
12892 const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() {
12893   if (!BackedgeCount) {
12894     SCEVUnionPredicate BackedgePred;
12895     BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, BackedgePred);
12896     addPredicate(BackedgePred);
12897   }
12898   return BackedgeCount;
12899 }
12900 
12901 void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
12902   if (Preds.implies(&Pred))
12903     return;
12904   Preds.add(&Pred);
12905   updateGeneration();
12906 }
12907 
12908 const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const {
12909   return Preds;
12910 }
12911 
12912 void PredicatedScalarEvolution::updateGeneration() {
12913   // If the generation number wrapped recompute everything.
12914   if (++Generation == 0) {
12915     for (auto &II : RewriteMap) {
12916       const SCEV *Rewritten = II.second.second;
12917       II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)};
12918     }
12919   }
12920 }
12921 
12922 void PredicatedScalarEvolution::setNoOverflow(
12923     Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
12924   const SCEV *Expr = getSCEV(V);
12925   const auto *AR = cast<SCEVAddRecExpr>(Expr);
12926 
12927   auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
12928 
12929   // Clear the statically implied flags.
12930   Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
12931   addPredicate(*SE.getWrapPredicate(AR, Flags));
12932 
12933   auto II = FlagsMap.insert({V, Flags});
12934   if (!II.second)
12935     II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
12936 }
12937 
12938 bool PredicatedScalarEvolution::hasNoOverflow(
12939     Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
12940   const SCEV *Expr = getSCEV(V);
12941   const auto *AR = cast<SCEVAddRecExpr>(Expr);
12942 
12943   Flags = SCEVWrapPredicate::clearFlags(
12944       Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
12945 
12946   auto II = FlagsMap.find(V);
12947 
12948   if (II != FlagsMap.end())
12949     Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
12950 
12951   return Flags == SCEVWrapPredicate::IncrementAnyWrap;
12952 }
12953 
12954 const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
12955   const SCEV *Expr = this->getSCEV(V);
12956   SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
12957   auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
12958 
12959   if (!New)
12960     return nullptr;
12961 
12962   for (auto *P : NewPreds)
12963     Preds.add(P);
12964 
12965   updateGeneration();
12966   RewriteMap[SE.getSCEV(V)] = {Generation, New};
12967   return New;
12968 }
12969 
12970 PredicatedScalarEvolution::PredicatedScalarEvolution(
12971     const PredicatedScalarEvolution &Init)
12972     : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds),
12973       Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
12974   for (auto I : Init.FlagsMap)
12975     FlagsMap.insert(I);
12976 }
12977 
12978 void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
12979   // For each block.
12980   for (auto *BB : L.getBlocks())
12981     for (auto &I : *BB) {
12982       if (!SE.isSCEVable(I.getType()))
12983         continue;
12984 
12985       auto *Expr = SE.getSCEV(&I);
12986       auto II = RewriteMap.find(Expr);
12987 
12988       if (II == RewriteMap.end())
12989         continue;
12990 
12991       // Don't print things that are not interesting.
12992       if (II->second.second == Expr)
12993         continue;
12994 
12995       OS.indent(Depth) << "[PSE]" << I << ":\n";
12996       OS.indent(Depth + 2) << *Expr << "\n";
12997       OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
12998     }
12999 }
13000 
13001 // Match the mathematical pattern A - (A / B) * B, where A and B can be
13002 // arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
13003 // for URem with constant power-of-2 second operands.
13004 // It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
13005 // 4, A / B becomes X / 8).
13006 bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
13007                                 const SCEV *&RHS) {
13008   // Try to match 'zext (trunc A to iB) to iY', which is used
13009   // for URem with constant power-of-2 second operands. Make sure the size of
13010   // the operand A matches the size of the whole expressions.
13011   if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
13012     if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
13013       LHS = Trunc->getOperand();
13014       if (LHS->getType() != Expr->getType())
13015         LHS = getZeroExtendExpr(LHS, Expr->getType());
13016       RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
13017                         << getTypeSizeInBits(Trunc->getType()));
13018       return true;
13019     }
13020   const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
13021   if (Add == nullptr || Add->getNumOperands() != 2)
13022     return false;
13023 
13024   const SCEV *A = Add->getOperand(1);
13025   const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
13026 
13027   if (Mul == nullptr)
13028     return false;
13029 
13030   const auto MatchURemWithDivisor = [&](const SCEV *B) {
13031     // (SomeExpr + (-(SomeExpr / B) * B)).
13032     if (Expr == getURemExpr(A, B)) {
13033       LHS = A;
13034       RHS = B;
13035       return true;
13036     }
13037     return false;
13038   };
13039 
13040   // (SomeExpr + (-1 * (SomeExpr / B) * B)).
13041   if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
13042     return MatchURemWithDivisor(Mul->getOperand(1)) ||
13043            MatchURemWithDivisor(Mul->getOperand(2));
13044 
13045   // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
13046   if (Mul->getNumOperands() == 2)
13047     return MatchURemWithDivisor(Mul->getOperand(1)) ||
13048            MatchURemWithDivisor(Mul->getOperand(0)) ||
13049            MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
13050            MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
13051   return false;
13052 }
13053 
13054 const SCEV *
13055 ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
13056   SmallVector<BasicBlock*, 16> ExitingBlocks;
13057   L->getExitingBlocks(ExitingBlocks);
13058 
13059   // Form an expression for the maximum exit count possible for this loop. We
13060   // merge the max and exact information to approximate a version of
13061   // getConstantMaxBackedgeTakenCount which isn't restricted to just constants.
13062   SmallVector<const SCEV*, 4> ExitCounts;
13063   for (BasicBlock *ExitingBB : ExitingBlocks) {
13064     const SCEV *ExitCount = getExitCount(L, ExitingBB);
13065     if (isa<SCEVCouldNotCompute>(ExitCount))
13066       ExitCount = getExitCount(L, ExitingBB,
13067                                   ScalarEvolution::ConstantMaximum);
13068     if (!isa<SCEVCouldNotCompute>(ExitCount)) {
13069       assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
13070              "We should only have known counts for exiting blocks that "
13071              "dominate latch!");
13072       ExitCounts.push_back(ExitCount);
13073     }
13074   }
13075   if (ExitCounts.empty())
13076     return getCouldNotCompute();
13077   return getUMinFromMismatchedTypes(ExitCounts);
13078 }
13079 
13080 /// This rewriter is similar to SCEVParameterRewriter (it replaces SCEVUnknown
13081 /// components following the Map (Value -> SCEV)), but skips AddRecExpr because
13082 /// we cannot guarantee that the replacement is loop invariant in the loop of
13083 /// the AddRec.
13084 class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
13085   ValueToSCEVMapTy &Map;
13086 
13087 public:
13088   SCEVLoopGuardRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
13089       : SCEVRewriteVisitor(SE), Map(M) {}
13090 
13091   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
13092 
13093   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
13094     auto I = Map.find(Expr->getValue());
13095     if (I == Map.end())
13096       return Expr;
13097     return I->second;
13098   }
13099 };
13100 
13101 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
13102   auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
13103                               const SCEV *RHS, ValueToSCEVMapTy &RewriteMap) {
13104     if (!isa<SCEVUnknown>(LHS)) {
13105       std::swap(LHS, RHS);
13106       Predicate = CmpInst::getSwappedPredicate(Predicate);
13107     }
13108 
13109     // For now, limit to conditions that provide information about unknown
13110     // expressions.
13111     auto *LHSUnknown = dyn_cast<SCEVUnknown>(LHS);
13112     if (!LHSUnknown)
13113       return;
13114 
13115     // TODO: use information from more predicates.
13116     switch (Predicate) {
13117     case CmpInst::ICMP_ULT: {
13118       if (!containsAddRecurrence(RHS)) {
13119         const SCEV *Base = LHS;
13120         auto I = RewriteMap.find(LHSUnknown->getValue());
13121         if (I != RewriteMap.end())
13122           Base = I->second;
13123 
13124         RewriteMap[LHSUnknown->getValue()] =
13125             getUMinExpr(Base, getMinusSCEV(RHS, getOne(RHS->getType())));
13126       }
13127       break;
13128     }
13129     case CmpInst::ICMP_ULE: {
13130       if (!containsAddRecurrence(RHS)) {
13131         const SCEV *Base = LHS;
13132         auto I = RewriteMap.find(LHSUnknown->getValue());
13133         if (I != RewriteMap.end())
13134           Base = I->second;
13135         RewriteMap[LHSUnknown->getValue()] = getUMinExpr(Base, RHS);
13136       }
13137       break;
13138     }
13139     case CmpInst::ICMP_EQ:
13140       if (isa<SCEVConstant>(RHS))
13141         RewriteMap[LHSUnknown->getValue()] = RHS;
13142       break;
13143     case CmpInst::ICMP_NE:
13144       if (isa<SCEVConstant>(RHS) &&
13145           cast<SCEVConstant>(RHS)->getValue()->isNullValue())
13146         RewriteMap[LHSUnknown->getValue()] =
13147             getUMaxExpr(LHS, getOne(RHS->getType()));
13148       break;
13149     default:
13150       break;
13151     }
13152   };
13153   // Starting at the loop predecessor, climb up the predecessor chain, as long
13154   // as there are predecessors that can be found that have unique successors
13155   // leading to the original header.
13156   // TODO: share this logic with isLoopEntryGuardedByCond.
13157   ValueToSCEVMapTy RewriteMap;
13158   for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
13159            L->getLoopPredecessor(), L->getHeader());
13160        Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
13161 
13162     const BranchInst *LoopEntryPredicate =
13163         dyn_cast<BranchInst>(Pair.first->getTerminator());
13164     if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
13165       continue;
13166 
13167     // TODO: use information from more complex conditions, e.g. AND expressions.
13168     auto *Cmp = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition());
13169     if (!Cmp)
13170       continue;
13171 
13172     auto Predicate = Cmp->getPredicate();
13173     if (LoopEntryPredicate->getSuccessor(1) == Pair.second)
13174       Predicate = CmpInst::getInversePredicate(Predicate);
13175     CollectCondition(Predicate, getSCEV(Cmp->getOperand(0)),
13176                      getSCEV(Cmp->getOperand(1)), RewriteMap);
13177   }
13178 
13179   // Also collect information from assumptions dominating the loop.
13180   for (auto &AssumeVH : AC.assumptions()) {
13181     if (!AssumeVH)
13182       continue;
13183     auto *AssumeI = cast<CallInst>(AssumeVH);
13184     auto *Cmp = dyn_cast<ICmpInst>(AssumeI->getOperand(0));
13185     if (!Cmp || !DT.dominates(AssumeI, L->getHeader()))
13186       continue;
13187     CollectCondition(Cmp->getPredicate(), getSCEV(Cmp->getOperand(0)),
13188                      getSCEV(Cmp->getOperand(1)), RewriteMap);
13189   }
13190 
13191   if (RewriteMap.empty())
13192     return Expr;
13193   SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
13194   return Rewriter.visit(Expr);
13195 }
13196