xref: /llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp (revision 85c17e40926132575d1b98ca1a36b8394fe511cd)
1d53b4beeSSjoerd Meijer //===- LoopFlatten.cpp - Loop flattening pass------------------------------===//
2d53b4beeSSjoerd Meijer //
3d53b4beeSSjoerd Meijer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d53b4beeSSjoerd Meijer // See https://llvm.org/LICENSE.txt for license information.
5d53b4beeSSjoerd Meijer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d53b4beeSSjoerd Meijer //
7d53b4beeSSjoerd Meijer //===----------------------------------------------------------------------===//
8d53b4beeSSjoerd Meijer //
9d53b4beeSSjoerd Meijer // This pass flattens pairs nested loops into a single loop.
10d53b4beeSSjoerd Meijer //
11d53b4beeSSjoerd Meijer // The intention is to optimise loop nests like this, which together access an
12d53b4beeSSjoerd Meijer // array linearly:
13f6ac8088SSjoerd Meijer //
14d53b4beeSSjoerd Meijer //   for (int i = 0; i < N; ++i)
15d53b4beeSSjoerd Meijer //     for (int j = 0; j < M; ++j)
16d53b4beeSSjoerd Meijer //       f(A[i*M+j]);
17f6ac8088SSjoerd Meijer //
18d53b4beeSSjoerd Meijer // into one loop:
19f6ac8088SSjoerd Meijer //
20d53b4beeSSjoerd Meijer //   for (int i = 0; i < (N*M); ++i)
21d53b4beeSSjoerd Meijer //     f(A[i]);
22d53b4beeSSjoerd Meijer //
23d53b4beeSSjoerd Meijer // It can also flatten loops where the induction variables are not used in the
24d53b4beeSSjoerd Meijer // loop. This is only worth doing if the induction variables are only used in an
25d53b4beeSSjoerd Meijer // expression like i*M+j. If they had any other uses, we would have to insert a
26d53b4beeSSjoerd Meijer // div/mod to reconstruct the original values, so this wouldn't be profitable.
27d53b4beeSSjoerd Meijer //
28f6ac8088SSjoerd Meijer // We also need to prove that N*M will not overflow. The preferred solution is
29f6ac8088SSjoerd Meijer // to widen the IV, which avoids overflow checks, so that is tried first. If
30f6ac8088SSjoerd Meijer // the IV cannot be widened, then we try to determine that this new tripcount
31f6ac8088SSjoerd Meijer // expression won't overflow.
32f6ac8088SSjoerd Meijer //
33f6ac8088SSjoerd Meijer // Q: Does LoopFlatten use SCEV?
34f6ac8088SSjoerd Meijer // Short answer: Yes and no.
35f6ac8088SSjoerd Meijer //
36f6ac8088SSjoerd Meijer // Long answer:
37f6ac8088SSjoerd Meijer // For this transformation to be valid, we require all uses of the induction
38f6ac8088SSjoerd Meijer // variables to be linear expressions of the form i*M+j. The different Loop
39f6ac8088SSjoerd Meijer // APIs are used to get some loop components like the induction variable,
40f6ac8088SSjoerd Meijer // compare statement, etc. In addition, we do some pattern matching to find the
41f6ac8088SSjoerd Meijer // linear expressions and other loop components like the loop increment. The
42f6ac8088SSjoerd Meijer // latter are examples of expressions that do use the induction variable, but
43f6ac8088SSjoerd Meijer // are safe to ignore when we check all uses to be of the form i*M+j. We keep
44f6ac8088SSjoerd Meijer // track of all of this in bookkeeping struct FlattenInfo.
45f6ac8088SSjoerd Meijer // We assume the loops to be canonical, i.e. starting at 0 and increment with
46f6ac8088SSjoerd Meijer // 1. This makes RHS of the compare the loop tripcount (with the right
47f6ac8088SSjoerd Meijer // predicate). We use SCEV to then sanity check that this tripcount matches
48f6ac8088SSjoerd Meijer // with the tripcount as computed by SCEV.
49d53b4beeSSjoerd Meijer //
50d53b4beeSSjoerd Meijer //===----------------------------------------------------------------------===//
51d53b4beeSSjoerd Meijer 
52d53b4beeSSjoerd Meijer #include "llvm/Transforms/Scalar/LoopFlatten.h"
53e2217247SRosie Sumpter 
54e2217247SRosie Sumpter #include "llvm/ADT/Statistic.h"
55d53b4beeSSjoerd Meijer #include "llvm/Analysis/AssumptionCache.h"
56d53b4beeSSjoerd Meijer #include "llvm/Analysis/LoopInfo.h"
5759630917Sserge-sans-paille #include "llvm/Analysis/LoopNestAnalysis.h"
58d544a89aSSjoerd Meijer #include "llvm/Analysis/MemorySSAUpdater.h"
59d53b4beeSSjoerd Meijer #include "llvm/Analysis/OptimizationRemarkEmitter.h"
60d53b4beeSSjoerd Meijer #include "llvm/Analysis/ScalarEvolution.h"
61d53b4beeSSjoerd Meijer #include "llvm/Analysis/TargetTransformInfo.h"
62d53b4beeSSjoerd Meijer #include "llvm/Analysis/ValueTracking.h"
63d53b4beeSSjoerd Meijer #include "llvm/IR/Dominators.h"
64d53b4beeSSjoerd Meijer #include "llvm/IR/Function.h"
6533b2c88fSSjoerd Meijer #include "llvm/IR/IRBuilder.h"
66d53b4beeSSjoerd Meijer #include "llvm/IR/Module.h"
67d53b4beeSSjoerd Meijer #include "llvm/IR/PatternMatch.h"
68d53b4beeSSjoerd Meijer #include "llvm/Support/Debug.h"
69d53b4beeSSjoerd Meijer #include "llvm/Support/raw_ostream.h"
7059630917Sserge-sans-paille #include "llvm/Transforms/Scalar/LoopPassManager.h"
719aa77338SSjoerd Meijer #include "llvm/Transforms/Utils/Local.h"
72d53b4beeSSjoerd Meijer #include "llvm/Transforms/Utils/LoopUtils.h"
73a04d4a03SJohn Brawn #include "llvm/Transforms/Utils/LoopVersioning.h"
749aa77338SSjoerd Meijer #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
759aa77338SSjoerd Meijer #include "llvm/Transforms/Utils/SimplifyIndVar.h"
76bba55813SKazu Hirata #include <optional>
77d53b4beeSSjoerd Meijer 
78d53b4beeSSjoerd Meijer using namespace llvm;
79d53b4beeSSjoerd Meijer using namespace llvm::PatternMatch;
80d53b4beeSSjoerd Meijer 
81e2217247SRosie Sumpter #define DEBUG_TYPE "loop-flatten"
82e2217247SRosie Sumpter 
83e2217247SRosie Sumpter STATISTIC(NumFlattened, "Number of loops flattened");
84e2217247SRosie Sumpter 
85d53b4beeSSjoerd Meijer static cl::opt<unsigned> RepeatedInstructionThreshold(
86d53b4beeSSjoerd Meijer     "loop-flatten-cost-threshold", cl::Hidden, cl::init(2),
87d53b4beeSSjoerd Meijer     cl::desc("Limit on the cost of instructions that can be repeated due to "
88d53b4beeSSjoerd Meijer              "loop flattening"));
89d53b4beeSSjoerd Meijer 
90d53b4beeSSjoerd Meijer static cl::opt<bool>
91d53b4beeSSjoerd Meijer     AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden,
92d53b4beeSSjoerd Meijer                      cl::init(false),
93d53b4beeSSjoerd Meijer                      cl::desc("Assume that the product of the two iteration "
94491ac280SRosie Sumpter                               "trip counts will never overflow"));
95d53b4beeSSjoerd Meijer 
969aa77338SSjoerd Meijer static cl::opt<bool>
97d544a89aSSjoerd Meijer     WidenIV("loop-flatten-widen-iv", cl::Hidden, cl::init(true),
989aa77338SSjoerd Meijer             cl::desc("Widen the loop induction variables, if possible, so "
999aa77338SSjoerd Meijer                      "overflow checks won't reject flattening"));
1009aa77338SSjoerd Meijer 
101a04d4a03SJohn Brawn static cl::opt<bool>
102a04d4a03SJohn Brawn     VersionLoops("loop-flatten-version-loops", cl::Hidden, cl::init(true),
103a04d4a03SJohn Brawn                  cl::desc("Version loops if flattened loop could overflow"));
104a04d4a03SJohn Brawn 
105b6942a28SBenjamin Kramer namespace {
106f6ac8088SSjoerd Meijer // We require all uses of both induction variables to match this pattern:
107f6ac8088SSjoerd Meijer //
108f6ac8088SSjoerd Meijer //   (OuterPHI * InnerTripCount) + InnerPHI
109f6ac8088SSjoerd Meijer //
110f6ac8088SSjoerd Meijer // I.e., it needs to be a linear expression of the induction variables and the
111f6ac8088SSjoerd Meijer // inner loop trip count. We keep track of all different expressions on which
112f6ac8088SSjoerd Meijer // checks will be performed in this bookkeeping struct.
113f6ac8088SSjoerd Meijer //
114e2dcea44SSjoerd Meijer struct FlattenInfo {
115f6ac8088SSjoerd Meijer   Loop *OuterLoop = nullptr;  // The loop pair to be flattened.
116e2dcea44SSjoerd Meijer   Loop *InnerLoop = nullptr;
117f6ac8088SSjoerd Meijer 
118f6ac8088SSjoerd Meijer   PHINode *InnerInductionPHI = nullptr; // These PHINodes correspond to loop
119f6ac8088SSjoerd Meijer   PHINode *OuterInductionPHI = nullptr; // induction variables, which are
120f6ac8088SSjoerd Meijer                                         // expected to start at zero and
121f6ac8088SSjoerd Meijer                                         // increment by one on each loop.
122f6ac8088SSjoerd Meijer 
123f6ac8088SSjoerd Meijer   Value *InnerTripCount = nullptr; // The product of these two tripcounts
124f6ac8088SSjoerd Meijer   Value *OuterTripCount = nullptr; // will be the new flattened loop
125f6ac8088SSjoerd Meijer                                    // tripcount. Also used to recognise a
126f6ac8088SSjoerd Meijer                                    // linear expression that will be replaced.
127f6ac8088SSjoerd Meijer 
128f6ac8088SSjoerd Meijer   SmallPtrSet<Value *, 4> LinearIVUses;  // Contains the linear expressions
129f6ac8088SSjoerd Meijer                                          // of the form i*M+j that will be
130f6ac8088SSjoerd Meijer                                          // replaced.
131f6ac8088SSjoerd Meijer 
132f6ac8088SSjoerd Meijer   BinaryOperator *InnerIncrement = nullptr;  // Uses of induction variables in
133f6ac8088SSjoerd Meijer   BinaryOperator *OuterIncrement = nullptr;  // loop control statements that
134f6ac8088SSjoerd Meijer   BranchInst *InnerBranch = nullptr;         // are safe to ignore.
135f6ac8088SSjoerd Meijer 
136f6ac8088SSjoerd Meijer   BranchInst *OuterBranch = nullptr; // The instruction that needs to be
137f6ac8088SSjoerd Meijer                                      // updated with new tripcount.
138f6ac8088SSjoerd Meijer 
139e2dcea44SSjoerd Meijer   SmallPtrSet<PHINode *, 4> InnerPHIsToTransform;
140e2dcea44SSjoerd Meijer 
141f6ac8088SSjoerd Meijer   bool Widened = false; // Whether this holds the flatten info before or after
142f6ac8088SSjoerd Meijer                         // widening.
14333b2c88fSSjoerd Meijer 
144f6ac8088SSjoerd Meijer   PHINode *NarrowInnerInductionPHI = nullptr; // Holds the old/narrow induction
145f6ac8088SSjoerd Meijer   PHINode *NarrowOuterInductionPHI = nullptr; // phis, i.e. the Phis before IV
1460e37ef01SKazu Hirata                                               // has been applied. Used to skip
147f6ac8088SSjoerd Meijer                                               // checks on phi nodes.
1486a076fa9SSjoerd Meijer 
149a04d4a03SJohn Brawn   Value *NewTripCount = nullptr; // The tripcount of the flattened loop.
150a04d4a03SJohn Brawn 
151e2dcea44SSjoerd Meijer   FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){};
1520ea77502SSjoerd Meijer 
1530ea77502SSjoerd Meijer   bool isNarrowInductionPhi(PHINode *Phi) {
1540ea77502SSjoerd Meijer     // This can't be the narrow phi if we haven't widened the IV first.
1550ea77502SSjoerd Meijer     if (!Widened)
1560ea77502SSjoerd Meijer       return false;
1570ea77502SSjoerd Meijer     return NarrowInnerInductionPHI == Phi || NarrowOuterInductionPHI == Phi;
1580ea77502SSjoerd Meijer   }
159ada6d78aSSjoerd Meijer   bool isInnerLoopIncrement(User *U) {
160ada6d78aSSjoerd Meijer     return InnerIncrement == U;
161ada6d78aSSjoerd Meijer   }
162ada6d78aSSjoerd Meijer   bool isOuterLoopIncrement(User *U) {
163ada6d78aSSjoerd Meijer     return OuterIncrement == U;
164ada6d78aSSjoerd Meijer   }
165ada6d78aSSjoerd Meijer   bool isInnerLoopTest(User *U) {
166ada6d78aSSjoerd Meijer     return InnerBranch->getCondition() == U;
167ada6d78aSSjoerd Meijer   }
168ada6d78aSSjoerd Meijer 
169ada6d78aSSjoerd Meijer   bool checkOuterInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
170ada6d78aSSjoerd Meijer     for (User *U : OuterInductionPHI->users()) {
171ada6d78aSSjoerd Meijer       if (isOuterLoopIncrement(U))
172ada6d78aSSjoerd Meijer         continue;
173ada6d78aSSjoerd Meijer 
174ada6d78aSSjoerd Meijer       auto IsValidOuterPHIUses = [&] (User *U) -> bool {
175ada6d78aSSjoerd Meijer         LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
176ada6d78aSSjoerd Meijer         if (!ValidOuterPHIUses.count(U)) {
177ada6d78aSSjoerd Meijer           LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
178ada6d78aSSjoerd Meijer           return false;
179ada6d78aSSjoerd Meijer         }
180ada6d78aSSjoerd Meijer         LLVM_DEBUG(dbgs() << "Use is optimisable\n");
181ada6d78aSSjoerd Meijer         return true;
182ada6d78aSSjoerd Meijer       };
183ada6d78aSSjoerd Meijer 
184ada6d78aSSjoerd Meijer       if (auto *V = dyn_cast<TruncInst>(U)) {
185ada6d78aSSjoerd Meijer         for (auto *K : V->users()) {
186ada6d78aSSjoerd Meijer           if (!IsValidOuterPHIUses(K))
187ada6d78aSSjoerd Meijer             return false;
188ada6d78aSSjoerd Meijer         }
189ada6d78aSSjoerd Meijer         continue;
190ada6d78aSSjoerd Meijer       }
191ada6d78aSSjoerd Meijer 
192ada6d78aSSjoerd Meijer       if (!IsValidOuterPHIUses(U))
193ada6d78aSSjoerd Meijer         return false;
194ada6d78aSSjoerd Meijer     }
195ada6d78aSSjoerd Meijer     return true;
196ada6d78aSSjoerd Meijer   }
197ada6d78aSSjoerd Meijer 
198ada6d78aSSjoerd Meijer   bool matchLinearIVUser(User *U, Value *InnerTripCount,
199ada6d78aSSjoerd Meijer                          SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
200218e0c69SSjoerd Meijer     LLVM_DEBUG(dbgs() << "Checking linear i*M+j expression for: "; U->dump());
201ada6d78aSSjoerd Meijer     Value *MatchedMul = nullptr;
202ada6d78aSSjoerd Meijer     Value *MatchedItCount = nullptr;
203ada6d78aSSjoerd Meijer 
204ada6d78aSSjoerd Meijer     bool IsAdd = match(U, m_c_Add(m_Specific(InnerInductionPHI),
205ada6d78aSSjoerd Meijer                                   m_Value(MatchedMul))) &&
206ada6d78aSSjoerd Meijer                  match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI),
207ada6d78aSSjoerd Meijer                                            m_Value(MatchedItCount)));
208ada6d78aSSjoerd Meijer 
209ada6d78aSSjoerd Meijer     // Matches the same pattern as above, except it also looks for truncs
210ada6d78aSSjoerd Meijer     // on the phi, which can be the result of widening the induction variables.
211ada6d78aSSjoerd Meijer     bool IsAddTrunc =
212ada6d78aSSjoerd Meijer         match(U, m_c_Add(m_Trunc(m_Specific(InnerInductionPHI)),
213ada6d78aSSjoerd Meijer                          m_Value(MatchedMul))) &&
214ada6d78aSSjoerd Meijer         match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(OuterInductionPHI)),
215ada6d78aSSjoerd Meijer                                   m_Value(MatchedItCount)));
216ada6d78aSSjoerd Meijer 
217ae978baaSJohn Brawn     // Matches the pattern ptr+i*M+j, with the two additions being done via GEP.
218ae978baaSJohn Brawn     bool IsGEP = match(U, m_GEP(m_GEP(m_Value(), m_Value(MatchedMul)),
219ae978baaSJohn Brawn                                 m_Specific(InnerInductionPHI))) &&
220ae978baaSJohn Brawn                  match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI),
221ae978baaSJohn Brawn                                            m_Value(MatchedItCount)));
222ae978baaSJohn Brawn 
223ada6d78aSSjoerd Meijer     if (!MatchedItCount)
224ada6d78aSSjoerd Meijer       return false;
225ada6d78aSSjoerd Meijer 
226218e0c69SSjoerd Meijer     LLVM_DEBUG(dbgs() << "Matched multiplication: "; MatchedMul->dump());
227218e0c69SSjoerd Meijer     LLVM_DEBUG(dbgs() << "Matched iteration count: "; MatchedItCount->dump());
228218e0c69SSjoerd Meijer 
229161bfa5fSDavid Green     // The mul should not have any other uses. Widening may leave trivially dead
230161bfa5fSDavid Green     // uses, which can be ignored.
231161bfa5fSDavid Green     if (count_if(MatchedMul->users(), [](User *U) {
232161bfa5fSDavid Green           return !isInstructionTriviallyDead(cast<Instruction>(U));
233161bfa5fSDavid Green         }) > 1) {
234161bfa5fSDavid Green       LLVM_DEBUG(dbgs() << "Multiply has more than one use\n");
235161bfa5fSDavid Green       return false;
236161bfa5fSDavid Green     }
237161bfa5fSDavid Green 
238d73684e2SCraig Topper     // Look through extends if the IV has been widened. Don't look through
239d73684e2SCraig Topper     // extends if we already looked through a trunc.
240ae978baaSJohn Brawn     if (Widened && (IsAdd || IsGEP) &&
241ada6d78aSSjoerd Meijer         (isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
242ada6d78aSSjoerd Meijer       assert(MatchedItCount->getType() == InnerInductionPHI->getType() &&
243ada6d78aSSjoerd Meijer              "Unexpected type mismatch in types after widening");
244ada6d78aSSjoerd Meijer       MatchedItCount = isa<SExtInst>(MatchedItCount)
245ada6d78aSSjoerd Meijer                            ? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)
246ada6d78aSSjoerd Meijer                            : dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);
247ada6d78aSSjoerd Meijer     }
248ada6d78aSSjoerd Meijer 
249218e0c69SSjoerd Meijer     LLVM_DEBUG(dbgs() << "Looking for inner trip count: ";
250218e0c69SSjoerd Meijer                InnerTripCount->dump());
251218e0c69SSjoerd Meijer 
252ae978baaSJohn Brawn     if ((IsAdd || IsAddTrunc || IsGEP) && MatchedItCount == InnerTripCount) {
253218e0c69SSjoerd Meijer       LLVM_DEBUG(dbgs() << "Found. This sse is optimisable\n");
254ada6d78aSSjoerd Meijer       ValidOuterPHIUses.insert(MatchedMul);
255ada6d78aSSjoerd Meijer       LinearIVUses.insert(U);
256ada6d78aSSjoerd Meijer       return true;
257ada6d78aSSjoerd Meijer     }
258ada6d78aSSjoerd Meijer 
259ada6d78aSSjoerd Meijer     LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
260ada6d78aSSjoerd Meijer     return false;
261ada6d78aSSjoerd Meijer   }
262ada6d78aSSjoerd Meijer 
263ada6d78aSSjoerd Meijer   bool checkInnerInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
264ada6d78aSSjoerd Meijer     Value *SExtInnerTripCount = InnerTripCount;
265ada6d78aSSjoerd Meijer     if (Widened &&
266ada6d78aSSjoerd Meijer         (isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
267ada6d78aSSjoerd Meijer       SExtInnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
268ada6d78aSSjoerd Meijer 
269ada6d78aSSjoerd Meijer     for (User *U : InnerInductionPHI->users()) {
270218e0c69SSjoerd Meijer       LLVM_DEBUG(dbgs() << "Checking User: "; U->dump());
271218e0c69SSjoerd Meijer       if (isInnerLoopIncrement(U)) {
272218e0c69SSjoerd Meijer         LLVM_DEBUG(dbgs() << "Use is inner loop increment, continuing\n");
273ada6d78aSSjoerd Meijer         continue;
274218e0c69SSjoerd Meijer       }
275ada6d78aSSjoerd Meijer 
276ada6d78aSSjoerd Meijer       // After widening the IVs, a trunc instruction might have been introduced,
277ada6d78aSSjoerd Meijer       // so look through truncs.
278ada6d78aSSjoerd Meijer       if (isa<TruncInst>(U)) {
279ada6d78aSSjoerd Meijer         if (!U->hasOneUse())
280ada6d78aSSjoerd Meijer           return false;
281ada6d78aSSjoerd Meijer         U = *U->user_begin();
282ada6d78aSSjoerd Meijer       }
283ada6d78aSSjoerd Meijer 
284ada6d78aSSjoerd Meijer       // If the use is in the compare (which is also the condition of the inner
285ada6d78aSSjoerd Meijer       // branch) then the compare has been altered by another transformation e.g
286ada6d78aSSjoerd Meijer       // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is
287ada6d78aSSjoerd Meijer       // a constant. Ignore this use as the compare gets removed later anyway.
288218e0c69SSjoerd Meijer       if (isInnerLoopTest(U)) {
289218e0c69SSjoerd Meijer         LLVM_DEBUG(dbgs() << "Use is the inner loop test, continuing\n");
290ada6d78aSSjoerd Meijer         continue;
291218e0c69SSjoerd Meijer       }
292ada6d78aSSjoerd Meijer 
293218e0c69SSjoerd Meijer       if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses)) {
294218e0c69SSjoerd Meijer         LLVM_DEBUG(dbgs() << "Not a linear IV user\n");
295ada6d78aSSjoerd Meijer         return false;
296ada6d78aSSjoerd Meijer       }
297218e0c69SSjoerd Meijer       LLVM_DEBUG(dbgs() << "Linear IV users found!\n");
298218e0c69SSjoerd Meijer     }
299ada6d78aSSjoerd Meijer     return true;
300ada6d78aSSjoerd Meijer   }
301e2dcea44SSjoerd Meijer };
302b6942a28SBenjamin Kramer } // namespace
303e2dcea44SSjoerd Meijer 
30446abd1fbSRosie Sumpter static bool
30546abd1fbSRosie Sumpter setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment,
30646abd1fbSRosie Sumpter                   SmallPtrSetImpl<Instruction *> &IterationInstructions) {
30746abd1fbSRosie Sumpter   TripCount = TC;
30846abd1fbSRosie Sumpter   IterationInstructions.insert(Increment);
30946abd1fbSRosie Sumpter   LLVM_DEBUG(dbgs() << "Found Increment: "; Increment->dump());
31046abd1fbSRosie Sumpter   LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump());
31146abd1fbSRosie Sumpter   LLVM_DEBUG(dbgs() << "Successfully found all loop components\n");
31246abd1fbSRosie Sumpter   return true;
31346abd1fbSRosie Sumpter }
31446abd1fbSRosie Sumpter 
315ada6d78aSSjoerd Meijer // Given the RHS of the loop latch compare instruction, verify with SCEV
316ada6d78aSSjoerd Meijer // that this is indeed the loop tripcount.
317ada6d78aSSjoerd Meijer // TODO: This used to be a straightforward check but has grown to be quite
318ada6d78aSSjoerd Meijer // complicated now. It is therefore worth revisiting what the additional
319ada6d78aSSjoerd Meijer // benefits are of this (compared to relying on canonical loops and pattern
320ada6d78aSSjoerd Meijer // matching).
321ada6d78aSSjoerd Meijer static bool verifyTripCount(Value *RHS, Loop *L,
322ada6d78aSSjoerd Meijer      SmallPtrSetImpl<Instruction *> &IterationInstructions,
323ada6d78aSSjoerd Meijer     PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment,
324ada6d78aSSjoerd Meijer     BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) {
325ada6d78aSSjoerd Meijer   const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
326ada6d78aSSjoerd Meijer   if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
327ada6d78aSSjoerd Meijer     LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
328ada6d78aSSjoerd Meijer     return false;
329ada6d78aSSjoerd Meijer   }
330ada6d78aSSjoerd Meijer 
33109d879d0SPhilip Reames   // Evaluating in the trip count's type can not overflow here as the overflow
33209d879d0SPhilip Reames   // checks are performed in checkOverflow, but are first tried to avoid by
33309d879d0SPhilip Reames   // widening the IV.
334ada6d78aSSjoerd Meijer   const SCEV *SCEVTripCount =
33509d879d0SPhilip Reames     SE->getTripCountFromExitCount(BackedgeTakenCount,
33609d879d0SPhilip Reames                                   BackedgeTakenCount->getType(), L);
337ada6d78aSSjoerd Meijer 
338ada6d78aSSjoerd Meijer   const SCEV *SCEVRHS = SE->getSCEV(RHS);
339ada6d78aSSjoerd Meijer   if (SCEVRHS == SCEVTripCount)
340ada6d78aSSjoerd Meijer     return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
341ada6d78aSSjoerd Meijer   ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
342ada6d78aSSjoerd Meijer   if (ConstantRHS) {
343ada6d78aSSjoerd Meijer     const SCEV *BackedgeTCExt = nullptr;
344ada6d78aSSjoerd Meijer     if (IsWidened) {
345ada6d78aSSjoerd Meijer       const SCEV *SCEVTripCountExt;
346ada6d78aSSjoerd Meijer       // Find the extended backedge taken count and extended trip count using
347ada6d78aSSjoerd Meijer       // SCEV. One of these should now match the RHS of the compare.
348ada6d78aSSjoerd Meijer       BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
34909d879d0SPhilip Reames       SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt,
35009d879d0SPhilip Reames                                                        RHS->getType(), L);
351ada6d78aSSjoerd Meijer       if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
352ada6d78aSSjoerd Meijer         LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
353ada6d78aSSjoerd Meijer         return false;
354ada6d78aSSjoerd Meijer       }
355ada6d78aSSjoerd Meijer     }
356ada6d78aSSjoerd Meijer     // If the RHS of the compare is equal to the backedge taken count we need
357ada6d78aSSjoerd Meijer     // to add one to get the trip count.
358ada6d78aSSjoerd Meijer     if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
359dea16ebdSPaul Walker       Value *NewRHS = ConstantInt::get(ConstantRHS->getContext(),
360dea16ebdSPaul Walker                                        ConstantRHS->getValue() + 1);
361ada6d78aSSjoerd Meijer       return setLoopComponents(NewRHS, TripCount, Increment,
362ada6d78aSSjoerd Meijer                                IterationInstructions);
363ada6d78aSSjoerd Meijer     }
364ada6d78aSSjoerd Meijer     return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
365ada6d78aSSjoerd Meijer   }
366ada6d78aSSjoerd Meijer   // If the RHS isn't a constant then check that the reason it doesn't match
367ada6d78aSSjoerd Meijer   // the SCEV trip count is because the RHS is a ZExt or SExt instruction
368ada6d78aSSjoerd Meijer   // (and take the trip count to be the RHS).
369ada6d78aSSjoerd Meijer   if (!IsWidened) {
370ada6d78aSSjoerd Meijer     LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
371ada6d78aSSjoerd Meijer     return false;
372ada6d78aSSjoerd Meijer   }
373ada6d78aSSjoerd Meijer   auto *TripCountInst = dyn_cast<Instruction>(RHS);
374ada6d78aSSjoerd Meijer   if (!TripCountInst) {
375ada6d78aSSjoerd Meijer     LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
376ada6d78aSSjoerd Meijer     return false;
377ada6d78aSSjoerd Meijer   }
378ada6d78aSSjoerd Meijer   if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
379ada6d78aSSjoerd Meijer       SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
380ada6d78aSSjoerd Meijer     LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
381ada6d78aSSjoerd Meijer     return false;
382ada6d78aSSjoerd Meijer   }
383ada6d78aSSjoerd Meijer   return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
384ada6d78aSSjoerd Meijer }
385ada6d78aSSjoerd Meijer 
386491ac280SRosie Sumpter // Finds the induction variable, increment and trip count for a simple loop that
387491ac280SRosie Sumpter // we can flatten.
388d53b4beeSSjoerd Meijer static bool findLoopComponents(
389d53b4beeSSjoerd Meijer     Loop *L, SmallPtrSetImpl<Instruction *> &IterationInstructions,
390491ac280SRosie Sumpter     PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment,
391491ac280SRosie Sumpter     BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) {
392d53b4beeSSjoerd Meijer   LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n");
393d53b4beeSSjoerd Meijer 
394d53b4beeSSjoerd Meijer   if (!L->isLoopSimplifyForm()) {
395d53b4beeSSjoerd Meijer     LLVM_DEBUG(dbgs() << "Loop is not in normal form\n");
396d53b4beeSSjoerd Meijer     return false;
397d53b4beeSSjoerd Meijer   }
398d53b4beeSSjoerd Meijer 
399491ac280SRosie Sumpter   // Currently, to simplify the implementation, the Loop induction variable must
400491ac280SRosie Sumpter   // start at zero and increment with a step size of one.
401491ac280SRosie Sumpter   if (!L->isCanonical(*SE)) {
402491ac280SRosie Sumpter     LLVM_DEBUG(dbgs() << "Loop is not canonical\n");
403491ac280SRosie Sumpter     return false;
404491ac280SRosie Sumpter   }
405491ac280SRosie Sumpter 
406d53b4beeSSjoerd Meijer   // There must be exactly one exiting block, and it must be the same at the
407d53b4beeSSjoerd Meijer   // latch.
408d53b4beeSSjoerd Meijer   BasicBlock *Latch = L->getLoopLatch();
409d53b4beeSSjoerd Meijer   if (L->getExitingBlock() != Latch) {
410d53b4beeSSjoerd Meijer     LLVM_DEBUG(dbgs() << "Exiting and latch block are different\n");
411d53b4beeSSjoerd Meijer     return false;
412d53b4beeSSjoerd Meijer   }
413d53b4beeSSjoerd Meijer 
414d53b4beeSSjoerd Meijer   // Find the induction PHI. If there is no induction PHI, we can't do the
415d53b4beeSSjoerd Meijer   // transformation. TODO: could other variables trigger this? Do we have to
416d53b4beeSSjoerd Meijer   // search for the best one?
41734d68205SRosie Sumpter   InductionPHI = L->getInductionVariable(*SE);
418d53b4beeSSjoerd Meijer   if (!InductionPHI) {
419d53b4beeSSjoerd Meijer     LLVM_DEBUG(dbgs() << "Could not find induction PHI\n");
420d53b4beeSSjoerd Meijer     return false;
421d53b4beeSSjoerd Meijer   }
42234d68205SRosie Sumpter   LLVM_DEBUG(dbgs() << "Found induction PHI: "; InductionPHI->dump());
423d53b4beeSSjoerd Meijer 
42444c9adb4SRosie Sumpter   bool ContinueOnTrue = L->contains(Latch->getTerminator()->getSuccessor(0));
425d53b4beeSSjoerd Meijer   auto IsValidPredicate = [&](ICmpInst::Predicate Pred) {
426d53b4beeSSjoerd Meijer     if (ContinueOnTrue)
427d53b4beeSSjoerd Meijer       return Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT;
428d53b4beeSSjoerd Meijer     else
429d53b4beeSSjoerd Meijer       return Pred == CmpInst::ICMP_EQ;
430d53b4beeSSjoerd Meijer   };
431d53b4beeSSjoerd Meijer 
43244c9adb4SRosie Sumpter   // Find Compare and make sure it is valid. getLatchCmpInst checks that the
43344c9adb4SRosie Sumpter   // back branch of the latch is conditional.
43444c9adb4SRosie Sumpter   ICmpInst *Compare = L->getLatchCmpInst();
435d53b4beeSSjoerd Meijer   if (!Compare || !IsValidPredicate(Compare->getUnsignedPredicate()) ||
436d53b4beeSSjoerd Meijer       Compare->hasNUsesOrMore(2)) {
437d53b4beeSSjoerd Meijer     LLVM_DEBUG(dbgs() << "Could not find valid comparison\n");
438d53b4beeSSjoerd Meijer     return false;
439d53b4beeSSjoerd Meijer   }
44044c9adb4SRosie Sumpter   BackBranch = cast<BranchInst>(Latch->getTerminator());
44144c9adb4SRosie Sumpter   IterationInstructions.insert(BackBranch);
44244c9adb4SRosie Sumpter   LLVM_DEBUG(dbgs() << "Found back branch: "; BackBranch->dump());
443d53b4beeSSjoerd Meijer   IterationInstructions.insert(Compare);
444d53b4beeSSjoerd Meijer   LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump());
445d53b4beeSSjoerd Meijer 
446491ac280SRosie Sumpter   // Find increment and trip count.
447491ac280SRosie Sumpter   // There are exactly 2 incoming values to the induction phi; one from the
448491ac280SRosie Sumpter   // pre-header and one from the latch. The incoming latch value is the
449491ac280SRosie Sumpter   // increment variable.
450491ac280SRosie Sumpter   Increment =
451fdd58435SCraig Topper       cast<BinaryOperator>(InductionPHI->getIncomingValueForBlock(Latch));
4528e9e22f0SDavid Green   if ((Compare->getOperand(0) != Increment || !Increment->hasNUses(2)) &&
4538e9e22f0SDavid Green       !Increment->hasNUses(1)) {
454491ac280SRosie Sumpter     LLVM_DEBUG(dbgs() << "Could not find valid increment\n");
455d53b4beeSSjoerd Meijer     return false;
456d53b4beeSSjoerd Meijer   }
457491ac280SRosie Sumpter   // The trip count is the RHS of the compare. If this doesn't match the trip
45846abd1fbSRosie Sumpter   // count computed by SCEV then this is because the trip count variable
45946abd1fbSRosie Sumpter   // has been widened so the types don't match, or because it is a constant and
46046abd1fbSRosie Sumpter   // another transformation has changed the compare (e.g. icmp ult %inc,
46146abd1fbSRosie Sumpter   // tripcount -> icmp ult %j, tripcount-1), or both.
46246abd1fbSRosie Sumpter   Value *RHS = Compare->getOperand(1);
463ada6d78aSSjoerd Meijer 
464ada6d78aSSjoerd Meijer   return verifyTripCount(RHS, L, IterationInstructions, InductionPHI, TripCount,
465ada6d78aSSjoerd Meijer                          Increment, BackBranch, SE, IsWidened);
466d53b4beeSSjoerd Meijer }
467d53b4beeSSjoerd Meijer 
4688fde25b3STa-Wei Tu static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) {
469d53b4beeSSjoerd Meijer   // All PHIs in the inner and outer headers must either be:
470d53b4beeSSjoerd Meijer   // - The induction PHI, which we are going to rewrite as one induction in
471d53b4beeSSjoerd Meijer   //   the new loop. This is already checked by findLoopComponents.
472d53b4beeSSjoerd Meijer   // - An outer header PHI with all incoming values from outside the loop.
473d53b4beeSSjoerd Meijer   //   LoopSimplify guarantees we have a pre-header, so we don't need to
474d53b4beeSSjoerd Meijer   //   worry about that here.
475d53b4beeSSjoerd Meijer   // - Pairs of PHIs in the inner and outer headers, which implement a
476d53b4beeSSjoerd Meijer   //   loop-carried dependency that will still be valid in the new loop. To
477d53b4beeSSjoerd Meijer   //   be valid, this variable must be modified only in the inner loop.
478d53b4beeSSjoerd Meijer 
479d53b4beeSSjoerd Meijer   // The set of PHI nodes in the outer loop header that we know will still be
480d53b4beeSSjoerd Meijer   // valid after the transformation. These will not need to be modified (with
481d53b4beeSSjoerd Meijer   // the exception of the induction variable), but we do need to check that
482d53b4beeSSjoerd Meijer   // there are no unsafe PHI nodes.
483d53b4beeSSjoerd Meijer   SmallPtrSet<PHINode *, 4> SafeOuterPHIs;
484e2dcea44SSjoerd Meijer   SafeOuterPHIs.insert(FI.OuterInductionPHI);
485d53b4beeSSjoerd Meijer 
486d53b4beeSSjoerd Meijer   // Check that all PHI nodes in the inner loop header match one of the valid
487d53b4beeSSjoerd Meijer   // patterns.
488e2dcea44SSjoerd Meijer   for (PHINode &InnerPHI : FI.InnerLoop->getHeader()->phis()) {
489d53b4beeSSjoerd Meijer     // The induction PHIs break these rules, and that's OK because we treat
490d53b4beeSSjoerd Meijer     // them specially when doing the transformation.
491e2dcea44SSjoerd Meijer     if (&InnerPHI == FI.InnerInductionPHI)
492d53b4beeSSjoerd Meijer       continue;
4930ea77502SSjoerd Meijer     if (FI.isNarrowInductionPhi(&InnerPHI))
4946a076fa9SSjoerd Meijer       continue;
495d53b4beeSSjoerd Meijer 
496d53b4beeSSjoerd Meijer     // Each inner loop PHI node must have two incoming values/blocks - one
497d53b4beeSSjoerd Meijer     // from the pre-header, and one from the latch.
498d53b4beeSSjoerd Meijer     assert(InnerPHI.getNumIncomingValues() == 2);
499d53b4beeSSjoerd Meijer     Value *PreHeaderValue =
500e2dcea44SSjoerd Meijer         InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopPreheader());
501d53b4beeSSjoerd Meijer     Value *LatchValue =
502e2dcea44SSjoerd Meijer         InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopLatch());
503d53b4beeSSjoerd Meijer 
504d53b4beeSSjoerd Meijer     // The incoming value from the outer loop must be the PHI node in the
505d53b4beeSSjoerd Meijer     // outer loop header, with no modifications made in the top of the outer
506d53b4beeSSjoerd Meijer     // loop.
507d53b4beeSSjoerd Meijer     PHINode *OuterPHI = dyn_cast<PHINode>(PreHeaderValue);
508e2dcea44SSjoerd Meijer     if (!OuterPHI || OuterPHI->getParent() != FI.OuterLoop->getHeader()) {
509d53b4beeSSjoerd Meijer       LLVM_DEBUG(dbgs() << "value modified in top of outer loop\n");
510d53b4beeSSjoerd Meijer       return false;
511d53b4beeSSjoerd Meijer     }
512d53b4beeSSjoerd Meijer 
513d53b4beeSSjoerd Meijer     // The other incoming value must come from the inner loop, without any
514d53b4beeSSjoerd Meijer     // modifications in the tail end of the outer loop. We are in LCSSA form,
515d53b4beeSSjoerd Meijer     // so this will actually be a PHI in the inner loop's exit block, which
516d53b4beeSSjoerd Meijer     // only uses values from inside the inner loop.
517d53b4beeSSjoerd Meijer     PHINode *LCSSAPHI = dyn_cast<PHINode>(
518e2dcea44SSjoerd Meijer         OuterPHI->getIncomingValueForBlock(FI.OuterLoop->getLoopLatch()));
519d53b4beeSSjoerd Meijer     if (!LCSSAPHI) {
520d53b4beeSSjoerd Meijer       LLVM_DEBUG(dbgs() << "could not find LCSSA PHI\n");
521d53b4beeSSjoerd Meijer       return false;
522d53b4beeSSjoerd Meijer     }
523d53b4beeSSjoerd Meijer 
524d53b4beeSSjoerd Meijer     // The value used by the LCSSA PHI must be the same one that the inner
525d53b4beeSSjoerd Meijer     // loop's PHI uses.
526d53b4beeSSjoerd Meijer     if (LCSSAPHI->hasConstantValue() != LatchValue) {
527d53b4beeSSjoerd Meijer       LLVM_DEBUG(
528d53b4beeSSjoerd Meijer           dbgs() << "LCSSA PHI incoming value does not match latch value\n");
529d53b4beeSSjoerd Meijer       return false;
530d53b4beeSSjoerd Meijer     }
531d53b4beeSSjoerd Meijer 
532d53b4beeSSjoerd Meijer     LLVM_DEBUG(dbgs() << "PHI pair is safe:\n");
533d53b4beeSSjoerd Meijer     LLVM_DEBUG(dbgs() << "  Inner: "; InnerPHI.dump());
534d53b4beeSSjoerd Meijer     LLVM_DEBUG(dbgs() << "  Outer: "; OuterPHI->dump());
535d53b4beeSSjoerd Meijer     SafeOuterPHIs.insert(OuterPHI);
536e2dcea44SSjoerd Meijer     FI.InnerPHIsToTransform.insert(&InnerPHI);
537d53b4beeSSjoerd Meijer   }
538d53b4beeSSjoerd Meijer 
539e2dcea44SSjoerd Meijer   for (PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) {
5400ea77502SSjoerd Meijer     if (FI.isNarrowInductionPhi(&OuterPHI))
5416a076fa9SSjoerd Meijer       continue;
542d53b4beeSSjoerd Meijer     if (!SafeOuterPHIs.count(&OuterPHI)) {
543d53b4beeSSjoerd Meijer       LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: "; OuterPHI.dump());
544d53b4beeSSjoerd Meijer       return false;
545d53b4beeSSjoerd Meijer     }
546d53b4beeSSjoerd Meijer   }
547d53b4beeSSjoerd Meijer 
5489aa77338SSjoerd Meijer   LLVM_DEBUG(dbgs() << "checkPHIs: OK\n");
549d53b4beeSSjoerd Meijer   return true;
550d53b4beeSSjoerd Meijer }
551d53b4beeSSjoerd Meijer 
552d53b4beeSSjoerd Meijer static bool
5538fde25b3STa-Wei Tu checkOuterLoopInsts(FlattenInfo &FI,
554d53b4beeSSjoerd Meijer                     SmallPtrSetImpl<Instruction *> &IterationInstructions,
555e2dcea44SSjoerd Meijer                     const TargetTransformInfo *TTI) {
556d53b4beeSSjoerd Meijer   // Check for instructions in the outer but not inner loop. If any of these
557d53b4beeSSjoerd Meijer   // have side-effects then this transformation is not legal, and if there is
558d53b4beeSSjoerd Meijer   // a significant amount of code here which can't be optimised out that it's
559d53b4beeSSjoerd Meijer   // not profitable (as these instructions would get executed for each
560d53b4beeSSjoerd Meijer   // iteration of the inner loop).
561ae27274bSSander de Smalen   InstructionCost RepeatedInstrCost = 0;
562e2dcea44SSjoerd Meijer   for (auto *B : FI.OuterLoop->getBlocks()) {
563e2dcea44SSjoerd Meijer     if (FI.InnerLoop->contains(B))
564d53b4beeSSjoerd Meijer       continue;
565d53b4beeSSjoerd Meijer 
566d53b4beeSSjoerd Meijer     for (auto &I : *B) {
567d53b4beeSSjoerd Meijer       if (!isa<PHINode>(&I) && !I.isTerminator() &&
568d53b4beeSSjoerd Meijer           !isSafeToSpeculativelyExecute(&I)) {
569d53b4beeSSjoerd Meijer         LLVM_DEBUG(dbgs() << "Cannot flatten because instruction may have "
570d53b4beeSSjoerd Meijer                              "side effects: ";
571d53b4beeSSjoerd Meijer                    I.dump());
572d53b4beeSSjoerd Meijer         return false;
573d53b4beeSSjoerd Meijer       }
574d53b4beeSSjoerd Meijer       // The execution count of the outer loop's iteration instructions
575d53b4beeSSjoerd Meijer       // (increment, compare and branch) will be increased, but the
576d53b4beeSSjoerd Meijer       // equivalent instructions will be removed from the inner loop, so
577d53b4beeSSjoerd Meijer       // they make a net difference of zero.
578d53b4beeSSjoerd Meijer       if (IterationInstructions.count(&I))
579d53b4beeSSjoerd Meijer         continue;
580a2d45017SKazu Hirata       // The unconditional branch to the inner loop's header will turn into
581d53b4beeSSjoerd Meijer       // a fall-through, so adds no cost.
582d53b4beeSSjoerd Meijer       BranchInst *Br = dyn_cast<BranchInst>(&I);
583d53b4beeSSjoerd Meijer       if (Br && Br->isUnconditional() &&
584e2dcea44SSjoerd Meijer           Br->getSuccessor(0) == FI.InnerLoop->getHeader())
585d53b4beeSSjoerd Meijer         continue;
586d53b4beeSSjoerd Meijer       // Multiplies of the outer iteration variable and inner iteration
587d53b4beeSSjoerd Meijer       // count will be optimised out.
588e2dcea44SSjoerd Meijer       if (match(&I, m_c_Mul(m_Specific(FI.OuterInductionPHI),
589491ac280SRosie Sumpter                             m_Specific(FI.InnerTripCount))))
590d53b4beeSSjoerd Meijer         continue;
591ae27274bSSander de Smalen       InstructionCost Cost =
592fdec5018SSimon Pilgrim           TTI->getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
593d53b4beeSSjoerd Meijer       LLVM_DEBUG(dbgs() << "Cost " << Cost << ": "; I.dump());
594d53b4beeSSjoerd Meijer       RepeatedInstrCost += Cost;
595d53b4beeSSjoerd Meijer     }
596d53b4beeSSjoerd Meijer   }
597d53b4beeSSjoerd Meijer 
598d53b4beeSSjoerd Meijer   LLVM_DEBUG(dbgs() << "Cost of instructions that will be repeated: "
599d53b4beeSSjoerd Meijer                     << RepeatedInstrCost << "\n");
600d53b4beeSSjoerd Meijer   // Bail out if flattening the loops would cause instructions in the outer
601d53b4beeSSjoerd Meijer   // loop but not in the inner loop to be executed extra times.
6029aa77338SSjoerd Meijer   if (RepeatedInstrCost > RepeatedInstructionThreshold) {
6039aa77338SSjoerd Meijer     LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: not profitable, bailing.\n");
604d53b4beeSSjoerd Meijer     return false;
6059aa77338SSjoerd Meijer   }
606d53b4beeSSjoerd Meijer 
6079aa77338SSjoerd Meijer   LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: OK\n");
608d53b4beeSSjoerd Meijer   return true;
609d53b4beeSSjoerd Meijer }
610d53b4beeSSjoerd Meijer 
611ada6d78aSSjoerd Meijer 
612ada6d78aSSjoerd Meijer 
613d53b4beeSSjoerd Meijer // We require all uses of both induction variables to match this pattern:
614d53b4beeSSjoerd Meijer //
615491ac280SRosie Sumpter //   (OuterPHI * InnerTripCount) + InnerPHI
616d53b4beeSSjoerd Meijer //
617d53b4beeSSjoerd Meijer // Any uses of the induction variables not matching that pattern would
618d53b4beeSSjoerd Meijer // require a div/mod to reconstruct in the flattened loop, so the
619d53b4beeSSjoerd Meijer // transformation wouldn't be profitable.
620ada6d78aSSjoerd Meijer static bool checkIVUsers(FlattenInfo &FI) {
621d53b4beeSSjoerd Meijer   // Check that all uses of the inner loop's induction variable match the
622d53b4beeSSjoerd Meijer   // expected pattern, recording the uses of the outer IV.
623d53b4beeSSjoerd Meijer   SmallPtrSet<Value *, 4> ValidOuterPHIUses;
624ada6d78aSSjoerd Meijer   if (!FI.checkInnerInductionPhiUsers(ValidOuterPHIUses))
6259aa77338SSjoerd Meijer     return false;
626d53b4beeSSjoerd Meijer 
627d53b4beeSSjoerd Meijer   // Check that there are no uses of the outer IV other than the ones found
628d53b4beeSSjoerd Meijer   // as part of the pattern above.
629ada6d78aSSjoerd Meijer   if (!FI.checkOuterInductionPhiUsers(ValidOuterPHIUses))
630d53b4beeSSjoerd Meijer     return false;
6319aa77338SSjoerd Meijer 
6329aa77338SSjoerd Meijer   LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n";
6339aa77338SSjoerd Meijer              dbgs() << "Found " << FI.LinearIVUses.size()
634d53b4beeSSjoerd Meijer                     << " value(s) that can be replaced:\n";
635e2dcea44SSjoerd Meijer              for (Value *V : FI.LinearIVUses) {
636d53b4beeSSjoerd Meijer                dbgs() << "  ";
637d53b4beeSSjoerd Meijer                V->dump();
638d53b4beeSSjoerd Meijer              });
639d53b4beeSSjoerd Meijer   return true;
640d53b4beeSSjoerd Meijer }
641d53b4beeSSjoerd Meijer 
642d53b4beeSSjoerd Meijer // Return an OverflowResult dependant on if overflow of the multiplication of
643491ac280SRosie Sumpter // InnerTripCount and OuterTripCount can be assumed not to happen.
6448fde25b3STa-Wei Tu static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
6458fde25b3STa-Wei Tu                                     AssumptionCache *AC) {
646e2dcea44SSjoerd Meijer   Function *F = FI.OuterLoop->getHeader()->getParent();
6479df71d76SNikita Popov   const DataLayout &DL = F->getDataLayout();
648d53b4beeSSjoerd Meijer 
649d53b4beeSSjoerd Meijer   // For debugging/testing.
650d53b4beeSSjoerd Meijer   if (AssumeNoOverflow)
651d53b4beeSSjoerd Meijer     return OverflowResult::NeverOverflows;
652d53b4beeSSjoerd Meijer 
653d53b4beeSSjoerd Meijer   // Check if the multiply could not overflow due to known ranges of the
654d53b4beeSSjoerd Meijer   // input values.
655d53b4beeSSjoerd Meijer   OverflowResult OR = computeOverflowForUnsignedMul(
6561b3cc4e7SNikita Popov       FI.InnerTripCount, FI.OuterTripCount,
6571b3cc4e7SNikita Popov       SimplifyQuery(DL, DT, AC,
6581b3cc4e7SNikita Popov                     FI.OuterLoop->getLoopPreheader()->getTerminator()));
659d53b4beeSSjoerd Meijer   if (OR != OverflowResult::MayOverflow)
660d53b4beeSSjoerd Meijer     return OR;
661d53b4beeSSjoerd Meijer 
662ae978baaSJohn Brawn   auto CheckGEP = [&](GetElementPtrInst *GEP, Value *GEPOperand) {
663ae978baaSJohn Brawn     for (Value *GEPUser : GEP->users()) {
66466383038SSimon Pilgrim       auto *GEPUserInst = cast<Instruction>(GEPUser);
665d1aa0751SRosie Sumpter       if (!isa<LoadInst>(GEPUserInst) &&
666ae978baaSJohn Brawn           !(isa<StoreInst>(GEPUserInst) && GEP == GEPUserInst->getOperand(1)))
667d1aa0751SRosie Sumpter         continue;
668ae978baaSJohn Brawn       if (!isGuaranteedToExecuteForEveryIteration(GEPUserInst, FI.InnerLoop))
669d1aa0751SRosie Sumpter         continue;
670d1aa0751SRosie Sumpter       // The IV is used as the operand of a GEP which dominates the loop
671d1aa0751SRosie Sumpter       // latch, and the IV is at least as wide as the address space of the
672d1aa0751SRosie Sumpter       // GEP. In this case, the GEP would wrap around the address space
673d1aa0751SRosie Sumpter       // before the IV increment wraps, which would be UB.
674d53b4beeSSjoerd Meijer       if (GEP->isInBounds() &&
675ae978baaSJohn Brawn           GEPOperand->getType()->getIntegerBitWidth() >=
676d53b4beeSSjoerd Meijer               DL.getPointerTypeSizeInBits(GEP->getType())) {
677d53b4beeSSjoerd Meijer         LLVM_DEBUG(
678d53b4beeSSjoerd Meijer             dbgs() << "use of linear IV would be UB if overflow occurred: ";
679d53b4beeSSjoerd Meijer             GEP->dump());
680ae978baaSJohn Brawn         return true;
681ae978baaSJohn Brawn       }
682ae978baaSJohn Brawn     }
683ae978baaSJohn Brawn     return false;
684ae978baaSJohn Brawn   };
685ae978baaSJohn Brawn 
686ae978baaSJohn Brawn   // Check if any IV user is, or is used by, a GEP that would cause UB if the
687ae978baaSJohn Brawn   // multiply overflows.
688ae978baaSJohn Brawn   for (Value *V : FI.LinearIVUses) {
689ae978baaSJohn Brawn     if (auto *GEP = dyn_cast<GetElementPtrInst>(V))
690ae978baaSJohn Brawn       if (GEP->getNumIndices() == 1 && CheckGEP(GEP, GEP->getOperand(1)))
691d53b4beeSSjoerd Meijer         return OverflowResult::NeverOverflows;
692ae978baaSJohn Brawn     for (Value *U : V->users())
693ae978baaSJohn Brawn       if (auto *GEP = dyn_cast<GetElementPtrInst>(U))
694ae978baaSJohn Brawn         if (CheckGEP(GEP, V))
695ae978baaSJohn Brawn           return OverflowResult::NeverOverflows;
696d1aa0751SRosie Sumpter   }
697d53b4beeSSjoerd Meijer 
698d53b4beeSSjoerd Meijer   return OverflowResult::MayOverflow;
699d53b4beeSSjoerd Meijer }
700d53b4beeSSjoerd Meijer 
7018fde25b3STa-Wei Tu static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
7028fde25b3STa-Wei Tu                                ScalarEvolution *SE, AssumptionCache *AC,
7038fde25b3STa-Wei Tu                                const TargetTransformInfo *TTI) {
704d53b4beeSSjoerd Meijer   SmallPtrSet<Instruction *, 8> IterationInstructions;
705491ac280SRosie Sumpter   if (!findLoopComponents(FI.InnerLoop, IterationInstructions,
706491ac280SRosie Sumpter                           FI.InnerInductionPHI, FI.InnerTripCount,
707491ac280SRosie Sumpter                           FI.InnerIncrement, FI.InnerBranch, SE, FI.Widened))
708d53b4beeSSjoerd Meijer     return false;
709491ac280SRosie Sumpter   if (!findLoopComponents(FI.OuterLoop, IterationInstructions,
710491ac280SRosie Sumpter                           FI.OuterInductionPHI, FI.OuterTripCount,
711491ac280SRosie Sumpter                           FI.OuterIncrement, FI.OuterBranch, SE, FI.Widened))
712d53b4beeSSjoerd Meijer     return false;
713d53b4beeSSjoerd Meijer 
714491ac280SRosie Sumpter   // Both of the loop trip count values must be invariant in the outer loop
715d53b4beeSSjoerd Meijer   // (non-instructions are all inherently invariant).
716491ac280SRosie Sumpter   if (!FI.OuterLoop->isLoopInvariant(FI.InnerTripCount)) {
717491ac280SRosie Sumpter     LLVM_DEBUG(dbgs() << "inner loop trip count not invariant\n");
718d53b4beeSSjoerd Meijer     return false;
719d53b4beeSSjoerd Meijer   }
720491ac280SRosie Sumpter   if (!FI.OuterLoop->isLoopInvariant(FI.OuterTripCount)) {
721491ac280SRosie Sumpter     LLVM_DEBUG(dbgs() << "outer loop trip count not invariant\n");
722d53b4beeSSjoerd Meijer     return false;
723d53b4beeSSjoerd Meijer   }
724d53b4beeSSjoerd Meijer 
725e2dcea44SSjoerd Meijer   if (!checkPHIs(FI, TTI))
726d53b4beeSSjoerd Meijer     return false;
727d53b4beeSSjoerd Meijer 
728d53b4beeSSjoerd Meijer   // FIXME: it should be possible to handle different types correctly.
729e2dcea44SSjoerd Meijer   if (FI.InnerInductionPHI->getType() != FI.OuterInductionPHI->getType())
730d53b4beeSSjoerd Meijer     return false;
731d53b4beeSSjoerd Meijer 
732e2dcea44SSjoerd Meijer   if (!checkOuterLoopInsts(FI, IterationInstructions, TTI))
733d53b4beeSSjoerd Meijer     return false;
734d53b4beeSSjoerd Meijer 
735d53b4beeSSjoerd Meijer   // Find the values in the loop that can be replaced with the linearized
736d53b4beeSSjoerd Meijer   // induction variable, and check that there are no other uses of the inner
737d53b4beeSSjoerd Meijer   // or outer induction variable. If there were, we could still do this
738d53b4beeSSjoerd Meijer   // transformation, but we'd have to insert a div/mod to calculate the
739d53b4beeSSjoerd Meijer   // original IVs, so it wouldn't be profitable.
740e2dcea44SSjoerd Meijer   if (!checkIVUsers(FI))
741d53b4beeSSjoerd Meijer     return false;
742d53b4beeSSjoerd Meijer 
7439aa77338SSjoerd Meijer   LLVM_DEBUG(dbgs() << "CanFlattenLoopPair: OK\n");
7449aa77338SSjoerd Meijer   return true;
745d53b4beeSSjoerd Meijer }
746d53b4beeSSjoerd Meijer 
7478fde25b3STa-Wei Tu static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
7488fde25b3STa-Wei Tu                               ScalarEvolution *SE, AssumptionCache *AC,
749d544a89aSSjoerd Meijer                               const TargetTransformInfo *TTI, LPMUpdater *U,
750d544a89aSSjoerd Meijer                               MemorySSAUpdater *MSSAU) {
7519aa77338SSjoerd Meijer   Function *F = FI.OuterLoop->getHeader()->getParent();
752d53b4beeSSjoerd Meijer   LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n");
753d53b4beeSSjoerd Meijer   {
754d53b4beeSSjoerd Meijer     using namespace ore;
755e2dcea44SSjoerd Meijer     OptimizationRemark Remark(DEBUG_TYPE, "Flattened", FI.InnerLoop->getStartLoc(),
756e2dcea44SSjoerd Meijer                               FI.InnerLoop->getHeader());
757d53b4beeSSjoerd Meijer     OptimizationRemarkEmitter ORE(F);
758d53b4beeSSjoerd Meijer     Remark << "Flattened into outer loop";
759d53b4beeSSjoerd Meijer     ORE.emit(Remark);
760d53b4beeSSjoerd Meijer   }
761d53b4beeSSjoerd Meijer 
762a04d4a03SJohn Brawn   if (!FI.NewTripCount) {
763a04d4a03SJohn Brawn     FI.NewTripCount = BinaryOperator::CreateMul(
764491ac280SRosie Sumpter         FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount",
7652fe81edeSJeremy Morse         FI.OuterLoop->getLoopPreheader()->getTerminator()->getIterator());
766d53b4beeSSjoerd Meijer     LLVM_DEBUG(dbgs() << "Created new trip count in preheader: ";
767a04d4a03SJohn Brawn                FI.NewTripCount->dump());
768a04d4a03SJohn Brawn   }
769d53b4beeSSjoerd Meijer 
770d53b4beeSSjoerd Meijer   // Fix up PHI nodes that take values from the inner loop back-edge, which
771d53b4beeSSjoerd Meijer   // we are about to remove.
772e2dcea44SSjoerd Meijer   FI.InnerInductionPHI->removeIncomingValue(FI.InnerLoop->getLoopLatch());
7739aa77338SSjoerd Meijer 
7749aa77338SSjoerd Meijer   // The old Phi will be optimised away later, but for now we can't leave
7759aa77338SSjoerd Meijer   // leave it in an invalid state, so are updating them too.
776e2dcea44SSjoerd Meijer   for (PHINode *PHI : FI.InnerPHIsToTransform)
777e2dcea44SSjoerd Meijer     PHI->removeIncomingValue(FI.InnerLoop->getLoopLatch());
778d53b4beeSSjoerd Meijer 
779d53b4beeSSjoerd Meijer   // Modify the trip count of the outer loop to be the product of the two
780d53b4beeSSjoerd Meijer   // trip counts.
781a04d4a03SJohn Brawn   cast<User>(FI.OuterBranch->getCondition())->setOperand(1, FI.NewTripCount);
782d53b4beeSSjoerd Meijer 
783d53b4beeSSjoerd Meijer   // Replace the inner loop backedge with an unconditional branch to the exit.
784e2dcea44SSjoerd Meijer   BasicBlock *InnerExitBlock = FI.InnerLoop->getExitBlock();
785e2dcea44SSjoerd Meijer   BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock();
786e0f4d27aSShan Huang   Instruction *Term = InnerExitingBlock->getTerminator();
787e0f4d27aSShan Huang   Instruction *BI = BranchInst::Create(InnerExitBlock, InnerExitingBlock);
788e0f4d27aSShan Huang   BI->setDebugLoc(Term->getDebugLoc());
789e0f4d27aSShan Huang   Term->eraseFromParent();
790d544a89aSSjoerd Meijer 
791d544a89aSSjoerd Meijer   // Update the DomTree and MemorySSA.
792e2dcea44SSjoerd Meijer   DT->deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader());
793d544a89aSSjoerd Meijer   if (MSSAU)
794d544a89aSSjoerd Meijer     MSSAU->removeEdge(InnerExitingBlock, FI.InnerLoop->getHeader());
795d53b4beeSSjoerd Meijer 
796d53b4beeSSjoerd Meijer   // Replace all uses of the polynomial calculated from the two induction
797d53b4beeSSjoerd Meijer   // variables with the one new one.
79833b2c88fSSjoerd Meijer   IRBuilder<> Builder(FI.OuterInductionPHI->getParent()->getTerminator());
7999aa77338SSjoerd Meijer   for (Value *V : FI.LinearIVUses) {
80033b2c88fSSjoerd Meijer     Value *OuterValue = FI.OuterInductionPHI;
80133b2c88fSSjoerd Meijer     if (FI.Widened)
80233b2c88fSSjoerd Meijer       OuterValue = Builder.CreateTrunc(FI.OuterInductionPHI, V->getType(),
80333b2c88fSSjoerd Meijer                                        "flatten.trunciv");
80433b2c88fSSjoerd Meijer 
805ae978baaSJohn Brawn     if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {
806ae978baaSJohn Brawn       // Replace the GEP with one that uses OuterValue as the offset.
807ae978baaSJohn Brawn       auto *InnerGEP = cast<GetElementPtrInst>(GEP->getOperand(0));
808ae978baaSJohn Brawn       Value *Base = InnerGEP->getOperand(0);
809ae978baaSJohn Brawn       // When the base of the GEP doesn't dominate the outer induction phi then
810ae978baaSJohn Brawn       // we need to insert the new GEP where the old GEP was.
811ae978baaSJohn Brawn       if (!DT->dominates(Base, &*Builder.GetInsertPoint()))
812ae978baaSJohn Brawn         Builder.SetInsertPoint(cast<Instruction>(V));
8134d1ecf19SAtariDreams       OuterValue =
8144d1ecf19SAtariDreams           Builder.CreateGEP(GEP->getSourceElementType(), Base, OuterValue,
8154d1ecf19SAtariDreams                             "flatten." + V->getName(),
8164d1ecf19SAtariDreams                             GEP->isInBounds() && InnerGEP->isInBounds());
817ae978baaSJohn Brawn     }
818ae978baaSJohn Brawn 
819d544a89aSSjoerd Meijer     LLVM_DEBUG(dbgs() << "Replacing: "; V->dump(); dbgs() << "with:      ";
820d544a89aSSjoerd Meijer                OuterValue->dump());
82133b2c88fSSjoerd Meijer     V->replaceAllUsesWith(OuterValue);
8229aa77338SSjoerd Meijer   }
823d53b4beeSSjoerd Meijer 
824d53b4beeSSjoerd Meijer   // Tell LoopInfo, SCEV and the pass manager that the inner loop has been
825b2b4d958SJoshua Cao   // deleted, and invalidate any outer loop information.
826e2dcea44SSjoerd Meijer   SE->forgetLoop(FI.OuterLoop);
82798eb9179Sluxufan   SE->forgetBlockAndLoopDispositions();
828e3129fb7SNikita Popov   if (U)
829e3129fb7SNikita Popov     U->markLoopAsDeleted(*FI.InnerLoop, FI.InnerLoop->getName());
830e2dcea44SSjoerd Meijer   LI->erase(FI.InnerLoop);
831e2217247SRosie Sumpter 
832e2217247SRosie Sumpter   // Increment statistic value.
833e2217247SRosie Sumpter   NumFlattened++;
834e2217247SRosie Sumpter 
835d53b4beeSSjoerd Meijer   return true;
836d53b4beeSSjoerd Meijer }
837d53b4beeSSjoerd Meijer 
8388fde25b3STa-Wei Tu static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
8398fde25b3STa-Wei Tu                        ScalarEvolution *SE, AssumptionCache *AC,
8408fde25b3STa-Wei Tu                        const TargetTransformInfo *TTI) {
8419aa77338SSjoerd Meijer   if (!WidenIV) {
8429aa77338SSjoerd Meijer     LLVM_DEBUG(dbgs() << "Widening the IVs is disabled\n");
8439aa77338SSjoerd Meijer     return false;
8449aa77338SSjoerd Meijer   }
8459aa77338SSjoerd Meijer 
8469aa77338SSjoerd Meijer   LLVM_DEBUG(dbgs() << "Try widening the IVs\n");
8479aa77338SSjoerd Meijer   Module *M = FI.InnerLoop->getHeader()->getParent()->getParent();
8489aa77338SSjoerd Meijer   auto &DL = M->getDataLayout();
8499aa77338SSjoerd Meijer   auto *InnerType = FI.InnerInductionPHI->getType();
8509aa77338SSjoerd Meijer   auto *OuterType = FI.OuterInductionPHI->getType();
8519aa77338SSjoerd Meijer   unsigned MaxLegalSize = DL.getLargestLegalIntTypeSizeInBits();
8529aa77338SSjoerd Meijer   auto *MaxLegalType = DL.getLargestLegalIntType(M->getContext());
8539aa77338SSjoerd Meijer 
8549aa77338SSjoerd Meijer   // If both induction types are less than the maximum legal integer width,
8559aa77338SSjoerd Meijer   // promote both to the widest type available so we know calculating
856491ac280SRosie Sumpter   // (OuterTripCount * InnerTripCount) as the new trip count is safe.
8579aa77338SSjoerd Meijer   if (InnerType != OuterType ||
8589aa77338SSjoerd Meijer       InnerType->getScalarSizeInBits() >= MaxLegalSize ||
859d544a89aSSjoerd Meijer       MaxLegalType->getScalarSizeInBits() <
860d544a89aSSjoerd Meijer           InnerType->getScalarSizeInBits() * 2) {
8619aa77338SSjoerd Meijer     LLVM_DEBUG(dbgs() << "Can't widen the IV\n");
8629aa77338SSjoerd Meijer     return false;
8639aa77338SSjoerd Meijer   }
8649aa77338SSjoerd Meijer 
8659aa77338SSjoerd Meijer   SCEVExpander Rewriter(*SE, DL, "loopflatten");
8669aa77338SSjoerd Meijer   SmallVector<WeakTrackingVH, 4> DeadInsts;
867b8aba76aSSimon Pilgrim   unsigned ElimExt = 0;
868b8aba76aSSimon Pilgrim   unsigned Widened = 0;
8699aa77338SSjoerd Meijer 
8706a076fa9SSjoerd Meijer   auto CreateWideIV = [&](WideIVInfo WideIV, bool &Deleted) -> bool {
871d544a89aSSjoerd Meijer     PHINode *WidePhi =
872d544a89aSSjoerd Meijer         createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts, ElimExt, Widened,
873d544a89aSSjoerd Meijer                      true /* HasGuards */, true /* UsePostIncrementRanges */);
8749aa77338SSjoerd Meijer     if (!WidePhi)
8759aa77338SSjoerd Meijer       return false;
8769aa77338SSjoerd Meijer     LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump());
877b8aba76aSSimon Pilgrim     LLVM_DEBUG(dbgs() << "Deleting old phi: "; WideIV.NarrowIV->dump());
8786a076fa9SSjoerd Meijer     Deleted = RecursivelyDeleteDeadPHINode(WideIV.NarrowIV);
8796a076fa9SSjoerd Meijer     return true;
8806a076fa9SSjoerd Meijer   };
8816a076fa9SSjoerd Meijer 
8826a076fa9SSjoerd Meijer   bool Deleted;
8836a076fa9SSjoerd Meijer   if (!CreateWideIV({FI.InnerInductionPHI, MaxLegalType, false}, Deleted))
8846a076fa9SSjoerd Meijer     return false;
8850ea77502SSjoerd Meijer   // Add the narrow phi to list, so that it will be adjusted later when the
8860ea77502SSjoerd Meijer   // the transformation is performed.
8876a076fa9SSjoerd Meijer   if (!Deleted)
8880ea77502SSjoerd Meijer     FI.InnerPHIsToTransform.insert(FI.InnerInductionPHI);
8890ea77502SSjoerd Meijer 
8906a076fa9SSjoerd Meijer   if (!CreateWideIV({FI.OuterInductionPHI, MaxLegalType, false}, Deleted))
8916a076fa9SSjoerd Meijer     return false;
8926a076fa9SSjoerd Meijer 
893b8aba76aSSimon Pilgrim   assert(Widened && "Widened IV expected");
89433b2c88fSSjoerd Meijer   FI.Widened = true;
8956a076fa9SSjoerd Meijer 
8966a076fa9SSjoerd Meijer   // Save the old/narrow induction phis, which we need to ignore in CheckPHIs.
8970ea77502SSjoerd Meijer   FI.NarrowInnerInductionPHI = FI.InnerInductionPHI;
8980ea77502SSjoerd Meijer   FI.NarrowOuterInductionPHI = FI.OuterInductionPHI;
8996a076fa9SSjoerd Meijer 
9006a076fa9SSjoerd Meijer   // After widening, rediscover all the loop components.
9019aa77338SSjoerd Meijer   return CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI);
9029aa77338SSjoerd Meijer }
9039aa77338SSjoerd Meijer 
9048fde25b3STa-Wei Tu static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
9058fde25b3STa-Wei Tu                             ScalarEvolution *SE, AssumptionCache *AC,
906d544a89aSSjoerd Meijer                             const TargetTransformInfo *TTI, LPMUpdater *U,
907a04d4a03SJohn Brawn                             MemorySSAUpdater *MSSAU,
908a04d4a03SJohn Brawn                             const LoopAccessInfo &LAI) {
9092e7455f0SBenjamin Kramer   LLVM_DEBUG(
9102e7455f0SBenjamin Kramer       dbgs() << "Loop flattening running on outer loop "
9119aa77338SSjoerd Meijer              << FI.OuterLoop->getHeader()->getName() << " and inner loop "
9129aa77338SSjoerd Meijer              << FI.InnerLoop->getHeader()->getName() << " in "
9132e7455f0SBenjamin Kramer              << FI.OuterLoop->getHeader()->getParent()->getName() << "\n");
9149aa77338SSjoerd Meijer 
9159aa77338SSjoerd Meijer   if (!CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI))
9169aa77338SSjoerd Meijer     return false;
9179aa77338SSjoerd Meijer 
9189aa77338SSjoerd Meijer   // Check if we can widen the induction variables to avoid overflow checks.
919367df180SSjoerd Meijer   bool CanFlatten = CanWidenIV(FI, DT, LI, SE, AC, TTI);
920367df180SSjoerd Meijer 
921367df180SSjoerd Meijer   // It can happen that after widening of the IV, flattening may not be
922367df180SSjoerd Meijer   // possible/happening, e.g. when it is deemed unprofitable. So bail here if
923367df180SSjoerd Meijer   // that is the case.
924367df180SSjoerd Meijer   // TODO: IV widening without performing the actual flattening transformation
925367df180SSjoerd Meijer   // is not ideal. While this codegen change should not matter much, it is an
926367df180SSjoerd Meijer   // unnecessary change which is better to avoid. It's unlikely this happens
927367df180SSjoerd Meijer   // often, because if it's unprofitibale after widening, it should be
928367df180SSjoerd Meijer   // unprofitabe before widening as checked in the first round of checks. But
929367df180SSjoerd Meijer   // 'RepeatedInstructionThreshold' is set to only 2, which can probably be
930367df180SSjoerd Meijer   // relaxed. Because this is making a code change (the IV widening, but not
931367df180SSjoerd Meijer   // the flattening), we return true here.
932367df180SSjoerd Meijer   if (FI.Widened && !CanFlatten)
933367df180SSjoerd Meijer     return true;
934367df180SSjoerd Meijer 
935367df180SSjoerd Meijer   // If we have widened and can perform the transformation, do that here.
936367df180SSjoerd Meijer   if (CanFlatten)
937d544a89aSSjoerd Meijer     return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
9389aa77338SSjoerd Meijer 
939367df180SSjoerd Meijer   // Otherwise, if we haven't widened the IV, check if the new iteration
940367df180SSjoerd Meijer   // variable might overflow. In this case, we need to version the loop, and
941367df180SSjoerd Meijer   // select the original version at runtime if the iteration space is too
942367df180SSjoerd Meijer   // large.
9439aa77338SSjoerd Meijer   OverflowResult OR = checkOverflow(FI, DT, AC);
9449aa77338SSjoerd Meijer   if (OR == OverflowResult::AlwaysOverflowsHigh ||
9459aa77338SSjoerd Meijer       OR == OverflowResult::AlwaysOverflowsLow) {
9469aa77338SSjoerd Meijer     LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n");
9479aa77338SSjoerd Meijer     return false;
9489aa77338SSjoerd Meijer   } else if (OR == OverflowResult::MayOverflow) {
949a04d4a03SJohn Brawn     Module *M = FI.OuterLoop->getHeader()->getParent()->getParent();
950a04d4a03SJohn Brawn     const DataLayout &DL = M->getDataLayout();
951a04d4a03SJohn Brawn     if (!VersionLoops) {
9529aa77338SSjoerd Meijer       LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n");
9539aa77338SSjoerd Meijer       return false;
954a04d4a03SJohn Brawn     } else if (!DL.isLegalInteger(
955a04d4a03SJohn Brawn                    FI.OuterTripCount->getType()->getScalarSizeInBits())) {
956a04d4a03SJohn Brawn       // If the trip count type isn't legal then it won't be possible to check
957a04d4a03SJohn Brawn       // for overflow using only a single multiply instruction, so don't
958a04d4a03SJohn Brawn       // flatten.
959a04d4a03SJohn Brawn       LLVM_DEBUG(
960a04d4a03SJohn Brawn           dbgs() << "Can't check overflow efficiently, not flattening\n");
961a04d4a03SJohn Brawn       return false;
962a04d4a03SJohn Brawn     }
963a04d4a03SJohn Brawn     LLVM_DEBUG(dbgs() << "Multiply might overflow, versioning loop\n");
964a04d4a03SJohn Brawn 
965a04d4a03SJohn Brawn     // Version the loop. The overflow check isn't a runtime pointer check, so we
966a04d4a03SJohn Brawn     // pass an empty list of runtime pointer checks, causing LoopVersioning to
967a04d4a03SJohn Brawn     // emit 'false' as the branch condition, and add our own check afterwards.
968a04d4a03SJohn Brawn     BasicBlock *CheckBlock = FI.OuterLoop->getLoopPreheader();
969a04d4a03SJohn Brawn     ArrayRef<RuntimePointerCheck> Checks(nullptr, nullptr);
970a04d4a03SJohn Brawn     LoopVersioning LVer(LAI, Checks, FI.OuterLoop, LI, DT, SE);
971a04d4a03SJohn Brawn     LVer.versionLoop();
972a04d4a03SJohn Brawn 
973a04d4a03SJohn Brawn     // Check for overflow by calculating the new tripcount using
974a04d4a03SJohn Brawn     // umul_with_overflow and then checking if it overflowed.
975a04d4a03SJohn Brawn     BranchInst *Br = cast<BranchInst>(CheckBlock->getTerminator());
976a04d4a03SJohn Brawn     assert(Br->isConditional() &&
977a04d4a03SJohn Brawn            "Expected LoopVersioning to generate a conditional branch");
978a04d4a03SJohn Brawn     assert(match(Br->getCondition(), m_Zero()) &&
979a04d4a03SJohn Brawn            "Expected branch condition to be false");
980a04d4a03SJohn Brawn     IRBuilder<> Builder(Br);
981*85c17e40SJay Foad     Value *Call = Builder.CreateIntrinsic(
982*85c17e40SJay Foad         Intrinsic::umul_with_overflow, FI.OuterTripCount->getType(),
983*85c17e40SJay Foad         {FI.OuterTripCount, FI.InnerTripCount},
984*85c17e40SJay Foad         /*FMFSource=*/nullptr, "flatten.mul");
985a04d4a03SJohn Brawn     FI.NewTripCount = Builder.CreateExtractValue(Call, 0, "flatten.tripcount");
986a04d4a03SJohn Brawn     Value *Overflow = Builder.CreateExtractValue(Call, 1, "flatten.overflow");
987a04d4a03SJohn Brawn     Br->setCondition(Overflow);
988a04d4a03SJohn Brawn   } else {
989a04d4a03SJohn Brawn     LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n");
9909aa77338SSjoerd Meijer   }
9919aa77338SSjoerd Meijer 
992d544a89aSSjoerd Meijer   return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
9939aa77338SSjoerd Meijer }
9949aa77338SSjoerd Meijer 
995fa488ea8SeopXD PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
996fa488ea8SeopXD                                        LoopStandardAnalysisResults &AR,
997fa488ea8SeopXD                                        LPMUpdater &U) {
998706ead0eSSjoerd Meijer 
9991124ad2fSStelios Ioannou   bool Changed = false;
10001124ad2fSStelios Ioannou 
1001bba55813SKazu Hirata   std::optional<MemorySSAUpdater> MSSAU;
1002d544a89aSSjoerd Meijer   if (AR.MSSA) {
1003d544a89aSSjoerd Meijer     MSSAU = MemorySSAUpdater(AR.MSSA);
1004d544a89aSSjoerd Meijer     if (VerifyMemorySSA)
1005d544a89aSSjoerd Meijer       AR.MSSA->verifyMemorySSA();
1006d544a89aSSjoerd Meijer   }
1007d544a89aSSjoerd Meijer 
10081124ad2fSStelios Ioannou   // The loop flattening pass requires loops to be
10091124ad2fSStelios Ioannou   // in simplified form, and also needs LCSSA. Running
10101124ad2fSStelios Ioannou   // this pass will simplify all loops that contain inner loops,
10111124ad2fSStelios Ioannou   // regardless of whether anything ends up being flattened.
101228767afdSFlorian Hahn   LoopAccessInfoManager LAIM(AR.SE, AR.AA, AR.DT, AR.LI, &AR.TTI, nullptr);
1013e7d3f43eSFangrui Song   for (Loop *InnerLoop : LN.getLoops()) {
1014e7d3f43eSFangrui Song     auto *OuterLoop = InnerLoop->getParentLoop();
1015e7d3f43eSFangrui Song     if (!OuterLoop)
1016e7d3f43eSFangrui Song       continue;
1017e7d3f43eSFangrui Song     FlattenInfo FI(OuterLoop, InnerLoop);
1018a04d4a03SJohn Brawn     Changed |=
1019a04d4a03SJohn Brawn         FlattenLoopPair(FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U,
1020a04d4a03SJohn Brawn                         MSSAU ? &*MSSAU : nullptr, LAIM.getInfo(*OuterLoop));
1021e7d3f43eSFangrui Song   }
10221124ad2fSStelios Ioannou 
10231124ad2fSStelios Ioannou   if (!Changed)
1024d53b4beeSSjoerd Meijer     return PreservedAnalyses::all();
1025d53b4beeSSjoerd Meijer 
1026d544a89aSSjoerd Meijer   if (AR.MSSA && VerifyMemorySSA)
1027d544a89aSSjoerd Meijer     AR.MSSA->verifyMemorySSA();
1028d544a89aSSjoerd Meijer 
1029d544a89aSSjoerd Meijer   auto PA = getLoopPassPreservedAnalyses();
1030d544a89aSSjoerd Meijer   if (AR.MSSA)
1031d544a89aSSjoerd Meijer     PA.preserve<MemorySSAAnalysis>();
1032d544a89aSSjoerd Meijer   return PA;
1033d53b4beeSSjoerd Meijer }
1034