1 //===-- LoopPredication.cpp - Guard based loop predication pass -----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // The LoopPredication pass tries to convert loop variant range checks to loop 11 // invariant by widening checks across loop iterations. For example, it will 12 // convert 13 // 14 // for (i = 0; i < n; i++) { 15 // guard(i < len); 16 // ... 17 // } 18 // 19 // to 20 // 21 // for (i = 0; i < n; i++) { 22 // guard(n - 1 < len); 23 // ... 24 // } 25 // 26 // After this transformation the condition of the guard is loop invariant, so 27 // loop-unswitch can later unswitch the loop by this condition which basically 28 // predicates the loop by the widened condition: 29 // 30 // if (n - 1 < len) 31 // for (i = 0; i < n; i++) { 32 // ... 33 // } 34 // else 35 // deoptimize 36 // 37 //===----------------------------------------------------------------------===// 38 39 #include "llvm/Transforms/Scalar/LoopPredication.h" 40 #include "llvm/Pass.h" 41 #include "llvm/Analysis/LoopInfo.h" 42 #include "llvm/Analysis/LoopPass.h" 43 #include "llvm/Analysis/ScalarEvolution.h" 44 #include "llvm/Analysis/ScalarEvolutionExpander.h" 45 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 46 #include "llvm/IR/Function.h" 47 #include "llvm/IR/GlobalValue.h" 48 #include "llvm/IR/IntrinsicInst.h" 49 #include "llvm/IR/Module.h" 50 #include "llvm/IR/PatternMatch.h" 51 #include "llvm/Support/Debug.h" 52 #include "llvm/Transforms/Scalar.h" 53 #include "llvm/Transforms/Utils/LoopUtils.h" 54 55 #define DEBUG_TYPE "loop-predication" 56 57 using namespace llvm; 58 59 namespace { 60 class LoopPredication { 61 /// Represents an induction variable check: 62 /// icmp Pred, <induction variable>, <loop invariant limit> 63 struct LoopICmp { 64 ICmpInst::Predicate Pred; 65 const SCEVAddRecExpr *IV; 66 const SCEV *Limit; 67 LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV, 68 const SCEV *Limit) 69 : Pred(Pred), IV(IV), Limit(Limit) {} 70 LoopICmp() {} 71 }; 72 73 ScalarEvolution *SE; 74 75 Loop *L; 76 const DataLayout *DL; 77 BasicBlock *Preheader; 78 79 Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI); 80 81 Value *expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder, 82 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, 83 Instruction *InsertAt); 84 85 Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, 86 IRBuilder<> &Builder); 87 bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); 88 89 public: 90 LoopPredication(ScalarEvolution *SE) : SE(SE){}; 91 bool runOnLoop(Loop *L); 92 }; 93 94 class LoopPredicationLegacyPass : public LoopPass { 95 public: 96 static char ID; 97 LoopPredicationLegacyPass() : LoopPass(ID) { 98 initializeLoopPredicationLegacyPassPass(*PassRegistry::getPassRegistry()); 99 } 100 101 void getAnalysisUsage(AnalysisUsage &AU) const override { 102 getLoopAnalysisUsage(AU); 103 } 104 105 bool runOnLoop(Loop *L, LPPassManager &LPM) override { 106 if (skipLoop(L)) 107 return false; 108 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 109 LoopPredication LP(SE); 110 return LP.runOnLoop(L); 111 } 112 }; 113 114 char LoopPredicationLegacyPass::ID = 0; 115 } // end namespace llvm 116 117 INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication", 118 "Loop predication", false, false) 119 INITIALIZE_PASS_DEPENDENCY(LoopPass) 120 INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication", 121 "Loop predication", false, false) 122 123 Pass *llvm::createLoopPredicationPass() { 124 return new LoopPredicationLegacyPass(); 125 } 126 127 PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, 128 LoopStandardAnalysisResults &AR, 129 LPMUpdater &U) { 130 LoopPredication LP(&AR.SE); 131 if (!LP.runOnLoop(&L)) 132 return PreservedAnalyses::all(); 133 134 return getLoopPassPreservedAnalyses(); 135 } 136 137 Optional<LoopPredication::LoopICmp> 138 LoopPredication::parseLoopICmp(ICmpInst *ICI) { 139 ICmpInst::Predicate Pred = ICI->getPredicate(); 140 141 Value *LHS = ICI->getOperand(0); 142 Value *RHS = ICI->getOperand(1); 143 const SCEV *LHSS = SE->getSCEV(LHS); 144 if (isa<SCEVCouldNotCompute>(LHSS)) 145 return None; 146 const SCEV *RHSS = SE->getSCEV(RHS); 147 if (isa<SCEVCouldNotCompute>(RHSS)) 148 return None; 149 150 // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV 151 if (SE->isLoopInvariant(LHSS, L)) { 152 std::swap(LHS, RHS); 153 std::swap(LHSS, RHSS); 154 Pred = ICmpInst::getSwappedPredicate(Pred); 155 } 156 157 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS); 158 if (!AR || AR->getLoop() != L) 159 return None; 160 161 return LoopICmp(Pred, AR, RHSS); 162 } 163 164 Value *LoopPredication::expandCheck(SCEVExpander &Expander, 165 IRBuilder<> &Builder, 166 ICmpInst::Predicate Pred, const SCEV *LHS, 167 const SCEV *RHS, Instruction *InsertAt) { 168 Type *Ty = LHS->getType(); 169 assert(Ty == RHS->getType() && "expandCheck operands have different types?"); 170 Value *LHSV = Expander.expandCodeFor(LHS, Ty, InsertAt); 171 Value *RHSV = Expander.expandCodeFor(RHS, Ty, InsertAt); 172 return Builder.CreateICmp(Pred, LHSV, RHSV); 173 } 174 175 /// If ICI can be widened to a loop invariant condition emits the loop 176 /// invariant condition in the loop preheader and return it, otherwise 177 /// returns None. 178 Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, 179 SCEVExpander &Expander, 180 IRBuilder<> &Builder) { 181 DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); 182 DEBUG(ICI->dump()); 183 184 auto RangeCheck = parseLoopICmp(ICI); 185 if (!RangeCheck) { 186 DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); 187 return None; 188 } 189 190 ICmpInst::Predicate Pred = RangeCheck->Pred; 191 const SCEVAddRecExpr *IndexAR = RangeCheck->IV; 192 const SCEV *RHSS = RangeCheck->Limit; 193 194 auto CanExpand = [this](const SCEV *S) { 195 return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE); 196 }; 197 if (!CanExpand(RHSS)) 198 return None; 199 200 DEBUG(dbgs() << "IndexAR: "); 201 DEBUG(IndexAR->dump()); 202 203 bool IsIncreasing = false; 204 if (!SE->isMonotonicPredicate(IndexAR, Pred, IsIncreasing)) 205 return None; 206 207 // If the predicate is increasing the condition can change from false to true 208 // as the loop progresses, in this case take the value on the first iteration 209 // for the widened check. Otherwise the condition can change from true to 210 // false as the loop progresses, so take the value on the last iteration. 211 const SCEV *NewLHSS = IsIncreasing 212 ? IndexAR->getStart() 213 : SE->getSCEVAtScope(IndexAR, L->getParentLoop()); 214 if (NewLHSS == IndexAR) { 215 DEBUG(dbgs() << "Can't compute NewLHSS!\n"); 216 return None; 217 } 218 219 DEBUG(dbgs() << "NewLHSS: "); 220 DEBUG(NewLHSS->dump()); 221 222 if (!CanExpand(NewLHSS)) 223 return None; 224 225 DEBUG(dbgs() << "NewLHSS is loop invariant and safe to expand. Expand!\n"); 226 227 Instruction *InsertAt = Preheader->getTerminator(); 228 return expandCheck(Expander, Builder, Pred, NewLHSS, RHSS, InsertAt); 229 } 230 231 bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, 232 SCEVExpander &Expander) { 233 DEBUG(dbgs() << "Processing guard:\n"); 234 DEBUG(Guard->dump()); 235 236 IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator())); 237 238 // The guard condition is expected to be in form of: 239 // cond1 && cond2 && cond3 ... 240 // Iterate over subconditions looking for for icmp conditions which can be 241 // widened across loop iterations. Widening these conditions remember the 242 // resulting list of subconditions in Checks vector. 243 SmallVector<Value *, 4> Worklist(1, Guard->getOperand(0)); 244 SmallPtrSet<Value *, 4> Visited; 245 246 SmallVector<Value *, 4> Checks; 247 248 unsigned NumWidened = 0; 249 do { 250 Value *Condition = Worklist.pop_back_val(); 251 if (!Visited.insert(Condition).second) 252 continue; 253 254 Value *LHS, *RHS; 255 using namespace llvm::PatternMatch; 256 if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) { 257 Worklist.push_back(LHS); 258 Worklist.push_back(RHS); 259 continue; 260 } 261 262 if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) { 263 if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Builder)) { 264 Checks.push_back(NewRangeCheck.getValue()); 265 NumWidened++; 266 continue; 267 } 268 } 269 270 // Save the condition as is if we can't widen it 271 Checks.push_back(Condition); 272 } while (Worklist.size() != 0); 273 274 if (NumWidened == 0) 275 return false; 276 277 // Emit the new guard condition 278 Builder.SetInsertPoint(Guard); 279 Value *LastCheck = nullptr; 280 for (auto *Check : Checks) 281 if (!LastCheck) 282 LastCheck = Check; 283 else 284 LastCheck = Builder.CreateAnd(LastCheck, Check); 285 Guard->setOperand(0, LastCheck); 286 287 DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); 288 return true; 289 } 290 291 bool LoopPredication::runOnLoop(Loop *Loop) { 292 L = Loop; 293 294 DEBUG(dbgs() << "Analyzing "); 295 DEBUG(L->dump()); 296 297 Module *M = L->getHeader()->getModule(); 298 299 // There is nothing to do if the module doesn't use guards 300 auto *GuardDecl = 301 M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard)); 302 if (!GuardDecl || GuardDecl->use_empty()) 303 return false; 304 305 DL = &M->getDataLayout(); 306 307 Preheader = L->getLoopPreheader(); 308 if (!Preheader) 309 return false; 310 311 // Collect all the guards into a vector and process later, so as not 312 // to invalidate the instruction iterator. 313 SmallVector<IntrinsicInst *, 4> Guards; 314 for (const auto BB : L->blocks()) 315 for (auto &I : *BB) 316 if (auto *II = dyn_cast<IntrinsicInst>(&I)) 317 if (II->getIntrinsicID() == Intrinsic::experimental_guard) 318 Guards.push_back(II); 319 320 if (Guards.empty()) 321 return false; 322 323 SCEVExpander Expander(*SE, *DL, "loop-predication"); 324 325 bool Changed = false; 326 for (auto *Guard : Guards) 327 Changed |= widenGuardConditions(Guard, Expander); 328 329 return Changed; 330 } 331