109467b48Spatrick //===- LoopVersioning.cpp - Utility to version a loop ---------------------===//
209467b48Spatrick //
309467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409467b48Spatrick // See https://llvm.org/LICENSE.txt for license information.
509467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609467b48Spatrick //
709467b48Spatrick //===----------------------------------------------------------------------===//
809467b48Spatrick //
909467b48Spatrick // This file defines a utility class to perform loop versioning. The versioned
1009467b48Spatrick // loop speculates that otherwise may-aliasing memory accesses don't overlap and
1109467b48Spatrick // emits checks to prove this.
1209467b48Spatrick //
1309467b48Spatrick //===----------------------------------------------------------------------===//
1409467b48Spatrick
1509467b48Spatrick #include "llvm/Transforms/Utils/LoopVersioning.h"
16097a140dSpatrick #include "llvm/ADT/ArrayRef.h"
17*d415bd75Srobert #include "llvm/Analysis/AliasAnalysis.h"
18*d415bd75Srobert #include "llvm/Analysis/InstSimplifyFolder.h"
1909467b48Spatrick #include "llvm/Analysis/LoopAccessAnalysis.h"
2009467b48Spatrick #include "llvm/Analysis/LoopInfo.h"
2173471bf0Spatrick #include "llvm/Analysis/ScalarEvolution.h"
2273471bf0Spatrick #include "llvm/Analysis/TargetLibraryInfo.h"
2309467b48Spatrick #include "llvm/IR/Dominators.h"
2409467b48Spatrick #include "llvm/IR/MDBuilder.h"
2573471bf0Spatrick #include "llvm/IR/PassManager.h"
2609467b48Spatrick #include "llvm/InitializePasses.h"
2709467b48Spatrick #include "llvm/Support/CommandLine.h"
2809467b48Spatrick #include "llvm/Transforms/Utils/BasicBlockUtils.h"
2909467b48Spatrick #include "llvm/Transforms/Utils/Cloning.h"
30097a140dSpatrick #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
3109467b48Spatrick
3209467b48Spatrick using namespace llvm;
3309467b48Spatrick
3409467b48Spatrick static cl::opt<bool>
3509467b48Spatrick AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true),
3609467b48Spatrick cl::Hidden,
3709467b48Spatrick cl::desc("Add no-alias annotation for instructions that "
3809467b48Spatrick "are disambiguated by memchecks"));
3909467b48Spatrick
LoopVersioning(const LoopAccessInfo & LAI,ArrayRef<RuntimePointerCheck> Checks,Loop * L,LoopInfo * LI,DominatorTree * DT,ScalarEvolution * SE)4073471bf0Spatrick LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI,
4173471bf0Spatrick ArrayRef<RuntimePointerCheck> Checks, Loop *L,
4273471bf0Spatrick LoopInfo *LI, DominatorTree *DT,
4373471bf0Spatrick ScalarEvolution *SE)
44*d415bd75Srobert : VersionedLoop(L), AliasChecks(Checks.begin(), Checks.end()),
45*d415bd75Srobert Preds(LAI.getPSE().getPredicate()), LAI(LAI), LI(LI), DT(DT),
4609467b48Spatrick SE(SE) {
4709467b48Spatrick }
4809467b48Spatrick
versionLoop(const SmallVectorImpl<Instruction * > & DefsUsedOutside)4909467b48Spatrick void LoopVersioning::versionLoop(
5009467b48Spatrick const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
5173471bf0Spatrick assert(VersionedLoop->getUniqueExitBlock() && "No single exit block");
5273471bf0Spatrick assert(VersionedLoop->isLoopSimplifyForm() &&
5373471bf0Spatrick "Loop is not in loop-simplify form");
5473471bf0Spatrick
55*d415bd75Srobert Value *MemRuntimeCheck;
5609467b48Spatrick Value *SCEVRuntimeCheck;
5709467b48Spatrick Value *RuntimeCheck = nullptr;
5809467b48Spatrick
5909467b48Spatrick // Add the memcheck in the original preheader (this is empty initially).
6009467b48Spatrick BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
61097a140dSpatrick const auto &RtPtrChecking = *LAI.getRuntimePointerChecking();
6209467b48Spatrick
6373471bf0Spatrick SCEVExpander Exp2(*RtPtrChecking.getSE(),
6473471bf0Spatrick VersionedLoop->getHeader()->getModule()->getDataLayout(),
6573471bf0Spatrick "induction");
66*d415bd75Srobert MemRuntimeCheck = addRuntimeChecks(RuntimeCheckBB->getTerminator(),
67*d415bd75Srobert VersionedLoop, AliasChecks, Exp2);
6873471bf0Spatrick
6909467b48Spatrick SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
7009467b48Spatrick "scev.check");
7109467b48Spatrick SCEVRuntimeCheck =
7273471bf0Spatrick Exp.expandCodeForPredicate(&Preds, RuntimeCheckBB->getTerminator());
7309467b48Spatrick
74*d415bd75Srobert IRBuilder<InstSimplifyFolder> Builder(
75*d415bd75Srobert RuntimeCheckBB->getContext(),
76*d415bd75Srobert InstSimplifyFolder(RuntimeCheckBB->getModule()->getDataLayout()));
7709467b48Spatrick if (MemRuntimeCheck && SCEVRuntimeCheck) {
78*d415bd75Srobert Builder.SetInsertPoint(RuntimeCheckBB->getTerminator());
79*d415bd75Srobert RuntimeCheck =
80*d415bd75Srobert Builder.CreateOr(MemRuntimeCheck, SCEVRuntimeCheck, "lver.safe");
8109467b48Spatrick } else
8209467b48Spatrick RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
8309467b48Spatrick
8409467b48Spatrick assert(RuntimeCheck && "called even though we don't need "
8509467b48Spatrick "any runtime checks");
8609467b48Spatrick
8709467b48Spatrick // Rename the block to make the IR more readable.
8809467b48Spatrick RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
8909467b48Spatrick ".lver.check");
9009467b48Spatrick
9109467b48Spatrick // Create empty preheader for the loop (and after cloning for the
9209467b48Spatrick // non-versioned loop).
9309467b48Spatrick BasicBlock *PH =
9409467b48Spatrick SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI,
9509467b48Spatrick nullptr, VersionedLoop->getHeader()->getName() + ".ph");
9609467b48Spatrick
9709467b48Spatrick // Clone the loop including the preheader.
9809467b48Spatrick //
9909467b48Spatrick // FIXME: This does not currently preserve SimplifyLoop because the exit
10009467b48Spatrick // block is a join between the two loops.
10109467b48Spatrick SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
10209467b48Spatrick NonVersionedLoop =
10309467b48Spatrick cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
10409467b48Spatrick ".lver.orig", LI, DT, NonVersionedLoopBlocks);
10509467b48Spatrick remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
10609467b48Spatrick
10709467b48Spatrick // Insert the conditional branch based on the result of the memchecks.
10809467b48Spatrick Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
109*d415bd75Srobert Builder.SetInsertPoint(OrigTerm);
110*d415bd75Srobert Builder.CreateCondBr(RuntimeCheck, NonVersionedLoop->getLoopPreheader(),
111*d415bd75Srobert VersionedLoop->getLoopPreheader());
11209467b48Spatrick OrigTerm->eraseFromParent();
11309467b48Spatrick
11409467b48Spatrick // The loops merge in the original exit block. This is now dominated by the
11509467b48Spatrick // memchecking block.
11609467b48Spatrick DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
11709467b48Spatrick
11809467b48Spatrick // Adds the necessary PHI nodes for the versioned loops based on the
11909467b48Spatrick // loop-defined values used outside of the loop.
12009467b48Spatrick addPHINodes(DefsUsedOutside);
12173471bf0Spatrick formDedicatedExitBlocks(NonVersionedLoop, DT, LI, nullptr, true);
12273471bf0Spatrick formDedicatedExitBlocks(VersionedLoop, DT, LI, nullptr, true);
12373471bf0Spatrick assert(NonVersionedLoop->isLoopSimplifyForm() &&
12473471bf0Spatrick VersionedLoop->isLoopSimplifyForm() &&
12573471bf0Spatrick "The versioned loops should be in simplify form.");
12609467b48Spatrick }
12709467b48Spatrick
addPHINodes(const SmallVectorImpl<Instruction * > & DefsUsedOutside)12809467b48Spatrick void LoopVersioning::addPHINodes(
12909467b48Spatrick const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
13009467b48Spatrick BasicBlock *PHIBlock = VersionedLoop->getExitBlock();
13109467b48Spatrick assert(PHIBlock && "No single successor to loop exit block");
13209467b48Spatrick PHINode *PN;
13309467b48Spatrick
13409467b48Spatrick // First add a single-operand PHI for each DefsUsedOutside if one does not
13509467b48Spatrick // exists yet.
13609467b48Spatrick for (auto *Inst : DefsUsedOutside) {
13709467b48Spatrick // See if we have a single-operand PHI with the value defined by the
13809467b48Spatrick // original loop.
13909467b48Spatrick for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
140*d415bd75Srobert if (PN->getIncomingValue(0) == Inst) {
141*d415bd75Srobert SE->forgetValue(PN);
14209467b48Spatrick break;
14309467b48Spatrick }
144*d415bd75Srobert }
14509467b48Spatrick // If not create it.
14609467b48Spatrick if (!PN) {
14709467b48Spatrick PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver",
14809467b48Spatrick &PHIBlock->front());
14909467b48Spatrick SmallVector<User*, 8> UsersToUpdate;
15009467b48Spatrick for (User *U : Inst->users())
15109467b48Spatrick if (!VersionedLoop->contains(cast<Instruction>(U)->getParent()))
15209467b48Spatrick UsersToUpdate.push_back(U);
15309467b48Spatrick for (User *U : UsersToUpdate)
15409467b48Spatrick U->replaceUsesOfWith(Inst, PN);
15509467b48Spatrick PN->addIncoming(Inst, VersionedLoop->getExitingBlock());
15609467b48Spatrick }
15709467b48Spatrick }
15809467b48Spatrick
15909467b48Spatrick // Then for each PHI add the operand for the edge from the cloned loop.
16009467b48Spatrick for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
16109467b48Spatrick assert(PN->getNumOperands() == 1 &&
16209467b48Spatrick "Exit block should only have on predecessor");
16309467b48Spatrick
16409467b48Spatrick // If the definition was cloned used that otherwise use the same value.
16509467b48Spatrick Value *ClonedValue = PN->getIncomingValue(0);
16609467b48Spatrick auto Mapped = VMap.find(ClonedValue);
16709467b48Spatrick if (Mapped != VMap.end())
16809467b48Spatrick ClonedValue = Mapped->second;
16909467b48Spatrick
17009467b48Spatrick PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock());
17109467b48Spatrick }
17209467b48Spatrick }
17309467b48Spatrick
prepareNoAliasMetadata()17409467b48Spatrick void LoopVersioning::prepareNoAliasMetadata() {
17509467b48Spatrick // We need to turn the no-alias relation between pointer checking groups into
17609467b48Spatrick // no-aliasing annotations between instructions.
17709467b48Spatrick //
17809467b48Spatrick // We accomplish this by mapping each pointer checking group (a set of
17909467b48Spatrick // pointers memchecked together) to an alias scope and then also mapping each
18009467b48Spatrick // group to the list of scopes it can't alias.
18109467b48Spatrick
18209467b48Spatrick const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking();
18309467b48Spatrick LLVMContext &Context = VersionedLoop->getHeader()->getContext();
18409467b48Spatrick
18509467b48Spatrick // First allocate an aliasing scope for each pointer checking group.
18609467b48Spatrick //
18709467b48Spatrick // While traversing through the checking groups in the loop, also create a
18809467b48Spatrick // reverse map from pointers to the pointer checking group they were assigned
18909467b48Spatrick // to.
19009467b48Spatrick MDBuilder MDB(Context);
19109467b48Spatrick MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain");
19209467b48Spatrick
19309467b48Spatrick for (const auto &Group : RtPtrChecking->CheckingGroups) {
19409467b48Spatrick GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain);
19509467b48Spatrick
19609467b48Spatrick for (unsigned PtrIdx : Group.Members)
19709467b48Spatrick PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group;
19809467b48Spatrick }
19909467b48Spatrick
20009467b48Spatrick // Go through the checks and for each pointer group, collect the scopes for
20109467b48Spatrick // each non-aliasing pointer group.
202097a140dSpatrick DenseMap<const RuntimeCheckingPtrGroup *, SmallVector<Metadata *, 4>>
20309467b48Spatrick GroupToNonAliasingScopes;
20409467b48Spatrick
20509467b48Spatrick for (const auto &Check : AliasChecks)
20609467b48Spatrick GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]);
20709467b48Spatrick
20809467b48Spatrick // Finally, transform the above to actually map to scope list which is what
20909467b48Spatrick // the metadata uses.
21009467b48Spatrick
21109467b48Spatrick for (auto Pair : GroupToNonAliasingScopes)
21209467b48Spatrick GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second);
21309467b48Spatrick }
21409467b48Spatrick
annotateLoopWithNoAlias()21509467b48Spatrick void LoopVersioning::annotateLoopWithNoAlias() {
21609467b48Spatrick if (!AnnotateNoAlias)
21709467b48Spatrick return;
21809467b48Spatrick
21909467b48Spatrick // First prepare the maps.
22009467b48Spatrick prepareNoAliasMetadata();
22109467b48Spatrick
22209467b48Spatrick // Add the scope and no-alias metadata to the instructions.
22309467b48Spatrick for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) {
22409467b48Spatrick annotateInstWithNoAlias(I);
22509467b48Spatrick }
22609467b48Spatrick }
22709467b48Spatrick
annotateInstWithNoAlias(Instruction * VersionedInst,const Instruction * OrigInst)22809467b48Spatrick void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst,
22909467b48Spatrick const Instruction *OrigInst) {
23009467b48Spatrick if (!AnnotateNoAlias)
23109467b48Spatrick return;
23209467b48Spatrick
23309467b48Spatrick LLVMContext &Context = VersionedLoop->getHeader()->getContext();
23409467b48Spatrick const Value *Ptr = isa<LoadInst>(OrigInst)
23509467b48Spatrick ? cast<LoadInst>(OrigInst)->getPointerOperand()
23609467b48Spatrick : cast<StoreInst>(OrigInst)->getPointerOperand();
23709467b48Spatrick
23809467b48Spatrick // Find the group for the pointer and then add the scope metadata.
23909467b48Spatrick auto Group = PtrToGroup.find(Ptr);
24009467b48Spatrick if (Group != PtrToGroup.end()) {
24109467b48Spatrick VersionedInst->setMetadata(
24209467b48Spatrick LLVMContext::MD_alias_scope,
24309467b48Spatrick MDNode::concatenate(
24409467b48Spatrick VersionedInst->getMetadata(LLVMContext::MD_alias_scope),
24509467b48Spatrick MDNode::get(Context, GroupToScope[Group->second])));
24609467b48Spatrick
24709467b48Spatrick // Add the no-alias metadata.
24809467b48Spatrick auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second);
24909467b48Spatrick if (NonAliasingScopeList != GroupToNonAliasingScopeList.end())
25009467b48Spatrick VersionedInst->setMetadata(
25109467b48Spatrick LLVMContext::MD_noalias,
25209467b48Spatrick MDNode::concatenate(
25309467b48Spatrick VersionedInst->getMetadata(LLVMContext::MD_noalias),
25409467b48Spatrick NonAliasingScopeList->second));
25509467b48Spatrick }
25609467b48Spatrick }
25709467b48Spatrick
25809467b48Spatrick namespace {
runImpl(LoopInfo * LI,LoopAccessInfoManager & LAIs,DominatorTree * DT,ScalarEvolution * SE)259*d415bd75Srobert bool runImpl(LoopInfo *LI, LoopAccessInfoManager &LAIs, DominatorTree *DT,
260*d415bd75Srobert ScalarEvolution *SE) {
26109467b48Spatrick // Build up a worklist of inner-loops to version. This is necessary as the
26209467b48Spatrick // act of versioning a loop creates new loops and can invalidate iterators
26309467b48Spatrick // across the loops.
26409467b48Spatrick SmallVector<Loop *, 8> Worklist;
26509467b48Spatrick
26609467b48Spatrick for (Loop *TopLevelLoop : *LI)
26709467b48Spatrick for (Loop *L : depth_first(TopLevelLoop))
26809467b48Spatrick // We only handle inner-most loops.
26973471bf0Spatrick if (L->isInnermost())
27009467b48Spatrick Worklist.push_back(L);
27109467b48Spatrick
27209467b48Spatrick // Now walk the identified inner loops.
27309467b48Spatrick bool Changed = false;
27409467b48Spatrick for (Loop *L : Worklist) {
27573471bf0Spatrick if (!L->isLoopSimplifyForm() || !L->isRotatedForm() ||
27673471bf0Spatrick !L->getExitingBlock())
27773471bf0Spatrick continue;
278*d415bd75Srobert const LoopAccessInfo &LAI = LAIs.getInfo(*L);
27973471bf0Spatrick if (!LAI.hasConvergentOp() &&
28009467b48Spatrick (LAI.getNumRuntimePointerChecks() ||
281*d415bd75Srobert !LAI.getPSE().getPredicate().isAlwaysTrue())) {
28273471bf0Spatrick LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), L,
28373471bf0Spatrick LI, DT, SE);
28409467b48Spatrick LVer.versionLoop();
28509467b48Spatrick LVer.annotateLoopWithNoAlias();
28609467b48Spatrick Changed = true;
287*d415bd75Srobert LAIs.clear();
28809467b48Spatrick }
28909467b48Spatrick }
29009467b48Spatrick
29109467b48Spatrick return Changed;
29209467b48Spatrick }
29309467b48Spatrick
29473471bf0Spatrick /// Also expose this is a pass. Currently this is only used for
29573471bf0Spatrick /// unit-testing. It adds all memchecks necessary to remove all may-aliasing
29673471bf0Spatrick /// array accesses from the loop.
29773471bf0Spatrick class LoopVersioningLegacyPass : public FunctionPass {
29873471bf0Spatrick public:
LoopVersioningLegacyPass()29973471bf0Spatrick LoopVersioningLegacyPass() : FunctionPass(ID) {
30073471bf0Spatrick initializeLoopVersioningLegacyPassPass(*PassRegistry::getPassRegistry());
30173471bf0Spatrick }
30273471bf0Spatrick
runOnFunction(Function & F)30373471bf0Spatrick bool runOnFunction(Function &F) override {
30473471bf0Spatrick auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
305*d415bd75Srobert auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs();
30673471bf0Spatrick auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
30773471bf0Spatrick auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
30873471bf0Spatrick
309*d415bd75Srobert return runImpl(LI, LAIs, DT, SE);
31073471bf0Spatrick }
31173471bf0Spatrick
getAnalysisUsage(AnalysisUsage & AU) const31209467b48Spatrick void getAnalysisUsage(AnalysisUsage &AU) const override {
31309467b48Spatrick AU.addRequired<LoopInfoWrapperPass>();
31409467b48Spatrick AU.addPreserved<LoopInfoWrapperPass>();
31509467b48Spatrick AU.addRequired<LoopAccessLegacyAnalysis>();
31609467b48Spatrick AU.addRequired<DominatorTreeWrapperPass>();
31709467b48Spatrick AU.addPreserved<DominatorTreeWrapperPass>();
31809467b48Spatrick AU.addRequired<ScalarEvolutionWrapperPass>();
31909467b48Spatrick }
32009467b48Spatrick
32109467b48Spatrick static char ID;
32209467b48Spatrick };
32309467b48Spatrick }
32409467b48Spatrick
32509467b48Spatrick #define LVER_OPTION "loop-versioning"
32609467b48Spatrick #define DEBUG_TYPE LVER_OPTION
32709467b48Spatrick
32873471bf0Spatrick char LoopVersioningLegacyPass::ID;
32909467b48Spatrick static const char LVer_name[] = "Loop Versioning";
33009467b48Spatrick
33173471bf0Spatrick INITIALIZE_PASS_BEGIN(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
33273471bf0Spatrick false)
33309467b48Spatrick INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
33409467b48Spatrick INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
33509467b48Spatrick INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
33609467b48Spatrick INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
33773471bf0Spatrick INITIALIZE_PASS_END(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
33873471bf0Spatrick false)
33909467b48Spatrick
34009467b48Spatrick namespace llvm {
createLoopVersioningLegacyPass()34173471bf0Spatrick FunctionPass *createLoopVersioningLegacyPass() {
34273471bf0Spatrick return new LoopVersioningLegacyPass();
34309467b48Spatrick }
34473471bf0Spatrick
run(Function & F,FunctionAnalysisManager & AM)34573471bf0Spatrick PreservedAnalyses LoopVersioningPass::run(Function &F,
34673471bf0Spatrick FunctionAnalysisManager &AM) {
34773471bf0Spatrick auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
34873471bf0Spatrick auto &LI = AM.getResult<LoopAnalysis>(F);
349*d415bd75Srobert LoopAccessInfoManager &LAIs = AM.getResult<LoopAccessAnalysis>(F);
35073471bf0Spatrick auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
35173471bf0Spatrick
352*d415bd75Srobert if (runImpl(&LI, LAIs, &DT, &SE))
35373471bf0Spatrick return PreservedAnalyses::none();
35473471bf0Spatrick return PreservedAnalyses::all();
35509467b48Spatrick }
35673471bf0Spatrick } // namespace llvm
357