xref: /llvm-project/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp (revision dc6c3ba4c4372172f504fcbe440f62932edf1cc1)
137e309f1SMin-Yih Hsu //===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===//
237e309f1SMin-Yih Hsu //
337e309f1SMin-Yih Hsu // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
437e309f1SMin-Yih Hsu // See https://llvm.org/LICENSE.txt for license information.
537e309f1SMin-Yih Hsu // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
637e309f1SMin-Yih Hsu //
737e309f1SMin-Yih Hsu //===----------------------------------------------------------------------===//
837e309f1SMin-Yih Hsu //
937e309f1SMin-Yih Hsu // This pass implements a pass that recognizes certain loop idioms and
1037e309f1SMin-Yih Hsu // transforms them into more optimized versions of the same loop. In cases
1137e309f1SMin-Yih Hsu // where this happens, it can be a significant performance win.
1237e309f1SMin-Yih Hsu //
1337e309f1SMin-Yih Hsu // We currently only recognize one loop that finds the first mismatched byte
1437e309f1SMin-Yih Hsu // in an array and returns the index, i.e. something like:
1537e309f1SMin-Yih Hsu //
1637e309f1SMin-Yih Hsu //  while (++i != n) {
1737e309f1SMin-Yih Hsu //    if (a[i] != b[i])
1837e309f1SMin-Yih Hsu //      break;
1937e309f1SMin-Yih Hsu //  }
2037e309f1SMin-Yih Hsu //
2137e309f1SMin-Yih Hsu // In this example we can actually vectorize the loop despite the early exit,
2237e309f1SMin-Yih Hsu // although the loop vectorizer does not support it. It requires some extra
2337e309f1SMin-Yih Hsu // checks to deal with the possibility of faulting loads when crossing page
2437e309f1SMin-Yih Hsu // boundaries. However, even with these checks it is still profitable to do the
2537e309f1SMin-Yih Hsu // transformation.
2637e309f1SMin-Yih Hsu //
2737e309f1SMin-Yih Hsu //===----------------------------------------------------------------------===//
2837e309f1SMin-Yih Hsu //
2937e309f1SMin-Yih Hsu // NOTE: This Pass matches a really specific loop pattern because it's only
3037e309f1SMin-Yih Hsu // supposed to be a temporary solution until our LoopVectorizer is powerful
3137e309f1SMin-Yih Hsu // enought to vectorize it automatically.
3237e309f1SMin-Yih Hsu //
3337e309f1SMin-Yih Hsu // TODO List:
3437e309f1SMin-Yih Hsu //
3537e309f1SMin-Yih Hsu // * Add support for the inverse case where we scan for a matching element.
3637e309f1SMin-Yih Hsu // * Permit 64-bit induction variable types.
3737e309f1SMin-Yih Hsu // * Recognize loops that increment the IV *after* comparing bytes.
3837e309f1SMin-Yih Hsu // * Allow 32-bit sign-extends of the IV used by the GEP.
3937e309f1SMin-Yih Hsu //
4037e309f1SMin-Yih Hsu //===----------------------------------------------------------------------===//
4137e309f1SMin-Yih Hsu 
4237e309f1SMin-Yih Hsu #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
4337e309f1SMin-Yih Hsu #include "llvm/Analysis/DomTreeUpdater.h"
4437e309f1SMin-Yih Hsu #include "llvm/Analysis/LoopPass.h"
4537e309f1SMin-Yih Hsu #include "llvm/Analysis/TargetTransformInfo.h"
4637e309f1SMin-Yih Hsu #include "llvm/IR/Dominators.h"
4737e309f1SMin-Yih Hsu #include "llvm/IR/IRBuilder.h"
4837e309f1SMin-Yih Hsu #include "llvm/IR/Intrinsics.h"
4937e309f1SMin-Yih Hsu #include "llvm/IR/MDBuilder.h"
5037e309f1SMin-Yih Hsu #include "llvm/IR/PatternMatch.h"
5137e309f1SMin-Yih Hsu #include "llvm/Transforms/Utils/BasicBlockUtils.h"
5237e309f1SMin-Yih Hsu 
5337e309f1SMin-Yih Hsu using namespace llvm;
5437e309f1SMin-Yih Hsu using namespace PatternMatch;
5537e309f1SMin-Yih Hsu 
5637e309f1SMin-Yih Hsu #define DEBUG_TYPE "loop-idiom-vectorize"
5737e309f1SMin-Yih Hsu 
5837e309f1SMin-Yih Hsu static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
5937e309f1SMin-Yih Hsu                                 cl::init(false),
6037e309f1SMin-Yih Hsu                                 cl::desc("Disable Loop Idiom Vectorize Pass."));
6137e309f1SMin-Yih Hsu 
628b55d342SMin-Yih Hsu static cl::opt<LoopIdiomVectorizeStyle>
638b55d342SMin-Yih Hsu     LITVecStyle("loop-idiom-vectorize-style", cl::Hidden,
648b55d342SMin-Yih Hsu                 cl::desc("The vectorization style for loop idiom transform."),
658b55d342SMin-Yih Hsu                 cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked",
668b55d342SMin-Yih Hsu                                       "Use masked vector intrinsics"),
678b55d342SMin-Yih Hsu                            clEnumValN(LoopIdiomVectorizeStyle::Predicated,
688b55d342SMin-Yih Hsu                                       "predicated", "Use VP intrinsics")),
698b55d342SMin-Yih Hsu                 cl::init(LoopIdiomVectorizeStyle::Masked));
708b55d342SMin-Yih Hsu 
7137e309f1SMin-Yih Hsu static cl::opt<bool>
7237e309f1SMin-Yih Hsu     DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden,
7337e309f1SMin-Yih Hsu                    cl::init(false),
7437e309f1SMin-Yih Hsu                    cl::desc("Proceed with Loop Idiom Vectorize Pass, but do "
7537e309f1SMin-Yih Hsu                             "not convert byte-compare loop(s)."));
7637e309f1SMin-Yih Hsu 
778b55d342SMin-Yih Hsu static cl::opt<unsigned>
788b55d342SMin-Yih Hsu     ByteCmpVF("loop-idiom-vectorize-bytecmp-vf", cl::Hidden,
798b55d342SMin-Yih Hsu               cl::desc("The vectorization factor for byte-compare patterns."),
808b55d342SMin-Yih Hsu               cl::init(16));
818b55d342SMin-Yih Hsu 
8237e309f1SMin-Yih Hsu static cl::opt<bool>
8337e309f1SMin-Yih Hsu     VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false),
8437e309f1SMin-Yih Hsu                 cl::desc("Verify loops generated Loop Idiom Vectorize Pass."));
8537e309f1SMin-Yih Hsu 
8637e309f1SMin-Yih Hsu namespace {
8737e309f1SMin-Yih Hsu class LoopIdiomVectorize {
888b55d342SMin-Yih Hsu   LoopIdiomVectorizeStyle VectorizeStyle;
898b55d342SMin-Yih Hsu   unsigned ByteCompareVF;
9037e309f1SMin-Yih Hsu   Loop *CurLoop = nullptr;
9137e309f1SMin-Yih Hsu   DominatorTree *DT;
9237e309f1SMin-Yih Hsu   LoopInfo *LI;
9337e309f1SMin-Yih Hsu   const TargetTransformInfo *TTI;
9437e309f1SMin-Yih Hsu   const DataLayout *DL;
9537e309f1SMin-Yih Hsu 
96de5ff38aSMin-Yih Hsu   // Blocks that will be used for inserting vectorized code.
97de5ff38aSMin-Yih Hsu   BasicBlock *EndBlock = nullptr;
98de5ff38aSMin-Yih Hsu   BasicBlock *VectorLoopPreheaderBlock = nullptr;
99de5ff38aSMin-Yih Hsu   BasicBlock *VectorLoopStartBlock = nullptr;
100de5ff38aSMin-Yih Hsu   BasicBlock *VectorLoopMismatchBlock = nullptr;
101de5ff38aSMin-Yih Hsu   BasicBlock *VectorLoopIncBlock = nullptr;
102de5ff38aSMin-Yih Hsu 
10337e309f1SMin-Yih Hsu public:
1048b55d342SMin-Yih Hsu   LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT,
1058b55d342SMin-Yih Hsu                      LoopInfo *LI, const TargetTransformInfo *TTI,
10637e309f1SMin-Yih Hsu                      const DataLayout *DL)
1078b55d342SMin-Yih Hsu       : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
1088b55d342SMin-Yih Hsu   }
10937e309f1SMin-Yih Hsu 
11037e309f1SMin-Yih Hsu   bool run(Loop *L);
11137e309f1SMin-Yih Hsu 
11237e309f1SMin-Yih Hsu private:
11337e309f1SMin-Yih Hsu   /// \name Countable Loop Idiom Handling
11437e309f1SMin-Yih Hsu   /// @{
11537e309f1SMin-Yih Hsu 
11637e309f1SMin-Yih Hsu   bool runOnCountableLoop();
11737e309f1SMin-Yih Hsu   bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
11837e309f1SMin-Yih Hsu                       SmallVectorImpl<BasicBlock *> &ExitBlocks);
11937e309f1SMin-Yih Hsu 
12037e309f1SMin-Yih Hsu   bool recognizeByteCompare();
121de5ff38aSMin-Yih Hsu 
12237e309f1SMin-Yih Hsu   Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
12337e309f1SMin-Yih Hsu                             GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
12437e309f1SMin-Yih Hsu                             Instruction *Index, Value *Start, Value *MaxLen);
125de5ff38aSMin-Yih Hsu 
126de5ff38aSMin-Yih Hsu   Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
127de5ff38aSMin-Yih Hsu                                   GetElementPtrInst *GEPA,
128de5ff38aSMin-Yih Hsu                                   GetElementPtrInst *GEPB, Value *ExtStart,
129de5ff38aSMin-Yih Hsu                                   Value *ExtEnd);
1308b55d342SMin-Yih Hsu   Value *createPredicatedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
1318b55d342SMin-Yih Hsu                                       GetElementPtrInst *GEPA,
1328b55d342SMin-Yih Hsu                                       GetElementPtrInst *GEPB, Value *ExtStart,
1338b55d342SMin-Yih Hsu                                       Value *ExtEnd);
134de5ff38aSMin-Yih Hsu 
13537e309f1SMin-Yih Hsu   void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
13637e309f1SMin-Yih Hsu                             PHINode *IndPhi, Value *MaxLen, Instruction *Index,
13737e309f1SMin-Yih Hsu                             Value *Start, bool IncIdx, BasicBlock *FoundBB,
13837e309f1SMin-Yih Hsu                             BasicBlock *EndBB);
13937e309f1SMin-Yih Hsu   /// @}
14037e309f1SMin-Yih Hsu };
14137e309f1SMin-Yih Hsu } // anonymous namespace
14237e309f1SMin-Yih Hsu 
14337e309f1SMin-Yih Hsu PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
14437e309f1SMin-Yih Hsu                                               LoopStandardAnalysisResults &AR,
14537e309f1SMin-Yih Hsu                                               LPMUpdater &) {
14637e309f1SMin-Yih Hsu   if (DisableAll)
14737e309f1SMin-Yih Hsu     return PreservedAnalyses::all();
14837e309f1SMin-Yih Hsu 
1492d209d96SNikita Popov   const auto *DL = &L.getHeader()->getDataLayout();
15037e309f1SMin-Yih Hsu 
1518b55d342SMin-Yih Hsu   LoopIdiomVectorizeStyle VecStyle = VectorizeStyle;
1528b55d342SMin-Yih Hsu   if (LITVecStyle.getNumOccurrences())
1538b55d342SMin-Yih Hsu     VecStyle = LITVecStyle;
1548b55d342SMin-Yih Hsu 
1558b55d342SMin-Yih Hsu   unsigned BCVF = ByteCompareVF;
1568b55d342SMin-Yih Hsu   if (ByteCmpVF.getNumOccurrences())
1578b55d342SMin-Yih Hsu     BCVF = ByteCmpVF;
1588b55d342SMin-Yih Hsu 
1598b55d342SMin-Yih Hsu   LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL);
1608b55d342SMin-Yih Hsu   if (!LIV.run(&L))
16137e309f1SMin-Yih Hsu     return PreservedAnalyses::all();
16237e309f1SMin-Yih Hsu 
16337e309f1SMin-Yih Hsu   return PreservedAnalyses::none();
16437e309f1SMin-Yih Hsu }
16537e309f1SMin-Yih Hsu 
16637e309f1SMin-Yih Hsu //===----------------------------------------------------------------------===//
16737e309f1SMin-Yih Hsu //
16837e309f1SMin-Yih Hsu //          Implementation of LoopIdiomVectorize
16937e309f1SMin-Yih Hsu //
17037e309f1SMin-Yih Hsu //===----------------------------------------------------------------------===//
17137e309f1SMin-Yih Hsu 
17237e309f1SMin-Yih Hsu bool LoopIdiomVectorize::run(Loop *L) {
17337e309f1SMin-Yih Hsu   CurLoop = L;
17437e309f1SMin-Yih Hsu 
17537e309f1SMin-Yih Hsu   Function &F = *L->getHeader()->getParent();
17637e309f1SMin-Yih Hsu   if (DisableAll || F.hasOptSize())
17737e309f1SMin-Yih Hsu     return false;
17837e309f1SMin-Yih Hsu 
17937e309f1SMin-Yih Hsu   if (F.hasFnAttribute(Attribute::NoImplicitFloat)) {
18037e309f1SMin-Yih Hsu     LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName()
18137e309f1SMin-Yih Hsu                       << " due to its NoImplicitFloat attribute");
18237e309f1SMin-Yih Hsu     return false;
18337e309f1SMin-Yih Hsu   }
18437e309f1SMin-Yih Hsu 
18537e309f1SMin-Yih Hsu   // If the loop could not be converted to canonical form, it must have an
18637e309f1SMin-Yih Hsu   // indirectbr in it, just give up.
18737e309f1SMin-Yih Hsu   if (!L->getLoopPreheader())
18837e309f1SMin-Yih Hsu     return false;
18937e309f1SMin-Yih Hsu 
19037e309f1SMin-Yih Hsu   LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %"
19137e309f1SMin-Yih Hsu                     << CurLoop->getHeader()->getName() << "\n");
19237e309f1SMin-Yih Hsu 
19337e309f1SMin-Yih Hsu   return recognizeByteCompare();
19437e309f1SMin-Yih Hsu }
19537e309f1SMin-Yih Hsu 
19637e309f1SMin-Yih Hsu bool LoopIdiomVectorize::recognizeByteCompare() {
19737e309f1SMin-Yih Hsu   // Currently the transformation only works on scalable vector types, although
19837e309f1SMin-Yih Hsu   // there is no fundamental reason why it cannot be made to work for fixed
19937e309f1SMin-Yih Hsu   // width too.
20037e309f1SMin-Yih Hsu 
20137e309f1SMin-Yih Hsu   // We also need to know the minimum page size for the target in order to
20237e309f1SMin-Yih Hsu   // generate runtime memory checks to ensure the vector version won't fault.
20337e309f1SMin-Yih Hsu   if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
20437e309f1SMin-Yih Hsu       DisableByteCmp)
20537e309f1SMin-Yih Hsu     return false;
20637e309f1SMin-Yih Hsu 
20737e309f1SMin-Yih Hsu   BasicBlock *Header = CurLoop->getHeader();
20837e309f1SMin-Yih Hsu 
20937e309f1SMin-Yih Hsu   // In LoopIdiomVectorize::run we have already checked that the loop
21037e309f1SMin-Yih Hsu   // has a preheader so we can assume it's in a canonical form.
21137e309f1SMin-Yih Hsu   if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
21237e309f1SMin-Yih Hsu     return false;
21337e309f1SMin-Yih Hsu 
21437e309f1SMin-Yih Hsu   PHINode *PN = dyn_cast<PHINode>(&Header->front());
21537e309f1SMin-Yih Hsu   if (!PN || PN->getNumIncomingValues() != 2)
21637e309f1SMin-Yih Hsu     return false;
21737e309f1SMin-Yih Hsu 
21837e309f1SMin-Yih Hsu   auto LoopBlocks = CurLoop->getBlocks();
21937e309f1SMin-Yih Hsu   // The first block in the loop should contain only 4 instructions, e.g.
22037e309f1SMin-Yih Hsu   //
22137e309f1SMin-Yih Hsu   //  while.cond:
22237e309f1SMin-Yih Hsu   //   %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ]
22337e309f1SMin-Yih Hsu   //   %inc = add i32 %res.phi, 1
22437e309f1SMin-Yih Hsu   //   %cmp.not = icmp eq i32 %inc, %n
22537e309f1SMin-Yih Hsu   //   br i1 %cmp.not, label %while.end, label %while.body
22637e309f1SMin-Yih Hsu   //
22737e309f1SMin-Yih Hsu   if (LoopBlocks[0]->sizeWithoutDebug() > 4)
22837e309f1SMin-Yih Hsu     return false;
22937e309f1SMin-Yih Hsu 
23037e309f1SMin-Yih Hsu   // The second block should contain 7 instructions, e.g.
23137e309f1SMin-Yih Hsu   //
23237e309f1SMin-Yih Hsu   // while.body:
23337e309f1SMin-Yih Hsu   //   %idx = zext i32 %inc to i64
23437e309f1SMin-Yih Hsu   //   %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx
23537e309f1SMin-Yih Hsu   //   %load.a = load i8, ptr %idx.a
23637e309f1SMin-Yih Hsu   //   %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx
23737e309f1SMin-Yih Hsu   //   %load.b = load i8, ptr %idx.b
23837e309f1SMin-Yih Hsu   //   %cmp.not.ld = icmp eq i8 %load.a, %load.b
23937e309f1SMin-Yih Hsu   //   br i1 %cmp.not.ld, label %while.cond, label %while.end
24037e309f1SMin-Yih Hsu   //
24137e309f1SMin-Yih Hsu   if (LoopBlocks[1]->sizeWithoutDebug() > 7)
24237e309f1SMin-Yih Hsu     return false;
24337e309f1SMin-Yih Hsu 
24437e309f1SMin-Yih Hsu   // The incoming value to the PHI node from the loop should be an add of 1.
24537e309f1SMin-Yih Hsu   Value *StartIdx = nullptr;
24637e309f1SMin-Yih Hsu   Instruction *Index = nullptr;
24737e309f1SMin-Yih Hsu   if (!CurLoop->contains(PN->getIncomingBlock(0))) {
24837e309f1SMin-Yih Hsu     StartIdx = PN->getIncomingValue(0);
24937e309f1SMin-Yih Hsu     Index = dyn_cast<Instruction>(PN->getIncomingValue(1));
25037e309f1SMin-Yih Hsu   } else {
25137e309f1SMin-Yih Hsu     StartIdx = PN->getIncomingValue(1);
25237e309f1SMin-Yih Hsu     Index = dyn_cast<Instruction>(PN->getIncomingValue(0));
25337e309f1SMin-Yih Hsu   }
25437e309f1SMin-Yih Hsu 
25537e309f1SMin-Yih Hsu   // Limit to 32-bit types for now
25637e309f1SMin-Yih Hsu   if (!Index || !Index->getType()->isIntegerTy(32) ||
25737e309f1SMin-Yih Hsu       !match(Index, m_c_Add(m_Specific(PN), m_One())))
25837e309f1SMin-Yih Hsu     return false;
25937e309f1SMin-Yih Hsu 
26037e309f1SMin-Yih Hsu   // If we match the pattern, PN and Index will be replaced with the result of
26137e309f1SMin-Yih Hsu   // the cttz.elts intrinsic. If any other instructions are used outside of
26237e309f1SMin-Yih Hsu   // the loop, we cannot replace it.
26337e309f1SMin-Yih Hsu   for (BasicBlock *BB : LoopBlocks)
26437e309f1SMin-Yih Hsu     for (Instruction &I : *BB)
26537e309f1SMin-Yih Hsu       if (&I != PN && &I != Index)
26637e309f1SMin-Yih Hsu         for (User *U : I.users())
26737e309f1SMin-Yih Hsu           if (!CurLoop->contains(cast<Instruction>(U)))
26837e309f1SMin-Yih Hsu             return false;
26937e309f1SMin-Yih Hsu 
27037e309f1SMin-Yih Hsu   // Match the branch instruction for the header
27137e309f1SMin-Yih Hsu   Value *MaxLen;
27237e309f1SMin-Yih Hsu   BasicBlock *EndBB, *WhileBB;
27337e309f1SMin-Yih Hsu   if (!match(Header->getTerminator(),
27462e9f409SYingwei Zheng              m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(Index),
27562e9f409SYingwei Zheng                                  m_Value(MaxLen)),
27637e309f1SMin-Yih Hsu                   m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) ||
27762e9f409SYingwei Zheng       !CurLoop->contains(WhileBB))
27837e309f1SMin-Yih Hsu     return false;
27937e309f1SMin-Yih Hsu 
28037e309f1SMin-Yih Hsu   // WhileBB should contain the pattern of load & compare instructions. Match
28137e309f1SMin-Yih Hsu   // the pattern and find the GEP instructions used by the loads.
28237e309f1SMin-Yih Hsu   BasicBlock *FoundBB;
28337e309f1SMin-Yih Hsu   BasicBlock *TrueBB;
28437e309f1SMin-Yih Hsu   Value *LoadA, *LoadB;
28537e309f1SMin-Yih Hsu   if (!match(WhileBB->getTerminator(),
28662e9f409SYingwei Zheng              m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(LoadA),
28762e9f409SYingwei Zheng                                  m_Value(LoadB)),
28837e309f1SMin-Yih Hsu                   m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
28962e9f409SYingwei Zheng       !CurLoop->contains(TrueBB))
29037e309f1SMin-Yih Hsu     return false;
29137e309f1SMin-Yih Hsu 
29237e309f1SMin-Yih Hsu   Value *A, *B;
29337e309f1SMin-Yih Hsu   if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
29437e309f1SMin-Yih Hsu     return false;
29537e309f1SMin-Yih Hsu 
29637e309f1SMin-Yih Hsu   LoadInst *LoadAI = cast<LoadInst>(LoadA);
29737e309f1SMin-Yih Hsu   LoadInst *LoadBI = cast<LoadInst>(LoadB);
29837e309f1SMin-Yih Hsu   if (!LoadAI->isSimple() || !LoadBI->isSimple())
29937e309f1SMin-Yih Hsu     return false;
30037e309f1SMin-Yih Hsu 
30137e309f1SMin-Yih Hsu   GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A);
30237e309f1SMin-Yih Hsu   GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B);
30337e309f1SMin-Yih Hsu 
30437e309f1SMin-Yih Hsu   if (!GEPA || !GEPB)
30537e309f1SMin-Yih Hsu     return false;
30637e309f1SMin-Yih Hsu 
30737e309f1SMin-Yih Hsu   Value *PtrA = GEPA->getPointerOperand();
30837e309f1SMin-Yih Hsu   Value *PtrB = GEPB->getPointerOperand();
30937e309f1SMin-Yih Hsu 
31037e309f1SMin-Yih Hsu   // Check we are loading i8 values from two loop invariant pointers
31137e309f1SMin-Yih Hsu   if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
31237e309f1SMin-Yih Hsu       !GEPA->getResultElementType()->isIntegerTy(8) ||
31337e309f1SMin-Yih Hsu       !GEPB->getResultElementType()->isIntegerTy(8) ||
31437e309f1SMin-Yih Hsu       !LoadAI->getType()->isIntegerTy(8) ||
31537e309f1SMin-Yih Hsu       !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB)
31637e309f1SMin-Yih Hsu     return false;
31737e309f1SMin-Yih Hsu 
31837e309f1SMin-Yih Hsu   // Check that the index to the GEPs is the index we found earlier
31937e309f1SMin-Yih Hsu   if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
32037e309f1SMin-Yih Hsu     return false;
32137e309f1SMin-Yih Hsu 
32237e309f1SMin-Yih Hsu   Value *IdxA = GEPA->getOperand(GEPA->getNumIndices());
32337e309f1SMin-Yih Hsu   Value *IdxB = GEPB->getOperand(GEPB->getNumIndices());
32437e309f1SMin-Yih Hsu   if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index))))
32537e309f1SMin-Yih Hsu     return false;
32637e309f1SMin-Yih Hsu 
32737e309f1SMin-Yih Hsu   // We only ever expect the pre-incremented index value to be used inside the
32837e309f1SMin-Yih Hsu   // loop.
32937e309f1SMin-Yih Hsu   if (!PN->hasOneUse())
33037e309f1SMin-Yih Hsu     return false;
33137e309f1SMin-Yih Hsu 
33237e309f1SMin-Yih Hsu   // Ensure that when the Found and End blocks are identical the PHIs have the
33337e309f1SMin-Yih Hsu   // supported format. We don't currently allow cases like this:
33437e309f1SMin-Yih Hsu   // while.cond:
33537e309f1SMin-Yih Hsu   //   ...
33637e309f1SMin-Yih Hsu   //   br i1 %cmp.not, label %while.end, label %while.body
33737e309f1SMin-Yih Hsu   //
33837e309f1SMin-Yih Hsu   // while.body:
33937e309f1SMin-Yih Hsu   //   ...
34037e309f1SMin-Yih Hsu   //   br i1 %cmp.not2, label %while.cond, label %while.end
34137e309f1SMin-Yih Hsu   //
34237e309f1SMin-Yih Hsu   // while.end:
34337e309f1SMin-Yih Hsu   //   %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
34437e309f1SMin-Yih Hsu   //
34537e309f1SMin-Yih Hsu   // Where the incoming values for %final_ptr are unique and from each of the
34637e309f1SMin-Yih Hsu   // loop blocks, but not actually defined in the loop. This requires extra
34737e309f1SMin-Yih Hsu   // work setting up the byte.compare block, i.e. by introducing a select to
34837e309f1SMin-Yih Hsu   // choose the correct value.
34937e309f1SMin-Yih Hsu   // TODO: We could add support for this in future.
35037e309f1SMin-Yih Hsu   if (FoundBB == EndBB) {
35137e309f1SMin-Yih Hsu     for (PHINode &EndPN : EndBB->phis()) {
35237e309f1SMin-Yih Hsu       Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header);
35337e309f1SMin-Yih Hsu       Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB);
35437e309f1SMin-Yih Hsu 
35537e309f1SMin-Yih Hsu       // The value of the index when leaving the while.cond block is always the
35637e309f1SMin-Yih Hsu       // same as the end value (MaxLen) so we permit either. The value when
35737e309f1SMin-Yih Hsu       // leaving the while.body block should only be the index. Otherwise for
35837e309f1SMin-Yih Hsu       // any other values we only allow ones that are same for both blocks.
35937e309f1SMin-Yih Hsu       if (WhileCondVal != WhileBodyVal &&
36037e309f1SMin-Yih Hsu           ((WhileCondVal != Index && WhileCondVal != MaxLen) ||
36137e309f1SMin-Yih Hsu            (WhileBodyVal != Index)))
36237e309f1SMin-Yih Hsu         return false;
36337e309f1SMin-Yih Hsu     }
36437e309f1SMin-Yih Hsu   }
36537e309f1SMin-Yih Hsu 
36637e309f1SMin-Yih Hsu   LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n"
36737e309f1SMin-Yih Hsu                     << *(EndBB->getParent()) << "\n\n");
36837e309f1SMin-Yih Hsu 
36937e309f1SMin-Yih Hsu   // The index is incremented before the GEP/Load pair so we need to
37037e309f1SMin-Yih Hsu   // add 1 to the start value.
37137e309f1SMin-Yih Hsu   transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true,
37237e309f1SMin-Yih Hsu                        FoundBB, EndBB);
37337e309f1SMin-Yih Hsu   return true;
37437e309f1SMin-Yih Hsu }
37537e309f1SMin-Yih Hsu 
376de5ff38aSMin-Yih Hsu Value *LoopIdiomVectorize::createMaskedFindMismatch(
377de5ff38aSMin-Yih Hsu     IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
378de5ff38aSMin-Yih Hsu     GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
379de5ff38aSMin-Yih Hsu   Type *I64Type = Builder.getInt64Ty();
380de5ff38aSMin-Yih Hsu   Type *ResType = Builder.getInt32Ty();
381de5ff38aSMin-Yih Hsu   Type *LoadType = Builder.getInt8Ty();
382de5ff38aSMin-Yih Hsu   Value *PtrA = GEPA->getPointerOperand();
383de5ff38aSMin-Yih Hsu   Value *PtrB = GEPB->getPointerOperand();
384de5ff38aSMin-Yih Hsu 
385de5ff38aSMin-Yih Hsu   ScalableVectorType *PredVTy =
3868b55d342SMin-Yih Hsu       ScalableVectorType::get(Builder.getInt1Ty(), ByteCompareVF);
387de5ff38aSMin-Yih Hsu 
388de5ff38aSMin-Yih Hsu   Value *InitialPred = Builder.CreateIntrinsic(
389de5ff38aSMin-Yih Hsu       Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
390de5ff38aSMin-Yih Hsu 
391de5ff38aSMin-Yih Hsu   Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
3928b55d342SMin-Yih Hsu   VecLen =
3938b55d342SMin-Yih Hsu       Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "",
394de5ff38aSMin-Yih Hsu                         /*HasNUW=*/true, /*HasNSW=*/true);
395de5ff38aSMin-Yih Hsu 
396de5ff38aSMin-Yih Hsu   Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
397de5ff38aSMin-Yih Hsu                                             Builder.getInt1(false));
398de5ff38aSMin-Yih Hsu 
399de5ff38aSMin-Yih Hsu   BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
400de5ff38aSMin-Yih Hsu   Builder.Insert(JumpToVectorLoop);
401de5ff38aSMin-Yih Hsu 
402de5ff38aSMin-Yih Hsu   DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
403de5ff38aSMin-Yih Hsu                      VectorLoopStartBlock}});
404de5ff38aSMin-Yih Hsu 
405de5ff38aSMin-Yih Hsu   // Set up the first vector loop block by creating the PHIs, doing the vector
406de5ff38aSMin-Yih Hsu   // loads and comparing the vectors.
407de5ff38aSMin-Yih Hsu   Builder.SetInsertPoint(VectorLoopStartBlock);
408de5ff38aSMin-Yih Hsu   PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
409de5ff38aSMin-Yih Hsu   LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
410de5ff38aSMin-Yih Hsu   PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
411de5ff38aSMin-Yih Hsu   VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
4128b55d342SMin-Yih Hsu   Type *VectorLoadType =
4138b55d342SMin-Yih Hsu       ScalableVectorType::get(Builder.getInt8Ty(), ByteCompareVF);
414de5ff38aSMin-Yih Hsu   Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
415de5ff38aSMin-Yih Hsu 
416de5ff38aSMin-Yih Hsu   Value *VectorLhsGep =
417de5ff38aSMin-Yih Hsu       Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
418de5ff38aSMin-Yih Hsu   Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
419de5ff38aSMin-Yih Hsu                                                   Align(1), LoopPred, Passthru);
420de5ff38aSMin-Yih Hsu 
421de5ff38aSMin-Yih Hsu   Value *VectorRhsGep =
422de5ff38aSMin-Yih Hsu       Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
423de5ff38aSMin-Yih Hsu   Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
424de5ff38aSMin-Yih Hsu                                                   Align(1), LoopPred, Passthru);
425de5ff38aSMin-Yih Hsu 
426de5ff38aSMin-Yih Hsu   Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
427de5ff38aSMin-Yih Hsu   VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
428de5ff38aSMin-Yih Hsu   Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
429de5ff38aSMin-Yih Hsu   BranchInst *VectorEarlyExit = BranchInst::Create(
430de5ff38aSMin-Yih Hsu       VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
431de5ff38aSMin-Yih Hsu   Builder.Insert(VectorEarlyExit);
432de5ff38aSMin-Yih Hsu 
433de5ff38aSMin-Yih Hsu   DTU.applyUpdates(
434de5ff38aSMin-Yih Hsu       {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
435de5ff38aSMin-Yih Hsu        {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
436de5ff38aSMin-Yih Hsu 
437de5ff38aSMin-Yih Hsu   // Increment the index counter and calculate the predicate for the next
438de5ff38aSMin-Yih Hsu   // iteration of the loop. We branch back to the start of the loop if there
439de5ff38aSMin-Yih Hsu   // is at least one active lane.
440de5ff38aSMin-Yih Hsu   Builder.SetInsertPoint(VectorLoopIncBlock);
441de5ff38aSMin-Yih Hsu   Value *NewVectorIndexPhi =
442de5ff38aSMin-Yih Hsu       Builder.CreateAdd(VectorIndexPhi, VecLen, "",
443de5ff38aSMin-Yih Hsu                         /*HasNUW=*/true, /*HasNSW=*/true);
444de5ff38aSMin-Yih Hsu   VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
445de5ff38aSMin-Yih Hsu   Value *NewPred =
446de5ff38aSMin-Yih Hsu       Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
447de5ff38aSMin-Yih Hsu                               {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
448de5ff38aSMin-Yih Hsu   LoopPred->addIncoming(NewPred, VectorLoopIncBlock);
449de5ff38aSMin-Yih Hsu 
450de5ff38aSMin-Yih Hsu   Value *PredHasActiveLanes =
451de5ff38aSMin-Yih Hsu       Builder.CreateExtractElement(NewPred, uint64_t(0));
452de5ff38aSMin-Yih Hsu   BranchInst *VectorLoopBranchBack =
453de5ff38aSMin-Yih Hsu       BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
454de5ff38aSMin-Yih Hsu   Builder.Insert(VectorLoopBranchBack);
455de5ff38aSMin-Yih Hsu 
456de5ff38aSMin-Yih Hsu   DTU.applyUpdates(
457de5ff38aSMin-Yih Hsu       {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
458de5ff38aSMin-Yih Hsu        {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
459de5ff38aSMin-Yih Hsu 
460de5ff38aSMin-Yih Hsu   // If we found a mismatch then we need to calculate which lane in the vector
461de5ff38aSMin-Yih Hsu   // had a mismatch and add that on to the current loop index.
462de5ff38aSMin-Yih Hsu   Builder.SetInsertPoint(VectorLoopMismatchBlock);
463de5ff38aSMin-Yih Hsu   PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
464de5ff38aSMin-Yih Hsu   FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
465de5ff38aSMin-Yih Hsu   PHINode *LastLoopPred =
466de5ff38aSMin-Yih Hsu       Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
467de5ff38aSMin-Yih Hsu   LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
468de5ff38aSMin-Yih Hsu   PHINode *VectorFoundIndex =
469de5ff38aSMin-Yih Hsu       Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
470de5ff38aSMin-Yih Hsu   VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
471de5ff38aSMin-Yih Hsu 
472de5ff38aSMin-Yih Hsu   Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
473*dc6c3ba4SDavid Sherwood   Value *Ctz = Builder.CreateCountTrailingZeroElems(ResType, PredMatchCmp);
474de5ff38aSMin-Yih Hsu   Ctz = Builder.CreateZExt(Ctz, I64Type);
475de5ff38aSMin-Yih Hsu   Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
476de5ff38aSMin-Yih Hsu                                              /*HasNUW=*/true, /*HasNSW=*/true);
477de5ff38aSMin-Yih Hsu   return Builder.CreateTrunc(VectorLoopRes64, ResType);
478de5ff38aSMin-Yih Hsu }
479de5ff38aSMin-Yih Hsu 
4808b55d342SMin-Yih Hsu Value *LoopIdiomVectorize::createPredicatedFindMismatch(
4818b55d342SMin-Yih Hsu     IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
4828b55d342SMin-Yih Hsu     GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
4838b55d342SMin-Yih Hsu   Type *I64Type = Builder.getInt64Ty();
4848b55d342SMin-Yih Hsu   Type *I32Type = Builder.getInt32Ty();
4858b55d342SMin-Yih Hsu   Type *ResType = I32Type;
4868b55d342SMin-Yih Hsu   Type *LoadType = Builder.getInt8Ty();
4878b55d342SMin-Yih Hsu   Value *PtrA = GEPA->getPointerOperand();
4888b55d342SMin-Yih Hsu   Value *PtrB = GEPB->getPointerOperand();
4898b55d342SMin-Yih Hsu 
4908b55d342SMin-Yih Hsu   auto *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
4918b55d342SMin-Yih Hsu   Builder.Insert(JumpToVectorLoop);
4928b55d342SMin-Yih Hsu 
4938b55d342SMin-Yih Hsu   DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
4948b55d342SMin-Yih Hsu                      VectorLoopStartBlock}});
4958b55d342SMin-Yih Hsu 
4968b55d342SMin-Yih Hsu   // Set up the first Vector loop block by creating the PHIs, doing the vector
4978b55d342SMin-Yih Hsu   // loads and comparing the vectors.
4988b55d342SMin-Yih Hsu   Builder.SetInsertPoint(VectorLoopStartBlock);
4998b55d342SMin-Yih Hsu   auto *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vector_index");
5008b55d342SMin-Yih Hsu   VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
5018b55d342SMin-Yih Hsu 
5028b55d342SMin-Yih Hsu   // Calculate AVL by subtracting the vector loop index from the trip count
5038b55d342SMin-Yih Hsu   Value *AVL = Builder.CreateSub(ExtEnd, VectorIndexPhi, "avl", /*HasNUW=*/true,
5048b55d342SMin-Yih Hsu                                  /*HasNSW=*/true);
5058b55d342SMin-Yih Hsu 
5068b55d342SMin-Yih Hsu   auto *VectorLoadType = ScalableVectorType::get(LoadType, ByteCompareVF);
5078b55d342SMin-Yih Hsu   auto *VF = ConstantInt::get(I32Type, ByteCompareVF);
5088b55d342SMin-Yih Hsu 
5098b55d342SMin-Yih Hsu   Value *VL = Builder.CreateIntrinsic(Intrinsic::experimental_get_vector_length,
5108b55d342SMin-Yih Hsu                                       {I64Type}, {AVL, VF, Builder.getTrue()});
5118b55d342SMin-Yih Hsu   Value *GepOffset = VectorIndexPhi;
5128b55d342SMin-Yih Hsu 
5138b55d342SMin-Yih Hsu   Value *VectorLhsGep =
5148b55d342SMin-Yih Hsu       Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds());
5158b55d342SMin-Yih Hsu   VectorType *TrueMaskTy =
5168b55d342SMin-Yih Hsu       VectorType::get(Builder.getInt1Ty(), VectorLoadType->getElementCount());
5178b55d342SMin-Yih Hsu   Value *AllTrueMask = Constant::getAllOnesValue(TrueMaskTy);
5188b55d342SMin-Yih Hsu   Value *VectorLhsLoad = Builder.CreateIntrinsic(
5198b55d342SMin-Yih Hsu       Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()},
5208b55d342SMin-Yih Hsu       {VectorLhsGep, AllTrueMask, VL}, nullptr, "lhs.load");
5218b55d342SMin-Yih Hsu 
5228b55d342SMin-Yih Hsu   Value *VectorRhsGep =
5238b55d342SMin-Yih Hsu       Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds());
5248b55d342SMin-Yih Hsu   Value *VectorRhsLoad = Builder.CreateIntrinsic(
5258b55d342SMin-Yih Hsu       Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()},
5268b55d342SMin-Yih Hsu       {VectorRhsGep, AllTrueMask, VL}, nullptr, "rhs.load");
5278b55d342SMin-Yih Hsu 
5288b55d342SMin-Yih Hsu   StringRef PredicateStr = CmpInst::getPredicateName(CmpInst::ICMP_NE);
5298b55d342SMin-Yih Hsu   auto *PredicateMDS = MDString::get(VectorLhsLoad->getContext(), PredicateStr);
5308b55d342SMin-Yih Hsu   Value *Pred = MetadataAsValue::get(VectorLhsLoad->getContext(), PredicateMDS);
5318b55d342SMin-Yih Hsu   Value *VectorMatchCmp = Builder.CreateIntrinsic(
5328b55d342SMin-Yih Hsu       Intrinsic::vp_icmp, {VectorLhsLoad->getType()},
5338b55d342SMin-Yih Hsu       {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr,
5348b55d342SMin-Yih Hsu       "mismatch.cmp");
5358b55d342SMin-Yih Hsu   Value *CTZ = Builder.CreateIntrinsic(
5368b55d342SMin-Yih Hsu       Intrinsic::vp_cttz_elts, {ResType, VectorMatchCmp->getType()},
5378b55d342SMin-Yih Hsu       {VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(false), AllTrueMask,
5388b55d342SMin-Yih Hsu        VL});
5398b55d342SMin-Yih Hsu   Value *MismatchFound = Builder.CreateICmpNE(CTZ, VL);
5408b55d342SMin-Yih Hsu   auto *VectorEarlyExit = BranchInst::Create(VectorLoopMismatchBlock,
5418b55d342SMin-Yih Hsu                                              VectorLoopIncBlock, MismatchFound);
5428b55d342SMin-Yih Hsu   Builder.Insert(VectorEarlyExit);
5438b55d342SMin-Yih Hsu 
5448b55d342SMin-Yih Hsu   DTU.applyUpdates(
5458b55d342SMin-Yih Hsu       {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
5468b55d342SMin-Yih Hsu        {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
5478b55d342SMin-Yih Hsu 
5488b55d342SMin-Yih Hsu   // Increment the index counter and calculate the predicate for the next
5498b55d342SMin-Yih Hsu   // iteration of the loop. We branch back to the start of the loop if there
5508b55d342SMin-Yih Hsu   // is at least one active lane.
5518b55d342SMin-Yih Hsu   Builder.SetInsertPoint(VectorLoopIncBlock);
5528b55d342SMin-Yih Hsu   Value *VL64 = Builder.CreateZExt(VL, I64Type);
5538b55d342SMin-Yih Hsu   Value *NewVectorIndexPhi =
5548b55d342SMin-Yih Hsu       Builder.CreateAdd(VectorIndexPhi, VL64, "",
5558b55d342SMin-Yih Hsu                         /*HasNUW=*/true, /*HasNSW=*/true);
5568b55d342SMin-Yih Hsu   VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
5578b55d342SMin-Yih Hsu   Value *ExitCond = Builder.CreateICmpNE(NewVectorIndexPhi, ExtEnd);
5588b55d342SMin-Yih Hsu   auto *VectorLoopBranchBack =
5598b55d342SMin-Yih Hsu       BranchInst::Create(VectorLoopStartBlock, EndBlock, ExitCond);
5608b55d342SMin-Yih Hsu   Builder.Insert(VectorLoopBranchBack);
5618b55d342SMin-Yih Hsu 
5628b55d342SMin-Yih Hsu   DTU.applyUpdates(
5638b55d342SMin-Yih Hsu       {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
5648b55d342SMin-Yih Hsu        {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
5658b55d342SMin-Yih Hsu 
5668b55d342SMin-Yih Hsu   // If we found a mismatch then we need to calculate which lane in the vector
5678b55d342SMin-Yih Hsu   // had a mismatch and add that on to the current loop index.
5688b55d342SMin-Yih Hsu   Builder.SetInsertPoint(VectorLoopMismatchBlock);
5698b55d342SMin-Yih Hsu 
5708b55d342SMin-Yih Hsu   // Add LCSSA phis for CTZ and VectorIndexPhi.
5718b55d342SMin-Yih Hsu   auto *CTZLCSSAPhi = Builder.CreatePHI(CTZ->getType(), 1, "ctz");
5728b55d342SMin-Yih Hsu   CTZLCSSAPhi->addIncoming(CTZ, VectorLoopStartBlock);
5738b55d342SMin-Yih Hsu   auto *VectorIndexLCSSAPhi =
5748b55d342SMin-Yih Hsu       Builder.CreatePHI(VectorIndexPhi->getType(), 1, "mismatch_vector_index");
5758b55d342SMin-Yih Hsu   VectorIndexLCSSAPhi->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
5768b55d342SMin-Yih Hsu 
5778b55d342SMin-Yih Hsu   Value *CTZI64 = Builder.CreateZExt(CTZLCSSAPhi, I64Type);
5788b55d342SMin-Yih Hsu   Value *VectorLoopRes64 = Builder.CreateAdd(VectorIndexLCSSAPhi, CTZI64, "",
5798b55d342SMin-Yih Hsu                                              /*HasNUW=*/true, /*HasNSW=*/true);
5808b55d342SMin-Yih Hsu   return Builder.CreateTrunc(VectorLoopRes64, ResType);
5818b55d342SMin-Yih Hsu }
5828b55d342SMin-Yih Hsu 
58337e309f1SMin-Yih Hsu Value *LoopIdiomVectorize::expandFindMismatch(
58437e309f1SMin-Yih Hsu     IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
58537e309f1SMin-Yih Hsu     GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
58637e309f1SMin-Yih Hsu   Value *PtrA = GEPA->getPointerOperand();
58737e309f1SMin-Yih Hsu   Value *PtrB = GEPB->getPointerOperand();
58837e309f1SMin-Yih Hsu 
58937e309f1SMin-Yih Hsu   // Get the arguments and types for the intrinsic.
59037e309f1SMin-Yih Hsu   BasicBlock *Preheader = CurLoop->getLoopPreheader();
59137e309f1SMin-Yih Hsu   BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
59237e309f1SMin-Yih Hsu   LLVMContext &Ctx = PHBranch->getContext();
59337e309f1SMin-Yih Hsu   Type *LoadType = Type::getInt8Ty(Ctx);
59437e309f1SMin-Yih Hsu   Type *ResType = Builder.getInt32Ty();
59537e309f1SMin-Yih Hsu 
59637e309f1SMin-Yih Hsu   // Split block in the original loop preheader.
597de5ff38aSMin-Yih Hsu   EndBlock = SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
59837e309f1SMin-Yih Hsu 
59937e309f1SMin-Yih Hsu   // Create the blocks that we're going to need:
60037e309f1SMin-Yih Hsu   //  1. A block for checking the zero-extended length exceeds 0
60137e309f1SMin-Yih Hsu   //  2. A block to check that the start and end addresses of a given array
60237e309f1SMin-Yih Hsu   //     lie on the same page.
60337e309f1SMin-Yih Hsu   //  3. The vector loop preheader.
60437e309f1SMin-Yih Hsu   //  4. The first vector loop block.
60537e309f1SMin-Yih Hsu   //  5. The vector loop increment block.
60637e309f1SMin-Yih Hsu   //  6. A block we can jump to from the vector loop when a mismatch is found.
60737e309f1SMin-Yih Hsu   //  7. The first block of the scalar loop itself, containing PHIs , loads
60837e309f1SMin-Yih Hsu   //  and cmp.
60937e309f1SMin-Yih Hsu   //  8. A scalar loop increment block to increment the PHIs and go back
61037e309f1SMin-Yih Hsu   //  around the loop.
61137e309f1SMin-Yih Hsu 
61237e309f1SMin-Yih Hsu   BasicBlock *MinItCheckBlock = BasicBlock::Create(
61337e309f1SMin-Yih Hsu       Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock);
61437e309f1SMin-Yih Hsu 
61537e309f1SMin-Yih Hsu   // Update the terminator added by SplitBlock to branch to the first block
61637e309f1SMin-Yih Hsu   Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock);
61737e309f1SMin-Yih Hsu 
61837e309f1SMin-Yih Hsu   BasicBlock *MemCheckBlock = BasicBlock::Create(
61937e309f1SMin-Yih Hsu       Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock);
62037e309f1SMin-Yih Hsu 
621de5ff38aSMin-Yih Hsu   VectorLoopPreheaderBlock = BasicBlock::Create(
62237e309f1SMin-Yih Hsu       Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock);
62337e309f1SMin-Yih Hsu 
624de5ff38aSMin-Yih Hsu   VectorLoopStartBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop",
625de5ff38aSMin-Yih Hsu                                             EndBlock->getParent(), EndBlock);
62637e309f1SMin-Yih Hsu 
627de5ff38aSMin-Yih Hsu   VectorLoopIncBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_inc",
628de5ff38aSMin-Yih Hsu                                           EndBlock->getParent(), EndBlock);
62937e309f1SMin-Yih Hsu 
630de5ff38aSMin-Yih Hsu   VectorLoopMismatchBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_found",
631de5ff38aSMin-Yih Hsu                                                EndBlock->getParent(), EndBlock);
63237e309f1SMin-Yih Hsu 
63337e309f1SMin-Yih Hsu   BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
63437e309f1SMin-Yih Hsu       Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock);
63537e309f1SMin-Yih Hsu 
63637e309f1SMin-Yih Hsu   BasicBlock *LoopStartBlock =
63737e309f1SMin-Yih Hsu       BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock);
63837e309f1SMin-Yih Hsu 
63937e309f1SMin-Yih Hsu   BasicBlock *LoopIncBlock = BasicBlock::Create(
64037e309f1SMin-Yih Hsu       Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock);
64137e309f1SMin-Yih Hsu 
64237e309f1SMin-Yih Hsu   DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock},
64337e309f1SMin-Yih Hsu                     {DominatorTree::Delete, Preheader, EndBlock}});
64437e309f1SMin-Yih Hsu 
64537e309f1SMin-Yih Hsu   // Update LoopInfo with the new vector & scalar loops.
64637e309f1SMin-Yih Hsu   auto VectorLoop = LI->AllocateLoop();
64737e309f1SMin-Yih Hsu   auto ScalarLoop = LI->AllocateLoop();
64837e309f1SMin-Yih Hsu 
64937e309f1SMin-Yih Hsu   if (CurLoop->getParentLoop()) {
65037e309f1SMin-Yih Hsu     CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, *LI);
65137e309f1SMin-Yih Hsu     CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, *LI);
65237e309f1SMin-Yih Hsu     CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopPreheaderBlock,
65337e309f1SMin-Yih Hsu                                                   *LI);
65437e309f1SMin-Yih Hsu     CurLoop->getParentLoop()->addChildLoop(VectorLoop);
65537e309f1SMin-Yih Hsu     CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopMismatchBlock, *LI);
65637e309f1SMin-Yih Hsu     CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, *LI);
65737e309f1SMin-Yih Hsu     CurLoop->getParentLoop()->addChildLoop(ScalarLoop);
65837e309f1SMin-Yih Hsu   } else {
65937e309f1SMin-Yih Hsu     LI->addTopLevelLoop(VectorLoop);
66037e309f1SMin-Yih Hsu     LI->addTopLevelLoop(ScalarLoop);
66137e309f1SMin-Yih Hsu   }
66237e309f1SMin-Yih Hsu 
66337e309f1SMin-Yih Hsu   // Add the new basic blocks to their associated loops.
66437e309f1SMin-Yih Hsu   VectorLoop->addBasicBlockToLoop(VectorLoopStartBlock, *LI);
66537e309f1SMin-Yih Hsu   VectorLoop->addBasicBlockToLoop(VectorLoopIncBlock, *LI);
66637e309f1SMin-Yih Hsu 
66737e309f1SMin-Yih Hsu   ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI);
66837e309f1SMin-Yih Hsu   ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI);
66937e309f1SMin-Yih Hsu 
67037e309f1SMin-Yih Hsu   // Set up some types and constants that we intend to reuse.
67137e309f1SMin-Yih Hsu   Type *I64Type = Builder.getInt64Ty();
67237e309f1SMin-Yih Hsu 
67337e309f1SMin-Yih Hsu   // Check the zero-extended iteration count > 0
67437e309f1SMin-Yih Hsu   Builder.SetInsertPoint(MinItCheckBlock);
67537e309f1SMin-Yih Hsu   Value *ExtStart = Builder.CreateZExt(Start, I64Type);
67637e309f1SMin-Yih Hsu   Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type);
67737e309f1SMin-Yih Hsu   // This check doesn't really cost us very much.
67837e309f1SMin-Yih Hsu 
67937e309f1SMin-Yih Hsu   Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen);
68037e309f1SMin-Yih Hsu   BranchInst *MinItCheckBr =
68137e309f1SMin-Yih Hsu       BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck);
68237e309f1SMin-Yih Hsu   MinItCheckBr->setMetadata(
68337e309f1SMin-Yih Hsu       LLVMContext::MD_prof,
68437e309f1SMin-Yih Hsu       MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1));
68537e309f1SMin-Yih Hsu   Builder.Insert(MinItCheckBr);
68637e309f1SMin-Yih Hsu 
68737e309f1SMin-Yih Hsu   DTU.applyUpdates(
68837e309f1SMin-Yih Hsu       {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock},
68937e309f1SMin-Yih Hsu        {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}});
69037e309f1SMin-Yih Hsu 
69137e309f1SMin-Yih Hsu   // For each of the arrays, check the start/end addresses are on the same
69237e309f1SMin-Yih Hsu   // page.
69337e309f1SMin-Yih Hsu   Builder.SetInsertPoint(MemCheckBlock);
69437e309f1SMin-Yih Hsu 
69537e309f1SMin-Yih Hsu   // The early exit in the original loop means that when performing vector
69637e309f1SMin-Yih Hsu   // loads we are potentially reading ahead of the early exit. So we could
69737e309f1SMin-Yih Hsu   // fault if crossing a page boundary. Therefore, we create runtime memory
69837e309f1SMin-Yih Hsu   // checks based on the minimum page size as follows:
69937e309f1SMin-Yih Hsu   //   1. Calculate the addresses of the first memory accesses in the loop,
70037e309f1SMin-Yih Hsu   //      i.e. LhsStart and RhsStart.
70137e309f1SMin-Yih Hsu   //   2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd.
70237e309f1SMin-Yih Hsu   //   3. Determine which pages correspond to all the memory accesses, i.e
70337e309f1SMin-Yih Hsu   //      LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage.
70437e309f1SMin-Yih Hsu   //   4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then
70537e309f1SMin-Yih Hsu   //      we know we won't cross any page boundaries in the loop so we can
70637e309f1SMin-Yih Hsu   //      enter the vector loop! Otherwise we fall back on the scalar loop.
70737e309f1SMin-Yih Hsu   Value *LhsStartGEP = Builder.CreateGEP(LoadType, PtrA, ExtStart);
70837e309f1SMin-Yih Hsu   Value *RhsStartGEP = Builder.CreateGEP(LoadType, PtrB, ExtStart);
70937e309f1SMin-Yih Hsu   Value *RhsStart = Builder.CreatePtrToInt(RhsStartGEP, I64Type);
71037e309f1SMin-Yih Hsu   Value *LhsStart = Builder.CreatePtrToInt(LhsStartGEP, I64Type);
71137e309f1SMin-Yih Hsu   Value *LhsEndGEP = Builder.CreateGEP(LoadType, PtrA, ExtEnd);
71237e309f1SMin-Yih Hsu   Value *RhsEndGEP = Builder.CreateGEP(LoadType, PtrB, ExtEnd);
71337e309f1SMin-Yih Hsu   Value *LhsEnd = Builder.CreatePtrToInt(LhsEndGEP, I64Type);
71437e309f1SMin-Yih Hsu   Value *RhsEnd = Builder.CreatePtrToInt(RhsEndGEP, I64Type);
71537e309f1SMin-Yih Hsu 
71637e309f1SMin-Yih Hsu   const uint64_t MinPageSize = TTI->getMinPageSize().value();
71737e309f1SMin-Yih Hsu   const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize);
71837e309f1SMin-Yih Hsu   Value *LhsStartPage = Builder.CreateLShr(LhsStart, AddrShiftAmt);
71937e309f1SMin-Yih Hsu   Value *LhsEndPage = Builder.CreateLShr(LhsEnd, AddrShiftAmt);
72037e309f1SMin-Yih Hsu   Value *RhsStartPage = Builder.CreateLShr(RhsStart, AddrShiftAmt);
72137e309f1SMin-Yih Hsu   Value *RhsEndPage = Builder.CreateLShr(RhsEnd, AddrShiftAmt);
72237e309f1SMin-Yih Hsu   Value *LhsPageCmp = Builder.CreateICmpNE(LhsStartPage, LhsEndPage);
72337e309f1SMin-Yih Hsu   Value *RhsPageCmp = Builder.CreateICmpNE(RhsStartPage, RhsEndPage);
72437e309f1SMin-Yih Hsu 
72537e309f1SMin-Yih Hsu   Value *CombinedPageCmp = Builder.CreateOr(LhsPageCmp, RhsPageCmp);
72637e309f1SMin-Yih Hsu   BranchInst *CombinedPageCmpCmpBr = BranchInst::Create(
72737e309f1SMin-Yih Hsu       LoopPreHeaderBlock, VectorLoopPreheaderBlock, CombinedPageCmp);
72837e309f1SMin-Yih Hsu   CombinedPageCmpCmpBr->setMetadata(
72937e309f1SMin-Yih Hsu       LLVMContext::MD_prof, MDBuilder(CombinedPageCmpCmpBr->getContext())
73037e309f1SMin-Yih Hsu                                 .createBranchWeights(10, 90));
73137e309f1SMin-Yih Hsu   Builder.Insert(CombinedPageCmpCmpBr);
73237e309f1SMin-Yih Hsu 
73337e309f1SMin-Yih Hsu   DTU.applyUpdates(
73437e309f1SMin-Yih Hsu       {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock},
73537e309f1SMin-Yih Hsu        {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}});
73637e309f1SMin-Yih Hsu 
73737e309f1SMin-Yih Hsu   // Set up the vector loop preheader, i.e. calculate initial loop predicate,
73837e309f1SMin-Yih Hsu   // zero-extend MaxLen to 64-bits, determine the number of vector elements
73937e309f1SMin-Yih Hsu   // processed in each iteration, etc.
74037e309f1SMin-Yih Hsu   Builder.SetInsertPoint(VectorLoopPreheaderBlock);
74137e309f1SMin-Yih Hsu 
7428b55d342SMin-Yih Hsu   // At this point we know two things must be true:
7438b55d342SMin-Yih Hsu   //  1. Start <= End
7448b55d342SMin-Yih Hsu   //  2. ExtMaxLen <= MinPageSize due to the page checks.
7458b55d342SMin-Yih Hsu   // Therefore, we know that we can use a 64-bit induction variable that
7468b55d342SMin-Yih Hsu   // starts from 0 -> ExtMaxLen and it will not overflow.
7478b55d342SMin-Yih Hsu   Value *VectorLoopRes = nullptr;
7488b55d342SMin-Yih Hsu   switch (VectorizeStyle) {
7498b55d342SMin-Yih Hsu   case LoopIdiomVectorizeStyle::Masked:
7508b55d342SMin-Yih Hsu     VectorLoopRes =
751de5ff38aSMin-Yih Hsu         createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
7528b55d342SMin-Yih Hsu     break;
7538b55d342SMin-Yih Hsu   case LoopIdiomVectorizeStyle::Predicated:
7548b55d342SMin-Yih Hsu     VectorLoopRes = createPredicatedFindMismatch(Builder, DTU, GEPA, GEPB,
7558b55d342SMin-Yih Hsu                                                  ExtStart, ExtEnd);
7568b55d342SMin-Yih Hsu     break;
7578b55d342SMin-Yih Hsu   }
75837e309f1SMin-Yih Hsu 
75937e309f1SMin-Yih Hsu   Builder.Insert(BranchInst::Create(EndBlock));
76037e309f1SMin-Yih Hsu 
76137e309f1SMin-Yih Hsu   DTU.applyUpdates(
76237e309f1SMin-Yih Hsu       {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}});
76337e309f1SMin-Yih Hsu 
76437e309f1SMin-Yih Hsu   // Generate code for scalar loop.
76537e309f1SMin-Yih Hsu   Builder.SetInsertPoint(LoopPreHeaderBlock);
76637e309f1SMin-Yih Hsu   Builder.Insert(BranchInst::Create(LoopStartBlock));
76737e309f1SMin-Yih Hsu 
76837e309f1SMin-Yih Hsu   DTU.applyUpdates(
76937e309f1SMin-Yih Hsu       {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}});
77037e309f1SMin-Yih Hsu 
77137e309f1SMin-Yih Hsu   Builder.SetInsertPoint(LoopStartBlock);
77237e309f1SMin-Yih Hsu   PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index");
77337e309f1SMin-Yih Hsu   IndexPhi->addIncoming(Start, LoopPreHeaderBlock);
77437e309f1SMin-Yih Hsu 
77537e309f1SMin-Yih Hsu   // Otherwise compare the values
77637e309f1SMin-Yih Hsu   // Load bytes from each array and compare them.
77737e309f1SMin-Yih Hsu   Value *GepOffset = Builder.CreateZExt(IndexPhi, I64Type);
77837e309f1SMin-Yih Hsu 
77937e309f1SMin-Yih Hsu   Value *LhsGep =
78037e309f1SMin-Yih Hsu       Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds());
78137e309f1SMin-Yih Hsu   Value *LhsLoad = Builder.CreateLoad(LoadType, LhsGep);
78237e309f1SMin-Yih Hsu 
78337e309f1SMin-Yih Hsu   Value *RhsGep =
78437e309f1SMin-Yih Hsu       Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds());
78537e309f1SMin-Yih Hsu   Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep);
78637e309f1SMin-Yih Hsu 
78737e309f1SMin-Yih Hsu   Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
78837e309f1SMin-Yih Hsu   // If we have a mismatch then exit the loop ...
78937e309f1SMin-Yih Hsu   BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp);
79037e309f1SMin-Yih Hsu   Builder.Insert(MatchCmpBr);
79137e309f1SMin-Yih Hsu 
79237e309f1SMin-Yih Hsu   DTU.applyUpdates({{DominatorTree::Insert, LoopStartBlock, LoopIncBlock},
79337e309f1SMin-Yih Hsu                     {DominatorTree::Insert, LoopStartBlock, EndBlock}});
79437e309f1SMin-Yih Hsu 
79537e309f1SMin-Yih Hsu   // Have we reached the maximum permitted length for the loop?
79637e309f1SMin-Yih Hsu   Builder.SetInsertPoint(LoopIncBlock);
79737e309f1SMin-Yih Hsu   Value *PhiInc = Builder.CreateAdd(IndexPhi, ConstantInt::get(ResType, 1), "",
79837e309f1SMin-Yih Hsu                                     /*HasNUW=*/Index->hasNoUnsignedWrap(),
79937e309f1SMin-Yih Hsu                                     /*HasNSW=*/Index->hasNoSignedWrap());
80037e309f1SMin-Yih Hsu   IndexPhi->addIncoming(PhiInc, LoopIncBlock);
80137e309f1SMin-Yih Hsu   Value *IVCmp = Builder.CreateICmpEQ(PhiInc, MaxLen);
80237e309f1SMin-Yih Hsu   BranchInst *IVCmpBr = BranchInst::Create(EndBlock, LoopStartBlock, IVCmp);
80337e309f1SMin-Yih Hsu   Builder.Insert(IVCmpBr);
80437e309f1SMin-Yih Hsu 
80537e309f1SMin-Yih Hsu   DTU.applyUpdates({{DominatorTree::Insert, LoopIncBlock, EndBlock},
80637e309f1SMin-Yih Hsu                     {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}});
80737e309f1SMin-Yih Hsu 
80837e309f1SMin-Yih Hsu   // In the end block we need to insert a PHI node to deal with three cases:
80937e309f1SMin-Yih Hsu   //  1. We didn't find a mismatch in the scalar loop, so we return MaxLen.
81037e309f1SMin-Yih Hsu   //  2. We exitted the scalar loop early due to a mismatch and need to return
81137e309f1SMin-Yih Hsu   //  the index that we found.
81237e309f1SMin-Yih Hsu   //  3. We didn't find a mismatch in the vector loop, so we return MaxLen.
81337e309f1SMin-Yih Hsu   //  4. We exitted the vector loop early due to a mismatch and need to return
81437e309f1SMin-Yih Hsu   //  the index that we found.
815d75f9dd1SStephen Tozer   Builder.SetInsertPoint(EndBlock, EndBlock->getFirstInsertionPt());
81637e309f1SMin-Yih Hsu   PHINode *ResPhi = Builder.CreatePHI(ResType, 4, "mismatch_result");
81737e309f1SMin-Yih Hsu   ResPhi->addIncoming(MaxLen, LoopIncBlock);
81837e309f1SMin-Yih Hsu   ResPhi->addIncoming(IndexPhi, LoopStartBlock);
81937e309f1SMin-Yih Hsu   ResPhi->addIncoming(MaxLen, VectorLoopIncBlock);
82037e309f1SMin-Yih Hsu   ResPhi->addIncoming(VectorLoopRes, VectorLoopMismatchBlock);
82137e309f1SMin-Yih Hsu 
82237e309f1SMin-Yih Hsu   Value *FinalRes = Builder.CreateTrunc(ResPhi, ResType);
82337e309f1SMin-Yih Hsu 
82437e309f1SMin-Yih Hsu   if (VerifyLoops) {
82537e309f1SMin-Yih Hsu     ScalarLoop->verifyLoop();
82637e309f1SMin-Yih Hsu     VectorLoop->verifyLoop();
82737e309f1SMin-Yih Hsu     if (!VectorLoop->isRecursivelyLCSSAForm(*DT, *LI))
82837e309f1SMin-Yih Hsu       report_fatal_error("Loops must remain in LCSSA form!");
82937e309f1SMin-Yih Hsu     if (!ScalarLoop->isRecursivelyLCSSAForm(*DT, *LI))
83037e309f1SMin-Yih Hsu       report_fatal_error("Loops must remain in LCSSA form!");
83137e309f1SMin-Yih Hsu   }
83237e309f1SMin-Yih Hsu 
83337e309f1SMin-Yih Hsu   return FinalRes;
83437e309f1SMin-Yih Hsu }
83537e309f1SMin-Yih Hsu 
83637e309f1SMin-Yih Hsu void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
83737e309f1SMin-Yih Hsu                                               GetElementPtrInst *GEPB,
83837e309f1SMin-Yih Hsu                                               PHINode *IndPhi, Value *MaxLen,
83937e309f1SMin-Yih Hsu                                               Instruction *Index, Value *Start,
84037e309f1SMin-Yih Hsu                                               bool IncIdx, BasicBlock *FoundBB,
84137e309f1SMin-Yih Hsu                                               BasicBlock *EndBB) {
84237e309f1SMin-Yih Hsu 
84337e309f1SMin-Yih Hsu   // Insert the byte compare code at the end of the preheader block
84437e309f1SMin-Yih Hsu   BasicBlock *Preheader = CurLoop->getLoopPreheader();
84537e309f1SMin-Yih Hsu   BasicBlock *Header = CurLoop->getHeader();
84637e309f1SMin-Yih Hsu   BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
84737e309f1SMin-Yih Hsu   IRBuilder<> Builder(PHBranch);
84837e309f1SMin-Yih Hsu   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
84937e309f1SMin-Yih Hsu   Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
85037e309f1SMin-Yih Hsu 
85137e309f1SMin-Yih Hsu   // Increment the pointer if this was done before the loads in the loop.
85237e309f1SMin-Yih Hsu   if (IncIdx)
85337e309f1SMin-Yih Hsu     Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1));
85437e309f1SMin-Yih Hsu 
85537e309f1SMin-Yih Hsu   Value *ByteCmpRes =
85637e309f1SMin-Yih Hsu       expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen);
85737e309f1SMin-Yih Hsu 
85837e309f1SMin-Yih Hsu   // Replaces uses of index & induction Phi with intrinsic (we already
85937e309f1SMin-Yih Hsu   // checked that the the first instruction of Header is the Phi above).
86037e309f1SMin-Yih Hsu   assert(IndPhi->hasOneUse() && "Index phi node has more than one use!");
86137e309f1SMin-Yih Hsu   Index->replaceAllUsesWith(ByteCmpRes);
86237e309f1SMin-Yih Hsu 
86337e309f1SMin-Yih Hsu   assert(PHBranch->isUnconditional() &&
86437e309f1SMin-Yih Hsu          "Expected preheader to terminate with an unconditional branch.");
86537e309f1SMin-Yih Hsu 
86637e309f1SMin-Yih Hsu   // If no mismatch was found, we can jump to the end block. Create a
86737e309f1SMin-Yih Hsu   // new basic block for the compare instruction.
86837e309f1SMin-Yih Hsu   auto *CmpBB = BasicBlock::Create(Preheader->getContext(), "byte.compare",
86937e309f1SMin-Yih Hsu                                    Preheader->getParent());
87037e309f1SMin-Yih Hsu   CmpBB->moveBefore(EndBB);
87137e309f1SMin-Yih Hsu 
87237e309f1SMin-Yih Hsu   // Replace the branch in the preheader with an always-true conditional branch.
87337e309f1SMin-Yih Hsu   // This ensures there is still a reference to the original loop.
87437e309f1SMin-Yih Hsu   Builder.CreateCondBr(Builder.getTrue(), CmpBB, Header);
87537e309f1SMin-Yih Hsu   PHBranch->eraseFromParent();
87637e309f1SMin-Yih Hsu 
87737e309f1SMin-Yih Hsu   BasicBlock *MismatchEnd = cast<Instruction>(ByteCmpRes)->getParent();
87837e309f1SMin-Yih Hsu   DTU.applyUpdates({{DominatorTree::Insert, MismatchEnd, CmpBB}});
87937e309f1SMin-Yih Hsu 
88037e309f1SMin-Yih Hsu   // Create the branch to either the end or found block depending on the value
88137e309f1SMin-Yih Hsu   // returned by the intrinsic.
88237e309f1SMin-Yih Hsu   Builder.SetInsertPoint(CmpBB);
88337e309f1SMin-Yih Hsu   if (FoundBB != EndBB) {
88437e309f1SMin-Yih Hsu     Value *FoundCmp = Builder.CreateICmpEQ(ByteCmpRes, MaxLen);
88537e309f1SMin-Yih Hsu     Builder.CreateCondBr(FoundCmp, EndBB, FoundBB);
88637e309f1SMin-Yih Hsu     DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB},
88737e309f1SMin-Yih Hsu                       {DominatorTree::Insert, CmpBB, EndBB}});
88837e309f1SMin-Yih Hsu 
88937e309f1SMin-Yih Hsu   } else {
89037e309f1SMin-Yih Hsu     Builder.CreateBr(FoundBB);
89137e309f1SMin-Yih Hsu     DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}});
89237e309f1SMin-Yih Hsu   }
89337e309f1SMin-Yih Hsu 
89437e309f1SMin-Yih Hsu   auto fixSuccessorPhis = [&](BasicBlock *SuccBB) {
89537e309f1SMin-Yih Hsu     for (PHINode &PN : SuccBB->phis()) {
89637e309f1SMin-Yih Hsu       // At this point we've already replaced all uses of the result from the
89737e309f1SMin-Yih Hsu       // loop with ByteCmp. Look through the incoming values to find ByteCmp,
89837e309f1SMin-Yih Hsu       // meaning this is a Phi collecting the results of the byte compare.
89937e309f1SMin-Yih Hsu       bool ResPhi = false;
90037e309f1SMin-Yih Hsu       for (Value *Op : PN.incoming_values())
90137e309f1SMin-Yih Hsu         if (Op == ByteCmpRes) {
90237e309f1SMin-Yih Hsu           ResPhi = true;
90337e309f1SMin-Yih Hsu           break;
90437e309f1SMin-Yih Hsu         }
90537e309f1SMin-Yih Hsu 
90637e309f1SMin-Yih Hsu       // Any PHI that depended upon the result of the byte compare needs a new
90737e309f1SMin-Yih Hsu       // incoming value from CmpBB. This is because the original loop will get
90837e309f1SMin-Yih Hsu       // deleted.
90937e309f1SMin-Yih Hsu       if (ResPhi)
91037e309f1SMin-Yih Hsu         PN.addIncoming(ByteCmpRes, CmpBB);
91137e309f1SMin-Yih Hsu       else {
91237e309f1SMin-Yih Hsu         // There should be no other outside uses of other values in the
91337e309f1SMin-Yih Hsu         // original loop. Any incoming values should either:
91437e309f1SMin-Yih Hsu         //   1. Be for blocks outside the loop, which aren't interesting. Or ..
91537e309f1SMin-Yih Hsu         //   2. These are from blocks in the loop with values defined outside
91637e309f1SMin-Yih Hsu         //      the loop. We should a similar incoming value from CmpBB.
91737e309f1SMin-Yih Hsu         for (BasicBlock *BB : PN.blocks())
91837e309f1SMin-Yih Hsu           if (CurLoop->contains(BB)) {
91937e309f1SMin-Yih Hsu             PN.addIncoming(PN.getIncomingValueForBlock(BB), CmpBB);
92037e309f1SMin-Yih Hsu             break;
92137e309f1SMin-Yih Hsu           }
92237e309f1SMin-Yih Hsu       }
92337e309f1SMin-Yih Hsu     }
92437e309f1SMin-Yih Hsu   };
92537e309f1SMin-Yih Hsu 
92637e309f1SMin-Yih Hsu   // Ensure all Phis in the successors of CmpBB have an incoming value from it.
92737e309f1SMin-Yih Hsu   fixSuccessorPhis(EndBB);
92837e309f1SMin-Yih Hsu   if (EndBB != FoundBB)
92937e309f1SMin-Yih Hsu     fixSuccessorPhis(FoundBB);
93037e309f1SMin-Yih Hsu 
93137e309f1SMin-Yih Hsu   // The new CmpBB block isn't part of the loop, but will need to be added to
93237e309f1SMin-Yih Hsu   // the outer loop if there is one.
93337e309f1SMin-Yih Hsu   if (!CurLoop->isOutermost())
93437e309f1SMin-Yih Hsu     CurLoop->getParentLoop()->addBasicBlockToLoop(CmpBB, *LI);
93537e309f1SMin-Yih Hsu 
93637e309f1SMin-Yih Hsu   if (VerifyLoops && CurLoop->getParentLoop()) {
93737e309f1SMin-Yih Hsu     CurLoop->getParentLoop()->verifyLoop();
93837e309f1SMin-Yih Hsu     if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI))
93937e309f1SMin-Yih Hsu       report_fatal_error("Loops must remain in LCSSA form!");
94037e309f1SMin-Yih Hsu   }
94137e309f1SMin-Yih Hsu }
942