xref: /llvm-project/llvm/lib/Transforms/Scalar/LoopTermFold.cpp (revision 94f9cbbe49b4c836cfbed046637cdc0c63a4a083)
1*27a62ec7SPhilip Reames //===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//
2*27a62ec7SPhilip Reames //
3*27a62ec7SPhilip Reames // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*27a62ec7SPhilip Reames // See https://llvm.org/LICENSE.txt for license information.
5*27a62ec7SPhilip Reames // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*27a62ec7SPhilip Reames //
7*27a62ec7SPhilip Reames //===----------------------------------------------------------------------===//
8*27a62ec7SPhilip Reames //===----------------------------------------------------------------------===//
9*27a62ec7SPhilip Reames 
10*27a62ec7SPhilip Reames #include "llvm/Transforms/Scalar/LoopTermFold.h"
11*27a62ec7SPhilip Reames #include "llvm/ADT/Statistic.h"
12*27a62ec7SPhilip Reames #include "llvm/Analysis/LoopAnalysisManager.h"
13*27a62ec7SPhilip Reames #include "llvm/Analysis/LoopInfo.h"
14*27a62ec7SPhilip Reames #include "llvm/Analysis/LoopPass.h"
15*27a62ec7SPhilip Reames #include "llvm/Analysis/MemorySSA.h"
16*27a62ec7SPhilip Reames #include "llvm/Analysis/MemorySSAUpdater.h"
17*27a62ec7SPhilip Reames #include "llvm/Analysis/ScalarEvolution.h"
18*27a62ec7SPhilip Reames #include "llvm/Analysis/ScalarEvolutionExpressions.h"
19*27a62ec7SPhilip Reames #include "llvm/Analysis/TargetLibraryInfo.h"
20*27a62ec7SPhilip Reames #include "llvm/Analysis/TargetTransformInfo.h"
21*27a62ec7SPhilip Reames #include "llvm/Analysis/ValueTracking.h"
22*27a62ec7SPhilip Reames #include "llvm/Config/llvm-config.h"
23*27a62ec7SPhilip Reames #include "llvm/IR/BasicBlock.h"
24*27a62ec7SPhilip Reames #include "llvm/IR/Dominators.h"
25*27a62ec7SPhilip Reames #include "llvm/IR/IRBuilder.h"
26*27a62ec7SPhilip Reames #include "llvm/IR/InstrTypes.h"
27*27a62ec7SPhilip Reames #include "llvm/IR/Instruction.h"
28*27a62ec7SPhilip Reames #include "llvm/IR/Instructions.h"
29*27a62ec7SPhilip Reames #include "llvm/IR/Type.h"
30*27a62ec7SPhilip Reames #include "llvm/IR/Value.h"
31*27a62ec7SPhilip Reames #include "llvm/InitializePasses.h"
32*27a62ec7SPhilip Reames #include "llvm/Pass.h"
33*27a62ec7SPhilip Reames #include "llvm/Support/Debug.h"
34*27a62ec7SPhilip Reames #include "llvm/Support/raw_ostream.h"
35*27a62ec7SPhilip Reames #include "llvm/Transforms/Scalar.h"
36*27a62ec7SPhilip Reames #include "llvm/Transforms/Utils.h"
37*27a62ec7SPhilip Reames #include "llvm/Transforms/Utils/BasicBlockUtils.h"
38*27a62ec7SPhilip Reames #include "llvm/Transforms/Utils/Local.h"
39*27a62ec7SPhilip Reames #include "llvm/Transforms/Utils/LoopUtils.h"
40*27a62ec7SPhilip Reames #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
41*27a62ec7SPhilip Reames #include <cassert>
42*27a62ec7SPhilip Reames #include <optional>
43*27a62ec7SPhilip Reames 
44*27a62ec7SPhilip Reames using namespace llvm;
45*27a62ec7SPhilip Reames 
46*27a62ec7SPhilip Reames #define DEBUG_TYPE "loop-term-fold"
47*27a62ec7SPhilip Reames 
48*27a62ec7SPhilip Reames STATISTIC(NumTermFold,
49*27a62ec7SPhilip Reames           "Number of terminating condition fold recognized and performed");
50*27a62ec7SPhilip Reames 
51*27a62ec7SPhilip Reames static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
52*27a62ec7SPhilip Reames canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
53*27a62ec7SPhilip Reames                       const LoopInfo &LI, const TargetTransformInfo &TTI) {
54*27a62ec7SPhilip Reames   if (!L->isInnermost()) {
55*27a62ec7SPhilip Reames     LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
56*27a62ec7SPhilip Reames     return std::nullopt;
57*27a62ec7SPhilip Reames   }
58*27a62ec7SPhilip Reames   // Only inspect on simple loop structure
59*27a62ec7SPhilip Reames   if (!L->isLoopSimplifyForm()) {
60*27a62ec7SPhilip Reames     LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
61*27a62ec7SPhilip Reames     return std::nullopt;
62*27a62ec7SPhilip Reames   }
63*27a62ec7SPhilip Reames 
64*27a62ec7SPhilip Reames   if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
65*27a62ec7SPhilip Reames     LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
66*27a62ec7SPhilip Reames     return std::nullopt;
67*27a62ec7SPhilip Reames   }
68*27a62ec7SPhilip Reames 
69*27a62ec7SPhilip Reames   BasicBlock *LoopLatch = L->getLoopLatch();
70*27a62ec7SPhilip Reames   BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
71*27a62ec7SPhilip Reames   if (!BI || BI->isUnconditional())
72*27a62ec7SPhilip Reames     return std::nullopt;
73*27a62ec7SPhilip Reames   auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
74*27a62ec7SPhilip Reames   if (!TermCond) {
75*27a62ec7SPhilip Reames     LLVM_DEBUG(
76*27a62ec7SPhilip Reames         dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
77*27a62ec7SPhilip Reames     return std::nullopt;
78*27a62ec7SPhilip Reames   }
79*27a62ec7SPhilip Reames   if (!TermCond->hasOneUse()) {
80*27a62ec7SPhilip Reames     LLVM_DEBUG(
81*27a62ec7SPhilip Reames         dbgs()
82*27a62ec7SPhilip Reames         << "Cannot replace terminating condition with more than one use\n");
83*27a62ec7SPhilip Reames     return std::nullopt;
84*27a62ec7SPhilip Reames   }
85*27a62ec7SPhilip Reames 
86*27a62ec7SPhilip Reames   BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
87*27a62ec7SPhilip Reames   Value *RHS = TermCond->getOperand(1);
88*27a62ec7SPhilip Reames   if (!LHS || !L->isLoopInvariant(RHS))
89*27a62ec7SPhilip Reames     // We could pattern match the inverse form of the icmp, but that is
90*27a62ec7SPhilip Reames     // non-canonical, and this pass is running *very* late in the pipeline.
91*27a62ec7SPhilip Reames     return std::nullopt;
92*27a62ec7SPhilip Reames 
93*27a62ec7SPhilip Reames   // Find the IV used by the current exit condition.
94*27a62ec7SPhilip Reames   PHINode *ToFold;
95*27a62ec7SPhilip Reames   Value *ToFoldStart, *ToFoldStep;
96*27a62ec7SPhilip Reames   if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
97*27a62ec7SPhilip Reames     return std::nullopt;
98*27a62ec7SPhilip Reames 
99*27a62ec7SPhilip Reames   // Ensure the simple recurrence is a part of the current loop.
100*27a62ec7SPhilip Reames   if (ToFold->getParent() != L->getHeader())
101*27a62ec7SPhilip Reames     return std::nullopt;
102*27a62ec7SPhilip Reames 
103*27a62ec7SPhilip Reames   // If that IV isn't dead after we rewrite the exit condition in terms of
104*27a62ec7SPhilip Reames   // another IV, there's no point in doing the transform.
105*27a62ec7SPhilip Reames   if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
106*27a62ec7SPhilip Reames     return std::nullopt;
107*27a62ec7SPhilip Reames 
108*27a62ec7SPhilip Reames   // Inserting instructions in the preheader has a runtime cost, scale
109*27a62ec7SPhilip Reames   // the allowed cost with the loops trip count as best we can.
110*27a62ec7SPhilip Reames   const unsigned ExpansionBudget = [&]() {
111*27a62ec7SPhilip Reames     unsigned Budget = 2 * SCEVCheapExpansionBudget;
112*27a62ec7SPhilip Reames     if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
113*27a62ec7SPhilip Reames       return std::min(Budget, SmallTC);
114*27a62ec7SPhilip Reames     if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
115*27a62ec7SPhilip Reames       return std::min(Budget, *SmallTC);
116*27a62ec7SPhilip Reames     // Unknown trip count, assume long running by default.
117*27a62ec7SPhilip Reames     return Budget;
118*27a62ec7SPhilip Reames   }();
119*27a62ec7SPhilip Reames 
120*27a62ec7SPhilip Reames   const SCEV *BECount = SE.getBackedgeTakenCount(L);
121*27a62ec7SPhilip Reames   const DataLayout &DL = L->getHeader()->getDataLayout();
122*27a62ec7SPhilip Reames   SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
123*27a62ec7SPhilip Reames 
124*27a62ec7SPhilip Reames   PHINode *ToHelpFold = nullptr;
125*27a62ec7SPhilip Reames   const SCEV *TermValueS = nullptr;
126*27a62ec7SPhilip Reames   bool MustDropPoison = false;
127*27a62ec7SPhilip Reames   auto InsertPt = L->getLoopPreheader()->getTerminator();
128*27a62ec7SPhilip Reames   for (PHINode &PN : L->getHeader()->phis()) {
129*27a62ec7SPhilip Reames     if (ToFold == &PN)
130*27a62ec7SPhilip Reames       continue;
131*27a62ec7SPhilip Reames 
132*27a62ec7SPhilip Reames     if (!SE.isSCEVable(PN.getType())) {
133*27a62ec7SPhilip Reames       LLVM_DEBUG(dbgs() << "IV of phi '" << PN
134*27a62ec7SPhilip Reames                         << "' is not SCEV-able, not qualified for the "
135*27a62ec7SPhilip Reames                            "terminating condition folding.\n");
136*27a62ec7SPhilip Reames       continue;
137*27a62ec7SPhilip Reames     }
138*27a62ec7SPhilip Reames     const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
139*27a62ec7SPhilip Reames     // Only speculate on affine AddRec
140*27a62ec7SPhilip Reames     if (!AddRec || !AddRec->isAffine()) {
141*27a62ec7SPhilip Reames       LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
142*27a62ec7SPhilip Reames                         << "' is not an affine add recursion, not qualified "
143*27a62ec7SPhilip Reames                            "for the terminating condition folding.\n");
144*27a62ec7SPhilip Reames       continue;
145*27a62ec7SPhilip Reames     }
146*27a62ec7SPhilip Reames 
147*27a62ec7SPhilip Reames     // Check that we can compute the value of AddRec on the exiting iteration
148*27a62ec7SPhilip Reames     // without soundness problems.  evaluateAtIteration internally needs
149*27a62ec7SPhilip Reames     // to multiply the stride of the iteration number - which may wrap around.
150*27a62ec7SPhilip Reames     // The issue here is subtle because computing the result accounting for
151*27a62ec7SPhilip Reames     // wrap is insufficient. In order to use the result in an exit test, we
152*27a62ec7SPhilip Reames     // must also know that AddRec doesn't take the same value on any previous
153*27a62ec7SPhilip Reames     // iteration. The simplest case to consider is a candidate IV which is
154*27a62ec7SPhilip Reames     // narrower than the trip count (and thus original IV), but this can
155*27a62ec7SPhilip Reames     // also happen due to non-unit strides on the candidate IVs.
156*27a62ec7SPhilip Reames     if (!AddRec->hasNoSelfWrap() ||
157*27a62ec7SPhilip Reames         !SE.isKnownNonZero(AddRec->getStepRecurrence(SE)))
158*27a62ec7SPhilip Reames       continue;
159*27a62ec7SPhilip Reames 
160*27a62ec7SPhilip Reames     const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
161*27a62ec7SPhilip Reames     const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
162*27a62ec7SPhilip Reames     if (!Expander.isSafeToExpand(TermValueSLocal)) {
163*27a62ec7SPhilip Reames       LLVM_DEBUG(
164*27a62ec7SPhilip Reames           dbgs() << "Is not safe to expand terminating value for phi node" << PN
165*27a62ec7SPhilip Reames                  << "\n");
166*27a62ec7SPhilip Reames       continue;
167*27a62ec7SPhilip Reames     }
168*27a62ec7SPhilip Reames 
169*27a62ec7SPhilip Reames     if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI,
170*27a62ec7SPhilip Reames                                      InsertPt)) {
171*27a62ec7SPhilip Reames       LLVM_DEBUG(
172*27a62ec7SPhilip Reames           dbgs() << "Is too expensive to expand terminating value for phi node"
173*27a62ec7SPhilip Reames                  << PN << "\n");
174*27a62ec7SPhilip Reames       continue;
175*27a62ec7SPhilip Reames     }
176*27a62ec7SPhilip Reames 
177*27a62ec7SPhilip Reames     // The candidate IV may have been otherwise dead and poison from the
178*27a62ec7SPhilip Reames     // very first iteration.  If we can't disprove that, we can't use the IV.
179*27a62ec7SPhilip Reames     if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
180*27a62ec7SPhilip Reames       LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n");
181*27a62ec7SPhilip Reames       continue;
182*27a62ec7SPhilip Reames     }
183*27a62ec7SPhilip Reames 
184*27a62ec7SPhilip Reames     // The candidate IV may become poison on the last iteration.  If this
185*27a62ec7SPhilip Reames     // value is not branched on, this is a well defined program.  We're
186*27a62ec7SPhilip Reames     // about to add a new use to this IV, and we have to ensure we don't
187*27a62ec7SPhilip Reames     // insert UB which didn't previously exist.
188*27a62ec7SPhilip Reames     bool MustDropPoisonLocal = false;
189*27a62ec7SPhilip Reames     Instruction *PostIncV =
190*27a62ec7SPhilip Reames         cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
191*27a62ec7SPhilip Reames     if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
192*27a62ec7SPhilip Reames                                        &DT)) {
193*27a62ec7SPhilip Reames       LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN
194*27a62ec7SPhilip Reames                         << "\n");
195*27a62ec7SPhilip Reames 
196*27a62ec7SPhilip Reames       // If this is a complex recurrance with multiple instructions computing
197*27a62ec7SPhilip Reames       // the backedge value, we might need to strip poison flags from all of
198*27a62ec7SPhilip Reames       // them.
199*27a62ec7SPhilip Reames       if (PostIncV->getOperand(0) != &PN)
200*27a62ec7SPhilip Reames         continue;
201*27a62ec7SPhilip Reames 
202*27a62ec7SPhilip Reames       // In order to perform the transform, we need to drop the poison
203*27a62ec7SPhilip Reames       // generating flags on this instruction (if any).
204*27a62ec7SPhilip Reames       MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
205*27a62ec7SPhilip Reames     }
206*27a62ec7SPhilip Reames 
207*27a62ec7SPhilip Reames     // We pick the last legal alternate IV.  We could expore choosing an optimal
208*27a62ec7SPhilip Reames     // alternate IV if we had a decent heuristic to do so.
209*27a62ec7SPhilip Reames     ToHelpFold = &PN;
210*27a62ec7SPhilip Reames     TermValueS = TermValueSLocal;
211*27a62ec7SPhilip Reames     MustDropPoison = MustDropPoisonLocal;
212*27a62ec7SPhilip Reames   }
213*27a62ec7SPhilip Reames 
214*27a62ec7SPhilip Reames   LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
215*27a62ec7SPhilip Reames                  << "Cannot find other AddRec IV to help folding\n";);
216*27a62ec7SPhilip Reames 
217*27a62ec7SPhilip Reames   LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
218*27a62ec7SPhilip Reames              << "\nFound loop that can fold terminating condition\n"
219*27a62ec7SPhilip Reames              << "  BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
220*27a62ec7SPhilip Reames              << "  TermCond: " << *TermCond << "\n"
221*27a62ec7SPhilip Reames              << "  BrandInst: " << *BI << "\n"
222*27a62ec7SPhilip Reames              << "  ToFold: " << *ToFold << "\n"
223*27a62ec7SPhilip Reames              << "  ToHelpFold: " << *ToHelpFold << "\n");
224*27a62ec7SPhilip Reames 
225*27a62ec7SPhilip Reames   if (!ToFold || !ToHelpFold)
226*27a62ec7SPhilip Reames     return std::nullopt;
227*27a62ec7SPhilip Reames   return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
228*27a62ec7SPhilip Reames }
229*27a62ec7SPhilip Reames 
230*27a62ec7SPhilip Reames static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
231*27a62ec7SPhilip Reames                         LoopInfo &LI, const TargetTransformInfo &TTI,
232*27a62ec7SPhilip Reames                         TargetLibraryInfo &TLI, MemorySSA *MSSA) {
233*27a62ec7SPhilip Reames   std::unique_ptr<MemorySSAUpdater> MSSAU;
234*27a62ec7SPhilip Reames   if (MSSA)
235*27a62ec7SPhilip Reames     MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
236*27a62ec7SPhilip Reames 
237*27a62ec7SPhilip Reames   auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);
238*27a62ec7SPhilip Reames   if (!Opt)
239*27a62ec7SPhilip Reames     return false;
240*27a62ec7SPhilip Reames 
241*27a62ec7SPhilip Reames   auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
242*27a62ec7SPhilip Reames 
243*27a62ec7SPhilip Reames   NumTermFold++;
244*27a62ec7SPhilip Reames 
245*27a62ec7SPhilip Reames   BasicBlock *LoopPreheader = L->getLoopPreheader();
246*27a62ec7SPhilip Reames   BasicBlock *LoopLatch = L->getLoopLatch();
247*27a62ec7SPhilip Reames 
248*27a62ec7SPhilip Reames   (void)ToFold;
249*27a62ec7SPhilip Reames   LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
250*27a62ec7SPhilip Reames                     << *ToFold << "\n"
251*27a62ec7SPhilip Reames                     << "New term-cond phi-node:\n"
252*27a62ec7SPhilip Reames                     << *ToHelpFold << "\n");
253*27a62ec7SPhilip Reames 
254*27a62ec7SPhilip Reames   Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader);
255*27a62ec7SPhilip Reames   (void)StartValue;
256*27a62ec7SPhilip Reames   Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);
257*27a62ec7SPhilip Reames 
258*27a62ec7SPhilip Reames   // See comment in canFoldTermCondOfLoop on why this is sufficient.
259*27a62ec7SPhilip Reames   if (MustDrop)
260*27a62ec7SPhilip Reames     cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();
261*27a62ec7SPhilip Reames 
262*27a62ec7SPhilip Reames   // SCEVExpander for both use in preheader and latch
263*27a62ec7SPhilip Reames   const DataLayout &DL = L->getHeader()->getDataLayout();
264*27a62ec7SPhilip Reames   SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
265*27a62ec7SPhilip Reames 
266*27a62ec7SPhilip Reames   assert(Expander.isSafeToExpand(TermValueS) &&
267*27a62ec7SPhilip Reames          "Terminating value was checked safe in canFoldTerminatingCondition");
268*27a62ec7SPhilip Reames 
269*27a62ec7SPhilip Reames   // Create new terminating value at loop preheader
270*27a62ec7SPhilip Reames   Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(),
271*27a62ec7SPhilip Reames                                             LoopPreheader->getTerminator());
272*27a62ec7SPhilip Reames 
273*27a62ec7SPhilip Reames   LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
274*27a62ec7SPhilip Reames                     << *StartValue << "\n"
275*27a62ec7SPhilip Reames                     << "Terminating value of new term-cond phi-node:\n"
276*27a62ec7SPhilip Reames                     << *TermValue << "\n");
277*27a62ec7SPhilip Reames 
278*27a62ec7SPhilip Reames   // Create new terminating condition at loop latch
279*27a62ec7SPhilip Reames   BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
280*27a62ec7SPhilip Reames   ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
281*27a62ec7SPhilip Reames   IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
282*27a62ec7SPhilip Reames   Value *NewTermCond =
283*27a62ec7SPhilip Reames       LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
284*27a62ec7SPhilip Reames                               "lsr_fold_term_cond.replaced_term_cond");
285*27a62ec7SPhilip Reames   // Swap successors to exit loop body if IV equals to new TermValue
286*27a62ec7SPhilip Reames   if (BI->getSuccessor(0) == L->getHeader())
287*27a62ec7SPhilip Reames     BI->swapSuccessors();
288*27a62ec7SPhilip Reames 
289*27a62ec7SPhilip Reames   LLVM_DEBUG(dbgs() << "Old term-cond:\n"
290*27a62ec7SPhilip Reames                     << *OldTermCond << "\n"
291*27a62ec7SPhilip Reames                     << "New term-cond:\n"
292*27a62ec7SPhilip Reames                     << *NewTermCond << "\n");
293*27a62ec7SPhilip Reames 
294*27a62ec7SPhilip Reames   BI->setCondition(NewTermCond);
295*27a62ec7SPhilip Reames 
296*27a62ec7SPhilip Reames   Expander.clear();
297*27a62ec7SPhilip Reames   OldTermCond->eraseFromParent();
298*27a62ec7SPhilip Reames   DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
299*27a62ec7SPhilip Reames   return true;
300*27a62ec7SPhilip Reames }
301*27a62ec7SPhilip Reames 
302*27a62ec7SPhilip Reames namespace {
303*27a62ec7SPhilip Reames 
304*27a62ec7SPhilip Reames class LoopTermFold : public LoopPass {
305*27a62ec7SPhilip Reames public:
306*27a62ec7SPhilip Reames   static char ID; // Pass ID, replacement for typeid
307*27a62ec7SPhilip Reames 
308*27a62ec7SPhilip Reames   LoopTermFold();
309*27a62ec7SPhilip Reames 
310*27a62ec7SPhilip Reames private:
311*27a62ec7SPhilip Reames   bool runOnLoop(Loop *L, LPPassManager &LPM) override;
312*27a62ec7SPhilip Reames   void getAnalysisUsage(AnalysisUsage &AU) const override;
313*27a62ec7SPhilip Reames };
314*27a62ec7SPhilip Reames 
315*27a62ec7SPhilip Reames } // end anonymous namespace
316*27a62ec7SPhilip Reames 
317*27a62ec7SPhilip Reames LoopTermFold::LoopTermFold() : LoopPass(ID) {
318*27a62ec7SPhilip Reames   initializeLoopTermFoldPass(*PassRegistry::getPassRegistry());
319*27a62ec7SPhilip Reames }
320*27a62ec7SPhilip Reames 
321*27a62ec7SPhilip Reames void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {
322*27a62ec7SPhilip Reames   AU.addRequired<LoopInfoWrapperPass>();
323*27a62ec7SPhilip Reames   AU.addPreserved<LoopInfoWrapperPass>();
324*27a62ec7SPhilip Reames   AU.addPreservedID(LoopSimplifyID);
325*27a62ec7SPhilip Reames   AU.addRequiredID(LoopSimplifyID);
326*27a62ec7SPhilip Reames   AU.addRequired<DominatorTreeWrapperPass>();
327*27a62ec7SPhilip Reames   AU.addPreserved<DominatorTreeWrapperPass>();
328*27a62ec7SPhilip Reames   AU.addRequired<ScalarEvolutionWrapperPass>();
329*27a62ec7SPhilip Reames   AU.addPreserved<ScalarEvolutionWrapperPass>();
330*27a62ec7SPhilip Reames   AU.addRequired<TargetLibraryInfoWrapperPass>();
331*27a62ec7SPhilip Reames   AU.addRequired<TargetTransformInfoWrapperPass>();
332*27a62ec7SPhilip Reames   AU.addPreserved<MemorySSAWrapperPass>();
333*27a62ec7SPhilip Reames }
334*27a62ec7SPhilip Reames 
335*27a62ec7SPhilip Reames bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
336*27a62ec7SPhilip Reames   if (skipLoop(L))
337*27a62ec7SPhilip Reames     return false;
338*27a62ec7SPhilip Reames 
339*27a62ec7SPhilip Reames   auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
340*27a62ec7SPhilip Reames   auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
341*27a62ec7SPhilip Reames   auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
342*27a62ec7SPhilip Reames   const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
343*27a62ec7SPhilip Reames       *L->getHeader()->getParent());
344*27a62ec7SPhilip Reames   auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
345*27a62ec7SPhilip Reames       *L->getHeader()->getParent());
346*27a62ec7SPhilip Reames   auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
347*27a62ec7SPhilip Reames   MemorySSA *MSSA = nullptr;
348*27a62ec7SPhilip Reames   if (MSSAAnalysis)
349*27a62ec7SPhilip Reames     MSSA = &MSSAAnalysis->getMSSA();
350*27a62ec7SPhilip Reames   return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA);
351*27a62ec7SPhilip Reames }
352*27a62ec7SPhilip Reames 
353*27a62ec7SPhilip Reames PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM,
354*27a62ec7SPhilip Reames                                         LoopStandardAnalysisResults &AR,
355*27a62ec7SPhilip Reames                                         LPMUpdater &) {
356*27a62ec7SPhilip Reames   if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA))
357*27a62ec7SPhilip Reames     return PreservedAnalyses::all();
358*27a62ec7SPhilip Reames 
359*27a62ec7SPhilip Reames   auto PA = getLoopPassPreservedAnalyses();
360*27a62ec7SPhilip Reames   if (AR.MSSA)
361*27a62ec7SPhilip Reames     PA.preserve<MemorySSAAnalysis>();
362*27a62ec7SPhilip Reames   return PA;
363*27a62ec7SPhilip Reames }
364*27a62ec7SPhilip Reames 
365*27a62ec7SPhilip Reames char LoopTermFold::ID = 0;
366*27a62ec7SPhilip Reames 
367*27a62ec7SPhilip Reames INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
368*27a62ec7SPhilip Reames                       false, false)
369*27a62ec7SPhilip Reames INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
370*27a62ec7SPhilip Reames INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
371*27a62ec7SPhilip Reames INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
372*27a62ec7SPhilip Reames INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
373*27a62ec7SPhilip Reames INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
374*27a62ec7SPhilip Reames INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
375*27a62ec7SPhilip Reames                     false, false)
376*27a62ec7SPhilip Reames 
377*27a62ec7SPhilip Reames Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); }
378