xref: /llvm-project/llvm/lib/Transforms/Scalar/LoopPredication.cpp (revision c8016e7a65ffc6f0266845c4674f7a08dffff3ea)
1 //===-- LoopPredication.cpp - Guard based loop predication pass -----------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // The LoopPredication pass tries to convert loop variant range checks to loop
11 // invariant by widening checks across loop iterations. For example, it will
12 // convert
13 //
14 //   for (i = 0; i < n; i++) {
15 //     guard(i < len);
16 //     ...
17 //   }
18 //
19 // to
20 //
21 //   for (i = 0; i < n; i++) {
22 //     guard(n - 1 < len);
23 //     ...
24 //   }
25 //
26 // After this transformation the condition of the guard is loop invariant, so
27 // loop-unswitch can later unswitch the loop by this condition which basically
28 // predicates the loop by the widened condition:
29 //
30 //   if (n - 1 < len)
31 //     for (i = 0; i < n; i++) {
32 //       ...
33 //     }
34 //   else
35 //     deoptimize
36 //
37 // It's tempting to rely on SCEV here, but it has proven to be problematic.
38 // Generally the facts SCEV provides about the increment step of add
39 // recurrences are true if the backedge of the loop is taken, which implicitly
40 // assumes that the guard doesn't fail. Using these facts to optimize the
41 // guard results in a circular logic where the guard is optimized under the
42 // assumption that it never fails.
43 //
44 // For example, in the loop below the induction variable will be marked as nuw
45 // basing on the guard. Basing on nuw the guard predicate will be considered
46 // monotonic. Given a monotonic condition it's tempting to replace the induction
47 // variable in the condition with its value on the last iteration. But this
48 // transformation is not correct, e.g. e = 4, b = 5 breaks the loop.
49 //
50 //   for (int i = b; i != e; i++)
51 //     guard(i u< len)
52 //
53 // One of the ways to reason about this problem is to use an inductive proof
54 // approach. Given the loop:
55 //
56 //   if (B(0)) {
57 //     do {
58 //       I = PHI(0, I.INC)
59 //       I.INC = I + Step
60 //       guard(G(I));
61 //     } while (B(I));
62 //   }
63 //
64 // where B(x) and G(x) are predicates that map integers to booleans, we want a
65 // loop invariant expression M such the following program has the same semantics
66 // as the above:
67 //
68 //   if (B(0)) {
69 //     do {
70 //       I = PHI(0, I.INC)
71 //       I.INC = I + Step
72 //       guard(G(0) && M);
73 //     } while (B(I));
74 //   }
75 //
76 // One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step)
77 //
78 // Informal proof that the transformation above is correct:
79 //
80 //   By the definition of guards we can rewrite the guard condition to:
81 //     G(I) && G(0) && M
82 //
83 //   Let's prove that for each iteration of the loop:
84 //     G(0) && M => G(I)
85 //   And the condition above can be simplified to G(Start) && M.
86 //
87 //   Induction base.
88 //     G(0) && M => G(0)
89 //
90 //   Induction step. Assuming G(0) && M => G(I) on the subsequent
91 //   iteration:
92 //
93 //     B(I) is true because it's the backedge condition.
94 //     G(I) is true because the backedge is guarded by this condition.
95 //
96 //   So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step).
97 //
98 // Note that we can use anything stronger than M, i.e. any condition which
99 // implies M.
100 //
101 // When S = 1 (i.e. forward iterating loop), the transformation is supported
102 // when:
103 //   * The loop has a single latch with the condition of the form:
104 //     B(X) = latchStart + X <pred> latchLimit,
105 //     where <pred> is u<, u<=, s<, or s<=.
106 //   * The guard condition is of the form
107 //     G(X) = guardStart + X u< guardLimit
108 //
109 //   For the ult latch comparison case M is:
110 //     forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit =>
111 //        guardStart + X + 1 u< guardLimit
112 //
113 //   The only way the antecedent can be true and the consequent can be false is
114 //   if
115 //     X == guardLimit - 1 - guardStart
116 //   (and guardLimit is non-zero, but we won't use this latter fact).
117 //   If X == guardLimit - 1 - guardStart then the second half of the antecedent is
118 //     latchStart + guardLimit - 1 - guardStart u< latchLimit
119 //   and its negation is
120 //     latchStart + guardLimit - 1 - guardStart u>= latchLimit
121 //
122 //   In other words, if
123 //     latchLimit u<= latchStart + guardLimit - 1 - guardStart
124 //   then:
125 //   (the ranges below are written in ConstantRange notation, where [A, B) is the
126 //   set for (I = A; I != B; I++ /*maywrap*/) yield(I);)
127 //
128 //      forall X . guardStart + X u< guardLimit &&
129 //                 latchStart + X u< latchLimit =>
130 //        guardStart + X + 1 u< guardLimit
131 //   == forall X . guardStart + X u< guardLimit &&
132 //                 latchStart + X u< latchStart + guardLimit - 1 - guardStart =>
133 //        guardStart + X + 1 u< guardLimit
134 //   == forall X . (guardStart + X) in [0, guardLimit) &&
135 //                 (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) =>
136 //        (guardStart + X + 1) in [0, guardLimit)
137 //   == forall X . X in [-guardStart, guardLimit - guardStart) &&
138 //                 X in [-latchStart, guardLimit - 1 - guardStart) =>
139 //         X in [-guardStart - 1, guardLimit - guardStart - 1)
140 //   == true
141 //
142 //   So the widened condition is:
143 //     guardStart u< guardLimit &&
144 //     latchStart + guardLimit - 1 - guardStart u>= latchLimit
145 //   Similarly for ule condition the widened condition is:
146 //     guardStart u< guardLimit &&
147 //     latchStart + guardLimit - 1 - guardStart u> latchLimit
148 //   For slt condition the widened condition is:
149 //     guardStart u< guardLimit &&
150 //     latchStart + guardLimit - 1 - guardStart s>= latchLimit
151 //   For sle condition the widened condition is:
152 //     guardStart u< guardLimit &&
153 //     latchStart + guardLimit - 1 - guardStart s> latchLimit
154 //
155 // When S = -1 (i.e. reverse iterating loop), the transformation is supported
156 // when:
157 //   * The loop has a single latch with the condition of the form:
158 //     B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=.
159 //   * The guard condition is of the form
160 //     G(X) = X - 1 u< guardLimit
161 //
162 //   For the ugt latch comparison case M is:
163 //     forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit
164 //
165 //   The only way the antecedent can be true and the consequent can be false is if
166 //     X == 1.
167 //   If X == 1 then the second half of the antecedent is
168 //     1 u> latchLimit, and its negation is latchLimit u>= 1.
169 //
170 //   So the widened condition is:
171 //     guardStart u< guardLimit && latchLimit u>= 1.
172 //   Similarly for sgt condition the widened condition is:
173 //     guardStart u< guardLimit && latchLimit s>= 1.
174 //   For uge condition the widened condition is:
175 //     guardStart u< guardLimit && latchLimit u> 1.
176 //   For sge condition the widened condition is:
177 //     guardStart u< guardLimit && latchLimit s> 1.
178 //===----------------------------------------------------------------------===//
179 
180 #include "llvm/Transforms/Scalar/LoopPredication.h"
181 #include "llvm/Analysis/LoopInfo.h"
182 #include "llvm/Analysis/LoopPass.h"
183 #include "llvm/Analysis/ScalarEvolution.h"
184 #include "llvm/Analysis/ScalarEvolutionExpander.h"
185 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
186 #include "llvm/IR/Function.h"
187 #include "llvm/IR/GlobalValue.h"
188 #include "llvm/IR/IntrinsicInst.h"
189 #include "llvm/IR/Module.h"
190 #include "llvm/IR/PatternMatch.h"
191 #include "llvm/Pass.h"
192 #include "llvm/Support/Debug.h"
193 #include "llvm/Transforms/Scalar.h"
194 #include "llvm/Transforms/Utils/LoopUtils.h"
195 
196 #define DEBUG_TYPE "loop-predication"
197 
198 using namespace llvm;
199 
200 static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation",
201                                         cl::Hidden, cl::init(true));
202 
203 static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop",
204                                         cl::Hidden, cl::init(true));
205 namespace {
206 class LoopPredication {
207   /// Represents an induction variable check:
208   ///   icmp Pred, <induction variable>, <loop invariant limit>
209   struct LoopICmp {
210     ICmpInst::Predicate Pred;
211     const SCEVAddRecExpr *IV;
212     const SCEV *Limit;
213     LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV,
214              const SCEV *Limit)
215         : Pred(Pred), IV(IV), Limit(Limit) {}
216     LoopICmp() {}
217     void dump() {
218       dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV
219              << ", Limit = " << *Limit << "\n";
220     }
221   };
222 
223   ScalarEvolution *SE;
224 
225   Loop *L;
226   const DataLayout *DL;
227   BasicBlock *Preheader;
228   LoopICmp LatchCheck;
229 
230   bool isSupportedStep(const SCEV* Step);
231   Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI) {
232     return parseLoopICmp(ICI->getPredicate(), ICI->getOperand(0),
233                          ICI->getOperand(1));
234   }
235   Optional<LoopICmp> parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS,
236                                    Value *RHS);
237 
238   Optional<LoopICmp> parseLoopLatchICmp();
239 
240   bool CanExpand(const SCEV* S);
241   Value *expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder,
242                      ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
243                      Instruction *InsertAt);
244 
245   Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
246                                         IRBuilder<> &Builder);
247   Optional<Value *> widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck,
248                                                         LoopICmp RangeCheck,
249                                                         SCEVExpander &Expander,
250                                                         IRBuilder<> &Builder);
251   Optional<Value *> widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck,
252                                                         LoopICmp RangeCheck,
253                                                         SCEVExpander &Expander,
254                                                         IRBuilder<> &Builder);
255   bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
256 
257   // When the IV type is wider than the range operand type, we can still do loop
258   // predication, by generating SCEVs for the range and latch that are of the
259   // same type. We achieve this by generating a SCEV truncate expression for the
260   // latch IV. This is done iff truncation of the IV is a safe operation,
261   // without loss of information.
262   // Another way to achieve this is by generating a wider type SCEV for the
263   // range check operand, however, this needs a more involved check that
264   // operands do not overflow. This can lead to loss of information when the
265   // range operand is of the form: add i32 %offset, %iv. We need to prove that
266   // sext(x + y) is same as sext(x) + sext(y).
267   // This function returns true if we can safely represent the IV type in
268   // the RangeCheckType without loss of information.
269   bool isSafeToTruncateWideIVType(Type *RangeCheckType);
270   // Return the loopLatchCheck corresponding to the RangeCheckType if safe to do
271   // so.
272   Optional<LoopICmp> generateLoopLatchCheck(Type *RangeCheckType);
273 
274   // Returns the latch predicate for guard. SGT -> SGE, UGT -> UGE, SGE -> SGT,
275   // UGE -> UGT, etc.
276   ICmpInst::Predicate getLatchPredicateForGuard(ICmpInst::Predicate Pred);
277 
278 public:
279   LoopPredication(ScalarEvolution *SE) : SE(SE){};
280   bool runOnLoop(Loop *L);
281 };
282 
283 class LoopPredicationLegacyPass : public LoopPass {
284 public:
285   static char ID;
286   LoopPredicationLegacyPass() : LoopPass(ID) {
287     initializeLoopPredicationLegacyPassPass(*PassRegistry::getPassRegistry());
288   }
289 
290   void getAnalysisUsage(AnalysisUsage &AU) const override {
291     getLoopAnalysisUsage(AU);
292   }
293 
294   bool runOnLoop(Loop *L, LPPassManager &LPM) override {
295     if (skipLoop(L))
296       return false;
297     auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
298     LoopPredication LP(SE);
299     return LP.runOnLoop(L);
300   }
301 };
302 
303 char LoopPredicationLegacyPass::ID = 0;
304 } // end namespace llvm
305 
306 INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication",
307                       "Loop predication", false, false)
308 INITIALIZE_PASS_DEPENDENCY(LoopPass)
309 INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication",
310                     "Loop predication", false, false)
311 
312 Pass *llvm::createLoopPredicationPass() {
313   return new LoopPredicationLegacyPass();
314 }
315 
316 PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM,
317                                            LoopStandardAnalysisResults &AR,
318                                            LPMUpdater &U) {
319   LoopPredication LP(&AR.SE);
320   if (!LP.runOnLoop(&L))
321     return PreservedAnalyses::all();
322 
323   return getLoopPassPreservedAnalyses();
324 }
325 
326 Optional<LoopPredication::LoopICmp>
327 LoopPredication::parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS,
328                                Value *RHS) {
329   const SCEV *LHSS = SE->getSCEV(LHS);
330   if (isa<SCEVCouldNotCompute>(LHSS))
331     return None;
332   const SCEV *RHSS = SE->getSCEV(RHS);
333   if (isa<SCEVCouldNotCompute>(RHSS))
334     return None;
335 
336   // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV
337   if (SE->isLoopInvariant(LHSS, L)) {
338     std::swap(LHS, RHS);
339     std::swap(LHSS, RHSS);
340     Pred = ICmpInst::getSwappedPredicate(Pred);
341   }
342 
343   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS);
344   if (!AR || AR->getLoop() != L)
345     return None;
346 
347   return LoopICmp(Pred, AR, RHSS);
348 }
349 
350 Value *LoopPredication::expandCheck(SCEVExpander &Expander,
351                                     IRBuilder<> &Builder,
352                                     ICmpInst::Predicate Pred, const SCEV *LHS,
353                                     const SCEV *RHS, Instruction *InsertAt) {
354   // TODO: we can check isLoopEntryGuardedByCond before emitting the check
355 
356   Type *Ty = LHS->getType();
357   assert(Ty == RHS->getType() && "expandCheck operands have different types?");
358 
359   if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS))
360     return Builder.getTrue();
361 
362   Value *LHSV = Expander.expandCodeFor(LHS, Ty, InsertAt);
363   Value *RHSV = Expander.expandCodeFor(RHS, Ty, InsertAt);
364   return Builder.CreateICmp(Pred, LHSV, RHSV);
365 }
366 
367 Optional<LoopPredication::LoopICmp>
368 LoopPredication::generateLoopLatchCheck(Type *RangeCheckType) {
369 
370   auto *LatchType = LatchCheck.IV->getType();
371   if (RangeCheckType == LatchType)
372     return LatchCheck;
373   // For now, bail out if latch type is narrower than range type.
374   if (DL->getTypeSizeInBits(LatchType) < DL->getTypeSizeInBits(RangeCheckType))
375     return None;
376   if (!isSafeToTruncateWideIVType(RangeCheckType))
377     return None;
378   // We can now safely identify the truncated version of the IV and limit for
379   // RangeCheckType.
380   LoopICmp NewLatchCheck;
381   NewLatchCheck.Pred = LatchCheck.Pred;
382   NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>(
383       SE->getTruncateExpr(LatchCheck.IV, RangeCheckType));
384   if (!NewLatchCheck.IV)
385     return None;
386   NewLatchCheck.Limit = SE->getTruncateExpr(LatchCheck.Limit, RangeCheckType);
387   DEBUG(dbgs() << "IV of type: " << *LatchType
388                << "can be represented as range check type:" << *RangeCheckType
389                << "\n");
390   DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n");
391   DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n");
392   return NewLatchCheck;
393 }
394 
395 bool LoopPredication::isSupportedStep(const SCEV* Step) {
396   return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop);
397 }
398 
399 bool LoopPredication::CanExpand(const SCEV* S) {
400   return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE);
401 }
402 
403 ICmpInst::Predicate
404 LoopPredication::getLatchPredicateForGuard(ICmpInst::Predicate Pred) {
405   switch (LatchCheck.Pred) {
406   case ICmpInst::ICMP_ULT:
407     return ICmpInst::ICMP_ULE;
408   case ICmpInst::ICMP_ULE:
409     return ICmpInst::ICMP_ULT;
410   case ICmpInst::ICMP_SLT:
411     return ICmpInst::ICMP_SLE;
412   case ICmpInst::ICMP_SLE:
413     return ICmpInst::ICMP_SLT;
414   case ICmpInst::ICMP_UGT:
415     return ICmpInst::ICMP_UGE;
416   case ICmpInst::ICMP_UGE:
417     return ICmpInst::ICMP_UGT;
418   case ICmpInst::ICMP_SGT:
419     return ICmpInst::ICMP_SGE;
420   case ICmpInst::ICMP_SGE:
421     return ICmpInst::ICMP_SGT;
422   default:
423     llvm_unreachable("Unsupported loop latch!");
424   }
425 }
426 
427 Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
428     LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck,
429     SCEVExpander &Expander, IRBuilder<> &Builder) {
430   auto *Ty = RangeCheck.IV->getType();
431   // Generate the widened condition for the forward loop:
432   //   guardStart u< guardLimit &&
433   //   latchLimit <pred> guardLimit - 1 - guardStart + latchStart
434   // where <pred> depends on the latch condition predicate. See the file
435   // header comment for the reasoning.
436   // guardLimit - guardStart + latchStart - 1
437   const SCEV *GuardStart = RangeCheck.IV->getStart();
438   const SCEV *GuardLimit = RangeCheck.Limit;
439   const SCEV *LatchStart = LatchCheck.IV->getStart();
440   const SCEV *LatchLimit = LatchCheck.Limit;
441 
442   // guardLimit - guardStart + latchStart - 1
443   const SCEV *RHS =
444       SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart),
445                      SE->getMinusSCEV(LatchStart, SE->getOne(Ty)));
446   if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) ||
447       !CanExpand(LatchLimit) || !CanExpand(RHS)) {
448     DEBUG(dbgs() << "Can't expand limit check!\n");
449     return None;
450   }
451   auto LimitCheckPred = getLatchPredicateForGuard(LatchCheck.Pred);
452 
453   DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n");
454   DEBUG(dbgs() << "RHS: " << *RHS << "\n");
455   DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n");
456 
457   Instruction *InsertAt = Preheader->getTerminator();
458   auto *LimitCheck =
459       expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, RHS, InsertAt);
460   auto *FirstIterationCheck = expandCheck(Expander, Builder, RangeCheck.Pred,
461                                           GuardStart, GuardLimit, InsertAt);
462   return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
463 }
464 
465 Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
466     LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck,
467     SCEVExpander &Expander, IRBuilder<> &Builder) {
468   auto *Ty = RangeCheck.IV->getType();
469   const SCEV *GuardStart = RangeCheck.IV->getStart();
470   const SCEV *GuardLimit = RangeCheck.Limit;
471   const SCEV *LatchLimit = LatchCheck.Limit;
472   if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) ||
473       !CanExpand(LatchLimit)) {
474     DEBUG(dbgs() << "Can't expand limit check!\n");
475     return None;
476   }
477   // The decrement of the latch check IV should be the same as the
478   // rangeCheckIV.
479   auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE);
480   if (RangeCheck.IV != PostDecLatchCheckIV) {
481     DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: "
482                  << *PostDecLatchCheckIV
483                  << "  and RangeCheckIV: " << *RangeCheck.IV << "\n");
484     return None;
485   }
486 
487   // Generate the widened condition for CountDownLoop:
488   // guardStart u< guardLimit &&
489   // latchLimit <pred> 1.
490   // See the header comment for reasoning of the checks.
491   Instruction *InsertAt = Preheader->getTerminator();
492   auto LimitCheckPred = getLatchPredicateForGuard(LatchCheck.Pred);
493   auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT,
494                                           GuardStart, GuardLimit, InsertAt);
495   auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit,
496                                  SE->getOne(Ty), InsertAt);
497   return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
498 }
499 
500 /// If ICI can be widened to a loop invariant condition emits the loop
501 /// invariant condition in the loop preheader and return it, otherwise
502 /// returns None.
503 Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
504                                                        SCEVExpander &Expander,
505                                                        IRBuilder<> &Builder) {
506   DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
507   DEBUG(ICI->dump());
508 
509   // parseLoopStructure guarantees that the latch condition is:
510   //   ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=.
511   // We are looking for the range checks of the form:
512   //   i u< guardLimit
513   auto RangeCheck = parseLoopICmp(ICI);
514   if (!RangeCheck) {
515     DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
516     return None;
517   }
518   DEBUG(dbgs() << "Guard check:\n");
519   DEBUG(RangeCheck->dump());
520   if (RangeCheck->Pred != ICmpInst::ICMP_ULT) {
521     DEBUG(dbgs() << "Unsupported range check predicate(" << RangeCheck->Pred
522                  << ")!\n");
523     return None;
524   }
525   auto *RangeCheckIV = RangeCheck->IV;
526   if (!RangeCheckIV->isAffine()) {
527     DEBUG(dbgs() << "Range check IV is not affine!\n");
528     return None;
529   }
530   auto *Step = RangeCheckIV->getStepRecurrence(*SE);
531   // We cannot just compare with latch IV step because the latch and range IVs
532   // may have different types.
533   if (!isSupportedStep(Step)) {
534     DEBUG(dbgs() << "Range check and latch have IVs different steps!\n");
535     return None;
536   }
537   auto *Ty = RangeCheckIV->getType();
538   auto CurrLatchCheckOpt = generateLoopLatchCheck(Ty);
539   if (!CurrLatchCheckOpt) {
540     DEBUG(dbgs() << "Failed to generate a loop latch check "
541                     "corresponding to range type: "
542                  << *Ty << "\n");
543     return None;
544   }
545 
546   LoopICmp CurrLatchCheck = *CurrLatchCheckOpt;
547   // At this point, the range and latch step should have the same type, but need
548   // not have the same value (we support both 1 and -1 steps).
549   assert(Step->getType() ==
550              CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() &&
551          "Range and latch steps should be of same type!");
552   if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
553     DEBUG(dbgs() << "Range and latch have different step values!\n");
554     return None;
555   }
556 
557   if (Step->isOne())
558     return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
559                                                Expander, Builder);
560   else {
561     assert(Step->isAllOnesValue() && "Step should be -1!");
562     return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
563                                                Expander, Builder);
564   }
565 }
566 
567 bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
568                                            SCEVExpander &Expander) {
569   DEBUG(dbgs() << "Processing guard:\n");
570   DEBUG(Guard->dump());
571 
572   IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator()));
573 
574   // The guard condition is expected to be in form of:
575   //   cond1 && cond2 && cond3 ...
576   // Iterate over subconditions looking for icmp conditions which can be
577   // widened across loop iterations. Widening these conditions remember the
578   // resulting list of subconditions in Checks vector.
579   SmallVector<Value *, 4> Worklist(1, Guard->getOperand(0));
580   SmallPtrSet<Value *, 4> Visited;
581 
582   SmallVector<Value *, 4> Checks;
583 
584   unsigned NumWidened = 0;
585   do {
586     Value *Condition = Worklist.pop_back_val();
587     if (!Visited.insert(Condition).second)
588       continue;
589 
590     Value *LHS, *RHS;
591     using namespace llvm::PatternMatch;
592     if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) {
593       Worklist.push_back(LHS);
594       Worklist.push_back(RHS);
595       continue;
596     }
597 
598     if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
599       if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Builder)) {
600         Checks.push_back(NewRangeCheck.getValue());
601         NumWidened++;
602         continue;
603       }
604     }
605 
606     // Save the condition as is if we can't widen it
607     Checks.push_back(Condition);
608   } while (Worklist.size() != 0);
609 
610   if (NumWidened == 0)
611     return false;
612 
613   // Emit the new guard condition
614   Builder.SetInsertPoint(Guard);
615   Value *LastCheck = nullptr;
616   for (auto *Check : Checks)
617     if (!LastCheck)
618       LastCheck = Check;
619     else
620       LastCheck = Builder.CreateAnd(LastCheck, Check);
621   Guard->setOperand(0, LastCheck);
622 
623   DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
624   return true;
625 }
626 
627 Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() {
628   using namespace PatternMatch;
629 
630   BasicBlock *LoopLatch = L->getLoopLatch();
631   if (!LoopLatch) {
632     DEBUG(dbgs() << "The loop doesn't have a single latch!\n");
633     return None;
634   }
635 
636   ICmpInst::Predicate Pred;
637   Value *LHS, *RHS;
638   BasicBlock *TrueDest, *FalseDest;
639 
640   if (!match(LoopLatch->getTerminator(),
641              m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), TrueDest,
642                   FalseDest))) {
643     DEBUG(dbgs() << "Failed to match the latch terminator!\n");
644     return None;
645   }
646   assert((TrueDest == L->getHeader() || FalseDest == L->getHeader()) &&
647          "One of the latch's destinations must be the header");
648   if (TrueDest != L->getHeader())
649     Pred = ICmpInst::getInversePredicate(Pred);
650 
651   auto Result = parseLoopICmp(Pred, LHS, RHS);
652   if (!Result) {
653     DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
654     return None;
655   }
656 
657   // Check affine first, so if it's not we don't try to compute the step
658   // recurrence.
659   if (!Result->IV->isAffine()) {
660     DEBUG(dbgs() << "The induction variable is not affine!\n");
661     return None;
662   }
663 
664   auto *Step = Result->IV->getStepRecurrence(*SE);
665   if (!isSupportedStep(Step)) {
666     DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n");
667     return None;
668   }
669 
670   auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) {
671     if (Step->isOne()) {
672       return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT &&
673              Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE;
674     } else {
675       assert(Step->isAllOnesValue() && "Step should be -1!");
676       return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT &&
677              Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE;
678     }
679   };
680 
681   if (IsUnsupportedPredicate(Step, Result->Pred)) {
682     DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred
683                  << ")!\n");
684     return None;
685   }
686   return Result;
687 }
688 
689 // Returns true if its safe to truncate the IV to RangeCheckType.
690 bool LoopPredication::isSafeToTruncateWideIVType(Type *RangeCheckType) {
691   if (!EnableIVTruncation)
692     return false;
693   assert(DL->getTypeSizeInBits(LatchCheck.IV->getType()) >
694              DL->getTypeSizeInBits(RangeCheckType) &&
695          "Expected latch check IV type to be larger than range check operand "
696          "type!");
697   // The start and end values of the IV should be known. This is to guarantee
698   // that truncating the wide type will not lose information.
699   auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit);
700   auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart());
701   if (!Limit || !Start)
702     return false;
703   // This check makes sure that the IV does not change sign during loop
704   // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE,
705   // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the
706   // IV wraps around, and the truncation of the IV would lose the range of
707   // iterations between 2^32 and 2^64.
708   bool Increasing;
709   if (!SE->isMonotonicPredicate(LatchCheck.IV, LatchCheck.Pred, Increasing))
710     return false;
711   // The active bits should be less than the bits in the RangeCheckType. This
712   // guarantees that truncating the latch check to RangeCheckType is a safe
713   // operation.
714   auto RangeCheckTypeBitSize = DL->getTypeSizeInBits(RangeCheckType);
715   return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize &&
716          Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize;
717 }
718 
719 bool LoopPredication::runOnLoop(Loop *Loop) {
720   L = Loop;
721 
722   DEBUG(dbgs() << "Analyzing ");
723   DEBUG(L->dump());
724 
725   Module *M = L->getHeader()->getModule();
726 
727   // There is nothing to do if the module doesn't use guards
728   auto *GuardDecl =
729       M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard));
730   if (!GuardDecl || GuardDecl->use_empty())
731     return false;
732 
733   DL = &M->getDataLayout();
734 
735   Preheader = L->getLoopPreheader();
736   if (!Preheader)
737     return false;
738 
739   auto LatchCheckOpt = parseLoopLatchICmp();
740   if (!LatchCheckOpt)
741     return false;
742   LatchCheck = *LatchCheckOpt;
743 
744   DEBUG(dbgs() << "Latch check:\n");
745   DEBUG(LatchCheck.dump());
746 
747   // Collect all the guards into a vector and process later, so as not
748   // to invalidate the instruction iterator.
749   SmallVector<IntrinsicInst *, 4> Guards;
750   for (const auto BB : L->blocks())
751     for (auto &I : *BB)
752       if (auto *II = dyn_cast<IntrinsicInst>(&I))
753         if (II->getIntrinsicID() == Intrinsic::experimental_guard)
754           Guards.push_back(II);
755 
756   if (Guards.empty())
757     return false;
758 
759   SCEVExpander Expander(*SE, *DL, "loop-predication");
760 
761   bool Changed = false;
762   for (auto *Guard : Guards)
763     Changed |= widenGuardConditions(Guard, Expander);
764 
765   return Changed;
766 }
767