109467b48Spatrick //===-- UnrollLoop.cpp - Loop unrolling utilities -------------------------===// 209467b48Spatrick // 309467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 409467b48Spatrick // See https://llvm.org/LICENSE.txt for license information. 509467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 609467b48Spatrick // 709467b48Spatrick //===----------------------------------------------------------------------===// 809467b48Spatrick // 909467b48Spatrick // This file implements some loop unrolling utilities. It does not define any 1009467b48Spatrick // actual pass or policy, but provides a single function to perform loop 1109467b48Spatrick // unrolling. 1209467b48Spatrick // 1309467b48Spatrick // The process of unrolling can produce extraneous basic blocks linked with 1409467b48Spatrick // unconditional branches. This will be corrected in the future. 1509467b48Spatrick // 1609467b48Spatrick //===----------------------------------------------------------------------===// 1709467b48Spatrick 18*097a140dSpatrick #include "llvm/ADT/ArrayRef.h" 19*097a140dSpatrick #include "llvm/ADT/DenseMap.h" 20*097a140dSpatrick #include "llvm/ADT/Optional.h" 21*097a140dSpatrick #include "llvm/ADT/STLExtras.h" 22*097a140dSpatrick #include "llvm/ADT/SetVector.h" 23*097a140dSpatrick #include "llvm/ADT/SmallVector.h" 2409467b48Spatrick #include "llvm/ADT/Statistic.h" 25*097a140dSpatrick #include "llvm/ADT/StringRef.h" 26*097a140dSpatrick #include "llvm/ADT/Twine.h" 27*097a140dSpatrick #include "llvm/ADT/ilist_iterator.h" 28*097a140dSpatrick #include "llvm/ADT/iterator_range.h" 2909467b48Spatrick #include "llvm/Analysis/AssumptionCache.h" 30*097a140dSpatrick #include "llvm/Analysis/DomTreeUpdater.h" 3109467b48Spatrick #include "llvm/Analysis/InstructionSimplify.h" 32*097a140dSpatrick #include "llvm/Analysis/LoopInfo.h" 3309467b48Spatrick #include "llvm/Analysis/LoopIterator.h" 3409467b48Spatrick #include "llvm/Analysis/OptimizationRemarkEmitter.h" 3509467b48Spatrick #include "llvm/Analysis/ScalarEvolution.h" 3609467b48Spatrick #include "llvm/IR/BasicBlock.h" 37*097a140dSpatrick #include "llvm/IR/CFG.h" 38*097a140dSpatrick #include "llvm/IR/Constants.h" 3909467b48Spatrick #include "llvm/IR/DebugInfoMetadata.h" 40*097a140dSpatrick #include "llvm/IR/DebugLoc.h" 41*097a140dSpatrick #include "llvm/IR/DiagnosticInfo.h" 4209467b48Spatrick #include "llvm/IR/Dominators.h" 43*097a140dSpatrick #include "llvm/IR/Function.h" 44*097a140dSpatrick #include "llvm/IR/Instruction.h" 45*097a140dSpatrick #include "llvm/IR/Instructions.h" 4609467b48Spatrick #include "llvm/IR/IntrinsicInst.h" 47*097a140dSpatrick #include "llvm/IR/Metadata.h" 48*097a140dSpatrick #include "llvm/IR/Module.h" 49*097a140dSpatrick #include "llvm/IR/Use.h" 50*097a140dSpatrick #include "llvm/IR/User.h" 51*097a140dSpatrick #include "llvm/IR/ValueHandle.h" 52*097a140dSpatrick #include "llvm/IR/ValueMap.h" 53*097a140dSpatrick #include "llvm/Support/Casting.h" 5409467b48Spatrick #include "llvm/Support/CommandLine.h" 5509467b48Spatrick #include "llvm/Support/Debug.h" 56*097a140dSpatrick #include "llvm/Support/GenericDomTree.h" 57*097a140dSpatrick #include "llvm/Support/MathExtras.h" 5809467b48Spatrick #include "llvm/Support/raw_ostream.h" 5909467b48Spatrick #include "llvm/Transforms/Utils/BasicBlockUtils.h" 6009467b48Spatrick #include "llvm/Transforms/Utils/Cloning.h" 6109467b48Spatrick #include "llvm/Transforms/Utils/Local.h" 6209467b48Spatrick #include "llvm/Transforms/Utils/LoopSimplify.h" 6309467b48Spatrick #include "llvm/Transforms/Utils/LoopUtils.h" 6409467b48Spatrick #include "llvm/Transforms/Utils/SimplifyIndVar.h" 6509467b48Spatrick #include "llvm/Transforms/Utils/UnrollLoop.h" 66*097a140dSpatrick #include "llvm/Transforms/Utils/ValueMapper.h" 67*097a140dSpatrick #include <algorithm> 68*097a140dSpatrick #include <assert.h> 69*097a140dSpatrick #include <type_traits> 70*097a140dSpatrick #include <vector> 71*097a140dSpatrick 72*097a140dSpatrick namespace llvm { 73*097a140dSpatrick class DataLayout; 74*097a140dSpatrick class Value; 75*097a140dSpatrick } // namespace llvm 76*097a140dSpatrick 7709467b48Spatrick using namespace llvm; 7809467b48Spatrick 7909467b48Spatrick #define DEBUG_TYPE "loop-unroll" 8009467b48Spatrick 8109467b48Spatrick // TODO: Should these be here or in LoopUnroll? 8209467b48Spatrick STATISTIC(NumCompletelyUnrolled, "Number of loops completely unrolled"); 8309467b48Spatrick STATISTIC(NumUnrolled, "Number of loops unrolled (completely or otherwise)"); 84*097a140dSpatrick STATISTIC(NumUnrolledNotLatch, "Number of loops unrolled without a conditional " 85*097a140dSpatrick "latch (completely or otherwise)"); 8609467b48Spatrick 8709467b48Spatrick static cl::opt<bool> 8809467b48Spatrick UnrollRuntimeEpilog("unroll-runtime-epilog", cl::init(false), cl::Hidden, 8909467b48Spatrick cl::desc("Allow runtime unrolled loops to be unrolled " 9009467b48Spatrick "with epilog instead of prolog.")); 9109467b48Spatrick 9209467b48Spatrick static cl::opt<bool> 9309467b48Spatrick UnrollVerifyDomtree("unroll-verify-domtree", cl::Hidden, 9409467b48Spatrick cl::desc("Verify domtree after unrolling"), 9509467b48Spatrick #ifdef EXPENSIVE_CHECKS 9609467b48Spatrick cl::init(true) 9709467b48Spatrick #else 9809467b48Spatrick cl::init(false) 9909467b48Spatrick #endif 10009467b48Spatrick ); 10109467b48Spatrick 10209467b48Spatrick /// Check if unrolling created a situation where we need to insert phi nodes to 10309467b48Spatrick /// preserve LCSSA form. 10409467b48Spatrick /// \param Blocks is a vector of basic blocks representing unrolled loop. 10509467b48Spatrick /// \param L is the outer loop. 10609467b48Spatrick /// It's possible that some of the blocks are in L, and some are not. In this 10709467b48Spatrick /// case, if there is a use is outside L, and definition is inside L, we need to 10809467b48Spatrick /// insert a phi-node, otherwise LCSSA will be broken. 10909467b48Spatrick /// The function is just a helper function for llvm::UnrollLoop that returns 11009467b48Spatrick /// true if this situation occurs, indicating that LCSSA needs to be fixed. 11109467b48Spatrick static bool needToInsertPhisForLCSSA(Loop *L, std::vector<BasicBlock *> Blocks, 11209467b48Spatrick LoopInfo *LI) { 11309467b48Spatrick for (BasicBlock *BB : Blocks) { 11409467b48Spatrick if (LI->getLoopFor(BB) == L) 11509467b48Spatrick continue; 11609467b48Spatrick for (Instruction &I : *BB) { 11709467b48Spatrick for (Use &U : I.operands()) { 11809467b48Spatrick if (auto Def = dyn_cast<Instruction>(U)) { 11909467b48Spatrick Loop *DefLoop = LI->getLoopFor(Def->getParent()); 12009467b48Spatrick if (!DefLoop) 12109467b48Spatrick continue; 12209467b48Spatrick if (DefLoop->contains(L)) 12309467b48Spatrick return true; 12409467b48Spatrick } 12509467b48Spatrick } 12609467b48Spatrick } 12709467b48Spatrick } 12809467b48Spatrick return false; 12909467b48Spatrick } 13009467b48Spatrick 13109467b48Spatrick /// Adds ClonedBB to LoopInfo, creates a new loop for ClonedBB if necessary 13209467b48Spatrick /// and adds a mapping from the original loop to the new loop to NewLoops. 13309467b48Spatrick /// Returns nullptr if no new loop was created and a pointer to the 13409467b48Spatrick /// original loop OriginalBB was part of otherwise. 13509467b48Spatrick const Loop* llvm::addClonedBlockToLoopInfo(BasicBlock *OriginalBB, 13609467b48Spatrick BasicBlock *ClonedBB, LoopInfo *LI, 13709467b48Spatrick NewLoopsMap &NewLoops) { 13809467b48Spatrick // Figure out which loop New is in. 13909467b48Spatrick const Loop *OldLoop = LI->getLoopFor(OriginalBB); 14009467b48Spatrick assert(OldLoop && "Should (at least) be in the loop being unrolled!"); 14109467b48Spatrick 14209467b48Spatrick Loop *&NewLoop = NewLoops[OldLoop]; 14309467b48Spatrick if (!NewLoop) { 14409467b48Spatrick // Found a new sub-loop. 14509467b48Spatrick assert(OriginalBB == OldLoop->getHeader() && 14609467b48Spatrick "Header should be first in RPO"); 14709467b48Spatrick 14809467b48Spatrick NewLoop = LI->AllocateLoop(); 14909467b48Spatrick Loop *NewLoopParent = NewLoops.lookup(OldLoop->getParentLoop()); 15009467b48Spatrick 15109467b48Spatrick if (NewLoopParent) 15209467b48Spatrick NewLoopParent->addChildLoop(NewLoop); 15309467b48Spatrick else 15409467b48Spatrick LI->addTopLevelLoop(NewLoop); 15509467b48Spatrick 15609467b48Spatrick NewLoop->addBasicBlockToLoop(ClonedBB, *LI); 15709467b48Spatrick return OldLoop; 15809467b48Spatrick } else { 15909467b48Spatrick NewLoop->addBasicBlockToLoop(ClonedBB, *LI); 16009467b48Spatrick return nullptr; 16109467b48Spatrick } 16209467b48Spatrick } 16309467b48Spatrick 16409467b48Spatrick /// The function chooses which type of unroll (epilog or prolog) is more 16509467b48Spatrick /// profitabale. 16609467b48Spatrick /// Epilog unroll is more profitable when there is PHI that starts from 16709467b48Spatrick /// constant. In this case epilog will leave PHI start from constant, 16809467b48Spatrick /// but prolog will convert it to non-constant. 16909467b48Spatrick /// 17009467b48Spatrick /// loop: 17109467b48Spatrick /// PN = PHI [I, Latch], [CI, PreHeader] 17209467b48Spatrick /// I = foo(PN) 17309467b48Spatrick /// ... 17409467b48Spatrick /// 17509467b48Spatrick /// Epilog unroll case. 17609467b48Spatrick /// loop: 17709467b48Spatrick /// PN = PHI [I2, Latch], [CI, PreHeader] 17809467b48Spatrick /// I1 = foo(PN) 17909467b48Spatrick /// I2 = foo(I1) 18009467b48Spatrick /// ... 18109467b48Spatrick /// Prolog unroll case. 18209467b48Spatrick /// NewPN = PHI [PrologI, Prolog], [CI, PreHeader] 18309467b48Spatrick /// loop: 18409467b48Spatrick /// PN = PHI [I2, Latch], [NewPN, PreHeader] 18509467b48Spatrick /// I1 = foo(PN) 18609467b48Spatrick /// I2 = foo(I1) 18709467b48Spatrick /// ... 18809467b48Spatrick /// 18909467b48Spatrick static bool isEpilogProfitable(Loop *L) { 19009467b48Spatrick BasicBlock *PreHeader = L->getLoopPreheader(); 19109467b48Spatrick BasicBlock *Header = L->getHeader(); 19209467b48Spatrick assert(PreHeader && Header); 19309467b48Spatrick for (const PHINode &PN : Header->phis()) { 19409467b48Spatrick if (isa<ConstantInt>(PN.getIncomingValueForBlock(PreHeader))) 19509467b48Spatrick return true; 19609467b48Spatrick } 19709467b48Spatrick return false; 19809467b48Spatrick } 19909467b48Spatrick 20009467b48Spatrick /// Perform some cleanup and simplifications on loops after unrolling. It is 20109467b48Spatrick /// useful to simplify the IV's in the new loop, as well as do a quick 20209467b48Spatrick /// simplify/dce pass of the instructions. 20309467b48Spatrick void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI, 20409467b48Spatrick ScalarEvolution *SE, DominatorTree *DT, 205*097a140dSpatrick AssumptionCache *AC, 206*097a140dSpatrick const TargetTransformInfo *TTI) { 20709467b48Spatrick // Simplify any new induction variables in the partially unrolled loop. 20809467b48Spatrick if (SE && SimplifyIVs) { 20909467b48Spatrick SmallVector<WeakTrackingVH, 16> DeadInsts; 210*097a140dSpatrick simplifyLoopIVs(L, SE, DT, LI, TTI, DeadInsts); 21109467b48Spatrick 21209467b48Spatrick // Aggressively clean up dead instructions that simplifyLoopIVs already 21309467b48Spatrick // identified. Any remaining should be cleaned up below. 214*097a140dSpatrick while (!DeadInsts.empty()) { 215*097a140dSpatrick Value *V = DeadInsts.pop_back_val(); 216*097a140dSpatrick if (Instruction *Inst = dyn_cast_or_null<Instruction>(V)) 21709467b48Spatrick RecursivelyDeleteTriviallyDeadInstructions(Inst); 21809467b48Spatrick } 219*097a140dSpatrick } 22009467b48Spatrick 22109467b48Spatrick // At this point, the code is well formed. We now do a quick sweep over the 22209467b48Spatrick // inserted code, doing constant propagation and dead code elimination as we 22309467b48Spatrick // go. 22409467b48Spatrick const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); 22509467b48Spatrick for (BasicBlock *BB : L->getBlocks()) { 22609467b48Spatrick for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { 22709467b48Spatrick Instruction *Inst = &*I++; 22809467b48Spatrick 22909467b48Spatrick if (Value *V = SimplifyInstruction(Inst, {DL, nullptr, DT, AC})) 23009467b48Spatrick if (LI->replacementPreservesLCSSAForm(Inst, V)) 23109467b48Spatrick Inst->replaceAllUsesWith(V); 23209467b48Spatrick if (isInstructionTriviallyDead(Inst)) 23309467b48Spatrick BB->getInstList().erase(Inst); 23409467b48Spatrick } 23509467b48Spatrick } 23609467b48Spatrick 23709467b48Spatrick // TODO: after peeling or unrolling, previously loop variant conditions are 23809467b48Spatrick // likely to fold to constants, eagerly propagating those here will require 23909467b48Spatrick // fewer cleanup passes to be run. Alternatively, a LoopEarlyCSE might be 24009467b48Spatrick // appropriate. 24109467b48Spatrick } 24209467b48Spatrick 24309467b48Spatrick /// Unroll the given loop by Count. The loop must be in LCSSA form. Unrolling 24409467b48Spatrick /// can only fail when the loop's latch block is not terminated by a conditional 24509467b48Spatrick /// branch instruction. However, if the trip count (and multiple) are not known, 24609467b48Spatrick /// loop unrolling will mostly produce more code that is no faster. 24709467b48Spatrick /// 24809467b48Spatrick /// TripCount is the upper bound of the iteration on which control exits 24909467b48Spatrick /// LatchBlock. Control may exit the loop prior to TripCount iterations either 25009467b48Spatrick /// via an early branch in other loop block or via LatchBlock terminator. This 25109467b48Spatrick /// is relaxed from the general definition of trip count which is the number of 25209467b48Spatrick /// times the loop header executes. Note that UnrollLoop assumes that the loop 25309467b48Spatrick /// counter test is in LatchBlock in order to remove unnecesssary instances of 25409467b48Spatrick /// the test. If control can exit the loop from the LatchBlock's terminator 25509467b48Spatrick /// prior to TripCount iterations, flag PreserveCondBr needs to be set. 25609467b48Spatrick /// 25709467b48Spatrick /// PreserveCondBr indicates whether the conditional branch of the LatchBlock 25809467b48Spatrick /// needs to be preserved. It is needed when we use trip count upper bound to 25909467b48Spatrick /// fully unroll the loop. If PreserveOnlyFirst is also set then only the first 26009467b48Spatrick /// conditional branch needs to be preserved. 26109467b48Spatrick /// 26209467b48Spatrick /// Similarly, TripMultiple divides the number of times that the LatchBlock may 26309467b48Spatrick /// execute without exiting the loop. 26409467b48Spatrick /// 26509467b48Spatrick /// If AllowRuntime is true then UnrollLoop will consider unrolling loops that 26609467b48Spatrick /// have a runtime (i.e. not compile time constant) trip count. Unrolling these 26709467b48Spatrick /// loops require a unroll "prologue" that runs "RuntimeTripCount % Count" 26809467b48Spatrick /// iterations before branching into the unrolled loop. UnrollLoop will not 26909467b48Spatrick /// runtime-unroll the loop if computing RuntimeTripCount will be expensive and 27009467b48Spatrick /// AllowExpensiveTripCount is false. 27109467b48Spatrick /// 27209467b48Spatrick /// If we want to perform PGO-based loop peeling, PeelCount is set to the 27309467b48Spatrick /// number of iterations we want to peel off. 27409467b48Spatrick /// 27509467b48Spatrick /// The LoopInfo Analysis that is passed will be kept consistent. 27609467b48Spatrick /// 27709467b48Spatrick /// This utility preserves LoopInfo. It will also preserve ScalarEvolution and 27809467b48Spatrick /// DominatorTree if they are non-null. 27909467b48Spatrick /// 28009467b48Spatrick /// If RemainderLoop is non-null, it will receive the remainder loop (if 28109467b48Spatrick /// required and not fully unrolled). 28209467b48Spatrick LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, 28309467b48Spatrick ScalarEvolution *SE, DominatorTree *DT, 28409467b48Spatrick AssumptionCache *AC, 285*097a140dSpatrick const TargetTransformInfo *TTI, 28609467b48Spatrick OptimizationRemarkEmitter *ORE, 28709467b48Spatrick bool PreserveLCSSA, Loop **RemainderLoop) { 28809467b48Spatrick 28909467b48Spatrick BasicBlock *Preheader = L->getLoopPreheader(); 29009467b48Spatrick if (!Preheader) { 29109467b48Spatrick LLVM_DEBUG(dbgs() << " Can't unroll; loop preheader-insertion failed.\n"); 29209467b48Spatrick return LoopUnrollResult::Unmodified; 29309467b48Spatrick } 29409467b48Spatrick 29509467b48Spatrick BasicBlock *LatchBlock = L->getLoopLatch(); 29609467b48Spatrick if (!LatchBlock) { 29709467b48Spatrick LLVM_DEBUG(dbgs() << " Can't unroll; loop exit-block-insertion failed.\n"); 29809467b48Spatrick return LoopUnrollResult::Unmodified; 29909467b48Spatrick } 30009467b48Spatrick 30109467b48Spatrick // Loops with indirectbr cannot be cloned. 30209467b48Spatrick if (!L->isSafeToClone()) { 30309467b48Spatrick LLVM_DEBUG(dbgs() << " Can't unroll; Loop body cannot be cloned.\n"); 30409467b48Spatrick return LoopUnrollResult::Unmodified; 30509467b48Spatrick } 30609467b48Spatrick 307*097a140dSpatrick // The current loop unroll pass can unroll loops that have 308*097a140dSpatrick // (1) single latch; and 309*097a140dSpatrick // (2a) latch is unconditional; or 310*097a140dSpatrick // (2b) latch is conditional and is an exiting block 31109467b48Spatrick // FIXME: The implementation can be extended to work with more complicated 31209467b48Spatrick // cases, e.g. loops with multiple latches. 31309467b48Spatrick BasicBlock *Header = L->getHeader(); 314*097a140dSpatrick BranchInst *LatchBI = dyn_cast<BranchInst>(LatchBlock->getTerminator()); 31509467b48Spatrick 316*097a140dSpatrick // A conditional branch which exits the loop, which can be optimized to an 317*097a140dSpatrick // unconditional branch in the unrolled loop in some cases. 318*097a140dSpatrick BranchInst *ExitingBI = nullptr; 319*097a140dSpatrick bool LatchIsExiting = L->isLoopExiting(LatchBlock); 320*097a140dSpatrick if (LatchIsExiting) 321*097a140dSpatrick ExitingBI = LatchBI; 322*097a140dSpatrick else if (BasicBlock *ExitingBlock = L->getExitingBlock()) 323*097a140dSpatrick ExitingBI = dyn_cast<BranchInst>(ExitingBlock->getTerminator()); 324*097a140dSpatrick if (!LatchBI || (LatchBI->isConditional() && !LatchIsExiting)) { 32509467b48Spatrick LLVM_DEBUG( 32609467b48Spatrick dbgs() << "Can't unroll; a conditional latch must exit the loop"); 32709467b48Spatrick return LoopUnrollResult::Unmodified; 32809467b48Spatrick } 329*097a140dSpatrick LLVM_DEBUG({ 330*097a140dSpatrick if (ExitingBI) 331*097a140dSpatrick dbgs() << " Exiting Block = " << ExitingBI->getParent()->getName() 332*097a140dSpatrick << "\n"; 333*097a140dSpatrick else 334*097a140dSpatrick dbgs() << " No single exiting block\n"; 335*097a140dSpatrick }); 33609467b48Spatrick 33709467b48Spatrick if (Header->hasAddressTaken()) { 33809467b48Spatrick // The loop-rotate pass can be helpful to avoid this in many cases. 33909467b48Spatrick LLVM_DEBUG( 34009467b48Spatrick dbgs() << " Won't unroll loop: address of header block is taken.\n"); 34109467b48Spatrick return LoopUnrollResult::Unmodified; 34209467b48Spatrick } 34309467b48Spatrick 34409467b48Spatrick if (ULO.TripCount != 0) 34509467b48Spatrick LLVM_DEBUG(dbgs() << " Trip Count = " << ULO.TripCount << "\n"); 34609467b48Spatrick if (ULO.TripMultiple != 1) 34709467b48Spatrick LLVM_DEBUG(dbgs() << " Trip Multiple = " << ULO.TripMultiple << "\n"); 34809467b48Spatrick 34909467b48Spatrick // Effectively "DCE" unrolled iterations that are beyond the tripcount 35009467b48Spatrick // and will never be executed. 35109467b48Spatrick if (ULO.TripCount != 0 && ULO.Count > ULO.TripCount) 35209467b48Spatrick ULO.Count = ULO.TripCount; 35309467b48Spatrick 35409467b48Spatrick // Don't enter the unroll code if there is nothing to do. 35509467b48Spatrick if (ULO.TripCount == 0 && ULO.Count < 2 && ULO.PeelCount == 0) { 35609467b48Spatrick LLVM_DEBUG(dbgs() << "Won't unroll; almost nothing to do\n"); 35709467b48Spatrick return LoopUnrollResult::Unmodified; 35809467b48Spatrick } 35909467b48Spatrick 36009467b48Spatrick assert(ULO.Count > 0); 36109467b48Spatrick assert(ULO.TripMultiple > 0); 36209467b48Spatrick assert(ULO.TripCount == 0 || ULO.TripCount % ULO.TripMultiple == 0); 36309467b48Spatrick 36409467b48Spatrick // Are we eliminating the loop control altogether? 36509467b48Spatrick bool CompletelyUnroll = ULO.Count == ULO.TripCount; 36609467b48Spatrick SmallVector<BasicBlock *, 4> ExitBlocks; 36709467b48Spatrick L->getExitBlocks(ExitBlocks); 36809467b48Spatrick std::vector<BasicBlock*> OriginalLoopBlocks = L->getBlocks(); 36909467b48Spatrick 37009467b48Spatrick // Go through all exits of L and see if there are any phi-nodes there. We just 37109467b48Spatrick // conservatively assume that they're inserted to preserve LCSSA form, which 37209467b48Spatrick // means that complete unrolling might break this form. We need to either fix 37309467b48Spatrick // it in-place after the transformation, or entirely rebuild LCSSA. TODO: For 37409467b48Spatrick // now we just recompute LCSSA for the outer loop, but it should be possible 37509467b48Spatrick // to fix it in-place. 37609467b48Spatrick bool NeedToFixLCSSA = PreserveLCSSA && CompletelyUnroll && 37709467b48Spatrick any_of(ExitBlocks, [](const BasicBlock *BB) { 37809467b48Spatrick return isa<PHINode>(BB->begin()); 37909467b48Spatrick }); 38009467b48Spatrick 38109467b48Spatrick // We assume a run-time trip count if the compiler cannot 38209467b48Spatrick // figure out the loop trip count and the unroll-runtime 38309467b48Spatrick // flag is specified. 38409467b48Spatrick bool RuntimeTripCount = 38509467b48Spatrick (ULO.TripCount == 0 && ULO.Count > 0 && ULO.AllowRuntime); 38609467b48Spatrick 38709467b48Spatrick assert((!RuntimeTripCount || !ULO.PeelCount) && 38809467b48Spatrick "Did not expect runtime trip-count unrolling " 38909467b48Spatrick "and peeling for the same loop"); 39009467b48Spatrick 39109467b48Spatrick bool Peeled = false; 39209467b48Spatrick if (ULO.PeelCount) { 39309467b48Spatrick Peeled = peelLoop(L, ULO.PeelCount, LI, SE, DT, AC, PreserveLCSSA); 39409467b48Spatrick 39509467b48Spatrick // Successful peeling may result in a change in the loop preheader/trip 39609467b48Spatrick // counts. If we later unroll the loop, we want these to be updated. 39709467b48Spatrick if (Peeled) { 39809467b48Spatrick // According to our guards and profitability checks the only 39909467b48Spatrick // meaningful exit should be latch block. Other exits go to deopt, 40009467b48Spatrick // so we do not worry about them. 40109467b48Spatrick BasicBlock *ExitingBlock = L->getLoopLatch(); 40209467b48Spatrick assert(ExitingBlock && "Loop without exiting block?"); 40309467b48Spatrick assert(L->isLoopExiting(ExitingBlock) && "Latch is not exiting?"); 40409467b48Spatrick Preheader = L->getLoopPreheader(); 40509467b48Spatrick ULO.TripCount = SE->getSmallConstantTripCount(L, ExitingBlock); 40609467b48Spatrick ULO.TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock); 40709467b48Spatrick } 40809467b48Spatrick } 40909467b48Spatrick 41009467b48Spatrick // Loops containing convergent instructions must have a count that divides 41109467b48Spatrick // their TripMultiple. 41209467b48Spatrick LLVM_DEBUG( 41309467b48Spatrick { 41409467b48Spatrick bool HasConvergent = false; 41509467b48Spatrick for (auto &BB : L->blocks()) 41609467b48Spatrick for (auto &I : *BB) 417*097a140dSpatrick if (auto *CB = dyn_cast<CallBase>(&I)) 418*097a140dSpatrick HasConvergent |= CB->isConvergent(); 41909467b48Spatrick assert((!HasConvergent || ULO.TripMultiple % ULO.Count == 0) && 42009467b48Spatrick "Unroll count must divide trip multiple if loop contains a " 42109467b48Spatrick "convergent operation."); 42209467b48Spatrick }); 42309467b48Spatrick 42409467b48Spatrick bool EpilogProfitability = 42509467b48Spatrick UnrollRuntimeEpilog.getNumOccurrences() ? UnrollRuntimeEpilog 42609467b48Spatrick : isEpilogProfitable(L); 42709467b48Spatrick 42809467b48Spatrick if (RuntimeTripCount && ULO.TripMultiple % ULO.Count != 0 && 42909467b48Spatrick !UnrollRuntimeLoopRemainder(L, ULO.Count, ULO.AllowExpensiveTripCount, 43009467b48Spatrick EpilogProfitability, ULO.UnrollRemainder, 431*097a140dSpatrick ULO.ForgetAllSCEV, LI, SE, DT, AC, TTI, 43209467b48Spatrick PreserveLCSSA, RemainderLoop)) { 43309467b48Spatrick if (ULO.Force) 43409467b48Spatrick RuntimeTripCount = false; 43509467b48Spatrick else { 43609467b48Spatrick LLVM_DEBUG(dbgs() << "Won't unroll; remainder loop could not be " 43709467b48Spatrick "generated when assuming runtime trip count\n"); 43809467b48Spatrick return LoopUnrollResult::Unmodified; 43909467b48Spatrick } 44009467b48Spatrick } 44109467b48Spatrick 44209467b48Spatrick // If we know the trip count, we know the multiple... 44309467b48Spatrick unsigned BreakoutTrip = 0; 44409467b48Spatrick if (ULO.TripCount != 0) { 44509467b48Spatrick BreakoutTrip = ULO.TripCount % ULO.Count; 44609467b48Spatrick ULO.TripMultiple = 0; 44709467b48Spatrick } else { 44809467b48Spatrick // Figure out what multiple to use. 44909467b48Spatrick BreakoutTrip = ULO.TripMultiple = 45009467b48Spatrick (unsigned)GreatestCommonDivisor64(ULO.Count, ULO.TripMultiple); 45109467b48Spatrick } 45209467b48Spatrick 45309467b48Spatrick using namespace ore; 45409467b48Spatrick // Report the unrolling decision. 45509467b48Spatrick if (CompletelyUnroll) { 45609467b48Spatrick LLVM_DEBUG(dbgs() << "COMPLETELY UNROLLING loop %" << Header->getName() 45709467b48Spatrick << " with trip count " << ULO.TripCount << "!\n"); 45809467b48Spatrick if (ORE) 45909467b48Spatrick ORE->emit([&]() { 46009467b48Spatrick return OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(), 46109467b48Spatrick L->getHeader()) 46209467b48Spatrick << "completely unrolled loop with " 46309467b48Spatrick << NV("UnrollCount", ULO.TripCount) << " iterations"; 46409467b48Spatrick }); 46509467b48Spatrick } else if (ULO.PeelCount) { 46609467b48Spatrick LLVM_DEBUG(dbgs() << "PEELING loop %" << Header->getName() 46709467b48Spatrick << " with iteration count " << ULO.PeelCount << "!\n"); 46809467b48Spatrick if (ORE) 46909467b48Spatrick ORE->emit([&]() { 47009467b48Spatrick return OptimizationRemark(DEBUG_TYPE, "Peeled", L->getStartLoc(), 47109467b48Spatrick L->getHeader()) 47209467b48Spatrick << " peeled loop by " << NV("PeelCount", ULO.PeelCount) 47309467b48Spatrick << " iterations"; 47409467b48Spatrick }); 47509467b48Spatrick } else { 47609467b48Spatrick auto DiagBuilder = [&]() { 47709467b48Spatrick OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(), 47809467b48Spatrick L->getHeader()); 47909467b48Spatrick return Diag << "unrolled loop by a factor of " 48009467b48Spatrick << NV("UnrollCount", ULO.Count); 48109467b48Spatrick }; 48209467b48Spatrick 48309467b48Spatrick LLVM_DEBUG(dbgs() << "UNROLLING loop %" << Header->getName() << " by " 48409467b48Spatrick << ULO.Count); 48509467b48Spatrick if (ULO.TripMultiple == 0 || BreakoutTrip != ULO.TripMultiple) { 48609467b48Spatrick LLVM_DEBUG(dbgs() << " with a breakout at trip " << BreakoutTrip); 48709467b48Spatrick if (ORE) 48809467b48Spatrick ORE->emit([&]() { 48909467b48Spatrick return DiagBuilder() << " with a breakout at trip " 49009467b48Spatrick << NV("BreakoutTrip", BreakoutTrip); 49109467b48Spatrick }); 49209467b48Spatrick } else if (ULO.TripMultiple != 1) { 49309467b48Spatrick LLVM_DEBUG(dbgs() << " with " << ULO.TripMultiple << " trips per branch"); 49409467b48Spatrick if (ORE) 49509467b48Spatrick ORE->emit([&]() { 49609467b48Spatrick return DiagBuilder() 49709467b48Spatrick << " with " << NV("TripMultiple", ULO.TripMultiple) 49809467b48Spatrick << " trips per branch"; 49909467b48Spatrick }); 50009467b48Spatrick } else if (RuntimeTripCount) { 50109467b48Spatrick LLVM_DEBUG(dbgs() << " with run-time trip count"); 50209467b48Spatrick if (ORE) 50309467b48Spatrick ORE->emit( 50409467b48Spatrick [&]() { return DiagBuilder() << " with run-time trip count"; }); 50509467b48Spatrick } 50609467b48Spatrick LLVM_DEBUG(dbgs() << "!\n"); 50709467b48Spatrick } 50809467b48Spatrick 50909467b48Spatrick // We are going to make changes to this loop. SCEV may be keeping cached info 51009467b48Spatrick // about it, in particular about backedge taken count. The changes we make 51109467b48Spatrick // are guaranteed to invalidate this information for our loop. It is tempting 51209467b48Spatrick // to only invalidate the loop being unrolled, but it is incorrect as long as 51309467b48Spatrick // all exiting branches from all inner loops have impact on the outer loops, 51409467b48Spatrick // and if something changes inside them then any of outer loops may also 51509467b48Spatrick // change. When we forget outermost loop, we also forget all contained loops 51609467b48Spatrick // and this is what we need here. 51709467b48Spatrick if (SE) { 51809467b48Spatrick if (ULO.ForgetAllSCEV) 51909467b48Spatrick SE->forgetAllLoops(); 52009467b48Spatrick else 52109467b48Spatrick SE->forgetTopmostLoop(L); 52209467b48Spatrick } 52309467b48Spatrick 524*097a140dSpatrick if (!LatchIsExiting) 525*097a140dSpatrick ++NumUnrolledNotLatch; 526*097a140dSpatrick Optional<bool> ContinueOnTrue = None; 52709467b48Spatrick BasicBlock *LoopExit = nullptr; 528*097a140dSpatrick if (ExitingBI) { 529*097a140dSpatrick ContinueOnTrue = L->contains(ExitingBI->getSuccessor(0)); 530*097a140dSpatrick LoopExit = ExitingBI->getSuccessor(*ContinueOnTrue); 53109467b48Spatrick } 53209467b48Spatrick 53309467b48Spatrick // For the first iteration of the loop, we should use the precloned values for 53409467b48Spatrick // PHI nodes. Insert associations now. 53509467b48Spatrick ValueToValueMapTy LastValueMap; 53609467b48Spatrick std::vector<PHINode*> OrigPHINode; 53709467b48Spatrick for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) { 53809467b48Spatrick OrigPHINode.push_back(cast<PHINode>(I)); 53909467b48Spatrick } 54009467b48Spatrick 54109467b48Spatrick std::vector<BasicBlock *> Headers; 542*097a140dSpatrick std::vector<BasicBlock *> ExitingBlocks; 543*097a140dSpatrick std::vector<BasicBlock *> ExitingSucc; 54409467b48Spatrick std::vector<BasicBlock *> Latches; 54509467b48Spatrick Headers.push_back(Header); 54609467b48Spatrick Latches.push_back(LatchBlock); 547*097a140dSpatrick if (ExitingBI) { 548*097a140dSpatrick ExitingBlocks.push_back(ExitingBI->getParent()); 549*097a140dSpatrick ExitingSucc.push_back(ExitingBI->getSuccessor(!(*ContinueOnTrue))); 55009467b48Spatrick } 55109467b48Spatrick 55209467b48Spatrick // The current on-the-fly SSA update requires blocks to be processed in 55309467b48Spatrick // reverse postorder so that LastValueMap contains the correct value at each 55409467b48Spatrick // exit. 55509467b48Spatrick LoopBlocksDFS DFS(L); 55609467b48Spatrick DFS.perform(LI); 55709467b48Spatrick 55809467b48Spatrick // Stash the DFS iterators before adding blocks to the loop. 55909467b48Spatrick LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO(); 56009467b48Spatrick LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO(); 56109467b48Spatrick 56209467b48Spatrick std::vector<BasicBlock*> UnrolledLoopBlocks = L->getBlocks(); 56309467b48Spatrick 56409467b48Spatrick // Loop Unrolling might create new loops. While we do preserve LoopInfo, we 56509467b48Spatrick // might break loop-simplified form for these loops (as they, e.g., would 56609467b48Spatrick // share the same exit blocks). We'll keep track of loops for which we can 56709467b48Spatrick // break this so that later we can re-simplify them. 56809467b48Spatrick SmallSetVector<Loop *, 4> LoopsToSimplify; 56909467b48Spatrick for (Loop *SubLoop : *L) 57009467b48Spatrick LoopsToSimplify.insert(SubLoop); 57109467b48Spatrick 57209467b48Spatrick if (Header->getParent()->isDebugInfoForProfiling()) 57309467b48Spatrick for (BasicBlock *BB : L->getBlocks()) 57409467b48Spatrick for (Instruction &I : *BB) 57509467b48Spatrick if (!isa<DbgInfoIntrinsic>(&I)) 57609467b48Spatrick if (const DILocation *DIL = I.getDebugLoc()) { 57709467b48Spatrick auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(ULO.Count); 57809467b48Spatrick if (NewDIL) 57909467b48Spatrick I.setDebugLoc(NewDIL.getValue()); 58009467b48Spatrick else 58109467b48Spatrick LLVM_DEBUG(dbgs() 58209467b48Spatrick << "Failed to create new discriminator: " 58309467b48Spatrick << DIL->getFilename() << " Line: " << DIL->getLine()); 58409467b48Spatrick } 58509467b48Spatrick 58609467b48Spatrick for (unsigned It = 1; It != ULO.Count; ++It) { 587*097a140dSpatrick SmallVector<BasicBlock *, 8> NewBlocks; 58809467b48Spatrick SmallDenseMap<const Loop *, Loop *, 4> NewLoops; 58909467b48Spatrick NewLoops[L] = L; 59009467b48Spatrick 59109467b48Spatrick for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { 59209467b48Spatrick ValueToValueMapTy VMap; 59309467b48Spatrick BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It)); 59409467b48Spatrick Header->getParent()->getBasicBlockList().push_back(New); 59509467b48Spatrick 59609467b48Spatrick assert((*BB != Header || LI->getLoopFor(*BB) == L) && 59709467b48Spatrick "Header should not be in a sub-loop"); 59809467b48Spatrick // Tell LI about New. 59909467b48Spatrick const Loop *OldLoop = addClonedBlockToLoopInfo(*BB, New, LI, NewLoops); 60009467b48Spatrick if (OldLoop) 60109467b48Spatrick LoopsToSimplify.insert(NewLoops[OldLoop]); 60209467b48Spatrick 60309467b48Spatrick if (*BB == Header) 60409467b48Spatrick // Loop over all of the PHI nodes in the block, changing them to use 60509467b48Spatrick // the incoming values from the previous block. 60609467b48Spatrick for (PHINode *OrigPHI : OrigPHINode) { 60709467b48Spatrick PHINode *NewPHI = cast<PHINode>(VMap[OrigPHI]); 60809467b48Spatrick Value *InVal = NewPHI->getIncomingValueForBlock(LatchBlock); 60909467b48Spatrick if (Instruction *InValI = dyn_cast<Instruction>(InVal)) 61009467b48Spatrick if (It > 1 && L->contains(InValI)) 61109467b48Spatrick InVal = LastValueMap[InValI]; 61209467b48Spatrick VMap[OrigPHI] = InVal; 61309467b48Spatrick New->getInstList().erase(NewPHI); 61409467b48Spatrick } 61509467b48Spatrick 61609467b48Spatrick // Update our running map of newest clones 61709467b48Spatrick LastValueMap[*BB] = New; 61809467b48Spatrick for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end(); 61909467b48Spatrick VI != VE; ++VI) 62009467b48Spatrick LastValueMap[VI->first] = VI->second; 62109467b48Spatrick 62209467b48Spatrick // Add phi entries for newly created values to all exit blocks. 62309467b48Spatrick for (BasicBlock *Succ : successors(*BB)) { 62409467b48Spatrick if (L->contains(Succ)) 62509467b48Spatrick continue; 62609467b48Spatrick for (PHINode &PHI : Succ->phis()) { 62709467b48Spatrick Value *Incoming = PHI.getIncomingValueForBlock(*BB); 62809467b48Spatrick ValueToValueMapTy::iterator It = LastValueMap.find(Incoming); 62909467b48Spatrick if (It != LastValueMap.end()) 63009467b48Spatrick Incoming = It->second; 63109467b48Spatrick PHI.addIncoming(Incoming, New); 63209467b48Spatrick } 63309467b48Spatrick } 63409467b48Spatrick // Keep track of new headers and latches as we create them, so that 63509467b48Spatrick // we can insert the proper branches later. 63609467b48Spatrick if (*BB == Header) 63709467b48Spatrick Headers.push_back(New); 63809467b48Spatrick if (*BB == LatchBlock) 63909467b48Spatrick Latches.push_back(New); 64009467b48Spatrick 641*097a140dSpatrick // Keep track of the exiting block and its successor block contained in 642*097a140dSpatrick // the loop for the current iteration. 643*097a140dSpatrick if (ExitingBI) { 644*097a140dSpatrick if (*BB == ExitingBlocks[0]) 645*097a140dSpatrick ExitingBlocks.push_back(New); 646*097a140dSpatrick if (*BB == ExitingSucc[0]) 647*097a140dSpatrick ExitingSucc.push_back(New); 64809467b48Spatrick } 64909467b48Spatrick 65009467b48Spatrick NewBlocks.push_back(New); 65109467b48Spatrick UnrolledLoopBlocks.push_back(New); 65209467b48Spatrick 65309467b48Spatrick // Update DomTree: since we just copy the loop body, and each copy has a 65409467b48Spatrick // dedicated entry block (copy of the header block), this header's copy 65509467b48Spatrick // dominates all copied blocks. That means, dominance relations in the 65609467b48Spatrick // copied body are the same as in the original body. 65709467b48Spatrick if (DT) { 65809467b48Spatrick if (*BB == Header) 65909467b48Spatrick DT->addNewBlock(New, Latches[It - 1]); 66009467b48Spatrick else { 66109467b48Spatrick auto BBDomNode = DT->getNode(*BB); 66209467b48Spatrick auto BBIDom = BBDomNode->getIDom(); 66309467b48Spatrick BasicBlock *OriginalBBIDom = BBIDom->getBlock(); 66409467b48Spatrick DT->addNewBlock( 66509467b48Spatrick New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)])); 66609467b48Spatrick } 66709467b48Spatrick } 66809467b48Spatrick } 66909467b48Spatrick 67009467b48Spatrick // Remap all instructions in the most recent iteration 671*097a140dSpatrick remapInstructionsInBlocks(NewBlocks, LastValueMap); 67209467b48Spatrick for (BasicBlock *NewBlock : NewBlocks) { 67309467b48Spatrick for (Instruction &I : *NewBlock) { 67409467b48Spatrick if (auto *II = dyn_cast<IntrinsicInst>(&I)) 67509467b48Spatrick if (II->getIntrinsicID() == Intrinsic::assume) 67609467b48Spatrick AC->registerAssumption(II); 67709467b48Spatrick } 67809467b48Spatrick } 67909467b48Spatrick } 68009467b48Spatrick 68109467b48Spatrick // Loop over the PHI nodes in the original block, setting incoming values. 68209467b48Spatrick for (PHINode *PN : OrigPHINode) { 68309467b48Spatrick if (CompletelyUnroll) { 68409467b48Spatrick PN->replaceAllUsesWith(PN->getIncomingValueForBlock(Preheader)); 68509467b48Spatrick Header->getInstList().erase(PN); 68609467b48Spatrick } else if (ULO.Count > 1) { 68709467b48Spatrick Value *InVal = PN->removeIncomingValue(LatchBlock, false); 68809467b48Spatrick // If this value was defined in the loop, take the value defined by the 68909467b48Spatrick // last iteration of the loop. 69009467b48Spatrick if (Instruction *InValI = dyn_cast<Instruction>(InVal)) { 69109467b48Spatrick if (L->contains(InValI)) 69209467b48Spatrick InVal = LastValueMap[InVal]; 69309467b48Spatrick } 69409467b48Spatrick assert(Latches.back() == LastValueMap[LatchBlock] && "bad last latch"); 69509467b48Spatrick PN->addIncoming(InVal, Latches.back()); 69609467b48Spatrick } 69709467b48Spatrick } 69809467b48Spatrick 699*097a140dSpatrick auto setDest = [](BasicBlock *Src, BasicBlock *Dest, BasicBlock *BlockInLoop, 700*097a140dSpatrick bool NeedConditional, Optional<bool> ContinueOnTrue, 701*097a140dSpatrick bool IsDestLoopExit) { 70209467b48Spatrick auto *Term = cast<BranchInst>(Src->getTerminator()); 70309467b48Spatrick if (NeedConditional) { 70409467b48Spatrick // Update the conditional branch's successor for the following 70509467b48Spatrick // iteration. 706*097a140dSpatrick assert(ContinueOnTrue.hasValue() && 707*097a140dSpatrick "Expecting valid ContinueOnTrue when NeedConditional is true"); 708*097a140dSpatrick Term->setSuccessor(!(*ContinueOnTrue), Dest); 70909467b48Spatrick } else { 71009467b48Spatrick // Remove phi operands at this loop exit 711*097a140dSpatrick if (!IsDestLoopExit) { 71209467b48Spatrick BasicBlock *BB = Src; 71309467b48Spatrick for (BasicBlock *Succ : successors(BB)) { 71409467b48Spatrick // Preserve the incoming value from BB if we are jumping to the block 71509467b48Spatrick // in the current loop. 71609467b48Spatrick if (Succ == BlockInLoop) 71709467b48Spatrick continue; 71809467b48Spatrick for (PHINode &Phi : Succ->phis()) 71909467b48Spatrick Phi.removeIncomingValue(BB, false); 72009467b48Spatrick } 72109467b48Spatrick } 72209467b48Spatrick // Replace the conditional branch with an unconditional one. 72309467b48Spatrick BranchInst::Create(Dest, Term); 72409467b48Spatrick Term->eraseFromParent(); 72509467b48Spatrick } 72609467b48Spatrick }; 72709467b48Spatrick 728*097a140dSpatrick // Connect latches of the unrolled iterations to the headers of the next 729*097a140dSpatrick // iteration. If the latch is also the exiting block, the conditional branch 730*097a140dSpatrick // may have to be preserved. 73109467b48Spatrick for (unsigned i = 0, e = Latches.size(); i != e; ++i) { 73209467b48Spatrick // The branch destination. 73309467b48Spatrick unsigned j = (i + 1) % e; 73409467b48Spatrick BasicBlock *Dest = Headers[j]; 735*097a140dSpatrick bool NeedConditional = LatchIsExiting; 73609467b48Spatrick 737*097a140dSpatrick if (LatchIsExiting) { 738*097a140dSpatrick if (RuntimeTripCount && j != 0) 73909467b48Spatrick NeedConditional = false; 74009467b48Spatrick 74109467b48Spatrick // For a complete unroll, make the last iteration end with a branch 74209467b48Spatrick // to the exit block. 74309467b48Spatrick if (CompletelyUnroll) { 74409467b48Spatrick if (j == 0) 74509467b48Spatrick Dest = LoopExit; 746*097a140dSpatrick // If using trip count upper bound to completely unroll, we need to 747*097a140dSpatrick // keep the conditional branch except the last one because the loop 748*097a140dSpatrick // may exit after any iteration. 74909467b48Spatrick assert(NeedConditional && 75009467b48Spatrick "NeedCondition cannot be modified by both complete " 75109467b48Spatrick "unrolling and runtime unrolling"); 75209467b48Spatrick NeedConditional = 75309467b48Spatrick (ULO.PreserveCondBr && j && !(ULO.PreserveOnlyFirst && i != 0)); 75409467b48Spatrick } else if (j != BreakoutTrip && 75509467b48Spatrick (ULO.TripMultiple == 0 || j % ULO.TripMultiple != 0)) { 75609467b48Spatrick // If we know the trip count or a multiple of it, we can safely use an 75709467b48Spatrick // unconditional branch for some iterations. 75809467b48Spatrick NeedConditional = false; 75909467b48Spatrick } 76009467b48Spatrick } 761*097a140dSpatrick 762*097a140dSpatrick setDest(Latches[i], Dest, Headers[i], NeedConditional, ContinueOnTrue, 763*097a140dSpatrick Dest == LoopExit); 764*097a140dSpatrick } 765*097a140dSpatrick 766*097a140dSpatrick if (!LatchIsExiting) { 767*097a140dSpatrick // If the latch is not exiting, we may be able to simplify the conditional 768*097a140dSpatrick // branches in the unrolled exiting blocks. 769*097a140dSpatrick for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { 77009467b48Spatrick // The branch destination. 77109467b48Spatrick unsigned j = (i + 1) % e; 77209467b48Spatrick bool NeedConditional = true; 77309467b48Spatrick 77409467b48Spatrick if (RuntimeTripCount && j != 0) 77509467b48Spatrick NeedConditional = false; 77609467b48Spatrick 77709467b48Spatrick if (CompletelyUnroll) 77809467b48Spatrick // We cannot drop the conditional branch for the last condition, as we 77909467b48Spatrick // may have to execute the loop body depending on the condition. 78009467b48Spatrick NeedConditional = j == 0 || ULO.PreserveCondBr; 78109467b48Spatrick else if (j != BreakoutTrip && 78209467b48Spatrick (ULO.TripMultiple == 0 || j % ULO.TripMultiple != 0)) 78309467b48Spatrick // If we know the trip count or a multiple of it, we can safely use an 78409467b48Spatrick // unconditional branch for some iterations. 78509467b48Spatrick NeedConditional = false; 78609467b48Spatrick 787*097a140dSpatrick // Conditional branches from non-latch exiting block have successors 788*097a140dSpatrick // either in the same loop iteration or outside the loop. The branches are 789*097a140dSpatrick // already correct. 790*097a140dSpatrick if (NeedConditional) 791*097a140dSpatrick continue; 792*097a140dSpatrick setDest(ExitingBlocks[i], ExitingSucc[i], ExitingSucc[i], NeedConditional, 793*097a140dSpatrick None, false); 79409467b48Spatrick } 79509467b48Spatrick 79609467b48Spatrick // When completely unrolling, the last latch becomes unreachable. 797*097a140dSpatrick if (CompletelyUnroll) { 798*097a140dSpatrick BranchInst *Term = cast<BranchInst>(Latches.back()->getTerminator()); 79909467b48Spatrick new UnreachableInst(Term->getContext(), Term); 80009467b48Spatrick Term->eraseFromParent(); 80109467b48Spatrick } 80209467b48Spatrick } 80309467b48Spatrick 80409467b48Spatrick // Update dominators of blocks we might reach through exits. 80509467b48Spatrick // Immediate dominator of such block might change, because we add more 80609467b48Spatrick // routes which can lead to the exit: we can now reach it from the copied 80709467b48Spatrick // iterations too. 80809467b48Spatrick if (DT && ULO.Count > 1) { 80909467b48Spatrick for (auto *BB : OriginalLoopBlocks) { 81009467b48Spatrick auto *BBDomNode = DT->getNode(BB); 81109467b48Spatrick SmallVector<BasicBlock *, 16> ChildrenToUpdate; 812*097a140dSpatrick for (auto *ChildDomNode : BBDomNode->children()) { 81309467b48Spatrick auto *ChildBB = ChildDomNode->getBlock(); 81409467b48Spatrick if (!L->contains(ChildBB)) 81509467b48Spatrick ChildrenToUpdate.push_back(ChildBB); 81609467b48Spatrick } 81709467b48Spatrick BasicBlock *NewIDom; 818*097a140dSpatrick if (ExitingBI && BB == ExitingBlocks[0]) { 81909467b48Spatrick // The latch is special because we emit unconditional branches in 82009467b48Spatrick // some cases where the original loop contained a conditional branch. 82109467b48Spatrick // Since the latch is always at the bottom of the loop, if the latch 82209467b48Spatrick // dominated an exit before unrolling, the new dominator of that exit 82309467b48Spatrick // must also be a latch. Specifically, the dominator is the first 82409467b48Spatrick // latch which ends in a conditional branch, or the last latch if 82509467b48Spatrick // there is no such latch. 826*097a140dSpatrick // For loops exiting from non latch exiting block, we limit the 827*097a140dSpatrick // branch simplification to single exiting block loops. 828*097a140dSpatrick NewIDom = ExitingBlocks.back(); 829*097a140dSpatrick for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { 830*097a140dSpatrick Instruction *Term = ExitingBlocks[i]->getTerminator(); 83109467b48Spatrick if (isa<BranchInst>(Term) && cast<BranchInst>(Term)->isConditional()) { 832*097a140dSpatrick NewIDom = 833*097a140dSpatrick DT->findNearestCommonDominator(ExitingBlocks[i], Latches[i]); 83409467b48Spatrick break; 83509467b48Spatrick } 83609467b48Spatrick } 83709467b48Spatrick } else { 83809467b48Spatrick // The new idom of the block will be the nearest common dominator 83909467b48Spatrick // of all copies of the previous idom. This is equivalent to the 84009467b48Spatrick // nearest common dominator of the previous idom and the first latch, 84109467b48Spatrick // which dominates all copies of the previous idom. 84209467b48Spatrick NewIDom = DT->findNearestCommonDominator(BB, LatchBlock); 84309467b48Spatrick } 84409467b48Spatrick for (auto *ChildBB : ChildrenToUpdate) 84509467b48Spatrick DT->changeImmediateDominator(ChildBB, NewIDom); 84609467b48Spatrick } 84709467b48Spatrick } 84809467b48Spatrick 84909467b48Spatrick assert(!DT || !UnrollVerifyDomtree || 85009467b48Spatrick DT->verify(DominatorTree::VerificationLevel::Fast)); 85109467b48Spatrick 85209467b48Spatrick DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 85309467b48Spatrick // Merge adjacent basic blocks, if possible. 85409467b48Spatrick for (BasicBlock *Latch : Latches) { 85509467b48Spatrick BranchInst *Term = dyn_cast<BranchInst>(Latch->getTerminator()); 85609467b48Spatrick assert((Term || 85709467b48Spatrick (CompletelyUnroll && !LatchIsExiting && Latch == Latches.back())) && 85809467b48Spatrick "Need a branch as terminator, except when fully unrolling with " 85909467b48Spatrick "unconditional latch"); 86009467b48Spatrick if (Term && Term->isUnconditional()) { 86109467b48Spatrick BasicBlock *Dest = Term->getSuccessor(0); 86209467b48Spatrick BasicBlock *Fold = Dest->getUniquePredecessor(); 86309467b48Spatrick if (MergeBlockIntoPredecessor(Dest, &DTU, LI)) { 86409467b48Spatrick // Dest has been folded into Fold. Update our worklists accordingly. 86509467b48Spatrick std::replace(Latches.begin(), Latches.end(), Dest, Fold); 86609467b48Spatrick UnrolledLoopBlocks.erase(std::remove(UnrolledLoopBlocks.begin(), 86709467b48Spatrick UnrolledLoopBlocks.end(), Dest), 86809467b48Spatrick UnrolledLoopBlocks.end()); 86909467b48Spatrick } 87009467b48Spatrick } 87109467b48Spatrick } 87209467b48Spatrick // Apply updates to the DomTree. 87309467b48Spatrick DT = &DTU.getDomTree(); 87409467b48Spatrick 87509467b48Spatrick // At this point, the code is well formed. We now simplify the unrolled loop, 87609467b48Spatrick // doing constant propagation and dead code elimination as we go. 87709467b48Spatrick simplifyLoopAfterUnroll(L, !CompletelyUnroll && (ULO.Count > 1 || Peeled), LI, 878*097a140dSpatrick SE, DT, AC, TTI); 87909467b48Spatrick 88009467b48Spatrick NumCompletelyUnrolled += CompletelyUnroll; 88109467b48Spatrick ++NumUnrolled; 88209467b48Spatrick 88309467b48Spatrick Loop *OuterL = L->getParentLoop(); 88409467b48Spatrick // Update LoopInfo if the loop is completely removed. 88509467b48Spatrick if (CompletelyUnroll) 88609467b48Spatrick LI->erase(L); 88709467b48Spatrick 88809467b48Spatrick // After complete unrolling most of the blocks should be contained in OuterL. 88909467b48Spatrick // However, some of them might happen to be out of OuterL (e.g. if they 89009467b48Spatrick // precede a loop exit). In this case we might need to insert PHI nodes in 89109467b48Spatrick // order to preserve LCSSA form. 89209467b48Spatrick // We don't need to check this if we already know that we need to fix LCSSA 89309467b48Spatrick // form. 89409467b48Spatrick // TODO: For now we just recompute LCSSA for the outer loop in this case, but 89509467b48Spatrick // it should be possible to fix it in-place. 89609467b48Spatrick if (PreserveLCSSA && OuterL && CompletelyUnroll && !NeedToFixLCSSA) 89709467b48Spatrick NeedToFixLCSSA |= ::needToInsertPhisForLCSSA(OuterL, UnrolledLoopBlocks, LI); 89809467b48Spatrick 89909467b48Spatrick // If we have a pass and a DominatorTree we should re-simplify impacted loops 90009467b48Spatrick // to ensure subsequent analyses can rely on this form. We want to simplify 90109467b48Spatrick // at least one layer outside of the loop that was unrolled so that any 90209467b48Spatrick // changes to the parent loop exposed by the unrolling are considered. 90309467b48Spatrick if (DT) { 90409467b48Spatrick if (OuterL) { 90509467b48Spatrick // OuterL includes all loops for which we can break loop-simplify, so 90609467b48Spatrick // it's sufficient to simplify only it (it'll recursively simplify inner 90709467b48Spatrick // loops too). 90809467b48Spatrick if (NeedToFixLCSSA) { 90909467b48Spatrick // LCSSA must be performed on the outermost affected loop. The unrolled 91009467b48Spatrick // loop's last loop latch is guaranteed to be in the outermost loop 91109467b48Spatrick // after LoopInfo's been updated by LoopInfo::erase. 91209467b48Spatrick Loop *LatchLoop = LI->getLoopFor(Latches.back()); 91309467b48Spatrick Loop *FixLCSSALoop = OuterL; 91409467b48Spatrick if (!FixLCSSALoop->contains(LatchLoop)) 91509467b48Spatrick while (FixLCSSALoop->getParentLoop() != LatchLoop) 91609467b48Spatrick FixLCSSALoop = FixLCSSALoop->getParentLoop(); 91709467b48Spatrick 91809467b48Spatrick formLCSSARecursively(*FixLCSSALoop, *DT, LI, SE); 91909467b48Spatrick } else if (PreserveLCSSA) { 92009467b48Spatrick assert(OuterL->isLCSSAForm(*DT) && 92109467b48Spatrick "Loops should be in LCSSA form after loop-unroll."); 92209467b48Spatrick } 92309467b48Spatrick 92409467b48Spatrick // TODO: That potentially might be compile-time expensive. We should try 92509467b48Spatrick // to fix the loop-simplified form incrementally. 92609467b48Spatrick simplifyLoop(OuterL, DT, LI, SE, AC, nullptr, PreserveLCSSA); 92709467b48Spatrick } else { 92809467b48Spatrick // Simplify loops for which we might've broken loop-simplify form. 92909467b48Spatrick for (Loop *SubLoop : LoopsToSimplify) 93009467b48Spatrick simplifyLoop(SubLoop, DT, LI, SE, AC, nullptr, PreserveLCSSA); 93109467b48Spatrick } 93209467b48Spatrick } 93309467b48Spatrick 93409467b48Spatrick return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled 93509467b48Spatrick : LoopUnrollResult::PartiallyUnrolled; 93609467b48Spatrick } 93709467b48Spatrick 93809467b48Spatrick /// Given an llvm.loop loop id metadata node, returns the loop hint metadata 93909467b48Spatrick /// node with the given name (for example, "llvm.loop.unroll.count"). If no 94009467b48Spatrick /// such metadata node exists, then nullptr is returned. 94109467b48Spatrick MDNode *llvm::GetUnrollMetadata(MDNode *LoopID, StringRef Name) { 94209467b48Spatrick // First operand should refer to the loop id itself. 94309467b48Spatrick assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); 94409467b48Spatrick assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); 94509467b48Spatrick 94609467b48Spatrick for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) { 94709467b48Spatrick MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); 94809467b48Spatrick if (!MD) 94909467b48Spatrick continue; 95009467b48Spatrick 95109467b48Spatrick MDString *S = dyn_cast<MDString>(MD->getOperand(0)); 95209467b48Spatrick if (!S) 95309467b48Spatrick continue; 95409467b48Spatrick 95509467b48Spatrick if (Name.equals(S->getString())) 95609467b48Spatrick return MD; 95709467b48Spatrick } 95809467b48Spatrick return nullptr; 95909467b48Spatrick } 960