xref: /llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp (revision a19ae77d2a9016428fee7cd5af03fd20ad6d4464)
1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/Statistic.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/Twine.h"
22 #include "llvm/ADT/iterator_range.h"
23 #include "llvm/Analysis/AssumptionCache.h"
24 #include "llvm/Analysis/DependenceAnalysis.h"
25 #include "llvm/Analysis/DomTreeUpdater.h"
26 #include "llvm/Analysis/LoopInfo.h"
27 #include "llvm/Analysis/LoopIterator.h"
28 #include "llvm/Analysis/MustExecute.h"
29 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
30 #include "llvm/Analysis/ScalarEvolution.h"
31 #include "llvm/IR/BasicBlock.h"
32 #include "llvm/IR/DebugInfoMetadata.h"
33 #include "llvm/IR/DebugLoc.h"
34 #include "llvm/IR/DiagnosticInfo.h"
35 #include "llvm/IR/Dominators.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/Instruction.h"
38 #include "llvm/IR/Instructions.h"
39 #include "llvm/IR/IntrinsicInst.h"
40 #include "llvm/IR/User.h"
41 #include "llvm/IR/Value.h"
42 #include "llvm/IR/ValueHandle.h"
43 #include "llvm/IR/ValueMap.h"
44 #include "llvm/Support/Casting.h"
45 #include "llvm/Support/Debug.h"
46 #include "llvm/Support/ErrorHandling.h"
47 #include "llvm/Support/GenericDomTree.h"
48 #include "llvm/Support/raw_ostream.h"
49 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
50 #include "llvm/Transforms/Utils/Cloning.h"
51 #include "llvm/Transforms/Utils/LoopUtils.h"
52 #include "llvm/Transforms/Utils/UnrollLoop.h"
53 #include "llvm/Transforms/Utils/ValueMapper.h"
54 #include <assert.h>
55 #include <memory>
56 #include <type_traits>
57 #include <vector>
58 
59 using namespace llvm;
60 
61 #define DEBUG_TYPE "loop-unroll-and-jam"
62 
63 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
64 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
65 
66 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
67 
68 // Partition blocks in an outer/inner loop pair into blocks before and after
69 // the loop
70 static bool partitionLoopBlocks(Loop &L, BasicBlockSet &ForeBlocks,
71                                 BasicBlockSet &AftBlocks, DominatorTree &DT) {
72   Loop *SubLoop = L.getSubLoops()[0];
73   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
74 
75   for (BasicBlock *BB : L.blocks()) {
76     if (!SubLoop->contains(BB)) {
77       if (DT.dominates(SubLoopLatch, BB))
78         AftBlocks.insert(BB);
79       else
80         ForeBlocks.insert(BB);
81     }
82   }
83 
84   // Check that all blocks in ForeBlocks together dominate the subloop
85   // TODO: This might ideally be done better with a dominator/postdominators.
86   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
87   for (BasicBlock *BB : ForeBlocks) {
88     if (BB == SubLoopPreHeader)
89       continue;
90     Instruction *TI = BB->getTerminator();
91     for (BasicBlock *Succ : successors(TI))
92       if (!ForeBlocks.count(Succ))
93         return false;
94   }
95 
96   return true;
97 }
98 
99 /// Partition blocks in a loop nest into blocks before and after each inner
100 /// loop.
101 static bool partitionOuterLoopBlocks(
102     Loop &Root, Loop &JamLoop, BasicBlockSet &JamLoopBlocks,
103     DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
104     DenseMap<Loop *, BasicBlockSet> &AftBlocksMap, DominatorTree &DT) {
105   JamLoopBlocks.insert(JamLoop.block_begin(), JamLoop.block_end());
106 
107   for (Loop *L : Root.getLoopsInPreorder()) {
108     if (L == &JamLoop)
109       break;
110 
111     if (!partitionLoopBlocks(*L, ForeBlocksMap[L], AftBlocksMap[L], DT))
112       return false;
113   }
114 
115   return true;
116 }
117 
118 // TODO Remove when UnrollAndJamLoop changed to support unroll and jamming more
119 // than 2 levels loop.
120 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
121                                      BasicBlockSet &ForeBlocks,
122                                      BasicBlockSet &SubLoopBlocks,
123                                      BasicBlockSet &AftBlocks,
124                                      DominatorTree *DT) {
125   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
126   return partitionLoopBlocks(*L, ForeBlocks, AftBlocks, *DT);
127 }
128 
129 // Looks at the phi nodes in Header for values coming from Latch. For these
130 // instructions and all their operands calls Visit on them, keeping going for
131 // all the operands in AftBlocks. Returns false if Visit returns false,
132 // otherwise returns true. This is used to process the instructions in the
133 // Aft blocks that need to be moved before the subloop. It is used in two
134 // places. One to check that the required set of instructions can be moved
135 // before the loop. Then to collect the instructions to actually move in
136 // moveHeaderPhiOperandsToForeBlocks.
137 template <typename T>
138 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
139                                      BasicBlockSet &AftBlocks, T Visit) {
140   SmallVector<Instruction *, 8> Worklist;
141   SmallPtrSet<Instruction *, 8> VisitedInstr;
142   for (auto &Phi : Header->phis()) {
143     Value *V = Phi.getIncomingValueForBlock(Latch);
144     if (Instruction *I = dyn_cast<Instruction>(V))
145       Worklist.push_back(I);
146   }
147 
148   while (!Worklist.empty()) {
149     Instruction *I = Worklist.pop_back_val();
150     if (!Visit(I))
151       return false;
152     VisitedInstr.insert(I);
153 
154     if (AftBlocks.count(I->getParent()))
155       for (auto &U : I->operands())
156         if (Instruction *II = dyn_cast<Instruction>(U))
157           if (!VisitedInstr.count(II))
158             Worklist.push_back(II);
159   }
160 
161   return true;
162 }
163 
164 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
165 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
166                                               BasicBlock *Latch,
167                                               Instruction *InsertLoc,
168                                               BasicBlockSet &AftBlocks) {
169   // We need to ensure we move the instructions in the correct order,
170   // starting with the earliest required instruction and moving forward.
171   std::vector<Instruction *> Visited;
172   processHeaderPhiOperands(Header, Latch, AftBlocks,
173                            [&Visited, &AftBlocks](Instruction *I) {
174                              if (AftBlocks.count(I->getParent()))
175                                Visited.push_back(I);
176                              return true;
177                            });
178 
179   // Move all instructions in program order to before the InsertLoc
180   BasicBlock *InsertLocBB = InsertLoc->getParent();
181   for (Instruction *I : reverse(Visited)) {
182     if (I->getParent() != InsertLocBB)
183       I->moveBefore(InsertLoc);
184   }
185 }
186 
187 /*
188   This method performs Unroll and Jam. For a simple loop like:
189   for (i = ..)
190     Fore(i)
191     for (j = ..)
192       SubLoop(i, j)
193     Aft(i)
194 
195   Instead of doing normal inner or outer unrolling, we do:
196   for (i = .., i+=2)
197     Fore(i)
198     Fore(i+1)
199     for (j = ..)
200       SubLoop(i, j)
201       SubLoop(i+1, j)
202     Aft(i)
203     Aft(i+1)
204 
205   So the outer loop is essetially unrolled and then the inner loops are fused
206   ("jammed") together into a single loop. This can increase speed when there
207   are loads in SubLoop that are invariant to i, as they become shared between
208   the now jammed inner loops.
209 
210   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
211   Fore blocks are those before the inner loop, Aft are those after. Normal
212   Unroll code is used to copy each of these sets of blocks and the results are
213   combined together into the final form above.
214 
215   isSafeToUnrollAndJam should be used prior to calling this to make sure the
216   unrolling will be valid. Checking profitablility is also advisable.
217 
218   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
219   necessary to create one and not fully unrolled).
220 */
221 LoopUnrollResult
222 llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
223                        unsigned TripMultiple, bool UnrollRemainder,
224                        LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
225                        AssumptionCache *AC, const TargetTransformInfo *TTI,
226                        OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
227 
228   // When we enter here we should have already checked that it is safe
229   BasicBlock *Header = L->getHeader();
230   assert(Header && "No header.");
231   assert(L->getSubLoops().size() == 1);
232   Loop *SubLoop = *L->begin();
233 
234   // Don't enter the unroll code if there is nothing to do.
235   if (TripCount == 0 && Count < 2) {
236     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
237     return LoopUnrollResult::Unmodified;
238   }
239 
240   assert(Count > 0);
241   assert(TripMultiple > 0);
242   assert(TripCount == 0 || TripCount % TripMultiple == 0);
243 
244   // Are we eliminating the loop control altogether?
245   bool CompletelyUnroll = (Count == TripCount);
246 
247   // We use the runtime remainder in cases where we don't know trip multiple
248   if (TripMultiple % Count != 0) {
249     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
250                                     /*UseEpilogRemainder*/ true,
251                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
252                                     LI, SE, DT, AC, TTI, true, EpilogueLoop)) {
253       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
254                            "generated when assuming runtime trip count\n");
255       return LoopUnrollResult::Unmodified;
256     }
257   }
258 
259   // Notify ScalarEvolution that the loop will be substantially changed,
260   // if not outright eliminated.
261   if (SE) {
262     SE->forgetLoop(L);
263     SE->forgetLoop(SubLoop);
264   }
265 
266   using namespace ore;
267   // Report the unrolling decision.
268   if (CompletelyUnroll) {
269     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
270                       << Header->getName() << " with trip count " << TripCount
271                       << "!\n");
272     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
273                                  L->getHeader())
274               << "completely unroll and jammed loop with "
275               << NV("UnrollCount", TripCount) << " iterations");
276   } else {
277     auto DiagBuilder = [&]() {
278       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
279                               L->getHeader());
280       return Diag << "unroll and jammed loop by a factor of "
281                   << NV("UnrollCount", Count);
282     };
283 
284     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
285                       << " by " << Count);
286     if (TripMultiple != 1) {
287       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
288       ORE->emit([&]() {
289         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
290                              << " trips per branch";
291       });
292     } else {
293       LLVM_DEBUG(dbgs() << " with run-time trip count");
294       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
295     }
296     LLVM_DEBUG(dbgs() << "!\n");
297   }
298 
299   BasicBlock *Preheader = L->getLoopPreheader();
300   BasicBlock *LatchBlock = L->getLoopLatch();
301   assert(Preheader && "No preheader");
302   assert(LatchBlock && "No latch block");
303   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
304   assert(BI && !BI->isUnconditional());
305   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
306   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
307   bool SubLoopContinueOnTrue = SubLoop->contains(
308       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
309 
310   // Partition blocks in an outer/inner loop pair into blocks before and after
311   // the loop
312   BasicBlockSet SubLoopBlocks;
313   BasicBlockSet ForeBlocks;
314   BasicBlockSet AftBlocks;
315   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
316                            DT);
317 
318   // We keep track of the entering/first and exiting/last block of each of
319   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
320   // blocks easier.
321   std::vector<BasicBlock *> ForeBlocksFirst;
322   std::vector<BasicBlock *> ForeBlocksLast;
323   std::vector<BasicBlock *> SubLoopBlocksFirst;
324   std::vector<BasicBlock *> SubLoopBlocksLast;
325   std::vector<BasicBlock *> AftBlocksFirst;
326   std::vector<BasicBlock *> AftBlocksLast;
327   ForeBlocksFirst.push_back(Header);
328   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
329   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
330   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
331   AftBlocksFirst.push_back(SubLoop->getExitBlock());
332   AftBlocksLast.push_back(L->getExitingBlock());
333   // Maps Blocks[0] -> Blocks[It]
334   ValueToValueMapTy LastValueMap;
335 
336   // Move any instructions from fore phi operands from AftBlocks into Fore.
337   moveHeaderPhiOperandsToForeBlocks(
338       Header, LatchBlock, ForeBlocksLast[0]->getTerminator(), AftBlocks);
339 
340   // The current on-the-fly SSA update requires blocks to be processed in
341   // reverse postorder so that LastValueMap contains the correct value at each
342   // exit.
343   LoopBlocksDFS DFS(L);
344   DFS.perform(LI);
345   // Stash the DFS iterators before adding blocks to the loop.
346   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
347   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
348 
349   // When a FSDiscriminator is enabled, we don't need to add the multiply
350   // factors to the discriminators.
351   if (Header->getParent()->isDebugInfoForProfiling() && !EnableFSDiscriminator)
352     for (BasicBlock *BB : L->getBlocks())
353       for (Instruction &I : *BB)
354         if (!isa<DbgInfoIntrinsic>(&I))
355           if (const DILocation *DIL = I.getDebugLoc()) {
356             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
357             if (NewDIL)
358               I.setDebugLoc(*NewDIL);
359             else
360               LLVM_DEBUG(dbgs()
361                          << "Failed to create new discriminator: "
362                          << DIL->getFilename() << " Line: " << DIL->getLine());
363           }
364 
365   // Copy all blocks
366   for (unsigned It = 1; It != Count; ++It) {
367     SmallVector<BasicBlock *, 8> NewBlocks;
368     // Maps Blocks[It] -> Blocks[It-1]
369     DenseMap<Value *, Value *> PrevItValueMap;
370     SmallDenseMap<const Loop *, Loop *, 4> NewLoops;
371     NewLoops[L] = L;
372     NewLoops[SubLoop] = SubLoop;
373 
374     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
375       ValueToValueMapTy VMap;
376       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
377       Header->getParent()->insertBasicBlockAt(Header->getParent()->end(), New);
378 
379       // Tell LI about New.
380       addClonedBlockToLoopInfo(*BB, New, LI, NewLoops);
381 
382       if (ForeBlocks.count(*BB)) {
383         if (*BB == ForeBlocksFirst[0])
384           ForeBlocksFirst.push_back(New);
385         if (*BB == ForeBlocksLast[0])
386           ForeBlocksLast.push_back(New);
387       } else if (SubLoopBlocks.count(*BB)) {
388         if (*BB == SubLoopBlocksFirst[0])
389           SubLoopBlocksFirst.push_back(New);
390         if (*BB == SubLoopBlocksLast[0])
391           SubLoopBlocksLast.push_back(New);
392       } else if (AftBlocks.count(*BB)) {
393         if (*BB == AftBlocksFirst[0])
394           AftBlocksFirst.push_back(New);
395         if (*BB == AftBlocksLast[0])
396           AftBlocksLast.push_back(New);
397       } else {
398         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
399       }
400 
401       // Update our running maps of newest clones
402       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
403       LastValueMap[*BB] = New;
404       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
405            VI != VE; ++VI) {
406         PrevItValueMap[VI->second] =
407             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
408         LastValueMap[VI->first] = VI->second;
409       }
410 
411       NewBlocks.push_back(New);
412 
413       // Update DomTree:
414       if (*BB == ForeBlocksFirst[0])
415         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
416       else if (*BB == SubLoopBlocksFirst[0])
417         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
418       else if (*BB == AftBlocksFirst[0])
419         DT->addNewBlock(New, AftBlocksLast[It - 1]);
420       else {
421         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
422         // structure.
423         auto BBDomNode = DT->getNode(*BB);
424         auto BBIDom = BBDomNode->getIDom();
425         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
426         assert(OriginalBBIDom);
427         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
428         DT->addNewBlock(
429             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
430       }
431     }
432 
433     // Remap all instructions in the most recent iteration
434     remapInstructionsInBlocks(NewBlocks, LastValueMap);
435     for (BasicBlock *NewBlock : NewBlocks) {
436       for (Instruction &I : *NewBlock) {
437         if (auto *II = dyn_cast<AssumeInst>(&I))
438           AC->registerAssumption(II);
439       }
440     }
441 
442     // Alter the ForeBlocks phi's, pointing them at the latest version of the
443     // value from the previous iteration's phis
444     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
445       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
446       assert(OldValue && "should have incoming edge from Aft[It]");
447       Value *NewValue = OldValue;
448       if (Value *PrevValue = PrevItValueMap[OldValue])
449         NewValue = PrevValue;
450 
451       assert(Phi.getNumOperands() == 2);
452       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
453       Phi.setIncomingValue(0, NewValue);
454       Phi.removeIncomingValue(1);
455     }
456   }
457 
458   // Now that all the basic blocks for the unrolled iterations are in place,
459   // finish up connecting the blocks and phi nodes. At this point LastValueMap
460   // is the last unrolled iterations values.
461 
462   // Update Phis in BB from OldBB to point to NewBB and use the latest value
463   // from LastValueMap
464   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
465                                      BasicBlock *NewBB,
466                                      ValueToValueMapTy &LastValueMap) {
467     for (PHINode &Phi : BB->phis()) {
468       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
469         if (Phi.getIncomingBlock(b) == OldBB) {
470           Value *OldValue = Phi.getIncomingValue(b);
471           if (Value *LastValue = LastValueMap[OldValue])
472             Phi.setIncomingValue(b, LastValue);
473           Phi.setIncomingBlock(b, NewBB);
474           break;
475         }
476       }
477     }
478   };
479   // Move all the phis from Src into Dest
480   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
481     Instruction *insertPoint = Dest->getFirstNonPHI();
482     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
483       Phi->moveBefore(insertPoint);
484   };
485 
486   // Update the PHI values outside the loop to point to the last block
487   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
488                            LastValueMap);
489 
490   // Update ForeBlocks successors and phi nodes
491   BranchInst *ForeTerm =
492       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
493   assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
494   ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
495 
496   if (CompletelyUnroll) {
497     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
498       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
499       Phi->eraseFromParent();
500     }
501   } else {
502     // Update the PHI values to point to the last aft block
503     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
504                              AftBlocksLast.back(), LastValueMap);
505   }
506 
507   for (unsigned It = 1; It != Count; It++) {
508     // Remap ForeBlock successors from previous iteration to this
509     BranchInst *ForeTerm =
510         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
511     assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
512     ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
513   }
514 
515   // Subloop successors and phis
516   BranchInst *SubTerm =
517       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
518   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
519   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
520   SubLoopBlocksFirst[0]->replacePhiUsesWith(ForeBlocksLast[0],
521                                             ForeBlocksLast.back());
522   SubLoopBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
523                                             SubLoopBlocksLast.back());
524 
525   for (unsigned It = 1; It != Count; It++) {
526     // Replace the conditional branch of the previous iteration subloop with an
527     // unconditional one to this one
528     BranchInst *SubTerm =
529         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
530     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
531     SubTerm->eraseFromParent();
532 
533     SubLoopBlocksFirst[It]->replacePhiUsesWith(ForeBlocksLast[It],
534                                                ForeBlocksLast.back());
535     SubLoopBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
536                                                SubLoopBlocksLast.back());
537     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
538   }
539 
540   // Aft blocks successors and phis
541   BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
542   if (CompletelyUnroll) {
543     BranchInst::Create(LoopExit, AftTerm);
544     AftTerm->eraseFromParent();
545   } else {
546     AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
547     assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
548            "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
549   }
550   AftBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
551                                         SubLoopBlocksLast.back());
552 
553   for (unsigned It = 1; It != Count; It++) {
554     // Replace the conditional branch of the previous iteration subloop with an
555     // unconditional one to this one
556     BranchInst *AftTerm =
557         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
558     BranchInst::Create(AftBlocksFirst[It], AftTerm);
559     AftTerm->eraseFromParent();
560 
561     AftBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
562                                            SubLoopBlocksLast.back());
563     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
564   }
565 
566   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
567   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
568   // new ones required.
569   if (Count != 1) {
570     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
571     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
572                            SubLoopBlocksFirst[0]);
573     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
574                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
575 
576     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
577                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
578     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
579                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
580     DTU.applyUpdatesPermissive(DTUpdates);
581   }
582 
583   // Merge adjacent basic blocks, if possible.
584   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
585   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
586   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
587   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
588 
589   MergeBlockSuccessorsIntoGivenBlocks(MergeBlocks, L, &DTU, LI);
590 
591   // Apply updates to the DomTree.
592   DT = &DTU.getDomTree();
593 
594   // At this point, the code is well formed.  We now do a quick sweep over the
595   // inserted code, doing constant propagation and dead code elimination as we
596   // go.
597   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC, TTI);
598   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC,
599                           TTI);
600 
601   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
602   ++NumUnrolledAndJammed;
603 
604   // Update LoopInfo if the loop is completely removed.
605   if (CompletelyUnroll)
606     LI->erase(L);
607 
608 #ifndef NDEBUG
609   // We shouldn't have done anything to break loop simplify form or LCSSA.
610   Loop *OutestLoop = SubLoop->getParentLoop()
611                          ? SubLoop->getParentLoop()->getParentLoop()
612                                ? SubLoop->getParentLoop()->getParentLoop()
613                                : SubLoop->getParentLoop()
614                          : SubLoop;
615   assert(DT->verify());
616   LI->verify(*DT);
617   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
618   if (!CompletelyUnroll)
619     assert(L->isLoopSimplifyForm());
620   assert(SubLoop->isLoopSimplifyForm());
621   SE->verify();
622 #endif
623 
624   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
625                           : LoopUnrollResult::PartiallyUnrolled;
626 }
627 
628 static bool getLoadsAndStores(BasicBlockSet &Blocks,
629                               SmallVector<Instruction *, 4> &MemInstr) {
630   // Scan the BBs and collect legal loads and stores.
631   // Returns false if non-simple loads/stores are found.
632   for (BasicBlock *BB : Blocks) {
633     for (Instruction &I : *BB) {
634       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
635         if (!Ld->isSimple())
636           return false;
637         MemInstr.push_back(&I);
638       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
639         if (!St->isSimple())
640           return false;
641         MemInstr.push_back(&I);
642       } else if (I.mayReadOrWriteMemory()) {
643         return false;
644       }
645     }
646   }
647   return true;
648 }
649 
650 static bool preservesForwardDependence(Instruction *Src, Instruction *Dst,
651                                        unsigned UnrollLevel, unsigned JamLevel,
652                                        bool Sequentialized, Dependence *D) {
653   // UnrollLevel might carry the dependency Src --> Dst
654   // Does a different loop after unrolling?
655   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
656        ++CurLoopDepth) {
657     auto JammedDir = D->getDirection(CurLoopDepth);
658     if (JammedDir == Dependence::DVEntry::LT)
659       return true;
660 
661     if (JammedDir & Dependence::DVEntry::GT)
662       return false;
663   }
664 
665   return true;
666 }
667 
668 static bool preservesBackwardDependence(Instruction *Src, Instruction *Dst,
669                                         unsigned UnrollLevel, unsigned JamLevel,
670                                         bool Sequentialized, Dependence *D) {
671   // UnrollLevel might carry the dependency Dst --> Src
672   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
673        ++CurLoopDepth) {
674     auto JammedDir = D->getDirection(CurLoopDepth);
675     if (JammedDir == Dependence::DVEntry::GT)
676       return true;
677 
678     if (JammedDir & Dependence::DVEntry::LT)
679       return false;
680   }
681 
682   // Backward dependencies are only preserved if not interleaved.
683   return Sequentialized;
684 }
685 
686 // Check whether it is semantically safe Src and Dst considering any potential
687 // dependency between them.
688 //
689 // @param UnrollLevel The level of the loop being unrolled
690 // @param JamLevel    The level of the loop being jammed; if Src and Dst are on
691 // different levels, the outermost common loop counts as jammed level
692 //
693 // @return true if is safe and false if there is a dependency violation.
694 static bool checkDependency(Instruction *Src, Instruction *Dst,
695                             unsigned UnrollLevel, unsigned JamLevel,
696                             bool Sequentialized, DependenceInfo &DI) {
697   assert(UnrollLevel <= JamLevel &&
698          "Expecting JamLevel to be at least UnrollLevel");
699 
700   if (Src == Dst)
701     return true;
702   // Ignore Input dependencies.
703   if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
704     return true;
705 
706   // Check whether unroll-and-jam may violate a dependency.
707   // By construction, every dependency will be lexicographically non-negative
708   // (if it was, it would violate the current execution order), such as
709   //   (0,0,>,*,*)
710   // Unroll-and-jam changes the GT execution of two executions to the same
711   // iteration of the chosen unroll level. That is, a GT dependence becomes a GE
712   // dependence (or EQ, if we fully unrolled the loop) at the loop's position:
713   //   (0,0,>=,*,*)
714   // Now, the dependency is not necessarily non-negative anymore, i.e.
715   // unroll-and-jam may violate correctness.
716   std::unique_ptr<Dependence> D = DI.depends(Src, Dst, true);
717   if (!D)
718     return true;
719   assert(D->isOrdered() && "Expected an output, flow or anti dep.");
720 
721   if (D->isConfused()) {
722     LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
723                       << "  " << *Src << "\n"
724                       << "  " << *Dst << "\n");
725     return false;
726   }
727 
728   // If outer levels (levels enclosing the loop being unroll-and-jammed) have a
729   // non-equal direction, then the locations accessed in the inner levels cannot
730   // overlap in memory. We assumes the indexes never overlap into neighboring
731   // dimensions.
732   for (unsigned CurLoopDepth = 1; CurLoopDepth < UnrollLevel; ++CurLoopDepth)
733     if (!(D->getDirection(CurLoopDepth) & Dependence::DVEntry::EQ))
734       return true;
735 
736   auto UnrollDirection = D->getDirection(UnrollLevel);
737 
738   // If the distance carried by the unrolled loop is 0, then after unrolling
739   // that distance will become non-zero resulting in non-overlapping accesses in
740   // the inner loops.
741   if (UnrollDirection == Dependence::DVEntry::EQ)
742     return true;
743 
744   if (UnrollDirection & Dependence::DVEntry::LT &&
745       !preservesForwardDependence(Src, Dst, UnrollLevel, JamLevel,
746                                   Sequentialized, D.get()))
747     return false;
748 
749   if (UnrollDirection & Dependence::DVEntry::GT &&
750       !preservesBackwardDependence(Src, Dst, UnrollLevel, JamLevel,
751                                    Sequentialized, D.get()))
752     return false;
753 
754   return true;
755 }
756 
757 static bool
758 checkDependencies(Loop &Root, const BasicBlockSet &SubLoopBlocks,
759                   const DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
760                   const DenseMap<Loop *, BasicBlockSet> &AftBlocksMap,
761                   DependenceInfo &DI, LoopInfo &LI) {
762   SmallVector<BasicBlockSet, 8> AllBlocks;
763   for (Loop *L : Root.getLoopsInPreorder())
764     if (ForeBlocksMap.find(L) != ForeBlocksMap.end())
765       AllBlocks.push_back(ForeBlocksMap.lookup(L));
766   AllBlocks.push_back(SubLoopBlocks);
767   for (Loop *L : Root.getLoopsInPreorder())
768     if (AftBlocksMap.find(L) != AftBlocksMap.end())
769       AllBlocks.push_back(AftBlocksMap.lookup(L));
770 
771   unsigned LoopDepth = Root.getLoopDepth();
772   SmallVector<Instruction *, 4> EarlierLoadsAndStores;
773   SmallVector<Instruction *, 4> CurrentLoadsAndStores;
774   for (BasicBlockSet &Blocks : AllBlocks) {
775     CurrentLoadsAndStores.clear();
776     if (!getLoadsAndStores(Blocks, CurrentLoadsAndStores))
777       return false;
778 
779     Loop *CurLoop = LI.getLoopFor((*Blocks.begin())->front().getParent());
780     unsigned CurLoopDepth = CurLoop->getLoopDepth();
781 
782     for (auto *Earlier : EarlierLoadsAndStores) {
783       Loop *EarlierLoop = LI.getLoopFor(Earlier->getParent());
784       unsigned EarlierDepth = EarlierLoop->getLoopDepth();
785       unsigned CommonLoopDepth = std::min(EarlierDepth, CurLoopDepth);
786       for (auto *Later : CurrentLoadsAndStores) {
787         if (!checkDependency(Earlier, Later, LoopDepth, CommonLoopDepth, false,
788                              DI))
789           return false;
790       }
791     }
792 
793     size_t NumInsts = CurrentLoadsAndStores.size();
794     for (size_t I = 0; I < NumInsts; ++I) {
795       for (size_t J = I; J < NumInsts; ++J) {
796         if (!checkDependency(CurrentLoadsAndStores[I], CurrentLoadsAndStores[J],
797                              LoopDepth, CurLoopDepth, true, DI))
798           return false;
799       }
800     }
801 
802     EarlierLoadsAndStores.append(CurrentLoadsAndStores.begin(),
803                                  CurrentLoadsAndStores.end());
804   }
805   return true;
806 }
807 
808 static bool isEligibleLoopForm(const Loop &Root) {
809   // Root must have a child.
810   if (Root.getSubLoops().size() != 1)
811     return false;
812 
813   const Loop *L = &Root;
814   do {
815     // All loops in Root need to be in simplify and rotated form.
816     if (!L->isLoopSimplifyForm())
817       return false;
818 
819     if (!L->isRotatedForm())
820       return false;
821 
822     if (L->getHeader()->hasAddressTaken()) {
823       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
824       return false;
825     }
826 
827     unsigned SubLoopsSize = L->getSubLoops().size();
828     if (SubLoopsSize == 0)
829       return true;
830 
831     // Only one child is allowed.
832     if (SubLoopsSize != 1)
833       return false;
834 
835     // Only loops with a single exit block can be unrolled and jammed.
836     // The function getExitBlock() is used for this check, rather than
837     // getUniqueExitBlock() to ensure loops with mulitple exit edges are
838     // disallowed.
839     if (!L->getExitBlock()) {
840       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; only loops with single exit "
841                            "blocks can be unrolled and jammed.\n");
842       return false;
843     }
844 
845     // Only loops with a single exiting block can be unrolled and jammed.
846     if (!L->getExitingBlock()) {
847       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; only loops with single "
848                            "exiting blocks can be unrolled and jammed.\n");
849       return false;
850     }
851 
852     L = L->getSubLoops()[0];
853   } while (L);
854 
855   return true;
856 }
857 
858 static Loop *getInnerMostLoop(Loop *L) {
859   while (!L->getSubLoops().empty())
860     L = L->getSubLoops()[0];
861   return L;
862 }
863 
864 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
865                                 DependenceInfo &DI, LoopInfo &LI) {
866   if (!isEligibleLoopForm(*L)) {
867     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Ineligible loop form\n");
868     return false;
869   }
870 
871   /* We currently handle outer loops like this:
872         |
873     ForeFirst    <------\   }
874      Blocks             |   } ForeBlocks of L
875     ForeLast            |   }
876         |               |
877        ...              |
878         |               |
879     ForeFirst    <----\ |   }
880      Blocks           | |   } ForeBlocks of a inner loop of L
881     ForeLast          | |   }
882         |             | |
883     JamLoopFirst  <\  | |   }
884      Blocks        |  | |   } JamLoopBlocks of the innermost loop
885     JamLoopLast   -/  | |   }
886         |             | |
887     AftFirst          | |   }
888      Blocks           | |   } AftBlocks of a inner loop of L
889     AftLast     ------/ |   }
890         |               |
891        ...              |
892         |               |
893     AftFirst            |   }
894      Blocks             |   } AftBlocks of L
895     AftLast     --------/   }
896         |
897 
898     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
899     and AftBlocks, providing that there is one edge from Fores to SubLoops,
900     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
901     In practice we currently limit Aft blocks to a single block, and limit
902     things further in the profitablility checks of the unroll and jam pass.
903 
904     Because of the way we rearrange basic blocks, we also require that
905     the Fore blocks of L on all unrolled iterations are safe to move before the
906     blocks of the direct child of L of all iterations. So we require that the
907     phi node looping operands of ForeHeader can be moved to at least the end of
908     ForeEnd, so that we can arrange cloned Fore Blocks before the subloop and
909     match up Phi's correctly.
910 
911     i.e. The old order of blocks used to be
912            (F1)1 (F2)1 J1_1 J1_2 (A2)1 (A1)1 (F1)2 (F2)2 J2_1 J2_2 (A2)2 (A1)2.
913          It needs to be safe to transform this to
914            (F1)1 (F1)2 (F2)1 (F2)2 J1_1 J1_2 J2_1 J2_2 (A2)1 (A2)2 (A1)1 (A1)2.
915 
916     There are then a number of checks along the lines of no calls, no
917     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
918     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
919     UnrollAndJamLoop if the trip count cannot be easily calculated.
920   */
921 
922   // Split blocks into Fore/SubLoop/Aft based on dominators
923   Loop *JamLoop = getInnerMostLoop(L);
924   BasicBlockSet SubLoopBlocks;
925   DenseMap<Loop *, BasicBlockSet> ForeBlocksMap;
926   DenseMap<Loop *, BasicBlockSet> AftBlocksMap;
927   if (!partitionOuterLoopBlocks(*L, *JamLoop, SubLoopBlocks, ForeBlocksMap,
928                                 AftBlocksMap, DT)) {
929     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
930     return false;
931   }
932 
933   // Aft blocks may need to move instructions to fore blocks, which becomes more
934   // difficult if there are multiple (potentially conditionally executed)
935   // blocks. For now we just exclude loops with multiple aft blocks.
936   if (AftBlocksMap[L].size() != 1) {
937     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
938                          "multiple blocks after the loop\n");
939     return false;
940   }
941 
942   // Check inner loop backedge count is consistent on all iterations of the
943   // outer loop
944   if (any_of(L->getLoopsInPreorder(), [&SE](Loop *SubLoop) {
945         return !hasIterationCountInvariantInParent(SubLoop, SE);
946       })) {
947     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
948                          "not consistent on each iteration\n");
949     return false;
950   }
951 
952   // Check the loop safety info for exceptions.
953   SimpleLoopSafetyInfo LSI;
954   LSI.computeLoopSafetyInfo(L);
955   if (LSI.anyBlockMayThrow()) {
956     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
957     return false;
958   }
959 
960   // We've ruled out the easy stuff and now need to check that there are no
961   // interdependencies which may prevent us from moving the:
962   //  ForeBlocks before Subloop and AftBlocks.
963   //  Subloop before AftBlocks.
964   //  ForeBlock phi operands before the subloop
965 
966   // Make sure we can move all instructions we need to before the subloop
967   BasicBlock *Header = L->getHeader();
968   BasicBlock *Latch = L->getLoopLatch();
969   BasicBlockSet AftBlocks = AftBlocksMap[L];
970   Loop *SubLoop = L->getSubLoops()[0];
971   if (!processHeaderPhiOperands(
972           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
973             if (SubLoop->contains(I->getParent()))
974               return false;
975             if (AftBlocks.count(I->getParent())) {
976               // If we hit a phi node in afts we know we are done (probably
977               // LCSSA)
978               if (isa<PHINode>(I))
979                 return false;
980               // Can't move instructions with side effects or memory
981               // reads/writes
982               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
983                 return false;
984             }
985             // Keep going
986             return true;
987           })) {
988     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
989                          "instructions after subloop to before it\n");
990     return false;
991   }
992 
993   // Check for memory dependencies which prohibit the unrolling we are doing.
994   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
995   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
996   if (!checkDependencies(*L, SubLoopBlocks, ForeBlocksMap, AftBlocksMap, DI,
997                          LI)) {
998     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
999     return false;
1000   }
1001 
1002   return true;
1003 }
1004