1 //===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements a pass that recognizes certain loop idioms and 10 // transforms them into more optimized versions of the same loop. In cases 11 // where this happens, it can be a significant performance win. 12 // 13 // We currently only recognize one loop that finds the first mismatched byte 14 // in an array and returns the index, i.e. something like: 15 // 16 // while (++i != n) { 17 // if (a[i] != b[i]) 18 // break; 19 // } 20 // 21 // In this example we can actually vectorize the loop despite the early exit, 22 // although the loop vectorizer does not support it. It requires some extra 23 // checks to deal with the possibility of faulting loads when crossing page 24 // boundaries. However, even with these checks it is still profitable to do the 25 // transformation. 26 // 27 //===----------------------------------------------------------------------===// 28 // 29 // NOTE: This Pass matches a really specific loop pattern because it's only 30 // supposed to be a temporary solution until our LoopVectorizer is powerful 31 // enought to vectorize it automatically. 32 // 33 // TODO List: 34 // 35 // * Add support for the inverse case where we scan for a matching element. 36 // * Permit 64-bit induction variable types. 37 // * Recognize loops that increment the IV *after* comparing bytes. 38 // * Allow 32-bit sign-extends of the IV used by the GEP. 39 // 40 //===----------------------------------------------------------------------===// 41 42 #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h" 43 #include "llvm/Analysis/DomTreeUpdater.h" 44 #include "llvm/Analysis/LoopPass.h" 45 #include "llvm/Analysis/TargetTransformInfo.h" 46 #include "llvm/IR/Dominators.h" 47 #include "llvm/IR/IRBuilder.h" 48 #include "llvm/IR/Intrinsics.h" 49 #include "llvm/IR/MDBuilder.h" 50 #include "llvm/IR/PatternMatch.h" 51 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 52 53 using namespace llvm; 54 using namespace PatternMatch; 55 56 #define DEBUG_TYPE "loop-idiom-vectorize" 57 58 static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden, 59 cl::init(false), 60 cl::desc("Disable Loop Idiom Vectorize Pass.")); 61 62 static cl::opt<LoopIdiomVectorizeStyle> 63 LITVecStyle("loop-idiom-vectorize-style", cl::Hidden, 64 cl::desc("The vectorization style for loop idiom transform."), 65 cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked", 66 "Use masked vector intrinsics"), 67 clEnumValN(LoopIdiomVectorizeStyle::Predicated, 68 "predicated", "Use VP intrinsics")), 69 cl::init(LoopIdiomVectorizeStyle::Masked)); 70 71 static cl::opt<bool> 72 DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden, 73 cl::init(false), 74 cl::desc("Proceed with Loop Idiom Vectorize Pass, but do " 75 "not convert byte-compare loop(s).")); 76 77 static cl::opt<unsigned> 78 ByteCmpVF("loop-idiom-vectorize-bytecmp-vf", cl::Hidden, 79 cl::desc("The vectorization factor for byte-compare patterns."), 80 cl::init(16)); 81 82 static cl::opt<bool> 83 VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false), 84 cl::desc("Verify loops generated Loop Idiom Vectorize Pass.")); 85 86 namespace { 87 class LoopIdiomVectorize { 88 LoopIdiomVectorizeStyle VectorizeStyle; 89 unsigned ByteCompareVF; 90 Loop *CurLoop = nullptr; 91 DominatorTree *DT; 92 LoopInfo *LI; 93 const TargetTransformInfo *TTI; 94 const DataLayout *DL; 95 96 // Blocks that will be used for inserting vectorized code. 97 BasicBlock *EndBlock = nullptr; 98 BasicBlock *VectorLoopPreheaderBlock = nullptr; 99 BasicBlock *VectorLoopStartBlock = nullptr; 100 BasicBlock *VectorLoopMismatchBlock = nullptr; 101 BasicBlock *VectorLoopIncBlock = nullptr; 102 103 public: 104 LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT, 105 LoopInfo *LI, const TargetTransformInfo *TTI, 106 const DataLayout *DL) 107 : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) { 108 } 109 110 bool run(Loop *L); 111 112 private: 113 /// \name Countable Loop Idiom Handling 114 /// @{ 115 116 bool runOnCountableLoop(); 117 bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount, 118 SmallVectorImpl<BasicBlock *> &ExitBlocks); 119 120 bool recognizeByteCompare(); 121 122 Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, 123 GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, 124 Instruction *Index, Value *Start, Value *MaxLen); 125 126 Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, 127 GetElementPtrInst *GEPA, 128 GetElementPtrInst *GEPB, Value *ExtStart, 129 Value *ExtEnd); 130 Value *createPredicatedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, 131 GetElementPtrInst *GEPA, 132 GetElementPtrInst *GEPB, Value *ExtStart, 133 Value *ExtEnd); 134 135 void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, 136 PHINode *IndPhi, Value *MaxLen, Instruction *Index, 137 Value *Start, bool IncIdx, BasicBlock *FoundBB, 138 BasicBlock *EndBB); 139 /// @} 140 }; 141 } // anonymous namespace 142 143 PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM, 144 LoopStandardAnalysisResults &AR, 145 LPMUpdater &) { 146 if (DisableAll) 147 return PreservedAnalyses::all(); 148 149 const auto *DL = &L.getHeader()->getDataLayout(); 150 151 LoopIdiomVectorizeStyle VecStyle = VectorizeStyle; 152 if (LITVecStyle.getNumOccurrences()) 153 VecStyle = LITVecStyle; 154 155 unsigned BCVF = ByteCompareVF; 156 if (ByteCmpVF.getNumOccurrences()) 157 BCVF = ByteCmpVF; 158 159 LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL); 160 if (!LIV.run(&L)) 161 return PreservedAnalyses::all(); 162 163 return PreservedAnalyses::none(); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // 168 // Implementation of LoopIdiomVectorize 169 // 170 //===----------------------------------------------------------------------===// 171 172 bool LoopIdiomVectorize::run(Loop *L) { 173 CurLoop = L; 174 175 Function &F = *L->getHeader()->getParent(); 176 if (DisableAll || F.hasOptSize()) 177 return false; 178 179 if (F.hasFnAttribute(Attribute::NoImplicitFloat)) { 180 LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName() 181 << " due to its NoImplicitFloat attribute"); 182 return false; 183 } 184 185 // If the loop could not be converted to canonical form, it must have an 186 // indirectbr in it, just give up. 187 if (!L->getLoopPreheader()) 188 return false; 189 190 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %" 191 << CurLoop->getHeader()->getName() << "\n"); 192 193 return recognizeByteCompare(); 194 } 195 196 bool LoopIdiomVectorize::recognizeByteCompare() { 197 // Currently the transformation only works on scalable vector types, although 198 // there is no fundamental reason why it cannot be made to work for fixed 199 // width too. 200 201 // We also need to know the minimum page size for the target in order to 202 // generate runtime memory checks to ensure the vector version won't fault. 203 if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() || 204 DisableByteCmp) 205 return false; 206 207 BasicBlock *Header = CurLoop->getHeader(); 208 209 // In LoopIdiomVectorize::run we have already checked that the loop 210 // has a preheader so we can assume it's in a canonical form. 211 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2) 212 return false; 213 214 PHINode *PN = dyn_cast<PHINode>(&Header->front()); 215 if (!PN || PN->getNumIncomingValues() != 2) 216 return false; 217 218 auto LoopBlocks = CurLoop->getBlocks(); 219 // The first block in the loop should contain only 4 instructions, e.g. 220 // 221 // while.cond: 222 // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ] 223 // %inc = add i32 %res.phi, 1 224 // %cmp.not = icmp eq i32 %inc, %n 225 // br i1 %cmp.not, label %while.end, label %while.body 226 // 227 if (LoopBlocks[0]->sizeWithoutDebug() > 4) 228 return false; 229 230 // The second block should contain 7 instructions, e.g. 231 // 232 // while.body: 233 // %idx = zext i32 %inc to i64 234 // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx 235 // %load.a = load i8, ptr %idx.a 236 // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx 237 // %load.b = load i8, ptr %idx.b 238 // %cmp.not.ld = icmp eq i8 %load.a, %load.b 239 // br i1 %cmp.not.ld, label %while.cond, label %while.end 240 // 241 if (LoopBlocks[1]->sizeWithoutDebug() > 7) 242 return false; 243 244 // The incoming value to the PHI node from the loop should be an add of 1. 245 Value *StartIdx = nullptr; 246 Instruction *Index = nullptr; 247 if (!CurLoop->contains(PN->getIncomingBlock(0))) { 248 StartIdx = PN->getIncomingValue(0); 249 Index = dyn_cast<Instruction>(PN->getIncomingValue(1)); 250 } else { 251 StartIdx = PN->getIncomingValue(1); 252 Index = dyn_cast<Instruction>(PN->getIncomingValue(0)); 253 } 254 255 // Limit to 32-bit types for now 256 if (!Index || !Index->getType()->isIntegerTy(32) || 257 !match(Index, m_c_Add(m_Specific(PN), m_One()))) 258 return false; 259 260 // If we match the pattern, PN and Index will be replaced with the result of 261 // the cttz.elts intrinsic. If any other instructions are used outside of 262 // the loop, we cannot replace it. 263 for (BasicBlock *BB : LoopBlocks) 264 for (Instruction &I : *BB) 265 if (&I != PN && &I != Index) 266 for (User *U : I.users()) 267 if (!CurLoop->contains(cast<Instruction>(U))) 268 return false; 269 270 // Match the branch instruction for the header 271 ICmpInst::Predicate Pred; 272 Value *MaxLen; 273 BasicBlock *EndBB, *WhileBB; 274 if (!match(Header->getTerminator(), 275 m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)), 276 m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) || 277 Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB)) 278 return false; 279 280 // WhileBB should contain the pattern of load & compare instructions. Match 281 // the pattern and find the GEP instructions used by the loads. 282 ICmpInst::Predicate WhilePred; 283 BasicBlock *FoundBB; 284 BasicBlock *TrueBB; 285 Value *LoadA, *LoadB; 286 if (!match(WhileBB->getTerminator(), 287 m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)), 288 m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) || 289 WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB)) 290 return false; 291 292 Value *A, *B; 293 if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B)))) 294 return false; 295 296 LoadInst *LoadAI = cast<LoadInst>(LoadA); 297 LoadInst *LoadBI = cast<LoadInst>(LoadB); 298 if (!LoadAI->isSimple() || !LoadBI->isSimple()) 299 return false; 300 301 GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A); 302 GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B); 303 304 if (!GEPA || !GEPB) 305 return false; 306 307 Value *PtrA = GEPA->getPointerOperand(); 308 Value *PtrB = GEPB->getPointerOperand(); 309 310 // Check we are loading i8 values from two loop invariant pointers 311 if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) || 312 !GEPA->getResultElementType()->isIntegerTy(8) || 313 !GEPB->getResultElementType()->isIntegerTy(8) || 314 !LoadAI->getType()->isIntegerTy(8) || 315 !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB) 316 return false; 317 318 // Check that the index to the GEPs is the index we found earlier 319 if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1) 320 return false; 321 322 Value *IdxA = GEPA->getOperand(GEPA->getNumIndices()); 323 Value *IdxB = GEPB->getOperand(GEPB->getNumIndices()); 324 if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index)))) 325 return false; 326 327 // We only ever expect the pre-incremented index value to be used inside the 328 // loop. 329 if (!PN->hasOneUse()) 330 return false; 331 332 // Ensure that when the Found and End blocks are identical the PHIs have the 333 // supported format. We don't currently allow cases like this: 334 // while.cond: 335 // ... 336 // br i1 %cmp.not, label %while.end, label %while.body 337 // 338 // while.body: 339 // ... 340 // br i1 %cmp.not2, label %while.cond, label %while.end 341 // 342 // while.end: 343 // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ] 344 // 345 // Where the incoming values for %final_ptr are unique and from each of the 346 // loop blocks, but not actually defined in the loop. This requires extra 347 // work setting up the byte.compare block, i.e. by introducing a select to 348 // choose the correct value. 349 // TODO: We could add support for this in future. 350 if (FoundBB == EndBB) { 351 for (PHINode &EndPN : EndBB->phis()) { 352 Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header); 353 Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB); 354 355 // The value of the index when leaving the while.cond block is always the 356 // same as the end value (MaxLen) so we permit either. The value when 357 // leaving the while.body block should only be the index. Otherwise for 358 // any other values we only allow ones that are same for both blocks. 359 if (WhileCondVal != WhileBodyVal && 360 ((WhileCondVal != Index && WhileCondVal != MaxLen) || 361 (WhileBodyVal != Index))) 362 return false; 363 } 364 } 365 366 LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n" 367 << *(EndBB->getParent()) << "\n\n"); 368 369 // The index is incremented before the GEP/Load pair so we need to 370 // add 1 to the start value. 371 transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true, 372 FoundBB, EndBB); 373 return true; 374 } 375 376 Value *LoopIdiomVectorize::createMaskedFindMismatch( 377 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, 378 GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { 379 Type *I64Type = Builder.getInt64Ty(); 380 Type *ResType = Builder.getInt32Ty(); 381 Type *LoadType = Builder.getInt8Ty(); 382 Value *PtrA = GEPA->getPointerOperand(); 383 Value *PtrB = GEPB->getPointerOperand(); 384 385 ScalableVectorType *PredVTy = 386 ScalableVectorType::get(Builder.getInt1Ty(), ByteCompareVF); 387 388 Value *InitialPred = Builder.CreateIntrinsic( 389 Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd}); 390 391 Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {}); 392 VecLen = 393 Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "", 394 /*HasNUW=*/true, /*HasNSW=*/true); 395 396 Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(), 397 Builder.getInt1(false)); 398 399 BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock); 400 Builder.Insert(JumpToVectorLoop); 401 402 DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock, 403 VectorLoopStartBlock}}); 404 405 // Set up the first vector loop block by creating the PHIs, doing the vector 406 // loads and comparing the vectors. 407 Builder.SetInsertPoint(VectorLoopStartBlock); 408 PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred"); 409 LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock); 410 PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index"); 411 VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock); 412 Type *VectorLoadType = 413 ScalableVectorType::get(Builder.getInt8Ty(), ByteCompareVF); 414 Value *Passthru = ConstantInt::getNullValue(VectorLoadType); 415 416 Value *VectorLhsGep = 417 Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds()); 418 Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep, 419 Align(1), LoopPred, Passthru); 420 421 Value *VectorRhsGep = 422 Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds()); 423 Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep, 424 Align(1), LoopPred, Passthru); 425 426 Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad); 427 VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse); 428 Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp); 429 BranchInst *VectorEarlyExit = BranchInst::Create( 430 VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes); 431 Builder.Insert(VectorEarlyExit); 432 433 DTU.applyUpdates( 434 {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, 435 {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); 436 437 // Increment the index counter and calculate the predicate for the next 438 // iteration of the loop. We branch back to the start of the loop if there 439 // is at least one active lane. 440 Builder.SetInsertPoint(VectorLoopIncBlock); 441 Value *NewVectorIndexPhi = 442 Builder.CreateAdd(VectorIndexPhi, VecLen, "", 443 /*HasNUW=*/true, /*HasNSW=*/true); 444 VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock); 445 Value *NewPred = 446 Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, 447 {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd}); 448 LoopPred->addIncoming(NewPred, VectorLoopIncBlock); 449 450 Value *PredHasActiveLanes = 451 Builder.CreateExtractElement(NewPred, uint64_t(0)); 452 BranchInst *VectorLoopBranchBack = 453 BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes); 454 Builder.Insert(VectorLoopBranchBack); 455 456 DTU.applyUpdates( 457 {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, 458 {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); 459 460 // If we found a mismatch then we need to calculate which lane in the vector 461 // had a mismatch and add that on to the current loop index. 462 Builder.SetInsertPoint(VectorLoopMismatchBlock); 463 PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred"); 464 FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock); 465 PHINode *LastLoopPred = 466 Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred"); 467 LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock); 468 PHINode *VectorFoundIndex = 469 Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index"); 470 VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock); 471 472 Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred); 473 Value *Ctz = Builder.CreateIntrinsic( 474 Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()}, 475 {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)}); 476 Ctz = Builder.CreateZExt(Ctz, I64Type); 477 Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "", 478 /*HasNUW=*/true, /*HasNSW=*/true); 479 return Builder.CreateTrunc(VectorLoopRes64, ResType); 480 } 481 482 Value *LoopIdiomVectorize::createPredicatedFindMismatch( 483 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, 484 GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { 485 Type *I64Type = Builder.getInt64Ty(); 486 Type *I32Type = Builder.getInt32Ty(); 487 Type *ResType = I32Type; 488 Type *LoadType = Builder.getInt8Ty(); 489 Value *PtrA = GEPA->getPointerOperand(); 490 Value *PtrB = GEPB->getPointerOperand(); 491 492 auto *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock); 493 Builder.Insert(JumpToVectorLoop); 494 495 DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock, 496 VectorLoopStartBlock}}); 497 498 // Set up the first Vector loop block by creating the PHIs, doing the vector 499 // loads and comparing the vectors. 500 Builder.SetInsertPoint(VectorLoopStartBlock); 501 auto *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vector_index"); 502 VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock); 503 504 // Calculate AVL by subtracting the vector loop index from the trip count 505 Value *AVL = Builder.CreateSub(ExtEnd, VectorIndexPhi, "avl", /*HasNUW=*/true, 506 /*HasNSW=*/true); 507 508 auto *VectorLoadType = ScalableVectorType::get(LoadType, ByteCompareVF); 509 auto *VF = ConstantInt::get(I32Type, ByteCompareVF); 510 511 Value *VL = Builder.CreateIntrinsic(Intrinsic::experimental_get_vector_length, 512 {I64Type}, {AVL, VF, Builder.getTrue()}); 513 Value *GepOffset = VectorIndexPhi; 514 515 Value *VectorLhsGep = 516 Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds()); 517 VectorType *TrueMaskTy = 518 VectorType::get(Builder.getInt1Ty(), VectorLoadType->getElementCount()); 519 Value *AllTrueMask = Constant::getAllOnesValue(TrueMaskTy); 520 Value *VectorLhsLoad = Builder.CreateIntrinsic( 521 Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()}, 522 {VectorLhsGep, AllTrueMask, VL}, nullptr, "lhs.load"); 523 524 Value *VectorRhsGep = 525 Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds()); 526 Value *VectorRhsLoad = Builder.CreateIntrinsic( 527 Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()}, 528 {VectorRhsGep, AllTrueMask, VL}, nullptr, "rhs.load"); 529 530 StringRef PredicateStr = CmpInst::getPredicateName(CmpInst::ICMP_NE); 531 auto *PredicateMDS = MDString::get(VectorLhsLoad->getContext(), PredicateStr); 532 Value *Pred = MetadataAsValue::get(VectorLhsLoad->getContext(), PredicateMDS); 533 Value *VectorMatchCmp = Builder.CreateIntrinsic( 534 Intrinsic::vp_icmp, {VectorLhsLoad->getType()}, 535 {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr, 536 "mismatch.cmp"); 537 Value *CTZ = Builder.CreateIntrinsic( 538 Intrinsic::vp_cttz_elts, {ResType, VectorMatchCmp->getType()}, 539 {VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(false), AllTrueMask, 540 VL}); 541 Value *MismatchFound = Builder.CreateICmpNE(CTZ, VL); 542 auto *VectorEarlyExit = BranchInst::Create(VectorLoopMismatchBlock, 543 VectorLoopIncBlock, MismatchFound); 544 Builder.Insert(VectorEarlyExit); 545 546 DTU.applyUpdates( 547 {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, 548 {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); 549 550 // Increment the index counter and calculate the predicate for the next 551 // iteration of the loop. We branch back to the start of the loop if there 552 // is at least one active lane. 553 Builder.SetInsertPoint(VectorLoopIncBlock); 554 Value *VL64 = Builder.CreateZExt(VL, I64Type); 555 Value *NewVectorIndexPhi = 556 Builder.CreateAdd(VectorIndexPhi, VL64, "", 557 /*HasNUW=*/true, /*HasNSW=*/true); 558 VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock); 559 Value *ExitCond = Builder.CreateICmpNE(NewVectorIndexPhi, ExtEnd); 560 auto *VectorLoopBranchBack = 561 BranchInst::Create(VectorLoopStartBlock, EndBlock, ExitCond); 562 Builder.Insert(VectorLoopBranchBack); 563 564 DTU.applyUpdates( 565 {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, 566 {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); 567 568 // If we found a mismatch then we need to calculate which lane in the vector 569 // had a mismatch and add that on to the current loop index. 570 Builder.SetInsertPoint(VectorLoopMismatchBlock); 571 572 // Add LCSSA phis for CTZ and VectorIndexPhi. 573 auto *CTZLCSSAPhi = Builder.CreatePHI(CTZ->getType(), 1, "ctz"); 574 CTZLCSSAPhi->addIncoming(CTZ, VectorLoopStartBlock); 575 auto *VectorIndexLCSSAPhi = 576 Builder.CreatePHI(VectorIndexPhi->getType(), 1, "mismatch_vector_index"); 577 VectorIndexLCSSAPhi->addIncoming(VectorIndexPhi, VectorLoopStartBlock); 578 579 Value *CTZI64 = Builder.CreateZExt(CTZLCSSAPhi, I64Type); 580 Value *VectorLoopRes64 = Builder.CreateAdd(VectorIndexLCSSAPhi, CTZI64, "", 581 /*HasNUW=*/true, /*HasNSW=*/true); 582 return Builder.CreateTrunc(VectorLoopRes64, ResType); 583 } 584 585 Value *LoopIdiomVectorize::expandFindMismatch( 586 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, 587 GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) { 588 Value *PtrA = GEPA->getPointerOperand(); 589 Value *PtrB = GEPB->getPointerOperand(); 590 591 // Get the arguments and types for the intrinsic. 592 BasicBlock *Preheader = CurLoop->getLoopPreheader(); 593 BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator()); 594 LLVMContext &Ctx = PHBranch->getContext(); 595 Type *LoadType = Type::getInt8Ty(Ctx); 596 Type *ResType = Builder.getInt32Ty(); 597 598 // Split block in the original loop preheader. 599 EndBlock = SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end"); 600 601 // Create the blocks that we're going to need: 602 // 1. A block for checking the zero-extended length exceeds 0 603 // 2. A block to check that the start and end addresses of a given array 604 // lie on the same page. 605 // 3. The vector loop preheader. 606 // 4. The first vector loop block. 607 // 5. The vector loop increment block. 608 // 6. A block we can jump to from the vector loop when a mismatch is found. 609 // 7. The first block of the scalar loop itself, containing PHIs , loads 610 // and cmp. 611 // 8. A scalar loop increment block to increment the PHIs and go back 612 // around the loop. 613 614 BasicBlock *MinItCheckBlock = BasicBlock::Create( 615 Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock); 616 617 // Update the terminator added by SplitBlock to branch to the first block 618 Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock); 619 620 BasicBlock *MemCheckBlock = BasicBlock::Create( 621 Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock); 622 623 VectorLoopPreheaderBlock = BasicBlock::Create( 624 Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock); 625 626 VectorLoopStartBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop", 627 EndBlock->getParent(), EndBlock); 628 629 VectorLoopIncBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_inc", 630 EndBlock->getParent(), EndBlock); 631 632 VectorLoopMismatchBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_found", 633 EndBlock->getParent(), EndBlock); 634 635 BasicBlock *LoopPreHeaderBlock = BasicBlock::Create( 636 Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock); 637 638 BasicBlock *LoopStartBlock = 639 BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock); 640 641 BasicBlock *LoopIncBlock = BasicBlock::Create( 642 Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock); 643 644 DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock}, 645 {DominatorTree::Delete, Preheader, EndBlock}}); 646 647 // Update LoopInfo with the new vector & scalar loops. 648 auto VectorLoop = LI->AllocateLoop(); 649 auto ScalarLoop = LI->AllocateLoop(); 650 651 if (CurLoop->getParentLoop()) { 652 CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, *LI); 653 CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, *LI); 654 CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopPreheaderBlock, 655 *LI); 656 CurLoop->getParentLoop()->addChildLoop(VectorLoop); 657 CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopMismatchBlock, *LI); 658 CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, *LI); 659 CurLoop->getParentLoop()->addChildLoop(ScalarLoop); 660 } else { 661 LI->addTopLevelLoop(VectorLoop); 662 LI->addTopLevelLoop(ScalarLoop); 663 } 664 665 // Add the new basic blocks to their associated loops. 666 VectorLoop->addBasicBlockToLoop(VectorLoopStartBlock, *LI); 667 VectorLoop->addBasicBlockToLoop(VectorLoopIncBlock, *LI); 668 669 ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI); 670 ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI); 671 672 // Set up some types and constants that we intend to reuse. 673 Type *I64Type = Builder.getInt64Ty(); 674 675 // Check the zero-extended iteration count > 0 676 Builder.SetInsertPoint(MinItCheckBlock); 677 Value *ExtStart = Builder.CreateZExt(Start, I64Type); 678 Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type); 679 // This check doesn't really cost us very much. 680 681 Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen); 682 BranchInst *MinItCheckBr = 683 BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck); 684 MinItCheckBr->setMetadata( 685 LLVMContext::MD_prof, 686 MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1)); 687 Builder.Insert(MinItCheckBr); 688 689 DTU.applyUpdates( 690 {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock}, 691 {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}}); 692 693 // For each of the arrays, check the start/end addresses are on the same 694 // page. 695 Builder.SetInsertPoint(MemCheckBlock); 696 697 // The early exit in the original loop means that when performing vector 698 // loads we are potentially reading ahead of the early exit. So we could 699 // fault if crossing a page boundary. Therefore, we create runtime memory 700 // checks based on the minimum page size as follows: 701 // 1. Calculate the addresses of the first memory accesses in the loop, 702 // i.e. LhsStart and RhsStart. 703 // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd. 704 // 3. Determine which pages correspond to all the memory accesses, i.e 705 // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage. 706 // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then 707 // we know we won't cross any page boundaries in the loop so we can 708 // enter the vector loop! Otherwise we fall back on the scalar loop. 709 Value *LhsStartGEP = Builder.CreateGEP(LoadType, PtrA, ExtStart); 710 Value *RhsStartGEP = Builder.CreateGEP(LoadType, PtrB, ExtStart); 711 Value *RhsStart = Builder.CreatePtrToInt(RhsStartGEP, I64Type); 712 Value *LhsStart = Builder.CreatePtrToInt(LhsStartGEP, I64Type); 713 Value *LhsEndGEP = Builder.CreateGEP(LoadType, PtrA, ExtEnd); 714 Value *RhsEndGEP = Builder.CreateGEP(LoadType, PtrB, ExtEnd); 715 Value *LhsEnd = Builder.CreatePtrToInt(LhsEndGEP, I64Type); 716 Value *RhsEnd = Builder.CreatePtrToInt(RhsEndGEP, I64Type); 717 718 const uint64_t MinPageSize = TTI->getMinPageSize().value(); 719 const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize); 720 Value *LhsStartPage = Builder.CreateLShr(LhsStart, AddrShiftAmt); 721 Value *LhsEndPage = Builder.CreateLShr(LhsEnd, AddrShiftAmt); 722 Value *RhsStartPage = Builder.CreateLShr(RhsStart, AddrShiftAmt); 723 Value *RhsEndPage = Builder.CreateLShr(RhsEnd, AddrShiftAmt); 724 Value *LhsPageCmp = Builder.CreateICmpNE(LhsStartPage, LhsEndPage); 725 Value *RhsPageCmp = Builder.CreateICmpNE(RhsStartPage, RhsEndPage); 726 727 Value *CombinedPageCmp = Builder.CreateOr(LhsPageCmp, RhsPageCmp); 728 BranchInst *CombinedPageCmpCmpBr = BranchInst::Create( 729 LoopPreHeaderBlock, VectorLoopPreheaderBlock, CombinedPageCmp); 730 CombinedPageCmpCmpBr->setMetadata( 731 LLVMContext::MD_prof, MDBuilder(CombinedPageCmpCmpBr->getContext()) 732 .createBranchWeights(10, 90)); 733 Builder.Insert(CombinedPageCmpCmpBr); 734 735 DTU.applyUpdates( 736 {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock}, 737 {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}}); 738 739 // Set up the vector loop preheader, i.e. calculate initial loop predicate, 740 // zero-extend MaxLen to 64-bits, determine the number of vector elements 741 // processed in each iteration, etc. 742 Builder.SetInsertPoint(VectorLoopPreheaderBlock); 743 744 // At this point we know two things must be true: 745 // 1. Start <= End 746 // 2. ExtMaxLen <= MinPageSize due to the page checks. 747 // Therefore, we know that we can use a 64-bit induction variable that 748 // starts from 0 -> ExtMaxLen and it will not overflow. 749 Value *VectorLoopRes = nullptr; 750 switch (VectorizeStyle) { 751 case LoopIdiomVectorizeStyle::Masked: 752 VectorLoopRes = 753 createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd); 754 break; 755 case LoopIdiomVectorizeStyle::Predicated: 756 VectorLoopRes = createPredicatedFindMismatch(Builder, DTU, GEPA, GEPB, 757 ExtStart, ExtEnd); 758 break; 759 } 760 761 Builder.Insert(BranchInst::Create(EndBlock)); 762 763 DTU.applyUpdates( 764 {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}}); 765 766 // Generate code for scalar loop. 767 Builder.SetInsertPoint(LoopPreHeaderBlock); 768 Builder.Insert(BranchInst::Create(LoopStartBlock)); 769 770 DTU.applyUpdates( 771 {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}}); 772 773 Builder.SetInsertPoint(LoopStartBlock); 774 PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index"); 775 IndexPhi->addIncoming(Start, LoopPreHeaderBlock); 776 777 // Otherwise compare the values 778 // Load bytes from each array and compare them. 779 Value *GepOffset = Builder.CreateZExt(IndexPhi, I64Type); 780 781 Value *LhsGep = 782 Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds()); 783 Value *LhsLoad = Builder.CreateLoad(LoadType, LhsGep); 784 785 Value *RhsGep = 786 Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds()); 787 Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep); 788 789 Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad); 790 // If we have a mismatch then exit the loop ... 791 BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp); 792 Builder.Insert(MatchCmpBr); 793 794 DTU.applyUpdates({{DominatorTree::Insert, LoopStartBlock, LoopIncBlock}, 795 {DominatorTree::Insert, LoopStartBlock, EndBlock}}); 796 797 // Have we reached the maximum permitted length for the loop? 798 Builder.SetInsertPoint(LoopIncBlock); 799 Value *PhiInc = Builder.CreateAdd(IndexPhi, ConstantInt::get(ResType, 1), "", 800 /*HasNUW=*/Index->hasNoUnsignedWrap(), 801 /*HasNSW=*/Index->hasNoSignedWrap()); 802 IndexPhi->addIncoming(PhiInc, LoopIncBlock); 803 Value *IVCmp = Builder.CreateICmpEQ(PhiInc, MaxLen); 804 BranchInst *IVCmpBr = BranchInst::Create(EndBlock, LoopStartBlock, IVCmp); 805 Builder.Insert(IVCmpBr); 806 807 DTU.applyUpdates({{DominatorTree::Insert, LoopIncBlock, EndBlock}, 808 {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}}); 809 810 // In the end block we need to insert a PHI node to deal with three cases: 811 // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen. 812 // 2. We exitted the scalar loop early due to a mismatch and need to return 813 // the index that we found. 814 // 3. We didn't find a mismatch in the vector loop, so we return MaxLen. 815 // 4. We exitted the vector loop early due to a mismatch and need to return 816 // the index that we found. 817 Builder.SetInsertPoint(EndBlock, EndBlock->getFirstInsertionPt()); 818 PHINode *ResPhi = Builder.CreatePHI(ResType, 4, "mismatch_result"); 819 ResPhi->addIncoming(MaxLen, LoopIncBlock); 820 ResPhi->addIncoming(IndexPhi, LoopStartBlock); 821 ResPhi->addIncoming(MaxLen, VectorLoopIncBlock); 822 ResPhi->addIncoming(VectorLoopRes, VectorLoopMismatchBlock); 823 824 Value *FinalRes = Builder.CreateTrunc(ResPhi, ResType); 825 826 if (VerifyLoops) { 827 ScalarLoop->verifyLoop(); 828 VectorLoop->verifyLoop(); 829 if (!VectorLoop->isRecursivelyLCSSAForm(*DT, *LI)) 830 report_fatal_error("Loops must remain in LCSSA form!"); 831 if (!ScalarLoop->isRecursivelyLCSSAForm(*DT, *LI)) 832 report_fatal_error("Loops must remain in LCSSA form!"); 833 } 834 835 return FinalRes; 836 } 837 838 void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA, 839 GetElementPtrInst *GEPB, 840 PHINode *IndPhi, Value *MaxLen, 841 Instruction *Index, Value *Start, 842 bool IncIdx, BasicBlock *FoundBB, 843 BasicBlock *EndBB) { 844 845 // Insert the byte compare code at the end of the preheader block 846 BasicBlock *Preheader = CurLoop->getLoopPreheader(); 847 BasicBlock *Header = CurLoop->getHeader(); 848 BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator()); 849 IRBuilder<> Builder(PHBranch); 850 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 851 Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc()); 852 853 // Increment the pointer if this was done before the loads in the loop. 854 if (IncIdx) 855 Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1)); 856 857 Value *ByteCmpRes = 858 expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen); 859 860 // Replaces uses of index & induction Phi with intrinsic (we already 861 // checked that the the first instruction of Header is the Phi above). 862 assert(IndPhi->hasOneUse() && "Index phi node has more than one use!"); 863 Index->replaceAllUsesWith(ByteCmpRes); 864 865 assert(PHBranch->isUnconditional() && 866 "Expected preheader to terminate with an unconditional branch."); 867 868 // If no mismatch was found, we can jump to the end block. Create a 869 // new basic block for the compare instruction. 870 auto *CmpBB = BasicBlock::Create(Preheader->getContext(), "byte.compare", 871 Preheader->getParent()); 872 CmpBB->moveBefore(EndBB); 873 874 // Replace the branch in the preheader with an always-true conditional branch. 875 // This ensures there is still a reference to the original loop. 876 Builder.CreateCondBr(Builder.getTrue(), CmpBB, Header); 877 PHBranch->eraseFromParent(); 878 879 BasicBlock *MismatchEnd = cast<Instruction>(ByteCmpRes)->getParent(); 880 DTU.applyUpdates({{DominatorTree::Insert, MismatchEnd, CmpBB}}); 881 882 // Create the branch to either the end or found block depending on the value 883 // returned by the intrinsic. 884 Builder.SetInsertPoint(CmpBB); 885 if (FoundBB != EndBB) { 886 Value *FoundCmp = Builder.CreateICmpEQ(ByteCmpRes, MaxLen); 887 Builder.CreateCondBr(FoundCmp, EndBB, FoundBB); 888 DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}, 889 {DominatorTree::Insert, CmpBB, EndBB}}); 890 891 } else { 892 Builder.CreateBr(FoundBB); 893 DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}}); 894 } 895 896 auto fixSuccessorPhis = [&](BasicBlock *SuccBB) { 897 for (PHINode &PN : SuccBB->phis()) { 898 // At this point we've already replaced all uses of the result from the 899 // loop with ByteCmp. Look through the incoming values to find ByteCmp, 900 // meaning this is a Phi collecting the results of the byte compare. 901 bool ResPhi = false; 902 for (Value *Op : PN.incoming_values()) 903 if (Op == ByteCmpRes) { 904 ResPhi = true; 905 break; 906 } 907 908 // Any PHI that depended upon the result of the byte compare needs a new 909 // incoming value from CmpBB. This is because the original loop will get 910 // deleted. 911 if (ResPhi) 912 PN.addIncoming(ByteCmpRes, CmpBB); 913 else { 914 // There should be no other outside uses of other values in the 915 // original loop. Any incoming values should either: 916 // 1. Be for blocks outside the loop, which aren't interesting. Or .. 917 // 2. These are from blocks in the loop with values defined outside 918 // the loop. We should a similar incoming value from CmpBB. 919 for (BasicBlock *BB : PN.blocks()) 920 if (CurLoop->contains(BB)) { 921 PN.addIncoming(PN.getIncomingValueForBlock(BB), CmpBB); 922 break; 923 } 924 } 925 } 926 }; 927 928 // Ensure all Phis in the successors of CmpBB have an incoming value from it. 929 fixSuccessorPhis(EndBB); 930 if (EndBB != FoundBB) 931 fixSuccessorPhis(FoundBB); 932 933 // The new CmpBB block isn't part of the loop, but will need to be added to 934 // the outer loop if there is one. 935 if (!CurLoop->isOutermost()) 936 CurLoop->getParentLoop()->addBasicBlockToLoop(CmpBB, *LI); 937 938 if (VerifyLoops && CurLoop->getParentLoop()) { 939 CurLoop->getParentLoop()->verifyLoop(); 940 if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI)) 941 report_fatal_error("Loops must remain in LCSSA form!"); 942 } 943 } 944