xref: /llvm-project/llvm/lib/Transforms/Scalar/MergeICmps.cpp (revision fb2c98a929aa65603e9d984307a41325e577e9d3)
1 //===- MergeICmps.cpp - Optimize chains of integer comparisons ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass turns chains of integer comparisons into memcmp (the memcmp is
10 // later typically inlined as a chain of efficient hardware comparisons). This
11 // typically benefits c++ member or nonmember operator==().
12 //
13 // The basic idea is to replace a longer chain of integer comparisons loaded
14 // from contiguous memory locations into a shorter chain of larger integer
15 // comparisons. Benefits are double:
16 //  - There are less jumps, and therefore less opportunities for mispredictions
17 //    and I-cache misses.
18 //  - Code size is smaller, both because jumps are removed and because the
19 //    encoding of a 2*n byte compare is smaller than that of two n-byte
20 //    compares.
21 //
22 // Example:
23 //
24 //  struct S {
25 //    int a;
26 //    char b;
27 //    char c;
28 //    uint16_t d;
29 //    bool operator==(const S& o) const {
30 //      return a == o.a && b == o.b && c == o.c && d == o.d;
31 //    }
32 //  };
33 //
34 //  Is optimized as :
35 //
36 //    bool S::operator==(const S& o) const {
37 //      return memcmp(this, &o, 8) == 0;
38 //    }
39 //
40 //  Which will later be expanded (ExpandMemCmp) as a single 8-bytes icmp.
41 //
42 //===----------------------------------------------------------------------===//
43 
44 #include "llvm/Transforms/Scalar/MergeICmps.h"
45 #include "llvm/Analysis/DomTreeUpdater.h"
46 #include "llvm/Analysis/GlobalsModRef.h"
47 #include "llvm/Analysis/Loads.h"
48 #include "llvm/Analysis/TargetLibraryInfo.h"
49 #include "llvm/Analysis/TargetTransformInfo.h"
50 #include "llvm/IR/Dominators.h"
51 #include "llvm/IR/Function.h"
52 #include "llvm/IR/Instruction.h"
53 #include "llvm/IR/IRBuilder.h"
54 #include "llvm/InitializePasses.h"
55 #include "llvm/Pass.h"
56 #include "llvm/Transforms/Scalar.h"
57 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
58 #include "llvm/Transforms/Utils/BuildLibCalls.h"
59 #include <algorithm>
60 #include <numeric>
61 #include <utility>
62 #include <vector>
63 
64 using namespace llvm;
65 
66 namespace {
67 
68 #define DEBUG_TYPE "mergeicmps"
69 
70 // A BCE atom "Binary Compare Expression Atom" represents an integer load
71 // that is a constant offset from a base value, e.g. `a` or `o.c` in the example
72 // at the top.
73 struct BCEAtom {
74   BCEAtom() = default;
75   BCEAtom(GetElementPtrInst *GEP, LoadInst *LoadI, int BaseId, APInt Offset)
76       : GEP(GEP), LoadI(LoadI), BaseId(BaseId), Offset(Offset) {}
77 
78   BCEAtom(const BCEAtom &) = delete;
79   BCEAtom &operator=(const BCEAtom &) = delete;
80 
81   BCEAtom(BCEAtom &&that) = default;
82   BCEAtom &operator=(BCEAtom &&that) {
83     if (this == &that)
84       return *this;
85     GEP = that.GEP;
86     LoadI = that.LoadI;
87     BaseId = that.BaseId;
88     Offset = std::move(that.Offset);
89     return *this;
90   }
91 
92   // We want to order BCEAtoms by (Base, Offset). However we cannot use
93   // the pointer values for Base because these are non-deterministic.
94   // To make sure that the sort order is stable, we first assign to each atom
95   // base value an index based on its order of appearance in the chain of
96   // comparisons. We call this index `BaseOrdering`. For example, for:
97   //    b[3] == c[2] && a[1] == d[1] && b[4] == c[3]
98   //    |  block 1 |    |  block 2 |    |  block 3 |
99   // b gets assigned index 0 and a index 1, because b appears as LHS in block 1,
100   // which is before block 2.
101   // We then sort by (BaseOrdering[LHS.Base()], LHS.Offset), which is stable.
102   bool operator<(const BCEAtom &O) const {
103     return BaseId != O.BaseId ? BaseId < O.BaseId : Offset.slt(O.Offset);
104   }
105 
106   GetElementPtrInst *GEP = nullptr;
107   LoadInst *LoadI = nullptr;
108   unsigned BaseId = 0;
109   APInt Offset;
110 };
111 
112 // A class that assigns increasing ids to values in the order in which they are
113 // seen. See comment in `BCEAtom::operator<()``.
114 class BaseIdentifier {
115 public:
116   // Returns the id for value `Base`, after assigning one if `Base` has not been
117   // seen before.
118   int getBaseId(const Value *Base) {
119     assert(Base && "invalid base");
120     const auto Insertion = BaseToIndex.try_emplace(Base, Order);
121     if (Insertion.second)
122       ++Order;
123     return Insertion.first->second;
124   }
125 
126 private:
127   unsigned Order = 1;
128   DenseMap<const Value*, int> BaseToIndex;
129 };
130 
131 // If this value is a load from a constant offset w.r.t. a base address, and
132 // there are no other users of the load or address, returns the base address and
133 // the offset.
134 BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) {
135   auto *const LoadI = dyn_cast<LoadInst>(Val);
136   if (!LoadI)
137     return {};
138   LLVM_DEBUG(dbgs() << "load\n");
139   if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) {
140     LLVM_DEBUG(dbgs() << "used outside of block\n");
141     return {};
142   }
143   // Do not optimize atomic loads to non-atomic memcmp
144   if (!LoadI->isSimple()) {
145     LLVM_DEBUG(dbgs() << "volatile or atomic\n");
146     return {};
147   }
148   Value *Addr = LoadI->getOperand(0);
149   if (Addr->getType()->getPointerAddressSpace() != 0) {
150     LLVM_DEBUG(dbgs() << "from non-zero AddressSpace\n");
151     return {};
152   }
153   const auto &DL = LoadI->getModule()->getDataLayout();
154   if (!isDereferenceablePointer(Addr, LoadI->getType(), DL)) {
155     LLVM_DEBUG(dbgs() << "not dereferenceable\n");
156     // We need to make sure that we can do comparison in any order, so we
157     // require memory to be unconditionally dereferenceable.
158     return {};
159   }
160 
161   APInt Offset = APInt(DL.getIndexTypeSizeInBits(Addr->getType()), 0);
162   Value *Base = Addr;
163   auto *GEP = dyn_cast<GetElementPtrInst>(Addr);
164   if (GEP) {
165     LLVM_DEBUG(dbgs() << "GEP\n");
166     if (GEP->isUsedOutsideOfBlock(LoadI->getParent())) {
167       LLVM_DEBUG(dbgs() << "used outside of block\n");
168       return {};
169     }
170     if (!GEP->accumulateConstantOffset(DL, Offset))
171       return {};
172     Base = GEP->getPointerOperand();
173   }
174   return BCEAtom(GEP, LoadI, BaseId.getBaseId(Base), Offset);
175 }
176 
177 // A comparison between two BCE atoms, e.g. `a == o.a` in the example at the
178 // top.
179 // Note: the terminology is misleading: the comparison is symmetric, so there
180 // is no real {l/r}hs. What we want though is to have the same base on the
181 // left (resp. right), so that we can detect consecutive loads. To ensure this
182 // we put the smallest atom on the left.
183 struct BCECmp {
184   BCEAtom Lhs;
185   BCEAtom Rhs;
186   int SizeBits;
187   const ICmpInst *CmpI;
188 
189   BCECmp(BCEAtom L, BCEAtom R, int SizeBits, const ICmpInst *CmpI)
190       : Lhs(std::move(L)), Rhs(std::move(R)), SizeBits(SizeBits), CmpI(CmpI) {
191     if (Rhs < Lhs) std::swap(Rhs, Lhs);
192   }
193 };
194 
195 // A basic block with a comparison between two BCE atoms.
196 // The block might do extra work besides the atom comparison, in which case
197 // doesOtherWork() returns true. Under some conditions, the block can be
198 // split into the atom comparison part and the "other work" part
199 // (see canSplit()).
200 class BCECmpBlock {
201  public:
202   typedef SmallDenseSet<const Instruction *, 8> InstructionSet;
203 
204   BCECmpBlock(BCECmp Cmp, BasicBlock *BB, InstructionSet BlockInsts)
205       : BB(BB), BlockInsts(std::move(BlockInsts)), Cmp(std::move(Cmp)) {}
206 
207   const BCEAtom &Lhs() const { return Cmp.Lhs; }
208   const BCEAtom &Rhs() const { return Cmp.Rhs; }
209   int SizeBits() const { return Cmp.SizeBits; }
210 
211   // Returns true if the block does other works besides comparison.
212   bool doesOtherWork() const;
213 
214   // Returns true if the non-BCE-cmp instructions can be separated from BCE-cmp
215   // instructions in the block.
216   bool canSplit(AliasAnalysis &AA) const;
217 
218   // Return true if this all the relevant instructions in the BCE-cmp-block can
219   // be sunk below this instruction. By doing this, we know we can separate the
220   // BCE-cmp-block instructions from the non-BCE-cmp-block instructions in the
221   // block.
222   bool canSinkBCECmpInst(const Instruction *, AliasAnalysis &AA) const;
223 
224   // We can separate the BCE-cmp-block instructions and the non-BCE-cmp-block
225   // instructions. Split the old block and move all non-BCE-cmp-insts into the
226   // new parent block.
227   void split(BasicBlock *NewParent, AliasAnalysis &AA) const;
228 
229   // The basic block where this comparison happens.
230   BasicBlock *BB;
231   // Instructions relating to the BCECmp and branch.
232   InstructionSet BlockInsts;
233   // The block requires splitting.
234   bool RequireSplit = false;
235   // Original order of this block in the chain.
236   unsigned OrigOrder = 0;
237 
238 private:
239   BCECmp Cmp;
240 };
241 
242 bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst,
243                                     AliasAnalysis &AA) const {
244   // If this instruction may clobber the loads and is in middle of the BCE cmp
245   // block instructions, then bail for now.
246   if (Inst->mayWriteToMemory()) {
247     auto MayClobber = [&](LoadInst *LI) {
248       // If a potentially clobbering instruction comes before the load,
249       // we can still safely sink the load.
250       return (Inst->getParent() != LI->getParent() || !Inst->comesBefore(LI)) &&
251              isModSet(AA.getModRefInfo(Inst, MemoryLocation::get(LI)));
252     };
253     if (MayClobber(Cmp.Lhs.LoadI) || MayClobber(Cmp.Rhs.LoadI))
254       return false;
255   }
256   // Make sure this instruction does not use any of the BCE cmp block
257   // instructions as operand.
258   return llvm::none_of(Inst->operands(), [&](const Value *Op) {
259     const Instruction *OpI = dyn_cast<Instruction>(Op);
260     return OpI && BlockInsts.contains(OpI);
261   });
262 }
263 
264 void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis &AA) const {
265   llvm::SmallVector<Instruction *, 4> OtherInsts;
266   for (Instruction &Inst : *BB) {
267     if (BlockInsts.count(&Inst))
268       continue;
269     assert(canSinkBCECmpInst(&Inst, AA) && "Split unsplittable block");
270     // This is a non-BCE-cmp-block instruction. And it can be separated
271     // from the BCE-cmp-block instruction.
272     OtherInsts.push_back(&Inst);
273   }
274 
275   // Do the actual spliting.
276   for (Instruction *Inst : reverse(OtherInsts))
277     Inst->moveBefore(*NewParent, NewParent->begin());
278 }
279 
280 bool BCECmpBlock::canSplit(AliasAnalysis &AA) const {
281   for (Instruction &Inst : *BB) {
282     if (!BlockInsts.count(&Inst)) {
283       if (!canSinkBCECmpInst(&Inst, AA))
284         return false;
285     }
286   }
287   return true;
288 }
289 
290 bool BCECmpBlock::doesOtherWork() const {
291   // TODO(courbet): Can we allow some other things ? This is very conservative.
292   // We might be able to get away with anything does not have any side
293   // effects outside of the basic block.
294   // Note: The GEPs and/or loads are not necessarily in the same block.
295   for (const Instruction &Inst : *BB) {
296     if (!BlockInsts.count(&Inst))
297       return true;
298   }
299   return false;
300 }
301 
302 // Visit the given comparison. If this is a comparison between two valid
303 // BCE atoms, returns the comparison.
304 std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI,
305                                 const ICmpInst::Predicate ExpectedPredicate,
306                                 BaseIdentifier &BaseId) {
307   // The comparison can only be used once:
308   //  - For intermediate blocks, as a branch condition.
309   //  - For the final block, as an incoming value for the Phi.
310   // If there are any other uses of the comparison, we cannot merge it with
311   // other comparisons as we would create an orphan use of the value.
312   if (!CmpI->hasOneUse()) {
313     LLVM_DEBUG(dbgs() << "cmp has several uses\n");
314     return std::nullopt;
315   }
316   if (CmpI->getPredicate() != ExpectedPredicate)
317     return std::nullopt;
318   LLVM_DEBUG(dbgs() << "cmp "
319                     << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne")
320                     << "\n");
321   auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0), BaseId);
322   if (!Lhs.BaseId)
323     return std::nullopt;
324   auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1), BaseId);
325   if (!Rhs.BaseId)
326     return std::nullopt;
327   const auto &DL = CmpI->getModule()->getDataLayout();
328   return BCECmp(std::move(Lhs), std::move(Rhs),
329                 DL.getTypeSizeInBits(CmpI->getOperand(0)->getType()), CmpI);
330 }
331 
332 // Visit the given comparison block. If this is a comparison between two valid
333 // BCE atoms, returns the comparison.
334 std::optional<BCECmpBlock>
335 visitCmpBlock(Value *const Baseline, ICmpInst::Predicate &Predicate,
336               Value *const Val, BasicBlock *const Block,
337               const BasicBlock *const PhiBlock, BaseIdentifier &BaseId) {
338   if (Block->empty())
339     return std::nullopt;
340   auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator());
341   if (!BranchI)
342     return std::nullopt;
343   LLVM_DEBUG(dbgs() << "branch\n");
344   Value *Cond;
345   ICmpInst::Predicate ExpectedPredicate;
346   if (BranchI->isUnconditional()) {
347     // In this case, we expect an incoming value which is the result of the
348     // comparison. This is the last link in the chain of comparisons (note
349     // that this does not mean that this is the last incoming value, blocks
350     // can be reordered).
351     Cond = Val;
352     const auto *const ConstBase = cast<ConstantInt>(Baseline);
353     assert(ConstBase->getType()->isIntegerTy(1) &&
354            "Select condition is not an i1?");
355     ExpectedPredicate =
356         ConstBase->isOne() ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ;
357 
358     // Remember the correct predicate.
359     Predicate = ExpectedPredicate;
360   } else {
361     // All the incoming values must be consistent.
362     if (Baseline != Val)
363       return std::nullopt;
364     // In this case, we expect a constant incoming value (the comparison is
365     // chained).
366     const auto *const Const = cast<ConstantInt>(Val);
367     assert(Const->getType()->isIntegerTy(1) &&
368            "Incoming value is not an i1?");
369     LLVM_DEBUG(dbgs() << "const i1 value\n");
370     if (!Const->isZero() && !Const->isOne())
371       return std::nullopt;
372     LLVM_DEBUG(dbgs() << *Const << "\n");
373     assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch");
374     BasicBlock *const FalseBlock = BranchI->getSuccessor(1);
375     Cond = BranchI->getCondition();
376     ExpectedPredicate =
377         FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
378   }
379 
380   auto *CmpI = dyn_cast<ICmpInst>(Cond);
381   if (!CmpI)
382     return std::nullopt;
383   LLVM_DEBUG(dbgs() << "icmp\n");
384 
385   std::optional<BCECmp> Result = visitICmp(CmpI, ExpectedPredicate, BaseId);
386   if (!Result)
387     return std::nullopt;
388 
389   BCECmpBlock::InstructionSet BlockInsts(
390       {Result->Lhs.LoadI, Result->Rhs.LoadI, Result->CmpI, BranchI});
391   if (Result->Lhs.GEP)
392     BlockInsts.insert(Result->Lhs.GEP);
393   if (Result->Rhs.GEP)
394     BlockInsts.insert(Result->Rhs.GEP);
395   return BCECmpBlock(std::move(*Result), Block, BlockInsts);
396 }
397 
398 static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons,
399                                 BCECmpBlock &&Comparison) {
400   LLVM_DEBUG(dbgs() << "Block '" << Comparison.BB->getName()
401                     << "': Found cmp of " << Comparison.SizeBits()
402                     << " bits between " << Comparison.Lhs().BaseId << " + "
403                     << Comparison.Lhs().Offset << " and "
404                     << Comparison.Rhs().BaseId << " + "
405                     << Comparison.Rhs().Offset << "\n");
406   LLVM_DEBUG(dbgs() << "\n");
407   Comparison.OrigOrder = Comparisons.size();
408   Comparisons.push_back(std::move(Comparison));
409 }
410 
411 // A chain of comparisons.
412 class BCECmpChain {
413 public:
414   using ContiguousBlocks = std::vector<BCECmpBlock>;
415 
416   BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
417               AliasAnalysis &AA);
418 
419   bool simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
420                 DomTreeUpdater &DTU);
421 
422   bool atLeastOneMerged() const {
423     return any_of(MergedBlocks_,
424                   [](const auto &Blocks) { return Blocks.size() > 1; });
425   }
426 
427 private:
428   PHINode &Phi_;
429   // The list of all blocks in the chain, grouped by contiguity.
430   std::vector<ContiguousBlocks> MergedBlocks_;
431   // The original entry block (before sorting);
432   BasicBlock *EntryBlock_;
433   // Remember the predicate type of the chain.
434   ICmpInst::Predicate Predicate_;
435 };
436 
437 static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) {
438   return First.Lhs().BaseId == Second.Lhs().BaseId &&
439          First.Rhs().BaseId == Second.Rhs().BaseId &&
440          First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset &&
441          First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset;
442 }
443 
444 static unsigned getMinOrigOrder(const BCECmpChain::ContiguousBlocks &Blocks) {
445   unsigned MinOrigOrder = std::numeric_limits<unsigned>::max();
446   for (const BCECmpBlock &Block : Blocks)
447     MinOrigOrder = std::min(MinOrigOrder, Block.OrigOrder);
448   return MinOrigOrder;
449 }
450 
451 /// Given a chain of comparison blocks, groups the blocks into contiguous
452 /// ranges that can be merged together into a single comparison.
453 static std::vector<BCECmpChain::ContiguousBlocks>
454 mergeBlocks(std::vector<BCECmpBlock> &&Blocks) {
455   std::vector<BCECmpChain::ContiguousBlocks> MergedBlocks;
456 
457   // Sort to detect continuous offsets.
458   llvm::sort(Blocks,
459              [](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) {
460                return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) <
461                       std::tie(RhsBlock.Lhs(), RhsBlock.Rhs());
462              });
463 
464   BCECmpChain::ContiguousBlocks *LastMergedBlock = nullptr;
465   for (BCECmpBlock &Block : Blocks) {
466     if (!LastMergedBlock || !areContiguous(LastMergedBlock->back(), Block)) {
467       MergedBlocks.emplace_back();
468       LastMergedBlock = &MergedBlocks.back();
469     } else {
470       LLVM_DEBUG(dbgs() << "Merging block " << Block.BB->getName() << " into "
471                         << LastMergedBlock->back().BB->getName() << "\n");
472     }
473     LastMergedBlock->push_back(std::move(Block));
474   }
475 
476   // While we allow reordering for merging, do not reorder unmerged comparisons.
477   // Doing so may introduce branch on poison.
478   llvm::sort(MergedBlocks, [](const BCECmpChain::ContiguousBlocks &LhsBlocks,
479                               const BCECmpChain::ContiguousBlocks &RhsBlocks) {
480     return getMinOrigOrder(LhsBlocks) < getMinOrigOrder(RhsBlocks);
481   });
482 
483   return MergedBlocks;
484 }
485 
486 BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
487                          AliasAnalysis &AA)
488     : Phi_(Phi) {
489   assert(!Blocks.empty() && "a chain should have at least one block");
490   // Now look inside blocks to check for BCE comparisons.
491   std::vector<BCECmpBlock> Comparisons;
492   BaseIdentifier BaseId;
493   Value *const Baseline = Phi.getIncomingValueForBlock(Blocks[0]);
494   Predicate_ = CmpInst::BAD_ICMP_PREDICATE;
495   for (BasicBlock *const Block : Blocks) {
496     assert(Block && "invalid block");
497     std::optional<BCECmpBlock> Comparison =
498         visitCmpBlock(Baseline, Predicate_, Phi.getIncomingValueForBlock(Block),
499                       Block, Phi.getParent(), BaseId);
500     if (!Comparison) {
501       LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n");
502       return;
503     }
504     if (Comparison->doesOtherWork()) {
505       LLVM_DEBUG(dbgs() << "block '" << Comparison->BB->getName()
506                         << "' does extra work besides compare\n");
507       if (Comparisons.empty()) {
508         // This is the initial block in the chain, in case this block does other
509         // work, we can try to split the block and move the irrelevant
510         // instructions to the predecessor.
511         //
512         // If this is not the initial block in the chain, splitting it wont
513         // work.
514         //
515         // As once split, there will still be instructions before the BCE cmp
516         // instructions that do other work in program order, i.e. within the
517         // chain before sorting. Unless we can abort the chain at this point
518         // and start anew.
519         //
520         // NOTE: we only handle blocks a with single predecessor for now.
521         if (Comparison->canSplit(AA)) {
522           LLVM_DEBUG(dbgs()
523                      << "Split initial block '" << Comparison->BB->getName()
524                      << "' that does extra work besides compare\n");
525           Comparison->RequireSplit = true;
526           enqueueBlock(Comparisons, std::move(*Comparison));
527         } else {
528           LLVM_DEBUG(dbgs()
529                      << "ignoring initial block '" << Comparison->BB->getName()
530                      << "' that does extra work besides compare\n");
531         }
532         continue;
533       }
534       // TODO(courbet): Right now we abort the whole chain. We could be
535       // merging only the blocks that don't do other work and resume the
536       // chain from there. For example:
537       //  if (a[0] == b[0]) {  // bb1
538       //    if (a[1] == b[1]) {  // bb2
539       //      some_value = 3; //bb3
540       //      if (a[2] == b[2]) { //bb3
541       //        do a ton of stuff  //bb4
542       //      }
543       //    }
544       //  }
545       //
546       // This is:
547       //
548       // bb1 --eq--> bb2 --eq--> bb3* -eq--> bb4 --+
549       //  \            \           \               \
550       //   ne           ne          ne              \
551       //    \            \           \               v
552       //     +------------+-----------+----------> bb_phi
553       //
554       // We can only merge the first two comparisons, because bb3* does
555       // "other work" (setting some_value to 3).
556       // We could still merge bb1 and bb2 though.
557       return;
558     }
559     enqueueBlock(Comparisons, std::move(*Comparison));
560   }
561 
562   // It is possible we have no suitable comparison to merge.
563   if (Comparisons.empty()) {
564     LLVM_DEBUG(dbgs() << "chain with no BCE basic blocks, no merge\n");
565     return;
566   }
567   EntryBlock_ = Comparisons[0].BB;
568   MergedBlocks_ = mergeBlocks(std::move(Comparisons));
569 }
570 
571 namespace {
572 
573 // A class to compute the name of a set of merged basic blocks.
574 // This is optimized for the common case of no block names.
575 class MergedBlockName {
576   // Storage for the uncommon case of several named blocks.
577   SmallString<16> Scratch;
578 
579 public:
580   explicit MergedBlockName(ArrayRef<BCECmpBlock> Comparisons)
581       : Name(makeName(Comparisons)) {}
582   const StringRef Name;
583 
584 private:
585   StringRef makeName(ArrayRef<BCECmpBlock> Comparisons) {
586     assert(!Comparisons.empty() && "no basic block");
587     // Fast path: only one block, or no names at all.
588     if (Comparisons.size() == 1)
589       return Comparisons[0].BB->getName();
590     const int size = std::accumulate(Comparisons.begin(), Comparisons.end(), 0,
591                                      [](int i, const BCECmpBlock &Cmp) {
592                                        return i + Cmp.BB->getName().size();
593                                      });
594     if (size == 0)
595       return StringRef("", 0);
596 
597     // Slow path: at least two blocks, at least one block with a name.
598     Scratch.clear();
599     // We'll have `size` bytes for name and `Comparisons.size() - 1` bytes for
600     // separators.
601     Scratch.reserve(size + Comparisons.size() - 1);
602     const auto append = [this](StringRef str) {
603       Scratch.append(str.begin(), str.end());
604     };
605     append(Comparisons[0].BB->getName());
606     for (int I = 1, E = Comparisons.size(); I < E; ++I) {
607       const BasicBlock *const BB = Comparisons[I].BB;
608       if (!BB->getName().empty()) {
609         append("+");
610         append(BB->getName());
611       }
612     }
613     return Scratch.str();
614   }
615 };
616 } // namespace
617 
618 // Check whether the current comparisons is the last comparisons.
619 // Only the last comparisons chain include the last link (unconditionla jmp).
620 // Due to reordering of contiguous comparison blocks, the unconditionla
621 // comparison of a chain maybe not placed at the end of the array comparisons.
622 static bool isLastLinkComparison(ArrayRef<BCECmpBlock> Comparisons) {
623   for (int i = 0; i < Comparisons.size(); i++) {
624     BasicBlock *const BB = Comparisons[i].BB;
625     auto *const BranchI = cast<BranchInst>(BB->getTerminator());
626     // Only the last comparisons chain include the unconditionla jmp
627     if (BranchI->isUnconditional())
628       return true;
629   }
630   return false;
631 }
632 
633 // Merges the given contiguous comparison blocks into one memcmp block.
634 static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
635                                     BasicBlock *const InsertBefore,
636                                     BasicBlock *const NextCmpBlock,
637                                     PHINode &Phi, const TargetLibraryInfo &TLI,
638                                     AliasAnalysis &AA, DomTreeUpdater &DTU,
639                                     ICmpInst::Predicate Predicate) {
640   assert(!Comparisons.empty() && "merging zero comparisons");
641   LLVMContext &Context = NextCmpBlock->getContext();
642   const BCECmpBlock &FirstCmp = Comparisons[0];
643 
644   // Create a new cmp block before next cmp block.
645   BasicBlock *const BB =
646       BasicBlock::Create(Context, MergedBlockName(Comparisons).Name,
647                          NextCmpBlock->getParent(), InsertBefore);
648   IRBuilder<> Builder(BB);
649   // Add the GEPs from the first BCECmpBlock.
650   Value *Lhs, *Rhs;
651   if (FirstCmp.Lhs().GEP)
652     Lhs = Builder.Insert(FirstCmp.Lhs().GEP->clone());
653   else
654     Lhs = FirstCmp.Lhs().LoadI->getPointerOperand();
655   if (FirstCmp.Rhs().GEP)
656     Rhs = Builder.Insert(FirstCmp.Rhs().GEP->clone());
657   else
658     Rhs = FirstCmp.Rhs().LoadI->getPointerOperand();
659 
660   Value *ICmpValue = nullptr;
661   LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons -> "
662                     << BB->getName() << "\n");
663 
664   // If there is one block that requires splitting, we do it now, i.e.
665   // just before we know we will collapse the chain. The instructions
666   // can be executed before any of the instructions in the chain.
667   const auto ToSplit = llvm::find_if(
668       Comparisons, [](const BCECmpBlock &B) { return B.RequireSplit; });
669   if (ToSplit != Comparisons.end()) {
670     LLVM_DEBUG(dbgs() << "Splitting non_BCE work to header\n");
671     ToSplit->split(BB, AA);
672   }
673 
674   // For a Icmp chain, the Predicate is record the last link in the chain of
675   // comparisons. When we spilt the chain The new spilted chain of comparisons
676   // is end with ICMP_EQ.
677   ICmpInst::Predicate Pred =
678       isLastLinkComparison(Comparisons) ? Predicate : ICmpInst::ICMP_EQ;
679   if (Comparisons.size() == 1) {
680     LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n");
681     // Use clone to keep the metadata
682     Instruction *const LhsLoad = Builder.Insert(FirstCmp.Lhs().LoadI->clone());
683     Instruction *const RhsLoad = Builder.Insert(FirstCmp.Rhs().LoadI->clone());
684     LhsLoad->replaceUsesOfWith(LhsLoad->getOperand(0), Lhs);
685     RhsLoad->replaceUsesOfWith(RhsLoad->getOperand(0), Rhs);
686     // There are no blocks to merge, just do the comparison.
687     ICmpValue = Builder.CreateICmp(Pred, LhsLoad, RhsLoad);
688   } else {
689     const unsigned TotalSizeBits = std::accumulate(
690         Comparisons.begin(), Comparisons.end(), 0u,
691         [](int Size, const BCECmpBlock &C) { return Size + C.SizeBits(); });
692 
693     // memcmp expects a 'size_t' argument and returns 'int'.
694     unsigned SizeTBits = TLI.getSizeTSize(*Phi.getModule());
695     unsigned IntBits = TLI.getIntSize();
696 
697     // Create memcmp() == 0.
698     const auto &DL = Phi.getModule()->getDataLayout();
699     Value *const MemCmpCall = emitMemCmp(
700         Lhs, Rhs,
701         ConstantInt::get(Builder.getIntNTy(SizeTBits), TotalSizeBits / 8),
702         Builder, DL, &TLI);
703     ICmpValue = Builder.CreateICmp(
704         Pred, MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0));
705   }
706 
707   BasicBlock *const PhiBB = Phi.getParent();
708   // Add a branch to the next basic block in the chain.
709   if (NextCmpBlock == PhiBB) {
710     // Continue to phi, passing it the comparison result.
711     Builder.CreateBr(PhiBB);
712     Phi.addIncoming(ICmpValue, BB);
713     DTU.applyUpdates({{DominatorTree::Insert, BB, PhiBB}});
714   } else {
715     // Continue to next block if equal, exit to phi else.
716     Builder.CreateCondBr(ICmpValue, NextCmpBlock, PhiBB);
717     Value *ConstantVal = Predicate == CmpInst::ICMP_NE
718                              ? ConstantInt::getTrue(Context)
719                              : ConstantInt::getFalse(Context);
720     Phi.addIncoming(ConstantVal, BB);
721     DTU.applyUpdates({{DominatorTree::Insert, BB, NextCmpBlock},
722                       {DominatorTree::Insert, BB, PhiBB}});
723   }
724   return BB;
725 }
726 
727 bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
728                            DomTreeUpdater &DTU) {
729   assert(atLeastOneMerged() && "simplifying trivial BCECmpChain");
730   LLVM_DEBUG(dbgs() << "Simplifying comparison chain starting at block "
731                     << EntryBlock_->getName() << "\n");
732 
733   // Effectively merge blocks. We go in the reverse direction from the phi block
734   // so that the next block is always available to branch to.
735   BasicBlock *InsertBefore = EntryBlock_;
736   BasicBlock *NextCmpBlock = Phi_.getParent();
737   assert(Predicate_ != CmpInst::BAD_ICMP_PREDICATE &&
738          "Got the chain of comparisons");
739   for (const auto &Blocks : reverse(MergedBlocks_)) {
740     InsertBefore = NextCmpBlock = mergeComparisons(
741         Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU, Predicate_);
742   }
743 
744   // Replace the original cmp chain with the new cmp chain by pointing all
745   // predecessors of EntryBlock_ to NextCmpBlock instead. This makes all cmp
746   // blocks in the old chain unreachable.
747   while (!pred_empty(EntryBlock_)) {
748     BasicBlock* const Pred = *pred_begin(EntryBlock_);
749     LLVM_DEBUG(dbgs() << "Updating jump into old chain from " << Pred->getName()
750                       << "\n");
751     Pred->getTerminator()->replaceUsesOfWith(EntryBlock_, NextCmpBlock);
752     DTU.applyUpdates({{DominatorTree::Delete, Pred, EntryBlock_},
753                       {DominatorTree::Insert, Pred, NextCmpBlock}});
754   }
755 
756   // If the old cmp chain was the function entry, we need to update the function
757   // entry.
758   const bool ChainEntryIsFnEntry = EntryBlock_->isEntryBlock();
759   if (ChainEntryIsFnEntry && DTU.hasDomTree()) {
760     LLVM_DEBUG(dbgs() << "Changing function entry from "
761                       << EntryBlock_->getName() << " to "
762                       << NextCmpBlock->getName() << "\n");
763     DTU.getDomTree().setNewRoot(NextCmpBlock);
764     DTU.applyUpdates({{DominatorTree::Delete, NextCmpBlock, EntryBlock_}});
765   }
766   EntryBlock_ = nullptr;
767 
768   // Delete merged blocks. This also removes incoming values in phi.
769   SmallVector<BasicBlock *, 16> DeadBlocks;
770   for (const auto &Blocks : MergedBlocks_) {
771     for (const BCECmpBlock &Block : Blocks) {
772       LLVM_DEBUG(dbgs() << "Deleting merged block " << Block.BB->getName()
773                         << "\n");
774       DeadBlocks.push_back(Block.BB);
775     }
776   }
777   DeleteDeadBlocks(DeadBlocks, &DTU);
778 
779   MergedBlocks_.clear();
780   return true;
781 }
782 
783 std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi,
784                                            BasicBlock *const LastBlock,
785                                            int NumBlocks) {
786   // Walk up from the last block to find other blocks.
787   std::vector<BasicBlock *> Blocks(NumBlocks);
788   assert(LastBlock && "invalid last block");
789   BasicBlock *CurBlock = LastBlock;
790   for (int BlockIndex = NumBlocks - 1; BlockIndex > 0; --BlockIndex) {
791     if (CurBlock->hasAddressTaken()) {
792       // Somebody is jumping to the block through an address, all bets are
793       // off.
794       LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex
795                         << " has its address taken\n");
796       return {};
797     }
798     Blocks[BlockIndex] = CurBlock;
799     auto *SinglePredecessor = CurBlock->getSinglePredecessor();
800     if (!SinglePredecessor) {
801       // The block has two or more predecessors.
802       LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex
803                         << " has two or more predecessors\n");
804       return {};
805     }
806     if (Phi.getBasicBlockIndex(SinglePredecessor) < 0) {
807       // The block does not link back to the phi.
808       LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex
809                         << " does not link back to the phi\n");
810       return {};
811     }
812     CurBlock = SinglePredecessor;
813   }
814   Blocks[0] = CurBlock;
815   return Blocks;
816 }
817 
818 bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, AliasAnalysis &AA,
819                 DomTreeUpdater &DTU) {
820   LLVM_DEBUG(dbgs() << "processPhi()\n");
821   if (Phi.getNumIncomingValues() <= 1) {
822     LLVM_DEBUG(dbgs() << "skip: only one incoming value in phi\n");
823     return false;
824   }
825   // We are looking for something that has the following structure:
826   //   bb1 --eq--> bb2 --eq--> bb3 --eq--> bb4 --+
827   //     \            \           \               \
828   //      ne           ne          ne              \
829   //       \            \           \               v
830   //        +------------+-----------+----------> bb_phi
831   //
832   //  - The last basic block (bb4 here) must branch unconditionally to bb_phi.
833   //    It's the only block that contributes a non-constant value to the Phi.
834   //  - All other blocks (b1, b2, b3) must have exactly two successors, one of
835   //    them being the phi block.
836   //  - All intermediate blocks (bb2, bb3) must have only one predecessor.
837   //  - Blocks cannot do other work besides the comparison, see doesOtherWork()
838 
839   // The blocks are not necessarily ordered in the phi, so we start from the
840   // last block and reconstruct the order.
841   BasicBlock *LastBlock = nullptr;
842   for (unsigned I = 0; I < Phi.getNumIncomingValues(); ++I) {
843     if (isa<ConstantInt>(Phi.getIncomingValue(I))) continue;
844     if (LastBlock) {
845       // There are several non-constant values.
846       LLVM_DEBUG(dbgs() << "skip: several non-constant values\n");
847       return false;
848     }
849     if (!isa<ICmpInst>(Phi.getIncomingValue(I)) ||
850         cast<ICmpInst>(Phi.getIncomingValue(I))->getParent() !=
851             Phi.getIncomingBlock(I)) {
852       // Non-constant incoming value is not from a cmp instruction or not
853       // produced by the last block. We could end up processing the value
854       // producing block more than once.
855       //
856       // This is an uncommon case, so we bail.
857       LLVM_DEBUG(
858           dbgs()
859           << "skip: non-constant value not from cmp or not from last block.\n");
860       return false;
861     }
862     LastBlock = Phi.getIncomingBlock(I);
863   }
864   if (!LastBlock) {
865     // There is no non-constant block.
866     LLVM_DEBUG(dbgs() << "skip: no non-constant block\n");
867     return false;
868   }
869   if (LastBlock->getSingleSuccessor() != Phi.getParent()) {
870     LLVM_DEBUG(dbgs() << "skip: last block non-phi successor\n");
871     return false;
872   }
873 
874   const auto Blocks =
875       getOrderedBlocks(Phi, LastBlock, Phi.getNumIncomingValues());
876   if (Blocks.empty()) return false;
877   BCECmpChain CmpChain(Blocks, Phi, AA);
878 
879   if (!CmpChain.atLeastOneMerged()) {
880     LLVM_DEBUG(dbgs() << "skip: nothing merged\n");
881     return false;
882   }
883 
884   return CmpChain.simplify(TLI, AA, DTU);
885 }
886 
887 static bool runImpl(Function &F, const TargetLibraryInfo &TLI,
888                     const TargetTransformInfo &TTI, AliasAnalysis &AA,
889                     DominatorTree *DT) {
890   LLVM_DEBUG(dbgs() << "MergeICmpsLegacyPass: " << F.getName() << "\n");
891 
892   // We only try merging comparisons if the target wants to expand memcmp later.
893   // The rationale is to avoid turning small chains into memcmp calls.
894   if (!TTI.enableMemCmpExpansion(F.hasOptSize(), true))
895     return false;
896 
897   // If we don't have memcmp avaiable we can't emit calls to it.
898   if (!TLI.has(LibFunc_memcmp))
899     return false;
900 
901   DomTreeUpdater DTU(DT, /*PostDominatorTree*/ nullptr,
902                      DomTreeUpdater::UpdateStrategy::Eager);
903 
904   bool MadeChange = false;
905 
906   for (BasicBlock &BB : llvm::drop_begin(F)) {
907     // A Phi operation is always first in a basic block.
908     if (auto *const Phi = dyn_cast<PHINode>(&*BB.begin()))
909       MadeChange |= processPhi(*Phi, TLI, AA, DTU);
910   }
911 
912   return MadeChange;
913 }
914 
915 class MergeICmpsLegacyPass : public FunctionPass {
916 public:
917   static char ID;
918 
919   MergeICmpsLegacyPass() : FunctionPass(ID) {
920     initializeMergeICmpsLegacyPassPass(*PassRegistry::getPassRegistry());
921   }
922 
923   bool runOnFunction(Function &F) override {
924     if (skipFunction(F)) return false;
925     const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
926     const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
927     // MergeICmps does not need the DominatorTree, but we update it if it's
928     // already available.
929     auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
930     auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
931     return runImpl(F, TLI, TTI, AA, DTWP ? &DTWP->getDomTree() : nullptr);
932   }
933 
934  private:
935   void getAnalysisUsage(AnalysisUsage &AU) const override {
936     AU.addRequired<TargetLibraryInfoWrapperPass>();
937     AU.addRequired<TargetTransformInfoWrapperPass>();
938     AU.addRequired<AAResultsWrapperPass>();
939     AU.addPreserved<GlobalsAAWrapperPass>();
940     AU.addPreserved<DominatorTreeWrapperPass>();
941   }
942 };
943 
944 } // namespace
945 
946 char MergeICmpsLegacyPass::ID = 0;
947 INITIALIZE_PASS_BEGIN(MergeICmpsLegacyPass, "mergeicmps",
948                       "Merge contiguous icmps into a memcmp", false, false)
949 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
950 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
951 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
952 INITIALIZE_PASS_END(MergeICmpsLegacyPass, "mergeicmps",
953                     "Merge contiguous icmps into a memcmp", false, false)
954 
955 Pass *llvm::createMergeICmpsLegacyPass() { return new MergeICmpsLegacyPass(); }
956 
957 PreservedAnalyses MergeICmpsPass::run(Function &F,
958                                       FunctionAnalysisManager &AM) {
959   auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
960   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
961   auto &AA = AM.getResult<AAManager>(F);
962   auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
963   const bool MadeChanges = runImpl(F, TLI, TTI, AA, DT);
964   if (!MadeChanges)
965     return PreservedAnalyses::all();
966   PreservedAnalyses PA;
967   PA.preserve<DominatorTreeAnalysis>();
968   return PA;
969 }
970