1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the classes used to represent and build scalar expressions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 15 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/SmallPtrSet.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/Analysis/ScalarEvolution.h" 20 #include "llvm/IR/Constants.h" 21 #include "llvm/IR/ValueHandle.h" 22 #include "llvm/Support/Casting.h" 23 #include "llvm/Support/ErrorHandling.h" 24 #include <cassert> 25 #include <cstddef> 26 27 namespace llvm { 28 29 class APInt; 30 class Constant; 31 class ConstantInt; 32 class ConstantRange; 33 class Loop; 34 class Type; 35 class Value; 36 37 enum SCEVTypes : unsigned short { 38 // These should be ordered in terms of increasing complexity to make the 39 // folders simpler. 40 scConstant, 41 scVScale, 42 scTruncate, 43 scZeroExtend, 44 scSignExtend, 45 scAddExpr, 46 scMulExpr, 47 scUDivExpr, 48 scAddRecExpr, 49 scUMaxExpr, 50 scSMaxExpr, 51 scUMinExpr, 52 scSMinExpr, 53 scSequentialUMinExpr, 54 scPtrToInt, 55 scUnknown, 56 scCouldNotCompute 57 }; 58 59 /// This class represents a constant integer value. 60 class SCEVConstant : public SCEV { 61 friend class ScalarEvolution; 62 63 ConstantInt *V; 64 65 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) 66 : SCEV(ID, scConstant, 1), V(v) {} 67 68 public: 69 ConstantInt *getValue() const { return V; } 70 const APInt &getAPInt() const { return getValue()->getValue(); } 71 72 Type *getType() const { return V->getType(); } 73 74 /// Methods for support type inquiry through isa, cast, and dyn_cast: 75 static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; } 76 }; 77 78 /// This class represents the value of vscale, as used when defining the length 79 /// of a scalable vector or returned by the llvm.vscale() intrinsic. 80 class SCEVVScale : public SCEV { 81 friend class ScalarEvolution; 82 83 SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty) 84 : SCEV(ID, scVScale, 0), Ty(ty) {} 85 86 Type *Ty; 87 88 public: 89 Type *getType() const { return Ty; } 90 91 /// Methods for support type inquiry through isa, cast, and dyn_cast: 92 static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; } 93 }; 94 95 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) { 96 APInt Size(16, 1); 97 for (const auto *Arg : Args) 98 Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize())); 99 return (unsigned short)Size.getZExtValue(); 100 } 101 102 /// This is the base class for unary cast operator classes. 103 class SCEVCastExpr : public SCEV { 104 protected: 105 const SCEV *Op; 106 Type *Ty; 107 108 SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, 109 Type *ty); 110 111 public: 112 const SCEV *getOperand() const { return Op; } 113 const SCEV *getOperand(unsigned i) const { 114 assert(i == 0 && "Operand index out of range!"); 115 return Op; 116 } 117 ArrayRef<const SCEV *> operands() const { return Op; } 118 size_t getNumOperands() const { return 1; } 119 Type *getType() const { return Ty; } 120 121 /// Methods for support type inquiry through isa, cast, and dyn_cast: 122 static bool classof(const SCEV *S) { 123 return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate || 124 S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend; 125 } 126 }; 127 128 /// This class represents a cast from a pointer to a pointer-sized integer 129 /// value. 130 class SCEVPtrToIntExpr : public SCEVCastExpr { 131 friend class ScalarEvolution; 132 133 SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy); 134 135 public: 136 /// Methods for support type inquiry through isa, cast, and dyn_cast: 137 static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; } 138 }; 139 140 /// This is the base class for unary integral cast operator classes. 141 class SCEVIntegralCastExpr : public SCEVCastExpr { 142 protected: 143 SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, 144 const SCEV *op, Type *ty); 145 146 public: 147 /// Methods for support type inquiry through isa, cast, and dyn_cast: 148 static bool classof(const SCEV *S) { 149 return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend || 150 S->getSCEVType() == scSignExtend; 151 } 152 }; 153 154 /// This class represents a truncation of an integer value to a 155 /// smaller integer value. 156 class SCEVTruncateExpr : public SCEVIntegralCastExpr { 157 friend class ScalarEvolution; 158 159 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 160 161 public: 162 /// Methods for support type inquiry through isa, cast, and dyn_cast: 163 static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; } 164 }; 165 166 /// This class represents a zero extension of a small integer value 167 /// to a larger integer value. 168 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr { 169 friend class ScalarEvolution; 170 171 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 172 173 public: 174 /// Methods for support type inquiry through isa, cast, and dyn_cast: 175 static bool classof(const SCEV *S) { 176 return S->getSCEVType() == scZeroExtend; 177 } 178 }; 179 180 /// This class represents a sign extension of a small integer value 181 /// to a larger integer value. 182 class SCEVSignExtendExpr : public SCEVIntegralCastExpr { 183 friend class ScalarEvolution; 184 185 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); 186 187 public: 188 /// Methods for support type inquiry through isa, cast, and dyn_cast: 189 static bool classof(const SCEV *S) { 190 return S->getSCEVType() == scSignExtend; 191 } 192 }; 193 194 /// This node is a base class providing common functionality for 195 /// n'ary operators. 196 class SCEVNAryExpr : public SCEV { 197 protected: 198 // Since SCEVs are immutable, ScalarEvolution allocates operand 199 // arrays with its SCEVAllocator, so this class just needs a simple 200 // pointer rather than a more elaborate vector-like data structure. 201 // This also avoids the need for a non-trivial destructor. 202 const SCEV *const *Operands; 203 size_t NumOperands; 204 205 SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 206 const SCEV *const *O, size_t N) 207 : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O), 208 NumOperands(N) {} 209 210 public: 211 size_t getNumOperands() const { return NumOperands; } 212 213 const SCEV *getOperand(unsigned i) const { 214 assert(i < NumOperands && "Operand index out of range!"); 215 return Operands[i]; 216 } 217 218 ArrayRef<const SCEV *> operands() const { 219 return ArrayRef(Operands, NumOperands); 220 } 221 222 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const { 223 return (NoWrapFlags)(SubclassData & Mask); 224 } 225 226 bool hasNoUnsignedWrap() const { 227 return getNoWrapFlags(FlagNUW) != FlagAnyWrap; 228 } 229 230 bool hasNoSignedWrap() const { 231 return getNoWrapFlags(FlagNSW) != FlagAnyWrap; 232 } 233 234 bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; } 235 236 /// Methods for support type inquiry through isa, cast, and dyn_cast: 237 static bool classof(const SCEV *S) { 238 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || 239 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || 240 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr || 241 S->getSCEVType() == scSequentialUMinExpr || 242 S->getSCEVType() == scAddRecExpr; 243 } 244 }; 245 246 /// This node is the base class for n'ary commutative operators. 247 class SCEVCommutativeExpr : public SCEVNAryExpr { 248 protected: 249 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 250 const SCEV *const *O, size_t N) 251 : SCEVNAryExpr(ID, T, O, N) {} 252 253 public: 254 /// Methods for support type inquiry through isa, cast, and dyn_cast: 255 static bool classof(const SCEV *S) { 256 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || 257 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || 258 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr; 259 } 260 261 /// Set flags for a non-recurrence without clearing previously set flags. 262 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } 263 }; 264 265 /// This node represents an addition of some number of SCEVs. 266 class SCEVAddExpr : public SCEVCommutativeExpr { 267 friend class ScalarEvolution; 268 269 Type *Ty; 270 271 SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 272 : SCEVCommutativeExpr(ID, scAddExpr, O, N) { 273 auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) { 274 return Op->getType()->isPointerTy(); 275 }); 276 if (FirstPointerTypedOp != operands().end()) 277 Ty = (*FirstPointerTypedOp)->getType(); 278 else 279 Ty = getOperand(0)->getType(); 280 } 281 282 public: 283 Type *getType() const { return Ty; } 284 285 /// Methods for support type inquiry through isa, cast, and dyn_cast: 286 static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; } 287 }; 288 289 /// This node represents multiplication of some number of SCEVs. 290 class SCEVMulExpr : public SCEVCommutativeExpr { 291 friend class ScalarEvolution; 292 293 SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 294 : SCEVCommutativeExpr(ID, scMulExpr, O, N) {} 295 296 public: 297 Type *getType() const { return getOperand(0)->getType(); } 298 299 /// Methods for support type inquiry through isa, cast, and dyn_cast: 300 static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; } 301 }; 302 303 /// This class represents a binary unsigned division operation. 304 class SCEVUDivExpr : public SCEV { 305 friend class ScalarEvolution; 306 307 std::array<const SCEV *, 2> Operands; 308 309 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs) 310 : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) { 311 Operands[0] = lhs; 312 Operands[1] = rhs; 313 } 314 315 public: 316 const SCEV *getLHS() const { return Operands[0]; } 317 const SCEV *getRHS() const { return Operands[1]; } 318 size_t getNumOperands() const { return 2; } 319 const SCEV *getOperand(unsigned i) const { 320 assert((i == 0 || i == 1) && "Operand index out of range!"); 321 return i == 0 ? getLHS() : getRHS(); 322 } 323 324 ArrayRef<const SCEV *> operands() const { return Operands; } 325 326 Type *getType() const { 327 // In most cases the types of LHS and RHS will be the same, but in some 328 // crazy cases one or the other may be a pointer. ScalarEvolution doesn't 329 // depend on the type for correctness, but handling types carefully can 330 // avoid extra casts in the SCEVExpander. The LHS is more likely to be 331 // a pointer type than the RHS, so use the RHS' type here. 332 return getRHS()->getType(); 333 } 334 335 /// Methods for support type inquiry through isa, cast, and dyn_cast: 336 static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; } 337 }; 338 339 /// This node represents a polynomial recurrence on the trip count 340 /// of the specified loop. This is the primary focus of the 341 /// ScalarEvolution framework; all the other SCEV subclasses are 342 /// mostly just supporting infrastructure to allow SCEVAddRecExpr 343 /// expressions to be created and analyzed. 344 /// 345 /// All operands of an AddRec are required to be loop invariant. 346 /// 347 class SCEVAddRecExpr : public SCEVNAryExpr { 348 friend class ScalarEvolution; 349 350 const Loop *L; 351 352 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N, 353 const Loop *l) 354 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {} 355 356 public: 357 Type *getType() const { return getStart()->getType(); } 358 const SCEV *getStart() const { return Operands[0]; } 359 const Loop *getLoop() const { return L; } 360 361 /// Constructs and returns the recurrence indicating how much this 362 /// expression steps by. If this is a polynomial of degree N, it 363 /// returns a chrec of degree N-1. We cannot determine whether 364 /// the step recurrence has self-wraparound. 365 const SCEV *getStepRecurrence(ScalarEvolution &SE) const { 366 if (isAffine()) 367 return getOperand(1); 368 return SE.getAddRecExpr( 369 SmallVector<const SCEV *, 3>(operands().drop_front()), getLoop(), 370 FlagAnyWrap); 371 } 372 373 /// Return true if this represents an expression A + B*x where A 374 /// and B are loop invariant values. 375 bool isAffine() const { 376 // We know that the start value is invariant. This expression is thus 377 // affine iff the step is also invariant. 378 return getNumOperands() == 2; 379 } 380 381 /// Return true if this represents an expression A + B*x + C*x^2 382 /// where A, B and C are loop invariant values. This corresponds 383 /// to an addrec of the form {L,+,M,+,N} 384 bool isQuadratic() const { return getNumOperands() == 3; } 385 386 /// Set flags for a recurrence without clearing any previously set flags. 387 /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here 388 /// to make it easier to propagate flags. 389 void setNoWrapFlags(NoWrapFlags Flags) { 390 if (Flags & (FlagNUW | FlagNSW)) 391 Flags = ScalarEvolution::setFlags(Flags, FlagNW); 392 SubclassData |= Flags; 393 } 394 395 /// Return the value of this chain of recurrences at the specified 396 /// iteration number. 397 const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const; 398 399 /// Return the value of this chain of recurrences at the specified iteration 400 /// number. Takes an explicit list of operands to represent an AddRec. 401 static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands, 402 const SCEV *It, ScalarEvolution &SE); 403 404 /// Return the number of iterations of this loop that produce 405 /// values in the specified constant range. Another way of 406 /// looking at this is that it returns the first iteration number 407 /// where the value is not in the condition, thus computing the 408 /// exit count. If the iteration count can't be computed, an 409 /// instance of SCEVCouldNotCompute is returned. 410 const SCEV *getNumIterationsInRange(const ConstantRange &Range, 411 ScalarEvolution &SE) const; 412 413 /// Return an expression representing the value of this expression 414 /// one iteration of the loop ahead. 415 const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const; 416 417 /// Methods for support type inquiry through isa, cast, and dyn_cast: 418 static bool classof(const SCEV *S) { 419 return S->getSCEVType() == scAddRecExpr; 420 } 421 }; 422 423 /// This node is the base class min/max selections. 424 class SCEVMinMaxExpr : public SCEVCommutativeExpr { 425 friend class ScalarEvolution; 426 427 static bool isMinMaxType(enum SCEVTypes T) { 428 return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr || 429 T == scUMinExpr; 430 } 431 432 protected: 433 /// Note: Constructing subclasses via this constructor is allowed 434 SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 435 const SCEV *const *O, size_t N) 436 : SCEVCommutativeExpr(ID, T, O, N) { 437 assert(isMinMaxType(T)); 438 // Min and max never overflow 439 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 440 } 441 442 public: 443 Type *getType() const { return getOperand(0)->getType(); } 444 445 static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); } 446 447 static enum SCEVTypes negate(enum SCEVTypes T) { 448 switch (T) { 449 case scSMaxExpr: 450 return scSMinExpr; 451 case scSMinExpr: 452 return scSMaxExpr; 453 case scUMaxExpr: 454 return scUMinExpr; 455 case scUMinExpr: 456 return scUMaxExpr; 457 default: 458 llvm_unreachable("Not a min or max SCEV type!"); 459 } 460 } 461 }; 462 463 /// This class represents a signed maximum selection. 464 class SCEVSMaxExpr : public SCEVMinMaxExpr { 465 friend class ScalarEvolution; 466 467 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 468 : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {} 469 470 public: 471 /// Methods for support type inquiry through isa, cast, and dyn_cast: 472 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; } 473 }; 474 475 /// This class represents an unsigned maximum selection. 476 class SCEVUMaxExpr : public SCEVMinMaxExpr { 477 friend class ScalarEvolution; 478 479 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 480 : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {} 481 482 public: 483 /// Methods for support type inquiry through isa, cast, and dyn_cast: 484 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; } 485 }; 486 487 /// This class represents a signed minimum selection. 488 class SCEVSMinExpr : public SCEVMinMaxExpr { 489 friend class ScalarEvolution; 490 491 SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 492 : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {} 493 494 public: 495 /// Methods for support type inquiry through isa, cast, and dyn_cast: 496 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; } 497 }; 498 499 /// This class represents an unsigned minimum selection. 500 class SCEVUMinExpr : public SCEVMinMaxExpr { 501 friend class ScalarEvolution; 502 503 SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) 504 : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {} 505 506 public: 507 /// Methods for support type inquiry through isa, cast, and dyn_cast: 508 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; } 509 }; 510 511 /// This node is the base class for sequential/in-order min/max selections. 512 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they 513 /// are early-returning upon reaching saturation point. 514 /// I.e. given `0 umin_seq poison`, the result will be `0`, while the result of 515 /// `0 umin poison` is `poison`. When returning early, later expressions are not 516 /// executed, so `0 umin_seq (%x u/ 0)` does not result in undefined behavior. 517 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr { 518 friend class ScalarEvolution; 519 520 static bool isSequentialMinMaxType(enum SCEVTypes T) { 521 return T == scSequentialUMinExpr; 522 } 523 524 /// Set flags for a non-recurrence without clearing previously set flags. 525 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } 526 527 protected: 528 /// Note: Constructing subclasses via this constructor is allowed 529 SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, 530 const SCEV *const *O, size_t N) 531 : SCEVNAryExpr(ID, T, O, N) { 532 assert(isSequentialMinMaxType(T)); 533 // Min and max never overflow 534 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 535 } 536 537 public: 538 Type *getType() const { return getOperand(0)->getType(); } 539 540 static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) { 541 assert(isSequentialMinMaxType(Ty)); 542 switch (Ty) { 543 case scSequentialUMinExpr: 544 return scUMinExpr; 545 default: 546 llvm_unreachable("Not a sequential min/max type."); 547 } 548 } 549 550 SCEVTypes getEquivalentNonSequentialSCEVType() const { 551 return getEquivalentNonSequentialSCEVType(getSCEVType()); 552 } 553 554 static bool classof(const SCEV *S) { 555 return isSequentialMinMaxType(S->getSCEVType()); 556 } 557 }; 558 559 /// This class represents a sequential/in-order unsigned minimum selection. 560 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr { 561 friend class ScalarEvolution; 562 563 SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, 564 size_t N) 565 : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {} 566 567 public: 568 /// Methods for support type inquiry through isa, cast, and dyn_cast: 569 static bool classof(const SCEV *S) { 570 return S->getSCEVType() == scSequentialUMinExpr; 571 } 572 }; 573 574 /// This means that we are dealing with an entirely unknown SCEV 575 /// value, and only represent it as its LLVM Value. This is the 576 /// "bottom" value for the analysis. 577 class SCEVUnknown final : public SCEV, private CallbackVH { 578 friend class ScalarEvolution; 579 580 /// The parent ScalarEvolution value. This is used to update the 581 /// parent's maps when the value associated with a SCEVUnknown is 582 /// deleted or RAUW'd. 583 ScalarEvolution *SE; 584 585 /// The next pointer in the linked list of all SCEVUnknown 586 /// instances owned by a ScalarEvolution. 587 SCEVUnknown *Next; 588 589 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se, 590 SCEVUnknown *next) 591 : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {} 592 593 // Implement CallbackVH. 594 void deleted() override; 595 void allUsesReplacedWith(Value *New) override; 596 597 public: 598 Value *getValue() const { return getValPtr(); } 599 600 Type *getType() const { return getValPtr()->getType(); } 601 602 /// Methods for support type inquiry through isa, cast, and dyn_cast: 603 static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; } 604 }; 605 606 /// This class defines a simple visitor class that may be used for 607 /// various SCEV analysis purposes. 608 template <typename SC, typename RetVal = void> struct SCEVVisitor { 609 RetVal visit(const SCEV *S) { 610 switch (S->getSCEVType()) { 611 case scConstant: 612 return ((SC *)this)->visitConstant((const SCEVConstant *)S); 613 case scVScale: 614 return ((SC *)this)->visitVScale((const SCEVVScale *)S); 615 case scPtrToInt: 616 return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S); 617 case scTruncate: 618 return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S); 619 case scZeroExtend: 620 return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S); 621 case scSignExtend: 622 return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S); 623 case scAddExpr: 624 return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S); 625 case scMulExpr: 626 return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S); 627 case scUDivExpr: 628 return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S); 629 case scAddRecExpr: 630 return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S); 631 case scSMaxExpr: 632 return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S); 633 case scUMaxExpr: 634 return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S); 635 case scSMinExpr: 636 return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S); 637 case scUMinExpr: 638 return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S); 639 case scSequentialUMinExpr: 640 return ((SC *)this) 641 ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S); 642 case scUnknown: 643 return ((SC *)this)->visitUnknown((const SCEVUnknown *)S); 644 case scCouldNotCompute: 645 return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S); 646 } 647 llvm_unreachable("Unknown SCEV kind!"); 648 } 649 650 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) { 651 llvm_unreachable("Invalid use of SCEVCouldNotCompute!"); 652 } 653 }; 654 655 /// Visit all nodes in the expression tree using worklist traversal. 656 /// 657 /// Visitor implements: 658 /// // return true to follow this node. 659 /// bool follow(const SCEV *S); 660 /// // return true to terminate the search. 661 /// bool isDone(); 662 template <typename SV> class SCEVTraversal { 663 SV &Visitor; 664 SmallVector<const SCEV *, 8> Worklist; 665 SmallPtrSet<const SCEV *, 8> Visited; 666 667 void push(const SCEV *S) { 668 if (Visited.insert(S).second && Visitor.follow(S)) 669 Worklist.push_back(S); 670 } 671 672 public: 673 SCEVTraversal(SV &V) : Visitor(V) {} 674 675 void visitAll(const SCEV *Root) { 676 push(Root); 677 while (!Worklist.empty() && !Visitor.isDone()) { 678 const SCEV *S = Worklist.pop_back_val(); 679 680 switch (S->getSCEVType()) { 681 case scConstant: 682 case scVScale: 683 case scUnknown: 684 continue; 685 case scPtrToInt: 686 case scTruncate: 687 case scZeroExtend: 688 case scSignExtend: 689 case scAddExpr: 690 case scMulExpr: 691 case scUDivExpr: 692 case scSMaxExpr: 693 case scUMaxExpr: 694 case scSMinExpr: 695 case scUMinExpr: 696 case scSequentialUMinExpr: 697 case scAddRecExpr: 698 for (const auto *Op : S->operands()) { 699 push(Op); 700 if (Visitor.isDone()) 701 break; 702 } 703 continue; 704 case scCouldNotCompute: 705 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 706 } 707 llvm_unreachable("Unknown SCEV kind!"); 708 } 709 } 710 }; 711 712 /// Use SCEVTraversal to visit all nodes in the given expression tree. 713 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) { 714 SCEVTraversal<SV> T(Visitor); 715 T.visitAll(Root); 716 } 717 718 /// Return true if any node in \p Root satisfies the predicate \p Pred. 719 template <typename PredTy> 720 bool SCEVExprContains(const SCEV *Root, PredTy Pred) { 721 struct FindClosure { 722 bool Found = false; 723 PredTy Pred; 724 725 FindClosure(PredTy Pred) : Pred(Pred) {} 726 727 bool follow(const SCEV *S) { 728 if (!Pred(S)) 729 return true; 730 731 Found = true; 732 return false; 733 } 734 735 bool isDone() const { return Found; } 736 }; 737 738 FindClosure FC(Pred); 739 visitAll(Root, FC); 740 return FC.Found; 741 } 742 743 /// This visitor recursively visits a SCEV expression and re-writes it. 744 /// The result from each visit is cached, so it will return the same 745 /// SCEV for the same input. 746 template <typename SC> 747 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> { 748 protected: 749 ScalarEvolution &SE; 750 // Memoize the result of each visit so that we only compute once for 751 // the same input SCEV. This is to avoid redundant computations when 752 // a SCEV is referenced by multiple SCEVs. Without memoization, this 753 // visit algorithm would have exponential time complexity in the worst 754 // case, causing the compiler to hang on certain tests. 755 SmallDenseMap<const SCEV *, const SCEV *> RewriteResults; 756 757 public: 758 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {} 759 760 const SCEV *visit(const SCEV *S) { 761 auto It = RewriteResults.find(S); 762 if (It != RewriteResults.end()) 763 return It->second; 764 auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S); 765 auto Result = RewriteResults.try_emplace(S, Visited); 766 assert(Result.second && "Should insert a new entry"); 767 return Result.first->second; 768 } 769 770 const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } 771 772 const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; } 773 774 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { 775 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 776 return Operand == Expr->getOperand() 777 ? Expr 778 : SE.getPtrToIntExpr(Operand, Expr->getType()); 779 } 780 781 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { 782 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 783 return Operand == Expr->getOperand() 784 ? Expr 785 : SE.getTruncateExpr(Operand, Expr->getType()); 786 } 787 788 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 789 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 790 return Operand == Expr->getOperand() 791 ? Expr 792 : SE.getZeroExtendExpr(Operand, Expr->getType()); 793 } 794 795 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 796 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); 797 return Operand == Expr->getOperand() 798 ? Expr 799 : SE.getSignExtendExpr(Operand, Expr->getType()); 800 } 801 802 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { 803 SmallVector<const SCEV *, 2> Operands; 804 bool Changed = false; 805 for (const auto *Op : Expr->operands()) { 806 Operands.push_back(((SC *)this)->visit(Op)); 807 Changed |= Op != Operands.back(); 808 } 809 return !Changed ? Expr : SE.getAddExpr(Operands); 810 } 811 812 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { 813 SmallVector<const SCEV *, 2> Operands; 814 bool Changed = false; 815 for (const auto *Op : Expr->operands()) { 816 Operands.push_back(((SC *)this)->visit(Op)); 817 Changed |= Op != Operands.back(); 818 } 819 return !Changed ? Expr : SE.getMulExpr(Operands); 820 } 821 822 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { 823 auto *LHS = ((SC *)this)->visit(Expr->getLHS()); 824 auto *RHS = ((SC *)this)->visit(Expr->getRHS()); 825 bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS(); 826 return !Changed ? Expr : SE.getUDivExpr(LHS, RHS); 827 } 828 829 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 830 SmallVector<const SCEV *, 2> Operands; 831 bool Changed = false; 832 for (const auto *Op : Expr->operands()) { 833 Operands.push_back(((SC *)this)->visit(Op)); 834 Changed |= Op != Operands.back(); 835 } 836 return !Changed ? Expr 837 : SE.getAddRecExpr(Operands, Expr->getLoop(), 838 Expr->getNoWrapFlags()); 839 } 840 841 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { 842 SmallVector<const SCEV *, 2> Operands; 843 bool Changed = false; 844 for (const auto *Op : Expr->operands()) { 845 Operands.push_back(((SC *)this)->visit(Op)); 846 Changed |= Op != Operands.back(); 847 } 848 return !Changed ? Expr : SE.getSMaxExpr(Operands); 849 } 850 851 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { 852 SmallVector<const SCEV *, 2> Operands; 853 bool Changed = false; 854 for (const auto *Op : Expr->operands()) { 855 Operands.push_back(((SC *)this)->visit(Op)); 856 Changed |= Op != Operands.back(); 857 } 858 return !Changed ? Expr : SE.getUMaxExpr(Operands); 859 } 860 861 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) { 862 SmallVector<const SCEV *, 2> Operands; 863 bool Changed = false; 864 for (const auto *Op : Expr->operands()) { 865 Operands.push_back(((SC *)this)->visit(Op)); 866 Changed |= Op != Operands.back(); 867 } 868 return !Changed ? Expr : SE.getSMinExpr(Operands); 869 } 870 871 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) { 872 SmallVector<const SCEV *, 2> Operands; 873 bool Changed = false; 874 for (const auto *Op : Expr->operands()) { 875 Operands.push_back(((SC *)this)->visit(Op)); 876 Changed |= Op != Operands.back(); 877 } 878 return !Changed ? Expr : SE.getUMinExpr(Operands); 879 } 880 881 const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { 882 SmallVector<const SCEV *, 2> Operands; 883 bool Changed = false; 884 for (const auto *Op : Expr->operands()) { 885 Operands.push_back(((SC *)this)->visit(Op)); 886 Changed |= Op != Operands.back(); 887 } 888 return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true); 889 } 890 891 const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; } 892 893 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { 894 return Expr; 895 } 896 }; 897 898 using ValueToValueMap = DenseMap<const Value *, Value *>; 899 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>; 900 901 /// The SCEVParameterRewriter takes a scalar evolution expression and updates 902 /// the SCEVUnknown components following the Map (Value -> SCEV). 903 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> { 904 public: 905 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, 906 ValueToSCEVMapTy &Map) { 907 SCEVParameterRewriter Rewriter(SE, Map); 908 return Rewriter.visit(Scev); 909 } 910 911 SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M) 912 : SCEVRewriteVisitor(SE), Map(M) {} 913 914 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 915 auto I = Map.find(Expr->getValue()); 916 if (I == Map.end()) 917 return Expr; 918 return I->second; 919 } 920 921 private: 922 ValueToSCEVMapTy ⤅ 923 }; 924 925 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>; 926 927 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies 928 /// the Map (Loop -> SCEV) to all AddRecExprs. 929 class SCEVLoopAddRecRewriter 930 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> { 931 public: 932 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M) 933 : SCEVRewriteVisitor(SE), Map(M) {} 934 935 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, 936 ScalarEvolution &SE) { 937 SCEVLoopAddRecRewriter Rewriter(SE, Map); 938 return Rewriter.visit(Scev); 939 } 940 941 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 942 SmallVector<const SCEV *, 2> Operands; 943 for (const SCEV *Op : Expr->operands()) 944 Operands.push_back(visit(Op)); 945 946 const Loop *L = Expr->getLoop(); 947 if (0 == Map.count(L)) 948 return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags()); 949 950 return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE); 951 } 952 953 private: 954 LoopToScevMapT ⤅ 955 }; 956 957 } // end namespace llvm 958 959 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 960