xref: /llvm-project/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h (revision 7f06d8afb03383dea33379f9c06d010d6ee3f14e)
1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the classes used to represent and build scalar expressions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/Analysis/ScalarEvolution.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/ValueHandle.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include <cassert>
25 #include <cstddef>
26 
27 namespace llvm {
28 
29 class APInt;
30 class Constant;
31 class ConstantInt;
32 class ConstantRange;
33 class Loop;
34 class Type;
35 class Value;
36 
37 enum SCEVTypes : unsigned short {
38   // These should be ordered in terms of increasing complexity to make the
39   // folders simpler.
40   scConstant,
41   scVScale,
42   scTruncate,
43   scZeroExtend,
44   scSignExtend,
45   scAddExpr,
46   scMulExpr,
47   scUDivExpr,
48   scAddRecExpr,
49   scUMaxExpr,
50   scSMaxExpr,
51   scUMinExpr,
52   scSMinExpr,
53   scSequentialUMinExpr,
54   scPtrToInt,
55   scUnknown,
56   scCouldNotCompute
57 };
58 
59 /// This class represents a constant integer value.
60 class SCEVConstant : public SCEV {
61   friend class ScalarEvolution;
62 
63   ConstantInt *V;
64 
65   SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v)
66       : SCEV(ID, scConstant, 1), V(v) {}
67 
68 public:
69   ConstantInt *getValue() const { return V; }
70   const APInt &getAPInt() const { return getValue()->getValue(); }
71 
72   Type *getType() const { return V->getType(); }
73 
74   /// Methods for support type inquiry through isa, cast, and dyn_cast:
75   static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
76 };
77 
78 /// This class represents the value of vscale, as used when defining the length
79 /// of a scalable vector or returned by the llvm.vscale() intrinsic.
80 class SCEVVScale : public SCEV {
81   friend class ScalarEvolution;
82 
83   SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty)
84       : SCEV(ID, scVScale, 0), Ty(ty) {}
85 
86   Type *Ty;
87 
88 public:
89   Type *getType() const { return Ty; }
90 
91   /// Methods for support type inquiry through isa, cast, and dyn_cast:
92   static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; }
93 };
94 
95 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
96   APInt Size(16, 1);
97   for (const auto *Arg : Args)
98     Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
99   return (unsigned short)Size.getZExtValue();
100 }
101 
102 /// This is the base class for unary cast operator classes.
103 class SCEVCastExpr : public SCEV {
104 protected:
105   const SCEV *Op;
106   Type *Ty;
107 
108   SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
109                Type *ty);
110 
111 public:
112   const SCEV *getOperand() const { return Op; }
113   const SCEV *getOperand(unsigned i) const {
114     assert(i == 0 && "Operand index out of range!");
115     return Op;
116   }
117   ArrayRef<const SCEV *> operands() const { return Op; }
118   size_t getNumOperands() const { return 1; }
119   Type *getType() const { return Ty; }
120 
121   /// Methods for support type inquiry through isa, cast, and dyn_cast:
122   static bool classof(const SCEV *S) {
123     return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
124            S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
125   }
126 };
127 
128 /// This class represents a cast from a pointer to a pointer-sized integer
129 /// value.
130 class SCEVPtrToIntExpr : public SCEVCastExpr {
131   friend class ScalarEvolution;
132 
133   SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
134 
135 public:
136   /// Methods for support type inquiry through isa, cast, and dyn_cast:
137   static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
138 };
139 
140 /// This is the base class for unary integral cast operator classes.
141 class SCEVIntegralCastExpr : public SCEVCastExpr {
142 protected:
143   SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
144                        const SCEV *op, Type *ty);
145 
146 public:
147   /// Methods for support type inquiry through isa, cast, and dyn_cast:
148   static bool classof(const SCEV *S) {
149     return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
150            S->getSCEVType() == scSignExtend;
151   }
152 };
153 
154 /// This class represents a truncation of an integer value to a
155 /// smaller integer value.
156 class SCEVTruncateExpr : public SCEVIntegralCastExpr {
157   friend class ScalarEvolution;
158 
159   SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
160 
161 public:
162   /// Methods for support type inquiry through isa, cast, and dyn_cast:
163   static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; }
164 };
165 
166 /// This class represents a zero extension of a small integer value
167 /// to a larger integer value.
168 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
169   friend class ScalarEvolution;
170 
171   SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
172 
173 public:
174   /// Methods for support type inquiry through isa, cast, and dyn_cast:
175   static bool classof(const SCEV *S) {
176     return S->getSCEVType() == scZeroExtend;
177   }
178 };
179 
180 /// This class represents a sign extension of a small integer value
181 /// to a larger integer value.
182 class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
183   friend class ScalarEvolution;
184 
185   SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
186 
187 public:
188   /// Methods for support type inquiry through isa, cast, and dyn_cast:
189   static bool classof(const SCEV *S) {
190     return S->getSCEVType() == scSignExtend;
191   }
192 };
193 
194 /// This node is a base class providing common functionality for
195 /// n'ary operators.
196 class SCEVNAryExpr : public SCEV {
197 protected:
198   // Since SCEVs are immutable, ScalarEvolution allocates operand
199   // arrays with its SCEVAllocator, so this class just needs a simple
200   // pointer rather than a more elaborate vector-like data structure.
201   // This also avoids the need for a non-trivial destructor.
202   const SCEV *const *Operands;
203   size_t NumOperands;
204 
205   SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
206                const SCEV *const *O, size_t N)
207       : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O),
208         NumOperands(N) {}
209 
210 public:
211   size_t getNumOperands() const { return NumOperands; }
212 
213   const SCEV *getOperand(unsigned i) const {
214     assert(i < NumOperands && "Operand index out of range!");
215     return Operands[i];
216   }
217 
218   ArrayRef<const SCEV *> operands() const {
219     return ArrayRef(Operands, NumOperands);
220   }
221 
222   NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
223     return (NoWrapFlags)(SubclassData & Mask);
224   }
225 
226   bool hasNoUnsignedWrap() const {
227     return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
228   }
229 
230   bool hasNoSignedWrap() const {
231     return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
232   }
233 
234   bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; }
235 
236   /// Methods for support type inquiry through isa, cast, and dyn_cast:
237   static bool classof(const SCEV *S) {
238     return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
239            S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
240            S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
241            S->getSCEVType() == scSequentialUMinExpr ||
242            S->getSCEVType() == scAddRecExpr;
243   }
244 };
245 
246 /// This node is the base class for n'ary commutative operators.
247 class SCEVCommutativeExpr : public SCEVNAryExpr {
248 protected:
249   SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
250                       const SCEV *const *O, size_t N)
251       : SCEVNAryExpr(ID, T, O, N) {}
252 
253 public:
254   /// Methods for support type inquiry through isa, cast, and dyn_cast:
255   static bool classof(const SCEV *S) {
256     return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
257            S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
258            S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
259   }
260 
261   /// Set flags for a non-recurrence without clearing previously set flags.
262   void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
263 };
264 
265 /// This node represents an addition of some number of SCEVs.
266 class SCEVAddExpr : public SCEVCommutativeExpr {
267   friend class ScalarEvolution;
268 
269   Type *Ty;
270 
271   SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
272       : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
273     auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
274       return Op->getType()->isPointerTy();
275     });
276     if (FirstPointerTypedOp != operands().end())
277       Ty = (*FirstPointerTypedOp)->getType();
278     else
279       Ty = getOperand(0)->getType();
280   }
281 
282 public:
283   Type *getType() const { return Ty; }
284 
285   /// Methods for support type inquiry through isa, cast, and dyn_cast:
286   static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; }
287 };
288 
289 /// This node represents multiplication of some number of SCEVs.
290 class SCEVMulExpr : public SCEVCommutativeExpr {
291   friend class ScalarEvolution;
292 
293   SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
294       : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
295 
296 public:
297   Type *getType() const { return getOperand(0)->getType(); }
298 
299   /// Methods for support type inquiry through isa, cast, and dyn_cast:
300   static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; }
301 };
302 
303 /// This class represents a binary unsigned division operation.
304 class SCEVUDivExpr : public SCEV {
305   friend class ScalarEvolution;
306 
307   std::array<const SCEV *, 2> Operands;
308 
309   SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
310       : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
311     Operands[0] = lhs;
312     Operands[1] = rhs;
313   }
314 
315 public:
316   const SCEV *getLHS() const { return Operands[0]; }
317   const SCEV *getRHS() const { return Operands[1]; }
318   size_t getNumOperands() const { return 2; }
319   const SCEV *getOperand(unsigned i) const {
320     assert((i == 0 || i == 1) && "Operand index out of range!");
321     return i == 0 ? getLHS() : getRHS();
322   }
323 
324   ArrayRef<const SCEV *> operands() const { return Operands; }
325 
326   Type *getType() const {
327     // In most cases the types of LHS and RHS will be the same, but in some
328     // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
329     // depend on the type for correctness, but handling types carefully can
330     // avoid extra casts in the SCEVExpander. The LHS is more likely to be
331     // a pointer type than the RHS, so use the RHS' type here.
332     return getRHS()->getType();
333   }
334 
335   /// Methods for support type inquiry through isa, cast, and dyn_cast:
336   static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; }
337 };
338 
339 /// This node represents a polynomial recurrence on the trip count
340 /// of the specified loop.  This is the primary focus of the
341 /// ScalarEvolution framework; all the other SCEV subclasses are
342 /// mostly just supporting infrastructure to allow SCEVAddRecExpr
343 /// expressions to be created and analyzed.
344 ///
345 /// All operands of an AddRec are required to be loop invariant.
346 ///
347 class SCEVAddRecExpr : public SCEVNAryExpr {
348   friend class ScalarEvolution;
349 
350   const Loop *L;
351 
352   SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N,
353                  const Loop *l)
354       : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
355 
356 public:
357   Type *getType() const { return getStart()->getType(); }
358   const SCEV *getStart() const { return Operands[0]; }
359   const Loop *getLoop() const { return L; }
360 
361   /// Constructs and returns the recurrence indicating how much this
362   /// expression steps by.  If this is a polynomial of degree N, it
363   /// returns a chrec of degree N-1.  We cannot determine whether
364   /// the step recurrence has self-wraparound.
365   const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
366     if (isAffine())
367       return getOperand(1);
368     return SE.getAddRecExpr(
369         SmallVector<const SCEV *, 3>(operands().drop_front()), getLoop(),
370         FlagAnyWrap);
371   }
372 
373   /// Return true if this represents an expression A + B*x where A
374   /// and B are loop invariant values.
375   bool isAffine() const {
376     // We know that the start value is invariant.  This expression is thus
377     // affine iff the step is also invariant.
378     return getNumOperands() == 2;
379   }
380 
381   /// Return true if this represents an expression A + B*x + C*x^2
382   /// where A, B and C are loop invariant values.  This corresponds
383   /// to an addrec of the form {L,+,M,+,N}
384   bool isQuadratic() const { return getNumOperands() == 3; }
385 
386   /// Set flags for a recurrence without clearing any previously set flags.
387   /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
388   /// to make it easier to propagate flags.
389   void setNoWrapFlags(NoWrapFlags Flags) {
390     if (Flags & (FlagNUW | FlagNSW))
391       Flags = ScalarEvolution::setFlags(Flags, FlagNW);
392     SubclassData |= Flags;
393   }
394 
395   /// Return the value of this chain of recurrences at the specified
396   /// iteration number.
397   const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
398 
399   /// Return the value of this chain of recurrences at the specified iteration
400   /// number. Takes an explicit list of operands to represent an AddRec.
401   static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands,
402                                          const SCEV *It, ScalarEvolution &SE);
403 
404   /// Return the number of iterations of this loop that produce
405   /// values in the specified constant range.  Another way of
406   /// looking at this is that it returns the first iteration number
407   /// where the value is not in the condition, thus computing the
408   /// exit count.  If the iteration count can't be computed, an
409   /// instance of SCEVCouldNotCompute is returned.
410   const SCEV *getNumIterationsInRange(const ConstantRange &Range,
411                                       ScalarEvolution &SE) const;
412 
413   /// Return an expression representing the value of this expression
414   /// one iteration of the loop ahead.
415   const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
416 
417   /// Methods for support type inquiry through isa, cast, and dyn_cast:
418   static bool classof(const SCEV *S) {
419     return S->getSCEVType() == scAddRecExpr;
420   }
421 };
422 
423 /// This node is the base class min/max selections.
424 class SCEVMinMaxExpr : public SCEVCommutativeExpr {
425   friend class ScalarEvolution;
426 
427   static bool isMinMaxType(enum SCEVTypes T) {
428     return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
429            T == scUMinExpr;
430   }
431 
432 protected:
433   /// Note: Constructing subclasses via this constructor is allowed
434   SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
435                  const SCEV *const *O, size_t N)
436       : SCEVCommutativeExpr(ID, T, O, N) {
437     assert(isMinMaxType(T));
438     // Min and max never overflow
439     setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
440   }
441 
442 public:
443   Type *getType() const { return getOperand(0)->getType(); }
444 
445   static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); }
446 
447   static enum SCEVTypes negate(enum SCEVTypes T) {
448     switch (T) {
449     case scSMaxExpr:
450       return scSMinExpr;
451     case scSMinExpr:
452       return scSMaxExpr;
453     case scUMaxExpr:
454       return scUMinExpr;
455     case scUMinExpr:
456       return scUMaxExpr;
457     default:
458       llvm_unreachable("Not a min or max SCEV type!");
459     }
460   }
461 };
462 
463 /// This class represents a signed maximum selection.
464 class SCEVSMaxExpr : public SCEVMinMaxExpr {
465   friend class ScalarEvolution;
466 
467   SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
468       : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
469 
470 public:
471   /// Methods for support type inquiry through isa, cast, and dyn_cast:
472   static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; }
473 };
474 
475 /// This class represents an unsigned maximum selection.
476 class SCEVUMaxExpr : public SCEVMinMaxExpr {
477   friend class ScalarEvolution;
478 
479   SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
480       : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
481 
482 public:
483   /// Methods for support type inquiry through isa, cast, and dyn_cast:
484   static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; }
485 };
486 
487 /// This class represents a signed minimum selection.
488 class SCEVSMinExpr : public SCEVMinMaxExpr {
489   friend class ScalarEvolution;
490 
491   SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
492       : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
493 
494 public:
495   /// Methods for support type inquiry through isa, cast, and dyn_cast:
496   static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; }
497 };
498 
499 /// This class represents an unsigned minimum selection.
500 class SCEVUMinExpr : public SCEVMinMaxExpr {
501   friend class ScalarEvolution;
502 
503   SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
504       : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
505 
506 public:
507   /// Methods for support type inquiry through isa, cast, and dyn_cast:
508   static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
509 };
510 
511 /// This node is the base class for sequential/in-order min/max selections.
512 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they
513 /// are early-returning upon reaching saturation point.
514 /// I.e. given `0 umin_seq poison`, the result will be `0`, while the result of
515 /// `0 umin poison` is `poison`. When returning early, later expressions are not
516 /// executed, so `0 umin_seq (%x u/ 0)` does not result in undefined behavior.
517 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
518   friend class ScalarEvolution;
519 
520   static bool isSequentialMinMaxType(enum SCEVTypes T) {
521     return T == scSequentialUMinExpr;
522   }
523 
524   /// Set flags for a non-recurrence without clearing previously set flags.
525   void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
526 
527 protected:
528   /// Note: Constructing subclasses via this constructor is allowed
529   SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
530                            const SCEV *const *O, size_t N)
531       : SCEVNAryExpr(ID, T, O, N) {
532     assert(isSequentialMinMaxType(T));
533     // Min and max never overflow
534     setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
535   }
536 
537 public:
538   Type *getType() const { return getOperand(0)->getType(); }
539 
540   static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
541     assert(isSequentialMinMaxType(Ty));
542     switch (Ty) {
543     case scSequentialUMinExpr:
544       return scUMinExpr;
545     default:
546       llvm_unreachable("Not a sequential min/max type.");
547     }
548   }
549 
550   SCEVTypes getEquivalentNonSequentialSCEVType() const {
551     return getEquivalentNonSequentialSCEVType(getSCEVType());
552   }
553 
554   static bool classof(const SCEV *S) {
555     return isSequentialMinMaxType(S->getSCEVType());
556   }
557 };
558 
559 /// This class represents a sequential/in-order unsigned minimum selection.
560 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
561   friend class ScalarEvolution;
562 
563   SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
564                          size_t N)
565       : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}
566 
567 public:
568   /// Methods for support type inquiry through isa, cast, and dyn_cast:
569   static bool classof(const SCEV *S) {
570     return S->getSCEVType() == scSequentialUMinExpr;
571   }
572 };
573 
574 /// This means that we are dealing with an entirely unknown SCEV
575 /// value, and only represent it as its LLVM Value.  This is the
576 /// "bottom" value for the analysis.
577 class SCEVUnknown final : public SCEV, private CallbackVH {
578   friend class ScalarEvolution;
579 
580   /// The parent ScalarEvolution value. This is used to update the
581   /// parent's maps when the value associated with a SCEVUnknown is
582   /// deleted or RAUW'd.
583   ScalarEvolution *SE;
584 
585   /// The next pointer in the linked list of all SCEVUnknown
586   /// instances owned by a ScalarEvolution.
587   SCEVUnknown *Next;
588 
589   SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se,
590               SCEVUnknown *next)
591       : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
592 
593   // Implement CallbackVH.
594   void deleted() override;
595   void allUsesReplacedWith(Value *New) override;
596 
597 public:
598   Value *getValue() const { return getValPtr(); }
599 
600   Type *getType() const { return getValPtr()->getType(); }
601 
602   /// Methods for support type inquiry through isa, cast, and dyn_cast:
603   static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; }
604 };
605 
606 /// This class defines a simple visitor class that may be used for
607 /// various SCEV analysis purposes.
608 template <typename SC, typename RetVal = void> struct SCEVVisitor {
609   RetVal visit(const SCEV *S) {
610     switch (S->getSCEVType()) {
611     case scConstant:
612       return ((SC *)this)->visitConstant((const SCEVConstant *)S);
613     case scVScale:
614       return ((SC *)this)->visitVScale((const SCEVVScale *)S);
615     case scPtrToInt:
616       return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
617     case scTruncate:
618       return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S);
619     case scZeroExtend:
620       return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S);
621     case scSignExtend:
622       return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S);
623     case scAddExpr:
624       return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S);
625     case scMulExpr:
626       return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S);
627     case scUDivExpr:
628       return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S);
629     case scAddRecExpr:
630       return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S);
631     case scSMaxExpr:
632       return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S);
633     case scUMaxExpr:
634       return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S);
635     case scSMinExpr:
636       return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
637     case scUMinExpr:
638       return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
639     case scSequentialUMinExpr:
640       return ((SC *)this)
641           ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
642     case scUnknown:
643       return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
644     case scCouldNotCompute:
645       return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S);
646     }
647     llvm_unreachable("Unknown SCEV kind!");
648   }
649 
650   RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
651     llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
652   }
653 };
654 
655 /// Visit all nodes in the expression tree using worklist traversal.
656 ///
657 /// Visitor implements:
658 ///   // return true to follow this node.
659 ///   bool follow(const SCEV *S);
660 ///   // return true to terminate the search.
661 ///   bool isDone();
662 template <typename SV> class SCEVTraversal {
663   SV &Visitor;
664   SmallVector<const SCEV *, 8> Worklist;
665   SmallPtrSet<const SCEV *, 8> Visited;
666 
667   void push(const SCEV *S) {
668     if (Visited.insert(S).second && Visitor.follow(S))
669       Worklist.push_back(S);
670   }
671 
672 public:
673   SCEVTraversal(SV &V) : Visitor(V) {}
674 
675   void visitAll(const SCEV *Root) {
676     push(Root);
677     while (!Worklist.empty() && !Visitor.isDone()) {
678       const SCEV *S = Worklist.pop_back_val();
679 
680       switch (S->getSCEVType()) {
681       case scConstant:
682       case scVScale:
683       case scUnknown:
684         continue;
685       case scPtrToInt:
686       case scTruncate:
687       case scZeroExtend:
688       case scSignExtend:
689       case scAddExpr:
690       case scMulExpr:
691       case scUDivExpr:
692       case scSMaxExpr:
693       case scUMaxExpr:
694       case scSMinExpr:
695       case scUMinExpr:
696       case scSequentialUMinExpr:
697       case scAddRecExpr:
698         for (const auto *Op : S->operands()) {
699           push(Op);
700           if (Visitor.isDone())
701             break;
702         }
703         continue;
704       case scCouldNotCompute:
705         llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
706       }
707       llvm_unreachable("Unknown SCEV kind!");
708     }
709   }
710 };
711 
712 /// Use SCEVTraversal to visit all nodes in the given expression tree.
713 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) {
714   SCEVTraversal<SV> T(Visitor);
715   T.visitAll(Root);
716 }
717 
718 /// Return true if any node in \p Root satisfies the predicate \p Pred.
719 template <typename PredTy>
720 bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
721   struct FindClosure {
722     bool Found = false;
723     PredTy Pred;
724 
725     FindClosure(PredTy Pred) : Pred(Pred) {}
726 
727     bool follow(const SCEV *S) {
728       if (!Pred(S))
729         return true;
730 
731       Found = true;
732       return false;
733     }
734 
735     bool isDone() const { return Found; }
736   };
737 
738   FindClosure FC(Pred);
739   visitAll(Root, FC);
740   return FC.Found;
741 }
742 
743 /// This visitor recursively visits a SCEV expression and re-writes it.
744 /// The result from each visit is cached, so it will return the same
745 /// SCEV for the same input.
746 template <typename SC>
747 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
748 protected:
749   ScalarEvolution &SE;
750   // Memoize the result of each visit so that we only compute once for
751   // the same input SCEV. This is to avoid redundant computations when
752   // a SCEV is referenced by multiple SCEVs. Without memoization, this
753   // visit algorithm would have exponential time complexity in the worst
754   // case, causing the compiler to hang on certain tests.
755   SmallDenseMap<const SCEV *, const SCEV *> RewriteResults;
756 
757 public:
758   SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
759 
760   const SCEV *visit(const SCEV *S) {
761     auto It = RewriteResults.find(S);
762     if (It != RewriteResults.end())
763       return It->second;
764     auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
765     auto Result = RewriteResults.try_emplace(S, Visited);
766     assert(Result.second && "Should insert a new entry");
767     return Result.first->second;
768   }
769 
770   const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
771 
772   const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; }
773 
774   const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
775     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
776     return Operand == Expr->getOperand()
777                ? Expr
778                : SE.getPtrToIntExpr(Operand, Expr->getType());
779   }
780 
781   const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
782     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
783     return Operand == Expr->getOperand()
784                ? Expr
785                : SE.getTruncateExpr(Operand, Expr->getType());
786   }
787 
788   const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
789     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
790     return Operand == Expr->getOperand()
791                ? Expr
792                : SE.getZeroExtendExpr(Operand, Expr->getType());
793   }
794 
795   const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
796     const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
797     return Operand == Expr->getOperand()
798                ? Expr
799                : SE.getSignExtendExpr(Operand, Expr->getType());
800   }
801 
802   const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
803     SmallVector<const SCEV *, 2> Operands;
804     bool Changed = false;
805     for (const auto *Op : Expr->operands()) {
806       Operands.push_back(((SC *)this)->visit(Op));
807       Changed |= Op != Operands.back();
808     }
809     return !Changed ? Expr : SE.getAddExpr(Operands);
810   }
811 
812   const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
813     SmallVector<const SCEV *, 2> Operands;
814     bool Changed = false;
815     for (const auto *Op : Expr->operands()) {
816       Operands.push_back(((SC *)this)->visit(Op));
817       Changed |= Op != Operands.back();
818     }
819     return !Changed ? Expr : SE.getMulExpr(Operands);
820   }
821 
822   const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
823     auto *LHS = ((SC *)this)->visit(Expr->getLHS());
824     auto *RHS = ((SC *)this)->visit(Expr->getRHS());
825     bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
826     return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
827   }
828 
829   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
830     SmallVector<const SCEV *, 2> Operands;
831     bool Changed = false;
832     for (const auto *Op : Expr->operands()) {
833       Operands.push_back(((SC *)this)->visit(Op));
834       Changed |= Op != Operands.back();
835     }
836     return !Changed ? Expr
837                     : SE.getAddRecExpr(Operands, Expr->getLoop(),
838                                        Expr->getNoWrapFlags());
839   }
840 
841   const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
842     SmallVector<const SCEV *, 2> Operands;
843     bool Changed = false;
844     for (const auto *Op : Expr->operands()) {
845       Operands.push_back(((SC *)this)->visit(Op));
846       Changed |= Op != Operands.back();
847     }
848     return !Changed ? Expr : SE.getSMaxExpr(Operands);
849   }
850 
851   const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
852     SmallVector<const SCEV *, 2> Operands;
853     bool Changed = false;
854     for (const auto *Op : Expr->operands()) {
855       Operands.push_back(((SC *)this)->visit(Op));
856       Changed |= Op != Operands.back();
857     }
858     return !Changed ? Expr : SE.getUMaxExpr(Operands);
859   }
860 
861   const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
862     SmallVector<const SCEV *, 2> Operands;
863     bool Changed = false;
864     for (const auto *Op : Expr->operands()) {
865       Operands.push_back(((SC *)this)->visit(Op));
866       Changed |= Op != Operands.back();
867     }
868     return !Changed ? Expr : SE.getSMinExpr(Operands);
869   }
870 
871   const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
872     SmallVector<const SCEV *, 2> Operands;
873     bool Changed = false;
874     for (const auto *Op : Expr->operands()) {
875       Operands.push_back(((SC *)this)->visit(Op));
876       Changed |= Op != Operands.back();
877     }
878     return !Changed ? Expr : SE.getUMinExpr(Operands);
879   }
880 
881   const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
882     SmallVector<const SCEV *, 2> Operands;
883     bool Changed = false;
884     for (const auto *Op : Expr->operands()) {
885       Operands.push_back(((SC *)this)->visit(Op));
886       Changed |= Op != Operands.back();
887     }
888     return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true);
889   }
890 
891   const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
892 
893   const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
894     return Expr;
895   }
896 };
897 
898 using ValueToValueMap = DenseMap<const Value *, Value *>;
899 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
900 
901 /// The SCEVParameterRewriter takes a scalar evolution expression and updates
902 /// the SCEVUnknown components following the Map (Value -> SCEV).
903 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
904 public:
905   static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
906                              ValueToSCEVMapTy &Map) {
907     SCEVParameterRewriter Rewriter(SE, Map);
908     return Rewriter.visit(Scev);
909   }
910 
911   SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
912       : SCEVRewriteVisitor(SE), Map(M) {}
913 
914   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
915     auto I = Map.find(Expr->getValue());
916     if (I == Map.end())
917       return Expr;
918     return I->second;
919   }
920 
921 private:
922   ValueToSCEVMapTy &Map;
923 };
924 
925 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
926 
927 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
928 /// the Map (Loop -> SCEV) to all AddRecExprs.
929 class SCEVLoopAddRecRewriter
930     : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
931 public:
932   SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
933       : SCEVRewriteVisitor(SE), Map(M) {}
934 
935   static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
936                              ScalarEvolution &SE) {
937     SCEVLoopAddRecRewriter Rewriter(SE, Map);
938     return Rewriter.visit(Scev);
939   }
940 
941   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
942     SmallVector<const SCEV *, 2> Operands;
943     for (const SCEV *Op : Expr->operands())
944       Operands.push_back(visit(Op));
945 
946     const Loop *L = Expr->getLoop();
947     if (0 == Map.count(L))
948       return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
949 
950     return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE);
951   }
952 
953 private:
954   LoopToScevMapT &Map;
955 };
956 
957 } // end namespace llvm
958 
959 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
960