xref: /llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp (revision 89c0124273339076b25bf860f6c2ee765ab96db3)
1 //===- LoopVersioning.cpp - Utility to version a loop ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a utility class to perform loop versioning.  The versioned
10 // loop speculates that otherwise may-aliasing memory accesses don't overlap and
11 // emits checks to prove this.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Utils/LoopVersioning.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/Analysis/AliasAnalysis.h"
18 #include "llvm/Analysis/LoopAccessAnalysis.h"
19 #include "llvm/Analysis/LoopInfo.h"
20 #include "llvm/Analysis/MemorySSA.h"
21 #include "llvm/Analysis/ScalarEvolution.h"
22 #include "llvm/Analysis/TargetLibraryInfo.h"
23 #include "llvm/IR/Dominators.h"
24 #include "llvm/IR/MDBuilder.h"
25 #include "llvm/IR/PassManager.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
29 #include "llvm/Transforms/Utils/Cloning.h"
30 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
31 
32 using namespace llvm;
33 
34 static cl::opt<bool>
35     AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true),
36                     cl::Hidden,
37                     cl::desc("Add no-alias annotation for instructions that "
38                              "are disambiguated by memchecks"));
39 
40 LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI,
41                                ArrayRef<RuntimePointerCheck> Checks, Loop *L,
42                                LoopInfo *LI, DominatorTree *DT,
43                                ScalarEvolution *SE)
44     : VersionedLoop(L), NonVersionedLoop(nullptr),
45       AliasChecks(Checks.begin(), Checks.end()),
46       Preds(LAI.getPSE().getUnionPredicate()), LAI(LAI), LI(LI), DT(DT),
47       SE(SE) {
48   assert(L->getExitBlock() && "No single exit block");
49   assert(L->isLoopSimplifyForm() && "Loop is not in loop-simplify form");
50 }
51 
52 void LoopVersioning::versionLoop(
53     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
54   Instruction *FirstCheckInst;
55   Instruction *MemRuntimeCheck;
56   Value *SCEVRuntimeCheck;
57   Value *RuntimeCheck = nullptr;
58 
59   // Add the memcheck in the original preheader (this is empty initially).
60   BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
61   const auto &RtPtrChecking = *LAI.getRuntimePointerChecking();
62   std::tie(FirstCheckInst, MemRuntimeCheck) =
63       addRuntimeChecks(RuntimeCheckBB->getTerminator(), VersionedLoop,
64                        AliasChecks, RtPtrChecking.getSE());
65 
66   SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
67                    "scev.check");
68   SCEVRuntimeCheck =
69       Exp.expandCodeForPredicate(&Preds, RuntimeCheckBB->getTerminator());
70   auto *CI = dyn_cast<ConstantInt>(SCEVRuntimeCheck);
71 
72   // Discard the SCEV runtime check if it is always true.
73   if (CI && CI->isZero())
74     SCEVRuntimeCheck = nullptr;
75 
76   if (MemRuntimeCheck && SCEVRuntimeCheck) {
77     RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck,
78                                           SCEVRuntimeCheck, "lver.safe");
79     if (auto *I = dyn_cast<Instruction>(RuntimeCheck))
80       I->insertBefore(RuntimeCheckBB->getTerminator());
81   } else
82     RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
83 
84   assert(RuntimeCheck && "called even though we don't need "
85                          "any runtime checks");
86 
87   // Rename the block to make the IR more readable.
88   RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
89                           ".lver.check");
90 
91   // Create empty preheader for the loop (and after cloning for the
92   // non-versioned loop).
93   BasicBlock *PH =
94       SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI,
95                  nullptr, VersionedLoop->getHeader()->getName() + ".ph");
96 
97   // Clone the loop including the preheader.
98   //
99   // FIXME: This does not currently preserve SimplifyLoop because the exit
100   // block is a join between the two loops.
101   SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
102   NonVersionedLoop =
103       cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
104                              ".lver.orig", LI, DT, NonVersionedLoopBlocks);
105   remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
106 
107   // Insert the conditional branch based on the result of the memchecks.
108   Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
109   BranchInst::Create(NonVersionedLoop->getLoopPreheader(),
110                      VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm);
111   OrigTerm->eraseFromParent();
112 
113   // The loops merge in the original exit block.  This is now dominated by the
114   // memchecking block.
115   DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
116 
117   // Adds the necessary PHI nodes for the versioned loops based on the
118   // loop-defined values used outside of the loop.
119   addPHINodes(DefsUsedOutside);
120 }
121 
122 void LoopVersioning::addPHINodes(
123     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
124   BasicBlock *PHIBlock = VersionedLoop->getExitBlock();
125   assert(PHIBlock && "No single successor to loop exit block");
126   PHINode *PN;
127 
128   // First add a single-operand PHI for each DefsUsedOutside if one does not
129   // exists yet.
130   for (auto *Inst : DefsUsedOutside) {
131     // See if we have a single-operand PHI with the value defined by the
132     // original loop.
133     for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
134       if (PN->getIncomingValue(0) == Inst)
135         break;
136     }
137     // If not create it.
138     if (!PN) {
139       PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver",
140                            &PHIBlock->front());
141       SmallVector<User*, 8> UsersToUpdate;
142       for (User *U : Inst->users())
143         if (!VersionedLoop->contains(cast<Instruction>(U)->getParent()))
144           UsersToUpdate.push_back(U);
145       for (User *U : UsersToUpdate)
146         U->replaceUsesOfWith(Inst, PN);
147       PN->addIncoming(Inst, VersionedLoop->getExitingBlock());
148     }
149   }
150 
151   // Then for each PHI add the operand for the edge from the cloned loop.
152   for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
153     assert(PN->getNumOperands() == 1 &&
154            "Exit block should only have on predecessor");
155 
156     // If the definition was cloned used that otherwise use the same value.
157     Value *ClonedValue = PN->getIncomingValue(0);
158     auto Mapped = VMap.find(ClonedValue);
159     if (Mapped != VMap.end())
160       ClonedValue = Mapped->second;
161 
162     PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock());
163   }
164 }
165 
166 void LoopVersioning::prepareNoAliasMetadata() {
167   // We need to turn the no-alias relation between pointer checking groups into
168   // no-aliasing annotations between instructions.
169   //
170   // We accomplish this by mapping each pointer checking group (a set of
171   // pointers memchecked together) to an alias scope and then also mapping each
172   // group to the list of scopes it can't alias.
173 
174   const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking();
175   LLVMContext &Context = VersionedLoop->getHeader()->getContext();
176 
177   // First allocate an aliasing scope for each pointer checking group.
178   //
179   // While traversing through the checking groups in the loop, also create a
180   // reverse map from pointers to the pointer checking group they were assigned
181   // to.
182   MDBuilder MDB(Context);
183   MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain");
184 
185   for (const auto &Group : RtPtrChecking->CheckingGroups) {
186     GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain);
187 
188     for (unsigned PtrIdx : Group.Members)
189       PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group;
190   }
191 
192   // Go through the checks and for each pointer group, collect the scopes for
193   // each non-aliasing pointer group.
194   DenseMap<const RuntimeCheckingPtrGroup *, SmallVector<Metadata *, 4>>
195       GroupToNonAliasingScopes;
196 
197   for (const auto &Check : AliasChecks)
198     GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]);
199 
200   // Finally, transform the above to actually map to scope list which is what
201   // the metadata uses.
202 
203   for (auto Pair : GroupToNonAliasingScopes)
204     GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second);
205 }
206 
207 void LoopVersioning::annotateLoopWithNoAlias() {
208   if (!AnnotateNoAlias)
209     return;
210 
211   // First prepare the maps.
212   prepareNoAliasMetadata();
213 
214   // Add the scope and no-alias metadata to the instructions.
215   for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) {
216     annotateInstWithNoAlias(I);
217   }
218 }
219 
220 void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst,
221                                              const Instruction *OrigInst) {
222   if (!AnnotateNoAlias)
223     return;
224 
225   LLVMContext &Context = VersionedLoop->getHeader()->getContext();
226   const Value *Ptr = isa<LoadInst>(OrigInst)
227                          ? cast<LoadInst>(OrigInst)->getPointerOperand()
228                          : cast<StoreInst>(OrigInst)->getPointerOperand();
229 
230   // Find the group for the pointer and then add the scope metadata.
231   auto Group = PtrToGroup.find(Ptr);
232   if (Group != PtrToGroup.end()) {
233     VersionedInst->setMetadata(
234         LLVMContext::MD_alias_scope,
235         MDNode::concatenate(
236             VersionedInst->getMetadata(LLVMContext::MD_alias_scope),
237             MDNode::get(Context, GroupToScope[Group->second])));
238 
239     // Add the no-alias metadata.
240     auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second);
241     if (NonAliasingScopeList != GroupToNonAliasingScopeList.end())
242       VersionedInst->setMetadata(
243           LLVMContext::MD_noalias,
244           MDNode::concatenate(
245               VersionedInst->getMetadata(LLVMContext::MD_noalias),
246               NonAliasingScopeList->second));
247   }
248 }
249 
250 namespace {
251 bool runImpl(LoopInfo *LI, function_ref<const LoopAccessInfo &(Loop &)> GetLAA,
252              DominatorTree *DT, ScalarEvolution *SE) {
253   // Build up a worklist of inner-loops to version. This is necessary as the
254   // act of versioning a loop creates new loops and can invalidate iterators
255   // across the loops.
256   SmallVector<Loop *, 8> Worklist;
257 
258   for (Loop *TopLevelLoop : *LI)
259     for (Loop *L : depth_first(TopLevelLoop))
260       // We only handle inner-most loops.
261       if (L->isInnermost())
262         Worklist.push_back(L);
263 
264   // Now walk the identified inner loops.
265   bool Changed = false;
266   for (Loop *L : Worklist) {
267     const LoopAccessInfo &LAI = GetLAA(*L);
268     if (L->isLoopSimplifyForm() && !LAI.hasConvergentOp() &&
269         (LAI.getNumRuntimePointerChecks() ||
270          !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) {
271       LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), L,
272                           LI, DT, SE);
273       LVer.versionLoop();
274       LVer.annotateLoopWithNoAlias();
275       Changed = true;
276     }
277   }
278 
279   return Changed;
280 }
281 
282 /// Also expose this is a pass.  Currently this is only used for
283 /// unit-testing.  It adds all memchecks necessary to remove all may-aliasing
284 /// array accesses from the loop.
285 class LoopVersioningLegacyPass : public FunctionPass {
286 public:
287   LoopVersioningLegacyPass() : FunctionPass(ID) {
288     initializeLoopVersioningLegacyPassPass(*PassRegistry::getPassRegistry());
289   }
290 
291   bool runOnFunction(Function &F) override {
292     auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
293     auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & {
294       return getAnalysis<LoopAccessLegacyAnalysis>().getInfo(&L);
295     };
296 
297     auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
298     auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
299 
300     return runImpl(LI, GetLAA, DT, SE);
301   }
302 
303   void getAnalysisUsage(AnalysisUsage &AU) const override {
304     AU.addRequired<LoopInfoWrapperPass>();
305     AU.addPreserved<LoopInfoWrapperPass>();
306     AU.addRequired<LoopAccessLegacyAnalysis>();
307     AU.addRequired<DominatorTreeWrapperPass>();
308     AU.addPreserved<DominatorTreeWrapperPass>();
309     AU.addRequired<ScalarEvolutionWrapperPass>();
310   }
311 
312   static char ID;
313 };
314 }
315 
316 #define LVER_OPTION "loop-versioning"
317 #define DEBUG_TYPE LVER_OPTION
318 
319 char LoopVersioningLegacyPass::ID;
320 static const char LVer_name[] = "Loop Versioning";
321 
322 INITIALIZE_PASS_BEGIN(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
323                       false)
324 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
325 INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
326 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
327 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
328 INITIALIZE_PASS_END(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
329                     false)
330 
331 namespace llvm {
332 FunctionPass *createLoopVersioningLegacyPass() {
333   return new LoopVersioningLegacyPass();
334 }
335 
336 PreservedAnalyses LoopVersioningPass::run(Function &F,
337                                           FunctionAnalysisManager &AM) {
338   auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
339   auto &LI = AM.getResult<LoopAnalysis>(F);
340   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
341   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
342   auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
343   auto &AA = AM.getResult<AAManager>(F);
344   auto &AC = AM.getResult<AssumptionAnalysis>(F);
345   MemorySSA *MSSA = EnableMSSALoopDependency
346                         ? &AM.getResult<MemorySSAAnalysis>(F).getMSSA()
347                         : nullptr;
348 
349   auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager();
350   auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & {
351     LoopStandardAnalysisResults AR = {AA,  AC,  DT,      LI,  SE,
352                                       TLI, TTI, nullptr, MSSA};
353     return LAM.getResult<LoopAccessAnalysis>(L, AR);
354   };
355 
356   if (runImpl(&LI, GetLAA, &DT, &SE))
357     return PreservedAnalyses::none();
358   return PreservedAnalyses::all();
359 }
360 } // namespace llvm
361