1 //===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements a pass that recognizes certain loop idioms and 10 // transforms them into more optimized versions of the same loop. In cases 11 // where this happens, it can be a significant performance win. 12 // 13 // We currently only recognize one loop that finds the first mismatched byte 14 // in an array and returns the index, i.e. something like: 15 // 16 // while (++i != n) { 17 // if (a[i] != b[i]) 18 // break; 19 // } 20 // 21 // In this example we can actually vectorize the loop despite the early exit, 22 // although the loop vectorizer does not support it. It requires some extra 23 // checks to deal with the possibility of faulting loads when crossing page 24 // boundaries. However, even with these checks it is still profitable to do the 25 // transformation. 26 // 27 //===----------------------------------------------------------------------===// 28 // 29 // NOTE: This Pass matches a really specific loop pattern because it's only 30 // supposed to be a temporary solution until our LoopVectorizer is powerful 31 // enought to vectorize it automatically. 32 // 33 // TODO List: 34 // 35 // * Add support for the inverse case where we scan for a matching element. 36 // * Permit 64-bit induction variable types. 37 // * Recognize loops that increment the IV *after* comparing bytes. 38 // * Allow 32-bit sign-extends of the IV used by the GEP. 39 // 40 //===----------------------------------------------------------------------===// 41 42 #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h" 43 #include "llvm/Analysis/DomTreeUpdater.h" 44 #include "llvm/Analysis/LoopPass.h" 45 #include "llvm/Analysis/TargetTransformInfo.h" 46 #include "llvm/IR/Dominators.h" 47 #include "llvm/IR/IRBuilder.h" 48 #include "llvm/IR/Intrinsics.h" 49 #include "llvm/IR/MDBuilder.h" 50 #include "llvm/IR/PatternMatch.h" 51 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 52 53 using namespace llvm; 54 using namespace PatternMatch; 55 56 #define DEBUG_TYPE "loop-idiom-vectorize" 57 58 static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden, 59 cl::init(false), 60 cl::desc("Disable Loop Idiom Vectorize Pass.")); 61 62 static cl::opt<LoopIdiomVectorizeStyle> 63 LITVecStyle("loop-idiom-vectorize-style", cl::Hidden, 64 cl::desc("The vectorization style for loop idiom transform."), 65 cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked", 66 "Use masked vector intrinsics"), 67 clEnumValN(LoopIdiomVectorizeStyle::Predicated, 68 "predicated", "Use VP intrinsics")), 69 cl::init(LoopIdiomVectorizeStyle::Masked)); 70 71 static cl::opt<bool> 72 DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden, 73 cl::init(false), 74 cl::desc("Proceed with Loop Idiom Vectorize Pass, but do " 75 "not convert byte-compare loop(s).")); 76 77 static cl::opt<unsigned> 78 ByteCmpVF("loop-idiom-vectorize-bytecmp-vf", cl::Hidden, 79 cl::desc("The vectorization factor for byte-compare patterns."), 80 cl::init(16)); 81 82 static cl::opt<bool> 83 VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false), 84 cl::desc("Verify loops generated Loop Idiom Vectorize Pass.")); 85 86 namespace { 87 class LoopIdiomVectorize { 88 LoopIdiomVectorizeStyle VectorizeStyle; 89 unsigned ByteCompareVF; 90 Loop *CurLoop = nullptr; 91 DominatorTree *DT; 92 LoopInfo *LI; 93 const TargetTransformInfo *TTI; 94 const DataLayout *DL; 95 96 // Blocks that will be used for inserting vectorized code. 97 BasicBlock *EndBlock = nullptr; 98 BasicBlock *VectorLoopPreheaderBlock = nullptr; 99 BasicBlock *VectorLoopStartBlock = nullptr; 100 BasicBlock *VectorLoopMismatchBlock = nullptr; 101 BasicBlock *VectorLoopIncBlock = nullptr; 102 103 public: 104 LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT, 105 LoopInfo *LI, const TargetTransformInfo *TTI, 106 const DataLayout *DL) 107 : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) { 108 } 109 110 bool run(Loop *L); 111 112 private: 113 /// \name Countable Loop Idiom Handling 114 /// @{ 115 116 bool runOnCountableLoop(); 117 bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount, 118 SmallVectorImpl<BasicBlock *> &ExitBlocks); 119 120 bool recognizeByteCompare(); 121 122 Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, 123 GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, 124 Instruction *Index, Value *Start, Value *MaxLen); 125 126 Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, 127 GetElementPtrInst *GEPA, 128 GetElementPtrInst *GEPB, Value *ExtStart, 129 Value *ExtEnd); 130 Value *createPredicatedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, 131 GetElementPtrInst *GEPA, 132 GetElementPtrInst *GEPB, Value *ExtStart, 133 Value *ExtEnd); 134 135 void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, 136 PHINode *IndPhi, Value *MaxLen, Instruction *Index, 137 Value *Start, bool IncIdx, BasicBlock *FoundBB, 138 BasicBlock *EndBB); 139 /// @} 140 }; 141 } // anonymous namespace 142 143 PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM, 144 LoopStandardAnalysisResults &AR, 145 LPMUpdater &) { 146 if (DisableAll) 147 return PreservedAnalyses::all(); 148 149 const auto *DL = &L.getHeader()->getDataLayout(); 150 151 LoopIdiomVectorizeStyle VecStyle = VectorizeStyle; 152 if (LITVecStyle.getNumOccurrences()) 153 VecStyle = LITVecStyle; 154 155 unsigned BCVF = ByteCompareVF; 156 if (ByteCmpVF.getNumOccurrences()) 157 BCVF = ByteCmpVF; 158 159 LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL); 160 if (!LIV.run(&L)) 161 return PreservedAnalyses::all(); 162 163 return PreservedAnalyses::none(); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // 168 // Implementation of LoopIdiomVectorize 169 // 170 //===----------------------------------------------------------------------===// 171 172 bool LoopIdiomVectorize::run(Loop *L) { 173 CurLoop = L; 174 175 Function &F = *L->getHeader()->getParent(); 176 if (DisableAll || F.hasOptSize()) 177 return false; 178 179 if (F.hasFnAttribute(Attribute::NoImplicitFloat)) { 180 LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName() 181 << " due to its NoImplicitFloat attribute"); 182 return false; 183 } 184 185 // If the loop could not be converted to canonical form, it must have an 186 // indirectbr in it, just give up. 187 if (!L->getLoopPreheader()) 188 return false; 189 190 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %" 191 << CurLoop->getHeader()->getName() << "\n"); 192 193 return recognizeByteCompare(); 194 } 195 196 bool LoopIdiomVectorize::recognizeByteCompare() { 197 // Currently the transformation only works on scalable vector types, although 198 // there is no fundamental reason why it cannot be made to work for fixed 199 // width too. 200 201 // We also need to know the minimum page size for the target in order to 202 // generate runtime memory checks to ensure the vector version won't fault. 203 if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() || 204 DisableByteCmp) 205 return false; 206 207 BasicBlock *Header = CurLoop->getHeader(); 208 209 // In LoopIdiomVectorize::run we have already checked that the loop 210 // has a preheader so we can assume it's in a canonical form. 211 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2) 212 return false; 213 214 PHINode *PN = dyn_cast<PHINode>(&Header->front()); 215 if (!PN || PN->getNumIncomingValues() != 2) 216 return false; 217 218 auto LoopBlocks = CurLoop->getBlocks(); 219 // The first block in the loop should contain only 4 instructions, e.g. 220 // 221 // while.cond: 222 // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ] 223 // %inc = add i32 %res.phi, 1 224 // %cmp.not = icmp eq i32 %inc, %n 225 // br i1 %cmp.not, label %while.end, label %while.body 226 // 227 if (LoopBlocks[0]->sizeWithoutDebug() > 4) 228 return false; 229 230 // The second block should contain 7 instructions, e.g. 231 // 232 // while.body: 233 // %idx = zext i32 %inc to i64 234 // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx 235 // %load.a = load i8, ptr %idx.a 236 // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx 237 // %load.b = load i8, ptr %idx.b 238 // %cmp.not.ld = icmp eq i8 %load.a, %load.b 239 // br i1 %cmp.not.ld, label %while.cond, label %while.end 240 // 241 if (LoopBlocks[1]->sizeWithoutDebug() > 7) 242 return false; 243 244 // The incoming value to the PHI node from the loop should be an add of 1. 245 Value *StartIdx = nullptr; 246 Instruction *Index = nullptr; 247 if (!CurLoop->contains(PN->getIncomingBlock(0))) { 248 StartIdx = PN->getIncomingValue(0); 249 Index = dyn_cast<Instruction>(PN->getIncomingValue(1)); 250 } else { 251 StartIdx = PN->getIncomingValue(1); 252 Index = dyn_cast<Instruction>(PN->getIncomingValue(0)); 253 } 254 255 // Limit to 32-bit types for now 256 if (!Index || !Index->getType()->isIntegerTy(32) || 257 !match(Index, m_c_Add(m_Specific(PN), m_One()))) 258 return false; 259 260 // If we match the pattern, PN and Index will be replaced with the result of 261 // the cttz.elts intrinsic. If any other instructions are used outside of 262 // the loop, we cannot replace it. 263 for (BasicBlock *BB : LoopBlocks) 264 for (Instruction &I : *BB) 265 if (&I != PN && &I != Index) 266 for (User *U : I.users()) 267 if (!CurLoop->contains(cast<Instruction>(U))) 268 return false; 269 270 // Match the branch instruction for the header 271 Value *MaxLen; 272 BasicBlock *EndBB, *WhileBB; 273 if (!match(Header->getTerminator(), 274 m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(Index), 275 m_Value(MaxLen)), 276 m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) || 277 !CurLoop->contains(WhileBB)) 278 return false; 279 280 // WhileBB should contain the pattern of load & compare instructions. Match 281 // the pattern and find the GEP instructions used by the loads. 282 BasicBlock *FoundBB; 283 BasicBlock *TrueBB; 284 Value *LoadA, *LoadB; 285 if (!match(WhileBB->getTerminator(), 286 m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(LoadA), 287 m_Value(LoadB)), 288 m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) || 289 !CurLoop->contains(TrueBB)) 290 return false; 291 292 Value *A, *B; 293 if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B)))) 294 return false; 295 296 LoadInst *LoadAI = cast<LoadInst>(LoadA); 297 LoadInst *LoadBI = cast<LoadInst>(LoadB); 298 if (!LoadAI->isSimple() || !LoadBI->isSimple()) 299 return false; 300 301 GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A); 302 GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B); 303 304 if (!GEPA || !GEPB) 305 return false; 306 307 Value *PtrA = GEPA->getPointerOperand(); 308 Value *PtrB = GEPB->getPointerOperand(); 309 310 // Check we are loading i8 values from two loop invariant pointers 311 if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) || 312 !GEPA->getResultElementType()->isIntegerTy(8) || 313 !GEPB->getResultElementType()->isIntegerTy(8) || 314 !LoadAI->getType()->isIntegerTy(8) || 315 !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB) 316 return false; 317 318 // Check that the index to the GEPs is the index we found earlier 319 if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1) 320 return false; 321 322 Value *IdxA = GEPA->getOperand(GEPA->getNumIndices()); 323 Value *IdxB = GEPB->getOperand(GEPB->getNumIndices()); 324 if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index)))) 325 return false; 326 327 // We only ever expect the pre-incremented index value to be used inside the 328 // loop. 329 if (!PN->hasOneUse()) 330 return false; 331 332 // Ensure that when the Found and End blocks are identical the PHIs have the 333 // supported format. We don't currently allow cases like this: 334 // while.cond: 335 // ... 336 // br i1 %cmp.not, label %while.end, label %while.body 337 // 338 // while.body: 339 // ... 340 // br i1 %cmp.not2, label %while.cond, label %while.end 341 // 342 // while.end: 343 // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ] 344 // 345 // Where the incoming values for %final_ptr are unique and from each of the 346 // loop blocks, but not actually defined in the loop. This requires extra 347 // work setting up the byte.compare block, i.e. by introducing a select to 348 // choose the correct value. 349 // TODO: We could add support for this in future. 350 if (FoundBB == EndBB) { 351 for (PHINode &EndPN : EndBB->phis()) { 352 Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header); 353 Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB); 354 355 // The value of the index when leaving the while.cond block is always the 356 // same as the end value (MaxLen) so we permit either. The value when 357 // leaving the while.body block should only be the index. Otherwise for 358 // any other values we only allow ones that are same for both blocks. 359 if (WhileCondVal != WhileBodyVal && 360 ((WhileCondVal != Index && WhileCondVal != MaxLen) || 361 (WhileBodyVal != Index))) 362 return false; 363 } 364 } 365 366 LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n" 367 << *(EndBB->getParent()) << "\n\n"); 368 369 // The index is incremented before the GEP/Load pair so we need to 370 // add 1 to the start value. 371 transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true, 372 FoundBB, EndBB); 373 return true; 374 } 375 376 Value *LoopIdiomVectorize::createMaskedFindMismatch( 377 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, 378 GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { 379 Type *I64Type = Builder.getInt64Ty(); 380 Type *ResType = Builder.getInt32Ty(); 381 Type *LoadType = Builder.getInt8Ty(); 382 Value *PtrA = GEPA->getPointerOperand(); 383 Value *PtrB = GEPB->getPointerOperand(); 384 385 ScalableVectorType *PredVTy = 386 ScalableVectorType::get(Builder.getInt1Ty(), ByteCompareVF); 387 388 Value *InitialPred = Builder.CreateIntrinsic( 389 Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd}); 390 391 Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {}); 392 VecLen = 393 Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "", 394 /*HasNUW=*/true, /*HasNSW=*/true); 395 396 Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(), 397 Builder.getInt1(false)); 398 399 BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock); 400 Builder.Insert(JumpToVectorLoop); 401 402 DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock, 403 VectorLoopStartBlock}}); 404 405 // Set up the first vector loop block by creating the PHIs, doing the vector 406 // loads and comparing the vectors. 407 Builder.SetInsertPoint(VectorLoopStartBlock); 408 PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred"); 409 LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock); 410 PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index"); 411 VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock); 412 Type *VectorLoadType = 413 ScalableVectorType::get(Builder.getInt8Ty(), ByteCompareVF); 414 Value *Passthru = ConstantInt::getNullValue(VectorLoadType); 415 416 Value *VectorLhsGep = 417 Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds()); 418 Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep, 419 Align(1), LoopPred, Passthru); 420 421 Value *VectorRhsGep = 422 Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds()); 423 Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep, 424 Align(1), LoopPred, Passthru); 425 426 Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad); 427 VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse); 428 Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp); 429 BranchInst *VectorEarlyExit = BranchInst::Create( 430 VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes); 431 Builder.Insert(VectorEarlyExit); 432 433 DTU.applyUpdates( 434 {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, 435 {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); 436 437 // Increment the index counter and calculate the predicate for the next 438 // iteration of the loop. We branch back to the start of the loop if there 439 // is at least one active lane. 440 Builder.SetInsertPoint(VectorLoopIncBlock); 441 Value *NewVectorIndexPhi = 442 Builder.CreateAdd(VectorIndexPhi, VecLen, "", 443 /*HasNUW=*/true, /*HasNSW=*/true); 444 VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock); 445 Value *NewPred = 446 Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, 447 {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd}); 448 LoopPred->addIncoming(NewPred, VectorLoopIncBlock); 449 450 Value *PredHasActiveLanes = 451 Builder.CreateExtractElement(NewPred, uint64_t(0)); 452 BranchInst *VectorLoopBranchBack = 453 BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes); 454 Builder.Insert(VectorLoopBranchBack); 455 456 DTU.applyUpdates( 457 {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, 458 {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); 459 460 // If we found a mismatch then we need to calculate which lane in the vector 461 // had a mismatch and add that on to the current loop index. 462 Builder.SetInsertPoint(VectorLoopMismatchBlock); 463 PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred"); 464 FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock); 465 PHINode *LastLoopPred = 466 Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred"); 467 LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock); 468 PHINode *VectorFoundIndex = 469 Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index"); 470 VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock); 471 472 Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred); 473 Value *Ctz = Builder.CreateCountTrailingZeroElems(ResType, PredMatchCmp); 474 Ctz = Builder.CreateZExt(Ctz, I64Type); 475 Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "", 476 /*HasNUW=*/true, /*HasNSW=*/true); 477 return Builder.CreateTrunc(VectorLoopRes64, ResType); 478 } 479 480 Value *LoopIdiomVectorize::createPredicatedFindMismatch( 481 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, 482 GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { 483 Type *I64Type = Builder.getInt64Ty(); 484 Type *I32Type = Builder.getInt32Ty(); 485 Type *ResType = I32Type; 486 Type *LoadType = Builder.getInt8Ty(); 487 Value *PtrA = GEPA->getPointerOperand(); 488 Value *PtrB = GEPB->getPointerOperand(); 489 490 auto *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock); 491 Builder.Insert(JumpToVectorLoop); 492 493 DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock, 494 VectorLoopStartBlock}}); 495 496 // Set up the first Vector loop block by creating the PHIs, doing the vector 497 // loads and comparing the vectors. 498 Builder.SetInsertPoint(VectorLoopStartBlock); 499 auto *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vector_index"); 500 VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock); 501 502 // Calculate AVL by subtracting the vector loop index from the trip count 503 Value *AVL = Builder.CreateSub(ExtEnd, VectorIndexPhi, "avl", /*HasNUW=*/true, 504 /*HasNSW=*/true); 505 506 auto *VectorLoadType = ScalableVectorType::get(LoadType, ByteCompareVF); 507 auto *VF = ConstantInt::get(I32Type, ByteCompareVF); 508 509 Value *VL = Builder.CreateIntrinsic(Intrinsic::experimental_get_vector_length, 510 {I64Type}, {AVL, VF, Builder.getTrue()}); 511 Value *GepOffset = VectorIndexPhi; 512 513 Value *VectorLhsGep = 514 Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds()); 515 VectorType *TrueMaskTy = 516 VectorType::get(Builder.getInt1Ty(), VectorLoadType->getElementCount()); 517 Value *AllTrueMask = Constant::getAllOnesValue(TrueMaskTy); 518 Value *VectorLhsLoad = Builder.CreateIntrinsic( 519 Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()}, 520 {VectorLhsGep, AllTrueMask, VL}, nullptr, "lhs.load"); 521 522 Value *VectorRhsGep = 523 Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds()); 524 Value *VectorRhsLoad = Builder.CreateIntrinsic( 525 Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()}, 526 {VectorRhsGep, AllTrueMask, VL}, nullptr, "rhs.load"); 527 528 StringRef PredicateStr = CmpInst::getPredicateName(CmpInst::ICMP_NE); 529 auto *PredicateMDS = MDString::get(VectorLhsLoad->getContext(), PredicateStr); 530 Value *Pred = MetadataAsValue::get(VectorLhsLoad->getContext(), PredicateMDS); 531 Value *VectorMatchCmp = Builder.CreateIntrinsic( 532 Intrinsic::vp_icmp, {VectorLhsLoad->getType()}, 533 {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr, 534 "mismatch.cmp"); 535 Value *CTZ = Builder.CreateIntrinsic( 536 Intrinsic::vp_cttz_elts, {ResType, VectorMatchCmp->getType()}, 537 {VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(false), AllTrueMask, 538 VL}); 539 Value *MismatchFound = Builder.CreateICmpNE(CTZ, VL); 540 auto *VectorEarlyExit = BranchInst::Create(VectorLoopMismatchBlock, 541 VectorLoopIncBlock, MismatchFound); 542 Builder.Insert(VectorEarlyExit); 543 544 DTU.applyUpdates( 545 {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, 546 {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); 547 548 // Increment the index counter and calculate the predicate for the next 549 // iteration of the loop. We branch back to the start of the loop if there 550 // is at least one active lane. 551 Builder.SetInsertPoint(VectorLoopIncBlock); 552 Value *VL64 = Builder.CreateZExt(VL, I64Type); 553 Value *NewVectorIndexPhi = 554 Builder.CreateAdd(VectorIndexPhi, VL64, "", 555 /*HasNUW=*/true, /*HasNSW=*/true); 556 VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock); 557 Value *ExitCond = Builder.CreateICmpNE(NewVectorIndexPhi, ExtEnd); 558 auto *VectorLoopBranchBack = 559 BranchInst::Create(VectorLoopStartBlock, EndBlock, ExitCond); 560 Builder.Insert(VectorLoopBranchBack); 561 562 DTU.applyUpdates( 563 {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, 564 {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); 565 566 // If we found a mismatch then we need to calculate which lane in the vector 567 // had a mismatch and add that on to the current loop index. 568 Builder.SetInsertPoint(VectorLoopMismatchBlock); 569 570 // Add LCSSA phis for CTZ and VectorIndexPhi. 571 auto *CTZLCSSAPhi = Builder.CreatePHI(CTZ->getType(), 1, "ctz"); 572 CTZLCSSAPhi->addIncoming(CTZ, VectorLoopStartBlock); 573 auto *VectorIndexLCSSAPhi = 574 Builder.CreatePHI(VectorIndexPhi->getType(), 1, "mismatch_vector_index"); 575 VectorIndexLCSSAPhi->addIncoming(VectorIndexPhi, VectorLoopStartBlock); 576 577 Value *CTZI64 = Builder.CreateZExt(CTZLCSSAPhi, I64Type); 578 Value *VectorLoopRes64 = Builder.CreateAdd(VectorIndexLCSSAPhi, CTZI64, "", 579 /*HasNUW=*/true, /*HasNSW=*/true); 580 return Builder.CreateTrunc(VectorLoopRes64, ResType); 581 } 582 583 Value *LoopIdiomVectorize::expandFindMismatch( 584 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, 585 GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) { 586 Value *PtrA = GEPA->getPointerOperand(); 587 Value *PtrB = GEPB->getPointerOperand(); 588 589 // Get the arguments and types for the intrinsic. 590 BasicBlock *Preheader = CurLoop->getLoopPreheader(); 591 BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator()); 592 LLVMContext &Ctx = PHBranch->getContext(); 593 Type *LoadType = Type::getInt8Ty(Ctx); 594 Type *ResType = Builder.getInt32Ty(); 595 596 // Split block in the original loop preheader. 597 EndBlock = SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end"); 598 599 // Create the blocks that we're going to need: 600 // 1. A block for checking the zero-extended length exceeds 0 601 // 2. A block to check that the start and end addresses of a given array 602 // lie on the same page. 603 // 3. The vector loop preheader. 604 // 4. The first vector loop block. 605 // 5. The vector loop increment block. 606 // 6. A block we can jump to from the vector loop when a mismatch is found. 607 // 7. The first block of the scalar loop itself, containing PHIs , loads 608 // and cmp. 609 // 8. A scalar loop increment block to increment the PHIs and go back 610 // around the loop. 611 612 BasicBlock *MinItCheckBlock = BasicBlock::Create( 613 Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock); 614 615 // Update the terminator added by SplitBlock to branch to the first block 616 Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock); 617 618 BasicBlock *MemCheckBlock = BasicBlock::Create( 619 Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock); 620 621 VectorLoopPreheaderBlock = BasicBlock::Create( 622 Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock); 623 624 VectorLoopStartBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop", 625 EndBlock->getParent(), EndBlock); 626 627 VectorLoopIncBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_inc", 628 EndBlock->getParent(), EndBlock); 629 630 VectorLoopMismatchBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_found", 631 EndBlock->getParent(), EndBlock); 632 633 BasicBlock *LoopPreHeaderBlock = BasicBlock::Create( 634 Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock); 635 636 BasicBlock *LoopStartBlock = 637 BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock); 638 639 BasicBlock *LoopIncBlock = BasicBlock::Create( 640 Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock); 641 642 DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock}, 643 {DominatorTree::Delete, Preheader, EndBlock}}); 644 645 // Update LoopInfo with the new vector & scalar loops. 646 auto VectorLoop = LI->AllocateLoop(); 647 auto ScalarLoop = LI->AllocateLoop(); 648 649 if (CurLoop->getParentLoop()) { 650 CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, *LI); 651 CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, *LI); 652 CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopPreheaderBlock, 653 *LI); 654 CurLoop->getParentLoop()->addChildLoop(VectorLoop); 655 CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopMismatchBlock, *LI); 656 CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, *LI); 657 CurLoop->getParentLoop()->addChildLoop(ScalarLoop); 658 } else { 659 LI->addTopLevelLoop(VectorLoop); 660 LI->addTopLevelLoop(ScalarLoop); 661 } 662 663 // Add the new basic blocks to their associated loops. 664 VectorLoop->addBasicBlockToLoop(VectorLoopStartBlock, *LI); 665 VectorLoop->addBasicBlockToLoop(VectorLoopIncBlock, *LI); 666 667 ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI); 668 ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI); 669 670 // Set up some types and constants that we intend to reuse. 671 Type *I64Type = Builder.getInt64Ty(); 672 673 // Check the zero-extended iteration count > 0 674 Builder.SetInsertPoint(MinItCheckBlock); 675 Value *ExtStart = Builder.CreateZExt(Start, I64Type); 676 Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type); 677 // This check doesn't really cost us very much. 678 679 Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen); 680 BranchInst *MinItCheckBr = 681 BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck); 682 MinItCheckBr->setMetadata( 683 LLVMContext::MD_prof, 684 MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1)); 685 Builder.Insert(MinItCheckBr); 686 687 DTU.applyUpdates( 688 {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock}, 689 {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}}); 690 691 // For each of the arrays, check the start/end addresses are on the same 692 // page. 693 Builder.SetInsertPoint(MemCheckBlock); 694 695 // The early exit in the original loop means that when performing vector 696 // loads we are potentially reading ahead of the early exit. So we could 697 // fault if crossing a page boundary. Therefore, we create runtime memory 698 // checks based on the minimum page size as follows: 699 // 1. Calculate the addresses of the first memory accesses in the loop, 700 // i.e. LhsStart and RhsStart. 701 // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd. 702 // 3. Determine which pages correspond to all the memory accesses, i.e 703 // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage. 704 // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then 705 // we know we won't cross any page boundaries in the loop so we can 706 // enter the vector loop! Otherwise we fall back on the scalar loop. 707 Value *LhsStartGEP = Builder.CreateGEP(LoadType, PtrA, ExtStart); 708 Value *RhsStartGEP = Builder.CreateGEP(LoadType, PtrB, ExtStart); 709 Value *RhsStart = Builder.CreatePtrToInt(RhsStartGEP, I64Type); 710 Value *LhsStart = Builder.CreatePtrToInt(LhsStartGEP, I64Type); 711 Value *LhsEndGEP = Builder.CreateGEP(LoadType, PtrA, ExtEnd); 712 Value *RhsEndGEP = Builder.CreateGEP(LoadType, PtrB, ExtEnd); 713 Value *LhsEnd = Builder.CreatePtrToInt(LhsEndGEP, I64Type); 714 Value *RhsEnd = Builder.CreatePtrToInt(RhsEndGEP, I64Type); 715 716 const uint64_t MinPageSize = TTI->getMinPageSize().value(); 717 const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize); 718 Value *LhsStartPage = Builder.CreateLShr(LhsStart, AddrShiftAmt); 719 Value *LhsEndPage = Builder.CreateLShr(LhsEnd, AddrShiftAmt); 720 Value *RhsStartPage = Builder.CreateLShr(RhsStart, AddrShiftAmt); 721 Value *RhsEndPage = Builder.CreateLShr(RhsEnd, AddrShiftAmt); 722 Value *LhsPageCmp = Builder.CreateICmpNE(LhsStartPage, LhsEndPage); 723 Value *RhsPageCmp = Builder.CreateICmpNE(RhsStartPage, RhsEndPage); 724 725 Value *CombinedPageCmp = Builder.CreateOr(LhsPageCmp, RhsPageCmp); 726 BranchInst *CombinedPageCmpCmpBr = BranchInst::Create( 727 LoopPreHeaderBlock, VectorLoopPreheaderBlock, CombinedPageCmp); 728 CombinedPageCmpCmpBr->setMetadata( 729 LLVMContext::MD_prof, MDBuilder(CombinedPageCmpCmpBr->getContext()) 730 .createBranchWeights(10, 90)); 731 Builder.Insert(CombinedPageCmpCmpBr); 732 733 DTU.applyUpdates( 734 {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock}, 735 {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}}); 736 737 // Set up the vector loop preheader, i.e. calculate initial loop predicate, 738 // zero-extend MaxLen to 64-bits, determine the number of vector elements 739 // processed in each iteration, etc. 740 Builder.SetInsertPoint(VectorLoopPreheaderBlock); 741 742 // At this point we know two things must be true: 743 // 1. Start <= End 744 // 2. ExtMaxLen <= MinPageSize due to the page checks. 745 // Therefore, we know that we can use a 64-bit induction variable that 746 // starts from 0 -> ExtMaxLen and it will not overflow. 747 Value *VectorLoopRes = nullptr; 748 switch (VectorizeStyle) { 749 case LoopIdiomVectorizeStyle::Masked: 750 VectorLoopRes = 751 createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd); 752 break; 753 case LoopIdiomVectorizeStyle::Predicated: 754 VectorLoopRes = createPredicatedFindMismatch(Builder, DTU, GEPA, GEPB, 755 ExtStart, ExtEnd); 756 break; 757 } 758 759 Builder.Insert(BranchInst::Create(EndBlock)); 760 761 DTU.applyUpdates( 762 {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}}); 763 764 // Generate code for scalar loop. 765 Builder.SetInsertPoint(LoopPreHeaderBlock); 766 Builder.Insert(BranchInst::Create(LoopStartBlock)); 767 768 DTU.applyUpdates( 769 {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}}); 770 771 Builder.SetInsertPoint(LoopStartBlock); 772 PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index"); 773 IndexPhi->addIncoming(Start, LoopPreHeaderBlock); 774 775 // Otherwise compare the values 776 // Load bytes from each array and compare them. 777 Value *GepOffset = Builder.CreateZExt(IndexPhi, I64Type); 778 779 Value *LhsGep = 780 Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds()); 781 Value *LhsLoad = Builder.CreateLoad(LoadType, LhsGep); 782 783 Value *RhsGep = 784 Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds()); 785 Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep); 786 787 Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad); 788 // If we have a mismatch then exit the loop ... 789 BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp); 790 Builder.Insert(MatchCmpBr); 791 792 DTU.applyUpdates({{DominatorTree::Insert, LoopStartBlock, LoopIncBlock}, 793 {DominatorTree::Insert, LoopStartBlock, EndBlock}}); 794 795 // Have we reached the maximum permitted length for the loop? 796 Builder.SetInsertPoint(LoopIncBlock); 797 Value *PhiInc = Builder.CreateAdd(IndexPhi, ConstantInt::get(ResType, 1), "", 798 /*HasNUW=*/Index->hasNoUnsignedWrap(), 799 /*HasNSW=*/Index->hasNoSignedWrap()); 800 IndexPhi->addIncoming(PhiInc, LoopIncBlock); 801 Value *IVCmp = Builder.CreateICmpEQ(PhiInc, MaxLen); 802 BranchInst *IVCmpBr = BranchInst::Create(EndBlock, LoopStartBlock, IVCmp); 803 Builder.Insert(IVCmpBr); 804 805 DTU.applyUpdates({{DominatorTree::Insert, LoopIncBlock, EndBlock}, 806 {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}}); 807 808 // In the end block we need to insert a PHI node to deal with three cases: 809 // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen. 810 // 2. We exitted the scalar loop early due to a mismatch and need to return 811 // the index that we found. 812 // 3. We didn't find a mismatch in the vector loop, so we return MaxLen. 813 // 4. We exitted the vector loop early due to a mismatch and need to return 814 // the index that we found. 815 Builder.SetInsertPoint(EndBlock, EndBlock->getFirstInsertionPt()); 816 PHINode *ResPhi = Builder.CreatePHI(ResType, 4, "mismatch_result"); 817 ResPhi->addIncoming(MaxLen, LoopIncBlock); 818 ResPhi->addIncoming(IndexPhi, LoopStartBlock); 819 ResPhi->addIncoming(MaxLen, VectorLoopIncBlock); 820 ResPhi->addIncoming(VectorLoopRes, VectorLoopMismatchBlock); 821 822 Value *FinalRes = Builder.CreateTrunc(ResPhi, ResType); 823 824 if (VerifyLoops) { 825 ScalarLoop->verifyLoop(); 826 VectorLoop->verifyLoop(); 827 if (!VectorLoop->isRecursivelyLCSSAForm(*DT, *LI)) 828 report_fatal_error("Loops must remain in LCSSA form!"); 829 if (!ScalarLoop->isRecursivelyLCSSAForm(*DT, *LI)) 830 report_fatal_error("Loops must remain in LCSSA form!"); 831 } 832 833 return FinalRes; 834 } 835 836 void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA, 837 GetElementPtrInst *GEPB, 838 PHINode *IndPhi, Value *MaxLen, 839 Instruction *Index, Value *Start, 840 bool IncIdx, BasicBlock *FoundBB, 841 BasicBlock *EndBB) { 842 843 // Insert the byte compare code at the end of the preheader block 844 BasicBlock *Preheader = CurLoop->getLoopPreheader(); 845 BasicBlock *Header = CurLoop->getHeader(); 846 BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator()); 847 IRBuilder<> Builder(PHBranch); 848 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 849 Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc()); 850 851 // Increment the pointer if this was done before the loads in the loop. 852 if (IncIdx) 853 Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1)); 854 855 Value *ByteCmpRes = 856 expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen); 857 858 // Replaces uses of index & induction Phi with intrinsic (we already 859 // checked that the the first instruction of Header is the Phi above). 860 assert(IndPhi->hasOneUse() && "Index phi node has more than one use!"); 861 Index->replaceAllUsesWith(ByteCmpRes); 862 863 assert(PHBranch->isUnconditional() && 864 "Expected preheader to terminate with an unconditional branch."); 865 866 // If no mismatch was found, we can jump to the end block. Create a 867 // new basic block for the compare instruction. 868 auto *CmpBB = BasicBlock::Create(Preheader->getContext(), "byte.compare", 869 Preheader->getParent()); 870 CmpBB->moveBefore(EndBB); 871 872 // Replace the branch in the preheader with an always-true conditional branch. 873 // This ensures there is still a reference to the original loop. 874 Builder.CreateCondBr(Builder.getTrue(), CmpBB, Header); 875 PHBranch->eraseFromParent(); 876 877 BasicBlock *MismatchEnd = cast<Instruction>(ByteCmpRes)->getParent(); 878 DTU.applyUpdates({{DominatorTree::Insert, MismatchEnd, CmpBB}}); 879 880 // Create the branch to either the end or found block depending on the value 881 // returned by the intrinsic. 882 Builder.SetInsertPoint(CmpBB); 883 if (FoundBB != EndBB) { 884 Value *FoundCmp = Builder.CreateICmpEQ(ByteCmpRes, MaxLen); 885 Builder.CreateCondBr(FoundCmp, EndBB, FoundBB); 886 DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}, 887 {DominatorTree::Insert, CmpBB, EndBB}}); 888 889 } else { 890 Builder.CreateBr(FoundBB); 891 DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}}); 892 } 893 894 auto fixSuccessorPhis = [&](BasicBlock *SuccBB) { 895 for (PHINode &PN : SuccBB->phis()) { 896 // At this point we've already replaced all uses of the result from the 897 // loop with ByteCmp. Look through the incoming values to find ByteCmp, 898 // meaning this is a Phi collecting the results of the byte compare. 899 bool ResPhi = false; 900 for (Value *Op : PN.incoming_values()) 901 if (Op == ByteCmpRes) { 902 ResPhi = true; 903 break; 904 } 905 906 // Any PHI that depended upon the result of the byte compare needs a new 907 // incoming value from CmpBB. This is because the original loop will get 908 // deleted. 909 if (ResPhi) 910 PN.addIncoming(ByteCmpRes, CmpBB); 911 else { 912 // There should be no other outside uses of other values in the 913 // original loop. Any incoming values should either: 914 // 1. Be for blocks outside the loop, which aren't interesting. Or .. 915 // 2. These are from blocks in the loop with values defined outside 916 // the loop. We should a similar incoming value from CmpBB. 917 for (BasicBlock *BB : PN.blocks()) 918 if (CurLoop->contains(BB)) { 919 PN.addIncoming(PN.getIncomingValueForBlock(BB), CmpBB); 920 break; 921 } 922 } 923 } 924 }; 925 926 // Ensure all Phis in the successors of CmpBB have an incoming value from it. 927 fixSuccessorPhis(EndBB); 928 if (EndBB != FoundBB) 929 fixSuccessorPhis(FoundBB); 930 931 // The new CmpBB block isn't part of the loop, but will need to be added to 932 // the outer loop if there is one. 933 if (!CurLoop->isOutermost()) 934 CurLoop->getParentLoop()->addBasicBlockToLoop(CmpBB, *LI); 935 936 if (VerifyLoops && CurLoop->getParentLoop()) { 937 CurLoop->getParentLoop()->verifyLoop(); 938 if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI)) 939 report_fatal_error("Loops must remain in LCSSA form!"); 940 } 941 } 942