xref: /llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp (revision 4288b6520a8ed37a17fe3f41c3a686f3d3bf9f6e)
1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h"
10 #include "llvm/ADT/Sequence.h"
11 #include "llvm/Analysis/LoopAccessAnalysis.h"
12 #include "llvm/Analysis/LoopAnalysisManager.h"
13 #include "llvm/Analysis/LoopInfo.h"
14 #include "llvm/Analysis/LoopIterator.h"
15 #include "llvm/Analysis/LoopPass.h"
16 #include "llvm/Analysis/MemorySSA.h"
17 #include "llvm/Analysis/MemorySSAUpdater.h"
18 #include "llvm/Analysis/ScalarEvolution.h"
19 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
20 #include "llvm/IR/PatternMatch.h"
21 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
22 #include "llvm/Transforms/Utils/Cloning.h"
23 #include "llvm/Transforms/Utils/LoopSimplify.h"
24 #include "llvm/Transforms/Utils/LoopUtils.h"
25 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
26 
27 #define DEBUG_TYPE "loop-bound-split"
28 
29 namespace llvm {
30 
31 using namespace PatternMatch;
32 
33 namespace {
34 struct ConditionInfo {
35   /// Branch instruction with this condition
36   BranchInst *BI;
37   /// ICmp instruction with this condition
38   ICmpInst *ICmp;
39   /// Preciate info
40   ICmpInst::Predicate Pred;
41   /// AddRec llvm value
42   Value *AddRecValue;
43   /// Bound llvm value
44   Value *BoundValue;
45   /// AddRec SCEV
46   const SCEVAddRecExpr *AddRecSCEV;
47   /// Bound SCEV
48   const SCEV *BoundSCEV;
49 
50   ConditionInfo()
51       : BI(nullptr), ICmp(nullptr), Pred(ICmpInst::BAD_ICMP_PREDICATE),
52         AddRecValue(nullptr), BoundValue(nullptr), AddRecSCEV(nullptr),
53         BoundSCEV(nullptr) {}
54 };
55 } // namespace
56 
57 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,
58                         ConditionInfo &Cond) {
59   Cond.ICmp = ICmp;
60   if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue),
61                          m_Value(Cond.BoundValue)))) {
62     const SCEV *AddRecSCEV = SE.getSCEV(Cond.AddRecValue);
63     const SCEV *BoundSCEV = SE.getSCEV(Cond.BoundValue);
64     const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
65     const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(BoundSCEV);
66     // Locate AddRec in LHSSCEV and Bound in RHSSCEV.
67     if (!LHSAddRecSCEV && RHSAddRecSCEV) {
68       std::swap(Cond.AddRecValue, Cond.BoundValue);
69       std::swap(AddRecSCEV, BoundSCEV);
70       Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred);
71     }
72 
73     Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
74     Cond.BoundSCEV = BoundSCEV;
75   }
76 }
77 
78 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,
79                                 ConditionInfo &Cond, bool IsExitCond) {
80   if (IsExitCond) {
81     const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent());
82     if (isa<SCEVCouldNotCompute>(ExitCount))
83       return false;
84 
85     Cond.BoundSCEV = ExitCount;
86     return true;
87   }
88 
89   // For non-exit condtion, if pred is LT, keep existing bound.
90   if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)
91     return true;
92 
93   // For non-exit condition, if pre is LE, try to convert it to LT.
94   //      Range                 Range
95   // AddRec <= Bound  -->  AddRec < Bound + 1
96   if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)
97     return false;
98 
99   if (IntegerType *BoundSCEVIntType =
100           dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) {
101     unsigned BitWidth = BoundSCEVIntType->getBitWidth();
102     APInt Max = ICmpInst::isSigned(Cond.Pred)
103                     ? APInt::getSignedMaxValue(BitWidth)
104                     : APInt::getMaxValue(BitWidth);
105     const SCEV *MaxSCEV = SE.getConstant(Max);
106     // Check Bound < INT_MAX
107     ICmpInst::Predicate Pred =
108         ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
109     if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) {
110       const SCEV *BoundPlusOneSCEV =
111           SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType));
112       Cond.BoundSCEV = BoundPlusOneSCEV;
113       Cond.Pred = Pred;
114       return true;
115     }
116   }
117 
118   // ToDo: Support ICMP_NE/EQ.
119 
120   return false;
121 }
122 
123 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,
124                                     ICmpInst *ICmp, ConditionInfo &Cond,
125                                     bool IsExitCond) {
126   analyzeICmp(SE, ICmp, Cond);
127 
128   // The BoundSCEV should be evaluated at loop entry.
129   if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L))
130     return false;
131 
132   // Allowed AddRec as induction variable.
133   if (!Cond.AddRecSCEV)
134     return false;
135 
136   if (!Cond.AddRecSCEV->isAffine())
137     return false;
138 
139   const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE);
140   // Allowed constant step.
141   if (!isa<SCEVConstant>(StepRecSCEV))
142     return false;
143 
144   ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue();
145   // Allowed positive step for now.
146   // TODO: Support negative step.
147   if (StepCI->isNegative() || StepCI->isZero())
148     return false;
149 
150   // Calculate upper bound.
151   if (!calculateUpperBound(L, SE, Cond, IsExitCond))
152     return false;
153 
154   return true;
155 }
156 
157 static bool isProcessableCondBI(const ScalarEvolution &SE,
158                                 const BranchInst *BI) {
159   BasicBlock *TrueSucc = nullptr;
160   BasicBlock *FalseSucc = nullptr;
161   ICmpInst::Predicate Pred;
162   Value *LHS, *RHS;
163   if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),
164                       m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))
165     return false;
166 
167   if (!SE.isSCEVable(LHS->getType()))
168     return false;
169   assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");
170 
171   if (TrueSucc == FalseSucc)
172     return false;
173 
174   return true;
175 }
176 
177 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,
178                               ScalarEvolution &SE, ConditionInfo &Cond) {
179   // Skip function with optsize.
180   if (L.getHeader()->getParent()->hasOptSize())
181     return false;
182 
183   // Split only innermost loop.
184   if (!L.isInnermost())
185     return false;
186 
187   // Check loop is in simplified form.
188   if (!L.isLoopSimplifyForm())
189     return false;
190 
191   // Check loop is in LCSSA form.
192   if (!L.isLCSSAForm(DT))
193     return false;
194 
195   // Skip loop that cannot be cloned.
196   if (!L.isSafeToClone())
197     return false;
198 
199   BasicBlock *ExitingBB = L.getExitingBlock();
200   // Assumed only one exiting block.
201   if (!ExitingBB)
202     return false;
203 
204   BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
205   if (!ExitingBI)
206     return false;
207 
208   // Allowed only conditional branch with ICmp.
209   if (!isProcessableCondBI(SE, ExitingBI))
210     return false;
211 
212   // Check the condition is processable.
213   ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition());
214   if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true))
215     return false;
216 
217   Cond.BI = ExitingBI;
218   return true;
219 }
220 
221 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {
222   // If the conditional branch splits a loop into two halves, we could
223   // generally say it is profitable.
224   //
225   // ToDo: Add more profitable cases here.
226 
227   // Check this branch causes diamond CFG.
228   BasicBlock *Succ0 = BI->getSuccessor(0);
229   BasicBlock *Succ1 = BI->getSuccessor(1);
230 
231   BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();
232   BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();
233   if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)
234     return false;
235 
236   // ToDo: Calculate each successor's instruction cost.
237 
238   return true;
239 }
240 
241 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,
242                                       ConditionInfo &ExitingCond,
243                                       ConditionInfo &SplitCandidateCond) {
244   for (auto *BB : L.blocks()) {
245     // Skip condition of backedge.
246     if (L.getLoopLatch() == BB)
247       continue;
248 
249     auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
250     if (!BI)
251       continue;
252 
253     // Check conditional branch with ICmp.
254     if (!isProcessableCondBI(SE, BI))
255       continue;
256 
257     // Skip loop invariant condition.
258     if (L.isLoopInvariant(BI->getCondition()))
259       continue;
260 
261     // Check the condition is processable.
262     ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition());
263     if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond,
264                                  /*IsExitCond*/ false))
265       continue;
266 
267     if (ExitingCond.BoundSCEV->getType() !=
268         SplitCandidateCond.BoundSCEV->getType())
269       continue;
270 
271     // After transformation, we assume the split condition of the pre-loop is
272     // always true. In order to guarantee it, we need to check the start value
273     // of the split cond AddRec satisfies the split condition.
274     if (!SE.isLoopEntryGuardedByCond(&L, SplitCandidateCond.Pred,
275                                      SplitCandidateCond.AddRecSCEV->getStart(),
276                                      SplitCandidateCond.BoundSCEV))
277       continue;
278 
279     SplitCandidateCond.BI = BI;
280     return BI;
281   }
282 
283   return nullptr;
284 }
285 
286 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
287                            ScalarEvolution &SE, LPMUpdater &U) {
288   ConditionInfo SplitCandidateCond;
289   ConditionInfo ExitingCond;
290 
291   // Check we can split this loop's bound.
292   if (!canSplitLoopBound(L, DT, SE, ExitingCond))
293     return false;
294 
295   if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))
296     return false;
297 
298   if (!isProfitableToTransform(L, SplitCandidateCond.BI))
299     return false;
300 
301   // Now, we have a split candidate. Let's build a form as below.
302   //    +--------------------+
303   //    |     preheader      |
304   //    |  set up newbound   |
305   //    +--------------------+
306   //             |     /----------------\
307   //    +--------v----v------+          |
308   //    |      header        |---\      |
309   //    | with true condition|   |      |
310   //    +--------------------+   |      |
311   //             |               |      |
312   //    +--------v-----------+   |      |
313   //    |     if.then.BB     |   |      |
314   //    +--------------------+   |      |
315   //             |               |      |
316   //    +--------v-----------<---/      |
317   //    |       latch        >----------/
318   //    |   with newbound    |
319   //    +--------------------+
320   //             |
321   //    +--------v-----------+
322   //    |     preheader2     |--------------\
323   //    | if (AddRec i !=    |              |
324   //    |     org bound)     |              |
325   //    +--------------------+              |
326   //             |     /----------------\   |
327   //    +--------v----v------+          |   |
328   //    |      header2       |---\      |   |
329   //    | conditional branch |   |      |   |
330   //    |with false condition|   |      |   |
331   //    +--------------------+   |      |   |
332   //             |               |      |   |
333   //    +--------v-----------+   |      |   |
334   //    |    if.then.BB2     |   |      |   |
335   //    +--------------------+   |      |   |
336   //             |               |      |   |
337   //    +--------v-----------<---/      |   |
338   //    |       latch2       >----------/   |
339   //    |   with org bound   |              |
340   //    +--------v-----------+              |
341   //             |                          |
342   //             |  +---------------+       |
343   //             +-->     exit      <-------/
344   //                +---------------+
345 
346   // Let's create post loop.
347   SmallVector<BasicBlock *, 8> PostLoopBlocks;
348   Loop *PostLoop;
349   ValueToValueMapTy VMap;
350   BasicBlock *PreHeader = L.getLoopPreheader();
351   BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI);
352   PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap,
353                                     ".split", &LI, &DT, PostLoopBlocks);
354   remapInstructionsInBlocks(PostLoopBlocks, VMap);
355 
356   // Add conditional branch to check we can skip post-loop in its preheader.
357   BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();
358   IRBuilder<> Builder(PostLoopPreHeader);
359   Instruction *OrigBI = PostLoopPreHeader->getTerminator();
360   ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;
361   Value *Cond =
362       Builder.CreateICmp(Pred, ExitingCond.AddRecValue, ExitingCond.BoundValue);
363   Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock());
364   OrigBI->eraseFromParent();
365 
366   // Create new loop bound and add it into preheader of pre-loop.
367   const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;
368   const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;
369   NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred)
370                      ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV)
371                      : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV);
372 
373   SCEVExpander Expander(
374       SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split");
375   Instruction *InsertPt = SplitLoopPH->getTerminator();
376   Value *NewBoundValue =
377       Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt);
378   NewBoundValue->setName("new.bound");
379 
380   // Replace exiting bound value of pre-loop NewBound.
381   ExitingCond.ICmp->setOperand(1, NewBoundValue);
382 
383   // Replace IV's start value of post-loop by NewBound.
384   for (PHINode &PN : L.getHeader()->phis()) {
385     // Find PHI with exiting condition from pre-loop.
386     if (SE.isSCEVable(PN.getType()) && isa<SCEVAddRecExpr>(SE.getSCEV(&PN))) {
387       for (Value *Op : PN.incoming_values()) {
388         if (Op == ExitingCond.AddRecValue) {
389           // Find cloned PHI for post-loop.
390           PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]);
391           PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader,
392                                                NewBoundValue);
393         }
394       }
395     }
396   }
397 
398   // Replace SplitCandidateCond.BI's condition of pre-loop by True.
399   LLVMContext &Context = PreHeader->getContext();
400   SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));
401 
402   // Replace cloned SplitCandidateCond.BI's condition in post-loop by False.
403   BranchInst *ClonedSplitCandidateBI =
404       cast<BranchInst>(VMap[SplitCandidateCond.BI]);
405   ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));
406 
407   // Replace exit branch target of pre-loop by post-loop's preheader.
408   if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0))
409     ExitingCond.BI->setSuccessor(0, PostLoopPreHeader);
410   else
411     ExitingCond.BI->setSuccessor(1, PostLoopPreHeader);
412 
413   // Update phi node in exit block of post-loop.
414   for (PHINode &PN : PostLoop->getExitBlock()->phis()) {
415     for (auto i : seq<int>(0, PN.getNumOperands())) {
416       // Check incoming block is pre-loop's exiting block.
417       if (PN.getIncomingBlock(i) == L.getExitingBlock()) {
418         // Replace pre-loop's exiting block by post-loop's preheader.
419         PN.setIncomingBlock(i, PostLoopPreHeader);
420         // Add a new incoming value with post-loop's exiting block.
421         PN.addIncoming(VMap[PN.getIncomingValue(i)],
422                        PostLoop->getExitingBlock());
423       }
424     }
425   }
426 
427   // Update dominator tree.
428   DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock());
429   DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader);
430 
431   // Invalidate cached SE information.
432   SE.forgetLoop(&L);
433 
434   // Canonicalize loops.
435   // TODO: Try to update LCSSA information according to above change.
436   formLCSSA(L, DT, &LI, &SE);
437   simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true);
438   formLCSSA(*PostLoop, DT, &LI, &SE);
439   simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true);
440 
441   // Add new post-loop to loop pass manager.
442   U.addSiblingLoops(PostLoop);
443 
444   return true;
445 }
446 
447 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,
448                                           LoopStandardAnalysisResults &AR,
449                                           LPMUpdater &U) {
450   Function &F = *L.getHeader()->getParent();
451   (void)F;
452 
453   LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L
454                     << "\n");
455 
456   if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U))
457     return PreservedAnalyses::all();
458 
459   assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
460   AR.LI.verify(AR.DT);
461 
462   return getLoopPassPreservedAnalyses();
463 }
464 
465 } // end namespace llvm
466