xref: /freebsd-src/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp (revision e8d8bef961a50d4dc22501cde4fb9fb0be1b2532)
1*e8d8bef9SDimitry Andric //===- LoopFlatten.cpp - Loop flattening pass------------------------------===//
2*e8d8bef9SDimitry Andric //
3*e8d8bef9SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*e8d8bef9SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*e8d8bef9SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*e8d8bef9SDimitry Andric //
7*e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===//
8*e8d8bef9SDimitry Andric //
9*e8d8bef9SDimitry Andric // This pass flattens pairs nested loops into a single loop.
10*e8d8bef9SDimitry Andric //
11*e8d8bef9SDimitry Andric // The intention is to optimise loop nests like this, which together access an
12*e8d8bef9SDimitry Andric // array linearly:
13*e8d8bef9SDimitry Andric //   for (int i = 0; i < N; ++i)
14*e8d8bef9SDimitry Andric //     for (int j = 0; j < M; ++j)
15*e8d8bef9SDimitry Andric //       f(A[i*M+j]);
16*e8d8bef9SDimitry Andric // into one loop:
17*e8d8bef9SDimitry Andric //   for (int i = 0; i < (N*M); ++i)
18*e8d8bef9SDimitry Andric //     f(A[i]);
19*e8d8bef9SDimitry Andric //
20*e8d8bef9SDimitry Andric // It can also flatten loops where the induction variables are not used in the
21*e8d8bef9SDimitry Andric // loop. This is only worth doing if the induction variables are only used in an
22*e8d8bef9SDimitry Andric // expression like i*M+j. If they had any other uses, we would have to insert a
23*e8d8bef9SDimitry Andric // div/mod to reconstruct the original values, so this wouldn't be profitable.
24*e8d8bef9SDimitry Andric //
25*e8d8bef9SDimitry Andric // We also need to prove that N*M will not overflow.
26*e8d8bef9SDimitry Andric //
27*e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===//
28*e8d8bef9SDimitry Andric 
29*e8d8bef9SDimitry Andric #include "llvm/Transforms/Scalar/LoopFlatten.h"
30*e8d8bef9SDimitry Andric #include "llvm/Analysis/AssumptionCache.h"
31*e8d8bef9SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
32*e8d8bef9SDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h"
33*e8d8bef9SDimitry Andric #include "llvm/Analysis/ScalarEvolution.h"
34*e8d8bef9SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
35*e8d8bef9SDimitry Andric #include "llvm/Analysis/ValueTracking.h"
36*e8d8bef9SDimitry Andric #include "llvm/IR/Dominators.h"
37*e8d8bef9SDimitry Andric #include "llvm/IR/Function.h"
38*e8d8bef9SDimitry Andric #include "llvm/IR/IRBuilder.h"
39*e8d8bef9SDimitry Andric #include "llvm/IR/Module.h"
40*e8d8bef9SDimitry Andric #include "llvm/IR/PatternMatch.h"
41*e8d8bef9SDimitry Andric #include "llvm/IR/Verifier.h"
42*e8d8bef9SDimitry Andric #include "llvm/InitializePasses.h"
43*e8d8bef9SDimitry Andric #include "llvm/Pass.h"
44*e8d8bef9SDimitry Andric #include "llvm/Support/Debug.h"
45*e8d8bef9SDimitry Andric #include "llvm/Support/raw_ostream.h"
46*e8d8bef9SDimitry Andric #include "llvm/Transforms/Scalar.h"
47*e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/Local.h"
48*e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h"
49*e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
50*e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/SimplifyIndVar.h"
51*e8d8bef9SDimitry Andric 
52*e8d8bef9SDimitry Andric #define DEBUG_TYPE "loop-flatten"
53*e8d8bef9SDimitry Andric 
54*e8d8bef9SDimitry Andric using namespace llvm;
55*e8d8bef9SDimitry Andric using namespace llvm::PatternMatch;
56*e8d8bef9SDimitry Andric 
57*e8d8bef9SDimitry Andric static cl::opt<unsigned> RepeatedInstructionThreshold(
58*e8d8bef9SDimitry Andric     "loop-flatten-cost-threshold", cl::Hidden, cl::init(2),
59*e8d8bef9SDimitry Andric     cl::desc("Limit on the cost of instructions that can be repeated due to "
60*e8d8bef9SDimitry Andric              "loop flattening"));
61*e8d8bef9SDimitry Andric 
62*e8d8bef9SDimitry Andric static cl::opt<bool>
63*e8d8bef9SDimitry Andric     AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden,
64*e8d8bef9SDimitry Andric                      cl::init(false),
65*e8d8bef9SDimitry Andric                      cl::desc("Assume that the product of the two iteration "
66*e8d8bef9SDimitry Andric                               "limits will never overflow"));
67*e8d8bef9SDimitry Andric 
68*e8d8bef9SDimitry Andric static cl::opt<bool>
69*e8d8bef9SDimitry Andric     WidenIV("loop-flatten-widen-iv", cl::Hidden,
70*e8d8bef9SDimitry Andric             cl::init(true),
71*e8d8bef9SDimitry Andric             cl::desc("Widen the loop induction variables, if possible, so "
72*e8d8bef9SDimitry Andric                      "overflow checks won't reject flattening"));
73*e8d8bef9SDimitry Andric 
74*e8d8bef9SDimitry Andric struct FlattenInfo {
75*e8d8bef9SDimitry Andric   Loop *OuterLoop = nullptr;
76*e8d8bef9SDimitry Andric   Loop *InnerLoop = nullptr;
77*e8d8bef9SDimitry Andric   PHINode *InnerInductionPHI = nullptr;
78*e8d8bef9SDimitry Andric   PHINode *OuterInductionPHI = nullptr;
79*e8d8bef9SDimitry Andric   Value *InnerLimit = nullptr;
80*e8d8bef9SDimitry Andric   Value *OuterLimit = nullptr;
81*e8d8bef9SDimitry Andric   BinaryOperator *InnerIncrement = nullptr;
82*e8d8bef9SDimitry Andric   BinaryOperator *OuterIncrement = nullptr;
83*e8d8bef9SDimitry Andric   BranchInst *InnerBranch = nullptr;
84*e8d8bef9SDimitry Andric   BranchInst *OuterBranch = nullptr;
85*e8d8bef9SDimitry Andric   SmallPtrSet<Value *, 4> LinearIVUses;
86*e8d8bef9SDimitry Andric   SmallPtrSet<PHINode *, 4> InnerPHIsToTransform;
87*e8d8bef9SDimitry Andric 
88*e8d8bef9SDimitry Andric   // Whether this holds the flatten info before or after widening.
89*e8d8bef9SDimitry Andric   bool Widened = false;
90*e8d8bef9SDimitry Andric 
91*e8d8bef9SDimitry Andric   FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {};
92*e8d8bef9SDimitry Andric };
93*e8d8bef9SDimitry Andric 
94*e8d8bef9SDimitry Andric // Finds the induction variable, increment and limit for a simple loop that we
95*e8d8bef9SDimitry Andric // can flatten.
96*e8d8bef9SDimitry Andric static bool findLoopComponents(
97*e8d8bef9SDimitry Andric     Loop *L, SmallPtrSetImpl<Instruction *> &IterationInstructions,
98*e8d8bef9SDimitry Andric     PHINode *&InductionPHI, Value *&Limit, BinaryOperator *&Increment,
99*e8d8bef9SDimitry Andric     BranchInst *&BackBranch, ScalarEvolution *SE) {
100*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n");
101*e8d8bef9SDimitry Andric 
102*e8d8bef9SDimitry Andric   if (!L->isLoopSimplifyForm()) {
103*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "Loop is not in normal form\n");
104*e8d8bef9SDimitry Andric     return false;
105*e8d8bef9SDimitry Andric   }
106*e8d8bef9SDimitry Andric 
107*e8d8bef9SDimitry Andric   // There must be exactly one exiting block, and it must be the same at the
108*e8d8bef9SDimitry Andric   // latch.
109*e8d8bef9SDimitry Andric   BasicBlock *Latch = L->getLoopLatch();
110*e8d8bef9SDimitry Andric   if (L->getExitingBlock() != Latch) {
111*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "Exiting and latch block are different\n");
112*e8d8bef9SDimitry Andric     return false;
113*e8d8bef9SDimitry Andric   }
114*e8d8bef9SDimitry Andric   // Latch block must end in a conditional branch.
115*e8d8bef9SDimitry Andric   BackBranch = dyn_cast<BranchInst>(Latch->getTerminator());
116*e8d8bef9SDimitry Andric   if (!BackBranch || !BackBranch->isConditional()) {
117*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "Could not find back-branch\n");
118*e8d8bef9SDimitry Andric     return false;
119*e8d8bef9SDimitry Andric   }
120*e8d8bef9SDimitry Andric   IterationInstructions.insert(BackBranch);
121*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "Found back branch: "; BackBranch->dump());
122*e8d8bef9SDimitry Andric   bool ContinueOnTrue = L->contains(BackBranch->getSuccessor(0));
123*e8d8bef9SDimitry Andric 
124*e8d8bef9SDimitry Andric   // Find the induction PHI. If there is no induction PHI, we can't do the
125*e8d8bef9SDimitry Andric   // transformation. TODO: could other variables trigger this? Do we have to
126*e8d8bef9SDimitry Andric   // search for the best one?
127*e8d8bef9SDimitry Andric   InductionPHI = nullptr;
128*e8d8bef9SDimitry Andric   for (PHINode &PHI : L->getHeader()->phis()) {
129*e8d8bef9SDimitry Andric     InductionDescriptor ID;
130*e8d8bef9SDimitry Andric     if (InductionDescriptor::isInductionPHI(&PHI, L, SE, ID)) {
131*e8d8bef9SDimitry Andric       InductionPHI = &PHI;
132*e8d8bef9SDimitry Andric       LLVM_DEBUG(dbgs() << "Found induction PHI: "; InductionPHI->dump());
133*e8d8bef9SDimitry Andric       break;
134*e8d8bef9SDimitry Andric     }
135*e8d8bef9SDimitry Andric   }
136*e8d8bef9SDimitry Andric   if (!InductionPHI) {
137*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "Could not find induction PHI\n");
138*e8d8bef9SDimitry Andric     return false;
139*e8d8bef9SDimitry Andric   }
140*e8d8bef9SDimitry Andric 
141*e8d8bef9SDimitry Andric   auto IsValidPredicate = [&](ICmpInst::Predicate Pred) {
142*e8d8bef9SDimitry Andric     if (ContinueOnTrue)
143*e8d8bef9SDimitry Andric       return Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT;
144*e8d8bef9SDimitry Andric     else
145*e8d8bef9SDimitry Andric       return Pred == CmpInst::ICMP_EQ;
146*e8d8bef9SDimitry Andric   };
147*e8d8bef9SDimitry Andric 
148*e8d8bef9SDimitry Andric   // Find Compare and make sure it is valid
149*e8d8bef9SDimitry Andric   ICmpInst *Compare = dyn_cast<ICmpInst>(BackBranch->getCondition());
150*e8d8bef9SDimitry Andric   if (!Compare || !IsValidPredicate(Compare->getUnsignedPredicate()) ||
151*e8d8bef9SDimitry Andric       Compare->hasNUsesOrMore(2)) {
152*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "Could not find valid comparison\n");
153*e8d8bef9SDimitry Andric     return false;
154*e8d8bef9SDimitry Andric   }
155*e8d8bef9SDimitry Andric   IterationInstructions.insert(Compare);
156*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump());
157*e8d8bef9SDimitry Andric 
158*e8d8bef9SDimitry Andric   // Find increment and limit from the compare
159*e8d8bef9SDimitry Andric   Increment = nullptr;
160*e8d8bef9SDimitry Andric   if (match(Compare->getOperand(0),
161*e8d8bef9SDimitry Andric             m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) {
162*e8d8bef9SDimitry Andric     Increment = dyn_cast<BinaryOperator>(Compare->getOperand(0));
163*e8d8bef9SDimitry Andric     Limit = Compare->getOperand(1);
164*e8d8bef9SDimitry Andric   } else if (Compare->getUnsignedPredicate() == CmpInst::ICMP_NE &&
165*e8d8bef9SDimitry Andric              match(Compare->getOperand(1),
166*e8d8bef9SDimitry Andric                    m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) {
167*e8d8bef9SDimitry Andric     Increment = dyn_cast<BinaryOperator>(Compare->getOperand(1));
168*e8d8bef9SDimitry Andric     Limit = Compare->getOperand(0);
169*e8d8bef9SDimitry Andric   }
170*e8d8bef9SDimitry Andric   if (!Increment || Increment->hasNUsesOrMore(3)) {
171*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "Cound not find valid increment\n");
172*e8d8bef9SDimitry Andric     return false;
173*e8d8bef9SDimitry Andric   }
174*e8d8bef9SDimitry Andric   IterationInstructions.insert(Increment);
175*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump());
176*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "Found limit: "; Limit->dump());
177*e8d8bef9SDimitry Andric 
178*e8d8bef9SDimitry Andric   assert(InductionPHI->getNumIncomingValues() == 2);
179*e8d8bef9SDimitry Andric   assert(InductionPHI->getIncomingValueForBlock(Latch) == Increment &&
180*e8d8bef9SDimitry Andric          "PHI value is not increment inst");
181*e8d8bef9SDimitry Andric 
182*e8d8bef9SDimitry Andric   auto *CI = dyn_cast<ConstantInt>(
183*e8d8bef9SDimitry Andric       InductionPHI->getIncomingValueForBlock(L->getLoopPreheader()));
184*e8d8bef9SDimitry Andric   if (!CI || !CI->isZero()) {
185*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "PHI value is not zero: "; CI->dump());
186*e8d8bef9SDimitry Andric     return false;
187*e8d8bef9SDimitry Andric   }
188*e8d8bef9SDimitry Andric 
189*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "Successfully found all loop components\n");
190*e8d8bef9SDimitry Andric   return true;
191*e8d8bef9SDimitry Andric }
192*e8d8bef9SDimitry Andric 
193*e8d8bef9SDimitry Andric static bool checkPHIs(struct FlattenInfo &FI,
194*e8d8bef9SDimitry Andric                       const TargetTransformInfo *TTI) {
195*e8d8bef9SDimitry Andric   // All PHIs in the inner and outer headers must either be:
196*e8d8bef9SDimitry Andric   // - The induction PHI, which we are going to rewrite as one induction in
197*e8d8bef9SDimitry Andric   //   the new loop. This is already checked by findLoopComponents.
198*e8d8bef9SDimitry Andric   // - An outer header PHI with all incoming values from outside the loop.
199*e8d8bef9SDimitry Andric   //   LoopSimplify guarantees we have a pre-header, so we don't need to
200*e8d8bef9SDimitry Andric   //   worry about that here.
201*e8d8bef9SDimitry Andric   // - Pairs of PHIs in the inner and outer headers, which implement a
202*e8d8bef9SDimitry Andric   //   loop-carried dependency that will still be valid in the new loop. To
203*e8d8bef9SDimitry Andric   //   be valid, this variable must be modified only in the inner loop.
204*e8d8bef9SDimitry Andric 
205*e8d8bef9SDimitry Andric   // The set of PHI nodes in the outer loop header that we know will still be
206*e8d8bef9SDimitry Andric   // valid after the transformation. These will not need to be modified (with
207*e8d8bef9SDimitry Andric   // the exception of the induction variable), but we do need to check that
208*e8d8bef9SDimitry Andric   // there are no unsafe PHI nodes.
209*e8d8bef9SDimitry Andric   SmallPtrSet<PHINode *, 4> SafeOuterPHIs;
210*e8d8bef9SDimitry Andric   SafeOuterPHIs.insert(FI.OuterInductionPHI);
211*e8d8bef9SDimitry Andric 
212*e8d8bef9SDimitry Andric   // Check that all PHI nodes in the inner loop header match one of the valid
213*e8d8bef9SDimitry Andric   // patterns.
214*e8d8bef9SDimitry Andric   for (PHINode &InnerPHI : FI.InnerLoop->getHeader()->phis()) {
215*e8d8bef9SDimitry Andric     // The induction PHIs break these rules, and that's OK because we treat
216*e8d8bef9SDimitry Andric     // them specially when doing the transformation.
217*e8d8bef9SDimitry Andric     if (&InnerPHI == FI.InnerInductionPHI)
218*e8d8bef9SDimitry Andric       continue;
219*e8d8bef9SDimitry Andric 
220*e8d8bef9SDimitry Andric     // Each inner loop PHI node must have two incoming values/blocks - one
221*e8d8bef9SDimitry Andric     // from the pre-header, and one from the latch.
222*e8d8bef9SDimitry Andric     assert(InnerPHI.getNumIncomingValues() == 2);
223*e8d8bef9SDimitry Andric     Value *PreHeaderValue =
224*e8d8bef9SDimitry Andric         InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopPreheader());
225*e8d8bef9SDimitry Andric     Value *LatchValue =
226*e8d8bef9SDimitry Andric         InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopLatch());
227*e8d8bef9SDimitry Andric 
228*e8d8bef9SDimitry Andric     // The incoming value from the outer loop must be the PHI node in the
229*e8d8bef9SDimitry Andric     // outer loop header, with no modifications made in the top of the outer
230*e8d8bef9SDimitry Andric     // loop.
231*e8d8bef9SDimitry Andric     PHINode *OuterPHI = dyn_cast<PHINode>(PreHeaderValue);
232*e8d8bef9SDimitry Andric     if (!OuterPHI || OuterPHI->getParent() != FI.OuterLoop->getHeader()) {
233*e8d8bef9SDimitry Andric       LLVM_DEBUG(dbgs() << "value modified in top of outer loop\n");
234*e8d8bef9SDimitry Andric       return false;
235*e8d8bef9SDimitry Andric     }
236*e8d8bef9SDimitry Andric 
237*e8d8bef9SDimitry Andric     // The other incoming value must come from the inner loop, without any
238*e8d8bef9SDimitry Andric     // modifications in the tail end of the outer loop. We are in LCSSA form,
239*e8d8bef9SDimitry Andric     // so this will actually be a PHI in the inner loop's exit block, which
240*e8d8bef9SDimitry Andric     // only uses values from inside the inner loop.
241*e8d8bef9SDimitry Andric     PHINode *LCSSAPHI = dyn_cast<PHINode>(
242*e8d8bef9SDimitry Andric         OuterPHI->getIncomingValueForBlock(FI.OuterLoop->getLoopLatch()));
243*e8d8bef9SDimitry Andric     if (!LCSSAPHI) {
244*e8d8bef9SDimitry Andric       LLVM_DEBUG(dbgs() << "could not find LCSSA PHI\n");
245*e8d8bef9SDimitry Andric       return false;
246*e8d8bef9SDimitry Andric     }
247*e8d8bef9SDimitry Andric 
248*e8d8bef9SDimitry Andric     // The value used by the LCSSA PHI must be the same one that the inner
249*e8d8bef9SDimitry Andric     // loop's PHI uses.
250*e8d8bef9SDimitry Andric     if (LCSSAPHI->hasConstantValue() != LatchValue) {
251*e8d8bef9SDimitry Andric       LLVM_DEBUG(
252*e8d8bef9SDimitry Andric           dbgs() << "LCSSA PHI incoming value does not match latch value\n");
253*e8d8bef9SDimitry Andric       return false;
254*e8d8bef9SDimitry Andric     }
255*e8d8bef9SDimitry Andric 
256*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "PHI pair is safe:\n");
257*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "  Inner: "; InnerPHI.dump());
258*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "  Outer: "; OuterPHI->dump());
259*e8d8bef9SDimitry Andric     SafeOuterPHIs.insert(OuterPHI);
260*e8d8bef9SDimitry Andric     FI.InnerPHIsToTransform.insert(&InnerPHI);
261*e8d8bef9SDimitry Andric   }
262*e8d8bef9SDimitry Andric 
263*e8d8bef9SDimitry Andric   for (PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) {
264*e8d8bef9SDimitry Andric     if (!SafeOuterPHIs.count(&OuterPHI)) {
265*e8d8bef9SDimitry Andric       LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: "; OuterPHI.dump());
266*e8d8bef9SDimitry Andric       return false;
267*e8d8bef9SDimitry Andric     }
268*e8d8bef9SDimitry Andric   }
269*e8d8bef9SDimitry Andric 
270*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "checkPHIs: OK\n");
271*e8d8bef9SDimitry Andric   return true;
272*e8d8bef9SDimitry Andric }
273*e8d8bef9SDimitry Andric 
274*e8d8bef9SDimitry Andric static bool
275*e8d8bef9SDimitry Andric checkOuterLoopInsts(struct FlattenInfo &FI,
276*e8d8bef9SDimitry Andric                     SmallPtrSetImpl<Instruction *> &IterationInstructions,
277*e8d8bef9SDimitry Andric                     const TargetTransformInfo *TTI) {
278*e8d8bef9SDimitry Andric   // Check for instructions in the outer but not inner loop. If any of these
279*e8d8bef9SDimitry Andric   // have side-effects then this transformation is not legal, and if there is
280*e8d8bef9SDimitry Andric   // a significant amount of code here which can't be optimised out that it's
281*e8d8bef9SDimitry Andric   // not profitable (as these instructions would get executed for each
282*e8d8bef9SDimitry Andric   // iteration of the inner loop).
283*e8d8bef9SDimitry Andric   unsigned RepeatedInstrCost = 0;
284*e8d8bef9SDimitry Andric   for (auto *B : FI.OuterLoop->getBlocks()) {
285*e8d8bef9SDimitry Andric     if (FI.InnerLoop->contains(B))
286*e8d8bef9SDimitry Andric       continue;
287*e8d8bef9SDimitry Andric 
288*e8d8bef9SDimitry Andric     for (auto &I : *B) {
289*e8d8bef9SDimitry Andric       if (!isa<PHINode>(&I) && !I.isTerminator() &&
290*e8d8bef9SDimitry Andric           !isSafeToSpeculativelyExecute(&I)) {
291*e8d8bef9SDimitry Andric         LLVM_DEBUG(dbgs() << "Cannot flatten because instruction may have "
292*e8d8bef9SDimitry Andric                              "side effects: ";
293*e8d8bef9SDimitry Andric                    I.dump());
294*e8d8bef9SDimitry Andric         return false;
295*e8d8bef9SDimitry Andric       }
296*e8d8bef9SDimitry Andric       // The execution count of the outer loop's iteration instructions
297*e8d8bef9SDimitry Andric       // (increment, compare and branch) will be increased, but the
298*e8d8bef9SDimitry Andric       // equivalent instructions will be removed from the inner loop, so
299*e8d8bef9SDimitry Andric       // they make a net difference of zero.
300*e8d8bef9SDimitry Andric       if (IterationInstructions.count(&I))
301*e8d8bef9SDimitry Andric         continue;
302*e8d8bef9SDimitry Andric       // The uncoditional branch to the inner loop's header will turn into
303*e8d8bef9SDimitry Andric       // a fall-through, so adds no cost.
304*e8d8bef9SDimitry Andric       BranchInst *Br = dyn_cast<BranchInst>(&I);
305*e8d8bef9SDimitry Andric       if (Br && Br->isUnconditional() &&
306*e8d8bef9SDimitry Andric           Br->getSuccessor(0) == FI.InnerLoop->getHeader())
307*e8d8bef9SDimitry Andric         continue;
308*e8d8bef9SDimitry Andric       // Multiplies of the outer iteration variable and inner iteration
309*e8d8bef9SDimitry Andric       // count will be optimised out.
310*e8d8bef9SDimitry Andric       if (match(&I, m_c_Mul(m_Specific(FI.OuterInductionPHI),
311*e8d8bef9SDimitry Andric                             m_Specific(FI.InnerLimit))))
312*e8d8bef9SDimitry Andric         continue;
313*e8d8bef9SDimitry Andric       int Cost = TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
314*e8d8bef9SDimitry Andric       LLVM_DEBUG(dbgs() << "Cost " << Cost << ": "; I.dump());
315*e8d8bef9SDimitry Andric       RepeatedInstrCost += Cost;
316*e8d8bef9SDimitry Andric     }
317*e8d8bef9SDimitry Andric   }
318*e8d8bef9SDimitry Andric 
319*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "Cost of instructions that will be repeated: "
320*e8d8bef9SDimitry Andric                     << RepeatedInstrCost << "\n");
321*e8d8bef9SDimitry Andric   // Bail out if flattening the loops would cause instructions in the outer
322*e8d8bef9SDimitry Andric   // loop but not in the inner loop to be executed extra times.
323*e8d8bef9SDimitry Andric   if (RepeatedInstrCost > RepeatedInstructionThreshold) {
324*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: not profitable, bailing.\n");
325*e8d8bef9SDimitry Andric     return false;
326*e8d8bef9SDimitry Andric   }
327*e8d8bef9SDimitry Andric 
328*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: OK\n");
329*e8d8bef9SDimitry Andric   return true;
330*e8d8bef9SDimitry Andric }
331*e8d8bef9SDimitry Andric 
332*e8d8bef9SDimitry Andric static bool checkIVUsers(struct FlattenInfo &FI) {
333*e8d8bef9SDimitry Andric   // We require all uses of both induction variables to match this pattern:
334*e8d8bef9SDimitry Andric   //
335*e8d8bef9SDimitry Andric   //   (OuterPHI * InnerLimit) + InnerPHI
336*e8d8bef9SDimitry Andric   //
337*e8d8bef9SDimitry Andric   // Any uses of the induction variables not matching that pattern would
338*e8d8bef9SDimitry Andric   // require a div/mod to reconstruct in the flattened loop, so the
339*e8d8bef9SDimitry Andric   // transformation wouldn't be profitable.
340*e8d8bef9SDimitry Andric 
341*e8d8bef9SDimitry Andric   Value *InnerLimit = FI.InnerLimit;
342*e8d8bef9SDimitry Andric   if (FI.Widened &&
343*e8d8bef9SDimitry Andric       (isa<SExtInst>(InnerLimit) || isa<ZExtInst>(InnerLimit)))
344*e8d8bef9SDimitry Andric     InnerLimit = cast<Instruction>(InnerLimit)->getOperand(0);
345*e8d8bef9SDimitry Andric 
346*e8d8bef9SDimitry Andric   // Check that all uses of the inner loop's induction variable match the
347*e8d8bef9SDimitry Andric   // expected pattern, recording the uses of the outer IV.
348*e8d8bef9SDimitry Andric   SmallPtrSet<Value *, 4> ValidOuterPHIUses;
349*e8d8bef9SDimitry Andric   for (User *U : FI.InnerInductionPHI->users()) {
350*e8d8bef9SDimitry Andric     if (U == FI.InnerIncrement)
351*e8d8bef9SDimitry Andric       continue;
352*e8d8bef9SDimitry Andric 
353*e8d8bef9SDimitry Andric     // After widening the IVs, a trunc instruction might have been introduced, so
354*e8d8bef9SDimitry Andric     // look through truncs.
355*e8d8bef9SDimitry Andric     if (isa<TruncInst>(U)) {
356*e8d8bef9SDimitry Andric       if (!U->hasOneUse())
357*e8d8bef9SDimitry Andric         return false;
358*e8d8bef9SDimitry Andric       U = *U->user_begin();
359*e8d8bef9SDimitry Andric     }
360*e8d8bef9SDimitry Andric 
361*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump());
362*e8d8bef9SDimitry Andric 
363*e8d8bef9SDimitry Andric     Value *MatchedMul;
364*e8d8bef9SDimitry Andric     Value *MatchedItCount;
365*e8d8bef9SDimitry Andric     bool IsAdd = match(U, m_c_Add(m_Specific(FI.InnerInductionPHI),
366*e8d8bef9SDimitry Andric                                   m_Value(MatchedMul))) &&
367*e8d8bef9SDimitry Andric                  match(MatchedMul, m_c_Mul(m_Specific(FI.OuterInductionPHI),
368*e8d8bef9SDimitry Andric                                            m_Value(MatchedItCount)));
369*e8d8bef9SDimitry Andric 
370*e8d8bef9SDimitry Andric     // Matches the same pattern as above, except it also looks for truncs
371*e8d8bef9SDimitry Andric     // on the phi, which can be the result of widening the induction variables.
372*e8d8bef9SDimitry Andric     bool IsAddTrunc = match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)),
373*e8d8bef9SDimitry Andric                                        m_Value(MatchedMul))) &&
374*e8d8bef9SDimitry Andric                       match(MatchedMul,
375*e8d8bef9SDimitry Andric                             m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)),
376*e8d8bef9SDimitry Andric                             m_Value(MatchedItCount)));
377*e8d8bef9SDimitry Andric 
378*e8d8bef9SDimitry Andric     if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerLimit) {
379*e8d8bef9SDimitry Andric       LLVM_DEBUG(dbgs() << "Use is optimisable\n");
380*e8d8bef9SDimitry Andric       ValidOuterPHIUses.insert(MatchedMul);
381*e8d8bef9SDimitry Andric       FI.LinearIVUses.insert(U);
382*e8d8bef9SDimitry Andric     } else {
383*e8d8bef9SDimitry Andric       LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
384*e8d8bef9SDimitry Andric       return false;
385*e8d8bef9SDimitry Andric     }
386*e8d8bef9SDimitry Andric   }
387*e8d8bef9SDimitry Andric 
388*e8d8bef9SDimitry Andric   // Check that there are no uses of the outer IV other than the ones found
389*e8d8bef9SDimitry Andric   // as part of the pattern above.
390*e8d8bef9SDimitry Andric   for (User *U : FI.OuterInductionPHI->users()) {
391*e8d8bef9SDimitry Andric     if (U == FI.OuterIncrement)
392*e8d8bef9SDimitry Andric       continue;
393*e8d8bef9SDimitry Andric 
394*e8d8bef9SDimitry Andric     auto IsValidOuterPHIUses = [&] (User *U) -> bool {
395*e8d8bef9SDimitry Andric       LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
396*e8d8bef9SDimitry Andric       if (!ValidOuterPHIUses.count(U)) {
397*e8d8bef9SDimitry Andric         LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
398*e8d8bef9SDimitry Andric         return false;
399*e8d8bef9SDimitry Andric       }
400*e8d8bef9SDimitry Andric       LLVM_DEBUG(dbgs() << "Use is optimisable\n");
401*e8d8bef9SDimitry Andric       return true;
402*e8d8bef9SDimitry Andric     };
403*e8d8bef9SDimitry Andric 
404*e8d8bef9SDimitry Andric     if (auto *V = dyn_cast<TruncInst>(U)) {
405*e8d8bef9SDimitry Andric       for (auto *K : V->users()) {
406*e8d8bef9SDimitry Andric         if (!IsValidOuterPHIUses(K))
407*e8d8bef9SDimitry Andric           return false;
408*e8d8bef9SDimitry Andric       }
409*e8d8bef9SDimitry Andric       continue;
410*e8d8bef9SDimitry Andric     }
411*e8d8bef9SDimitry Andric 
412*e8d8bef9SDimitry Andric     if (!IsValidOuterPHIUses(U))
413*e8d8bef9SDimitry Andric       return false;
414*e8d8bef9SDimitry Andric   }
415*e8d8bef9SDimitry Andric 
416*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n";
417*e8d8bef9SDimitry Andric              dbgs() << "Found " << FI.LinearIVUses.size()
418*e8d8bef9SDimitry Andric                     << " value(s) that can be replaced:\n";
419*e8d8bef9SDimitry Andric              for (Value *V : FI.LinearIVUses) {
420*e8d8bef9SDimitry Andric                dbgs() << "  ";
421*e8d8bef9SDimitry Andric                V->dump();
422*e8d8bef9SDimitry Andric              });
423*e8d8bef9SDimitry Andric   return true;
424*e8d8bef9SDimitry Andric }
425*e8d8bef9SDimitry Andric 
426*e8d8bef9SDimitry Andric // Return an OverflowResult dependant on if overflow of the multiplication of
427*e8d8bef9SDimitry Andric // InnerLimit and OuterLimit can be assumed not to happen.
428*e8d8bef9SDimitry Andric static OverflowResult checkOverflow(struct FlattenInfo &FI,
429*e8d8bef9SDimitry Andric                                     DominatorTree *DT, AssumptionCache *AC) {
430*e8d8bef9SDimitry Andric   Function *F = FI.OuterLoop->getHeader()->getParent();
431*e8d8bef9SDimitry Andric   const DataLayout &DL = F->getParent()->getDataLayout();
432*e8d8bef9SDimitry Andric 
433*e8d8bef9SDimitry Andric   // For debugging/testing.
434*e8d8bef9SDimitry Andric   if (AssumeNoOverflow)
435*e8d8bef9SDimitry Andric     return OverflowResult::NeverOverflows;
436*e8d8bef9SDimitry Andric 
437*e8d8bef9SDimitry Andric   // Check if the multiply could not overflow due to known ranges of the
438*e8d8bef9SDimitry Andric   // input values.
439*e8d8bef9SDimitry Andric   OverflowResult OR = computeOverflowForUnsignedMul(
440*e8d8bef9SDimitry Andric       FI.InnerLimit, FI.OuterLimit, DL, AC,
441*e8d8bef9SDimitry Andric       FI.OuterLoop->getLoopPreheader()->getTerminator(), DT);
442*e8d8bef9SDimitry Andric   if (OR != OverflowResult::MayOverflow)
443*e8d8bef9SDimitry Andric     return OR;
444*e8d8bef9SDimitry Andric 
445*e8d8bef9SDimitry Andric   for (Value *V : FI.LinearIVUses) {
446*e8d8bef9SDimitry Andric     for (Value *U : V->users()) {
447*e8d8bef9SDimitry Andric       if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
448*e8d8bef9SDimitry Andric         // The IV is used as the operand of a GEP, and the IV is at least as
449*e8d8bef9SDimitry Andric         // wide as the address space of the GEP. In this case, the GEP would
450*e8d8bef9SDimitry Andric         // wrap around the address space before the IV increment wraps, which
451*e8d8bef9SDimitry Andric         // would be UB.
452*e8d8bef9SDimitry Andric         if (GEP->isInBounds() &&
453*e8d8bef9SDimitry Andric             V->getType()->getIntegerBitWidth() >=
454*e8d8bef9SDimitry Andric                 DL.getPointerTypeSizeInBits(GEP->getType())) {
455*e8d8bef9SDimitry Andric           LLVM_DEBUG(
456*e8d8bef9SDimitry Andric               dbgs() << "use of linear IV would be UB if overflow occurred: ";
457*e8d8bef9SDimitry Andric               GEP->dump());
458*e8d8bef9SDimitry Andric           return OverflowResult::NeverOverflows;
459*e8d8bef9SDimitry Andric         }
460*e8d8bef9SDimitry Andric       }
461*e8d8bef9SDimitry Andric     }
462*e8d8bef9SDimitry Andric   }
463*e8d8bef9SDimitry Andric 
464*e8d8bef9SDimitry Andric   return OverflowResult::MayOverflow;
465*e8d8bef9SDimitry Andric }
466*e8d8bef9SDimitry Andric 
467*e8d8bef9SDimitry Andric static bool CanFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT,
468*e8d8bef9SDimitry Andric                                LoopInfo *LI, ScalarEvolution *SE,
469*e8d8bef9SDimitry Andric                                AssumptionCache *AC, const TargetTransformInfo *TTI) {
470*e8d8bef9SDimitry Andric   SmallPtrSet<Instruction *, 8> IterationInstructions;
471*e8d8bef9SDimitry Andric   if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI,
472*e8d8bef9SDimitry Andric                           FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE))
473*e8d8bef9SDimitry Andric     return false;
474*e8d8bef9SDimitry Andric   if (!findLoopComponents(FI.OuterLoop, IterationInstructions, FI.OuterInductionPHI,
475*e8d8bef9SDimitry Andric                           FI.OuterLimit, FI.OuterIncrement, FI.OuterBranch, SE))
476*e8d8bef9SDimitry Andric     return false;
477*e8d8bef9SDimitry Andric 
478*e8d8bef9SDimitry Andric   // Both of the loop limit values must be invariant in the outer loop
479*e8d8bef9SDimitry Andric   // (non-instructions are all inherently invariant).
480*e8d8bef9SDimitry Andric   if (!FI.OuterLoop->isLoopInvariant(FI.InnerLimit)) {
481*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "inner loop limit not invariant\n");
482*e8d8bef9SDimitry Andric     return false;
483*e8d8bef9SDimitry Andric   }
484*e8d8bef9SDimitry Andric   if (!FI.OuterLoop->isLoopInvariant(FI.OuterLimit)) {
485*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "outer loop limit not invariant\n");
486*e8d8bef9SDimitry Andric     return false;
487*e8d8bef9SDimitry Andric   }
488*e8d8bef9SDimitry Andric 
489*e8d8bef9SDimitry Andric   if (!checkPHIs(FI, TTI))
490*e8d8bef9SDimitry Andric     return false;
491*e8d8bef9SDimitry Andric 
492*e8d8bef9SDimitry Andric   // FIXME: it should be possible to handle different types correctly.
493*e8d8bef9SDimitry Andric   if (FI.InnerInductionPHI->getType() != FI.OuterInductionPHI->getType())
494*e8d8bef9SDimitry Andric     return false;
495*e8d8bef9SDimitry Andric 
496*e8d8bef9SDimitry Andric   if (!checkOuterLoopInsts(FI, IterationInstructions, TTI))
497*e8d8bef9SDimitry Andric     return false;
498*e8d8bef9SDimitry Andric 
499*e8d8bef9SDimitry Andric   // Find the values in the loop that can be replaced with the linearized
500*e8d8bef9SDimitry Andric   // induction variable, and check that there are no other uses of the inner
501*e8d8bef9SDimitry Andric   // or outer induction variable. If there were, we could still do this
502*e8d8bef9SDimitry Andric   // transformation, but we'd have to insert a div/mod to calculate the
503*e8d8bef9SDimitry Andric   // original IVs, so it wouldn't be profitable.
504*e8d8bef9SDimitry Andric   if (!checkIVUsers(FI))
505*e8d8bef9SDimitry Andric     return false;
506*e8d8bef9SDimitry Andric 
507*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "CanFlattenLoopPair: OK\n");
508*e8d8bef9SDimitry Andric   return true;
509*e8d8bef9SDimitry Andric }
510*e8d8bef9SDimitry Andric 
511*e8d8bef9SDimitry Andric static bool DoFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT,
512*e8d8bef9SDimitry Andric                               LoopInfo *LI, ScalarEvolution *SE,
513*e8d8bef9SDimitry Andric                               AssumptionCache *AC,
514*e8d8bef9SDimitry Andric                               const TargetTransformInfo *TTI) {
515*e8d8bef9SDimitry Andric   Function *F = FI.OuterLoop->getHeader()->getParent();
516*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n");
517*e8d8bef9SDimitry Andric   {
518*e8d8bef9SDimitry Andric     using namespace ore;
519*e8d8bef9SDimitry Andric     OptimizationRemark Remark(DEBUG_TYPE, "Flattened", FI.InnerLoop->getStartLoc(),
520*e8d8bef9SDimitry Andric                               FI.InnerLoop->getHeader());
521*e8d8bef9SDimitry Andric     OptimizationRemarkEmitter ORE(F);
522*e8d8bef9SDimitry Andric     Remark << "Flattened into outer loop";
523*e8d8bef9SDimitry Andric     ORE.emit(Remark);
524*e8d8bef9SDimitry Andric   }
525*e8d8bef9SDimitry Andric 
526*e8d8bef9SDimitry Andric   Value *NewTripCount =
527*e8d8bef9SDimitry Andric       BinaryOperator::CreateMul(FI.InnerLimit, FI.OuterLimit, "flatten.tripcount",
528*e8d8bef9SDimitry Andric                                 FI.OuterLoop->getLoopPreheader()->getTerminator());
529*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "Created new trip count in preheader: ";
530*e8d8bef9SDimitry Andric              NewTripCount->dump());
531*e8d8bef9SDimitry Andric 
532*e8d8bef9SDimitry Andric   // Fix up PHI nodes that take values from the inner loop back-edge, which
533*e8d8bef9SDimitry Andric   // we are about to remove.
534*e8d8bef9SDimitry Andric   FI.InnerInductionPHI->removeIncomingValue(FI.InnerLoop->getLoopLatch());
535*e8d8bef9SDimitry Andric 
536*e8d8bef9SDimitry Andric   // The old Phi will be optimised away later, but for now we can't leave
537*e8d8bef9SDimitry Andric   // leave it in an invalid state, so are updating them too.
538*e8d8bef9SDimitry Andric   for (PHINode *PHI : FI.InnerPHIsToTransform)
539*e8d8bef9SDimitry Andric     PHI->removeIncomingValue(FI.InnerLoop->getLoopLatch());
540*e8d8bef9SDimitry Andric 
541*e8d8bef9SDimitry Andric   // Modify the trip count of the outer loop to be the product of the two
542*e8d8bef9SDimitry Andric   // trip counts.
543*e8d8bef9SDimitry Andric   cast<User>(FI.OuterBranch->getCondition())->setOperand(1, NewTripCount);
544*e8d8bef9SDimitry Andric 
545*e8d8bef9SDimitry Andric   // Replace the inner loop backedge with an unconditional branch to the exit.
546*e8d8bef9SDimitry Andric   BasicBlock *InnerExitBlock = FI.InnerLoop->getExitBlock();
547*e8d8bef9SDimitry Andric   BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock();
548*e8d8bef9SDimitry Andric   InnerExitingBlock->getTerminator()->eraseFromParent();
549*e8d8bef9SDimitry Andric   BranchInst::Create(InnerExitBlock, InnerExitingBlock);
550*e8d8bef9SDimitry Andric   DT->deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader());
551*e8d8bef9SDimitry Andric 
552*e8d8bef9SDimitry Andric   // Replace all uses of the polynomial calculated from the two induction
553*e8d8bef9SDimitry Andric   // variables with the one new one.
554*e8d8bef9SDimitry Andric   IRBuilder<> Builder(FI.OuterInductionPHI->getParent()->getTerminator());
555*e8d8bef9SDimitry Andric   for (Value *V : FI.LinearIVUses) {
556*e8d8bef9SDimitry Andric     Value *OuterValue = FI.OuterInductionPHI;
557*e8d8bef9SDimitry Andric     if (FI.Widened)
558*e8d8bef9SDimitry Andric       OuterValue = Builder.CreateTrunc(FI.OuterInductionPHI, V->getType(),
559*e8d8bef9SDimitry Andric                                        "flatten.trunciv");
560*e8d8bef9SDimitry Andric 
561*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "Replacing: "; V->dump();
562*e8d8bef9SDimitry Andric                dbgs() << "with:      "; OuterValue->dump());
563*e8d8bef9SDimitry Andric     V->replaceAllUsesWith(OuterValue);
564*e8d8bef9SDimitry Andric   }
565*e8d8bef9SDimitry Andric 
566*e8d8bef9SDimitry Andric   // Tell LoopInfo, SCEV and the pass manager that the inner loop has been
567*e8d8bef9SDimitry Andric   // deleted, and any information that have about the outer loop invalidated.
568*e8d8bef9SDimitry Andric   SE->forgetLoop(FI.OuterLoop);
569*e8d8bef9SDimitry Andric   SE->forgetLoop(FI.InnerLoop);
570*e8d8bef9SDimitry Andric   LI->erase(FI.InnerLoop);
571*e8d8bef9SDimitry Andric   return true;
572*e8d8bef9SDimitry Andric }
573*e8d8bef9SDimitry Andric 
574*e8d8bef9SDimitry Andric static bool CanWidenIV(struct FlattenInfo &FI, DominatorTree *DT,
575*e8d8bef9SDimitry Andric                        LoopInfo *LI, ScalarEvolution *SE,
576*e8d8bef9SDimitry Andric                        AssumptionCache *AC, const TargetTransformInfo *TTI) {
577*e8d8bef9SDimitry Andric   if (!WidenIV) {
578*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "Widening the IVs is disabled\n");
579*e8d8bef9SDimitry Andric     return false;
580*e8d8bef9SDimitry Andric   }
581*e8d8bef9SDimitry Andric 
582*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "Try widening the IVs\n");
583*e8d8bef9SDimitry Andric   Module *M = FI.InnerLoop->getHeader()->getParent()->getParent();
584*e8d8bef9SDimitry Andric   auto &DL = M->getDataLayout();
585*e8d8bef9SDimitry Andric   auto *InnerType = FI.InnerInductionPHI->getType();
586*e8d8bef9SDimitry Andric   auto *OuterType = FI.OuterInductionPHI->getType();
587*e8d8bef9SDimitry Andric   unsigned MaxLegalSize = DL.getLargestLegalIntTypeSizeInBits();
588*e8d8bef9SDimitry Andric   auto *MaxLegalType = DL.getLargestLegalIntType(M->getContext());
589*e8d8bef9SDimitry Andric 
590*e8d8bef9SDimitry Andric   // If both induction types are less than the maximum legal integer width,
591*e8d8bef9SDimitry Andric   // promote both to the widest type available so we know calculating
592*e8d8bef9SDimitry Andric   // (OuterLimit * InnerLimit) as the new trip count is safe.
593*e8d8bef9SDimitry Andric   if (InnerType != OuterType ||
594*e8d8bef9SDimitry Andric       InnerType->getScalarSizeInBits() >= MaxLegalSize ||
595*e8d8bef9SDimitry Andric       MaxLegalType->getScalarSizeInBits() < InnerType->getScalarSizeInBits() * 2) {
596*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "Can't widen the IV\n");
597*e8d8bef9SDimitry Andric     return false;
598*e8d8bef9SDimitry Andric   }
599*e8d8bef9SDimitry Andric 
600*e8d8bef9SDimitry Andric   SCEVExpander Rewriter(*SE, DL, "loopflatten");
601*e8d8bef9SDimitry Andric   SmallVector<WideIVInfo, 2> WideIVs;
602*e8d8bef9SDimitry Andric   SmallVector<WeakTrackingVH, 4> DeadInsts;
603*e8d8bef9SDimitry Andric   WideIVs.push_back( {FI.InnerInductionPHI, MaxLegalType, false });
604*e8d8bef9SDimitry Andric   WideIVs.push_back( {FI.OuterInductionPHI, MaxLegalType, false });
605*e8d8bef9SDimitry Andric   unsigned ElimExt;
606*e8d8bef9SDimitry Andric   unsigned Widened;
607*e8d8bef9SDimitry Andric 
608*e8d8bef9SDimitry Andric   for (unsigned i = 0; i < WideIVs.size(); i++) {
609*e8d8bef9SDimitry Andric     PHINode *WidePhi = createWideIV(WideIVs[i], LI, SE, Rewriter, DT, DeadInsts,
610*e8d8bef9SDimitry Andric                                     ElimExt, Widened, true /* HasGuards */,
611*e8d8bef9SDimitry Andric                                     true /* UsePostIncrementRanges */);
612*e8d8bef9SDimitry Andric     if (!WidePhi)
613*e8d8bef9SDimitry Andric       return false;
614*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump());
615*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "Deleting old phi: "; WideIVs[i].NarrowIV->dump());
616*e8d8bef9SDimitry Andric     RecursivelyDeleteDeadPHINode(WideIVs[i].NarrowIV);
617*e8d8bef9SDimitry Andric   }
618*e8d8bef9SDimitry Andric   // After widening, rediscover all the loop components.
619*e8d8bef9SDimitry Andric   assert(Widened && "Widenend IV expected");
620*e8d8bef9SDimitry Andric   FI.Widened = true;
621*e8d8bef9SDimitry Andric   return CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI);
622*e8d8bef9SDimitry Andric }
623*e8d8bef9SDimitry Andric 
624*e8d8bef9SDimitry Andric static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT,
625*e8d8bef9SDimitry Andric                             LoopInfo *LI, ScalarEvolution *SE,
626*e8d8bef9SDimitry Andric                             AssumptionCache *AC,
627*e8d8bef9SDimitry Andric                             const TargetTransformInfo *TTI) {
628*e8d8bef9SDimitry Andric   LLVM_DEBUG(
629*e8d8bef9SDimitry Andric       dbgs() << "Loop flattening running on outer loop "
630*e8d8bef9SDimitry Andric              << FI.OuterLoop->getHeader()->getName() << " and inner loop "
631*e8d8bef9SDimitry Andric              << FI.InnerLoop->getHeader()->getName() << " in "
632*e8d8bef9SDimitry Andric              << FI.OuterLoop->getHeader()->getParent()->getName() << "\n");
633*e8d8bef9SDimitry Andric 
634*e8d8bef9SDimitry Andric   if (!CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI))
635*e8d8bef9SDimitry Andric     return false;
636*e8d8bef9SDimitry Andric 
637*e8d8bef9SDimitry Andric   // Check if we can widen the induction variables to avoid overflow checks.
638*e8d8bef9SDimitry Andric   if (CanWidenIV(FI, DT, LI, SE, AC, TTI))
639*e8d8bef9SDimitry Andric     return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI);
640*e8d8bef9SDimitry Andric 
641*e8d8bef9SDimitry Andric   // Check if the new iteration variable might overflow. In this case, we
642*e8d8bef9SDimitry Andric   // need to version the loop, and select the original version at runtime if
643*e8d8bef9SDimitry Andric   // the iteration space is too large.
644*e8d8bef9SDimitry Andric   // TODO: We currently don't version the loop.
645*e8d8bef9SDimitry Andric   OverflowResult OR = checkOverflow(FI, DT, AC);
646*e8d8bef9SDimitry Andric   if (OR == OverflowResult::AlwaysOverflowsHigh ||
647*e8d8bef9SDimitry Andric       OR == OverflowResult::AlwaysOverflowsLow) {
648*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n");
649*e8d8bef9SDimitry Andric     return false;
650*e8d8bef9SDimitry Andric   } else if (OR == OverflowResult::MayOverflow) {
651*e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n");
652*e8d8bef9SDimitry Andric     return false;
653*e8d8bef9SDimitry Andric   }
654*e8d8bef9SDimitry Andric 
655*e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n");
656*e8d8bef9SDimitry Andric   return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI);
657*e8d8bef9SDimitry Andric }
658*e8d8bef9SDimitry Andric 
659*e8d8bef9SDimitry Andric bool Flatten(DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE,
660*e8d8bef9SDimitry Andric              AssumptionCache *AC, TargetTransformInfo *TTI) {
661*e8d8bef9SDimitry Andric   bool Changed = false;
662*e8d8bef9SDimitry Andric   for (auto *InnerLoop : LI->getLoopsInPreorder()) {
663*e8d8bef9SDimitry Andric     auto *OuterLoop = InnerLoop->getParentLoop();
664*e8d8bef9SDimitry Andric     if (!OuterLoop)
665*e8d8bef9SDimitry Andric       continue;
666*e8d8bef9SDimitry Andric     struct FlattenInfo FI(OuterLoop, InnerLoop);
667*e8d8bef9SDimitry Andric     Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI);
668*e8d8bef9SDimitry Andric   }
669*e8d8bef9SDimitry Andric   return Changed;
670*e8d8bef9SDimitry Andric }
671*e8d8bef9SDimitry Andric 
672*e8d8bef9SDimitry Andric PreservedAnalyses LoopFlattenPass::run(Function &F,
673*e8d8bef9SDimitry Andric                                        FunctionAnalysisManager &AM) {
674*e8d8bef9SDimitry Andric   auto *DT = &AM.getResult<DominatorTreeAnalysis>(F);
675*e8d8bef9SDimitry Andric   auto *LI = &AM.getResult<LoopAnalysis>(F);
676*e8d8bef9SDimitry Andric   auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
677*e8d8bef9SDimitry Andric   auto *AC = &AM.getResult<AssumptionAnalysis>(F);
678*e8d8bef9SDimitry Andric   auto *TTI = &AM.getResult<TargetIRAnalysis>(F);
679*e8d8bef9SDimitry Andric 
680*e8d8bef9SDimitry Andric   if (!Flatten(DT, LI, SE, AC, TTI))
681*e8d8bef9SDimitry Andric     return PreservedAnalyses::all();
682*e8d8bef9SDimitry Andric 
683*e8d8bef9SDimitry Andric   PreservedAnalyses PA;
684*e8d8bef9SDimitry Andric   PA.preserveSet<CFGAnalyses>();
685*e8d8bef9SDimitry Andric   return PA;
686*e8d8bef9SDimitry Andric }
687*e8d8bef9SDimitry Andric 
688*e8d8bef9SDimitry Andric namespace {
689*e8d8bef9SDimitry Andric class LoopFlattenLegacyPass : public FunctionPass {
690*e8d8bef9SDimitry Andric public:
691*e8d8bef9SDimitry Andric   static char ID; // Pass ID, replacement for typeid
692*e8d8bef9SDimitry Andric   LoopFlattenLegacyPass() : FunctionPass(ID) {
693*e8d8bef9SDimitry Andric     initializeLoopFlattenLegacyPassPass(*PassRegistry::getPassRegistry());
694*e8d8bef9SDimitry Andric   }
695*e8d8bef9SDimitry Andric 
696*e8d8bef9SDimitry Andric   // Possibly flatten loop L into its child.
697*e8d8bef9SDimitry Andric   bool runOnFunction(Function &F) override;
698*e8d8bef9SDimitry Andric 
699*e8d8bef9SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
700*e8d8bef9SDimitry Andric     getLoopAnalysisUsage(AU);
701*e8d8bef9SDimitry Andric     AU.addRequired<TargetTransformInfoWrapperPass>();
702*e8d8bef9SDimitry Andric     AU.addPreserved<TargetTransformInfoWrapperPass>();
703*e8d8bef9SDimitry Andric     AU.addRequired<AssumptionCacheTracker>();
704*e8d8bef9SDimitry Andric     AU.addPreserved<AssumptionCacheTracker>();
705*e8d8bef9SDimitry Andric   }
706*e8d8bef9SDimitry Andric };
707*e8d8bef9SDimitry Andric } // namespace
708*e8d8bef9SDimitry Andric 
709*e8d8bef9SDimitry Andric char LoopFlattenLegacyPass::ID = 0;
710*e8d8bef9SDimitry Andric INITIALIZE_PASS_BEGIN(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
711*e8d8bef9SDimitry Andric                       false, false)
712*e8d8bef9SDimitry Andric INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
713*e8d8bef9SDimitry Andric INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
714*e8d8bef9SDimitry Andric INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
715*e8d8bef9SDimitry Andric                     false, false)
716*e8d8bef9SDimitry Andric 
717*e8d8bef9SDimitry Andric FunctionPass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); }
718*e8d8bef9SDimitry Andric 
719*e8d8bef9SDimitry Andric bool LoopFlattenLegacyPass::runOnFunction(Function &F) {
720*e8d8bef9SDimitry Andric   ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
721*e8d8bef9SDimitry Andric   LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
722*e8d8bef9SDimitry Andric   auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
723*e8d8bef9SDimitry Andric   DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr;
724*e8d8bef9SDimitry Andric   auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>();
725*e8d8bef9SDimitry Andric   auto *TTI = &TTIP.getTTI(F);
726*e8d8bef9SDimitry Andric   auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
727*e8d8bef9SDimitry Andric   return Flatten(DT, LI, SE, AC, TTI);
728*e8d8bef9SDimitry Andric }
729