xref: /llvm-project/llvm/lib/Transforms/Scalar/MergeICmps.cpp (revision a95d95d3922e1a24d8b9affdd570c1d8fca00129)
1 //===- MergeICmps.cpp - Optimize chains of integer comparisons ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass turns chains of integer comparisons into memcmp (the memcmp is
10 // later typically inlined as a chain of efficient hardware comparisons). This
11 // typically benefits c++ member or nonmember operator==().
12 //
13 // The basic idea is to replace a longer chain of integer comparisons loaded
14 // from contiguous memory locations into a shorter chain of larger integer
15 // comparisons. Benefits are double:
16 //  - There are less jumps, and therefore less opportunities for mispredictions
17 //    and I-cache misses.
18 //  - Code size is smaller, both because jumps are removed and because the
19 //    encoding of a 2*n byte compare is smaller than that of two n-byte
20 //    compares.
21 //
22 // Example:
23 //
24 //  struct S {
25 //    int a;
26 //    char b;
27 //    char c;
28 //    uint16_t d;
29 //    bool operator==(const S& o) const {
30 //      return a == o.a && b == o.b && c == o.c && d == o.d;
31 //    }
32 //  };
33 //
34 //  Is optimized as :
35 //
36 //    bool S::operator==(const S& o) const {
37 //      return memcmp(this, &o, 8) == 0;
38 //    }
39 //
40 //  Which will later be expanded (ExpandMemCmp) as a single 8-bytes icmp.
41 //
42 //===----------------------------------------------------------------------===//
43 
44 #include "llvm/Analysis/DomTreeUpdater.h"
45 #include "llvm/Analysis/GlobalsModRef.h"
46 #include "llvm/Analysis/Loads.h"
47 #include "llvm/Analysis/TargetLibraryInfo.h"
48 #include "llvm/Analysis/TargetTransformInfo.h"
49 #include "llvm/IR/Dominators.h"
50 #include "llvm/IR/Function.h"
51 #include "llvm/IR/IRBuilder.h"
52 #include "llvm/Pass.h"
53 #include "llvm/Transforms/Scalar.h"
54 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
55 #include "llvm/Transforms/Utils/BuildLibCalls.h"
56 #include <algorithm>
57 #include <numeric>
58 #include <utility>
59 #include <vector>
60 
61 using namespace llvm;
62 
63 namespace {
64 
65 #define DEBUG_TYPE "mergeicmps"
66 
67 // Returns true if the instruction is a simple load or a simple store
68 static bool isSimpleLoadOrStore(const Instruction *I) {
69   if (const LoadInst *LI = dyn_cast<LoadInst>(I))
70     return LI->isSimple();
71   if (const StoreInst *SI = dyn_cast<StoreInst>(I))
72     return SI->isSimple();
73   return false;
74 }
75 
76 // A BCE atom "Binary Compare Expression Atom" represents an integer load
77 // that is a constant offset from a base value, e.g. `a` or `o.c` in the example
78 // at the top.
79 struct BCEAtom {
80   BCEAtom() = default;
81   BCEAtom(GetElementPtrInst *GEP, LoadInst *LoadI, int BaseId, APInt Offset)
82       : GEP(GEP), LoadI(LoadI), BaseId(BaseId), Offset(Offset) {}
83 
84   // We want to order BCEAtoms by (Base, Offset). However we cannot use
85   // the pointer values for Base because these are non-deterministic.
86   // To make sure that the sort order is stable, we first assign to each atom
87   // base value an index based on its order of appearance in the chain of
88   // comparisons. We call this index `BaseOrdering`. For example, for:
89   //    b[3] == c[2] && a[1] == d[1] && b[4] == c[3]
90   //    |  block 1 |    |  block 2 |    |  block 3 |
91   // b gets assigned index 0 and a index 1, because b appears as LHS in block 1,
92   // which is before block 2.
93   // We then sort by (BaseOrdering[LHS.Base()], LHS.Offset), which is stable.
94   bool operator<(const BCEAtom &O) const {
95     return BaseId != O.BaseId ? BaseId < O.BaseId : Offset.slt(O.Offset);
96   }
97 
98   GetElementPtrInst *GEP = nullptr;
99   LoadInst *LoadI = nullptr;
100   unsigned BaseId = 0;
101   APInt Offset;
102 };
103 
104 // A class that assigns increasing ids to values in the order in which they are
105 // seen. See comment in `BCEAtom::operator<()``.
106 class BaseIdentifier {
107 public:
108   // Returns the id for value `Base`, after assigning one if `Base` has not been
109   // seen before.
110   int getBaseId(const Value *Base) {
111     assert(Base && "invalid base");
112     const auto Insertion = BaseToIndex.try_emplace(Base, Order);
113     if (Insertion.second)
114       ++Order;
115     return Insertion.first->second;
116   }
117 
118 private:
119   unsigned Order = 1;
120   DenseMap<const Value*, int> BaseToIndex;
121 };
122 
123 // If this value is a load from a constant offset w.r.t. a base address, and
124 // there are no other users of the load or address, returns the base address and
125 // the offset.
126 BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) {
127   auto *const LoadI = dyn_cast<LoadInst>(Val);
128   if (!LoadI)
129     return {};
130   LLVM_DEBUG(dbgs() << "load\n");
131   if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) {
132     LLVM_DEBUG(dbgs() << "used outside of block\n");
133     return {};
134   }
135   // Do not optimize atomic loads to non-atomic memcmp
136   if (!LoadI->isSimple()) {
137     LLVM_DEBUG(dbgs() << "volatile or atomic\n");
138     return {};
139   }
140   Value *const Addr = LoadI->getOperand(0);
141   auto *const GEP = dyn_cast<GetElementPtrInst>(Addr);
142   if (!GEP)
143     return {};
144   LLVM_DEBUG(dbgs() << "GEP\n");
145   if (GEP->isUsedOutsideOfBlock(LoadI->getParent())) {
146     LLVM_DEBUG(dbgs() << "used outside of block\n");
147     return {};
148   }
149   const auto &DL = GEP->getModule()->getDataLayout();
150   if (!isDereferenceablePointer(GEP, DL)) {
151     LLVM_DEBUG(dbgs() << "not dereferenceable\n");
152     // We need to make sure that we can do comparison in any order, so we
153     // require memory to be unconditionnally dereferencable.
154     return {};
155   }
156   APInt Offset = APInt(DL.getPointerTypeSizeInBits(GEP->getType()), 0);
157   if (!GEP->accumulateConstantOffset(DL, Offset))
158     return {};
159   return BCEAtom(GEP, LoadI, BaseId.getBaseId(GEP->getPointerOperand()),
160                  Offset);
161 }
162 
163 // A basic block with a comparison between two BCE atoms, e.g. `a == o.a` in the
164 // example at the top.
165 // The block might do extra work besides the atom comparison, in which case
166 // doesOtherWork() returns true. Under some conditions, the block can be
167 // split into the atom comparison part and the "other work" part
168 // (see canSplit()).
169 // Note: the terminology is misleading: the comparison is symmetric, so there
170 // is no real {l/r}hs. What we want though is to have the same base on the
171 // left (resp. right), so that we can detect consecutive loads. To ensure this
172 // we put the smallest atom on the left.
173 class BCECmpBlock {
174  public:
175   BCECmpBlock() {}
176 
177   BCECmpBlock(BCEAtom L, BCEAtom R, int SizeBits)
178       : Lhs_(L), Rhs_(R), SizeBits_(SizeBits) {
179     if (Rhs_ < Lhs_) std::swap(Rhs_, Lhs_);
180   }
181 
182   bool IsValid() const { return Lhs_.BaseId != 0 && Rhs_.BaseId != 0; }
183 
184   // Assert the block is consistent: If valid, it should also have
185   // non-null members besides Lhs_ and Rhs_.
186   void AssertConsistent() const {
187     if (IsValid()) {
188       assert(BB);
189       assert(CmpI);
190       assert(BranchI);
191     }
192   }
193 
194   const BCEAtom &Lhs() const { return Lhs_; }
195   const BCEAtom &Rhs() const { return Rhs_; }
196   int SizeBits() const { return SizeBits_; }
197 
198   // Returns true if the block does other works besides comparison.
199   bool doesOtherWork() const;
200 
201   // Returns true if the non-BCE-cmp instructions can be separated from BCE-cmp
202   // instructions in the block.
203   bool canSplit(AliasAnalysis *AA) const;
204 
205   // Return true if this all the relevant instructions in the BCE-cmp-block can
206   // be sunk below this instruction. By doing this, we know we can separate the
207   // BCE-cmp-block instructions from the non-BCE-cmp-block instructions in the
208   // block.
209   bool canSinkBCECmpInst(const Instruction *, DenseSet<Instruction *> &,
210                          AliasAnalysis *AA) const;
211 
212   // We can separate the BCE-cmp-block instructions and the non-BCE-cmp-block
213   // instructions. Split the old block and move all non-BCE-cmp-insts into the
214   // new parent block.
215   void split(BasicBlock *NewParent, AliasAnalysis *AA) const;
216 
217   // The basic block where this comparison happens.
218   BasicBlock *BB = nullptr;
219   // The ICMP for this comparison.
220   ICmpInst *CmpI = nullptr;
221   // The terminating branch.
222   BranchInst *BranchI = nullptr;
223   // The block requires splitting.
224   bool RequireSplit = false;
225 
226 private:
227   BCEAtom Lhs_;
228   BCEAtom Rhs_;
229   int SizeBits_ = 0;
230 };
231 
232 bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst,
233                                     DenseSet<Instruction *> &BlockInsts,
234                                     AliasAnalysis *AA) const {
235   // If this instruction has side effects and its in middle of the BCE cmp block
236   // instructions, then bail for now.
237   if (Inst->mayHaveSideEffects()) {
238     // Bail if this is not a simple load or store
239     if (!isSimpleLoadOrStore(Inst))
240       return false;
241     // Disallow stores that might alias the BCE operands
242     MemoryLocation LLoc = MemoryLocation::get(Lhs_.LoadI);
243     MemoryLocation RLoc = MemoryLocation::get(Rhs_.LoadI);
244     if (isModSet(AA->getModRefInfo(Inst, LLoc)) ||
245         isModSet(AA->getModRefInfo(Inst, RLoc)))
246         return false;
247   }
248   // Make sure this instruction does not use any of the BCE cmp block
249   // instructions as operand.
250   for (auto BI : BlockInsts) {
251     if (is_contained(Inst->operands(), BI))
252       return false;
253   }
254   return true;
255 }
256 
257 void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis *AA) const {
258   DenseSet<Instruction *> BlockInsts(
259       {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI});
260   llvm::SmallVector<Instruction *, 4> OtherInsts;
261   for (Instruction &Inst : *BB) {
262     if (BlockInsts.count(&Inst))
263       continue;
264       assert(canSinkBCECmpInst(&Inst, BlockInsts, AA) &&
265              "Split unsplittable block");
266     // This is a non-BCE-cmp-block instruction. And it can be separated
267     // from the BCE-cmp-block instruction.
268     OtherInsts.push_back(&Inst);
269   }
270 
271   // Do the actual spliting.
272   for (Instruction *Inst : reverse(OtherInsts)) {
273     Inst->moveBefore(&*NewParent->begin());
274   }
275 }
276 
277 bool BCECmpBlock::canSplit(AliasAnalysis *AA) const {
278   DenseSet<Instruction *> BlockInsts(
279       {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI});
280   for (Instruction &Inst : *BB) {
281     if (!BlockInsts.count(&Inst)) {
282       if (!canSinkBCECmpInst(&Inst, BlockInsts, AA))
283         return false;
284     }
285   }
286   return true;
287 }
288 
289 bool BCECmpBlock::doesOtherWork() const {
290   AssertConsistent();
291   // All the instructions we care about in the BCE cmp block.
292   DenseSet<Instruction *> BlockInsts(
293       {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI});
294   // TODO(courbet): Can we allow some other things ? This is very conservative.
295   // We might be able to get away with anything does not have any side
296   // effects outside of the basic block.
297   // Note: The GEPs and/or loads are not necessarily in the same block.
298   for (const Instruction &Inst : *BB) {
299     if (!BlockInsts.count(&Inst))
300       return true;
301   }
302   return false;
303 }
304 
305 // Visit the given comparison. If this is a comparison between two valid
306 // BCE atoms, returns the comparison.
307 BCECmpBlock visitICmp(const ICmpInst *const CmpI,
308                       const ICmpInst::Predicate ExpectedPredicate,
309                       BaseIdentifier &BaseId) {
310   // The comparison can only be used once:
311   //  - For intermediate blocks, as a branch condition.
312   //  - For the final block, as an incoming value for the Phi.
313   // If there are any other uses of the comparison, we cannot merge it with
314   // other comparisons as we would create an orphan use of the value.
315   if (!CmpI->hasOneUse()) {
316     LLVM_DEBUG(dbgs() << "cmp has several uses\n");
317     return {};
318   }
319   if (CmpI->getPredicate() != ExpectedPredicate)
320     return {};
321   LLVM_DEBUG(dbgs() << "cmp "
322                     << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne")
323                     << "\n");
324   auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0), BaseId);
325   if (!Lhs.BaseId)
326     return {};
327   auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1), BaseId);
328   if (!Rhs.BaseId)
329     return {};
330   const auto &DL = CmpI->getModule()->getDataLayout();
331   return BCECmpBlock(std::move(Lhs), std::move(Rhs),
332                      DL.getTypeSizeInBits(CmpI->getOperand(0)->getType()));
333 }
334 
335 // Visit the given comparison block. If this is a comparison between two valid
336 // BCE atoms, returns the comparison.
337 BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block,
338                           const BasicBlock *const PhiBlock,
339                           BaseIdentifier &BaseId) {
340   if (Block->empty()) return {};
341   auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator());
342   if (!BranchI) return {};
343   LLVM_DEBUG(dbgs() << "branch\n");
344   if (BranchI->isUnconditional()) {
345     // In this case, we expect an incoming value which is the result of the
346     // comparison. This is the last link in the chain of comparisons (note
347     // that this does not mean that this is the last incoming value, blocks
348     // can be reordered).
349     auto *const CmpI = dyn_cast<ICmpInst>(Val);
350     if (!CmpI) return {};
351     LLVM_DEBUG(dbgs() << "icmp\n");
352     auto Result = visitICmp(CmpI, ICmpInst::ICMP_EQ, BaseId);
353     Result.CmpI = CmpI;
354     Result.BranchI = BranchI;
355     return Result;
356   } else {
357     // In this case, we expect a constant incoming value (the comparison is
358     // chained).
359     const auto *const Const = dyn_cast<ConstantInt>(Val);
360     LLVM_DEBUG(dbgs() << "const\n");
361     if (!Const->isZero()) return {};
362     LLVM_DEBUG(dbgs() << "false\n");
363     auto *const CmpI = dyn_cast<ICmpInst>(BranchI->getCondition());
364     if (!CmpI) return {};
365     LLVM_DEBUG(dbgs() << "icmp\n");
366     assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch");
367     BasicBlock *const FalseBlock = BranchI->getSuccessor(1);
368     auto Result = visitICmp(
369         CmpI, FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
370         BaseId);
371     Result.CmpI = CmpI;
372     Result.BranchI = BranchI;
373     return Result;
374   }
375   return {};
376 }
377 
378 static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons,
379                                 BCECmpBlock &Comparison) {
380   LLVM_DEBUG(dbgs() << "Block '" << Comparison.BB->getName()
381                     << "': Found cmp of " << Comparison.SizeBits()
382                     << " bits between " << Comparison.Lhs().BaseId << " + "
383                     << Comparison.Lhs().Offset << " and "
384                     << Comparison.Rhs().BaseId << " + "
385                     << Comparison.Rhs().Offset << "\n");
386   LLVM_DEBUG(dbgs() << "\n");
387   Comparisons.push_back(Comparison);
388 }
389 
390 // A chain of comparisons.
391 class BCECmpChain {
392  public:
393   BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
394               AliasAnalysis *AA);
395 
396   int size() const { return Comparisons_.size(); }
397 
398 #ifdef MERGEICMPS_DOT_ON
399   void dump() const;
400 #endif  // MERGEICMPS_DOT_ON
401 
402   bool simplify(const TargetLibraryInfo *const TLI, AliasAnalysis *AA,
403                 DomTreeUpdater &DTU);
404 
405 private:
406   static bool IsContiguous(const BCECmpBlock &First,
407                            const BCECmpBlock &Second) {
408     return First.Lhs().BaseId == Second.Lhs().BaseId &&
409            First.Rhs().BaseId == Second.Rhs().BaseId &&
410            First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset &&
411            First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset;
412   }
413 
414   PHINode &Phi_;
415   std::vector<BCECmpBlock> Comparisons_;
416   // The original entry block (before sorting);
417   BasicBlock *EntryBlock_;
418 };
419 
420 BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
421                          AliasAnalysis *AA)
422     : Phi_(Phi) {
423   assert(!Blocks.empty() && "a chain should have at least one block");
424   // Now look inside blocks to check for BCE comparisons.
425   std::vector<BCECmpBlock> Comparisons;
426   BaseIdentifier BaseId;
427   for (size_t BlockIdx = 0; BlockIdx < Blocks.size(); ++BlockIdx) {
428     BasicBlock *const Block = Blocks[BlockIdx];
429     assert(Block && "invalid block");
430     BCECmpBlock Comparison = visitCmpBlock(Phi.getIncomingValueForBlock(Block),
431                                            Block, Phi.getParent(), BaseId);
432     Comparison.BB = Block;
433     if (!Comparison.IsValid()) {
434       LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n");
435       return;
436     }
437     if (Comparison.doesOtherWork()) {
438       LLVM_DEBUG(dbgs() << "block '" << Comparison.BB->getName()
439                         << "' does extra work besides compare\n");
440       if (Comparisons.empty()) {
441         // This is the initial block in the chain, in case this block does other
442         // work, we can try to split the block and move the irrelevant
443         // instructions to the predecessor.
444         //
445         // If this is not the initial block in the chain, splitting it wont
446         // work.
447         //
448         // As once split, there will still be instructions before the BCE cmp
449         // instructions that do other work in program order, i.e. within the
450         // chain before sorting. Unless we can abort the chain at this point
451         // and start anew.
452         //
453         // NOTE: we only handle blocks a with single predecessor for now.
454         if (Comparison.canSplit(AA)) {
455           LLVM_DEBUG(dbgs()
456                      << "Split initial block '" << Comparison.BB->getName()
457                      << "' that does extra work besides compare\n");
458           Comparison.RequireSplit = true;
459           enqueueBlock(Comparisons, Comparison);
460         } else {
461           LLVM_DEBUG(dbgs()
462                      << "ignoring initial block '" << Comparison.BB->getName()
463                      << "' that does extra work besides compare\n");
464         }
465         continue;
466       }
467       // TODO(courbet): Right now we abort the whole chain. We could be
468       // merging only the blocks that don't do other work and resume the
469       // chain from there. For example:
470       //  if (a[0] == b[0]) {  // bb1
471       //    if (a[1] == b[1]) {  // bb2
472       //      some_value = 3; //bb3
473       //      if (a[2] == b[2]) { //bb3
474       //        do a ton of stuff  //bb4
475       //      }
476       //    }
477       //  }
478       //
479       // This is:
480       //
481       // bb1 --eq--> bb2 --eq--> bb3* -eq--> bb4 --+
482       //  \            \           \               \
483       //   ne           ne          ne              \
484       //    \            \           \               v
485       //     +------------+-----------+----------> bb_phi
486       //
487       // We can only merge the first two comparisons, because bb3* does
488       // "other work" (setting some_value to 3).
489       // We could still merge bb1 and bb2 though.
490       return;
491     }
492     enqueueBlock(Comparisons, Comparison);
493   }
494 
495   // It is possible we have no suitable comparison to merge.
496   if (Comparisons.empty()) {
497     LLVM_DEBUG(dbgs() << "chain with no BCE basic blocks, no merge\n");
498     return;
499   }
500   EntryBlock_ = Comparisons[0].BB;
501   Comparisons_ = std::move(Comparisons);
502 #ifdef MERGEICMPS_DOT_ON
503   errs() << "BEFORE REORDERING:\n\n";
504   dump();
505 #endif  // MERGEICMPS_DOT_ON
506   // Reorder blocks by LHS. We can do that without changing the
507   // semantics because we are only accessing dereferencable memory.
508   llvm::sort(Comparisons_,
509              [](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) {
510                return LhsBlock.Lhs() < RhsBlock.Lhs();
511              });
512 #ifdef MERGEICMPS_DOT_ON
513   errs() << "AFTER REORDERING:\n\n";
514   dump();
515 #endif  // MERGEICMPS_DOT_ON
516 }
517 
518 #ifdef MERGEICMPS_DOT_ON
519 void BCECmpChain::dump() const {
520   errs() << "digraph dag {\n";
521   errs() << " graph [bgcolor=transparent];\n";
522   errs() << " node [color=black,style=filled,fillcolor=lightyellow];\n";
523   errs() << " edge [color=black];\n";
524   for (size_t I = 0; I < Comparisons_.size(); ++I) {
525     const auto &Comparison = Comparisons_[I];
526     errs() << " \"" << I << "\" [label=\"%"
527            << Comparison.Lhs().Base()->getName() << " + "
528            << Comparison.Lhs().Offset << " == %"
529            << Comparison.Rhs().Base()->getName() << " + "
530            << Comparison.Rhs().Offset << " (" << (Comparison.SizeBits() / 8)
531            << " bytes)\"];\n";
532     const Value *const Val = Phi_.getIncomingValueForBlock(Comparison.BB);
533     if (I > 0) errs() << " \"" << (I - 1) << "\" -> \"" << I << "\";\n";
534     errs() << " \"" << I << "\" -> \"Phi\" [label=\"" << *Val << "\"];\n";
535   }
536   errs() << " \"Phi\" [label=\"Phi\"];\n";
537   errs() << "}\n\n";
538 }
539 #endif  // MERGEICMPS_DOT_ON
540 
541 namespace {
542 
543 // A class to compute the name of a set of merged basic blocks.
544 // This is optimized for the common case of no block names.
545 class MergedBlockName {
546   // Storage for the uncommon case of several named blocks.
547   SmallString<16> Scratch;
548 
549 public:
550   explicit MergedBlockName(ArrayRef<BCECmpBlock> Comparisons)
551       : Name(makeName(Comparisons)) {}
552   const StringRef Name;
553 
554 private:
555   StringRef makeName(ArrayRef<BCECmpBlock> Comparisons) {
556     assert(!Comparisons.empty() && "no basic block");
557     // Fast path: only one block, or no names at all.
558     if (Comparisons.size() == 1)
559       return Comparisons[0].BB->getName();
560     const int size = std::accumulate(Comparisons.begin(), Comparisons.end(), 0,
561                                      [](int i, const BCECmpBlock &Cmp) {
562                                        return i + Cmp.BB->getName().size();
563                                      });
564     if (size == 0)
565       return StringRef("", 0);
566 
567     // Slow path: at least two blocks, at least one block with a name.
568     Scratch.clear();
569     // We'll have `size` bytes for name and `Comparisons.size() - 1` bytes for
570     // separators.
571     Scratch.reserve(size + Comparisons.size() - 1);
572     const auto append = [this](StringRef str) {
573       Scratch.append(str.begin(), str.end());
574     };
575     append(Comparisons[0].BB->getName());
576     for (int I = 1, E = Comparisons.size(); I < E; ++I) {
577       const BasicBlock *const BB = Comparisons[I].BB;
578       if (!BB->getName().empty()) {
579         append("+");
580         append(BB->getName());
581       }
582     }
583     return StringRef(Scratch);
584   }
585 };
586 } // namespace
587 
588 // Merges the given contiguous comparison blocks into one memcmp block.
589 static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
590                                     BasicBlock *const InsertBefore,
591                                     BasicBlock *const NextCmpBlock,
592                                     PHINode &Phi,
593                                     const TargetLibraryInfo *const TLI,
594                                     AliasAnalysis *AA, DomTreeUpdater &DTU) {
595   assert(!Comparisons.empty() && "merging zero comparisons");
596   LLVMContext &Context = NextCmpBlock->getContext();
597   const BCECmpBlock &FirstCmp = Comparisons[0];
598 
599   // Create a new cmp block before next cmp block.
600   BasicBlock *const BB =
601       BasicBlock::Create(Context, MergedBlockName(Comparisons).Name,
602                          NextCmpBlock->getParent(), InsertBefore);
603   IRBuilder<> Builder(BB);
604   // Add the GEPs from the first BCECmpBlock.
605   Value *const Lhs = Builder.Insert(FirstCmp.Lhs().GEP->clone());
606   Value *const Rhs = Builder.Insert(FirstCmp.Rhs().GEP->clone());
607 
608   Value *IsEqual = nullptr;
609   LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons -> "
610                     << BB->getName() << "\n");
611   if (Comparisons.size() == 1) {
612     LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n");
613     Value *const LhsLoad =
614         Builder.CreateLoad(FirstCmp.Lhs().LoadI->getType(), Lhs);
615     Value *const RhsLoad =
616         Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs);
617     // There are no blocks to merge, just do the comparison.
618     IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
619   } else {
620     // If there is one block that requires splitting, we do it now, i.e.
621     // just before we know we will collapse the chain. The instructions
622     // can be executed before any of the instructions in the chain.
623     const auto ToSplit =
624         std::find_if(Comparisons.begin(), Comparisons.end(),
625                      [](const BCECmpBlock &B) { return B.RequireSplit; });
626     if (ToSplit != Comparisons.end()) {
627       LLVM_DEBUG(dbgs() << "Splitting non_BCE work to header\n");
628       ToSplit->split(BB, AA);
629     }
630 
631     const unsigned TotalSizeBits = std::accumulate(
632         Comparisons.begin(), Comparisons.end(), 0u,
633         [](int Size, const BCECmpBlock &C) { return Size + C.SizeBits(); });
634 
635     // Create memcmp() == 0.
636     const auto &DL = Phi.getModule()->getDataLayout();
637     Value *const MemCmpCall = emitMemCmp(
638         Lhs, Rhs,
639         ConstantInt::get(DL.getIntPtrType(Context), TotalSizeBits / 8), Builder,
640         DL, TLI);
641     IsEqual = Builder.CreateICmpEQ(
642         MemCmpCall, ConstantInt::get(Type::getInt32Ty(Context), 0));
643   }
644 
645   BasicBlock *const PhiBB = Phi.getParent();
646   // Add a branch to the next basic block in the chain.
647   if (NextCmpBlock == PhiBB) {
648     // Continue to phi, passing it the comparison result.
649     Builder.CreateBr(PhiBB);
650     Phi.addIncoming(IsEqual, BB);
651     DTU.applyUpdates({{DominatorTree::Insert, BB, PhiBB}});
652   } else {
653     // Continue to next block if equal, exit to phi else.
654     Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB);
655     Phi.addIncoming(ConstantInt::getFalse(Context), BB);
656     DTU.applyUpdates({{DominatorTree::Insert, BB, NextCmpBlock},
657                       {DominatorTree::Insert, BB, PhiBB}});
658   }
659   return BB;
660 }
661 
662 bool BCECmpChain::simplify(const TargetLibraryInfo *const TLI,
663                            AliasAnalysis *AA, DomTreeUpdater &DTU) {
664   assert(Comparisons_.size() >= 2 && "simplifying trivial BCECmpChain");
665   // First pass to check if there is at least one merge. If not, we don't do
666   // anything and we keep analysis passes intact.
667   const auto AtLeastOneMerged = [this]() {
668     for (size_t I = 1; I < Comparisons_.size(); ++I) {
669       if (IsContiguous(Comparisons_[I - 1], Comparisons_[I]))
670         return true;
671     }
672     return false;
673   };
674   if (!AtLeastOneMerged())
675     return false;
676 
677   LLVM_DEBUG(dbgs() << "Simplifying comparison chain starting at block "
678                     << EntryBlock_->getName() << "\n");
679 
680   // Effectively merge blocks. We go in the reverse direction from the phi block
681   // so that the next block is always available to branch to.
682   const auto mergeRange = [this, TLI, AA, &DTU](int I, int Num,
683                                                 BasicBlock *InsertBefore,
684                                                 BasicBlock *Next) {
685     return mergeComparisons(makeArrayRef(Comparisons_).slice(I, Num),
686                             InsertBefore, Next, Phi_, TLI, AA, DTU);
687   };
688   int NumMerged = 1;
689   BasicBlock *NextCmpBlock = Phi_.getParent();
690   for (int I = static_cast<int>(Comparisons_.size()) - 2; I >= 0; --I) {
691     if (IsContiguous(Comparisons_[I], Comparisons_[I + 1])) {
692       LLVM_DEBUG(dbgs() << "Merging block " << Comparisons_[I].BB->getName()
693                         << " into " << Comparisons_[I + 1].BB->getName()
694                         << "\n");
695       ++NumMerged;
696     } else {
697       NextCmpBlock = mergeRange(I + 1, NumMerged, NextCmpBlock, NextCmpBlock);
698       NumMerged = 1;
699     }
700   }
701   // Insert the entry block for the new chain before the old entry block.
702   // If the old entry block was the function entry, this ensures that the new
703   // entry can become the function entry.
704   NextCmpBlock = mergeRange(0, NumMerged, EntryBlock_, NextCmpBlock);
705 
706   // Replace the original cmp chain with the new cmp chain by pointing all
707   // predecessors of EntryBlock_ to NextCmpBlock instead. This makes all cmp
708   // blocks in the old chain unreachable.
709   while (!pred_empty(EntryBlock_)) {
710     BasicBlock* const Pred = *pred_begin(EntryBlock_);
711     LLVM_DEBUG(dbgs() << "Updating jump into old chain from " << Pred->getName()
712                       << "\n");
713     Pred->getTerminator()->replaceUsesOfWith(EntryBlock_, NextCmpBlock);
714     DTU.applyUpdates({{DominatorTree::Delete, Pred, EntryBlock_},
715                       {DominatorTree::Insert, Pred, NextCmpBlock}});
716   }
717 
718   // If the old cmp chain was the function entry, we need to update the function
719   // entry.
720   const bool ChainEntryIsFnEntry =
721       (EntryBlock_ == &EntryBlock_->getParent()->getEntryBlock());
722   if (ChainEntryIsFnEntry && DTU.hasDomTree()) {
723     LLVM_DEBUG(dbgs() << "Changing function entry from "
724                       << EntryBlock_->getName() << " to "
725                       << NextCmpBlock->getName() << "\n");
726     DTU.getDomTree().setNewRoot(NextCmpBlock);
727     DTU.applyUpdates({{DominatorTree::Delete, NextCmpBlock, EntryBlock_}});
728   }
729   EntryBlock_ = nullptr;
730 
731   // Delete merged blocks. This also removes incoming values in phi.
732   SmallVector<BasicBlock *, 16> DeadBlocks;
733   for (auto &Cmp : Comparisons_) {
734     LLVM_DEBUG(dbgs() << "Deleting merged block " << Cmp.BB->getName() << "\n");
735     DeadBlocks.push_back(Cmp.BB);
736   }
737   DeleteDeadBlocks(DeadBlocks, &DTU);
738 
739   Comparisons_.clear();
740   return true;
741 }
742 
743 std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi,
744                                            BasicBlock *const LastBlock,
745                                            int NumBlocks) {
746   // Walk up from the last block to find other blocks.
747   std::vector<BasicBlock *> Blocks(NumBlocks);
748   assert(LastBlock && "invalid last block");
749   BasicBlock *CurBlock = LastBlock;
750   for (int BlockIndex = NumBlocks - 1; BlockIndex > 0; --BlockIndex) {
751     if (CurBlock->hasAddressTaken()) {
752       // Somebody is jumping to the block through an address, all bets are
753       // off.
754       LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex
755                         << " has its address taken\n");
756       return {};
757     }
758     Blocks[BlockIndex] = CurBlock;
759     auto *SinglePredecessor = CurBlock->getSinglePredecessor();
760     if (!SinglePredecessor) {
761       // The block has two or more predecessors.
762       LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex
763                         << " has two or more predecessors\n");
764       return {};
765     }
766     if (Phi.getBasicBlockIndex(SinglePredecessor) < 0) {
767       // The block does not link back to the phi.
768       LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex
769                         << " does not link back to the phi\n");
770       return {};
771     }
772     CurBlock = SinglePredecessor;
773   }
774   Blocks[0] = CurBlock;
775   return Blocks;
776 }
777 
778 bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI,
779                 AliasAnalysis *AA, DomTreeUpdater &DTU) {
780   LLVM_DEBUG(dbgs() << "processPhi()\n");
781   if (Phi.getNumIncomingValues() <= 1) {
782     LLVM_DEBUG(dbgs() << "skip: only one incoming value in phi\n");
783     return false;
784   }
785   // We are looking for something that has the following structure:
786   //   bb1 --eq--> bb2 --eq--> bb3 --eq--> bb4 --+
787   //     \            \           \               \
788   //      ne           ne          ne              \
789   //       \            \           \               v
790   //        +------------+-----------+----------> bb_phi
791   //
792   //  - The last basic block (bb4 here) must branch unconditionally to bb_phi.
793   //    It's the only block that contributes a non-constant value to the Phi.
794   //  - All other blocks (b1, b2, b3) must have exactly two successors, one of
795   //    them being the phi block.
796   //  - All intermediate blocks (bb2, bb3) must have only one predecessor.
797   //  - Blocks cannot do other work besides the comparison, see doesOtherWork()
798 
799   // The blocks are not necessarily ordered in the phi, so we start from the
800   // last block and reconstruct the order.
801   BasicBlock *LastBlock = nullptr;
802   for (unsigned I = 0; I < Phi.getNumIncomingValues(); ++I) {
803     if (isa<ConstantInt>(Phi.getIncomingValue(I))) continue;
804     if (LastBlock) {
805       // There are several non-constant values.
806       LLVM_DEBUG(dbgs() << "skip: several non-constant values\n");
807       return false;
808     }
809     if (!isa<ICmpInst>(Phi.getIncomingValue(I)) ||
810         cast<ICmpInst>(Phi.getIncomingValue(I))->getParent() !=
811             Phi.getIncomingBlock(I)) {
812       // Non-constant incoming value is not from a cmp instruction or not
813       // produced by the last block. We could end up processing the value
814       // producing block more than once.
815       //
816       // This is an uncommon case, so we bail.
817       LLVM_DEBUG(
818           dbgs()
819           << "skip: non-constant value not from cmp or not from last block.\n");
820       return false;
821     }
822     LastBlock = Phi.getIncomingBlock(I);
823   }
824   if (!LastBlock) {
825     // There is no non-constant block.
826     LLVM_DEBUG(dbgs() << "skip: no non-constant block\n");
827     return false;
828   }
829   if (LastBlock->getSingleSuccessor() != Phi.getParent()) {
830     LLVM_DEBUG(dbgs() << "skip: last block non-phi successor\n");
831     return false;
832   }
833 
834   const auto Blocks =
835       getOrderedBlocks(Phi, LastBlock, Phi.getNumIncomingValues());
836   if (Blocks.empty()) return false;
837   BCECmpChain CmpChain(Blocks, Phi, AA);
838 
839   if (CmpChain.size() < 2) {
840     LLVM_DEBUG(dbgs() << "skip: only one compare block\n");
841     return false;
842   }
843 
844   return CmpChain.simplify(TLI, AA, DTU);
845 }
846 
847 class MergeICmps : public FunctionPass {
848  public:
849   static char ID;
850 
851   MergeICmps() : FunctionPass(ID) {
852     initializeMergeICmpsPass(*PassRegistry::getPassRegistry());
853   }
854 
855   bool runOnFunction(Function &F) override {
856     if (skipFunction(F)) return false;
857     const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
858     const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
859     // MergeICmps does not need the DominatorTree, but we update it if it's
860     // already available.
861     auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
862     DomTreeUpdater DTU(DTWP ? &DTWP->getDomTree() : nullptr,
863                        /*PostDominatorTree*/ nullptr,
864                        DomTreeUpdater::UpdateStrategy::Eager);
865     AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
866     auto PA = runImpl(F, &TLI, &TTI, AA, DTU);
867     return !PA.areAllPreserved();
868   }
869 
870  private:
871   void getAnalysisUsage(AnalysisUsage &AU) const override {
872     AU.addRequired<TargetLibraryInfoWrapperPass>();
873     AU.addRequired<TargetTransformInfoWrapperPass>();
874     AU.addRequired<AAResultsWrapperPass>();
875     AU.addPreserved<GlobalsAAWrapperPass>();
876     AU.addPreserved<DominatorTreeWrapperPass>();
877   }
878 
879   PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI,
880                             const TargetTransformInfo *TTI, AliasAnalysis *AA,
881                             DomTreeUpdater &DTU);
882 };
883 
884 PreservedAnalyses MergeICmps::runImpl(Function &F, const TargetLibraryInfo *TLI,
885                                       const TargetTransformInfo *TTI,
886                                       AliasAnalysis *AA, DomTreeUpdater &DTU) {
887   LLVM_DEBUG(dbgs() << "MergeICmpsPass: " << F.getName() << "\n");
888 
889   // We only try merging comparisons if the target wants to expand memcmp later.
890   // The rationale is to avoid turning small chains into memcmp calls.
891   if (!TTI->enableMemCmpExpansion(true)) return PreservedAnalyses::all();
892 
893   // If we don't have memcmp avaiable we can't emit calls to it.
894   if (!TLI->has(LibFunc_memcmp))
895     return PreservedAnalyses::all();
896 
897   bool MadeChange = false;
898 
899   for (auto BBIt = ++F.begin(); BBIt != F.end(); ++BBIt) {
900     // A Phi operation is always first in a basic block.
901     if (auto *const Phi = dyn_cast<PHINode>(&*BBIt->begin()))
902       MadeChange |= processPhi(*Phi, TLI, AA, DTU);
903   }
904 
905   if (!MadeChange)
906     return PreservedAnalyses::all();
907   PreservedAnalyses PA;
908   PA.preserve<GlobalsAA>();
909   PA.preserve<DominatorTreeAnalysis>();
910   return PA;
911 }
912 
913 }  // namespace
914 
915 char MergeICmps::ID = 0;
916 INITIALIZE_PASS_BEGIN(MergeICmps, "mergeicmps",
917                       "Merge contiguous icmps into a memcmp", false, false)
918 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
919 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
920 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
921 INITIALIZE_PASS_END(MergeICmps, "mergeicmps",
922                     "Merge contiguous icmps into a memcmp", false, false)
923 
924 Pass *llvm::createMergeICmpsPass() { return new MergeICmps(); }
925