xref: /llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp (revision 94f9cbbe49b4c836cfbed046637cdc0c63a4a083)
1 //===- LoopUnrollAndJam.cpp - Loop unroll and jam pass --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements an unroll and jam pass. Most of the work is done by
10 // Utils/UnrollLoopAndJam.cpp.
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Transforms/Scalar/LoopUnrollAndJamPass.h"
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/PriorityWorklist.h"
16 #include "llvm/ADT/SmallPtrSet.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Analysis/AssumptionCache.h"
19 #include "llvm/Analysis/CodeMetrics.h"
20 #include "llvm/Analysis/DependenceAnalysis.h"
21 #include "llvm/Analysis/LoopAnalysisManager.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/Analysis/LoopNestAnalysis.h"
24 #include "llvm/Analysis/LoopPass.h"
25 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
26 #include "llvm/Analysis/ScalarEvolution.h"
27 #include "llvm/Analysis/TargetTransformInfo.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/Constants.h"
30 #include "llvm/IR/Dominators.h"
31 #include "llvm/IR/Function.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/Metadata.h"
34 #include "llvm/IR/PassManager.h"
35 #include "llvm/Support/Casting.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/Debug.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include "llvm/Transforms/Scalar/LoopPassManager.h"
40 #include "llvm/Transforms/Utils/LoopPeel.h"
41 #include "llvm/Transforms/Utils/LoopUtils.h"
42 #include "llvm/Transforms/Utils/UnrollLoop.h"
43 #include <cassert>
44 #include <cstdint>
45 
46 namespace llvm {
47 class Instruction;
48 class Value;
49 } // namespace llvm
50 
51 using namespace llvm;
52 
53 #define DEBUG_TYPE "loop-unroll-and-jam"
54 
55 /// @{
56 /// Metadata attribute names
57 static const char *const LLVMLoopUnrollAndJamFollowupAll =
58     "llvm.loop.unroll_and_jam.followup_all";
59 static const char *const LLVMLoopUnrollAndJamFollowupInner =
60     "llvm.loop.unroll_and_jam.followup_inner";
61 static const char *const LLVMLoopUnrollAndJamFollowupOuter =
62     "llvm.loop.unroll_and_jam.followup_outer";
63 static const char *const LLVMLoopUnrollAndJamFollowupRemainderInner =
64     "llvm.loop.unroll_and_jam.followup_remainder_inner";
65 static const char *const LLVMLoopUnrollAndJamFollowupRemainderOuter =
66     "llvm.loop.unroll_and_jam.followup_remainder_outer";
67 /// @}
68 
69 static cl::opt<bool>
70     AllowUnrollAndJam("allow-unroll-and-jam", cl::Hidden,
71                       cl::desc("Allows loops to be unroll-and-jammed."));
72 
73 static cl::opt<unsigned> UnrollAndJamCount(
74     "unroll-and-jam-count", cl::Hidden,
75     cl::desc("Use this unroll count for all loops including those with "
76              "unroll_and_jam_count pragma values, for testing purposes"));
77 
78 static cl::opt<unsigned> UnrollAndJamThreshold(
79     "unroll-and-jam-threshold", cl::init(60), cl::Hidden,
80     cl::desc("Threshold to use for inner loop when doing unroll and jam."));
81 
82 static cl::opt<unsigned> PragmaUnrollAndJamThreshold(
83     "pragma-unroll-and-jam-threshold", cl::init(1024), cl::Hidden,
84     cl::desc("Unrolled size limit for loops with an unroll_and_jam(full) or "
85              "unroll_count pragma."));
86 
87 // Returns the loop hint metadata node with the given name (for example,
88 // "llvm.loop.unroll.count").  If no such metadata node exists, then nullptr is
89 // returned.
90 static MDNode *getUnrollMetadataForLoop(const Loop *L, StringRef Name) {
91   if (MDNode *LoopID = L->getLoopID())
92     return GetUnrollMetadata(LoopID, Name);
93   return nullptr;
94 }
95 
96 // Returns true if the loop has any metadata starting with Prefix. For example a
97 // Prefix of "llvm.loop.unroll." returns true if we have any unroll metadata.
98 static bool hasAnyUnrollPragma(const Loop *L, StringRef Prefix) {
99   if (MDNode *LoopID = L->getLoopID()) {
100     // First operand should refer to the loop id itself.
101     assert(LoopID->getNumOperands() > 0 && "requires at least one operand");
102     assert(LoopID->getOperand(0) == LoopID && "invalid loop id");
103 
104     for (unsigned I = 1, E = LoopID->getNumOperands(); I < E; ++I) {
105       MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(I));
106       if (!MD)
107         continue;
108 
109       MDString *S = dyn_cast<MDString>(MD->getOperand(0));
110       if (!S)
111         continue;
112 
113       if (S->getString().starts_with(Prefix))
114         return true;
115     }
116   }
117   return false;
118 }
119 
120 // Returns true if the loop has an unroll_and_jam(enable) pragma.
121 static bool hasUnrollAndJamEnablePragma(const Loop *L) {
122   return getUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.enable");
123 }
124 
125 // If loop has an unroll_and_jam_count pragma return the (necessarily
126 // positive) value from the pragma.  Otherwise return 0.
127 static unsigned unrollAndJamCountPragmaValue(const Loop *L) {
128   MDNode *MD = getUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.count");
129   if (MD) {
130     assert(MD->getNumOperands() == 2 &&
131            "Unroll count hint metadata should have two operands.");
132     unsigned Count =
133         mdconst::extract<ConstantInt>(MD->getOperand(1))->getZExtValue();
134     assert(Count >= 1 && "Unroll count must be positive.");
135     return Count;
136   }
137   return 0;
138 }
139 
140 // Returns loop size estimation for unrolled loop.
141 static uint64_t
142 getUnrollAndJammedLoopSize(unsigned LoopSize,
143                            TargetTransformInfo::UnrollingPreferences &UP) {
144   assert(LoopSize >= UP.BEInsns && "LoopSize should not be less than BEInsns!");
145   return static_cast<uint64_t>(LoopSize - UP.BEInsns) * UP.Count + UP.BEInsns;
146 }
147 
148 // Calculates unroll and jam count and writes it to UP.Count. Returns true if
149 // unroll count was set explicitly.
150 static bool computeUnrollAndJamCount(
151     Loop *L, Loop *SubLoop, const TargetTransformInfo &TTI, DominatorTree &DT,
152     LoopInfo *LI, AssumptionCache *AC, ScalarEvolution &SE,
153     const SmallPtrSetImpl<const Value *> &EphValues,
154     OptimizationRemarkEmitter *ORE, unsigned OuterTripCount,
155     unsigned OuterTripMultiple, const UnrollCostEstimator &OuterUCE,
156     unsigned InnerTripCount, unsigned InnerLoopSize,
157     TargetTransformInfo::UnrollingPreferences &UP,
158     TargetTransformInfo::PeelingPreferences &PP) {
159   unsigned OuterLoopSize = OuterUCE.getRolledLoopSize();
160   // First up use computeUnrollCount from the loop unroller to get a count
161   // for unrolling the outer loop, plus any loops requiring explicit
162   // unrolling we leave to the unroller. This uses UP.Threshold /
163   // UP.PartialThreshold / UP.MaxCount to come up with sensible loop values.
164   // We have already checked that the loop has no unroll.* pragmas.
165   unsigned MaxTripCount = 0;
166   bool UseUpperBound = false;
167   bool ExplicitUnroll = computeUnrollCount(
168     L, TTI, DT, LI, AC, SE, EphValues, ORE, OuterTripCount, MaxTripCount,
169       /*MaxOrZero*/ false, OuterTripMultiple, OuterUCE, UP, PP,
170       UseUpperBound);
171   if (ExplicitUnroll || UseUpperBound) {
172     // If the user explicitly set the loop as unrolled, dont UnJ it. Leave it
173     // for the unroller instead.
174     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; explicit count set by "
175                          "computeUnrollCount\n");
176     UP.Count = 0;
177     return false;
178   }
179 
180   // Override with any explicit Count from the "unroll-and-jam-count" option.
181   bool UserUnrollCount = UnrollAndJamCount.getNumOccurrences() > 0;
182   if (UserUnrollCount) {
183     UP.Count = UnrollAndJamCount;
184     UP.Force = true;
185     if (UP.AllowRemainder &&
186         getUnrollAndJammedLoopSize(OuterLoopSize, UP) < UP.Threshold &&
187         getUnrollAndJammedLoopSize(InnerLoopSize, UP) <
188             UP.UnrollAndJamInnerLoopThreshold)
189       return true;
190   }
191 
192   // Check for unroll_and_jam pragmas
193   unsigned PragmaCount = unrollAndJamCountPragmaValue(L);
194   if (PragmaCount > 0) {
195     UP.Count = PragmaCount;
196     UP.Runtime = true;
197     UP.Force = true;
198     if ((UP.AllowRemainder || (OuterTripMultiple % PragmaCount == 0)) &&
199         getUnrollAndJammedLoopSize(OuterLoopSize, UP) < UP.Threshold &&
200         getUnrollAndJammedLoopSize(InnerLoopSize, UP) <
201             UP.UnrollAndJamInnerLoopThreshold)
202       return true;
203   }
204 
205   bool PragmaEnableUnroll = hasUnrollAndJamEnablePragma(L);
206   bool ExplicitUnrollAndJamCount = PragmaCount > 0 || UserUnrollCount;
207   bool ExplicitUnrollAndJam = PragmaEnableUnroll || ExplicitUnrollAndJamCount;
208 
209   // If the loop has an unrolling pragma, we want to be more aggressive with
210   // unrolling limits.
211   if (ExplicitUnrollAndJam)
212     UP.UnrollAndJamInnerLoopThreshold = PragmaUnrollAndJamThreshold;
213 
214   if (!UP.AllowRemainder && getUnrollAndJammedLoopSize(InnerLoopSize, UP) >=
215                                 UP.UnrollAndJamInnerLoopThreshold) {
216     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't create remainder and "
217                          "inner loop too large\n");
218     UP.Count = 0;
219     return false;
220   }
221 
222   // We have a sensible limit for the outer loop, now adjust it for the inner
223   // loop and UP.UnrollAndJamInnerLoopThreshold. If the outer limit was set
224   // explicitly, we want to stick to it.
225   if (!ExplicitUnrollAndJamCount && UP.AllowRemainder) {
226     while (UP.Count != 0 && getUnrollAndJammedLoopSize(InnerLoopSize, UP) >=
227                                 UP.UnrollAndJamInnerLoopThreshold)
228       UP.Count--;
229   }
230 
231   // If we are explicitly unroll and jamming, we are done. Otherwise there are a
232   // number of extra performance heuristics to check.
233   if (ExplicitUnrollAndJam)
234     return true;
235 
236   // If the inner loop count is known and small, leave the entire loop nest to
237   // be the unroller
238   if (InnerTripCount && InnerLoopSize * InnerTripCount < UP.Threshold) {
239     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; small inner loop count is "
240                          "being left for the unroller\n");
241     UP.Count = 0;
242     return false;
243   }
244 
245   // Check for situations where UnJ is likely to be unprofitable. Including
246   // subloops with more than 1 block.
247   if (SubLoop->getBlocks().size() != 1) {
248     LLVM_DEBUG(
249         dbgs() << "Won't unroll-and-jam; More than one inner loop block\n");
250     UP.Count = 0;
251     return false;
252   }
253 
254   // Limit to loops where there is something to gain from unrolling and
255   // jamming the loop. In this case, look for loads that are invariant in the
256   // outer loop and can become shared.
257   unsigned NumInvariant = 0;
258   for (BasicBlock *BB : SubLoop->getBlocks()) {
259     for (Instruction &I : *BB) {
260       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
261         Value *V = Ld->getPointerOperand();
262         const SCEV *LSCEV = SE.getSCEVAtScope(V, L);
263         if (SE.isLoopInvariant(LSCEV, L))
264           NumInvariant++;
265       }
266     }
267   }
268   if (NumInvariant == 0) {
269     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; No loop invariant loads\n");
270     UP.Count = 0;
271     return false;
272   }
273 
274   return false;
275 }
276 
277 static LoopUnrollResult
278 tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
279                       ScalarEvolution &SE, const TargetTransformInfo &TTI,
280                       AssumptionCache &AC, DependenceInfo &DI,
281                       OptimizationRemarkEmitter &ORE, int OptLevel) {
282   TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
283       L, SE, TTI, nullptr, nullptr, ORE, OptLevel, std::nullopt, std::nullopt,
284       std::nullopt, std::nullopt, std::nullopt, std::nullopt);
285   TargetTransformInfo::PeelingPreferences PP =
286       gatherPeelingPreferences(L, SE, TTI, std::nullopt, std::nullopt);
287 
288   TransformationMode EnableMode = hasUnrollAndJamTransformation(L);
289   if (EnableMode & TM_Disable)
290     return LoopUnrollResult::Unmodified;
291   if (EnableMode & TM_ForcedByUser)
292     UP.UnrollAndJam = true;
293 
294   if (AllowUnrollAndJam.getNumOccurrences() > 0)
295     UP.UnrollAndJam = AllowUnrollAndJam;
296   if (UnrollAndJamThreshold.getNumOccurrences() > 0)
297     UP.UnrollAndJamInnerLoopThreshold = UnrollAndJamThreshold;
298   // Exit early if unrolling is disabled.
299   if (!UP.UnrollAndJam || UP.UnrollAndJamInnerLoopThreshold == 0)
300     return LoopUnrollResult::Unmodified;
301 
302   LLVM_DEBUG(dbgs() << "Loop Unroll and Jam: F["
303                     << L->getHeader()->getParent()->getName() << "] Loop %"
304                     << L->getHeader()->getName() << "\n");
305 
306   // A loop with any unroll pragma (enabling/disabling/count/etc) is left for
307   // the unroller, so long as it does not explicitly have unroll_and_jam
308   // metadata. This means #pragma nounroll will disable unroll and jam as well
309   // as unrolling
310   if (hasAnyUnrollPragma(L, "llvm.loop.unroll.") &&
311       !hasAnyUnrollPragma(L, "llvm.loop.unroll_and_jam.")) {
312     LLVM_DEBUG(dbgs() << "  Disabled due to pragma.\n");
313     return LoopUnrollResult::Unmodified;
314   }
315 
316   if (!isSafeToUnrollAndJam(L, SE, DT, DI, *LI)) {
317     LLVM_DEBUG(dbgs() << "  Disabled due to not being safe.\n");
318     return LoopUnrollResult::Unmodified;
319   }
320 
321   // Approximate the loop size and collect useful info
322   SmallPtrSet<const Value *, 32> EphValues;
323   CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
324   Loop *SubLoop = L->getSubLoops()[0];
325   UnrollCostEstimator InnerUCE(SubLoop, TTI, EphValues, UP.BEInsns);
326   UnrollCostEstimator OuterUCE(L, TTI, EphValues, UP.BEInsns);
327 
328   if (!InnerUCE.canUnroll() || !OuterUCE.canUnroll()) {
329     LLVM_DEBUG(dbgs() << "  Loop not considered unrollable\n");
330     return LoopUnrollResult::Unmodified;
331   }
332 
333   unsigned InnerLoopSize = InnerUCE.getRolledLoopSize();
334   LLVM_DEBUG(dbgs() << "  Outer Loop Size: " << OuterUCE.getRolledLoopSize()
335                     << "\n");
336   LLVM_DEBUG(dbgs() << "  Inner Loop Size: " << InnerLoopSize << "\n");
337 
338   if (InnerUCE.NumInlineCandidates != 0 || OuterUCE.NumInlineCandidates != 0) {
339     LLVM_DEBUG(dbgs() << "  Not unrolling loop with inlinable calls.\n");
340     return LoopUnrollResult::Unmodified;
341   }
342   // FIXME: The call to canUnroll() allows some controlled convergent
343   // operations, but we block them here for future changes.
344   if (InnerUCE.Convergence != ConvergenceKind::None ||
345       OuterUCE.Convergence != ConvergenceKind::None) {
346     LLVM_DEBUG(
347         dbgs() << "  Not unrolling loop with convergent instructions.\n");
348     return LoopUnrollResult::Unmodified;
349   }
350 
351   // Save original loop IDs for after the transformation.
352   MDNode *OrigOuterLoopID = L->getLoopID();
353   MDNode *OrigSubLoopID = SubLoop->getLoopID();
354 
355   // To assign the loop id of the epilogue, assign it before unrolling it so it
356   // is applied to every inner loop of the epilogue. We later apply the loop ID
357   // for the jammed inner loop.
358   std::optional<MDNode *> NewInnerEpilogueLoopID = makeFollowupLoopID(
359       OrigOuterLoopID, {LLVMLoopUnrollAndJamFollowupAll,
360                         LLVMLoopUnrollAndJamFollowupRemainderInner});
361   if (NewInnerEpilogueLoopID)
362     SubLoop->setLoopID(*NewInnerEpilogueLoopID);
363 
364   // Find trip count and trip multiple
365   BasicBlock *Latch = L->getLoopLatch();
366   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
367   unsigned OuterTripCount = SE.getSmallConstantTripCount(L, Latch);
368   unsigned OuterTripMultiple = SE.getSmallConstantTripMultiple(L, Latch);
369   unsigned InnerTripCount = SE.getSmallConstantTripCount(SubLoop, SubLoopLatch);
370 
371   // Decide if, and by how much, to unroll
372   bool IsCountSetExplicitly = computeUnrollAndJamCount(
373     L, SubLoop, TTI, DT, LI, &AC, SE, EphValues, &ORE, OuterTripCount,
374       OuterTripMultiple, OuterUCE, InnerTripCount, InnerLoopSize, UP, PP);
375   if (UP.Count <= 1)
376     return LoopUnrollResult::Unmodified;
377   // Unroll factor (Count) must be less or equal to TripCount.
378   if (OuterTripCount && UP.Count > OuterTripCount)
379     UP.Count = OuterTripCount;
380 
381   Loop *EpilogueOuterLoop = nullptr;
382   LoopUnrollResult UnrollResult = UnrollAndJamLoop(
383       L, UP.Count, OuterTripCount, OuterTripMultiple, UP.UnrollRemainder, LI,
384       &SE, &DT, &AC, &TTI, &ORE, &EpilogueOuterLoop);
385 
386   // Assign new loop attributes.
387   if (EpilogueOuterLoop) {
388     std::optional<MDNode *> NewOuterEpilogueLoopID = makeFollowupLoopID(
389         OrigOuterLoopID, {LLVMLoopUnrollAndJamFollowupAll,
390                           LLVMLoopUnrollAndJamFollowupRemainderOuter});
391     if (NewOuterEpilogueLoopID)
392       EpilogueOuterLoop->setLoopID(*NewOuterEpilogueLoopID);
393   }
394 
395   std::optional<MDNode *> NewInnerLoopID =
396       makeFollowupLoopID(OrigOuterLoopID, {LLVMLoopUnrollAndJamFollowupAll,
397                                            LLVMLoopUnrollAndJamFollowupInner});
398   if (NewInnerLoopID)
399     SubLoop->setLoopID(*NewInnerLoopID);
400   else
401     SubLoop->setLoopID(OrigSubLoopID);
402 
403   if (UnrollResult == LoopUnrollResult::PartiallyUnrolled) {
404     std::optional<MDNode *> NewOuterLoopID = makeFollowupLoopID(
405         OrigOuterLoopID,
406         {LLVMLoopUnrollAndJamFollowupAll, LLVMLoopUnrollAndJamFollowupOuter});
407     if (NewOuterLoopID) {
408       L->setLoopID(*NewOuterLoopID);
409 
410       // Do not setLoopAlreadyUnrolled if a followup was given.
411       return UnrollResult;
412     }
413   }
414 
415   // If loop has an unroll count pragma or unrolled by explicitly set count
416   // mark loop as unrolled to prevent unrolling beyond that requested.
417   if (UnrollResult != LoopUnrollResult::FullyUnrolled && IsCountSetExplicitly)
418     L->setLoopAlreadyUnrolled();
419 
420   return UnrollResult;
421 }
422 
423 static bool tryToUnrollAndJamLoop(LoopNest &LN, DominatorTree &DT, LoopInfo &LI,
424                                   ScalarEvolution &SE,
425                                   const TargetTransformInfo &TTI,
426                                   AssumptionCache &AC, DependenceInfo &DI,
427                                   OptimizationRemarkEmitter &ORE, int OptLevel,
428                                   LPMUpdater &U) {
429   bool DidSomething = false;
430   ArrayRef<Loop *> Loops = LN.getLoops();
431   Loop *OutmostLoop = &LN.getOutermostLoop();
432 
433   // Add the loop nests in the reverse order of LN. See method
434   // declaration.
435   SmallPriorityWorklist<Loop *, 4> Worklist;
436   appendLoopsToWorklist(Loops, Worklist);
437   while (!Worklist.empty()) {
438     Loop *L = Worklist.pop_back_val();
439     std::string LoopName = std::string(L->getName());
440     LoopUnrollResult Result =
441         tryToUnrollAndJamLoop(L, DT, &LI, SE, TTI, AC, DI, ORE, OptLevel);
442     if (Result != LoopUnrollResult::Unmodified)
443       DidSomething = true;
444     if (L == OutmostLoop && Result == LoopUnrollResult::FullyUnrolled)
445       U.markLoopAsDeleted(*L, LoopName);
446   }
447 
448   return DidSomething;
449 }
450 
451 PreservedAnalyses LoopUnrollAndJamPass::run(LoopNest &LN,
452                                             LoopAnalysisManager &AM,
453                                             LoopStandardAnalysisResults &AR,
454                                             LPMUpdater &U) {
455   Function &F = *LN.getParent();
456 
457   DependenceInfo DI(&F, &AR.AA, &AR.SE, &AR.LI);
458   OptimizationRemarkEmitter ORE(&F);
459 
460   if (!tryToUnrollAndJamLoop(LN, AR.DT, AR.LI, AR.SE, AR.TTI, AR.AC, DI, ORE,
461                              OptLevel, U))
462     return PreservedAnalyses::all();
463 
464   auto PA = getLoopPassPreservedAnalyses();
465   PA.preserve<LoopNestAnalysis>();
466   return PA;
467 }
468