1 //===- LoopVersioning.cpp - Utility to version a loop ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines a utility class to perform loop versioning. The versioned 10 // loop speculates that otherwise may-aliasing memory accesses don't overlap and 11 // emits checks to prove this. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Utils/LoopVersioning.h" 16 #include "llvm/ADT/ArrayRef.h" 17 #include "llvm/Analysis/AliasAnalysis.h" 18 #include "llvm/Analysis/LoopAccessAnalysis.h" 19 #include "llvm/Analysis/LoopInfo.h" 20 #include "llvm/Analysis/ScalarEvolution.h" 21 #include "llvm/Analysis/TargetLibraryInfo.h" 22 #include "llvm/IR/Dominators.h" 23 #include "llvm/IR/MDBuilder.h" 24 #include "llvm/IR/PassManager.h" 25 #include "llvm/InitializePasses.h" 26 #include "llvm/Support/CommandLine.h" 27 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 28 #include "llvm/Transforms/Utils/Cloning.h" 29 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 30 31 using namespace llvm; 32 33 #define LVER_OPTION "loop-versioning" 34 #define DEBUG_TYPE LVER_OPTION 35 36 static cl::opt<bool> 37 AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true), 38 cl::Hidden, 39 cl::desc("Add no-alias annotation for instructions that " 40 "are disambiguated by memchecks")); 41 42 LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, 43 ArrayRef<RuntimePointerCheck> Checks, Loop *L, 44 LoopInfo *LI, DominatorTree *DT, 45 ScalarEvolution *SE) 46 : VersionedLoop(L), NonVersionedLoop(nullptr), 47 AliasChecks(Checks.begin(), Checks.end()), 48 Preds(LAI.getPSE().getUnionPredicate()), LAI(LAI), LI(LI), DT(DT), 49 SE(SE) { 50 } 51 52 void LoopVersioning::versionLoop( 53 const SmallVectorImpl<Instruction *> &DefsUsedOutside) { 54 assert(VersionedLoop->getUniqueExitBlock() && "No single exit block"); 55 assert(VersionedLoop->isLoopSimplifyForm() && 56 "Loop is not in loop-simplify form"); 57 58 Value *MemRuntimeCheck; 59 Value *SCEVRuntimeCheck; 60 Value *RuntimeCheck = nullptr; 61 62 // Add the memcheck in the original preheader (this is empty initially). 63 BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader(); 64 const auto &RtPtrChecking = *LAI.getRuntimePointerChecking(); 65 66 SCEVExpander Exp2(*RtPtrChecking.getSE(), 67 VersionedLoop->getHeader()->getModule()->getDataLayout(), 68 "induction"); 69 MemRuntimeCheck = addRuntimeChecks(RuntimeCheckBB->getTerminator(), 70 VersionedLoop, AliasChecks, Exp2); 71 72 SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(), 73 "scev.check"); 74 SCEVRuntimeCheck = 75 Exp.expandCodeForPredicate(&Preds, RuntimeCheckBB->getTerminator()); 76 auto *CI = dyn_cast<ConstantInt>(SCEVRuntimeCheck); 77 78 // Discard the SCEV runtime check if it is always true. 79 if (CI && CI->isZero()) 80 SCEVRuntimeCheck = nullptr; 81 82 if (MemRuntimeCheck && SCEVRuntimeCheck) { 83 RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck, 84 SCEVRuntimeCheck, "lver.safe"); 85 if (auto *I = dyn_cast<Instruction>(RuntimeCheck)) 86 I->insertBefore(RuntimeCheckBB->getTerminator()); 87 } else 88 RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck; 89 90 if (!RuntimeCheck) { 91 LLVM_DEBUG(dbgs() << DEBUG_TYPE " No MemRuntimeCheck or SCEVRuntimeCheck found" 92 << ", set RuntimeCheck to False\n"); 93 RuntimeCheck = ConstantInt::getFalse(RuntimeCheckBB->getContext()); 94 } 95 96 // Rename the block to make the IR more readable. 97 RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() + 98 ".lver.check"); 99 100 // Create empty preheader for the loop (and after cloning for the 101 // non-versioned loop). 102 BasicBlock *PH = 103 SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI, 104 nullptr, VersionedLoop->getHeader()->getName() + ".ph"); 105 106 // Clone the loop including the preheader. 107 // 108 // FIXME: This does not currently preserve SimplifyLoop because the exit 109 // block is a join between the two loops. 110 SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks; 111 NonVersionedLoop = 112 cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap, 113 ".lver.orig", LI, DT, NonVersionedLoopBlocks); 114 remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap); 115 116 // Insert the conditional branch based on the result of the memchecks. 117 Instruction *OrigTerm = RuntimeCheckBB->getTerminator(); 118 RuntimeCheckBI = BranchInst::Create(NonVersionedLoop->getLoopPreheader(), 119 VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm); 120 OrigTerm->eraseFromParent(); 121 122 // The loops merge in the original exit block. This is now dominated by the 123 // memchecking block. 124 DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB); 125 126 // Adds the necessary PHI nodes for the versioned loops based on the 127 // loop-defined values used outside of the loop. 128 addPHINodes(DefsUsedOutside); 129 formDedicatedExitBlocks(NonVersionedLoop, DT, LI, nullptr, true); 130 formDedicatedExitBlocks(VersionedLoop, DT, LI, nullptr, true); 131 assert(NonVersionedLoop->isLoopSimplifyForm() && 132 VersionedLoop->isLoopSimplifyForm() && 133 "The versioned loops should be in simplify form."); 134 135 // RuntimeCheckBB and RuntimeCheckBI is recorded 136 assert(RuntimeCheckBB && RuntimeCheckBI); 137 } 138 139 void LoopVersioning::addPHINodes( 140 const SmallVectorImpl<Instruction *> &DefsUsedOutside) { 141 BasicBlock *PHIBlock = VersionedLoop->getExitBlock(); 142 assert(PHIBlock && "No single successor to loop exit block"); 143 PHINode *PN; 144 145 // First add a single-operand PHI for each DefsUsedOutside if one does not 146 // exists yet. 147 for (auto *Inst : DefsUsedOutside) { 148 // See if we have a single-operand PHI with the value defined by the 149 // original loop. 150 for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) { 151 if (PN->getIncomingValue(0) == Inst) 152 break; 153 } 154 // If not create it. 155 if (!PN) { 156 PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver", 157 &PHIBlock->front()); 158 SmallVector<User*, 8> UsersToUpdate; 159 for (User *U : Inst->users()) 160 if (!VersionedLoop->contains(cast<Instruction>(U)->getParent())) 161 UsersToUpdate.push_back(U); 162 for (User *U : UsersToUpdate) 163 U->replaceUsesOfWith(Inst, PN); 164 PN->addIncoming(Inst, VersionedLoop->getExitingBlock()); 165 } 166 } 167 168 // Then for each PHI add the operand for the edge from the cloned loop. 169 for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) { 170 assert(PN->getNumOperands() == 1 && 171 "Exit block should only have on predecessor"); 172 173 // If the definition was cloned used that otherwise use the same value. 174 Value *ClonedValue = PN->getIncomingValue(0); 175 auto Mapped = VMap.find(ClonedValue); 176 if (Mapped != VMap.end()) 177 ClonedValue = Mapped->second; 178 179 PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock()); 180 } 181 } 182 183 void LoopVersioning::prepareNoAliasMetadata() { 184 // We need to turn the no-alias relation between pointer checking groups into 185 // no-aliasing annotations between instructions. 186 // 187 // We accomplish this by mapping each pointer checking group (a set of 188 // pointers memchecked together) to an alias scope and then also mapping each 189 // group to the list of scopes it can't alias. 190 191 const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking(); 192 LLVMContext &Context = VersionedLoop->getHeader()->getContext(); 193 194 // First allocate an aliasing scope for each pointer checking group. 195 // 196 // While traversing through the checking groups in the loop, also create a 197 // reverse map from pointers to the pointer checking group they were assigned 198 // to. 199 MDBuilder MDB(Context); 200 MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain"); 201 202 for (const auto &Group : RtPtrChecking->CheckingGroups) { 203 GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain); 204 205 for (unsigned PtrIdx : Group.Members) 206 PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group; 207 } 208 209 // Go through the checks and for each pointer group, collect the scopes for 210 // each non-aliasing pointer group. 211 DenseMap<const RuntimeCheckingPtrGroup *, SmallVector<Metadata *, 4>> 212 GroupToNonAliasingScopes; 213 214 for (const auto &Check : AliasChecks) 215 GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]); 216 217 // Finally, transform the above to actually map to scope list which is what 218 // the metadata uses. 219 220 for (auto Pair : GroupToNonAliasingScopes) 221 GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second); 222 } 223 224 void LoopVersioning::annotateLoopWithNoAlias() { 225 if (!AnnotateNoAlias) 226 return; 227 228 // First prepare the maps. 229 prepareNoAliasMetadata(); 230 231 // Add the scope and no-alias metadata to the instructions. 232 for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) { 233 annotateInstWithNoAlias(I); 234 } 235 } 236 237 void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst, 238 const Instruction *OrigInst) { 239 if (!AnnotateNoAlias) 240 return; 241 242 LLVMContext &Context = VersionedLoop->getHeader()->getContext(); 243 const Value *Ptr = isa<LoadInst>(OrigInst) 244 ? cast<LoadInst>(OrigInst)->getPointerOperand() 245 : cast<StoreInst>(OrigInst)->getPointerOperand(); 246 247 // Find the group for the pointer and then add the scope metadata. 248 auto Group = PtrToGroup.find(Ptr); 249 if (Group != PtrToGroup.end()) { 250 VersionedInst->setMetadata( 251 LLVMContext::MD_alias_scope, 252 MDNode::concatenate( 253 VersionedInst->getMetadata(LLVMContext::MD_alias_scope), 254 MDNode::get(Context, GroupToScope[Group->second]))); 255 256 // Add the no-alias metadata. 257 auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second); 258 if (NonAliasingScopeList != GroupToNonAliasingScopeList.end()) 259 VersionedInst->setMetadata( 260 LLVMContext::MD_noalias, 261 MDNode::concatenate( 262 VersionedInst->getMetadata(LLVMContext::MD_noalias), 263 NonAliasingScopeList->second)); 264 } 265 } 266 267 namespace { 268 bool runImpl(LoopInfo *LI, function_ref<const LoopAccessInfo &(Loop &)> GetLAA, 269 DominatorTree *DT, ScalarEvolution *SE) { 270 // Build up a worklist of inner-loops to version. This is necessary as the 271 // act of versioning a loop creates new loops and can invalidate iterators 272 // across the loops. 273 SmallVector<Loop *, 8> Worklist; 274 275 for (Loop *TopLevelLoop : *LI) 276 for (Loop *L : depth_first(TopLevelLoop)) 277 // We only handle inner-most loops. 278 if (L->isInnermost()) 279 Worklist.push_back(L); 280 281 // Now walk the identified inner loops. 282 bool Changed = false; 283 for (Loop *L : Worklist) { 284 if (!L->isLoopSimplifyForm() || !L->isRotatedForm() || 285 !L->getExitingBlock()) 286 continue; 287 const LoopAccessInfo &LAI = GetLAA(*L); 288 if (!LAI.hasConvergentOp() && 289 (LAI.getNumRuntimePointerChecks() || 290 !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) { 291 LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), L, 292 LI, DT, SE); 293 LVer.versionLoop(); 294 LVer.annotateLoopWithNoAlias(); 295 Changed = true; 296 } 297 } 298 299 return Changed; 300 } 301 302 /// Also expose this is a pass. Currently this is only used for 303 /// unit-testing. It adds all memchecks necessary to remove all may-aliasing 304 /// array accesses from the loop. 305 class LoopVersioningLegacyPass : public FunctionPass { 306 public: 307 LoopVersioningLegacyPass() : FunctionPass(ID) { 308 initializeLoopVersioningLegacyPassPass(*PassRegistry::getPassRegistry()); 309 } 310 311 bool runOnFunction(Function &F) override { 312 auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 313 auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & { 314 return getAnalysis<LoopAccessLegacyAnalysis>().getInfo(&L); 315 }; 316 317 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 318 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 319 320 return runImpl(LI, GetLAA, DT, SE); 321 } 322 323 void getAnalysisUsage(AnalysisUsage &AU) const override { 324 AU.addRequired<LoopInfoWrapperPass>(); 325 AU.addPreserved<LoopInfoWrapperPass>(); 326 AU.addRequired<LoopAccessLegacyAnalysis>(); 327 AU.addRequired<DominatorTreeWrapperPass>(); 328 AU.addPreserved<DominatorTreeWrapperPass>(); 329 AU.addRequired<ScalarEvolutionWrapperPass>(); 330 } 331 332 static char ID; 333 }; 334 } 335 336 char LoopVersioningLegacyPass::ID; 337 static const char LVer_name[] = "Loop Versioning"; 338 339 INITIALIZE_PASS_BEGIN(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false, 340 false) 341 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 342 INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) 343 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 344 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) 345 INITIALIZE_PASS_END(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false, 346 false) 347 348 namespace llvm { 349 FunctionPass *createLoopVersioningLegacyPass() { 350 return new LoopVersioningLegacyPass(); 351 } 352 353 PreservedAnalyses LoopVersioningPass::run(Function &F, 354 FunctionAnalysisManager &AM) { 355 auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); 356 auto &LI = AM.getResult<LoopAnalysis>(F); 357 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 358 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 359 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); 360 auto &AA = AM.getResult<AAManager>(F); 361 auto &AC = AM.getResult<AssumptionAnalysis>(F); 362 363 auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); 364 auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & { 365 LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, 366 TLI, TTI, nullptr, nullptr, nullptr}; 367 return LAM.getResult<LoopAccessAnalysis>(L, AR); 368 }; 369 370 if (runImpl(&LI, GetLAA, &DT, &SE)) 371 return PreservedAnalyses::none(); 372 return PreservedAnalyses::all(); 373 } 374 } // namespace llvm 375