xref: /llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp (revision c83c4b58d1c1125cbac0a54266ff570e9e1f6845)
1 //===- LoopVersioning.cpp - Utility to version a loop ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a utility class to perform loop versioning.  The versioned
10 // loop speculates that otherwise may-aliasing memory accesses don't overlap and
11 // emits checks to prove this.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Utils/LoopVersioning.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/Analysis/AliasAnalysis.h"
18 #include "llvm/Analysis/InstSimplifyFolder.h"
19 #include "llvm/Analysis/LoopAccessAnalysis.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/Analysis/ScalarEvolution.h"
22 #include "llvm/Analysis/TargetLibraryInfo.h"
23 #include "llvm/IR/Dominators.h"
24 #include "llvm/IR/MDBuilder.h"
25 #include "llvm/IR/PassManager.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
29 #include "llvm/Transforms/Utils/Cloning.h"
30 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
31 
32 using namespace llvm;
33 
34 #define DEBUG_TYPE "loop-versioning"
35 
36 static cl::opt<bool>
37     AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true),
38                     cl::Hidden,
39                     cl::desc("Add no-alias annotation for instructions that "
40                              "are disambiguated by memchecks"));
41 
42 LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI,
43                                ArrayRef<RuntimePointerCheck> Checks, Loop *L,
44                                LoopInfo *LI, DominatorTree *DT,
45                                ScalarEvolution *SE)
46     : VersionedLoop(L), AliasChecks(Checks.begin(), Checks.end()),
47       Preds(LAI.getPSE().getPredicate()), LAI(LAI), LI(LI), DT(DT),
48       SE(SE) {
49 }
50 
51 void LoopVersioning::versionLoop(
52     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
53   assert(VersionedLoop->getUniqueExitBlock() && "No single exit block");
54   assert(VersionedLoop->isLoopSimplifyForm() &&
55          "Loop is not in loop-simplify form");
56 
57   Value *MemRuntimeCheck;
58   Value *SCEVRuntimeCheck;
59   Value *RuntimeCheck = nullptr;
60 
61   // Add the memcheck in the original preheader (this is empty initially).
62   BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
63   const auto &RtPtrChecking = *LAI.getRuntimePointerChecking();
64 
65   SCEVExpander Exp2(*RtPtrChecking.getSE(),
66                     VersionedLoop->getHeader()->getModule()->getDataLayout(),
67                     "induction");
68   MemRuntimeCheck = addRuntimeChecks(RuntimeCheckBB->getTerminator(),
69                                      VersionedLoop, AliasChecks, Exp2);
70 
71   SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
72                    "scev.check");
73   SCEVRuntimeCheck =
74       Exp.expandCodeForPredicate(&Preds, RuntimeCheckBB->getTerminator());
75 
76   IRBuilder<InstSimplifyFolder> Builder(
77       RuntimeCheckBB->getContext(),
78       InstSimplifyFolder(RuntimeCheckBB->getModule()->getDataLayout()));
79   if (MemRuntimeCheck && SCEVRuntimeCheck) {
80     Builder.SetInsertPoint(RuntimeCheckBB->getTerminator());
81     RuntimeCheck =
82         Builder.CreateOr(MemRuntimeCheck, SCEVRuntimeCheck, "lver.safe");
83   } else
84     RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
85 
86   assert(RuntimeCheck && "called even though we don't need "
87                          "any runtime checks");
88 
89   // Rename the block to make the IR more readable.
90   RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
91                           ".lver.check");
92 
93   // Create empty preheader for the loop (and after cloning for the
94   // non-versioned loop).
95   BasicBlock *PH =
96       SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI,
97                  nullptr, VersionedLoop->getHeader()->getName() + ".ph");
98 
99   // Clone the loop including the preheader.
100   //
101   // FIXME: This does not currently preserve SimplifyLoop because the exit
102   // block is a join between the two loops.
103   SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
104   NonVersionedLoop =
105       cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
106                              ".lver.orig", LI, DT, NonVersionedLoopBlocks);
107   remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
108 
109   // Insert the conditional branch based on the result of the memchecks.
110   Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
111   Builder.SetInsertPoint(OrigTerm);
112   Builder.CreateCondBr(RuntimeCheck, NonVersionedLoop->getLoopPreheader(),
113                        VersionedLoop->getLoopPreheader());
114   OrigTerm->eraseFromParent();
115 
116   // The loops merge in the original exit block.  This is now dominated by the
117   // memchecking block.
118   DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
119 
120   // Adds the necessary PHI nodes for the versioned loops based on the
121   // loop-defined values used outside of the loop.
122   addPHINodes(DefsUsedOutside);
123   formDedicatedExitBlocks(NonVersionedLoop, DT, LI, nullptr, true);
124   formDedicatedExitBlocks(VersionedLoop, DT, LI, nullptr, true);
125   assert(NonVersionedLoop->isLoopSimplifyForm() &&
126          VersionedLoop->isLoopSimplifyForm() &&
127          "The versioned loops should be in simplify form.");
128 }
129 
130 void LoopVersioning::addPHINodes(
131     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
132   BasicBlock *PHIBlock = VersionedLoop->getExitBlock();
133   assert(PHIBlock && "No single successor to loop exit block");
134   PHINode *PN;
135 
136   // First add a single-operand PHI for each DefsUsedOutside if one does not
137   // exists yet.
138   for (auto *Inst : DefsUsedOutside) {
139     // See if we have a single-operand PHI with the value defined by the
140     // original loop.
141     for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
142       if (PN->getIncomingValue(0) == Inst) {
143         SE->forgetValue(PN);
144         break;
145       }
146     }
147     // If not create it.
148     if (!PN) {
149       PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver",
150                            &PHIBlock->front());
151       SmallVector<User*, 8> UsersToUpdate;
152       for (User *U : Inst->users())
153         if (!VersionedLoop->contains(cast<Instruction>(U)->getParent()))
154           UsersToUpdate.push_back(U);
155       for (User *U : UsersToUpdate)
156         U->replaceUsesOfWith(Inst, PN);
157       PN->addIncoming(Inst, VersionedLoop->getExitingBlock());
158     }
159   }
160 
161   // Then for each PHI add the operand for the edge from the cloned loop.
162   for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
163     assert(PN->getNumOperands() == 1 &&
164            "Exit block should only have on predecessor");
165 
166     // If the definition was cloned used that otherwise use the same value.
167     Value *ClonedValue = PN->getIncomingValue(0);
168     auto Mapped = VMap.find(ClonedValue);
169     if (Mapped != VMap.end())
170       ClonedValue = Mapped->second;
171 
172     PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock());
173   }
174 }
175 
176 void LoopVersioning::prepareNoAliasMetadata() {
177   // We need to turn the no-alias relation between pointer checking groups into
178   // no-aliasing annotations between instructions.
179   //
180   // We accomplish this by mapping each pointer checking group (a set of
181   // pointers memchecked together) to an alias scope and then also mapping each
182   // group to the list of scopes it can't alias.
183 
184   const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking();
185   LLVMContext &Context = VersionedLoop->getHeader()->getContext();
186 
187   // First allocate an aliasing scope for each pointer checking group.
188   //
189   // While traversing through the checking groups in the loop, also create a
190   // reverse map from pointers to the pointer checking group they were assigned
191   // to.
192   MDBuilder MDB(Context);
193   MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain");
194 
195   for (const auto &Group : RtPtrChecking->CheckingGroups) {
196     GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain);
197 
198     for (unsigned PtrIdx : Group.Members)
199       PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group;
200   }
201 
202   // Go through the checks and for each pointer group, collect the scopes for
203   // each non-aliasing pointer group.
204   DenseMap<const RuntimeCheckingPtrGroup *, SmallVector<Metadata *, 4>>
205       GroupToNonAliasingScopes;
206 
207   for (const auto &Check : AliasChecks)
208     GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]);
209 
210   // Finally, transform the above to actually map to scope list which is what
211   // the metadata uses.
212 
213   for (const auto &Pair : GroupToNonAliasingScopes)
214     GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second);
215 }
216 
217 void LoopVersioning::annotateLoopWithNoAlias() {
218   if (!AnnotateNoAlias)
219     return;
220 
221   // First prepare the maps.
222   prepareNoAliasMetadata();
223 
224   // Add the scope and no-alias metadata to the instructions.
225   for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) {
226     annotateInstWithNoAlias(I);
227   }
228 }
229 
230 void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst,
231                                              const Instruction *OrigInst) {
232   if (!AnnotateNoAlias)
233     return;
234 
235   LLVMContext &Context = VersionedLoop->getHeader()->getContext();
236   const Value *Ptr = isa<LoadInst>(OrigInst)
237                          ? cast<LoadInst>(OrigInst)->getPointerOperand()
238                          : cast<StoreInst>(OrigInst)->getPointerOperand();
239 
240   // Find the group for the pointer and then add the scope metadata.
241   auto Group = PtrToGroup.find(Ptr);
242   if (Group != PtrToGroup.end()) {
243     VersionedInst->setMetadata(
244         LLVMContext::MD_alias_scope,
245         MDNode::concatenate(
246             VersionedInst->getMetadata(LLVMContext::MD_alias_scope),
247             MDNode::get(Context, GroupToScope[Group->second])));
248 
249     // Add the no-alias metadata.
250     auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second);
251     if (NonAliasingScopeList != GroupToNonAliasingScopeList.end())
252       VersionedInst->setMetadata(
253           LLVMContext::MD_noalias,
254           MDNode::concatenate(
255               VersionedInst->getMetadata(LLVMContext::MD_noalias),
256               NonAliasingScopeList->second));
257   }
258 }
259 
260 namespace {
261 bool runImpl(LoopInfo *LI, LoopAccessInfoManager &LAIs, DominatorTree *DT,
262              ScalarEvolution *SE) {
263   // Build up a worklist of inner-loops to version. This is necessary as the
264   // act of versioning a loop creates new loops and can invalidate iterators
265   // across the loops.
266   SmallVector<Loop *, 8> Worklist;
267 
268   for (Loop *TopLevelLoop : *LI)
269     for (Loop *L : depth_first(TopLevelLoop))
270       // We only handle inner-most loops.
271       if (L->isInnermost())
272         Worklist.push_back(L);
273 
274   // Now walk the identified inner loops.
275   bool Changed = false;
276   for (Loop *L : Worklist) {
277     if (!L->isLoopSimplifyForm() || !L->isRotatedForm() ||
278         !L->getExitingBlock())
279       continue;
280     const LoopAccessInfo &LAI = LAIs.getInfo(*L);
281     if (!LAI.hasConvergentOp() &&
282         (LAI.getNumRuntimePointerChecks() ||
283          !LAI.getPSE().getPredicate().isAlwaysTrue())) {
284       LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), L,
285                           LI, DT, SE);
286       LVer.versionLoop();
287       LVer.annotateLoopWithNoAlias();
288       Changed = true;
289       LAIs.clear();
290     }
291   }
292 
293   return Changed;
294 }
295 }
296 
297 PreservedAnalyses LoopVersioningPass::run(Function &F,
298                                           FunctionAnalysisManager &AM) {
299   auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
300   auto &LI = AM.getResult<LoopAnalysis>(F);
301   LoopAccessInfoManager &LAIs = AM.getResult<LoopAccessAnalysis>(F);
302   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
303 
304   if (runImpl(&LI, LAIs, &DT, &SE))
305     return PreservedAnalyses::none();
306   return PreservedAnalyses::all();
307 }
308