xref: /llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp (revision caebb4562ce634a22f7b13480b19cffc2a6a6730)
1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/Statistic.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/Twine.h"
22 #include "llvm/Analysis/AssumptionCache.h"
23 #include "llvm/Analysis/DependenceAnalysis.h"
24 #include "llvm/Analysis/DomTreeUpdater.h"
25 #include "llvm/Analysis/LoopInfo.h"
26 #include "llvm/Analysis/LoopIterator.h"
27 #include "llvm/Analysis/MustExecute.h"
28 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
29 #include "llvm/Analysis/ScalarEvolution.h"
30 #include "llvm/IR/BasicBlock.h"
31 #include "llvm/IR/DebugInfoMetadata.h"
32 #include "llvm/IR/DebugLoc.h"
33 #include "llvm/IR/DiagnosticInfo.h"
34 #include "llvm/IR/Dominators.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Instruction.h"
37 #include "llvm/IR/Instructions.h"
38 #include "llvm/IR/IntrinsicInst.h"
39 #include "llvm/IR/User.h"
40 #include "llvm/IR/Value.h"
41 #include "llvm/IR/ValueHandle.h"
42 #include "llvm/IR/ValueMap.h"
43 #include "llvm/Support/Casting.h"
44 #include "llvm/Support/Debug.h"
45 #include "llvm/Support/ErrorHandling.h"
46 #include "llvm/Support/GenericDomTree.h"
47 #include "llvm/Support/raw_ostream.h"
48 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
49 #include "llvm/Transforms/Utils/Cloning.h"
50 #include "llvm/Transforms/Utils/LoopUtils.h"
51 #include "llvm/Transforms/Utils/UnrollLoop.h"
52 #include "llvm/Transforms/Utils/ValueMapper.h"
53 #include <assert.h>
54 #include <memory>
55 #include <type_traits>
56 #include <vector>
57 
58 using namespace llvm;
59 
60 #define DEBUG_TYPE "loop-unroll-and-jam"
61 
62 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
63 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
64 
65 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
66 
67 // Partition blocks in an outer/inner loop pair into blocks before and after
68 // the loop
69 static bool partitionLoopBlocks(Loop &L, BasicBlockSet &ForeBlocks,
70                                 BasicBlockSet &AftBlocks, DominatorTree &DT) {
71   Loop *SubLoop = L.getSubLoops()[0];
72   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
73 
74   for (BasicBlock *BB : L.blocks()) {
75     if (!SubLoop->contains(BB)) {
76       if (DT.dominates(SubLoopLatch, BB))
77         AftBlocks.insert(BB);
78       else
79         ForeBlocks.insert(BB);
80     }
81   }
82 
83   // Check that all blocks in ForeBlocks together dominate the subloop
84   // TODO: This might ideally be done better with a dominator/postdominators.
85   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
86   for (BasicBlock *BB : ForeBlocks) {
87     if (BB == SubLoopPreHeader)
88       continue;
89     Instruction *TI = BB->getTerminator();
90     for (BasicBlock *Succ : successors(TI))
91       if (!ForeBlocks.count(Succ))
92         return false;
93   }
94 
95   return true;
96 }
97 
98 /// Partition blocks in a loop nest into blocks before and after each inner
99 /// loop.
100 static bool partitionOuterLoopBlocks(
101     Loop &Root, Loop &JamLoop, BasicBlockSet &JamLoopBlocks,
102     DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
103     DenseMap<Loop *, BasicBlockSet> &AftBlocksMap, DominatorTree &DT) {
104   JamLoopBlocks.insert(JamLoop.block_begin(), JamLoop.block_end());
105 
106   for (Loop *L : Root.getLoopsInPreorder()) {
107     if (L == &JamLoop)
108       break;
109 
110     if (!partitionLoopBlocks(*L, ForeBlocksMap[L], AftBlocksMap[L], DT))
111       return false;
112   }
113 
114   return true;
115 }
116 
117 // TODO Remove when UnrollAndJamLoop changed to support unroll and jamming more
118 // than 2 levels loop.
119 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
120                                      BasicBlockSet &ForeBlocks,
121                                      BasicBlockSet &SubLoopBlocks,
122                                      BasicBlockSet &AftBlocks,
123                                      DominatorTree *DT) {
124   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
125   return partitionLoopBlocks(*L, ForeBlocks, AftBlocks, *DT);
126 }
127 
128 // Looks at the phi nodes in Header for values coming from Latch. For these
129 // instructions and all their operands calls Visit on them, keeping going for
130 // all the operands in AftBlocks. Returns false if Visit returns false,
131 // otherwise returns true. This is used to process the instructions in the
132 // Aft blocks that need to be moved before the subloop. It is used in two
133 // places. One to check that the required set of instructions can be moved
134 // before the loop. Then to collect the instructions to actually move in
135 // moveHeaderPhiOperandsToForeBlocks.
136 template <typename T>
137 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
138                                      BasicBlockSet &AftBlocks, T Visit) {
139   SmallPtrSet<Instruction *, 8> VisitedInstr;
140 
141   std::function<bool(Instruction * I)> ProcessInstr = [&](Instruction *I) {
142     if (!VisitedInstr.insert(I).second)
143       return true;
144 
145     if (AftBlocks.count(I->getParent()))
146       for (auto &U : I->operands())
147         if (Instruction *II = dyn_cast<Instruction>(U))
148           if (!ProcessInstr(II))
149             return false;
150 
151     return Visit(I);
152   };
153 
154   for (auto &Phi : Header->phis()) {
155     Value *V = Phi.getIncomingValueForBlock(Latch);
156     if (Instruction *I = dyn_cast<Instruction>(V))
157       if (!ProcessInstr(I))
158         return false;
159   }
160 
161   return true;
162 }
163 
164 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
165 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
166                                               BasicBlock *Latch,
167                                               Instruction *InsertLoc,
168                                               BasicBlockSet &AftBlocks) {
169   // We need to ensure we move the instructions in the correct order,
170   // starting with the earliest required instruction and moving forward.
171   processHeaderPhiOperands(Header, Latch, AftBlocks,
172                            [&AftBlocks, &InsertLoc](Instruction *I) {
173                              if (AftBlocks.count(I->getParent()))
174                                I->moveBefore(InsertLoc);
175                              return true;
176                            });
177 }
178 
179 /*
180   This method performs Unroll and Jam. For a simple loop like:
181   for (i = ..)
182     Fore(i)
183     for (j = ..)
184       SubLoop(i, j)
185     Aft(i)
186 
187   Instead of doing normal inner or outer unrolling, we do:
188   for (i = .., i+=2)
189     Fore(i)
190     Fore(i+1)
191     for (j = ..)
192       SubLoop(i, j)
193       SubLoop(i+1, j)
194     Aft(i)
195     Aft(i+1)
196 
197   So the outer loop is essetially unrolled and then the inner loops are fused
198   ("jammed") together into a single loop. This can increase speed when there
199   are loads in SubLoop that are invariant to i, as they become shared between
200   the now jammed inner loops.
201 
202   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
203   Fore blocks are those before the inner loop, Aft are those after. Normal
204   Unroll code is used to copy each of these sets of blocks and the results are
205   combined together into the final form above.
206 
207   isSafeToUnrollAndJam should be used prior to calling this to make sure the
208   unrolling will be valid. Checking profitablility is also advisable.
209 
210   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
211   necessary to create one and not fully unrolled).
212 */
213 LoopUnrollResult
214 llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
215                        unsigned TripMultiple, bool UnrollRemainder,
216                        LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
217                        AssumptionCache *AC, const TargetTransformInfo *TTI,
218                        OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
219 
220   // When we enter here we should have already checked that it is safe
221   BasicBlock *Header = L->getHeader();
222   assert(Header && "No header.");
223   assert(L->getSubLoops().size() == 1);
224   Loop *SubLoop = *L->begin();
225 
226   // Don't enter the unroll code if there is nothing to do.
227   if (TripCount == 0 && Count < 2) {
228     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
229     return LoopUnrollResult::Unmodified;
230   }
231 
232   assert(Count > 0);
233   assert(TripMultiple > 0);
234   assert(TripCount == 0 || TripCount % TripMultiple == 0);
235 
236   // Are we eliminating the loop control altogether?
237   bool CompletelyUnroll = (Count == TripCount);
238 
239   // We use the runtime remainder in cases where we don't know trip multiple
240   if (TripMultiple % Count != 0) {
241     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
242                                     /*UseEpilogRemainder*/ true,
243                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
244                                     LI, SE, DT, AC, TTI, true, EpilogueLoop)) {
245       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
246                            "generated when assuming runtime trip count\n");
247       return LoopUnrollResult::Unmodified;
248     }
249   }
250 
251   // Notify ScalarEvolution that the loop will be substantially changed,
252   // if not outright eliminated.
253   if (SE) {
254     SE->forgetLoop(L);
255     SE->forgetBlockAndLoopDispositions();
256   }
257 
258   using namespace ore;
259   // Report the unrolling decision.
260   if (CompletelyUnroll) {
261     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
262                       << Header->getName() << " with trip count " << TripCount
263                       << "!\n");
264     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
265                                  L->getHeader())
266               << "completely unroll and jammed loop with "
267               << NV("UnrollCount", TripCount) << " iterations");
268   } else {
269     auto DiagBuilder = [&]() {
270       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
271                               L->getHeader());
272       return Diag << "unroll and jammed loop by a factor of "
273                   << NV("UnrollCount", Count);
274     };
275 
276     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
277                       << " by " << Count);
278     if (TripMultiple != 1) {
279       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
280       ORE->emit([&]() {
281         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
282                              << " trips per branch";
283       });
284     } else {
285       LLVM_DEBUG(dbgs() << " with run-time trip count");
286       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
287     }
288     LLVM_DEBUG(dbgs() << "!\n");
289   }
290 
291   BasicBlock *Preheader = L->getLoopPreheader();
292   BasicBlock *LatchBlock = L->getLoopLatch();
293   assert(Preheader && "No preheader");
294   assert(LatchBlock && "No latch block");
295   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
296   assert(BI && !BI->isUnconditional());
297   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
298   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
299   bool SubLoopContinueOnTrue = SubLoop->contains(
300       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
301 
302   // Partition blocks in an outer/inner loop pair into blocks before and after
303   // the loop
304   BasicBlockSet SubLoopBlocks;
305   BasicBlockSet ForeBlocks;
306   BasicBlockSet AftBlocks;
307   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
308                            DT);
309 
310   // We keep track of the entering/first and exiting/last block of each of
311   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
312   // blocks easier.
313   std::vector<BasicBlock *> ForeBlocksFirst;
314   std::vector<BasicBlock *> ForeBlocksLast;
315   std::vector<BasicBlock *> SubLoopBlocksFirst;
316   std::vector<BasicBlock *> SubLoopBlocksLast;
317   std::vector<BasicBlock *> AftBlocksFirst;
318   std::vector<BasicBlock *> AftBlocksLast;
319   ForeBlocksFirst.push_back(Header);
320   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
321   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
322   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
323   AftBlocksFirst.push_back(SubLoop->getExitBlock());
324   AftBlocksLast.push_back(L->getExitingBlock());
325   // Maps Blocks[0] -> Blocks[It]
326   ValueToValueMapTy LastValueMap;
327 
328   // Move any instructions from fore phi operands from AftBlocks into Fore.
329   moveHeaderPhiOperandsToForeBlocks(
330       Header, LatchBlock, ForeBlocksLast[0]->getTerminator(), AftBlocks);
331 
332   // The current on-the-fly SSA update requires blocks to be processed in
333   // reverse postorder so that LastValueMap contains the correct value at each
334   // exit.
335   LoopBlocksDFS DFS(L);
336   DFS.perform(LI);
337   // Stash the DFS iterators before adding blocks to the loop.
338   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
339   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
340 
341   // When a FSDiscriminator is enabled, we don't need to add the multiply
342   // factors to the discriminators.
343   if (Header->getParent()->shouldEmitDebugInfoForProfiling() &&
344       !EnableFSDiscriminator)
345     for (BasicBlock *BB : L->getBlocks())
346       for (Instruction &I : *BB)
347         if (!I.isDebugOrPseudoInst())
348           if (const DILocation *DIL = I.getDebugLoc()) {
349             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
350             if (NewDIL)
351               I.setDebugLoc(*NewDIL);
352             else
353               LLVM_DEBUG(dbgs()
354                          << "Failed to create new discriminator: "
355                          << DIL->getFilename() << " Line: " << DIL->getLine());
356           }
357 
358   // Copy all blocks
359   for (unsigned It = 1; It != Count; ++It) {
360     SmallVector<BasicBlock *, 8> NewBlocks;
361     // Maps Blocks[It] -> Blocks[It-1]
362     DenseMap<Value *, Value *> PrevItValueMap;
363     SmallDenseMap<const Loop *, Loop *, 4> NewLoops;
364     NewLoops[L] = L;
365     NewLoops[SubLoop] = SubLoop;
366 
367     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
368       ValueToValueMapTy VMap;
369       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
370       Header->getParent()->insert(Header->getParent()->end(), New);
371 
372       // Tell LI about New.
373       addClonedBlockToLoopInfo(*BB, New, LI, NewLoops);
374 
375       if (ForeBlocks.count(*BB)) {
376         if (*BB == ForeBlocksFirst[0])
377           ForeBlocksFirst.push_back(New);
378         if (*BB == ForeBlocksLast[0])
379           ForeBlocksLast.push_back(New);
380       } else if (SubLoopBlocks.count(*BB)) {
381         if (*BB == SubLoopBlocksFirst[0])
382           SubLoopBlocksFirst.push_back(New);
383         if (*BB == SubLoopBlocksLast[0])
384           SubLoopBlocksLast.push_back(New);
385       } else if (AftBlocks.count(*BB)) {
386         if (*BB == AftBlocksFirst[0])
387           AftBlocksFirst.push_back(New);
388         if (*BB == AftBlocksLast[0])
389           AftBlocksLast.push_back(New);
390       } else {
391         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
392       }
393 
394       // Update our running maps of newest clones
395       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
396       LastValueMap[*BB] = New;
397       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
398            VI != VE; ++VI) {
399         PrevItValueMap[VI->second] =
400             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
401         LastValueMap[VI->first] = VI->second;
402       }
403 
404       NewBlocks.push_back(New);
405 
406       // Update DomTree:
407       if (*BB == ForeBlocksFirst[0])
408         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
409       else if (*BB == SubLoopBlocksFirst[0])
410         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
411       else if (*BB == AftBlocksFirst[0])
412         DT->addNewBlock(New, AftBlocksLast[It - 1]);
413       else {
414         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
415         // structure.
416         auto BBDomNode = DT->getNode(*BB);
417         auto BBIDom = BBDomNode->getIDom();
418         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
419         assert(OriginalBBIDom);
420         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
421         DT->addNewBlock(
422             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
423       }
424     }
425 
426     // Remap all instructions in the most recent iteration
427     remapInstructionsInBlocks(NewBlocks, LastValueMap);
428     for (BasicBlock *NewBlock : NewBlocks) {
429       for (Instruction &I : *NewBlock) {
430         if (auto *II = dyn_cast<AssumeInst>(&I))
431           AC->registerAssumption(II);
432       }
433     }
434 
435     // Alter the ForeBlocks phi's, pointing them at the latest version of the
436     // value from the previous iteration's phis
437     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
438       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
439       assert(OldValue && "should have incoming edge from Aft[It]");
440       Value *NewValue = OldValue;
441       if (Value *PrevValue = PrevItValueMap[OldValue])
442         NewValue = PrevValue;
443 
444       assert(Phi.getNumOperands() == 2);
445       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
446       Phi.setIncomingValue(0, NewValue);
447       Phi.removeIncomingValue(1);
448     }
449   }
450 
451   // Now that all the basic blocks for the unrolled iterations are in place,
452   // finish up connecting the blocks and phi nodes. At this point LastValueMap
453   // is the last unrolled iterations values.
454 
455   // Update Phis in BB from OldBB to point to NewBB and use the latest value
456   // from LastValueMap
457   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
458                                      BasicBlock *NewBB,
459                                      ValueToValueMapTy &LastValueMap) {
460     for (PHINode &Phi : BB->phis()) {
461       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
462         if (Phi.getIncomingBlock(b) == OldBB) {
463           Value *OldValue = Phi.getIncomingValue(b);
464           if (Value *LastValue = LastValueMap[OldValue])
465             Phi.setIncomingValue(b, LastValue);
466           Phi.setIncomingBlock(b, NewBB);
467           break;
468         }
469       }
470     }
471   };
472   // Move all the phis from Src into Dest
473   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
474     BasicBlock::iterator insertPoint = Dest->getFirstNonPHIIt();
475     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
476       Phi->moveBefore(*Dest, insertPoint);
477   };
478 
479   // Update the PHI values outside the loop to point to the last block
480   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
481                            LastValueMap);
482 
483   // Update ForeBlocks successors and phi nodes
484   BranchInst *ForeTerm =
485       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
486   assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
487   ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
488 
489   if (CompletelyUnroll) {
490     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
491       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
492       Phi->eraseFromParent();
493     }
494   } else {
495     // Update the PHI values to point to the last aft block
496     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
497                              AftBlocksLast.back(), LastValueMap);
498   }
499 
500   for (unsigned It = 1; It != Count; It++) {
501     // Remap ForeBlock successors from previous iteration to this
502     BranchInst *ForeTerm =
503         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
504     assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
505     ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
506   }
507 
508   // Subloop successors and phis
509   BranchInst *SubTerm =
510       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
511   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
512   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
513   SubLoopBlocksFirst[0]->replacePhiUsesWith(ForeBlocksLast[0],
514                                             ForeBlocksLast.back());
515   SubLoopBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
516                                             SubLoopBlocksLast.back());
517 
518   for (unsigned It = 1; It != Count; It++) {
519     // Replace the conditional branch of the previous iteration subloop with an
520     // unconditional one to this one
521     BranchInst *SubTerm =
522         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
523     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm->getIterator());
524     SubTerm->eraseFromParent();
525 
526     SubLoopBlocksFirst[It]->replacePhiUsesWith(ForeBlocksLast[It],
527                                                ForeBlocksLast.back());
528     SubLoopBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
529                                                SubLoopBlocksLast.back());
530     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
531   }
532 
533   // Aft blocks successors and phis
534   BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
535   if (CompletelyUnroll) {
536     BranchInst::Create(LoopExit, AftTerm->getIterator());
537     AftTerm->eraseFromParent();
538   } else {
539     AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
540     assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
541            "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
542   }
543   AftBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
544                                         SubLoopBlocksLast.back());
545 
546   for (unsigned It = 1; It != Count; It++) {
547     // Replace the conditional branch of the previous iteration subloop with an
548     // unconditional one to this one
549     BranchInst *AftTerm =
550         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
551     BranchInst::Create(AftBlocksFirst[It], AftTerm->getIterator());
552     AftTerm->eraseFromParent();
553 
554     AftBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
555                                            SubLoopBlocksLast.back());
556     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
557   }
558 
559   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
560   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
561   // new ones required.
562   if (Count != 1) {
563     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
564     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
565                            SubLoopBlocksFirst[0]);
566     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
567                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
568 
569     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
570                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
571     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
572                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
573     DTU.applyUpdatesPermissive(DTUpdates);
574   }
575 
576   // Merge adjacent basic blocks, if possible.
577   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
578   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
579   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
580   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
581 
582   MergeBlockSuccessorsIntoGivenBlocks(MergeBlocks, L, &DTU, LI);
583 
584   // Apply updates to the DomTree.
585   DT = &DTU.getDomTree();
586 
587   // At this point, the code is well formed.  We now do a quick sweep over the
588   // inserted code, doing constant propagation and dead code elimination as we
589   // go.
590   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC, TTI);
591   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC,
592                           TTI);
593 
594   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
595   ++NumUnrolledAndJammed;
596 
597   // Update LoopInfo if the loop is completely removed.
598   if (CompletelyUnroll)
599     LI->erase(L);
600 
601 #ifndef NDEBUG
602   // We shouldn't have done anything to break loop simplify form or LCSSA.
603   Loop *OutestLoop = SubLoop->getParentLoop()
604                          ? SubLoop->getParentLoop()->getParentLoop()
605                                ? SubLoop->getParentLoop()->getParentLoop()
606                                : SubLoop->getParentLoop()
607                          : SubLoop;
608   assert(DT->verify());
609   LI->verify(*DT);
610   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
611   if (!CompletelyUnroll)
612     assert(L->isLoopSimplifyForm());
613   assert(SubLoop->isLoopSimplifyForm());
614   SE->verify();
615 #endif
616 
617   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
618                           : LoopUnrollResult::PartiallyUnrolled;
619 }
620 
621 static bool getLoadsAndStores(BasicBlockSet &Blocks,
622                               SmallVector<Instruction *, 4> &MemInstr) {
623   // Scan the BBs and collect legal loads and stores.
624   // Returns false if non-simple loads/stores are found.
625   for (BasicBlock *BB : Blocks) {
626     for (Instruction &I : *BB) {
627       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
628         if (!Ld->isSimple())
629           return false;
630         MemInstr.push_back(&I);
631       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
632         if (!St->isSimple())
633           return false;
634         MemInstr.push_back(&I);
635       } else if (I.mayReadOrWriteMemory()) {
636         return false;
637       }
638     }
639   }
640   return true;
641 }
642 
643 static bool preservesForwardDependence(Instruction *Src, Instruction *Dst,
644                                        unsigned UnrollLevel, unsigned JamLevel,
645                                        bool Sequentialized, Dependence *D) {
646   // UnrollLevel might carry the dependency Src --> Dst
647   // Does a different loop after unrolling?
648   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
649        ++CurLoopDepth) {
650     auto JammedDir = D->getDirection(CurLoopDepth);
651     if (JammedDir == Dependence::DVEntry::LT)
652       return true;
653 
654     if (JammedDir & Dependence::DVEntry::GT)
655       return false;
656   }
657 
658   return true;
659 }
660 
661 static bool preservesBackwardDependence(Instruction *Src, Instruction *Dst,
662                                         unsigned UnrollLevel, unsigned JamLevel,
663                                         bool Sequentialized, Dependence *D) {
664   // UnrollLevel might carry the dependency Dst --> Src
665   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
666        ++CurLoopDepth) {
667     auto JammedDir = D->getDirection(CurLoopDepth);
668     if (JammedDir == Dependence::DVEntry::GT)
669       return true;
670 
671     if (JammedDir & Dependence::DVEntry::LT)
672       return false;
673   }
674 
675   // Backward dependencies are only preserved if not interleaved.
676   return Sequentialized;
677 }
678 
679 // Check whether it is semantically safe Src and Dst considering any potential
680 // dependency between them.
681 //
682 // @param UnrollLevel The level of the loop being unrolled
683 // @param JamLevel    The level of the loop being jammed; if Src and Dst are on
684 // different levels, the outermost common loop counts as jammed level
685 //
686 // @return true if is safe and false if there is a dependency violation.
687 static bool checkDependency(Instruction *Src, Instruction *Dst,
688                             unsigned UnrollLevel, unsigned JamLevel,
689                             bool Sequentialized, DependenceInfo &DI) {
690   assert(UnrollLevel <= JamLevel &&
691          "Expecting JamLevel to be at least UnrollLevel");
692 
693   if (Src == Dst)
694     return true;
695   // Ignore Input dependencies.
696   if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
697     return true;
698 
699   // Check whether unroll-and-jam may violate a dependency.
700   // By construction, every dependency will be lexicographically non-negative
701   // (if it was, it would violate the current execution order), such as
702   //   (0,0,>,*,*)
703   // Unroll-and-jam changes the GT execution of two executions to the same
704   // iteration of the chosen unroll level. That is, a GT dependence becomes a GE
705   // dependence (or EQ, if we fully unrolled the loop) at the loop's position:
706   //   (0,0,>=,*,*)
707   // Now, the dependency is not necessarily non-negative anymore, i.e.
708   // unroll-and-jam may violate correctness.
709   std::unique_ptr<Dependence> D = DI.depends(Src, Dst, true);
710   if (!D)
711     return true;
712   assert(D->isOrdered() && "Expected an output, flow or anti dep.");
713 
714   if (D->isConfused()) {
715     LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
716                       << "  " << *Src << "\n"
717                       << "  " << *Dst << "\n");
718     return false;
719   }
720 
721   // If outer levels (levels enclosing the loop being unroll-and-jammed) have a
722   // non-equal direction, then the locations accessed in the inner levels cannot
723   // overlap in memory. We assumes the indexes never overlap into neighboring
724   // dimensions.
725   for (unsigned CurLoopDepth = 1; CurLoopDepth < UnrollLevel; ++CurLoopDepth)
726     if (!(D->getDirection(CurLoopDepth) & Dependence::DVEntry::EQ))
727       return true;
728 
729   auto UnrollDirection = D->getDirection(UnrollLevel);
730 
731   // If the distance carried by the unrolled loop is 0, then after unrolling
732   // that distance will become non-zero resulting in non-overlapping accesses in
733   // the inner loops.
734   if (UnrollDirection == Dependence::DVEntry::EQ)
735     return true;
736 
737   if (UnrollDirection & Dependence::DVEntry::LT &&
738       !preservesForwardDependence(Src, Dst, UnrollLevel, JamLevel,
739                                   Sequentialized, D.get()))
740     return false;
741 
742   if (UnrollDirection & Dependence::DVEntry::GT &&
743       !preservesBackwardDependence(Src, Dst, UnrollLevel, JamLevel,
744                                    Sequentialized, D.get()))
745     return false;
746 
747   return true;
748 }
749 
750 static bool
751 checkDependencies(Loop &Root, const BasicBlockSet &SubLoopBlocks,
752                   const DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
753                   const DenseMap<Loop *, BasicBlockSet> &AftBlocksMap,
754                   DependenceInfo &DI, LoopInfo &LI) {
755   SmallVector<BasicBlockSet, 8> AllBlocks;
756   for (Loop *L : Root.getLoopsInPreorder())
757     if (ForeBlocksMap.contains(L))
758       AllBlocks.push_back(ForeBlocksMap.lookup(L));
759   AllBlocks.push_back(SubLoopBlocks);
760   for (Loop *L : Root.getLoopsInPreorder())
761     if (AftBlocksMap.contains(L))
762       AllBlocks.push_back(AftBlocksMap.lookup(L));
763 
764   unsigned LoopDepth = Root.getLoopDepth();
765   SmallVector<Instruction *, 4> EarlierLoadsAndStores;
766   SmallVector<Instruction *, 4> CurrentLoadsAndStores;
767   for (BasicBlockSet &Blocks : AllBlocks) {
768     CurrentLoadsAndStores.clear();
769     if (!getLoadsAndStores(Blocks, CurrentLoadsAndStores))
770       return false;
771 
772     Loop *CurLoop = LI.getLoopFor((*Blocks.begin())->front().getParent());
773     unsigned CurLoopDepth = CurLoop->getLoopDepth();
774 
775     for (auto *Earlier : EarlierLoadsAndStores) {
776       Loop *EarlierLoop = LI.getLoopFor(Earlier->getParent());
777       unsigned EarlierDepth = EarlierLoop->getLoopDepth();
778       unsigned CommonLoopDepth = std::min(EarlierDepth, CurLoopDepth);
779       for (auto *Later : CurrentLoadsAndStores) {
780         if (!checkDependency(Earlier, Later, LoopDepth, CommonLoopDepth, false,
781                              DI))
782           return false;
783       }
784     }
785 
786     size_t NumInsts = CurrentLoadsAndStores.size();
787     for (size_t I = 0; I < NumInsts; ++I) {
788       for (size_t J = I; J < NumInsts; ++J) {
789         if (!checkDependency(CurrentLoadsAndStores[I], CurrentLoadsAndStores[J],
790                              LoopDepth, CurLoopDepth, true, DI))
791           return false;
792       }
793     }
794 
795     EarlierLoadsAndStores.append(CurrentLoadsAndStores.begin(),
796                                  CurrentLoadsAndStores.end());
797   }
798   return true;
799 }
800 
801 static bool isEligibleLoopForm(const Loop &Root) {
802   // Root must have a child.
803   if (Root.getSubLoops().size() != 1)
804     return false;
805 
806   const Loop *L = &Root;
807   do {
808     // All loops in Root need to be in simplify and rotated form.
809     if (!L->isLoopSimplifyForm())
810       return false;
811 
812     if (!L->isRotatedForm())
813       return false;
814 
815     if (L->getHeader()->hasAddressTaken()) {
816       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
817       return false;
818     }
819 
820     unsigned SubLoopsSize = L->getSubLoops().size();
821     if (SubLoopsSize == 0)
822       return true;
823 
824     // Only one child is allowed.
825     if (SubLoopsSize != 1)
826       return false;
827 
828     // Only loops with a single exit block can be unrolled and jammed.
829     // The function getExitBlock() is used for this check, rather than
830     // getUniqueExitBlock() to ensure loops with mulitple exit edges are
831     // disallowed.
832     if (!L->getExitBlock()) {
833       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; only loops with single exit "
834                            "blocks can be unrolled and jammed.\n");
835       return false;
836     }
837 
838     // Only loops with a single exiting block can be unrolled and jammed.
839     if (!L->getExitingBlock()) {
840       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; only loops with single "
841                            "exiting blocks can be unrolled and jammed.\n");
842       return false;
843     }
844 
845     L = L->getSubLoops()[0];
846   } while (L);
847 
848   return true;
849 }
850 
851 static Loop *getInnerMostLoop(Loop *L) {
852   while (!L->getSubLoops().empty())
853     L = L->getSubLoops()[0];
854   return L;
855 }
856 
857 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
858                                 DependenceInfo &DI, LoopInfo &LI) {
859   if (!isEligibleLoopForm(*L)) {
860     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Ineligible loop form\n");
861     return false;
862   }
863 
864   /* We currently handle outer loops like this:
865         |
866     ForeFirst    <------\   }
867      Blocks             |   } ForeBlocks of L
868     ForeLast            |   }
869         |               |
870        ...              |
871         |               |
872     ForeFirst    <----\ |   }
873      Blocks           | |   } ForeBlocks of a inner loop of L
874     ForeLast          | |   }
875         |             | |
876     JamLoopFirst  <\  | |   }
877      Blocks        |  | |   } JamLoopBlocks of the innermost loop
878     JamLoopLast   -/  | |   }
879         |             | |
880     AftFirst          | |   }
881      Blocks           | |   } AftBlocks of a inner loop of L
882     AftLast     ------/ |   }
883         |               |
884        ...              |
885         |               |
886     AftFirst            |   }
887      Blocks             |   } AftBlocks of L
888     AftLast     --------/   }
889         |
890 
891     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
892     and AftBlocks, providing that there is one edge from Fores to SubLoops,
893     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
894     In practice we currently limit Aft blocks to a single block, and limit
895     things further in the profitablility checks of the unroll and jam pass.
896 
897     Because of the way we rearrange basic blocks, we also require that
898     the Fore blocks of L on all unrolled iterations are safe to move before the
899     blocks of the direct child of L of all iterations. So we require that the
900     phi node looping operands of ForeHeader can be moved to at least the end of
901     ForeEnd, so that we can arrange cloned Fore Blocks before the subloop and
902     match up Phi's correctly.
903 
904     i.e. The old order of blocks used to be
905            (F1)1 (F2)1 J1_1 J1_2 (A2)1 (A1)1 (F1)2 (F2)2 J2_1 J2_2 (A2)2 (A1)2.
906          It needs to be safe to transform this to
907            (F1)1 (F1)2 (F2)1 (F2)2 J1_1 J1_2 J2_1 J2_2 (A2)1 (A2)2 (A1)1 (A1)2.
908 
909     There are then a number of checks along the lines of no calls, no
910     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
911     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
912     UnrollAndJamLoop if the trip count cannot be easily calculated.
913   */
914 
915   // Split blocks into Fore/SubLoop/Aft based on dominators
916   Loop *JamLoop = getInnerMostLoop(L);
917   BasicBlockSet SubLoopBlocks;
918   DenseMap<Loop *, BasicBlockSet> ForeBlocksMap;
919   DenseMap<Loop *, BasicBlockSet> AftBlocksMap;
920   if (!partitionOuterLoopBlocks(*L, *JamLoop, SubLoopBlocks, ForeBlocksMap,
921                                 AftBlocksMap, DT)) {
922     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
923     return false;
924   }
925 
926   // Aft blocks may need to move instructions to fore blocks, which becomes more
927   // difficult if there are multiple (potentially conditionally executed)
928   // blocks. For now we just exclude loops with multiple aft blocks.
929   if (AftBlocksMap[L].size() != 1) {
930     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
931                          "multiple blocks after the loop\n");
932     return false;
933   }
934 
935   // Check inner loop backedge count is consistent on all iterations of the
936   // outer loop
937   if (any_of(L->getLoopsInPreorder(), [&SE](Loop *SubLoop) {
938         return !hasIterationCountInvariantInParent(SubLoop, SE);
939       })) {
940     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
941                          "not consistent on each iteration\n");
942     return false;
943   }
944 
945   // Check the loop safety info for exceptions.
946   SimpleLoopSafetyInfo LSI;
947   LSI.computeLoopSafetyInfo(L);
948   if (LSI.anyBlockMayThrow()) {
949     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
950     return false;
951   }
952 
953   // We've ruled out the easy stuff and now need to check that there are no
954   // interdependencies which may prevent us from moving the:
955   //  ForeBlocks before Subloop and AftBlocks.
956   //  Subloop before AftBlocks.
957   //  ForeBlock phi operands before the subloop
958 
959   // Make sure we can move all instructions we need to before the subloop
960   BasicBlock *Header = L->getHeader();
961   BasicBlock *Latch = L->getLoopLatch();
962   BasicBlockSet AftBlocks = AftBlocksMap[L];
963   Loop *SubLoop = L->getSubLoops()[0];
964   if (!processHeaderPhiOperands(
965           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
966             if (SubLoop->contains(I->getParent()))
967               return false;
968             if (AftBlocks.count(I->getParent())) {
969               // If we hit a phi node in afts we know we are done (probably
970               // LCSSA)
971               if (isa<PHINode>(I))
972                 return false;
973               // Can't move instructions with side effects or memory
974               // reads/writes
975               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
976                 return false;
977             }
978             // Keep going
979             return true;
980           })) {
981     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
982                          "instructions after subloop to before it\n");
983     return false;
984   }
985 
986   // Check for memory dependencies which prohibit the unrolling we are doing.
987   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
988   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
989   if (!checkDependencies(*L, SubLoopBlocks, ForeBlocksMap, AftBlocksMap, DI,
990                          LI)) {
991     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
992     return false;
993   }
994 
995   return true;
996 }
997