xref: /llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp (revision 4a0d53a0b0a58a3c6980a7c551357ac71ba3db10)
1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h"
10 #include "llvm/ADT/Sequence.h"
11 #include "llvm/Analysis/LoopAnalysisManager.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/ScalarEvolution.h"
14 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
15 #include "llvm/IR/PatternMatch.h"
16 #include "llvm/Transforms/Scalar/LoopPassManager.h"
17 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
18 #include "llvm/Transforms/Utils/Cloning.h"
19 #include "llvm/Transforms/Utils/LoopSimplify.h"
20 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
21 
22 #define DEBUG_TYPE "loop-bound-split"
23 
24 namespace llvm {
25 
26 using namespace PatternMatch;
27 
28 namespace {
29 struct ConditionInfo {
30   /// Branch instruction with this condition
31   BranchInst *BI = nullptr;
32   /// ICmp instruction with this condition
33   ICmpInst *ICmp = nullptr;
34   /// Preciate info
35   CmpPredicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
36   /// AddRec llvm value
37   Value *AddRecValue = nullptr;
38   /// Non PHI AddRec llvm value
39   Value *NonPHIAddRecValue;
40   /// Bound llvm value
41   Value *BoundValue = nullptr;
42   /// AddRec SCEV
43   const SCEVAddRecExpr *AddRecSCEV = nullptr;
44   /// Bound SCEV
45   const SCEV *BoundSCEV = nullptr;
46 
47   ConditionInfo() = default;
48 };
49 } // namespace
50 
51 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,
52                         ConditionInfo &Cond, const Loop &L) {
53   Cond.ICmp = ICmp;
54   if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue),
55                          m_Value(Cond.BoundValue)))) {
56     const SCEV *AddRecSCEV = SE.getSCEV(Cond.AddRecValue);
57     const SCEV *BoundSCEV = SE.getSCEV(Cond.BoundValue);
58     const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
59     const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(BoundSCEV);
60     // Locate AddRec in LHSSCEV and Bound in RHSSCEV.
61     if (!LHSAddRecSCEV && RHSAddRecSCEV) {
62       std::swap(Cond.AddRecValue, Cond.BoundValue);
63       std::swap(AddRecSCEV, BoundSCEV);
64       Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred);
65     }
66 
67     Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
68     Cond.BoundSCEV = BoundSCEV;
69     Cond.NonPHIAddRecValue = Cond.AddRecValue;
70 
71     // If the Cond.AddRecValue is PHI node, update Cond.NonPHIAddRecValue with
72     // value from backedge.
73     if (Cond.AddRecSCEV && isa<PHINode>(Cond.AddRecValue)) {
74       PHINode *PN = cast<PHINode>(Cond.AddRecValue);
75       Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(L.getLoopLatch());
76     }
77   }
78 }
79 
80 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,
81                                 ConditionInfo &Cond, bool IsExitCond) {
82   if (IsExitCond) {
83     const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent());
84     if (isa<SCEVCouldNotCompute>(ExitCount))
85       return false;
86 
87     Cond.BoundSCEV = ExitCount;
88     return true;
89   }
90 
91   // For non-exit condtion, if pred is LT, keep existing bound.
92   if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)
93     return true;
94 
95   // For non-exit condition, if pre is LE, try to convert it to LT.
96   //      Range                 Range
97   // AddRec <= Bound  -->  AddRec < Bound + 1
98   if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)
99     return false;
100 
101   if (IntegerType *BoundSCEVIntType =
102           dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) {
103     unsigned BitWidth = BoundSCEVIntType->getBitWidth();
104     APInt Max = ICmpInst::isSigned(Cond.Pred)
105                     ? APInt::getSignedMaxValue(BitWidth)
106                     : APInt::getMaxValue(BitWidth);
107     const SCEV *MaxSCEV = SE.getConstant(Max);
108     // Check Bound < INT_MAX
109     ICmpInst::Predicate Pred =
110         ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
111     if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) {
112       const SCEV *BoundPlusOneSCEV =
113           SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType));
114       Cond.BoundSCEV = BoundPlusOneSCEV;
115       Cond.Pred = Pred;
116       return true;
117     }
118   }
119 
120   // ToDo: Support ICMP_NE/EQ.
121 
122   return false;
123 }
124 
125 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,
126                                     ICmpInst *ICmp, ConditionInfo &Cond,
127                                     bool IsExitCond) {
128   analyzeICmp(SE, ICmp, Cond, L);
129 
130   // The BoundSCEV should be evaluated at loop entry.
131   if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L))
132     return false;
133 
134   // Allowed AddRec as induction variable.
135   if (!Cond.AddRecSCEV)
136     return false;
137 
138   if (!Cond.AddRecSCEV->isAffine())
139     return false;
140 
141   const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE);
142   // Allowed constant step.
143   if (!isa<SCEVConstant>(StepRecSCEV))
144     return false;
145 
146   ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue();
147   // Allowed positive step for now.
148   // TODO: Support negative step.
149   if (StepCI->isNegative() || StepCI->isZero())
150     return false;
151 
152   // Calculate upper bound.
153   if (!calculateUpperBound(L, SE, Cond, IsExitCond))
154     return false;
155 
156   return true;
157 }
158 
159 static bool isProcessableCondBI(const ScalarEvolution &SE,
160                                 const BranchInst *BI) {
161   BasicBlock *TrueSucc = nullptr;
162   BasicBlock *FalseSucc = nullptr;
163   Value *LHS, *RHS;
164   if (!match(BI, m_Br(m_ICmp(m_Value(LHS), m_Value(RHS)),
165                       m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))
166     return false;
167 
168   if (!SE.isSCEVable(LHS->getType()))
169     return false;
170   assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");
171 
172   if (TrueSucc == FalseSucc)
173     return false;
174 
175   return true;
176 }
177 
178 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,
179                               ScalarEvolution &SE, ConditionInfo &Cond) {
180   // Skip function with optsize.
181   if (L.getHeader()->getParent()->hasOptSize())
182     return false;
183 
184   // Split only innermost loop.
185   if (!L.isInnermost())
186     return false;
187 
188   // Check loop is in simplified form.
189   if (!L.isLoopSimplifyForm())
190     return false;
191 
192   // Check loop is in LCSSA form.
193   if (!L.isLCSSAForm(DT))
194     return false;
195 
196   // Skip loop that cannot be cloned.
197   if (!L.isSafeToClone())
198     return false;
199 
200   BasicBlock *ExitingBB = L.getExitingBlock();
201   // Assumed only one exiting block.
202   if (!ExitingBB)
203     return false;
204 
205   BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
206   if (!ExitingBI)
207     return false;
208 
209   // Allowed only conditional branch with ICmp.
210   if (!isProcessableCondBI(SE, ExitingBI))
211     return false;
212 
213   // Check the condition is processable.
214   ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition());
215   if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true))
216     return false;
217 
218   Cond.BI = ExitingBI;
219   return true;
220 }
221 
222 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {
223   // If the conditional branch splits a loop into two halves, we could
224   // generally say it is profitable.
225   //
226   // ToDo: Add more profitable cases here.
227 
228   // Check this branch causes diamond CFG.
229   BasicBlock *Succ0 = BI->getSuccessor(0);
230   BasicBlock *Succ1 = BI->getSuccessor(1);
231 
232   BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();
233   BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();
234   if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)
235     return false;
236 
237   // ToDo: Calculate each successor's instruction cost.
238 
239   return true;
240 }
241 
242 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,
243                                       ConditionInfo &ExitingCond,
244                                       ConditionInfo &SplitCandidateCond) {
245   for (auto *BB : L.blocks()) {
246     // Skip condition of backedge.
247     if (L.getLoopLatch() == BB)
248       continue;
249 
250     auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
251     if (!BI)
252       continue;
253 
254     // Check conditional branch with ICmp.
255     if (!isProcessableCondBI(SE, BI))
256       continue;
257 
258     // Skip loop invariant condition.
259     if (L.isLoopInvariant(BI->getCondition()))
260       continue;
261 
262     // Check the condition is processable.
263     ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition());
264     if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond,
265                                  /*IsExitCond*/ false))
266       continue;
267 
268     if (ExitingCond.BoundSCEV->getType() !=
269         SplitCandidateCond.BoundSCEV->getType())
270       continue;
271 
272     // After transformation, we assume the split condition of the pre-loop is
273     // always true. In order to guarantee it, we need to check the start value
274     // of the split cond AddRec satisfies the split condition.
275     if (!SE.isLoopEntryGuardedByCond(&L, SplitCandidateCond.Pred,
276                                      SplitCandidateCond.AddRecSCEV->getStart(),
277                                      SplitCandidateCond.BoundSCEV))
278       continue;
279 
280     SplitCandidateCond.BI = BI;
281     return BI;
282   }
283 
284   return nullptr;
285 }
286 
287 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
288                            ScalarEvolution &SE, LPMUpdater &U) {
289   ConditionInfo SplitCandidateCond;
290   ConditionInfo ExitingCond;
291 
292   // Check we can split this loop's bound.
293   if (!canSplitLoopBound(L, DT, SE, ExitingCond))
294     return false;
295 
296   if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))
297     return false;
298 
299   if (!isProfitableToTransform(L, SplitCandidateCond.BI))
300     return false;
301 
302   // Now, we have a split candidate. Let's build a form as below.
303   //    +--------------------+
304   //    |     preheader      |
305   //    |  set up newbound   |
306   //    +--------------------+
307   //             |     /----------------\
308   //    +--------v----v------+          |
309   //    |      header        |---\      |
310   //    | with true condition|   |      |
311   //    +--------------------+   |      |
312   //             |               |      |
313   //    +--------v-----------+   |      |
314   //    |     if.then.BB     |   |      |
315   //    +--------------------+   |      |
316   //             |               |      |
317   //    +--------v-----------<---/      |
318   //    |       latch        >----------/
319   //    |   with newbound    |
320   //    +--------------------+
321   //             |
322   //    +--------v-----------+
323   //    |     preheader2     |--------------\
324   //    | if (AddRec i !=    |              |
325   //    |     org bound)     |              |
326   //    +--------------------+              |
327   //             |     /----------------\   |
328   //    +--------v----v------+          |   |
329   //    |      header2       |---\      |   |
330   //    | conditional branch |   |      |   |
331   //    |with false condition|   |      |   |
332   //    +--------------------+   |      |   |
333   //             |               |      |   |
334   //    +--------v-----------+   |      |   |
335   //    |    if.then.BB2     |   |      |   |
336   //    +--------------------+   |      |   |
337   //             |               |      |   |
338   //    +--------v-----------<---/      |   |
339   //    |       latch2       >----------/   |
340   //    |   with org bound   |              |
341   //    +--------v-----------+              |
342   //             |                          |
343   //             |  +---------------+       |
344   //             +-->     exit      <-------/
345   //                +---------------+
346 
347   // Let's create post loop.
348   SmallVector<BasicBlock *, 8> PostLoopBlocks;
349   Loop *PostLoop;
350   ValueToValueMapTy VMap;
351   BasicBlock *PreHeader = L.getLoopPreheader();
352   BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI);
353   PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap,
354                                     ".split", &LI, &DT, PostLoopBlocks);
355   remapInstructionsInBlocks(PostLoopBlocks, VMap);
356 
357   BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();
358   IRBuilder<> Builder(&PostLoopPreHeader->front());
359 
360   // Update phi nodes in header of post-loop.
361   bool isExitingLatch =
362       (L.getExitingBlock() == L.getLoopLatch()) ? true : false;
363   Value *ExitingCondLCSSAPhi = nullptr;
364   for (PHINode &PN : L.getHeader()->phis()) {
365     // Create LCSSA phi node in preheader of post-loop.
366     PHINode *LCSSAPhi =
367         Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");
368     LCSSAPhi->setDebugLoc(PN.getDebugLoc());
369     // If the exiting block is loop latch, the phi does not have the update at
370     // last iteration. In this case, update lcssa phi with value from backedge.
371     LCSSAPhi->addIncoming(
372         isExitingLatch ? PN.getIncomingValueForBlock(L.getLoopLatch()) : &PN,
373         L.getExitingBlock());
374 
375     // Update the start value of phi node in post-loop with the LCSSA phi node.
376     PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]);
377     PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, LCSSAPhi);
378 
379     // Find PHI with exiting condition from pre-loop. The PHI should be
380     // SCEVAddRecExpr and have same incoming value from backedge with
381     // ExitingCond.
382     if (!SE.isSCEVable(PN.getType()))
383       continue;
384 
385     const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
386     if (PhiSCEV && ExitingCond.NonPHIAddRecValue ==
387                        PN.getIncomingValueForBlock(L.getLoopLatch()))
388       ExitingCondLCSSAPhi = LCSSAPhi;
389   }
390 
391   // Add conditional branch to check we can skip post-loop in its preheader.
392   Instruction *OrigBI = PostLoopPreHeader->getTerminator();
393   ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;
394   Value *Cond =
395       Builder.CreateICmp(Pred, ExitingCondLCSSAPhi, ExitingCond.BoundValue);
396   Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock());
397   OrigBI->eraseFromParent();
398 
399   // Create new loop bound and add it into preheader of pre-loop.
400   const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;
401   const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;
402   NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred)
403                      ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV)
404                      : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV);
405 
406   SCEVExpander Expander(
407       SE, L.getHeader()->getDataLayout(), "split");
408   Instruction *InsertPt = SplitLoopPH->getTerminator();
409   Value *NewBoundValue =
410       Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt);
411   NewBoundValue->setName("new.bound");
412 
413   // Replace exiting bound value of pre-loop NewBound.
414   ExitingCond.ICmp->setOperand(1, NewBoundValue);
415 
416   // Replace SplitCandidateCond.BI's condition of pre-loop by True.
417   LLVMContext &Context = PreHeader->getContext();
418   SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));
419 
420   // Replace cloned SplitCandidateCond.BI's condition in post-loop by False.
421   BranchInst *ClonedSplitCandidateBI =
422       cast<BranchInst>(VMap[SplitCandidateCond.BI]);
423   ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));
424 
425   // Replace exit branch target of pre-loop by post-loop's preheader.
426   if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0))
427     ExitingCond.BI->setSuccessor(0, PostLoopPreHeader);
428   else
429     ExitingCond.BI->setSuccessor(1, PostLoopPreHeader);
430 
431   // Update phi node in exit block of post-loop.
432   Builder.SetInsertPoint(PostLoopPreHeader, PostLoopPreHeader->begin());
433   for (PHINode &PN : PostLoop->getExitBlock()->phis()) {
434     for (auto i : seq<int>(0, PN.getNumOperands())) {
435       // Check incoming block is pre-loop's exiting block.
436       if (PN.getIncomingBlock(i) == L.getExitingBlock()) {
437         Value *IncomingValue = PN.getIncomingValue(i);
438 
439         // Create LCSSA phi node for incoming value.
440         PHINode *LCSSAPhi =
441             Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");
442         LCSSAPhi->setDebugLoc(PN.getDebugLoc());
443         LCSSAPhi->addIncoming(IncomingValue, PN.getIncomingBlock(i));
444 
445         // Replace pre-loop's exiting block by post-loop's preheader.
446         PN.setIncomingBlock(i, PostLoopPreHeader);
447         // Replace incoming value by LCSSAPhi.
448         PN.setIncomingValue(i, LCSSAPhi);
449         // Add a new incoming value with post-loop's exiting block.
450         PN.addIncoming(VMap[IncomingValue], PostLoop->getExitingBlock());
451       }
452     }
453   }
454 
455   // Update dominator tree.
456   DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock());
457   DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader);
458 
459   // Invalidate cached SE information.
460   SE.forgetLoop(&L);
461 
462   // Canonicalize loops.
463   simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true);
464   simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true);
465 
466   // Add new post-loop to loop pass manager.
467   U.addSiblingLoops(PostLoop);
468 
469   return true;
470 }
471 
472 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,
473                                           LoopStandardAnalysisResults &AR,
474                                           LPMUpdater &U) {
475   Function &F = *L.getHeader()->getParent();
476   (void)F;
477 
478   LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L
479                     << "\n");
480 
481   if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U))
482     return PreservedAnalyses::all();
483 
484   assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
485   AR.LI.verify(AR.DT);
486 
487   return getLoopPassPreservedAnalyses();
488 }
489 
490 } // end namespace llvm
491