1 //===-- LoopPredication.cpp - Guard based loop predication pass -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The LoopPredication pass tries to convert loop variant range checks to loop 10 // invariant by widening checks across loop iterations. For example, it will 11 // convert 12 // 13 // for (i = 0; i < n; i++) { 14 // guard(i < len); 15 // ... 16 // } 17 // 18 // to 19 // 20 // for (i = 0; i < n; i++) { 21 // guard(n - 1 < len); 22 // ... 23 // } 24 // 25 // After this transformation the condition of the guard is loop invariant, so 26 // loop-unswitch can later unswitch the loop by this condition which basically 27 // predicates the loop by the widened condition: 28 // 29 // if (n - 1 < len) 30 // for (i = 0; i < n; i++) { 31 // ... 32 // } 33 // else 34 // deoptimize 35 // 36 // It's tempting to rely on SCEV here, but it has proven to be problematic. 37 // Generally the facts SCEV provides about the increment step of add 38 // recurrences are true if the backedge of the loop is taken, which implicitly 39 // assumes that the guard doesn't fail. Using these facts to optimize the 40 // guard results in a circular logic where the guard is optimized under the 41 // assumption that it never fails. 42 // 43 // For example, in the loop below the induction variable will be marked as nuw 44 // basing on the guard. Basing on nuw the guard predicate will be considered 45 // monotonic. Given a monotonic condition it's tempting to replace the induction 46 // variable in the condition with its value on the last iteration. But this 47 // transformation is not correct, e.g. e = 4, b = 5 breaks the loop. 48 // 49 // for (int i = b; i != e; i++) 50 // guard(i u< len) 51 // 52 // One of the ways to reason about this problem is to use an inductive proof 53 // approach. Given the loop: 54 // 55 // if (B(0)) { 56 // do { 57 // I = PHI(0, I.INC) 58 // I.INC = I + Step 59 // guard(G(I)); 60 // } while (B(I)); 61 // } 62 // 63 // where B(x) and G(x) are predicates that map integers to booleans, we want a 64 // loop invariant expression M such the following program has the same semantics 65 // as the above: 66 // 67 // if (B(0)) { 68 // do { 69 // I = PHI(0, I.INC) 70 // I.INC = I + Step 71 // guard(G(0) && M); 72 // } while (B(I)); 73 // } 74 // 75 // One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step) 76 // 77 // Informal proof that the transformation above is correct: 78 // 79 // By the definition of guards we can rewrite the guard condition to: 80 // G(I) && G(0) && M 81 // 82 // Let's prove that for each iteration of the loop: 83 // G(0) && M => G(I) 84 // And the condition above can be simplified to G(Start) && M. 85 // 86 // Induction base. 87 // G(0) && M => G(0) 88 // 89 // Induction step. Assuming G(0) && M => G(I) on the subsequent 90 // iteration: 91 // 92 // B(I) is true because it's the backedge condition. 93 // G(I) is true because the backedge is guarded by this condition. 94 // 95 // So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step). 96 // 97 // Note that we can use anything stronger than M, i.e. any condition which 98 // implies M. 99 // 100 // When S = 1 (i.e. forward iterating loop), the transformation is supported 101 // when: 102 // * The loop has a single latch with the condition of the form: 103 // B(X) = latchStart + X <pred> latchLimit, 104 // where <pred> is u<, u<=, s<, or s<=. 105 // * The guard condition is of the form 106 // G(X) = guardStart + X u< guardLimit 107 // 108 // For the ult latch comparison case M is: 109 // forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit => 110 // guardStart + X + 1 u< guardLimit 111 // 112 // The only way the antecedent can be true and the consequent can be false is 113 // if 114 // X == guardLimit - 1 - guardStart 115 // (and guardLimit is non-zero, but we won't use this latter fact). 116 // If X == guardLimit - 1 - guardStart then the second half of the antecedent is 117 // latchStart + guardLimit - 1 - guardStart u< latchLimit 118 // and its negation is 119 // latchStart + guardLimit - 1 - guardStart u>= latchLimit 120 // 121 // In other words, if 122 // latchLimit u<= latchStart + guardLimit - 1 - guardStart 123 // then: 124 // (the ranges below are written in ConstantRange notation, where [A, B) is the 125 // set for (I = A; I != B; I++ /*maywrap*/) yield(I);) 126 // 127 // forall X . guardStart + X u< guardLimit && 128 // latchStart + X u< latchLimit => 129 // guardStart + X + 1 u< guardLimit 130 // == forall X . guardStart + X u< guardLimit && 131 // latchStart + X u< latchStart + guardLimit - 1 - guardStart => 132 // guardStart + X + 1 u< guardLimit 133 // == forall X . (guardStart + X) in [0, guardLimit) && 134 // (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) => 135 // (guardStart + X + 1) in [0, guardLimit) 136 // == forall X . X in [-guardStart, guardLimit - guardStart) && 137 // X in [-latchStart, guardLimit - 1 - guardStart) => 138 // X in [-guardStart - 1, guardLimit - guardStart - 1) 139 // == true 140 // 141 // So the widened condition is: 142 // guardStart u< guardLimit && 143 // latchStart + guardLimit - 1 - guardStart u>= latchLimit 144 // Similarly for ule condition the widened condition is: 145 // guardStart u< guardLimit && 146 // latchStart + guardLimit - 1 - guardStart u> latchLimit 147 // For slt condition the widened condition is: 148 // guardStart u< guardLimit && 149 // latchStart + guardLimit - 1 - guardStart s>= latchLimit 150 // For sle condition the widened condition is: 151 // guardStart u< guardLimit && 152 // latchStart + guardLimit - 1 - guardStart s> latchLimit 153 // 154 // When S = -1 (i.e. reverse iterating loop), the transformation is supported 155 // when: 156 // * The loop has a single latch with the condition of the form: 157 // B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=. 158 // * The guard condition is of the form 159 // G(X) = X - 1 u< guardLimit 160 // 161 // For the ugt latch comparison case M is: 162 // forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit 163 // 164 // The only way the antecedent can be true and the consequent can be false is if 165 // X == 1. 166 // If X == 1 then the second half of the antecedent is 167 // 1 u> latchLimit, and its negation is latchLimit u>= 1. 168 // 169 // So the widened condition is: 170 // guardStart u< guardLimit && latchLimit u>= 1. 171 // Similarly for sgt condition the widened condition is: 172 // guardStart u< guardLimit && latchLimit s>= 1. 173 // For uge condition the widened condition is: 174 // guardStart u< guardLimit && latchLimit u> 1. 175 // For sge condition the widened condition is: 176 // guardStart u< guardLimit && latchLimit s> 1. 177 //===----------------------------------------------------------------------===// 178 179 #include "llvm/Transforms/Scalar/LoopPredication.h" 180 #include "llvm/ADT/Statistic.h" 181 #include "llvm/Analysis/AliasAnalysis.h" 182 #include "llvm/Analysis/BranchProbabilityInfo.h" 183 #include "llvm/Analysis/GuardUtils.h" 184 #include "llvm/Analysis/LoopInfo.h" 185 #include "llvm/Analysis/LoopPass.h" 186 #include "llvm/Analysis/MemorySSA.h" 187 #include "llvm/Analysis/MemorySSAUpdater.h" 188 #include "llvm/Analysis/ScalarEvolution.h" 189 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 190 #include "llvm/IR/Function.h" 191 #include "llvm/IR/IntrinsicInst.h" 192 #include "llvm/IR/Module.h" 193 #include "llvm/IR/PatternMatch.h" 194 #include "llvm/IR/ProfDataUtils.h" 195 #include "llvm/Pass.h" 196 #include "llvm/Support/CommandLine.h" 197 #include "llvm/Support/Debug.h" 198 #include "llvm/Transforms/Scalar.h" 199 #include "llvm/Transforms/Utils/GuardUtils.h" 200 #include "llvm/Transforms/Utils/Local.h" 201 #include "llvm/Transforms/Utils/LoopUtils.h" 202 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 203 #include <optional> 204 205 #define DEBUG_TYPE "loop-predication" 206 207 STATISTIC(TotalConsidered, "Number of guards considered"); 208 STATISTIC(TotalWidened, "Number of checks widened"); 209 210 using namespace llvm; 211 212 static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation", 213 cl::Hidden, cl::init(true)); 214 215 static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop", 216 cl::Hidden, cl::init(true)); 217 218 static cl::opt<bool> 219 SkipProfitabilityChecks("loop-predication-skip-profitability-checks", 220 cl::Hidden, cl::init(false)); 221 222 // This is the scale factor for the latch probability. We use this during 223 // profitability analysis to find other exiting blocks that have a much higher 224 // probability of exiting the loop instead of loop exiting via latch. 225 // This value should be greater than 1 for a sane profitability check. 226 static cl::opt<float> LatchExitProbabilityScale( 227 "loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0), 228 cl::desc("scale factor for the latch probability. Value should be greater " 229 "than 1. Lower values are ignored")); 230 231 static cl::opt<bool> PredicateWidenableBranchGuards( 232 "loop-predication-predicate-widenable-branches-to-deopt", cl::Hidden, 233 cl::desc("Whether or not we should predicate guards " 234 "expressed as widenable branches to deoptimize blocks"), 235 cl::init(true)); 236 237 static cl::opt<bool> InsertAssumesOfPredicatedGuardsConditions( 238 "loop-predication-insert-assumes-of-predicated-guards-conditions", 239 cl::Hidden, 240 cl::desc("Whether or not we should insert assumes of conditions of " 241 "predicated guards"), 242 cl::init(true)); 243 244 namespace { 245 /// Represents an induction variable check: 246 /// icmp Pred, <induction variable>, <loop invariant limit> 247 struct LoopICmp { 248 ICmpInst::Predicate Pred; 249 const SCEVAddRecExpr *IV; 250 const SCEV *Limit; 251 LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV, 252 const SCEV *Limit) 253 : Pred(Pred), IV(IV), Limit(Limit) {} 254 LoopICmp() = default; 255 void dump() { 256 dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV 257 << ", Limit = " << *Limit << "\n"; 258 } 259 }; 260 261 class LoopPredication { 262 AliasAnalysis *AA; 263 DominatorTree *DT; 264 ScalarEvolution *SE; 265 LoopInfo *LI; 266 MemorySSAUpdater *MSSAU; 267 268 Loop *L; 269 const DataLayout *DL; 270 BasicBlock *Preheader; 271 LoopICmp LatchCheck; 272 273 bool isSupportedStep(const SCEV* Step); 274 std::optional<LoopICmp> parseLoopICmp(ICmpInst *ICI); 275 std::optional<LoopICmp> parseLoopLatchICmp(); 276 277 /// Return an insertion point suitable for inserting a safe to speculate 278 /// instruction whose only user will be 'User' which has operands 'Ops'. A 279 /// trivial result would be the at the User itself, but we try to return a 280 /// loop invariant location if possible. 281 Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops); 282 /// Same as above, *except* that this uses the SCEV definition of invariant 283 /// which is that an expression *can be made* invariant via SCEVExpander. 284 /// Thus, this version is only suitable for finding an insert point to be 285 /// passed to SCEVExpander! 286 Instruction *findInsertPt(const SCEVExpander &Expander, Instruction *User, 287 ArrayRef<const SCEV *> Ops); 288 289 /// Return true if the value is known to produce a single fixed value across 290 /// all iterations on which it executes. Note that this does not imply 291 /// speculation safety. That must be established separately. 292 bool isLoopInvariantValue(const SCEV* S); 293 294 Value *expandCheck(SCEVExpander &Expander, Instruction *Guard, 295 ICmpInst::Predicate Pred, const SCEV *LHS, 296 const SCEV *RHS); 297 298 std::optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, 299 SCEVExpander &Expander, 300 Instruction *Guard); 301 std::optional<Value *> 302 widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, 303 SCEVExpander &Expander, 304 Instruction *Guard); 305 std::optional<Value *> 306 widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, 307 SCEVExpander &Expander, 308 Instruction *Guard); 309 void widenChecks(SmallVectorImpl<Value *> &Checks, 310 SmallVectorImpl<Value *> &WidenedChecks, 311 SCEVExpander &Expander, Instruction *Guard); 312 bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); 313 bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander); 314 // If the loop always exits through another block in the loop, we should not 315 // predicate based on the latch check. For example, the latch check can be a 316 // very coarse grained check and there can be more fine grained exit checks 317 // within the loop. 318 bool isLoopProfitableToPredicate(); 319 320 bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter); 321 322 public: 323 LoopPredication(AliasAnalysis *AA, DominatorTree *DT, ScalarEvolution *SE, 324 LoopInfo *LI, MemorySSAUpdater *MSSAU) 325 : AA(AA), DT(DT), SE(SE), LI(LI), MSSAU(MSSAU){}; 326 bool runOnLoop(Loop *L); 327 }; 328 329 } // end namespace 330 331 PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, 332 LoopStandardAnalysisResults &AR, 333 LPMUpdater &U) { 334 std::unique_ptr<MemorySSAUpdater> MSSAU; 335 if (AR.MSSA) 336 MSSAU = std::make_unique<MemorySSAUpdater>(AR.MSSA); 337 LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, 338 MSSAU ? MSSAU.get() : nullptr); 339 if (!LP.runOnLoop(&L)) 340 return PreservedAnalyses::all(); 341 342 auto PA = getLoopPassPreservedAnalyses(); 343 if (AR.MSSA) 344 PA.preserve<MemorySSAAnalysis>(); 345 return PA; 346 } 347 348 std::optional<LoopICmp> LoopPredication::parseLoopICmp(ICmpInst *ICI) { 349 auto Pred = ICI->getPredicate(); 350 auto *LHS = ICI->getOperand(0); 351 auto *RHS = ICI->getOperand(1); 352 353 const SCEV *LHSS = SE->getSCEV(LHS); 354 if (isa<SCEVCouldNotCompute>(LHSS)) 355 return std::nullopt; 356 const SCEV *RHSS = SE->getSCEV(RHS); 357 if (isa<SCEVCouldNotCompute>(RHSS)) 358 return std::nullopt; 359 360 // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV 361 if (SE->isLoopInvariant(LHSS, L)) { 362 std::swap(LHS, RHS); 363 std::swap(LHSS, RHSS); 364 Pred = ICmpInst::getSwappedPredicate(Pred); 365 } 366 367 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS); 368 if (!AR || AR->getLoop() != L) 369 return std::nullopt; 370 371 return LoopICmp(Pred, AR, RHSS); 372 } 373 374 Value *LoopPredication::expandCheck(SCEVExpander &Expander, 375 Instruction *Guard, 376 ICmpInst::Predicate Pred, const SCEV *LHS, 377 const SCEV *RHS) { 378 Type *Ty = LHS->getType(); 379 assert(Ty == RHS->getType() && "expandCheck operands have different types?"); 380 381 if (SE->isLoopInvariant(LHS, L) && SE->isLoopInvariant(RHS, L)) { 382 IRBuilder<> Builder(Guard); 383 if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS)) 384 return Builder.getTrue(); 385 if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred), 386 LHS, RHS)) 387 return Builder.getFalse(); 388 } 389 390 Value *LHSV = 391 Expander.expandCodeFor(LHS, Ty, findInsertPt(Expander, Guard, {LHS})); 392 Value *RHSV = 393 Expander.expandCodeFor(RHS, Ty, findInsertPt(Expander, Guard, {RHS})); 394 IRBuilder<> Builder(findInsertPt(Guard, {LHSV, RHSV})); 395 return Builder.CreateICmp(Pred, LHSV, RHSV); 396 } 397 398 // Returns true if its safe to truncate the IV to RangeCheckType. 399 // When the IV type is wider than the range operand type, we can still do loop 400 // predication, by generating SCEVs for the range and latch that are of the 401 // same type. We achieve this by generating a SCEV truncate expression for the 402 // latch IV. This is done iff truncation of the IV is a safe operation, 403 // without loss of information. 404 // Another way to achieve this is by generating a wider type SCEV for the 405 // range check operand, however, this needs a more involved check that 406 // operands do not overflow. This can lead to loss of information when the 407 // range operand is of the form: add i32 %offset, %iv. We need to prove that 408 // sext(x + y) is same as sext(x) + sext(y). 409 // This function returns true if we can safely represent the IV type in 410 // the RangeCheckType without loss of information. 411 static bool isSafeToTruncateWideIVType(const DataLayout &DL, 412 ScalarEvolution &SE, 413 const LoopICmp LatchCheck, 414 Type *RangeCheckType) { 415 if (!EnableIVTruncation) 416 return false; 417 assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedValue() > 418 DL.getTypeSizeInBits(RangeCheckType).getFixedValue() && 419 "Expected latch check IV type to be larger than range check operand " 420 "type!"); 421 // The start and end values of the IV should be known. This is to guarantee 422 // that truncating the wide type will not lose information. 423 auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit); 424 auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart()); 425 if (!Limit || !Start) 426 return false; 427 // This check makes sure that the IV does not change sign during loop 428 // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE, 429 // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the 430 // IV wraps around, and the truncation of the IV would lose the range of 431 // iterations between 2^32 and 2^64. 432 if (!SE.getMonotonicPredicateType(LatchCheck.IV, LatchCheck.Pred)) 433 return false; 434 // The active bits should be less than the bits in the RangeCheckType. This 435 // guarantees that truncating the latch check to RangeCheckType is a safe 436 // operation. 437 auto RangeCheckTypeBitSize = 438 DL.getTypeSizeInBits(RangeCheckType).getFixedValue(); 439 return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize && 440 Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; 441 } 442 443 444 // Return an LoopICmp describing a latch check equivlent to LatchCheck but with 445 // the requested type if safe to do so. May involve the use of a new IV. 446 static std::optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL, 447 ScalarEvolution &SE, 448 const LoopICmp LatchCheck, 449 Type *RangeCheckType) { 450 451 auto *LatchType = LatchCheck.IV->getType(); 452 if (RangeCheckType == LatchType) 453 return LatchCheck; 454 // For now, bail out if latch type is narrower than range type. 455 if (DL.getTypeSizeInBits(LatchType).getFixedValue() < 456 DL.getTypeSizeInBits(RangeCheckType).getFixedValue()) 457 return std::nullopt; 458 if (!isSafeToTruncateWideIVType(DL, SE, LatchCheck, RangeCheckType)) 459 return std::nullopt; 460 // We can now safely identify the truncated version of the IV and limit for 461 // RangeCheckType. 462 LoopICmp NewLatchCheck; 463 NewLatchCheck.Pred = LatchCheck.Pred; 464 NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>( 465 SE.getTruncateExpr(LatchCheck.IV, RangeCheckType)); 466 if (!NewLatchCheck.IV) 467 return std::nullopt; 468 NewLatchCheck.Limit = SE.getTruncateExpr(LatchCheck.Limit, RangeCheckType); 469 LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType 470 << "can be represented as range check type:" 471 << *RangeCheckType << "\n"); 472 LLVM_DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n"); 473 LLVM_DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n"); 474 return NewLatchCheck; 475 } 476 477 bool LoopPredication::isSupportedStep(const SCEV* Step) { 478 return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop); 479 } 480 481 Instruction *LoopPredication::findInsertPt(Instruction *Use, 482 ArrayRef<Value*> Ops) { 483 for (Value *Op : Ops) 484 if (!L->isLoopInvariant(Op)) 485 return Use; 486 return Preheader->getTerminator(); 487 } 488 489 Instruction *LoopPredication::findInsertPt(const SCEVExpander &Expander, 490 Instruction *Use, 491 ArrayRef<const SCEV *> Ops) { 492 // Subtlety: SCEV considers things to be invariant if the value produced is 493 // the same across iterations. This is not the same as being able to 494 // evaluate outside the loop, which is what we actually need here. 495 for (const SCEV *Op : Ops) 496 if (!SE->isLoopInvariant(Op, L) || 497 !Expander.isSafeToExpandAt(Op, Preheader->getTerminator())) 498 return Use; 499 return Preheader->getTerminator(); 500 } 501 502 bool LoopPredication::isLoopInvariantValue(const SCEV* S) { 503 // Handling expressions which produce invariant results, but *haven't* yet 504 // been removed from the loop serves two important purposes. 505 // 1) Most importantly, it resolves a pass ordering cycle which would 506 // otherwise need us to iteration licm, loop-predication, and either 507 // loop-unswitch or loop-peeling to make progress on examples with lots of 508 // predicable range checks in a row. (Since, in the general case, we can't 509 // hoist the length checks until the dominating checks have been discharged 510 // as we can't prove doing so is safe.) 511 // 2) As a nice side effect, this exposes the value of peeling or unswitching 512 // much more obviously in the IR. Otherwise, the cost modeling for other 513 // transforms would end up needing to duplicate all of this logic to model a 514 // check which becomes predictable based on a modeled peel or unswitch. 515 // 516 // The cost of doing so in the worst case is an extra fill from the stack in 517 // the loop to materialize the loop invariant test value instead of checking 518 // against the original IV which is presumable in a register inside the loop. 519 // Such cases are presumably rare, and hint at missing oppurtunities for 520 // other passes. 521 522 if (SE->isLoopInvariant(S, L)) 523 // Note: This the SCEV variant, so the original Value* may be within the 524 // loop even though SCEV has proven it is loop invariant. 525 return true; 526 527 // Handle a particular important case which SCEV doesn't yet know about which 528 // shows up in range checks on arrays with immutable lengths. 529 // TODO: This should be sunk inside SCEV. 530 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) 531 if (const auto *LI = dyn_cast<LoadInst>(U->getValue())) 532 if (LI->isUnordered() && L->hasLoopInvariantOperands(LI)) 533 if (!isModSet(AA->getModRefInfoMask(LI->getOperand(0))) || 534 LI->hasMetadata(LLVMContext::MD_invariant_load)) 535 return true; 536 return false; 537 } 538 539 std::optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( 540 LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, 541 Instruction *Guard) { 542 auto *Ty = RangeCheck.IV->getType(); 543 // Generate the widened condition for the forward loop: 544 // guardStart u< guardLimit && 545 // latchLimit <pred> guardLimit - 1 - guardStart + latchStart 546 // where <pred> depends on the latch condition predicate. See the file 547 // header comment for the reasoning. 548 // guardLimit - guardStart + latchStart - 1 549 const SCEV *GuardStart = RangeCheck.IV->getStart(); 550 const SCEV *GuardLimit = RangeCheck.Limit; 551 const SCEV *LatchStart = LatchCheck.IV->getStart(); 552 const SCEV *LatchLimit = LatchCheck.Limit; 553 // Subtlety: We need all the values to be *invariant* across all iterations, 554 // but we only need to check expansion safety for those which *aren't* 555 // already guaranteed to dominate the guard. 556 if (!isLoopInvariantValue(GuardStart) || 557 !isLoopInvariantValue(GuardLimit) || 558 !isLoopInvariantValue(LatchStart) || 559 !isLoopInvariantValue(LatchLimit)) { 560 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); 561 return std::nullopt; 562 } 563 if (!Expander.isSafeToExpandAt(LatchStart, Guard) || 564 !Expander.isSafeToExpandAt(LatchLimit, Guard)) { 565 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); 566 return std::nullopt; 567 } 568 569 // guardLimit - guardStart + latchStart - 1 570 const SCEV *RHS = 571 SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart), 572 SE->getMinusSCEV(LatchStart, SE->getOne(Ty))); 573 auto LimitCheckPred = 574 ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred); 575 576 LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n"); 577 LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n"); 578 LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n"); 579 580 auto *LimitCheck = 581 expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS); 582 auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred, 583 GuardStart, GuardLimit); 584 IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck})); 585 return Builder.CreateFreeze( 586 Builder.CreateAnd(FirstIterationCheck, LimitCheck)); 587 } 588 589 std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( 590 LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, 591 Instruction *Guard) { 592 auto *Ty = RangeCheck.IV->getType(); 593 const SCEV *GuardStart = RangeCheck.IV->getStart(); 594 const SCEV *GuardLimit = RangeCheck.Limit; 595 const SCEV *LatchStart = LatchCheck.IV->getStart(); 596 const SCEV *LatchLimit = LatchCheck.Limit; 597 // Subtlety: We need all the values to be *invariant* across all iterations, 598 // but we only need to check expansion safety for those which *aren't* 599 // already guaranteed to dominate the guard. 600 if (!isLoopInvariantValue(GuardStart) || 601 !isLoopInvariantValue(GuardLimit) || 602 !isLoopInvariantValue(LatchStart) || 603 !isLoopInvariantValue(LatchLimit)) { 604 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); 605 return std::nullopt; 606 } 607 if (!Expander.isSafeToExpandAt(LatchStart, Guard) || 608 !Expander.isSafeToExpandAt(LatchLimit, Guard)) { 609 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); 610 return std::nullopt; 611 } 612 // The decrement of the latch check IV should be the same as the 613 // rangeCheckIV. 614 auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE); 615 if (RangeCheck.IV != PostDecLatchCheckIV) { 616 LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: " 617 << *PostDecLatchCheckIV 618 << " and RangeCheckIV: " << *RangeCheck.IV << "\n"); 619 return std::nullopt; 620 } 621 622 // Generate the widened condition for CountDownLoop: 623 // guardStart u< guardLimit && 624 // latchLimit <pred> 1. 625 // See the header comment for reasoning of the checks. 626 auto LimitCheckPred = 627 ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred); 628 auto *FirstIterationCheck = expandCheck(Expander, Guard, 629 ICmpInst::ICMP_ULT, 630 GuardStart, GuardLimit); 631 auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, 632 SE->getOne(Ty)); 633 IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck})); 634 return Builder.CreateFreeze( 635 Builder.CreateAnd(FirstIterationCheck, LimitCheck)); 636 } 637 638 static void normalizePredicate(ScalarEvolution *SE, Loop *L, 639 LoopICmp& RC) { 640 // LFTR canonicalizes checks to the ICMP_NE/EQ form; normalize back to the 641 // ULT/UGE form for ease of handling by our caller. 642 if (ICmpInst::isEquality(RC.Pred) && 643 RC.IV->getStepRecurrence(*SE)->isOne() && 644 SE->isKnownPredicate(ICmpInst::ICMP_ULE, RC.IV->getStart(), RC.Limit)) 645 RC.Pred = RC.Pred == ICmpInst::ICMP_NE ? 646 ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE; 647 } 648 649 /// If ICI can be widened to a loop invariant condition emits the loop 650 /// invariant condition in the loop preheader and return it, otherwise 651 /// returns std::nullopt. 652 std::optional<Value *> 653 LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, 654 Instruction *Guard) { 655 LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); 656 LLVM_DEBUG(ICI->dump()); 657 658 // parseLoopStructure guarantees that the latch condition is: 659 // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=. 660 // We are looking for the range checks of the form: 661 // i u< guardLimit 662 auto RangeCheck = parseLoopICmp(ICI); 663 if (!RangeCheck) { 664 LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); 665 return std::nullopt; 666 } 667 LLVM_DEBUG(dbgs() << "Guard check:\n"); 668 LLVM_DEBUG(RangeCheck->dump()); 669 if (RangeCheck->Pred != ICmpInst::ICMP_ULT) { 670 LLVM_DEBUG(dbgs() << "Unsupported range check predicate(" 671 << RangeCheck->Pred << ")!\n"); 672 return std::nullopt; 673 } 674 auto *RangeCheckIV = RangeCheck->IV; 675 if (!RangeCheckIV->isAffine()) { 676 LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n"); 677 return std::nullopt; 678 } 679 const SCEV *Step = RangeCheckIV->getStepRecurrence(*SE); 680 // We cannot just compare with latch IV step because the latch and range IVs 681 // may have different types. 682 if (!isSupportedStep(Step)) { 683 LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n"); 684 return std::nullopt; 685 } 686 auto *Ty = RangeCheckIV->getType(); 687 auto CurrLatchCheckOpt = generateLoopLatchCheck(*DL, *SE, LatchCheck, Ty); 688 if (!CurrLatchCheckOpt) { 689 LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check " 690 "corresponding to range type: " 691 << *Ty << "\n"); 692 return std::nullopt; 693 } 694 695 LoopICmp CurrLatchCheck = *CurrLatchCheckOpt; 696 // At this point, the range and latch step should have the same type, but need 697 // not have the same value (we support both 1 and -1 steps). 698 assert(Step->getType() == 699 CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() && 700 "Range and latch steps should be of same type!"); 701 if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) { 702 LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n"); 703 return std::nullopt; 704 } 705 706 if (Step->isOne()) 707 return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck, 708 Expander, Guard); 709 else { 710 assert(Step->isAllOnesValue() && "Step should be -1!"); 711 return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck, 712 Expander, Guard); 713 } 714 } 715 716 void LoopPredication::widenChecks(SmallVectorImpl<Value *> &Checks, 717 SmallVectorImpl<Value *> &WidenedChecks, 718 SCEVExpander &Expander, Instruction *Guard) { 719 for (auto &Check : Checks) 720 if (ICmpInst *ICI = dyn_cast<ICmpInst>(Check)) 721 if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Guard)) { 722 WidenedChecks.push_back(Check); 723 Check = *NewRangeCheck; 724 } 725 } 726 727 bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, 728 SCEVExpander &Expander) { 729 LLVM_DEBUG(dbgs() << "Processing guard:\n"); 730 LLVM_DEBUG(Guard->dump()); 731 732 TotalConsidered++; 733 SmallVector<Value *, 4> Checks; 734 SmallVector<Value *> WidenedChecks; 735 parseWidenableGuard(Guard, Checks); 736 widenChecks(Checks, WidenedChecks, Expander, Guard); 737 if (WidenedChecks.empty()) 738 return false; 739 740 TotalWidened += WidenedChecks.size(); 741 742 // Emit the new guard condition 743 IRBuilder<> Builder(findInsertPt(Guard, Checks)); 744 Value *AllChecks = Builder.CreateAnd(Checks); 745 auto *OldCond = Guard->getOperand(0); 746 Guard->setOperand(0, AllChecks); 747 if (InsertAssumesOfPredicatedGuardsConditions) { 748 Builder.SetInsertPoint(&*++BasicBlock::iterator(Guard)); 749 Builder.CreateAssumption(OldCond); 750 } 751 RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU); 752 753 LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n"); 754 return true; 755 } 756 757 bool LoopPredication::widenWidenableBranchGuardConditions( 758 BranchInst *BI, SCEVExpander &Expander) { 759 assert(isGuardAsWidenableBranch(BI) && "Must be!"); 760 LLVM_DEBUG(dbgs() << "Processing guard:\n"); 761 LLVM_DEBUG(BI->dump()); 762 763 TotalConsidered++; 764 SmallVector<Value *, 4> Checks; 765 SmallVector<Value *> WidenedChecks; 766 parseWidenableGuard(BI, Checks); 767 // At the moment, our matching logic for wideable conditions implicitly 768 // assumes we preserve the form: (br (and Cond, WC())). FIXME 769 auto WC = extractWidenableCondition(BI); 770 Checks.push_back(WC); 771 widenChecks(Checks, WidenedChecks, Expander, BI); 772 if (WidenedChecks.empty()) 773 return false; 774 775 TotalWidened += WidenedChecks.size(); 776 777 // Emit the new guard condition 778 IRBuilder<> Builder(findInsertPt(BI, Checks)); 779 Value *AllChecks = Builder.CreateAnd(Checks); 780 auto *OldCond = BI->getCondition(); 781 BI->setCondition(AllChecks); 782 if (InsertAssumesOfPredicatedGuardsConditions) { 783 BasicBlock *IfTrueBB = BI->getSuccessor(0); 784 Builder.SetInsertPoint(IfTrueBB, IfTrueBB->getFirstInsertionPt()); 785 // If this block has other predecessors, we might not be able to use Cond. 786 // In this case, create a Phi where every other input is `true` and input 787 // from guard block is Cond. 788 Value *AssumeCond = Builder.CreateAnd(WidenedChecks); 789 if (!IfTrueBB->getUniquePredecessor()) { 790 auto *GuardBB = BI->getParent(); 791 auto *PN = Builder.CreatePHI(AssumeCond->getType(), pred_size(IfTrueBB), 792 "assume.cond"); 793 for (auto *Pred : predecessors(IfTrueBB)) 794 PN->addIncoming(Pred == GuardBB ? AssumeCond : Builder.getTrue(), Pred); 795 AssumeCond = PN; 796 } 797 Builder.CreateAssumption(AssumeCond); 798 } 799 RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU); 800 assert(isGuardAsWidenableBranch(BI) && 801 "Stopped being a guard after transform?"); 802 803 LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n"); 804 return true; 805 } 806 807 std::optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { 808 using namespace PatternMatch; 809 810 BasicBlock *LoopLatch = L->getLoopLatch(); 811 if (!LoopLatch) { 812 LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n"); 813 return std::nullopt; 814 } 815 816 auto *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator()); 817 if (!BI || !BI->isConditional()) { 818 LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n"); 819 return std::nullopt; 820 } 821 BasicBlock *TrueDest = BI->getSuccessor(0); 822 assert( 823 (TrueDest == L->getHeader() || BI->getSuccessor(1) == L->getHeader()) && 824 "One of the latch's destinations must be the header"); 825 826 auto *ICI = dyn_cast<ICmpInst>(BI->getCondition()); 827 if (!ICI) { 828 LLVM_DEBUG(dbgs() << "Failed to match the latch condition!\n"); 829 return std::nullopt; 830 } 831 auto Result = parseLoopICmp(ICI); 832 if (!Result) { 833 LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); 834 return std::nullopt; 835 } 836 837 if (TrueDest != L->getHeader()) 838 Result->Pred = ICmpInst::getInversePredicate(Result->Pred); 839 840 // Check affine first, so if it's not we don't try to compute the step 841 // recurrence. 842 if (!Result->IV->isAffine()) { 843 LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n"); 844 return std::nullopt; 845 } 846 847 const SCEV *Step = Result->IV->getStepRecurrence(*SE); 848 if (!isSupportedStep(Step)) { 849 LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n"); 850 return std::nullopt; 851 } 852 853 auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) { 854 if (Step->isOne()) { 855 return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT && 856 Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE; 857 } else { 858 assert(Step->isAllOnesValue() && "Step should be -1!"); 859 return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT && 860 Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE; 861 } 862 }; 863 864 normalizePredicate(SE, L, *Result); 865 if (IsUnsupportedPredicate(Step, Result->Pred)) { 866 LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred 867 << ")!\n"); 868 return std::nullopt; 869 } 870 871 return Result; 872 } 873 874 bool LoopPredication::isLoopProfitableToPredicate() { 875 if (SkipProfitabilityChecks) 876 return true; 877 878 SmallVector<std::pair<BasicBlock *, BasicBlock *>, 8> ExitEdges; 879 L->getExitEdges(ExitEdges); 880 // If there is only one exiting edge in the loop, it is always profitable to 881 // predicate the loop. 882 if (ExitEdges.size() == 1) 883 return true; 884 885 // Calculate the exiting probabilities of all exiting edges from the loop, 886 // starting with the LatchExitProbability. 887 // Heuristic for profitability: If any of the exiting blocks' probability of 888 // exiting the loop is larger than exiting through the latch block, it's not 889 // profitable to predicate the loop. 890 auto *LatchBlock = L->getLoopLatch(); 891 assert(LatchBlock && "Should have a single latch at this point!"); 892 auto *LatchTerm = LatchBlock->getTerminator(); 893 assert(LatchTerm->getNumSuccessors() == 2 && 894 "expected to be an exiting block with 2 succs!"); 895 unsigned LatchBrExitIdx = 896 LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0; 897 // We compute branch probabilities without BPI. We do not rely on BPI since 898 // Loop predication is usually run in an LPM and BPI is only preserved 899 // lossily within loop pass managers, while BPI has an inherent notion of 900 // being complete for an entire function. 901 902 // If the latch exits into a deoptimize or an unreachable block, do not 903 // predicate on that latch check. 904 auto *LatchExitBlock = LatchTerm->getSuccessor(LatchBrExitIdx); 905 if (isa<UnreachableInst>(LatchTerm) || 906 LatchExitBlock->getTerminatingDeoptimizeCall()) 907 return false; 908 909 // Latch terminator has no valid profile data, so nothing to check 910 // profitability on. 911 if (!hasValidBranchWeightMD(*LatchTerm)) 912 return true; 913 914 auto ComputeBranchProbability = 915 [&](const BasicBlock *ExitingBlock, 916 const BasicBlock *ExitBlock) -> BranchProbability { 917 auto *Term = ExitingBlock->getTerminator(); 918 unsigned NumSucc = Term->getNumSuccessors(); 919 if (MDNode *ProfileData = getValidBranchWeightMDNode(*Term)) { 920 SmallVector<uint32_t> Weights; 921 extractBranchWeights(ProfileData, Weights); 922 uint64_t Numerator = 0, Denominator = 0; 923 for (auto [i, Weight] : llvm::enumerate(Weights)) { 924 if (Term->getSuccessor(i) == ExitBlock) 925 Numerator += Weight; 926 Denominator += Weight; 927 } 928 // If all weights are zero act as if there was no profile data 929 if (Denominator == 0) 930 return BranchProbability::getBranchProbability(1, NumSucc); 931 return BranchProbability::getBranchProbability(Numerator, Denominator); 932 } else { 933 assert(LatchBlock != ExitingBlock && 934 "Latch term should always have profile data!"); 935 // No profile data, so we choose the weight as 1/num_of_succ(Src) 936 return BranchProbability::getBranchProbability(1, NumSucc); 937 } 938 }; 939 940 BranchProbability LatchExitProbability = 941 ComputeBranchProbability(LatchBlock, LatchExitBlock); 942 943 // Protect against degenerate inputs provided by the user. Providing a value 944 // less than one, can invert the definition of profitable loop predication. 945 float ScaleFactor = LatchExitProbabilityScale; 946 if (ScaleFactor < 1) { 947 LLVM_DEBUG( 948 dbgs() 949 << "Ignored user setting for loop-predication-latch-probability-scale: " 950 << LatchExitProbabilityScale << "\n"); 951 LLVM_DEBUG(dbgs() << "The value is set to 1.0\n"); 952 ScaleFactor = 1.0; 953 } 954 const auto LatchProbabilityThreshold = LatchExitProbability * ScaleFactor; 955 956 for (const auto &ExitEdge : ExitEdges) { 957 BranchProbability ExitingBlockProbability = 958 ComputeBranchProbability(ExitEdge.first, ExitEdge.second); 959 // Some exiting edge has higher probability than the latch exiting edge. 960 // No longer profitable to predicate. 961 if (ExitingBlockProbability > LatchProbabilityThreshold) 962 return false; 963 } 964 965 // We have concluded that the most probable way to exit from the 966 // loop is through the latch (or there's no profile information and all 967 // exits are equally likely). 968 return true; 969 } 970 971 /// If we can (cheaply) find a widenable branch which controls entry into the 972 /// loop, return it. 973 static BranchInst *FindWidenableTerminatorAboveLoop(Loop *L, LoopInfo &LI) { 974 // Walk back through any unconditional executed blocks and see if we can find 975 // a widenable condition which seems to control execution of this loop. Note 976 // that we predict that maythrow calls are likely untaken and thus that it's 977 // profitable to widen a branch before a maythrow call with a condition 978 // afterwards even though that may cause the slow path to run in a case where 979 // it wouldn't have otherwise. 980 BasicBlock *BB = L->getLoopPreheader(); 981 if (!BB) 982 return nullptr; 983 do { 984 if (BasicBlock *Pred = BB->getSinglePredecessor()) 985 if (BB == Pred->getSingleSuccessor()) { 986 BB = Pred; 987 continue; 988 } 989 break; 990 } while (true); 991 992 if (BasicBlock *Pred = BB->getSinglePredecessor()) { 993 if (auto *BI = dyn_cast<BranchInst>(Pred->getTerminator())) 994 if (BI->getSuccessor(0) == BB && isWidenableBranch(BI)) 995 return BI; 996 } 997 return nullptr; 998 } 999 1000 /// Return the minimum of all analyzeable exit counts. This is an upper bound 1001 /// on the actual exit count. If there are not at least two analyzeable exits, 1002 /// returns SCEVCouldNotCompute. 1003 static const SCEV *getMinAnalyzeableBackedgeTakenCount(ScalarEvolution &SE, 1004 DominatorTree &DT, 1005 Loop *L) { 1006 SmallVector<BasicBlock *, 16> ExitingBlocks; 1007 L->getExitingBlocks(ExitingBlocks); 1008 1009 SmallVector<const SCEV *, 4> ExitCounts; 1010 for (BasicBlock *ExitingBB : ExitingBlocks) { 1011 const SCEV *ExitCount = SE.getExitCount(L, ExitingBB); 1012 if (isa<SCEVCouldNotCompute>(ExitCount)) 1013 continue; 1014 assert(DT.dominates(ExitingBB, L->getLoopLatch()) && 1015 "We should only have known counts for exiting blocks that " 1016 "dominate latch!"); 1017 ExitCounts.push_back(ExitCount); 1018 } 1019 if (ExitCounts.size() < 2) 1020 return SE.getCouldNotCompute(); 1021 return SE.getUMinFromMismatchedTypes(ExitCounts); 1022 } 1023 1024 /// This implements an analogous, but entirely distinct transform from the main 1025 /// loop predication transform. This one is phrased in terms of using a 1026 /// widenable branch *outside* the loop to allow us to simplify loop exits in a 1027 /// following loop. This is close in spirit to the IndVarSimplify transform 1028 /// of the same name, but is materially different widening loosens legality 1029 /// sharply. 1030 bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { 1031 // The transformation performed here aims to widen a widenable condition 1032 // above the loop such that all analyzeable exit leading to deopt are dead. 1033 // It assumes that the latch is the dominant exit for profitability and that 1034 // exits branching to deoptimizing blocks are rarely taken. It relies on the 1035 // semantics of widenable expressions for legality. (i.e. being able to fall 1036 // down the widenable path spuriously allows us to ignore exit order, 1037 // unanalyzeable exits, side effects, exceptional exits, and other challenges 1038 // which restrict the applicability of the non-WC based version of this 1039 // transform in IndVarSimplify.) 1040 // 1041 // NOTE ON POISON/UNDEF - We're hoisting an expression above guards which may 1042 // imply flags on the expression being hoisted and inserting new uses (flags 1043 // are only correct for current uses). The result is that we may be 1044 // inserting a branch on the value which can be either poison or undef. In 1045 // this case, the branch can legally go either way; we just need to avoid 1046 // introducing UB. This is achieved through the use of the freeze 1047 // instruction. 1048 1049 SmallVector<BasicBlock *, 16> ExitingBlocks; 1050 L->getExitingBlocks(ExitingBlocks); 1051 1052 if (ExitingBlocks.empty()) 1053 return false; // Nothing to do. 1054 1055 auto *Latch = L->getLoopLatch(); 1056 if (!Latch) 1057 return false; 1058 1059 auto *WidenableBR = FindWidenableTerminatorAboveLoop(L, *LI); 1060 if (!WidenableBR) 1061 return false; 1062 1063 const SCEV *LatchEC = SE->getExitCount(L, Latch); 1064 if (isa<SCEVCouldNotCompute>(LatchEC)) 1065 return false; // profitability - want hot exit in analyzeable set 1066 1067 // At this point, we have found an analyzeable latch, and a widenable 1068 // condition above the loop. If we have a widenable exit within the loop 1069 // (for which we can't compute exit counts), drop the ability to further 1070 // widen so that we gain ability to analyze it's exit count and perform this 1071 // transform. TODO: It'd be nice to know for sure the exit became 1072 // analyzeable after dropping widenability. 1073 bool ChangedLoop = false; 1074 1075 for (auto *ExitingBB : ExitingBlocks) { 1076 if (LI->getLoopFor(ExitingBB) != L) 1077 continue; 1078 1079 auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); 1080 if (!BI) 1081 continue; 1082 1083 if (auto WC = extractWidenableCondition(BI)) 1084 if (L->contains(BI->getSuccessor(0))) { 1085 assert(WC->hasOneUse() && "Not appropriate widenable branch!"); 1086 WC->user_back()->replaceUsesOfWith( 1087 WC, ConstantInt::getTrue(BI->getContext())); 1088 ChangedLoop = true; 1089 } 1090 } 1091 if (ChangedLoop) 1092 SE->forgetLoop(L); 1093 1094 // The insertion point for the widening should be at the widenably call, not 1095 // at the WidenableBR. If we do this at the widenableBR, we can incorrectly 1096 // change a loop-invariant condition to a loop-varying one. 1097 auto *IP = cast<Instruction>(WidenableBR->getCondition()); 1098 1099 // The use of umin(all analyzeable exits) instead of latch is subtle, but 1100 // important for profitability. We may have a loop which hasn't been fully 1101 // canonicalized just yet. If the exit we chose to widen is provably never 1102 // taken, we want the widened form to *also* be provably never taken. We 1103 // can't guarantee this as a current unanalyzeable exit may later become 1104 // analyzeable, but we can at least avoid the obvious cases. 1105 const SCEV *MinEC = getMinAnalyzeableBackedgeTakenCount(*SE, *DT, L); 1106 if (isa<SCEVCouldNotCompute>(MinEC) || MinEC->getType()->isPointerTy() || 1107 !SE->isLoopInvariant(MinEC, L) || 1108 !Rewriter.isSafeToExpandAt(MinEC, IP)) 1109 return ChangedLoop; 1110 1111 Rewriter.setInsertPoint(IP); 1112 IRBuilder<> B(IP); 1113 1114 bool InvalidateLoop = false; 1115 Value *MinECV = nullptr; // lazily generated if needed 1116 for (BasicBlock *ExitingBB : ExitingBlocks) { 1117 // If our exiting block exits multiple loops, we can only rewrite the 1118 // innermost one. Otherwise, we're changing how many times the innermost 1119 // loop runs before it exits. 1120 if (LI->getLoopFor(ExitingBB) != L) 1121 continue; 1122 1123 // Can't rewrite non-branch yet. 1124 auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); 1125 if (!BI) 1126 continue; 1127 1128 // If already constant, nothing to do. 1129 if (isa<Constant>(BI->getCondition())) 1130 continue; 1131 1132 const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); 1133 if (isa<SCEVCouldNotCompute>(ExitCount) || 1134 ExitCount->getType()->isPointerTy() || 1135 !Rewriter.isSafeToExpandAt(ExitCount, WidenableBR)) 1136 continue; 1137 1138 const bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB)); 1139 BasicBlock *ExitBB = BI->getSuccessor(ExitIfTrue ? 0 : 1); 1140 if (!ExitBB->getPostdominatingDeoptimizeCall()) 1141 continue; 1142 1143 /// Here we can be fairly sure that executing this exit will most likely 1144 /// lead to executing llvm.experimental.deoptimize. 1145 /// This is a profitability heuristic, not a legality constraint. 1146 1147 // If we found a widenable exit condition, do two things: 1148 // 1) fold the widened exit test into the widenable condition 1149 // 2) fold the branch to untaken - avoids infinite looping 1150 1151 Value *ECV = Rewriter.expandCodeFor(ExitCount); 1152 if (!MinECV) 1153 MinECV = Rewriter.expandCodeFor(MinEC); 1154 Value *RHS = MinECV; 1155 if (ECV->getType() != RHS->getType()) { 1156 Type *WiderTy = SE->getWiderType(ECV->getType(), RHS->getType()); 1157 ECV = B.CreateZExt(ECV, WiderTy); 1158 RHS = B.CreateZExt(RHS, WiderTy); 1159 } 1160 assert(!Latch || DT->dominates(ExitingBB, Latch)); 1161 Value *NewCond = B.CreateICmp(ICmpInst::ICMP_UGT, ECV, RHS); 1162 // Freeze poison or undef to an arbitrary bit pattern to ensure we can 1163 // branch without introducing UB. See NOTE ON POISON/UNDEF above for 1164 // context. 1165 NewCond = B.CreateFreeze(NewCond); 1166 1167 widenWidenableBranch(WidenableBR, NewCond); 1168 1169 Value *OldCond = BI->getCondition(); 1170 BI->setCondition(ConstantInt::get(OldCond->getType(), !ExitIfTrue)); 1171 InvalidateLoop = true; 1172 } 1173 1174 if (InvalidateLoop) 1175 // We just mutated a bunch of loop exits changing there exit counts 1176 // widely. We need to force recomputation of the exit counts given these 1177 // changes. Note that all of the inserted exits are never taken, and 1178 // should be removed next time the CFG is modified. 1179 SE->forgetLoop(L); 1180 1181 // Always return `true` since we have moved the WidenableBR's condition. 1182 return true; 1183 } 1184 1185 bool LoopPredication::runOnLoop(Loop *Loop) { 1186 L = Loop; 1187 1188 LLVM_DEBUG(dbgs() << "Analyzing "); 1189 LLVM_DEBUG(L->dump()); 1190 1191 Module *M = L->getHeader()->getModule(); 1192 1193 // There is nothing to do if the module doesn't use guards 1194 auto *GuardDecl = 1195 Intrinsic::getDeclarationIfExists(M, Intrinsic::experimental_guard); 1196 bool HasIntrinsicGuards = GuardDecl && !GuardDecl->use_empty(); 1197 auto *WCDecl = Intrinsic::getDeclarationIfExists( 1198 M, Intrinsic::experimental_widenable_condition); 1199 bool HasWidenableConditions = 1200 PredicateWidenableBranchGuards && WCDecl && !WCDecl->use_empty(); 1201 if (!HasIntrinsicGuards && !HasWidenableConditions) 1202 return false; 1203 1204 DL = &M->getDataLayout(); 1205 1206 Preheader = L->getLoopPreheader(); 1207 if (!Preheader) 1208 return false; 1209 1210 auto LatchCheckOpt = parseLoopLatchICmp(); 1211 if (!LatchCheckOpt) 1212 return false; 1213 LatchCheck = *LatchCheckOpt; 1214 1215 LLVM_DEBUG(dbgs() << "Latch check:\n"); 1216 LLVM_DEBUG(LatchCheck.dump()); 1217 1218 if (!isLoopProfitableToPredicate()) { 1219 LLVM_DEBUG(dbgs() << "Loop not profitable to predicate!\n"); 1220 return false; 1221 } 1222 // Collect all the guards into a vector and process later, so as not 1223 // to invalidate the instruction iterator. 1224 SmallVector<IntrinsicInst *, 4> Guards; 1225 SmallVector<BranchInst *, 4> GuardsAsWidenableBranches; 1226 for (const auto BB : L->blocks()) { 1227 for (auto &I : *BB) 1228 if (isGuard(&I)) 1229 Guards.push_back(cast<IntrinsicInst>(&I)); 1230 if (PredicateWidenableBranchGuards && 1231 isGuardAsWidenableBranch(BB->getTerminator())) 1232 GuardsAsWidenableBranches.push_back( 1233 cast<BranchInst>(BB->getTerminator())); 1234 } 1235 1236 SCEVExpander Expander(*SE, *DL, "loop-predication"); 1237 bool Changed = false; 1238 for (auto *Guard : Guards) 1239 Changed |= widenGuardConditions(Guard, Expander); 1240 for (auto *Guard : GuardsAsWidenableBranches) 1241 Changed |= widenWidenableBranchGuardConditions(Guard, Expander); 1242 Changed |= predicateLoopExits(L, Expander); 1243 1244 if (MSSAU && VerifyMemorySSA) 1245 MSSAU->getMemorySSA()->verifyMemorySSA(); 1246 return Changed; 1247 } 1248