xref: /llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp (revision d53b4bee0ccd408cfe6e592540858046244e74ce)
1 //===- LoopFlatten.cpp - Loop flattening pass------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass flattens pairs nested loops into a single loop.
10 //
11 // The intention is to optimise loop nests like this, which together access an
12 // array linearly:
13 //   for (int i = 0; i < N; ++i)
14 //     for (int j = 0; j < M; ++j)
15 //       f(A[i*M+j]);
16 // into one loop:
17 //   for (int i = 0; i < (N*M); ++i)
18 //     f(A[i]);
19 //
20 // It can also flatten loops where the induction variables are not used in the
21 // loop. This is only worth doing if the induction variables are only used in an
22 // expression like i*M+j. If they had any other uses, we would have to insert a
23 // div/mod to reconstruct the original values, so this wouldn't be profitable.
24 //
25 // We also need to prove that N*M will not overflow.
26 //
27 //===----------------------------------------------------------------------===//
28 
29 #include "llvm/Transforms/Scalar/LoopFlatten.h"
30 #include "llvm/Analysis/AssumptionCache.h"
31 #include "llvm/Analysis/LoopInfo.h"
32 #include "llvm/Analysis/LoopPass.h"
33 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
34 #include "llvm/Analysis/ScalarEvolution.h"
35 #include "llvm/Analysis/TargetTransformInfo.h"
36 #include "llvm/Analysis/ValueTracking.h"
37 #include "llvm/IR/Dominators.h"
38 #include "llvm/IR/Function.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/PatternMatch.h"
41 #include "llvm/IR/Verifier.h"
42 #include "llvm/InitializePasses.h"
43 #include "llvm/Pass.h"
44 #include "llvm/Support/Debug.h"
45 #include "llvm/Support/raw_ostream.h"
46 #include "llvm/Transforms/Scalar.h"
47 #include "llvm/Transforms/Utils/LoopUtils.h"
48 
49 #define DEBUG_TYPE "loop-flatten"
50 
51 using namespace llvm;
52 using namespace llvm::PatternMatch;
53 
54 static cl::opt<unsigned> RepeatedInstructionThreshold(
55     "loop-flatten-cost-threshold", cl::Hidden, cl::init(2),
56     cl::desc("Limit on the cost of instructions that can be repeated due to "
57              "loop flattening"));
58 
59 static cl::opt<bool>
60     AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden,
61                      cl::init(false),
62                      cl::desc("Assume that the product of the two iteration "
63                               "limits will never overflow"));
64 
65 // Finds the induction variable, increment and limit for a simple loop that we
66 // can flatten.
67 static bool findLoopComponents(
68     Loop *L, SmallPtrSetImpl<Instruction *> &IterationInstructions,
69     PHINode *&InductionPHI, Value *&Limit, BinaryOperator *&Increment,
70     BranchInst *&BackBranch, ScalarEvolution *SE) {
71   LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n");
72 
73   if (!L->isLoopSimplifyForm()) {
74     LLVM_DEBUG(dbgs() << "Loop is not in normal form\n");
75     return false;
76   }
77 
78   // There must be exactly one exiting block, and it must be the same at the
79   // latch.
80   BasicBlock *Latch = L->getLoopLatch();
81   if (L->getExitingBlock() != Latch) {
82     LLVM_DEBUG(dbgs() << "Exiting and latch block are different\n");
83     return false;
84   }
85   // Latch block must end in a conditional branch.
86   BackBranch = dyn_cast<BranchInst>(Latch->getTerminator());
87   if (!BackBranch || !BackBranch->isConditional()) {
88     LLVM_DEBUG(dbgs() << "Could not find back-branch\n");
89     return false;
90   }
91   IterationInstructions.insert(BackBranch);
92   LLVM_DEBUG(dbgs() << "Found back branch: "; BackBranch->dump());
93   bool ContinueOnTrue = L->contains(BackBranch->getSuccessor(0));
94 
95   // Find the induction PHI. If there is no induction PHI, we can't do the
96   // transformation. TODO: could other variables trigger this? Do we have to
97   // search for the best one?
98   InductionPHI = nullptr;
99   for (PHINode &PHI : L->getHeader()->phis()) {
100     InductionDescriptor ID;
101     if (InductionDescriptor::isInductionPHI(&PHI, L, SE, ID)) {
102       InductionPHI = &PHI;
103       LLVM_DEBUG(dbgs() << "Found induction PHI: "; InductionPHI->dump());
104       break;
105     }
106   }
107   if (!InductionPHI) {
108     LLVM_DEBUG(dbgs() << "Could not find induction PHI\n");
109     return false;
110   }
111 
112   auto IsValidPredicate = [&](ICmpInst::Predicate Pred) {
113     if (ContinueOnTrue)
114       return Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT;
115     else
116       return Pred == CmpInst::ICMP_EQ;
117   };
118 
119   // Find Compare and make sure it is valid
120   ICmpInst *Compare = dyn_cast<ICmpInst>(BackBranch->getCondition());
121   if (!Compare || !IsValidPredicate(Compare->getUnsignedPredicate()) ||
122       Compare->hasNUsesOrMore(2)) {
123     LLVM_DEBUG(dbgs() << "Could not find valid comparison\n");
124     return false;
125   }
126   IterationInstructions.insert(Compare);
127   LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump());
128 
129   // Find increment and limit from the compare
130   Increment = nullptr;
131   if (match(Compare->getOperand(0),
132             m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) {
133     Increment = dyn_cast<BinaryOperator>(Compare->getOperand(0));
134     Limit = Compare->getOperand(1);
135   } else if (Compare->getUnsignedPredicate() == CmpInst::ICMP_NE &&
136              match(Compare->getOperand(1),
137                    m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) {
138     Increment = dyn_cast<BinaryOperator>(Compare->getOperand(1));
139     Limit = Compare->getOperand(0);
140   }
141   if (!Increment || Increment->hasNUsesOrMore(3)) {
142     LLVM_DEBUG(dbgs() << "Cound not find valid increment\n");
143     return false;
144   }
145   IterationInstructions.insert(Increment);
146   LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump());
147   LLVM_DEBUG(dbgs() << "Found limit: "; Limit->dump());
148 
149   assert(InductionPHI->getNumIncomingValues() == 2);
150   assert(InductionPHI->getIncomingValueForBlock(Latch) == Increment &&
151          "PHI value is not increment inst");
152 
153   auto *CI = dyn_cast<ConstantInt>(
154       InductionPHI->getIncomingValueForBlock(L->getLoopPreheader()));
155   if (!CI || !CI->isZero()) {
156     LLVM_DEBUG(dbgs() << "PHI value is not zero: "; CI->dump());
157     return false;
158   }
159 
160   LLVM_DEBUG(dbgs() << "Successfully found all loop components\n");
161   return true;
162 }
163 
164 static bool checkPHIs(Loop *OuterLoop, Loop *InnerLoop,
165                       SmallPtrSetImpl<PHINode *> &InnerPHIsToTransform,
166                       PHINode *InnerInductionPHI, PHINode *OuterInductionPHI,
167                       TargetTransformInfo *TTI) {
168   // All PHIs in the inner and outer headers must either be:
169   // - The induction PHI, which we are going to rewrite as one induction in
170   //   the new loop. This is already checked by findLoopComponents.
171   // - An outer header PHI with all incoming values from outside the loop.
172   //   LoopSimplify guarantees we have a pre-header, so we don't need to
173   //   worry about that here.
174   // - Pairs of PHIs in the inner and outer headers, which implement a
175   //   loop-carried dependency that will still be valid in the new loop. To
176   //   be valid, this variable must be modified only in the inner loop.
177 
178   // The set of PHI nodes in the outer loop header that we know will still be
179   // valid after the transformation. These will not need to be modified (with
180   // the exception of the induction variable), but we do need to check that
181   // there are no unsafe PHI nodes.
182   SmallPtrSet<PHINode *, 4> SafeOuterPHIs;
183   SafeOuterPHIs.insert(OuterInductionPHI);
184 
185   // Check that all PHI nodes in the inner loop header match one of the valid
186   // patterns.
187   for (PHINode &InnerPHI : InnerLoop->getHeader()->phis()) {
188     // The induction PHIs break these rules, and that's OK because we treat
189     // them specially when doing the transformation.
190     if (&InnerPHI == InnerInductionPHI)
191       continue;
192 
193     // Each inner loop PHI node must have two incoming values/blocks - one
194     // from the pre-header, and one from the latch.
195     assert(InnerPHI.getNumIncomingValues() == 2);
196     Value *PreHeaderValue =
197         InnerPHI.getIncomingValueForBlock(InnerLoop->getLoopPreheader());
198     Value *LatchValue =
199         InnerPHI.getIncomingValueForBlock(InnerLoop->getLoopLatch());
200 
201     // The incoming value from the outer loop must be the PHI node in the
202     // outer loop header, with no modifications made in the top of the outer
203     // loop.
204     PHINode *OuterPHI = dyn_cast<PHINode>(PreHeaderValue);
205     if (!OuterPHI || OuterPHI->getParent() != OuterLoop->getHeader()) {
206       LLVM_DEBUG(dbgs() << "value modified in top of outer loop\n");
207       return false;
208     }
209 
210     // The other incoming value must come from the inner loop, without any
211     // modifications in the tail end of the outer loop. We are in LCSSA form,
212     // so this will actually be a PHI in the inner loop's exit block, which
213     // only uses values from inside the inner loop.
214     PHINode *LCSSAPHI = dyn_cast<PHINode>(
215         OuterPHI->getIncomingValueForBlock(OuterLoop->getLoopLatch()));
216     if (!LCSSAPHI) {
217       LLVM_DEBUG(dbgs() << "could not find LCSSA PHI\n");
218       return false;
219     }
220 
221     // The value used by the LCSSA PHI must be the same one that the inner
222     // loop's PHI uses.
223     if (LCSSAPHI->hasConstantValue() != LatchValue) {
224       LLVM_DEBUG(
225           dbgs() << "LCSSA PHI incoming value does not match latch value\n");
226       return false;
227     }
228 
229     LLVM_DEBUG(dbgs() << "PHI pair is safe:\n");
230     LLVM_DEBUG(dbgs() << "  Inner: "; InnerPHI.dump());
231     LLVM_DEBUG(dbgs() << "  Outer: "; OuterPHI->dump());
232     SafeOuterPHIs.insert(OuterPHI);
233     InnerPHIsToTransform.insert(&InnerPHI);
234   }
235 
236   for (PHINode &OuterPHI : OuterLoop->getHeader()->phis()) {
237     if (!SafeOuterPHIs.count(&OuterPHI)) {
238       LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: "; OuterPHI.dump());
239       return false;
240     }
241   }
242 
243   return true;
244 }
245 
246 static bool
247 checkOuterLoopInsts(Loop *OuterLoop, Loop *InnerLoop,
248                     SmallPtrSetImpl<Instruction *> &IterationInstructions,
249                     Value *InnerLimit, PHINode *OuterPHI,
250                     TargetTransformInfo *TTI) {
251   // Check for instructions in the outer but not inner loop. If any of these
252   // have side-effects then this transformation is not legal, and if there is
253   // a significant amount of code here which can't be optimised out that it's
254   // not profitable (as these instructions would get executed for each
255   // iteration of the inner loop).
256   unsigned RepeatedInstrCost = 0;
257   for (auto *B : OuterLoop->getBlocks()) {
258     if (InnerLoop->contains(B))
259       continue;
260 
261     for (auto &I : *B) {
262       if (!isa<PHINode>(&I) && !I.isTerminator() &&
263           !isSafeToSpeculativelyExecute(&I)) {
264         LLVM_DEBUG(dbgs() << "Cannot flatten because instruction may have "
265                              "side effects: ";
266                    I.dump());
267         return false;
268       }
269       // The execution count of the outer loop's iteration instructions
270       // (increment, compare and branch) will be increased, but the
271       // equivalent instructions will be removed from the inner loop, so
272       // they make a net difference of zero.
273       if (IterationInstructions.count(&I))
274         continue;
275       // The uncoditional branch to the inner loop's header will turn into
276       // a fall-through, so adds no cost.
277       BranchInst *Br = dyn_cast<BranchInst>(&I);
278       if (Br && Br->isUnconditional() &&
279           Br->getSuccessor(0) == InnerLoop->getHeader())
280         continue;
281       // Multiplies of the outer iteration variable and inner iteration
282       // count will be optimised out.
283       if (match(&I, m_c_Mul(m_Specific(OuterPHI), m_Specific(InnerLimit))))
284         continue;
285       int Cost = TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
286       LLVM_DEBUG(dbgs() << "Cost " << Cost << ": "; I.dump());
287       RepeatedInstrCost += Cost;
288     }
289   }
290 
291   LLVM_DEBUG(dbgs() << "Cost of instructions that will be repeated: "
292                     << RepeatedInstrCost << "\n");
293   // Bail out if flattening the loops would cause instructions in the outer
294   // loop but not in the inner loop to be executed extra times.
295   if (RepeatedInstrCost > RepeatedInstructionThreshold)
296     return false;
297 
298   return true;
299 }
300 
301 static bool checkIVUsers(PHINode *InnerPHI, PHINode *OuterPHI,
302                          BinaryOperator *InnerIncrement,
303                          BinaryOperator *OuterIncrement, Value *InnerLimit,
304                          SmallPtrSetImpl<Value *> &LinearIVUses) {
305   // We require all uses of both induction variables to match this pattern:
306   //
307   //   (OuterPHI * InnerLimit) + InnerPHI
308   //
309   // Any uses of the induction variables not matching that pattern would
310   // require a div/mod to reconstruct in the flattened loop, so the
311   // transformation wouldn't be profitable.
312 
313   // Check that all uses of the inner loop's induction variable match the
314   // expected pattern, recording the uses of the outer IV.
315   SmallPtrSet<Value *, 4> ValidOuterPHIUses;
316   for (User *U : InnerPHI->users()) {
317     if (U == InnerIncrement)
318       continue;
319 
320     LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump());
321 
322     Value *MatchedMul, *MatchedItCount;
323     if (match(U, m_c_Add(m_Specific(InnerPHI), m_Value(MatchedMul))) &&
324         match(MatchedMul,
325               m_c_Mul(m_Specific(OuterPHI), m_Value(MatchedItCount))) &&
326         MatchedItCount == InnerLimit) {
327       LLVM_DEBUG(dbgs() << "Use is optimisable\n");
328       ValidOuterPHIUses.insert(MatchedMul);
329       LinearIVUses.insert(U);
330     } else {
331       LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
332       return false;
333     }
334   }
335 
336   // Check that there are no uses of the outer IV other than the ones found
337   // as part of the pattern above.
338   for (User *U : OuterPHI->users()) {
339     if (U == OuterIncrement)
340       continue;
341 
342     LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
343 
344     if (!ValidOuterPHIUses.count(U)) {
345       LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
346       return false;
347     } else {
348       LLVM_DEBUG(dbgs() << "Use is optimisable\n");
349     }
350   }
351 
352   LLVM_DEBUG(dbgs() << "Found " << LinearIVUses.size()
353                     << " value(s) that can be replaced:\n";
354              for (Value *V : LinearIVUses) {
355                dbgs() << "  ";
356                V->dump();
357              });
358 
359   return true;
360 }
361 
362 // Return an OverflowResult dependant on if overflow of the multiplication of
363 // InnerLimit and OuterLimit can be assumed not to happen.
364 static OverflowResult checkOverflow(Loop *OuterLoop, Value *InnerLimit,
365                                     Value *OuterLimit,
366                                     SmallPtrSetImpl<Value *> &LinearIVUses,
367                                     DominatorTree *DT, AssumptionCache *AC) {
368   Function *F = OuterLoop->getHeader()->getParent();
369   const DataLayout &DL = F->getParent()->getDataLayout();
370 
371   // For debugging/testing.
372   if (AssumeNoOverflow)
373     return OverflowResult::NeverOverflows;
374 
375   // Check if the multiply could not overflow due to known ranges of the
376   // input values.
377   OverflowResult OR = computeOverflowForUnsignedMul(
378       InnerLimit, OuterLimit, DL, AC,
379       OuterLoop->getLoopPreheader()->getTerminator(), DT);
380   if (OR != OverflowResult::MayOverflow)
381     return OR;
382 
383   for (Value *V : LinearIVUses) {
384     for (Value *U : V->users()) {
385       if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
386         // The IV is used as the operand of a GEP, and the IV is at least as
387         // wide as the address space of the GEP. In this case, the GEP would
388         // wrap around the address space before the IV increment wraps, which
389         // would be UB.
390         if (GEP->isInBounds() &&
391             V->getType()->getIntegerBitWidth() >=
392                 DL.getPointerTypeSizeInBits(GEP->getType())) {
393           LLVM_DEBUG(
394               dbgs() << "use of linear IV would be UB if overflow occurred: ";
395               GEP->dump());
396           return OverflowResult::NeverOverflows;
397         }
398       }
399     }
400   }
401 
402   return OverflowResult::MayOverflow;
403 }
404 
405 static bool FlattenLoopPair(Loop *OuterLoop, Loop *InnerLoop, DominatorTree *DT,
406                             LoopInfo *LI, ScalarEvolution *SE,
407                             AssumptionCache *AC, TargetTransformInfo *TTI,
408                             std::function<void(Loop *)> markLoopAsDeleted) {
409   Function *F = OuterLoop->getHeader()->getParent();
410 
411   LLVM_DEBUG(dbgs() << "Loop flattening running on outer loop "
412                     << OuterLoop->getHeader()->getName() << " and inner loop "
413                     << InnerLoop->getHeader()->getName() << " in "
414                     << F->getName() << "\n");
415 
416   SmallPtrSet<Instruction *, 8> IterationInstructions;
417 
418   PHINode *InnerInductionPHI, *OuterInductionPHI;
419   Value *InnerLimit, *OuterLimit;
420   BinaryOperator *InnerIncrement, *OuterIncrement;
421   BranchInst *InnerBranch, *OuterBranch;
422 
423   if (!findLoopComponents(InnerLoop, IterationInstructions, InnerInductionPHI,
424                           InnerLimit, InnerIncrement, InnerBranch, SE))
425     return false;
426   if (!findLoopComponents(OuterLoop, IterationInstructions, OuterInductionPHI,
427                           OuterLimit, OuterIncrement, OuterBranch, SE))
428     return false;
429 
430   // Both of the loop limit values must be invariant in the outer loop
431   // (non-instructions are all inherently invariant).
432   if (!OuterLoop->isLoopInvariant(InnerLimit)) {
433     LLVM_DEBUG(dbgs() << "inner loop limit not invariant\n");
434     return false;
435   }
436   if (!OuterLoop->isLoopInvariant(OuterLimit)) {
437     LLVM_DEBUG(dbgs() << "outer loop limit not invariant\n");
438     return false;
439   }
440 
441   SmallPtrSet<PHINode *, 4> InnerPHIsToTransform;
442   if (!checkPHIs(OuterLoop, InnerLoop, InnerPHIsToTransform, InnerInductionPHI,
443                  OuterInductionPHI, TTI))
444     return false;
445 
446   // FIXME: it should be possible to handle different types correctly.
447   if (InnerInductionPHI->getType() != OuterInductionPHI->getType())
448     return false;
449 
450   if (!checkOuterLoopInsts(OuterLoop, InnerLoop, IterationInstructions,
451                            InnerLimit, OuterInductionPHI, TTI))
452     return false;
453 
454   // Find the values in the loop that can be replaced with the linearized
455   // induction variable, and check that there are no other uses of the inner
456   // or outer induction variable. If there were, we could still do this
457   // transformation, but we'd have to insert a div/mod to calculate the
458   // original IVs, so it wouldn't be profitable.
459   SmallPtrSet<Value *, 4> LinearIVUses;
460   if (!checkIVUsers(InnerInductionPHI, OuterInductionPHI, InnerIncrement,
461                     OuterIncrement, InnerLimit, LinearIVUses))
462     return false;
463 
464   // Check if the new iteration variable might overflow. In this case, we
465   // need to version the loop, and select the original version at runtime if
466   // the iteration space is too large.
467   // TODO: We currently don't version the loop.
468   // TODO: it might be worth using a wider iteration variable rather than
469   // versioning the loop, if a wide enough type is legal.
470   bool MustVersionLoop = true;
471   OverflowResult OR =
472       checkOverflow(OuterLoop, InnerLimit, OuterLimit, LinearIVUses, DT, AC);
473   if (OR == OverflowResult::AlwaysOverflowsHigh ||
474       OR == OverflowResult::AlwaysOverflowsLow) {
475     LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n");
476     return false;
477   } else if (OR == OverflowResult::MayOverflow) {
478     LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n");
479   } else {
480     LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n");
481     MustVersionLoop = false;
482   }
483 
484   // We cannot safely flatten the loop. Exit now.
485   if (MustVersionLoop)
486     return false;
487 
488   // Do the actual transformation.
489   LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n");
490 
491   {
492     using namespace ore;
493     OptimizationRemark Remark(DEBUG_TYPE, "Flattened", InnerLoop->getStartLoc(),
494                               InnerLoop->getHeader());
495     OptimizationRemarkEmitter ORE(F);
496     Remark << "Flattened into outer loop";
497     ORE.emit(Remark);
498   }
499 
500   Value *NewTripCount =
501       BinaryOperator::CreateMul(InnerLimit, OuterLimit, "flatten.tripcount",
502                                 OuterLoop->getLoopPreheader()->getTerminator());
503   LLVM_DEBUG(dbgs() << "Created new trip count in preheader: ";
504              NewTripCount->dump());
505 
506   // Fix up PHI nodes that take values from the inner loop back-edge, which
507   // we are about to remove.
508   InnerInductionPHI->removeIncomingValue(InnerLoop->getLoopLatch());
509   for (PHINode *PHI : InnerPHIsToTransform)
510     PHI->removeIncomingValue(InnerLoop->getLoopLatch());
511 
512   // Modify the trip count of the outer loop to be the product of the two
513   // trip counts.
514   cast<User>(OuterBranch->getCondition())->setOperand(1, NewTripCount);
515 
516   // Replace the inner loop backedge with an unconditional branch to the exit.
517   BasicBlock *InnerExitBlock = InnerLoop->getExitBlock();
518   BasicBlock *InnerExitingBlock = InnerLoop->getExitingBlock();
519   InnerExitingBlock->getTerminator()->eraseFromParent();
520   BranchInst::Create(InnerExitBlock, InnerExitingBlock);
521   DT->deleteEdge(InnerExitingBlock, InnerLoop->getHeader());
522 
523   // Replace all uses of the polynomial calculated from the two induction
524   // variables with the one new one.
525   for (Value *V : LinearIVUses)
526     V->replaceAllUsesWith(OuterInductionPHI);
527 
528   // Tell LoopInfo, SCEV and the pass manager that the inner loop has been
529   // deleted, and any information that have about the outer loop invalidated.
530   markLoopAsDeleted(InnerLoop);
531   SE->forgetLoop(OuterLoop);
532   SE->forgetLoop(InnerLoop);
533   LI->erase(InnerLoop);
534 
535   return true;
536 }
537 
538 PreservedAnalyses LoopFlattenPass::run(Loop &L, LoopAnalysisManager &AM,
539                                        LoopStandardAnalysisResults &AR,
540                                        LPMUpdater &Updater) {
541   if (L.getSubLoops().size() != 1)
542     return PreservedAnalyses::all();
543 
544   Loop *InnerLoop = *L.begin();
545   std::string LoopName(InnerLoop->getName());
546   if (!FlattenLoopPair(
547           &L, InnerLoop, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI,
548           [&](Loop *L) { Updater.markLoopAsDeleted(*L, LoopName); }))
549     return PreservedAnalyses::all();
550   return getLoopPassPreservedAnalyses();
551 }
552 
553 namespace {
554 class LoopFlattenLegacyPass : public LoopPass {
555 public:
556   static char ID; // Pass ID, replacement for typeid
557   LoopFlattenLegacyPass() : LoopPass(ID) {
558     initializeLoopFlattenLegacyPassPass(*PassRegistry::getPassRegistry());
559   }
560 
561   // Possibly flatten loop L into its child.
562   bool runOnLoop(Loop *L, LPPassManager &) override;
563 
564   void getAnalysisUsage(AnalysisUsage &AU) const override {
565     getLoopAnalysisUsage(AU);
566     AU.addRequired<TargetTransformInfoWrapperPass>();
567     AU.addPreserved<TargetTransformInfoWrapperPass>();
568     AU.addRequired<AssumptionCacheTracker>();
569     AU.addPreserved<AssumptionCacheTracker>();
570   }
571 };
572 } // namespace
573 
574 char LoopFlattenLegacyPass::ID = 0;
575 INITIALIZE_PASS_BEGIN(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
576                       false, false)
577 INITIALIZE_PASS_DEPENDENCY(LoopPass)
578 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
579 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
580 INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
581                     false, false)
582 
583 Pass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); }
584 
585 bool LoopFlattenLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
586   if (skipLoop(L))
587     return false;
588 
589   if (L->getSubLoops().size() != 1)
590     return false;
591 
592   ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
593   LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
594   auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
595   DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr;
596   auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>();
597   TargetTransformInfo *TTI = &TTIP.getTTI(*L->getHeader()->getParent());
598   AssumptionCache *AC =
599       &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
600           *L->getHeader()->getParent());
601 
602   Loop *InnerLoop = *L->begin();
603   return FlattenLoopPair(L, InnerLoop, DT, LI, SE, AC, TTI,
604                          [&](Loop *L) { LPM.markLoopAsDeleted(*L); });
605 }
606