xref: /llvm-project/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp (revision de5ff38a0d20faedac43a1d838fb65b67e77c34e)
1 //===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements a pass that recognizes certain loop idioms and
10 // transforms them into more optimized versions of the same loop. In cases
11 // where this happens, it can be a significant performance win.
12 //
13 // We currently only recognize one loop that finds the first mismatched byte
14 // in an array and returns the index, i.e. something like:
15 //
16 //  while (++i != n) {
17 //    if (a[i] != b[i])
18 //      break;
19 //  }
20 //
21 // In this example we can actually vectorize the loop despite the early exit,
22 // although the loop vectorizer does not support it. It requires some extra
23 // checks to deal with the possibility of faulting loads when crossing page
24 // boundaries. However, even with these checks it is still profitable to do the
25 // transformation.
26 //
27 //===----------------------------------------------------------------------===//
28 //
29 // NOTE: This Pass matches a really specific loop pattern because it's only
30 // supposed to be a temporary solution until our LoopVectorizer is powerful
31 // enought to vectorize it automatically.
32 //
33 // TODO List:
34 //
35 // * Add support for the inverse case where we scan for a matching element.
36 // * Permit 64-bit induction variable types.
37 // * Recognize loops that increment the IV *after* comparing bytes.
38 // * Allow 32-bit sign-extends of the IV used by the GEP.
39 //
40 //===----------------------------------------------------------------------===//
41 
42 #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
43 #include "llvm/Analysis/DomTreeUpdater.h"
44 #include "llvm/Analysis/LoopPass.h"
45 #include "llvm/Analysis/TargetTransformInfo.h"
46 #include "llvm/IR/Dominators.h"
47 #include "llvm/IR/IRBuilder.h"
48 #include "llvm/IR/Intrinsics.h"
49 #include "llvm/IR/MDBuilder.h"
50 #include "llvm/IR/PatternMatch.h"
51 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
52 
53 using namespace llvm;
54 using namespace PatternMatch;
55 
56 #define DEBUG_TYPE "loop-idiom-vectorize"
57 
58 static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
59                                 cl::init(false),
60                                 cl::desc("Disable Loop Idiom Vectorize Pass."));
61 
62 static cl::opt<bool>
63     DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden,
64                    cl::init(false),
65                    cl::desc("Proceed with Loop Idiom Vectorize Pass, but do "
66                             "not convert byte-compare loop(s)."));
67 
68 static cl::opt<bool>
69     VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false),
70                 cl::desc("Verify loops generated Loop Idiom Vectorize Pass."));
71 
72 namespace {
73 
74 class LoopIdiomVectorize {
75   Loop *CurLoop = nullptr;
76   DominatorTree *DT;
77   LoopInfo *LI;
78   const TargetTransformInfo *TTI;
79   const DataLayout *DL;
80 
81   // Blocks that will be used for inserting vectorized code.
82   BasicBlock *EndBlock = nullptr;
83   BasicBlock *VectorLoopPreheaderBlock = nullptr;
84   BasicBlock *VectorLoopStartBlock = nullptr;
85   BasicBlock *VectorLoopMismatchBlock = nullptr;
86   BasicBlock *VectorLoopIncBlock = nullptr;
87 
88 public:
89   explicit LoopIdiomVectorize(DominatorTree *DT, LoopInfo *LI,
90                               const TargetTransformInfo *TTI,
91                               const DataLayout *DL)
92       : DT(DT), LI(LI), TTI(TTI), DL(DL) {}
93 
94   bool run(Loop *L);
95 
96 private:
97   /// \name Countable Loop Idiom Handling
98   /// @{
99 
100   bool runOnCountableLoop();
101   bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
102                       SmallVectorImpl<BasicBlock *> &ExitBlocks);
103 
104   bool recognizeByteCompare();
105 
106   Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
107                             GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
108                             Instruction *Index, Value *Start, Value *MaxLen);
109 
110   Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
111                                   GetElementPtrInst *GEPA,
112                                   GetElementPtrInst *GEPB, Value *ExtStart,
113                                   Value *ExtEnd);
114 
115   void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
116                             PHINode *IndPhi, Value *MaxLen, Instruction *Index,
117                             Value *Start, bool IncIdx, BasicBlock *FoundBB,
118                             BasicBlock *EndBB);
119   /// @}
120 };
121 } // anonymous namespace
122 
123 PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
124                                               LoopStandardAnalysisResults &AR,
125                                               LPMUpdater &) {
126   if (DisableAll)
127     return PreservedAnalyses::all();
128 
129   const auto *DL = &L.getHeader()->getDataLayout();
130 
131   LoopIdiomVectorize LIT(&AR.DT, &AR.LI, &AR.TTI, DL);
132   if (!LIT.run(&L))
133     return PreservedAnalyses::all();
134 
135   return PreservedAnalyses::none();
136 }
137 
138 //===----------------------------------------------------------------------===//
139 //
140 //          Implementation of LoopIdiomVectorize
141 //
142 //===----------------------------------------------------------------------===//
143 
144 bool LoopIdiomVectorize::run(Loop *L) {
145   CurLoop = L;
146 
147   Function &F = *L->getHeader()->getParent();
148   if (DisableAll || F.hasOptSize())
149     return false;
150 
151   if (F.hasFnAttribute(Attribute::NoImplicitFloat)) {
152     LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName()
153                       << " due to its NoImplicitFloat attribute");
154     return false;
155   }
156 
157   // If the loop could not be converted to canonical form, it must have an
158   // indirectbr in it, just give up.
159   if (!L->getLoopPreheader())
160     return false;
161 
162   LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %"
163                     << CurLoop->getHeader()->getName() << "\n");
164 
165   return recognizeByteCompare();
166 }
167 
168 bool LoopIdiomVectorize::recognizeByteCompare() {
169   // Currently the transformation only works on scalable vector types, although
170   // there is no fundamental reason why it cannot be made to work for fixed
171   // width too.
172 
173   // We also need to know the minimum page size for the target in order to
174   // generate runtime memory checks to ensure the vector version won't fault.
175   if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
176       DisableByteCmp)
177     return false;
178 
179   BasicBlock *Header = CurLoop->getHeader();
180 
181   // In LoopIdiomVectorize::run we have already checked that the loop
182   // has a preheader so we can assume it's in a canonical form.
183   if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
184     return false;
185 
186   PHINode *PN = dyn_cast<PHINode>(&Header->front());
187   if (!PN || PN->getNumIncomingValues() != 2)
188     return false;
189 
190   auto LoopBlocks = CurLoop->getBlocks();
191   // The first block in the loop should contain only 4 instructions, e.g.
192   //
193   //  while.cond:
194   //   %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ]
195   //   %inc = add i32 %res.phi, 1
196   //   %cmp.not = icmp eq i32 %inc, %n
197   //   br i1 %cmp.not, label %while.end, label %while.body
198   //
199   if (LoopBlocks[0]->sizeWithoutDebug() > 4)
200     return false;
201 
202   // The second block should contain 7 instructions, e.g.
203   //
204   // while.body:
205   //   %idx = zext i32 %inc to i64
206   //   %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx
207   //   %load.a = load i8, ptr %idx.a
208   //   %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx
209   //   %load.b = load i8, ptr %idx.b
210   //   %cmp.not.ld = icmp eq i8 %load.a, %load.b
211   //   br i1 %cmp.not.ld, label %while.cond, label %while.end
212   //
213   if (LoopBlocks[1]->sizeWithoutDebug() > 7)
214     return false;
215 
216   // The incoming value to the PHI node from the loop should be an add of 1.
217   Value *StartIdx = nullptr;
218   Instruction *Index = nullptr;
219   if (!CurLoop->contains(PN->getIncomingBlock(0))) {
220     StartIdx = PN->getIncomingValue(0);
221     Index = dyn_cast<Instruction>(PN->getIncomingValue(1));
222   } else {
223     StartIdx = PN->getIncomingValue(1);
224     Index = dyn_cast<Instruction>(PN->getIncomingValue(0));
225   }
226 
227   // Limit to 32-bit types for now
228   if (!Index || !Index->getType()->isIntegerTy(32) ||
229       !match(Index, m_c_Add(m_Specific(PN), m_One())))
230     return false;
231 
232   // If we match the pattern, PN and Index will be replaced with the result of
233   // the cttz.elts intrinsic. If any other instructions are used outside of
234   // the loop, we cannot replace it.
235   for (BasicBlock *BB : LoopBlocks)
236     for (Instruction &I : *BB)
237       if (&I != PN && &I != Index)
238         for (User *U : I.users())
239           if (!CurLoop->contains(cast<Instruction>(U)))
240             return false;
241 
242   // Match the branch instruction for the header
243   ICmpInst::Predicate Pred;
244   Value *MaxLen;
245   BasicBlock *EndBB, *WhileBB;
246   if (!match(Header->getTerminator(),
247              m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)),
248                   m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) ||
249       Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB))
250     return false;
251 
252   // WhileBB should contain the pattern of load & compare instructions. Match
253   // the pattern and find the GEP instructions used by the loads.
254   ICmpInst::Predicate WhilePred;
255   BasicBlock *FoundBB;
256   BasicBlock *TrueBB;
257   Value *LoadA, *LoadB;
258   if (!match(WhileBB->getTerminator(),
259              m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)),
260                   m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
261       WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB))
262     return false;
263 
264   Value *A, *B;
265   if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
266     return false;
267 
268   LoadInst *LoadAI = cast<LoadInst>(LoadA);
269   LoadInst *LoadBI = cast<LoadInst>(LoadB);
270   if (!LoadAI->isSimple() || !LoadBI->isSimple())
271     return false;
272 
273   GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A);
274   GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B);
275 
276   if (!GEPA || !GEPB)
277     return false;
278 
279   Value *PtrA = GEPA->getPointerOperand();
280   Value *PtrB = GEPB->getPointerOperand();
281 
282   // Check we are loading i8 values from two loop invariant pointers
283   if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
284       !GEPA->getResultElementType()->isIntegerTy(8) ||
285       !GEPB->getResultElementType()->isIntegerTy(8) ||
286       !LoadAI->getType()->isIntegerTy(8) ||
287       !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB)
288     return false;
289 
290   // Check that the index to the GEPs is the index we found earlier
291   if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
292     return false;
293 
294   Value *IdxA = GEPA->getOperand(GEPA->getNumIndices());
295   Value *IdxB = GEPB->getOperand(GEPB->getNumIndices());
296   if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index))))
297     return false;
298 
299   // We only ever expect the pre-incremented index value to be used inside the
300   // loop.
301   if (!PN->hasOneUse())
302     return false;
303 
304   // Ensure that when the Found and End blocks are identical the PHIs have the
305   // supported format. We don't currently allow cases like this:
306   // while.cond:
307   //   ...
308   //   br i1 %cmp.not, label %while.end, label %while.body
309   //
310   // while.body:
311   //   ...
312   //   br i1 %cmp.not2, label %while.cond, label %while.end
313   //
314   // while.end:
315   //   %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
316   //
317   // Where the incoming values for %final_ptr are unique and from each of the
318   // loop blocks, but not actually defined in the loop. This requires extra
319   // work setting up the byte.compare block, i.e. by introducing a select to
320   // choose the correct value.
321   // TODO: We could add support for this in future.
322   if (FoundBB == EndBB) {
323     for (PHINode &EndPN : EndBB->phis()) {
324       Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header);
325       Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB);
326 
327       // The value of the index when leaving the while.cond block is always the
328       // same as the end value (MaxLen) so we permit either. The value when
329       // leaving the while.body block should only be the index. Otherwise for
330       // any other values we only allow ones that are same for both blocks.
331       if (WhileCondVal != WhileBodyVal &&
332           ((WhileCondVal != Index && WhileCondVal != MaxLen) ||
333            (WhileBodyVal != Index)))
334         return false;
335     }
336   }
337 
338   LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n"
339                     << *(EndBB->getParent()) << "\n\n");
340 
341   // The index is incremented before the GEP/Load pair so we need to
342   // add 1 to the start value.
343   transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true,
344                        FoundBB, EndBB);
345   return true;
346 }
347 
348 Value *LoopIdiomVectorize::createMaskedFindMismatch(
349     IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
350     GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
351   Type *I64Type = Builder.getInt64Ty();
352   Type *ResType = Builder.getInt32Ty();
353   Type *LoadType = Builder.getInt8Ty();
354   Value *PtrA = GEPA->getPointerOperand();
355   Value *PtrB = GEPB->getPointerOperand();
356 
357   // At this point we know two things must be true:
358   //  1. Start <= End
359   //  2. ExtMaxLen <= MinPageSize due to the page checks.
360   // Therefore, we know that we can use a 64-bit induction variable that
361   // starts from 0 -> ExtMaxLen and it will not overflow.
362   ScalableVectorType *PredVTy =
363       ScalableVectorType::get(Builder.getInt1Ty(), 16);
364 
365   Value *InitialPred = Builder.CreateIntrinsic(
366       Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
367 
368   Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
369   VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
370                              /*HasNUW=*/true, /*HasNSW=*/true);
371 
372   Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
373                                             Builder.getInt1(false));
374 
375   BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
376   Builder.Insert(JumpToVectorLoop);
377 
378   DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
379                      VectorLoopStartBlock}});
380 
381   // Set up the first vector loop block by creating the PHIs, doing the vector
382   // loads and comparing the vectors.
383   Builder.SetInsertPoint(VectorLoopStartBlock);
384   PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
385   LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
386   PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
387   VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
388   Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
389   Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
390 
391   Value *VectorLhsGep =
392       Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
393   Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
394                                                   Align(1), LoopPred, Passthru);
395 
396   Value *VectorRhsGep =
397       Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
398   Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
399                                                   Align(1), LoopPred, Passthru);
400 
401   Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
402   VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
403   Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
404   BranchInst *VectorEarlyExit = BranchInst::Create(
405       VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
406   Builder.Insert(VectorEarlyExit);
407 
408   DTU.applyUpdates(
409       {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
410        {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
411 
412   // Increment the index counter and calculate the predicate for the next
413   // iteration of the loop. We branch back to the start of the loop if there
414   // is at least one active lane.
415   Builder.SetInsertPoint(VectorLoopIncBlock);
416   Value *NewVectorIndexPhi =
417       Builder.CreateAdd(VectorIndexPhi, VecLen, "",
418                         /*HasNUW=*/true, /*HasNSW=*/true);
419   VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
420   Value *NewPred =
421       Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
422                               {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
423   LoopPred->addIncoming(NewPred, VectorLoopIncBlock);
424 
425   Value *PredHasActiveLanes =
426       Builder.CreateExtractElement(NewPred, uint64_t(0));
427   BranchInst *VectorLoopBranchBack =
428       BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
429   Builder.Insert(VectorLoopBranchBack);
430 
431   DTU.applyUpdates(
432       {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
433        {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
434 
435   // If we found a mismatch then we need to calculate which lane in the vector
436   // had a mismatch and add that on to the current loop index.
437   Builder.SetInsertPoint(VectorLoopMismatchBlock);
438   PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
439   FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
440   PHINode *LastLoopPred =
441       Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
442   LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
443   PHINode *VectorFoundIndex =
444       Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
445   VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
446 
447   Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
448   Value *Ctz = Builder.CreateIntrinsic(
449       Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
450       {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
451   Ctz = Builder.CreateZExt(Ctz, I64Type);
452   Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
453                                              /*HasNUW=*/true, /*HasNSW=*/true);
454   return Builder.CreateTrunc(VectorLoopRes64, ResType);
455 }
456 
457 Value *LoopIdiomVectorize::expandFindMismatch(
458     IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
459     GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
460   Value *PtrA = GEPA->getPointerOperand();
461   Value *PtrB = GEPB->getPointerOperand();
462 
463   // Get the arguments and types for the intrinsic.
464   BasicBlock *Preheader = CurLoop->getLoopPreheader();
465   BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
466   LLVMContext &Ctx = PHBranch->getContext();
467   Type *LoadType = Type::getInt8Ty(Ctx);
468   Type *ResType = Builder.getInt32Ty();
469 
470   // Split block in the original loop preheader.
471   EndBlock = SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
472 
473   // Create the blocks that we're going to need:
474   //  1. A block for checking the zero-extended length exceeds 0
475   //  2. A block to check that the start and end addresses of a given array
476   //     lie on the same page.
477   //  3. The vector loop preheader.
478   //  4. The first vector loop block.
479   //  5. The vector loop increment block.
480   //  6. A block we can jump to from the vector loop when a mismatch is found.
481   //  7. The first block of the scalar loop itself, containing PHIs , loads
482   //  and cmp.
483   //  8. A scalar loop increment block to increment the PHIs and go back
484   //  around the loop.
485 
486   BasicBlock *MinItCheckBlock = BasicBlock::Create(
487       Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock);
488 
489   // Update the terminator added by SplitBlock to branch to the first block
490   Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock);
491 
492   BasicBlock *MemCheckBlock = BasicBlock::Create(
493       Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock);
494 
495   VectorLoopPreheaderBlock = BasicBlock::Create(
496       Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock);
497 
498   VectorLoopStartBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop",
499                                             EndBlock->getParent(), EndBlock);
500 
501   VectorLoopIncBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_inc",
502                                           EndBlock->getParent(), EndBlock);
503 
504   VectorLoopMismatchBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_found",
505                                                EndBlock->getParent(), EndBlock);
506 
507   BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
508       Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock);
509 
510   BasicBlock *LoopStartBlock =
511       BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock);
512 
513   BasicBlock *LoopIncBlock = BasicBlock::Create(
514       Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock);
515 
516   DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock},
517                     {DominatorTree::Delete, Preheader, EndBlock}});
518 
519   // Update LoopInfo with the new vector & scalar loops.
520   auto VectorLoop = LI->AllocateLoop();
521   auto ScalarLoop = LI->AllocateLoop();
522 
523   if (CurLoop->getParentLoop()) {
524     CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, *LI);
525     CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, *LI);
526     CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopPreheaderBlock,
527                                                   *LI);
528     CurLoop->getParentLoop()->addChildLoop(VectorLoop);
529     CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopMismatchBlock, *LI);
530     CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, *LI);
531     CurLoop->getParentLoop()->addChildLoop(ScalarLoop);
532   } else {
533     LI->addTopLevelLoop(VectorLoop);
534     LI->addTopLevelLoop(ScalarLoop);
535   }
536 
537   // Add the new basic blocks to their associated loops.
538   VectorLoop->addBasicBlockToLoop(VectorLoopStartBlock, *LI);
539   VectorLoop->addBasicBlockToLoop(VectorLoopIncBlock, *LI);
540 
541   ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI);
542   ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI);
543 
544   // Set up some types and constants that we intend to reuse.
545   Type *I64Type = Builder.getInt64Ty();
546 
547   // Check the zero-extended iteration count > 0
548   Builder.SetInsertPoint(MinItCheckBlock);
549   Value *ExtStart = Builder.CreateZExt(Start, I64Type);
550   Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type);
551   // This check doesn't really cost us very much.
552 
553   Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen);
554   BranchInst *MinItCheckBr =
555       BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck);
556   MinItCheckBr->setMetadata(
557       LLVMContext::MD_prof,
558       MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1));
559   Builder.Insert(MinItCheckBr);
560 
561   DTU.applyUpdates(
562       {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock},
563        {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}});
564 
565   // For each of the arrays, check the start/end addresses are on the same
566   // page.
567   Builder.SetInsertPoint(MemCheckBlock);
568 
569   // The early exit in the original loop means that when performing vector
570   // loads we are potentially reading ahead of the early exit. So we could
571   // fault if crossing a page boundary. Therefore, we create runtime memory
572   // checks based on the minimum page size as follows:
573   //   1. Calculate the addresses of the first memory accesses in the loop,
574   //      i.e. LhsStart and RhsStart.
575   //   2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd.
576   //   3. Determine which pages correspond to all the memory accesses, i.e
577   //      LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage.
578   //   4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then
579   //      we know we won't cross any page boundaries in the loop so we can
580   //      enter the vector loop! Otherwise we fall back on the scalar loop.
581   Value *LhsStartGEP = Builder.CreateGEP(LoadType, PtrA, ExtStart);
582   Value *RhsStartGEP = Builder.CreateGEP(LoadType, PtrB, ExtStart);
583   Value *RhsStart = Builder.CreatePtrToInt(RhsStartGEP, I64Type);
584   Value *LhsStart = Builder.CreatePtrToInt(LhsStartGEP, I64Type);
585   Value *LhsEndGEP = Builder.CreateGEP(LoadType, PtrA, ExtEnd);
586   Value *RhsEndGEP = Builder.CreateGEP(LoadType, PtrB, ExtEnd);
587   Value *LhsEnd = Builder.CreatePtrToInt(LhsEndGEP, I64Type);
588   Value *RhsEnd = Builder.CreatePtrToInt(RhsEndGEP, I64Type);
589 
590   const uint64_t MinPageSize = TTI->getMinPageSize().value();
591   const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize);
592   Value *LhsStartPage = Builder.CreateLShr(LhsStart, AddrShiftAmt);
593   Value *LhsEndPage = Builder.CreateLShr(LhsEnd, AddrShiftAmt);
594   Value *RhsStartPage = Builder.CreateLShr(RhsStart, AddrShiftAmt);
595   Value *RhsEndPage = Builder.CreateLShr(RhsEnd, AddrShiftAmt);
596   Value *LhsPageCmp = Builder.CreateICmpNE(LhsStartPage, LhsEndPage);
597   Value *RhsPageCmp = Builder.CreateICmpNE(RhsStartPage, RhsEndPage);
598 
599   Value *CombinedPageCmp = Builder.CreateOr(LhsPageCmp, RhsPageCmp);
600   BranchInst *CombinedPageCmpCmpBr = BranchInst::Create(
601       LoopPreHeaderBlock, VectorLoopPreheaderBlock, CombinedPageCmp);
602   CombinedPageCmpCmpBr->setMetadata(
603       LLVMContext::MD_prof, MDBuilder(CombinedPageCmpCmpBr->getContext())
604                                 .createBranchWeights(10, 90));
605   Builder.Insert(CombinedPageCmpCmpBr);
606 
607   DTU.applyUpdates(
608       {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock},
609        {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}});
610 
611   // Set up the vector loop preheader, i.e. calculate initial loop predicate,
612   // zero-extend MaxLen to 64-bits, determine the number of vector elements
613   // processed in each iteration, etc.
614   Builder.SetInsertPoint(VectorLoopPreheaderBlock);
615 
616   Value *VectorLoopRes =
617       createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
618 
619   Builder.Insert(BranchInst::Create(EndBlock));
620 
621   DTU.applyUpdates(
622       {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}});
623 
624   // Generate code for scalar loop.
625   Builder.SetInsertPoint(LoopPreHeaderBlock);
626   Builder.Insert(BranchInst::Create(LoopStartBlock));
627 
628   DTU.applyUpdates(
629       {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}});
630 
631   Builder.SetInsertPoint(LoopStartBlock);
632   PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index");
633   IndexPhi->addIncoming(Start, LoopPreHeaderBlock);
634 
635   // Otherwise compare the values
636   // Load bytes from each array and compare them.
637   Value *GepOffset = Builder.CreateZExt(IndexPhi, I64Type);
638 
639   Value *LhsGep =
640       Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds());
641   Value *LhsLoad = Builder.CreateLoad(LoadType, LhsGep);
642 
643   Value *RhsGep =
644       Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds());
645   Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep);
646 
647   Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
648   // If we have a mismatch then exit the loop ...
649   BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp);
650   Builder.Insert(MatchCmpBr);
651 
652   DTU.applyUpdates({{DominatorTree::Insert, LoopStartBlock, LoopIncBlock},
653                     {DominatorTree::Insert, LoopStartBlock, EndBlock}});
654 
655   // Have we reached the maximum permitted length for the loop?
656   Builder.SetInsertPoint(LoopIncBlock);
657   Value *PhiInc = Builder.CreateAdd(IndexPhi, ConstantInt::get(ResType, 1), "",
658                                     /*HasNUW=*/Index->hasNoUnsignedWrap(),
659                                     /*HasNSW=*/Index->hasNoSignedWrap());
660   IndexPhi->addIncoming(PhiInc, LoopIncBlock);
661   Value *IVCmp = Builder.CreateICmpEQ(PhiInc, MaxLen);
662   BranchInst *IVCmpBr = BranchInst::Create(EndBlock, LoopStartBlock, IVCmp);
663   Builder.Insert(IVCmpBr);
664 
665   DTU.applyUpdates({{DominatorTree::Insert, LoopIncBlock, EndBlock},
666                     {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}});
667 
668   // In the end block we need to insert a PHI node to deal with three cases:
669   //  1. We didn't find a mismatch in the scalar loop, so we return MaxLen.
670   //  2. We exitted the scalar loop early due to a mismatch and need to return
671   //  the index that we found.
672   //  3. We didn't find a mismatch in the vector loop, so we return MaxLen.
673   //  4. We exitted the vector loop early due to a mismatch and need to return
674   //  the index that we found.
675   Builder.SetInsertPoint(EndBlock, EndBlock->getFirstInsertionPt());
676   PHINode *ResPhi = Builder.CreatePHI(ResType, 4, "mismatch_result");
677   ResPhi->addIncoming(MaxLen, LoopIncBlock);
678   ResPhi->addIncoming(IndexPhi, LoopStartBlock);
679   ResPhi->addIncoming(MaxLen, VectorLoopIncBlock);
680   ResPhi->addIncoming(VectorLoopRes, VectorLoopMismatchBlock);
681 
682   Value *FinalRes = Builder.CreateTrunc(ResPhi, ResType);
683 
684   if (VerifyLoops) {
685     ScalarLoop->verifyLoop();
686     VectorLoop->verifyLoop();
687     if (!VectorLoop->isRecursivelyLCSSAForm(*DT, *LI))
688       report_fatal_error("Loops must remain in LCSSA form!");
689     if (!ScalarLoop->isRecursivelyLCSSAForm(*DT, *LI))
690       report_fatal_error("Loops must remain in LCSSA form!");
691   }
692 
693   return FinalRes;
694 }
695 
696 void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
697                                               GetElementPtrInst *GEPB,
698                                               PHINode *IndPhi, Value *MaxLen,
699                                               Instruction *Index, Value *Start,
700                                               bool IncIdx, BasicBlock *FoundBB,
701                                               BasicBlock *EndBB) {
702 
703   // Insert the byte compare code at the end of the preheader block
704   BasicBlock *Preheader = CurLoop->getLoopPreheader();
705   BasicBlock *Header = CurLoop->getHeader();
706   BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
707   IRBuilder<> Builder(PHBranch);
708   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
709   Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
710 
711   // Increment the pointer if this was done before the loads in the loop.
712   if (IncIdx)
713     Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1));
714 
715   Value *ByteCmpRes =
716       expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen);
717 
718   // Replaces uses of index & induction Phi with intrinsic (we already
719   // checked that the the first instruction of Header is the Phi above).
720   assert(IndPhi->hasOneUse() && "Index phi node has more than one use!");
721   Index->replaceAllUsesWith(ByteCmpRes);
722 
723   assert(PHBranch->isUnconditional() &&
724          "Expected preheader to terminate with an unconditional branch.");
725 
726   // If no mismatch was found, we can jump to the end block. Create a
727   // new basic block for the compare instruction.
728   auto *CmpBB = BasicBlock::Create(Preheader->getContext(), "byte.compare",
729                                    Preheader->getParent());
730   CmpBB->moveBefore(EndBB);
731 
732   // Replace the branch in the preheader with an always-true conditional branch.
733   // This ensures there is still a reference to the original loop.
734   Builder.CreateCondBr(Builder.getTrue(), CmpBB, Header);
735   PHBranch->eraseFromParent();
736 
737   BasicBlock *MismatchEnd = cast<Instruction>(ByteCmpRes)->getParent();
738   DTU.applyUpdates({{DominatorTree::Insert, MismatchEnd, CmpBB}});
739 
740   // Create the branch to either the end or found block depending on the value
741   // returned by the intrinsic.
742   Builder.SetInsertPoint(CmpBB);
743   if (FoundBB != EndBB) {
744     Value *FoundCmp = Builder.CreateICmpEQ(ByteCmpRes, MaxLen);
745     Builder.CreateCondBr(FoundCmp, EndBB, FoundBB);
746     DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB},
747                       {DominatorTree::Insert, CmpBB, EndBB}});
748 
749   } else {
750     Builder.CreateBr(FoundBB);
751     DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}});
752   }
753 
754   auto fixSuccessorPhis = [&](BasicBlock *SuccBB) {
755     for (PHINode &PN : SuccBB->phis()) {
756       // At this point we've already replaced all uses of the result from the
757       // loop with ByteCmp. Look through the incoming values to find ByteCmp,
758       // meaning this is a Phi collecting the results of the byte compare.
759       bool ResPhi = false;
760       for (Value *Op : PN.incoming_values())
761         if (Op == ByteCmpRes) {
762           ResPhi = true;
763           break;
764         }
765 
766       // Any PHI that depended upon the result of the byte compare needs a new
767       // incoming value from CmpBB. This is because the original loop will get
768       // deleted.
769       if (ResPhi)
770         PN.addIncoming(ByteCmpRes, CmpBB);
771       else {
772         // There should be no other outside uses of other values in the
773         // original loop. Any incoming values should either:
774         //   1. Be for blocks outside the loop, which aren't interesting. Or ..
775         //   2. These are from blocks in the loop with values defined outside
776         //      the loop. We should a similar incoming value from CmpBB.
777         for (BasicBlock *BB : PN.blocks())
778           if (CurLoop->contains(BB)) {
779             PN.addIncoming(PN.getIncomingValueForBlock(BB), CmpBB);
780             break;
781           }
782       }
783     }
784   };
785 
786   // Ensure all Phis in the successors of CmpBB have an incoming value from it.
787   fixSuccessorPhis(EndBB);
788   if (EndBB != FoundBB)
789     fixSuccessorPhis(FoundBB);
790 
791   // The new CmpBB block isn't part of the loop, but will need to be added to
792   // the outer loop if there is one.
793   if (!CurLoop->isOutermost())
794     CurLoop->getParentLoop()->addBasicBlockToLoop(CmpBB, *LI);
795 
796   if (VerifyLoops && CurLoop->getParentLoop()) {
797     CurLoop->getParentLoop()->verifyLoop();
798     if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI))
799       report_fatal_error("Loops must remain in LCSSA form!");
800   }
801 }
802