xref: /llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp (revision 2ea4c2c598b7c6f95b5d4db747bdf72770e586df)
1 //===- LoopVersioning.cpp - Utility to version a loop ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a utility class to perform loop versioning.  The versioned
10 // loop speculates that otherwise may-aliasing memory accesses don't overlap and
11 // emits checks to prove this.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Utils/LoopVersioning.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/Analysis/AliasAnalysis.h"
18 #include "llvm/Analysis/LoopAccessAnalysis.h"
19 #include "llvm/Analysis/LoopInfo.h"
20 #include "llvm/Analysis/MemorySSA.h"
21 #include "llvm/Analysis/ScalarEvolution.h"
22 #include "llvm/Analysis/TargetLibraryInfo.h"
23 #include "llvm/IR/Dominators.h"
24 #include "llvm/IR/MDBuilder.h"
25 #include "llvm/IR/PassManager.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
29 #include "llvm/Transforms/Utils/Cloning.h"
30 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
31 
32 using namespace llvm;
33 
34 static cl::opt<bool>
35     AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true),
36                     cl::Hidden,
37                     cl::desc("Add no-alias annotation for instructions that "
38                              "are disambiguated by memchecks"));
39 
40 LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
41                                DominatorTree *DT, ScalarEvolution *SE,
42                                bool UseLAIChecks)
43     : VersionedLoop(L), NonVersionedLoop(nullptr), LAI(LAI), LI(LI), DT(DT),
44       SE(SE) {
45   assert(L->getExitBlock() && "No single exit block");
46   assert(L->isLoopSimplifyForm() && "Loop is not in loop-simplify form");
47   if (UseLAIChecks) {
48     setAliasChecks(LAI.getRuntimePointerChecking()->getChecks());
49     setSCEVChecks(LAI.getPSE().getUnionPredicate());
50   }
51 }
52 
53 void LoopVersioning::setAliasChecks(ArrayRef<RuntimePointerCheck> Checks) {
54   AliasChecks = {Checks.begin(), Checks.end()};
55 }
56 
57 void LoopVersioning::setSCEVChecks(SCEVUnionPredicate Check) {
58   Preds = std::move(Check);
59 }
60 
61 void LoopVersioning::versionLoop(
62     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
63   Instruction *FirstCheckInst;
64   Instruction *MemRuntimeCheck;
65   Value *SCEVRuntimeCheck;
66   Value *RuntimeCheck = nullptr;
67 
68   // Add the memcheck in the original preheader (this is empty initially).
69   BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
70   const auto &RtPtrChecking = *LAI.getRuntimePointerChecking();
71   std::tie(FirstCheckInst, MemRuntimeCheck) =
72       addRuntimeChecks(RuntimeCheckBB->getTerminator(), VersionedLoop,
73                        AliasChecks, RtPtrChecking.getSE());
74 
75   const SCEVUnionPredicate &Pred = LAI.getPSE().getUnionPredicate();
76   SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
77                    "scev.check");
78   SCEVRuntimeCheck =
79       Exp.expandCodeForPredicate(&Pred, RuntimeCheckBB->getTerminator());
80   auto *CI = dyn_cast<ConstantInt>(SCEVRuntimeCheck);
81 
82   // Discard the SCEV runtime check if it is always true.
83   if (CI && CI->isZero())
84     SCEVRuntimeCheck = nullptr;
85 
86   if (MemRuntimeCheck && SCEVRuntimeCheck) {
87     RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck,
88                                           SCEVRuntimeCheck, "lver.safe");
89     if (auto *I = dyn_cast<Instruction>(RuntimeCheck))
90       I->insertBefore(RuntimeCheckBB->getTerminator());
91   } else
92     RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
93 
94   assert(RuntimeCheck && "called even though we don't need "
95                          "any runtime checks");
96 
97   // Rename the block to make the IR more readable.
98   RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
99                           ".lver.check");
100 
101   // Create empty preheader for the loop (and after cloning for the
102   // non-versioned loop).
103   BasicBlock *PH =
104       SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI,
105                  nullptr, VersionedLoop->getHeader()->getName() + ".ph");
106 
107   // Clone the loop including the preheader.
108   //
109   // FIXME: This does not currently preserve SimplifyLoop because the exit
110   // block is a join between the two loops.
111   SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
112   NonVersionedLoop =
113       cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
114                              ".lver.orig", LI, DT, NonVersionedLoopBlocks);
115   remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
116 
117   // Insert the conditional branch based on the result of the memchecks.
118   Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
119   BranchInst::Create(NonVersionedLoop->getLoopPreheader(),
120                      VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm);
121   OrigTerm->eraseFromParent();
122 
123   // The loops merge in the original exit block.  This is now dominated by the
124   // memchecking block.
125   DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
126 
127   // Adds the necessary PHI nodes for the versioned loops based on the
128   // loop-defined values used outside of the loop.
129   addPHINodes(DefsUsedOutside);
130 }
131 
132 void LoopVersioning::addPHINodes(
133     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
134   BasicBlock *PHIBlock = VersionedLoop->getExitBlock();
135   assert(PHIBlock && "No single successor to loop exit block");
136   PHINode *PN;
137 
138   // First add a single-operand PHI for each DefsUsedOutside if one does not
139   // exists yet.
140   for (auto *Inst : DefsUsedOutside) {
141     // See if we have a single-operand PHI with the value defined by the
142     // original loop.
143     for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
144       if (PN->getIncomingValue(0) == Inst)
145         break;
146     }
147     // If not create it.
148     if (!PN) {
149       PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver",
150                            &PHIBlock->front());
151       SmallVector<User*, 8> UsersToUpdate;
152       for (User *U : Inst->users())
153         if (!VersionedLoop->contains(cast<Instruction>(U)->getParent()))
154           UsersToUpdate.push_back(U);
155       for (User *U : UsersToUpdate)
156         U->replaceUsesOfWith(Inst, PN);
157       PN->addIncoming(Inst, VersionedLoop->getExitingBlock());
158     }
159   }
160 
161   // Then for each PHI add the operand for the edge from the cloned loop.
162   for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
163     assert(PN->getNumOperands() == 1 &&
164            "Exit block should only have on predecessor");
165 
166     // If the definition was cloned used that otherwise use the same value.
167     Value *ClonedValue = PN->getIncomingValue(0);
168     auto Mapped = VMap.find(ClonedValue);
169     if (Mapped != VMap.end())
170       ClonedValue = Mapped->second;
171 
172     PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock());
173   }
174 }
175 
176 void LoopVersioning::prepareNoAliasMetadata() {
177   // We need to turn the no-alias relation between pointer checking groups into
178   // no-aliasing annotations between instructions.
179   //
180   // We accomplish this by mapping each pointer checking group (a set of
181   // pointers memchecked together) to an alias scope and then also mapping each
182   // group to the list of scopes it can't alias.
183 
184   const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking();
185   LLVMContext &Context = VersionedLoop->getHeader()->getContext();
186 
187   // First allocate an aliasing scope for each pointer checking group.
188   //
189   // While traversing through the checking groups in the loop, also create a
190   // reverse map from pointers to the pointer checking group they were assigned
191   // to.
192   MDBuilder MDB(Context);
193   MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain");
194 
195   for (const auto &Group : RtPtrChecking->CheckingGroups) {
196     GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain);
197 
198     for (unsigned PtrIdx : Group.Members)
199       PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group;
200   }
201 
202   // Go through the checks and for each pointer group, collect the scopes for
203   // each non-aliasing pointer group.
204   DenseMap<const RuntimeCheckingPtrGroup *, SmallVector<Metadata *, 4>>
205       GroupToNonAliasingScopes;
206 
207   for (const auto &Check : AliasChecks)
208     GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]);
209 
210   // Finally, transform the above to actually map to scope list which is what
211   // the metadata uses.
212 
213   for (auto Pair : GroupToNonAliasingScopes)
214     GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second);
215 }
216 
217 void LoopVersioning::annotateLoopWithNoAlias() {
218   if (!AnnotateNoAlias)
219     return;
220 
221   // First prepare the maps.
222   prepareNoAliasMetadata();
223 
224   // Add the scope and no-alias metadata to the instructions.
225   for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) {
226     annotateInstWithNoAlias(I);
227   }
228 }
229 
230 void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst,
231                                              const Instruction *OrigInst) {
232   if (!AnnotateNoAlias)
233     return;
234 
235   LLVMContext &Context = VersionedLoop->getHeader()->getContext();
236   const Value *Ptr = isa<LoadInst>(OrigInst)
237                          ? cast<LoadInst>(OrigInst)->getPointerOperand()
238                          : cast<StoreInst>(OrigInst)->getPointerOperand();
239 
240   // Find the group for the pointer and then add the scope metadata.
241   auto Group = PtrToGroup.find(Ptr);
242   if (Group != PtrToGroup.end()) {
243     VersionedInst->setMetadata(
244         LLVMContext::MD_alias_scope,
245         MDNode::concatenate(
246             VersionedInst->getMetadata(LLVMContext::MD_alias_scope),
247             MDNode::get(Context, GroupToScope[Group->second])));
248 
249     // Add the no-alias metadata.
250     auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second);
251     if (NonAliasingScopeList != GroupToNonAliasingScopeList.end())
252       VersionedInst->setMetadata(
253           LLVMContext::MD_noalias,
254           MDNode::concatenate(
255               VersionedInst->getMetadata(LLVMContext::MD_noalias),
256               NonAliasingScopeList->second));
257   }
258 }
259 
260 namespace {
261 bool runImpl(LoopInfo *LI, function_ref<const LoopAccessInfo &(Loop &)> GetLAA,
262              DominatorTree *DT, ScalarEvolution *SE) {
263   // Build up a worklist of inner-loops to version. This is necessary as the
264   // act of versioning a loop creates new loops and can invalidate iterators
265   // across the loops.
266   SmallVector<Loop *, 8> Worklist;
267 
268   for (Loop *TopLevelLoop : *LI)
269     for (Loop *L : depth_first(TopLevelLoop))
270       // We only handle inner-most loops.
271       if (L->empty())
272         Worklist.push_back(L);
273 
274   // Now walk the identified inner loops.
275   bool Changed = false;
276   for (Loop *L : Worklist) {
277     const LoopAccessInfo &LAI = GetLAA(*L);
278     if (L->isLoopSimplifyForm() && !LAI.hasConvergentOp() &&
279         (LAI.getNumRuntimePointerChecks() ||
280          !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) {
281       LoopVersioning LVer(LAI, L, LI, DT, SE);
282       LVer.versionLoop();
283       LVer.annotateLoopWithNoAlias();
284       Changed = true;
285     }
286   }
287 
288   return Changed;
289 }
290 
291 /// Also expose this is a pass.  Currently this is only used for
292 /// unit-testing.  It adds all memchecks necessary to remove all may-aliasing
293 /// array accesses from the loop.
294 class LoopVersioningLegacyPass : public FunctionPass {
295 public:
296   LoopVersioningLegacyPass() : FunctionPass(ID) {
297     initializeLoopVersioningLegacyPassPass(*PassRegistry::getPassRegistry());
298   }
299 
300   bool runOnFunction(Function &F) override {
301     auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
302     auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & {
303       return getAnalysis<LoopAccessLegacyAnalysis>().getInfo(&L);
304     };
305 
306     auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
307     auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
308 
309     return runImpl(LI, GetLAA, DT, SE);
310   }
311 
312   void getAnalysisUsage(AnalysisUsage &AU) const override {
313     AU.addRequired<LoopInfoWrapperPass>();
314     AU.addPreserved<LoopInfoWrapperPass>();
315     AU.addRequired<LoopAccessLegacyAnalysis>();
316     AU.addRequired<DominatorTreeWrapperPass>();
317     AU.addPreserved<DominatorTreeWrapperPass>();
318     AU.addRequired<ScalarEvolutionWrapperPass>();
319   }
320 
321   static char ID;
322 };
323 }
324 
325 #define LVER_OPTION "loop-versioning"
326 #define DEBUG_TYPE LVER_OPTION
327 
328 char LoopVersioningLegacyPass::ID;
329 static const char LVer_name[] = "Loop Versioning";
330 
331 INITIALIZE_PASS_BEGIN(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
332                       false)
333 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
334 INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
335 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
336 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
337 INITIALIZE_PASS_END(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
338                     false)
339 
340 namespace llvm {
341 FunctionPass *createLoopVersioningLegacyPass() {
342   return new LoopVersioningLegacyPass();
343 }
344 
345 PreservedAnalyses LoopVersioningPass::run(Function &F,
346                                           FunctionAnalysisManager &AM) {
347   auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
348   auto &LI = AM.getResult<LoopAnalysis>(F);
349   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
350   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
351   auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
352   auto &AA = AM.getResult<AAManager>(F);
353   auto &AC = AM.getResult<AssumptionAnalysis>(F);
354   MemorySSA *MSSA = EnableMSSALoopDependency
355                         ? &AM.getResult<MemorySSAAnalysis>(F).getMSSA()
356                         : nullptr;
357 
358   auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager();
359   auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & {
360     LoopStandardAnalysisResults AR = {AA,  AC,  DT,      LI,  SE,
361                                       TLI, TTI, nullptr, MSSA};
362     return LAM.getResult<LoopAccessAnalysis>(L, AR);
363   };
364 
365   if (runImpl(&LI, GetLAA, &DT, &SE))
366     return PreservedAnalyses::none();
367   return PreservedAnalyses::all();
368 }
369 } // namespace llvm
370