1 //===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements a pass that recognizes certain loop idioms and 10 // transforms them into more optimized versions of the same loop. In cases 11 // where this happens, it can be a significant performance win. 12 // 13 // We currently only recognize one loop that finds the first mismatched byte 14 // in an array and returns the index, i.e. something like: 15 // 16 // while (++i != n) { 17 // if (a[i] != b[i]) 18 // break; 19 // } 20 // 21 // In this example we can actually vectorize the loop despite the early exit, 22 // although the loop vectorizer does not support it. It requires some extra 23 // checks to deal with the possibility of faulting loads when crossing page 24 // boundaries. However, even with these checks it is still profitable to do the 25 // transformation. 26 // 27 //===----------------------------------------------------------------------===// 28 // 29 // NOTE: This Pass matches a really specific loop pattern because it's only 30 // supposed to be a temporary solution until our LoopVectorizer is powerful 31 // enought to vectorize it automatically. 32 // 33 // TODO List: 34 // 35 // * Add support for the inverse case where we scan for a matching element. 36 // * Permit 64-bit induction variable types. 37 // * Recognize loops that increment the IV *after* comparing bytes. 38 // * Allow 32-bit sign-extends of the IV used by the GEP. 39 // 40 //===----------------------------------------------------------------------===// 41 42 #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h" 43 #include "llvm/Analysis/DomTreeUpdater.h" 44 #include "llvm/Analysis/LoopPass.h" 45 #include "llvm/Analysis/TargetTransformInfo.h" 46 #include "llvm/IR/Dominators.h" 47 #include "llvm/IR/IRBuilder.h" 48 #include "llvm/IR/Intrinsics.h" 49 #include "llvm/IR/MDBuilder.h" 50 #include "llvm/IR/PatternMatch.h" 51 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 52 53 using namespace llvm; 54 using namespace PatternMatch; 55 56 #define DEBUG_TYPE "loop-idiom-vectorize" 57 58 static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden, 59 cl::init(false), 60 cl::desc("Disable Loop Idiom Vectorize Pass.")); 61 62 static cl::opt<bool> 63 DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden, 64 cl::init(false), 65 cl::desc("Proceed with Loop Idiom Vectorize Pass, but do " 66 "not convert byte-compare loop(s).")); 67 68 static cl::opt<bool> 69 VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false), 70 cl::desc("Verify loops generated Loop Idiom Vectorize Pass.")); 71 72 namespace { 73 74 class LoopIdiomVectorize { 75 Loop *CurLoop = nullptr; 76 DominatorTree *DT; 77 LoopInfo *LI; 78 const TargetTransformInfo *TTI; 79 const DataLayout *DL; 80 81 // Blocks that will be used for inserting vectorized code. 82 BasicBlock *EndBlock = nullptr; 83 BasicBlock *VectorLoopPreheaderBlock = nullptr; 84 BasicBlock *VectorLoopStartBlock = nullptr; 85 BasicBlock *VectorLoopMismatchBlock = nullptr; 86 BasicBlock *VectorLoopIncBlock = nullptr; 87 88 public: 89 explicit LoopIdiomVectorize(DominatorTree *DT, LoopInfo *LI, 90 const TargetTransformInfo *TTI, 91 const DataLayout *DL) 92 : DT(DT), LI(LI), TTI(TTI), DL(DL) {} 93 94 bool run(Loop *L); 95 96 private: 97 /// \name Countable Loop Idiom Handling 98 /// @{ 99 100 bool runOnCountableLoop(); 101 bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount, 102 SmallVectorImpl<BasicBlock *> &ExitBlocks); 103 104 bool recognizeByteCompare(); 105 106 Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, 107 GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, 108 Instruction *Index, Value *Start, Value *MaxLen); 109 110 Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, 111 GetElementPtrInst *GEPA, 112 GetElementPtrInst *GEPB, Value *ExtStart, 113 Value *ExtEnd); 114 115 void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, 116 PHINode *IndPhi, Value *MaxLen, Instruction *Index, 117 Value *Start, bool IncIdx, BasicBlock *FoundBB, 118 BasicBlock *EndBB); 119 /// @} 120 }; 121 } // anonymous namespace 122 123 PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM, 124 LoopStandardAnalysisResults &AR, 125 LPMUpdater &) { 126 if (DisableAll) 127 return PreservedAnalyses::all(); 128 129 const auto *DL = &L.getHeader()->getDataLayout(); 130 131 LoopIdiomVectorize LIT(&AR.DT, &AR.LI, &AR.TTI, DL); 132 if (!LIT.run(&L)) 133 return PreservedAnalyses::all(); 134 135 return PreservedAnalyses::none(); 136 } 137 138 //===----------------------------------------------------------------------===// 139 // 140 // Implementation of LoopIdiomVectorize 141 // 142 //===----------------------------------------------------------------------===// 143 144 bool LoopIdiomVectorize::run(Loop *L) { 145 CurLoop = L; 146 147 Function &F = *L->getHeader()->getParent(); 148 if (DisableAll || F.hasOptSize()) 149 return false; 150 151 if (F.hasFnAttribute(Attribute::NoImplicitFloat)) { 152 LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName() 153 << " due to its NoImplicitFloat attribute"); 154 return false; 155 } 156 157 // If the loop could not be converted to canonical form, it must have an 158 // indirectbr in it, just give up. 159 if (!L->getLoopPreheader()) 160 return false; 161 162 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %" 163 << CurLoop->getHeader()->getName() << "\n"); 164 165 return recognizeByteCompare(); 166 } 167 168 bool LoopIdiomVectorize::recognizeByteCompare() { 169 // Currently the transformation only works on scalable vector types, although 170 // there is no fundamental reason why it cannot be made to work for fixed 171 // width too. 172 173 // We also need to know the minimum page size for the target in order to 174 // generate runtime memory checks to ensure the vector version won't fault. 175 if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() || 176 DisableByteCmp) 177 return false; 178 179 BasicBlock *Header = CurLoop->getHeader(); 180 181 // In LoopIdiomVectorize::run we have already checked that the loop 182 // has a preheader so we can assume it's in a canonical form. 183 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2) 184 return false; 185 186 PHINode *PN = dyn_cast<PHINode>(&Header->front()); 187 if (!PN || PN->getNumIncomingValues() != 2) 188 return false; 189 190 auto LoopBlocks = CurLoop->getBlocks(); 191 // The first block in the loop should contain only 4 instructions, e.g. 192 // 193 // while.cond: 194 // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ] 195 // %inc = add i32 %res.phi, 1 196 // %cmp.not = icmp eq i32 %inc, %n 197 // br i1 %cmp.not, label %while.end, label %while.body 198 // 199 if (LoopBlocks[0]->sizeWithoutDebug() > 4) 200 return false; 201 202 // The second block should contain 7 instructions, e.g. 203 // 204 // while.body: 205 // %idx = zext i32 %inc to i64 206 // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx 207 // %load.a = load i8, ptr %idx.a 208 // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx 209 // %load.b = load i8, ptr %idx.b 210 // %cmp.not.ld = icmp eq i8 %load.a, %load.b 211 // br i1 %cmp.not.ld, label %while.cond, label %while.end 212 // 213 if (LoopBlocks[1]->sizeWithoutDebug() > 7) 214 return false; 215 216 // The incoming value to the PHI node from the loop should be an add of 1. 217 Value *StartIdx = nullptr; 218 Instruction *Index = nullptr; 219 if (!CurLoop->contains(PN->getIncomingBlock(0))) { 220 StartIdx = PN->getIncomingValue(0); 221 Index = dyn_cast<Instruction>(PN->getIncomingValue(1)); 222 } else { 223 StartIdx = PN->getIncomingValue(1); 224 Index = dyn_cast<Instruction>(PN->getIncomingValue(0)); 225 } 226 227 // Limit to 32-bit types for now 228 if (!Index || !Index->getType()->isIntegerTy(32) || 229 !match(Index, m_c_Add(m_Specific(PN), m_One()))) 230 return false; 231 232 // If we match the pattern, PN and Index will be replaced with the result of 233 // the cttz.elts intrinsic. If any other instructions are used outside of 234 // the loop, we cannot replace it. 235 for (BasicBlock *BB : LoopBlocks) 236 for (Instruction &I : *BB) 237 if (&I != PN && &I != Index) 238 for (User *U : I.users()) 239 if (!CurLoop->contains(cast<Instruction>(U))) 240 return false; 241 242 // Match the branch instruction for the header 243 ICmpInst::Predicate Pred; 244 Value *MaxLen; 245 BasicBlock *EndBB, *WhileBB; 246 if (!match(Header->getTerminator(), 247 m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)), 248 m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) || 249 Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB)) 250 return false; 251 252 // WhileBB should contain the pattern of load & compare instructions. Match 253 // the pattern and find the GEP instructions used by the loads. 254 ICmpInst::Predicate WhilePred; 255 BasicBlock *FoundBB; 256 BasicBlock *TrueBB; 257 Value *LoadA, *LoadB; 258 if (!match(WhileBB->getTerminator(), 259 m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)), 260 m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) || 261 WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB)) 262 return false; 263 264 Value *A, *B; 265 if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B)))) 266 return false; 267 268 LoadInst *LoadAI = cast<LoadInst>(LoadA); 269 LoadInst *LoadBI = cast<LoadInst>(LoadB); 270 if (!LoadAI->isSimple() || !LoadBI->isSimple()) 271 return false; 272 273 GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A); 274 GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B); 275 276 if (!GEPA || !GEPB) 277 return false; 278 279 Value *PtrA = GEPA->getPointerOperand(); 280 Value *PtrB = GEPB->getPointerOperand(); 281 282 // Check we are loading i8 values from two loop invariant pointers 283 if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) || 284 !GEPA->getResultElementType()->isIntegerTy(8) || 285 !GEPB->getResultElementType()->isIntegerTy(8) || 286 !LoadAI->getType()->isIntegerTy(8) || 287 !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB) 288 return false; 289 290 // Check that the index to the GEPs is the index we found earlier 291 if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1) 292 return false; 293 294 Value *IdxA = GEPA->getOperand(GEPA->getNumIndices()); 295 Value *IdxB = GEPB->getOperand(GEPB->getNumIndices()); 296 if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index)))) 297 return false; 298 299 // We only ever expect the pre-incremented index value to be used inside the 300 // loop. 301 if (!PN->hasOneUse()) 302 return false; 303 304 // Ensure that when the Found and End blocks are identical the PHIs have the 305 // supported format. We don't currently allow cases like this: 306 // while.cond: 307 // ... 308 // br i1 %cmp.not, label %while.end, label %while.body 309 // 310 // while.body: 311 // ... 312 // br i1 %cmp.not2, label %while.cond, label %while.end 313 // 314 // while.end: 315 // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ] 316 // 317 // Where the incoming values for %final_ptr are unique and from each of the 318 // loop blocks, but not actually defined in the loop. This requires extra 319 // work setting up the byte.compare block, i.e. by introducing a select to 320 // choose the correct value. 321 // TODO: We could add support for this in future. 322 if (FoundBB == EndBB) { 323 for (PHINode &EndPN : EndBB->phis()) { 324 Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header); 325 Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB); 326 327 // The value of the index when leaving the while.cond block is always the 328 // same as the end value (MaxLen) so we permit either. The value when 329 // leaving the while.body block should only be the index. Otherwise for 330 // any other values we only allow ones that are same for both blocks. 331 if (WhileCondVal != WhileBodyVal && 332 ((WhileCondVal != Index && WhileCondVal != MaxLen) || 333 (WhileBodyVal != Index))) 334 return false; 335 } 336 } 337 338 LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n" 339 << *(EndBB->getParent()) << "\n\n"); 340 341 // The index is incremented before the GEP/Load pair so we need to 342 // add 1 to the start value. 343 transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true, 344 FoundBB, EndBB); 345 return true; 346 } 347 348 Value *LoopIdiomVectorize::createMaskedFindMismatch( 349 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, 350 GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { 351 Type *I64Type = Builder.getInt64Ty(); 352 Type *ResType = Builder.getInt32Ty(); 353 Type *LoadType = Builder.getInt8Ty(); 354 Value *PtrA = GEPA->getPointerOperand(); 355 Value *PtrB = GEPB->getPointerOperand(); 356 357 // At this point we know two things must be true: 358 // 1. Start <= End 359 // 2. ExtMaxLen <= MinPageSize due to the page checks. 360 // Therefore, we know that we can use a 64-bit induction variable that 361 // starts from 0 -> ExtMaxLen and it will not overflow. 362 ScalableVectorType *PredVTy = 363 ScalableVectorType::get(Builder.getInt1Ty(), 16); 364 365 Value *InitialPred = Builder.CreateIntrinsic( 366 Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd}); 367 368 Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {}); 369 VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "", 370 /*HasNUW=*/true, /*HasNSW=*/true); 371 372 Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(), 373 Builder.getInt1(false)); 374 375 BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock); 376 Builder.Insert(JumpToVectorLoop); 377 378 DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock, 379 VectorLoopStartBlock}}); 380 381 // Set up the first vector loop block by creating the PHIs, doing the vector 382 // loads and comparing the vectors. 383 Builder.SetInsertPoint(VectorLoopStartBlock); 384 PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred"); 385 LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock); 386 PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index"); 387 VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock); 388 Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16); 389 Value *Passthru = ConstantInt::getNullValue(VectorLoadType); 390 391 Value *VectorLhsGep = 392 Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds()); 393 Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep, 394 Align(1), LoopPred, Passthru); 395 396 Value *VectorRhsGep = 397 Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds()); 398 Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep, 399 Align(1), LoopPred, Passthru); 400 401 Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad); 402 VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse); 403 Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp); 404 BranchInst *VectorEarlyExit = BranchInst::Create( 405 VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes); 406 Builder.Insert(VectorEarlyExit); 407 408 DTU.applyUpdates( 409 {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, 410 {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); 411 412 // Increment the index counter and calculate the predicate for the next 413 // iteration of the loop. We branch back to the start of the loop if there 414 // is at least one active lane. 415 Builder.SetInsertPoint(VectorLoopIncBlock); 416 Value *NewVectorIndexPhi = 417 Builder.CreateAdd(VectorIndexPhi, VecLen, "", 418 /*HasNUW=*/true, /*HasNSW=*/true); 419 VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock); 420 Value *NewPred = 421 Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, 422 {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd}); 423 LoopPred->addIncoming(NewPred, VectorLoopIncBlock); 424 425 Value *PredHasActiveLanes = 426 Builder.CreateExtractElement(NewPred, uint64_t(0)); 427 BranchInst *VectorLoopBranchBack = 428 BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes); 429 Builder.Insert(VectorLoopBranchBack); 430 431 DTU.applyUpdates( 432 {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, 433 {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); 434 435 // If we found a mismatch then we need to calculate which lane in the vector 436 // had a mismatch and add that on to the current loop index. 437 Builder.SetInsertPoint(VectorLoopMismatchBlock); 438 PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred"); 439 FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock); 440 PHINode *LastLoopPred = 441 Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred"); 442 LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock); 443 PHINode *VectorFoundIndex = 444 Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index"); 445 VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock); 446 447 Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred); 448 Value *Ctz = Builder.CreateIntrinsic( 449 Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()}, 450 {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)}); 451 Ctz = Builder.CreateZExt(Ctz, I64Type); 452 Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "", 453 /*HasNUW=*/true, /*HasNSW=*/true); 454 return Builder.CreateTrunc(VectorLoopRes64, ResType); 455 } 456 457 Value *LoopIdiomVectorize::expandFindMismatch( 458 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, 459 GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) { 460 Value *PtrA = GEPA->getPointerOperand(); 461 Value *PtrB = GEPB->getPointerOperand(); 462 463 // Get the arguments and types for the intrinsic. 464 BasicBlock *Preheader = CurLoop->getLoopPreheader(); 465 BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator()); 466 LLVMContext &Ctx = PHBranch->getContext(); 467 Type *LoadType = Type::getInt8Ty(Ctx); 468 Type *ResType = Builder.getInt32Ty(); 469 470 // Split block in the original loop preheader. 471 EndBlock = SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end"); 472 473 // Create the blocks that we're going to need: 474 // 1. A block for checking the zero-extended length exceeds 0 475 // 2. A block to check that the start and end addresses of a given array 476 // lie on the same page. 477 // 3. The vector loop preheader. 478 // 4. The first vector loop block. 479 // 5. The vector loop increment block. 480 // 6. A block we can jump to from the vector loop when a mismatch is found. 481 // 7. The first block of the scalar loop itself, containing PHIs , loads 482 // and cmp. 483 // 8. A scalar loop increment block to increment the PHIs and go back 484 // around the loop. 485 486 BasicBlock *MinItCheckBlock = BasicBlock::Create( 487 Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock); 488 489 // Update the terminator added by SplitBlock to branch to the first block 490 Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock); 491 492 BasicBlock *MemCheckBlock = BasicBlock::Create( 493 Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock); 494 495 VectorLoopPreheaderBlock = BasicBlock::Create( 496 Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock); 497 498 VectorLoopStartBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop", 499 EndBlock->getParent(), EndBlock); 500 501 VectorLoopIncBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_inc", 502 EndBlock->getParent(), EndBlock); 503 504 VectorLoopMismatchBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_found", 505 EndBlock->getParent(), EndBlock); 506 507 BasicBlock *LoopPreHeaderBlock = BasicBlock::Create( 508 Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock); 509 510 BasicBlock *LoopStartBlock = 511 BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock); 512 513 BasicBlock *LoopIncBlock = BasicBlock::Create( 514 Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock); 515 516 DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock}, 517 {DominatorTree::Delete, Preheader, EndBlock}}); 518 519 // Update LoopInfo with the new vector & scalar loops. 520 auto VectorLoop = LI->AllocateLoop(); 521 auto ScalarLoop = LI->AllocateLoop(); 522 523 if (CurLoop->getParentLoop()) { 524 CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, *LI); 525 CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, *LI); 526 CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopPreheaderBlock, 527 *LI); 528 CurLoop->getParentLoop()->addChildLoop(VectorLoop); 529 CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopMismatchBlock, *LI); 530 CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, *LI); 531 CurLoop->getParentLoop()->addChildLoop(ScalarLoop); 532 } else { 533 LI->addTopLevelLoop(VectorLoop); 534 LI->addTopLevelLoop(ScalarLoop); 535 } 536 537 // Add the new basic blocks to their associated loops. 538 VectorLoop->addBasicBlockToLoop(VectorLoopStartBlock, *LI); 539 VectorLoop->addBasicBlockToLoop(VectorLoopIncBlock, *LI); 540 541 ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI); 542 ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI); 543 544 // Set up some types and constants that we intend to reuse. 545 Type *I64Type = Builder.getInt64Ty(); 546 547 // Check the zero-extended iteration count > 0 548 Builder.SetInsertPoint(MinItCheckBlock); 549 Value *ExtStart = Builder.CreateZExt(Start, I64Type); 550 Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type); 551 // This check doesn't really cost us very much. 552 553 Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen); 554 BranchInst *MinItCheckBr = 555 BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck); 556 MinItCheckBr->setMetadata( 557 LLVMContext::MD_prof, 558 MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1)); 559 Builder.Insert(MinItCheckBr); 560 561 DTU.applyUpdates( 562 {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock}, 563 {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}}); 564 565 // For each of the arrays, check the start/end addresses are on the same 566 // page. 567 Builder.SetInsertPoint(MemCheckBlock); 568 569 // The early exit in the original loop means that when performing vector 570 // loads we are potentially reading ahead of the early exit. So we could 571 // fault if crossing a page boundary. Therefore, we create runtime memory 572 // checks based on the minimum page size as follows: 573 // 1. Calculate the addresses of the first memory accesses in the loop, 574 // i.e. LhsStart and RhsStart. 575 // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd. 576 // 3. Determine which pages correspond to all the memory accesses, i.e 577 // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage. 578 // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then 579 // we know we won't cross any page boundaries in the loop so we can 580 // enter the vector loop! Otherwise we fall back on the scalar loop. 581 Value *LhsStartGEP = Builder.CreateGEP(LoadType, PtrA, ExtStart); 582 Value *RhsStartGEP = Builder.CreateGEP(LoadType, PtrB, ExtStart); 583 Value *RhsStart = Builder.CreatePtrToInt(RhsStartGEP, I64Type); 584 Value *LhsStart = Builder.CreatePtrToInt(LhsStartGEP, I64Type); 585 Value *LhsEndGEP = Builder.CreateGEP(LoadType, PtrA, ExtEnd); 586 Value *RhsEndGEP = Builder.CreateGEP(LoadType, PtrB, ExtEnd); 587 Value *LhsEnd = Builder.CreatePtrToInt(LhsEndGEP, I64Type); 588 Value *RhsEnd = Builder.CreatePtrToInt(RhsEndGEP, I64Type); 589 590 const uint64_t MinPageSize = TTI->getMinPageSize().value(); 591 const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize); 592 Value *LhsStartPage = Builder.CreateLShr(LhsStart, AddrShiftAmt); 593 Value *LhsEndPage = Builder.CreateLShr(LhsEnd, AddrShiftAmt); 594 Value *RhsStartPage = Builder.CreateLShr(RhsStart, AddrShiftAmt); 595 Value *RhsEndPage = Builder.CreateLShr(RhsEnd, AddrShiftAmt); 596 Value *LhsPageCmp = Builder.CreateICmpNE(LhsStartPage, LhsEndPage); 597 Value *RhsPageCmp = Builder.CreateICmpNE(RhsStartPage, RhsEndPage); 598 599 Value *CombinedPageCmp = Builder.CreateOr(LhsPageCmp, RhsPageCmp); 600 BranchInst *CombinedPageCmpCmpBr = BranchInst::Create( 601 LoopPreHeaderBlock, VectorLoopPreheaderBlock, CombinedPageCmp); 602 CombinedPageCmpCmpBr->setMetadata( 603 LLVMContext::MD_prof, MDBuilder(CombinedPageCmpCmpBr->getContext()) 604 .createBranchWeights(10, 90)); 605 Builder.Insert(CombinedPageCmpCmpBr); 606 607 DTU.applyUpdates( 608 {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock}, 609 {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}}); 610 611 // Set up the vector loop preheader, i.e. calculate initial loop predicate, 612 // zero-extend MaxLen to 64-bits, determine the number of vector elements 613 // processed in each iteration, etc. 614 Builder.SetInsertPoint(VectorLoopPreheaderBlock); 615 616 Value *VectorLoopRes = 617 createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd); 618 619 Builder.Insert(BranchInst::Create(EndBlock)); 620 621 DTU.applyUpdates( 622 {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}}); 623 624 // Generate code for scalar loop. 625 Builder.SetInsertPoint(LoopPreHeaderBlock); 626 Builder.Insert(BranchInst::Create(LoopStartBlock)); 627 628 DTU.applyUpdates( 629 {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}}); 630 631 Builder.SetInsertPoint(LoopStartBlock); 632 PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index"); 633 IndexPhi->addIncoming(Start, LoopPreHeaderBlock); 634 635 // Otherwise compare the values 636 // Load bytes from each array and compare them. 637 Value *GepOffset = Builder.CreateZExt(IndexPhi, I64Type); 638 639 Value *LhsGep = 640 Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds()); 641 Value *LhsLoad = Builder.CreateLoad(LoadType, LhsGep); 642 643 Value *RhsGep = 644 Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds()); 645 Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep); 646 647 Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad); 648 // If we have a mismatch then exit the loop ... 649 BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp); 650 Builder.Insert(MatchCmpBr); 651 652 DTU.applyUpdates({{DominatorTree::Insert, LoopStartBlock, LoopIncBlock}, 653 {DominatorTree::Insert, LoopStartBlock, EndBlock}}); 654 655 // Have we reached the maximum permitted length for the loop? 656 Builder.SetInsertPoint(LoopIncBlock); 657 Value *PhiInc = Builder.CreateAdd(IndexPhi, ConstantInt::get(ResType, 1), "", 658 /*HasNUW=*/Index->hasNoUnsignedWrap(), 659 /*HasNSW=*/Index->hasNoSignedWrap()); 660 IndexPhi->addIncoming(PhiInc, LoopIncBlock); 661 Value *IVCmp = Builder.CreateICmpEQ(PhiInc, MaxLen); 662 BranchInst *IVCmpBr = BranchInst::Create(EndBlock, LoopStartBlock, IVCmp); 663 Builder.Insert(IVCmpBr); 664 665 DTU.applyUpdates({{DominatorTree::Insert, LoopIncBlock, EndBlock}, 666 {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}}); 667 668 // In the end block we need to insert a PHI node to deal with three cases: 669 // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen. 670 // 2. We exitted the scalar loop early due to a mismatch and need to return 671 // the index that we found. 672 // 3. We didn't find a mismatch in the vector loop, so we return MaxLen. 673 // 4. We exitted the vector loop early due to a mismatch and need to return 674 // the index that we found. 675 Builder.SetInsertPoint(EndBlock, EndBlock->getFirstInsertionPt()); 676 PHINode *ResPhi = Builder.CreatePHI(ResType, 4, "mismatch_result"); 677 ResPhi->addIncoming(MaxLen, LoopIncBlock); 678 ResPhi->addIncoming(IndexPhi, LoopStartBlock); 679 ResPhi->addIncoming(MaxLen, VectorLoopIncBlock); 680 ResPhi->addIncoming(VectorLoopRes, VectorLoopMismatchBlock); 681 682 Value *FinalRes = Builder.CreateTrunc(ResPhi, ResType); 683 684 if (VerifyLoops) { 685 ScalarLoop->verifyLoop(); 686 VectorLoop->verifyLoop(); 687 if (!VectorLoop->isRecursivelyLCSSAForm(*DT, *LI)) 688 report_fatal_error("Loops must remain in LCSSA form!"); 689 if (!ScalarLoop->isRecursivelyLCSSAForm(*DT, *LI)) 690 report_fatal_error("Loops must remain in LCSSA form!"); 691 } 692 693 return FinalRes; 694 } 695 696 void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA, 697 GetElementPtrInst *GEPB, 698 PHINode *IndPhi, Value *MaxLen, 699 Instruction *Index, Value *Start, 700 bool IncIdx, BasicBlock *FoundBB, 701 BasicBlock *EndBB) { 702 703 // Insert the byte compare code at the end of the preheader block 704 BasicBlock *Preheader = CurLoop->getLoopPreheader(); 705 BasicBlock *Header = CurLoop->getHeader(); 706 BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator()); 707 IRBuilder<> Builder(PHBranch); 708 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 709 Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc()); 710 711 // Increment the pointer if this was done before the loads in the loop. 712 if (IncIdx) 713 Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1)); 714 715 Value *ByteCmpRes = 716 expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen); 717 718 // Replaces uses of index & induction Phi with intrinsic (we already 719 // checked that the the first instruction of Header is the Phi above). 720 assert(IndPhi->hasOneUse() && "Index phi node has more than one use!"); 721 Index->replaceAllUsesWith(ByteCmpRes); 722 723 assert(PHBranch->isUnconditional() && 724 "Expected preheader to terminate with an unconditional branch."); 725 726 // If no mismatch was found, we can jump to the end block. Create a 727 // new basic block for the compare instruction. 728 auto *CmpBB = BasicBlock::Create(Preheader->getContext(), "byte.compare", 729 Preheader->getParent()); 730 CmpBB->moveBefore(EndBB); 731 732 // Replace the branch in the preheader with an always-true conditional branch. 733 // This ensures there is still a reference to the original loop. 734 Builder.CreateCondBr(Builder.getTrue(), CmpBB, Header); 735 PHBranch->eraseFromParent(); 736 737 BasicBlock *MismatchEnd = cast<Instruction>(ByteCmpRes)->getParent(); 738 DTU.applyUpdates({{DominatorTree::Insert, MismatchEnd, CmpBB}}); 739 740 // Create the branch to either the end or found block depending on the value 741 // returned by the intrinsic. 742 Builder.SetInsertPoint(CmpBB); 743 if (FoundBB != EndBB) { 744 Value *FoundCmp = Builder.CreateICmpEQ(ByteCmpRes, MaxLen); 745 Builder.CreateCondBr(FoundCmp, EndBB, FoundBB); 746 DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}, 747 {DominatorTree::Insert, CmpBB, EndBB}}); 748 749 } else { 750 Builder.CreateBr(FoundBB); 751 DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}}); 752 } 753 754 auto fixSuccessorPhis = [&](BasicBlock *SuccBB) { 755 for (PHINode &PN : SuccBB->phis()) { 756 // At this point we've already replaced all uses of the result from the 757 // loop with ByteCmp. Look through the incoming values to find ByteCmp, 758 // meaning this is a Phi collecting the results of the byte compare. 759 bool ResPhi = false; 760 for (Value *Op : PN.incoming_values()) 761 if (Op == ByteCmpRes) { 762 ResPhi = true; 763 break; 764 } 765 766 // Any PHI that depended upon the result of the byte compare needs a new 767 // incoming value from CmpBB. This is because the original loop will get 768 // deleted. 769 if (ResPhi) 770 PN.addIncoming(ByteCmpRes, CmpBB); 771 else { 772 // There should be no other outside uses of other values in the 773 // original loop. Any incoming values should either: 774 // 1. Be for blocks outside the loop, which aren't interesting. Or .. 775 // 2. These are from blocks in the loop with values defined outside 776 // the loop. We should a similar incoming value from CmpBB. 777 for (BasicBlock *BB : PN.blocks()) 778 if (CurLoop->contains(BB)) { 779 PN.addIncoming(PN.getIncomingValueForBlock(BB), CmpBB); 780 break; 781 } 782 } 783 } 784 }; 785 786 // Ensure all Phis in the successors of CmpBB have an incoming value from it. 787 fixSuccessorPhis(EndBB); 788 if (EndBB != FoundBB) 789 fixSuccessorPhis(FoundBB); 790 791 // The new CmpBB block isn't part of the loop, but will need to be added to 792 // the outer loop if there is one. 793 if (!CurLoop->isOutermost()) 794 CurLoop->getParentLoop()->addBasicBlockToLoop(CmpBB, *LI); 795 796 if (VerifyLoops && CurLoop->getParentLoop()) { 797 CurLoop->getParentLoop()->verifyLoop(); 798 if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI)) 799 report_fatal_error("Loops must remain in LCSSA form!"); 800 } 801 } 802