1 //===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements a pass that recognizes certain loop idioms and 10 // transforms them into more optimized versions of the same loop. In cases 11 // where this happens, it can be a significant performance win. 12 // 13 // We currently only recognize one loop that finds the first mismatched byte 14 // in an array and returns the index, i.e. something like: 15 // 16 // while (++i != n) { 17 // if (a[i] != b[i]) 18 // break; 19 // } 20 // 21 // In this example we can actually vectorize the loop despite the early exit, 22 // although the loop vectorizer does not support it. It requires some extra 23 // checks to deal with the possibility of faulting loads when crossing page 24 // boundaries. However, even with these checks it is still profitable to do the 25 // transformation. 26 // 27 //===----------------------------------------------------------------------===// 28 // 29 // NOTE: This Pass matches a really specific loop pattern because it's only 30 // supposed to be a temporary solution until our LoopVectorizer is powerful 31 // enought to vectorize it automatically. 32 // 33 // TODO List: 34 // 35 // * Add support for the inverse case where we scan for a matching element. 36 // * Permit 64-bit induction variable types. 37 // * Recognize loops that increment the IV *after* comparing bytes. 38 // * Allow 32-bit sign-extends of the IV used by the GEP. 39 // 40 //===----------------------------------------------------------------------===// 41 42 #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h" 43 #include "llvm/Analysis/DomTreeUpdater.h" 44 #include "llvm/Analysis/LoopPass.h" 45 #include "llvm/Analysis/TargetTransformInfo.h" 46 #include "llvm/IR/Dominators.h" 47 #include "llvm/IR/IRBuilder.h" 48 #include "llvm/IR/Intrinsics.h" 49 #include "llvm/IR/MDBuilder.h" 50 #include "llvm/IR/PatternMatch.h" 51 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 52 53 using namespace llvm; 54 using namespace PatternMatch; 55 56 #define DEBUG_TYPE "loop-idiom-vectorize" 57 58 static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden, 59 cl::init(false), 60 cl::desc("Disable Loop Idiom Vectorize Pass.")); 61 62 static cl::opt<bool> 63 DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden, 64 cl::init(false), 65 cl::desc("Proceed with Loop Idiom Vectorize Pass, but do " 66 "not convert byte-compare loop(s).")); 67 68 static cl::opt<bool> 69 VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false), 70 cl::desc("Verify loops generated Loop Idiom Vectorize Pass.")); 71 72 namespace { 73 74 class LoopIdiomVectorize { 75 Loop *CurLoop = nullptr; 76 DominatorTree *DT; 77 LoopInfo *LI; 78 const TargetTransformInfo *TTI; 79 const DataLayout *DL; 80 81 public: 82 explicit LoopIdiomVectorize(DominatorTree *DT, LoopInfo *LI, 83 const TargetTransformInfo *TTI, 84 const DataLayout *DL) 85 : DT(DT), LI(LI), TTI(TTI), DL(DL) {} 86 87 bool run(Loop *L); 88 89 private: 90 /// \name Countable Loop Idiom Handling 91 /// @{ 92 93 bool runOnCountableLoop(); 94 bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount, 95 SmallVectorImpl<BasicBlock *> &ExitBlocks); 96 97 bool recognizeByteCompare(); 98 Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, 99 GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, 100 Instruction *Index, Value *Start, Value *MaxLen); 101 void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, 102 PHINode *IndPhi, Value *MaxLen, Instruction *Index, 103 Value *Start, bool IncIdx, BasicBlock *FoundBB, 104 BasicBlock *EndBB); 105 /// @} 106 }; 107 } // anonymous namespace 108 109 PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM, 110 LoopStandardAnalysisResults &AR, 111 LPMUpdater &) { 112 if (DisableAll) 113 return PreservedAnalyses::all(); 114 115 const auto *DL = &L.getHeader()->getModule()->getDataLayout(); 116 117 LoopIdiomVectorize LIT(&AR.DT, &AR.LI, &AR.TTI, DL); 118 if (!LIT.run(&L)) 119 return PreservedAnalyses::all(); 120 121 return PreservedAnalyses::none(); 122 } 123 124 //===----------------------------------------------------------------------===// 125 // 126 // Implementation of LoopIdiomVectorize 127 // 128 //===----------------------------------------------------------------------===// 129 130 bool LoopIdiomVectorize::run(Loop *L) { 131 CurLoop = L; 132 133 Function &F = *L->getHeader()->getParent(); 134 if (DisableAll || F.hasOptSize()) 135 return false; 136 137 if (F.hasFnAttribute(Attribute::NoImplicitFloat)) { 138 LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName() 139 << " due to its NoImplicitFloat attribute"); 140 return false; 141 } 142 143 // If the loop could not be converted to canonical form, it must have an 144 // indirectbr in it, just give up. 145 if (!L->getLoopPreheader()) 146 return false; 147 148 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %" 149 << CurLoop->getHeader()->getName() << "\n"); 150 151 return recognizeByteCompare(); 152 } 153 154 bool LoopIdiomVectorize::recognizeByteCompare() { 155 // Currently the transformation only works on scalable vector types, although 156 // there is no fundamental reason why it cannot be made to work for fixed 157 // width too. 158 159 // We also need to know the minimum page size for the target in order to 160 // generate runtime memory checks to ensure the vector version won't fault. 161 if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() || 162 DisableByteCmp) 163 return false; 164 165 BasicBlock *Header = CurLoop->getHeader(); 166 167 // In LoopIdiomVectorize::run we have already checked that the loop 168 // has a preheader so we can assume it's in a canonical form. 169 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2) 170 return false; 171 172 PHINode *PN = dyn_cast<PHINode>(&Header->front()); 173 if (!PN || PN->getNumIncomingValues() != 2) 174 return false; 175 176 auto LoopBlocks = CurLoop->getBlocks(); 177 // The first block in the loop should contain only 4 instructions, e.g. 178 // 179 // while.cond: 180 // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ] 181 // %inc = add i32 %res.phi, 1 182 // %cmp.not = icmp eq i32 %inc, %n 183 // br i1 %cmp.not, label %while.end, label %while.body 184 // 185 if (LoopBlocks[0]->sizeWithoutDebug() > 4) 186 return false; 187 188 // The second block should contain 7 instructions, e.g. 189 // 190 // while.body: 191 // %idx = zext i32 %inc to i64 192 // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx 193 // %load.a = load i8, ptr %idx.a 194 // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx 195 // %load.b = load i8, ptr %idx.b 196 // %cmp.not.ld = icmp eq i8 %load.a, %load.b 197 // br i1 %cmp.not.ld, label %while.cond, label %while.end 198 // 199 if (LoopBlocks[1]->sizeWithoutDebug() > 7) 200 return false; 201 202 // The incoming value to the PHI node from the loop should be an add of 1. 203 Value *StartIdx = nullptr; 204 Instruction *Index = nullptr; 205 if (!CurLoop->contains(PN->getIncomingBlock(0))) { 206 StartIdx = PN->getIncomingValue(0); 207 Index = dyn_cast<Instruction>(PN->getIncomingValue(1)); 208 } else { 209 StartIdx = PN->getIncomingValue(1); 210 Index = dyn_cast<Instruction>(PN->getIncomingValue(0)); 211 } 212 213 // Limit to 32-bit types for now 214 if (!Index || !Index->getType()->isIntegerTy(32) || 215 !match(Index, m_c_Add(m_Specific(PN), m_One()))) 216 return false; 217 218 // If we match the pattern, PN and Index will be replaced with the result of 219 // the cttz.elts intrinsic. If any other instructions are used outside of 220 // the loop, we cannot replace it. 221 for (BasicBlock *BB : LoopBlocks) 222 for (Instruction &I : *BB) 223 if (&I != PN && &I != Index) 224 for (User *U : I.users()) 225 if (!CurLoop->contains(cast<Instruction>(U))) 226 return false; 227 228 // Match the branch instruction for the header 229 ICmpInst::Predicate Pred; 230 Value *MaxLen; 231 BasicBlock *EndBB, *WhileBB; 232 if (!match(Header->getTerminator(), 233 m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)), 234 m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) || 235 Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB)) 236 return false; 237 238 // WhileBB should contain the pattern of load & compare instructions. Match 239 // the pattern and find the GEP instructions used by the loads. 240 ICmpInst::Predicate WhilePred; 241 BasicBlock *FoundBB; 242 BasicBlock *TrueBB; 243 Value *LoadA, *LoadB; 244 if (!match(WhileBB->getTerminator(), 245 m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)), 246 m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) || 247 WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB)) 248 return false; 249 250 Value *A, *B; 251 if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B)))) 252 return false; 253 254 LoadInst *LoadAI = cast<LoadInst>(LoadA); 255 LoadInst *LoadBI = cast<LoadInst>(LoadB); 256 if (!LoadAI->isSimple() || !LoadBI->isSimple()) 257 return false; 258 259 GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A); 260 GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B); 261 262 if (!GEPA || !GEPB) 263 return false; 264 265 Value *PtrA = GEPA->getPointerOperand(); 266 Value *PtrB = GEPB->getPointerOperand(); 267 268 // Check we are loading i8 values from two loop invariant pointers 269 if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) || 270 !GEPA->getResultElementType()->isIntegerTy(8) || 271 !GEPB->getResultElementType()->isIntegerTy(8) || 272 !LoadAI->getType()->isIntegerTy(8) || 273 !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB) 274 return false; 275 276 // Check that the index to the GEPs is the index we found earlier 277 if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1) 278 return false; 279 280 Value *IdxA = GEPA->getOperand(GEPA->getNumIndices()); 281 Value *IdxB = GEPB->getOperand(GEPB->getNumIndices()); 282 if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index)))) 283 return false; 284 285 // We only ever expect the pre-incremented index value to be used inside the 286 // loop. 287 if (!PN->hasOneUse()) 288 return false; 289 290 // Ensure that when the Found and End blocks are identical the PHIs have the 291 // supported format. We don't currently allow cases like this: 292 // while.cond: 293 // ... 294 // br i1 %cmp.not, label %while.end, label %while.body 295 // 296 // while.body: 297 // ... 298 // br i1 %cmp.not2, label %while.cond, label %while.end 299 // 300 // while.end: 301 // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ] 302 // 303 // Where the incoming values for %final_ptr are unique and from each of the 304 // loop blocks, but not actually defined in the loop. This requires extra 305 // work setting up the byte.compare block, i.e. by introducing a select to 306 // choose the correct value. 307 // TODO: We could add support for this in future. 308 if (FoundBB == EndBB) { 309 for (PHINode &EndPN : EndBB->phis()) { 310 Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header); 311 Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB); 312 313 // The value of the index when leaving the while.cond block is always the 314 // same as the end value (MaxLen) so we permit either. The value when 315 // leaving the while.body block should only be the index. Otherwise for 316 // any other values we only allow ones that are same for both blocks. 317 if (WhileCondVal != WhileBodyVal && 318 ((WhileCondVal != Index && WhileCondVal != MaxLen) || 319 (WhileBodyVal != Index))) 320 return false; 321 } 322 } 323 324 LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n" 325 << *(EndBB->getParent()) << "\n\n"); 326 327 // The index is incremented before the GEP/Load pair so we need to 328 // add 1 to the start value. 329 transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true, 330 FoundBB, EndBB); 331 return true; 332 } 333 334 Value *LoopIdiomVectorize::expandFindMismatch( 335 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, 336 GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) { 337 Value *PtrA = GEPA->getPointerOperand(); 338 Value *PtrB = GEPB->getPointerOperand(); 339 340 // Get the arguments and types for the intrinsic. 341 BasicBlock *Preheader = CurLoop->getLoopPreheader(); 342 BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator()); 343 LLVMContext &Ctx = PHBranch->getContext(); 344 Type *LoadType = Type::getInt8Ty(Ctx); 345 Type *ResType = Builder.getInt32Ty(); 346 347 // Split block in the original loop preheader. 348 BasicBlock *EndBlock = 349 SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end"); 350 351 // Create the blocks that we're going to need: 352 // 1. A block for checking the zero-extended length exceeds 0 353 // 2. A block to check that the start and end addresses of a given array 354 // lie on the same page. 355 // 3. The vector loop preheader. 356 // 4. The first vector loop block. 357 // 5. The vector loop increment block. 358 // 6. A block we can jump to from the vector loop when a mismatch is found. 359 // 7. The first block of the scalar loop itself, containing PHIs , loads 360 // and cmp. 361 // 8. A scalar loop increment block to increment the PHIs and go back 362 // around the loop. 363 364 BasicBlock *MinItCheckBlock = BasicBlock::Create( 365 Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock); 366 367 // Update the terminator added by SplitBlock to branch to the first block 368 Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock); 369 370 BasicBlock *MemCheckBlock = BasicBlock::Create( 371 Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock); 372 373 BasicBlock *VectorLoopPreheaderBlock = BasicBlock::Create( 374 Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock); 375 376 BasicBlock *VectorLoopStartBlock = BasicBlock::Create( 377 Ctx, "mismatch_vec_loop", EndBlock->getParent(), EndBlock); 378 379 BasicBlock *VectorLoopIncBlock = BasicBlock::Create( 380 Ctx, "mismatch_vec_loop_inc", EndBlock->getParent(), EndBlock); 381 382 BasicBlock *VectorLoopMismatchBlock = BasicBlock::Create( 383 Ctx, "mismatch_vec_loop_found", EndBlock->getParent(), EndBlock); 384 385 BasicBlock *LoopPreHeaderBlock = BasicBlock::Create( 386 Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock); 387 388 BasicBlock *LoopStartBlock = 389 BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock); 390 391 BasicBlock *LoopIncBlock = BasicBlock::Create( 392 Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock); 393 394 DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock}, 395 {DominatorTree::Delete, Preheader, EndBlock}}); 396 397 // Update LoopInfo with the new vector & scalar loops. 398 auto VectorLoop = LI->AllocateLoop(); 399 auto ScalarLoop = LI->AllocateLoop(); 400 401 if (CurLoop->getParentLoop()) { 402 CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, *LI); 403 CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, *LI); 404 CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopPreheaderBlock, 405 *LI); 406 CurLoop->getParentLoop()->addChildLoop(VectorLoop); 407 CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopMismatchBlock, *LI); 408 CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, *LI); 409 CurLoop->getParentLoop()->addChildLoop(ScalarLoop); 410 } else { 411 LI->addTopLevelLoop(VectorLoop); 412 LI->addTopLevelLoop(ScalarLoop); 413 } 414 415 // Add the new basic blocks to their associated loops. 416 VectorLoop->addBasicBlockToLoop(VectorLoopStartBlock, *LI); 417 VectorLoop->addBasicBlockToLoop(VectorLoopIncBlock, *LI); 418 419 ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI); 420 ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI); 421 422 // Set up some types and constants that we intend to reuse. 423 Type *I64Type = Builder.getInt64Ty(); 424 425 // Check the zero-extended iteration count > 0 426 Builder.SetInsertPoint(MinItCheckBlock); 427 Value *ExtStart = Builder.CreateZExt(Start, I64Type); 428 Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type); 429 // This check doesn't really cost us very much. 430 431 Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen); 432 BranchInst *MinItCheckBr = 433 BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck); 434 MinItCheckBr->setMetadata( 435 LLVMContext::MD_prof, 436 MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1)); 437 Builder.Insert(MinItCheckBr); 438 439 DTU.applyUpdates( 440 {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock}, 441 {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}}); 442 443 // For each of the arrays, check the start/end addresses are on the same 444 // page. 445 Builder.SetInsertPoint(MemCheckBlock); 446 447 // The early exit in the original loop means that when performing vector 448 // loads we are potentially reading ahead of the early exit. So we could 449 // fault if crossing a page boundary. Therefore, we create runtime memory 450 // checks based on the minimum page size as follows: 451 // 1. Calculate the addresses of the first memory accesses in the loop, 452 // i.e. LhsStart and RhsStart. 453 // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd. 454 // 3. Determine which pages correspond to all the memory accesses, i.e 455 // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage. 456 // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then 457 // we know we won't cross any page boundaries in the loop so we can 458 // enter the vector loop! Otherwise we fall back on the scalar loop. 459 Value *LhsStartGEP = Builder.CreateGEP(LoadType, PtrA, ExtStart); 460 Value *RhsStartGEP = Builder.CreateGEP(LoadType, PtrB, ExtStart); 461 Value *RhsStart = Builder.CreatePtrToInt(RhsStartGEP, I64Type); 462 Value *LhsStart = Builder.CreatePtrToInt(LhsStartGEP, I64Type); 463 Value *LhsEndGEP = Builder.CreateGEP(LoadType, PtrA, ExtEnd); 464 Value *RhsEndGEP = Builder.CreateGEP(LoadType, PtrB, ExtEnd); 465 Value *LhsEnd = Builder.CreatePtrToInt(LhsEndGEP, I64Type); 466 Value *RhsEnd = Builder.CreatePtrToInt(RhsEndGEP, I64Type); 467 468 const uint64_t MinPageSize = TTI->getMinPageSize().value(); 469 const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize); 470 Value *LhsStartPage = Builder.CreateLShr(LhsStart, AddrShiftAmt); 471 Value *LhsEndPage = Builder.CreateLShr(LhsEnd, AddrShiftAmt); 472 Value *RhsStartPage = Builder.CreateLShr(RhsStart, AddrShiftAmt); 473 Value *RhsEndPage = Builder.CreateLShr(RhsEnd, AddrShiftAmt); 474 Value *LhsPageCmp = Builder.CreateICmpNE(LhsStartPage, LhsEndPage); 475 Value *RhsPageCmp = Builder.CreateICmpNE(RhsStartPage, RhsEndPage); 476 477 Value *CombinedPageCmp = Builder.CreateOr(LhsPageCmp, RhsPageCmp); 478 BranchInst *CombinedPageCmpCmpBr = BranchInst::Create( 479 LoopPreHeaderBlock, VectorLoopPreheaderBlock, CombinedPageCmp); 480 CombinedPageCmpCmpBr->setMetadata( 481 LLVMContext::MD_prof, MDBuilder(CombinedPageCmpCmpBr->getContext()) 482 .createBranchWeights(10, 90)); 483 Builder.Insert(CombinedPageCmpCmpBr); 484 485 DTU.applyUpdates( 486 {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock}, 487 {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}}); 488 489 // Set up the vector loop preheader, i.e. calculate initial loop predicate, 490 // zero-extend MaxLen to 64-bits, determine the number of vector elements 491 // processed in each iteration, etc. 492 Builder.SetInsertPoint(VectorLoopPreheaderBlock); 493 494 // At this point we know two things must be true: 495 // 1. Start <= End 496 // 2. ExtMaxLen <= MinPageSize due to the page checks. 497 // Therefore, we know that we can use a 64-bit induction variable that 498 // starts from 0 -> ExtMaxLen and it will not overflow. 499 ScalableVectorType *PredVTy = 500 ScalableVectorType::get(Builder.getInt1Ty(), 16); 501 502 Value *InitialPred = Builder.CreateIntrinsic( 503 Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd}); 504 505 Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {}); 506 VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "", 507 /*HasNUW=*/true, /*HasNSW=*/true); 508 509 Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(), 510 Builder.getInt1(false)); 511 512 BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock); 513 Builder.Insert(JumpToVectorLoop); 514 515 DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock, 516 VectorLoopStartBlock}}); 517 518 // Set up the first vector loop block by creating the PHIs, doing the vector 519 // loads and comparing the vectors. 520 Builder.SetInsertPoint(VectorLoopStartBlock); 521 PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred"); 522 LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock); 523 PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index"); 524 VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock); 525 Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16); 526 Value *Passthru = ConstantInt::getNullValue(VectorLoadType); 527 528 Value *VectorLhsGep = 529 Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds()); 530 Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep, 531 Align(1), LoopPred, Passthru); 532 533 Value *VectorRhsGep = 534 Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds()); 535 Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep, 536 Align(1), LoopPred, Passthru); 537 538 Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad); 539 VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse); 540 Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp); 541 BranchInst *VectorEarlyExit = BranchInst::Create( 542 VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes); 543 Builder.Insert(VectorEarlyExit); 544 545 DTU.applyUpdates( 546 {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, 547 {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); 548 549 // Increment the index counter and calculate the predicate for the next 550 // iteration of the loop. We branch back to the start of the loop if there 551 // is at least one active lane. 552 Builder.SetInsertPoint(VectorLoopIncBlock); 553 Value *NewVectorIndexPhi = 554 Builder.CreateAdd(VectorIndexPhi, VecLen, "", 555 /*HasNUW=*/true, /*HasNSW=*/true); 556 VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock); 557 Value *NewPred = 558 Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, 559 {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd}); 560 LoopPred->addIncoming(NewPred, VectorLoopIncBlock); 561 562 Value *PredHasActiveLanes = 563 Builder.CreateExtractElement(NewPred, uint64_t(0)); 564 BranchInst *VectorLoopBranchBack = 565 BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes); 566 Builder.Insert(VectorLoopBranchBack); 567 568 DTU.applyUpdates( 569 {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, 570 {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); 571 572 // If we found a mismatch then we need to calculate which lane in the vector 573 // had a mismatch and add that on to the current loop index. 574 Builder.SetInsertPoint(VectorLoopMismatchBlock); 575 PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred"); 576 FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock); 577 PHINode *LastLoopPred = 578 Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred"); 579 LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock); 580 PHINode *VectorFoundIndex = 581 Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index"); 582 VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock); 583 584 Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred); 585 Value *Ctz = Builder.CreateIntrinsic( 586 Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()}, 587 {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)}); 588 Ctz = Builder.CreateZExt(Ctz, I64Type); 589 Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "", 590 /*HasNUW=*/true, /*HasNSW=*/true); 591 Value *VectorLoopRes = Builder.CreateTrunc(VectorLoopRes64, ResType); 592 593 Builder.Insert(BranchInst::Create(EndBlock)); 594 595 DTU.applyUpdates( 596 {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}}); 597 598 // Generate code for scalar loop. 599 Builder.SetInsertPoint(LoopPreHeaderBlock); 600 Builder.Insert(BranchInst::Create(LoopStartBlock)); 601 602 DTU.applyUpdates( 603 {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}}); 604 605 Builder.SetInsertPoint(LoopStartBlock); 606 PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index"); 607 IndexPhi->addIncoming(Start, LoopPreHeaderBlock); 608 609 // Otherwise compare the values 610 // Load bytes from each array and compare them. 611 Value *GepOffset = Builder.CreateZExt(IndexPhi, I64Type); 612 613 Value *LhsGep = 614 Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds()); 615 Value *LhsLoad = Builder.CreateLoad(LoadType, LhsGep); 616 617 Value *RhsGep = 618 Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds()); 619 Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep); 620 621 Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad); 622 // If we have a mismatch then exit the loop ... 623 BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp); 624 Builder.Insert(MatchCmpBr); 625 626 DTU.applyUpdates({{DominatorTree::Insert, LoopStartBlock, LoopIncBlock}, 627 {DominatorTree::Insert, LoopStartBlock, EndBlock}}); 628 629 // Have we reached the maximum permitted length for the loop? 630 Builder.SetInsertPoint(LoopIncBlock); 631 Value *PhiInc = Builder.CreateAdd(IndexPhi, ConstantInt::get(ResType, 1), "", 632 /*HasNUW=*/Index->hasNoUnsignedWrap(), 633 /*HasNSW=*/Index->hasNoSignedWrap()); 634 IndexPhi->addIncoming(PhiInc, LoopIncBlock); 635 Value *IVCmp = Builder.CreateICmpEQ(PhiInc, MaxLen); 636 BranchInst *IVCmpBr = BranchInst::Create(EndBlock, LoopStartBlock, IVCmp); 637 Builder.Insert(IVCmpBr); 638 639 DTU.applyUpdates({{DominatorTree::Insert, LoopIncBlock, EndBlock}, 640 {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}}); 641 642 // In the end block we need to insert a PHI node to deal with three cases: 643 // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen. 644 // 2. We exitted the scalar loop early due to a mismatch and need to return 645 // the index that we found. 646 // 3. We didn't find a mismatch in the vector loop, so we return MaxLen. 647 // 4. We exitted the vector loop early due to a mismatch and need to return 648 // the index that we found. 649 Builder.SetInsertPoint(EndBlock, EndBlock->getFirstInsertionPt()); 650 PHINode *ResPhi = Builder.CreatePHI(ResType, 4, "mismatch_result"); 651 ResPhi->addIncoming(MaxLen, LoopIncBlock); 652 ResPhi->addIncoming(IndexPhi, LoopStartBlock); 653 ResPhi->addIncoming(MaxLen, VectorLoopIncBlock); 654 ResPhi->addIncoming(VectorLoopRes, VectorLoopMismatchBlock); 655 656 Value *FinalRes = Builder.CreateTrunc(ResPhi, ResType); 657 658 if (VerifyLoops) { 659 ScalarLoop->verifyLoop(); 660 VectorLoop->verifyLoop(); 661 if (!VectorLoop->isRecursivelyLCSSAForm(*DT, *LI)) 662 report_fatal_error("Loops must remain in LCSSA form!"); 663 if (!ScalarLoop->isRecursivelyLCSSAForm(*DT, *LI)) 664 report_fatal_error("Loops must remain in LCSSA form!"); 665 } 666 667 return FinalRes; 668 } 669 670 void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA, 671 GetElementPtrInst *GEPB, 672 PHINode *IndPhi, Value *MaxLen, 673 Instruction *Index, Value *Start, 674 bool IncIdx, BasicBlock *FoundBB, 675 BasicBlock *EndBB) { 676 677 // Insert the byte compare code at the end of the preheader block 678 BasicBlock *Preheader = CurLoop->getLoopPreheader(); 679 BasicBlock *Header = CurLoop->getHeader(); 680 BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator()); 681 IRBuilder<> Builder(PHBranch); 682 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 683 Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc()); 684 685 // Increment the pointer if this was done before the loads in the loop. 686 if (IncIdx) 687 Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1)); 688 689 Value *ByteCmpRes = 690 expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen); 691 692 // Replaces uses of index & induction Phi with intrinsic (we already 693 // checked that the the first instruction of Header is the Phi above). 694 assert(IndPhi->hasOneUse() && "Index phi node has more than one use!"); 695 Index->replaceAllUsesWith(ByteCmpRes); 696 697 assert(PHBranch->isUnconditional() && 698 "Expected preheader to terminate with an unconditional branch."); 699 700 // If no mismatch was found, we can jump to the end block. Create a 701 // new basic block for the compare instruction. 702 auto *CmpBB = BasicBlock::Create(Preheader->getContext(), "byte.compare", 703 Preheader->getParent()); 704 CmpBB->moveBefore(EndBB); 705 706 // Replace the branch in the preheader with an always-true conditional branch. 707 // This ensures there is still a reference to the original loop. 708 Builder.CreateCondBr(Builder.getTrue(), CmpBB, Header); 709 PHBranch->eraseFromParent(); 710 711 BasicBlock *MismatchEnd = cast<Instruction>(ByteCmpRes)->getParent(); 712 DTU.applyUpdates({{DominatorTree::Insert, MismatchEnd, CmpBB}}); 713 714 // Create the branch to either the end or found block depending on the value 715 // returned by the intrinsic. 716 Builder.SetInsertPoint(CmpBB); 717 if (FoundBB != EndBB) { 718 Value *FoundCmp = Builder.CreateICmpEQ(ByteCmpRes, MaxLen); 719 Builder.CreateCondBr(FoundCmp, EndBB, FoundBB); 720 DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}, 721 {DominatorTree::Insert, CmpBB, EndBB}}); 722 723 } else { 724 Builder.CreateBr(FoundBB); 725 DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}}); 726 } 727 728 auto fixSuccessorPhis = [&](BasicBlock *SuccBB) { 729 for (PHINode &PN : SuccBB->phis()) { 730 // At this point we've already replaced all uses of the result from the 731 // loop with ByteCmp. Look through the incoming values to find ByteCmp, 732 // meaning this is a Phi collecting the results of the byte compare. 733 bool ResPhi = false; 734 for (Value *Op : PN.incoming_values()) 735 if (Op == ByteCmpRes) { 736 ResPhi = true; 737 break; 738 } 739 740 // Any PHI that depended upon the result of the byte compare needs a new 741 // incoming value from CmpBB. This is because the original loop will get 742 // deleted. 743 if (ResPhi) 744 PN.addIncoming(ByteCmpRes, CmpBB); 745 else { 746 // There should be no other outside uses of other values in the 747 // original loop. Any incoming values should either: 748 // 1. Be for blocks outside the loop, which aren't interesting. Or .. 749 // 2. These are from blocks in the loop with values defined outside 750 // the loop. We should a similar incoming value from CmpBB. 751 for (BasicBlock *BB : PN.blocks()) 752 if (CurLoop->contains(BB)) { 753 PN.addIncoming(PN.getIncomingValueForBlock(BB), CmpBB); 754 break; 755 } 756 } 757 } 758 }; 759 760 // Ensure all Phis in the successors of CmpBB have an incoming value from it. 761 fixSuccessorPhis(EndBB); 762 if (EndBB != FoundBB) 763 fixSuccessorPhis(FoundBB); 764 765 // The new CmpBB block isn't part of the loop, but will need to be added to 766 // the outer loop if there is one. 767 if (!CurLoop->isOutermost()) 768 CurLoop->getParentLoop()->addBasicBlockToLoop(CmpBB, *LI); 769 770 if (VerifyLoops && CurLoop->getParentLoop()) { 771 CurLoop->getParentLoop()->verifyLoop(); 772 if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI)) 773 report_fatal_error("Loops must remain in LCSSA form!"); 774 } 775 } 776