xref: /llvm-project/llvm/lib/Transforms/Scalar/GVNSink.cpp (revision 6292a808b3524d9ba6f4ce55bc5b9e547b088dd8)
1 //===- GVNSink.cpp - sink expressions into successors ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file GVNSink.cpp
10 /// This pass attempts to sink instructions into successors, reducing static
11 /// instruction count and enabling if-conversion.
12 ///
13 /// We use a variant of global value numbering to decide what can be sunk.
14 /// Consider:
15 ///
16 /// [ %a1 = add i32 %b, 1  ]   [ %c1 = add i32 %d, 1  ]
17 /// [ %a2 = xor i32 %a1, 1 ]   [ %c2 = xor i32 %c1, 1 ]
18 ///                  \           /
19 ///            [ %e = phi i32 %a2, %c2 ]
20 ///            [ add i32 %e, 4         ]
21 ///
22 ///
23 /// GVN would number %a1 and %c1 differently because they compute different
24 /// results - the VN of an instruction is a function of its opcode and the
25 /// transitive closure of its operands. This is the key property for hoisting
26 /// and CSE.
27 ///
28 /// What we want when sinking however is for a numbering that is a function of
29 /// the *uses* of an instruction, which allows us to answer the question "if I
30 /// replace %a1 with %c1, will it contribute in an equivalent way to all
31 /// successive instructions?". The PostValueTable class in GVN provides this
32 /// mapping.
33 //
34 //===----------------------------------------------------------------------===//
35 
36 #include "llvm/ADT/ArrayRef.h"
37 #include "llvm/ADT/DenseMap.h"
38 #include "llvm/ADT/DenseSet.h"
39 #include "llvm/ADT/Hashing.h"
40 #include "llvm/ADT/PostOrderIterator.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallPtrSet.h"
43 #include "llvm/ADT/SmallVector.h"
44 #include "llvm/ADT/Statistic.h"
45 #include "llvm/Analysis/GlobalsModRef.h"
46 #include "llvm/IR/BasicBlock.h"
47 #include "llvm/IR/CFG.h"
48 #include "llvm/IR/Constants.h"
49 #include "llvm/IR/Function.h"
50 #include "llvm/IR/InstrTypes.h"
51 #include "llvm/IR/Instruction.h"
52 #include "llvm/IR/Instructions.h"
53 #include "llvm/IR/PassManager.h"
54 #include "llvm/IR/Type.h"
55 #include "llvm/IR/Use.h"
56 #include "llvm/IR/Value.h"
57 #include "llvm/Support/Allocator.h"
58 #include "llvm/Support/ArrayRecycler.h"
59 #include "llvm/Support/AtomicOrdering.h"
60 #include "llvm/Support/Casting.h"
61 #include "llvm/Support/Compiler.h"
62 #include "llvm/Support/Debug.h"
63 #include "llvm/Support/raw_ostream.h"
64 #include "llvm/Transforms/Scalar/GVN.h"
65 #include "llvm/Transforms/Scalar/GVNExpression.h"
66 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
67 #include "llvm/Transforms/Utils/Local.h"
68 #include <cassert>
69 #include <cstddef>
70 #include <cstdint>
71 #include <iterator>
72 #include <utility>
73 
74 using namespace llvm;
75 
76 #define DEBUG_TYPE "gvn-sink"
77 
78 STATISTIC(NumRemoved, "Number of instructions removed");
79 
80 namespace llvm {
81 namespace GVNExpression {
82 
83 LLVM_DUMP_METHOD void Expression::dump() const {
84   print(dbgs());
85   dbgs() << "\n";
86 }
87 
88 } // end namespace GVNExpression
89 } // end namespace llvm
90 
91 namespace {
92 
93 static bool isMemoryInst(const Instruction *I) {
94   return isa<LoadInst>(I) || isa<StoreInst>(I) ||
95          (isa<InvokeInst>(I) && !cast<InvokeInst>(I)->doesNotAccessMemory()) ||
96          (isa<CallInst>(I) && !cast<CallInst>(I)->doesNotAccessMemory());
97 }
98 
99 /// Iterates through instructions in a set of blocks in reverse order from the
100 /// first non-terminator. For example (assume all blocks have size n):
101 ///   LockstepReverseIterator I([B1, B2, B3]);
102 ///   *I-- = [B1[n], B2[n], B3[n]];
103 ///   *I-- = [B1[n-1], B2[n-1], B3[n-1]];
104 ///   *I-- = [B1[n-2], B2[n-2], B3[n-2]];
105 ///   ...
106 ///
107 /// It continues until all blocks have been exhausted. Use \c getActiveBlocks()
108 /// to
109 /// determine which blocks are still going and the order they appear in the
110 /// list returned by operator*.
111 class LockstepReverseIterator {
112   ArrayRef<BasicBlock *> Blocks;
113   SmallSetVector<BasicBlock *, 4> ActiveBlocks;
114   SmallVector<Instruction *, 4> Insts;
115   bool Fail;
116 
117 public:
118   LockstepReverseIterator(ArrayRef<BasicBlock *> Blocks) : Blocks(Blocks) {
119     reset();
120   }
121 
122   void reset() {
123     Fail = false;
124     ActiveBlocks.clear();
125     for (BasicBlock *BB : Blocks)
126       ActiveBlocks.insert(BB);
127     Insts.clear();
128     for (BasicBlock *BB : Blocks) {
129       if (BB->size() <= 1) {
130         // Block wasn't big enough - only contained a terminator.
131         ActiveBlocks.remove(BB);
132         continue;
133       }
134       Insts.push_back(BB->getTerminator()->getPrevNonDebugInstruction());
135     }
136     if (Insts.empty())
137       Fail = true;
138   }
139 
140   bool isValid() const { return !Fail; }
141   ArrayRef<Instruction *> operator*() const { return Insts; }
142 
143   // Note: This needs to return a SmallSetVector as the elements of
144   // ActiveBlocks will be later copied to Blocks using std::copy. The
145   // resultant order of elements in Blocks needs to be deterministic.
146   // Using SmallPtrSet instead causes non-deterministic order while
147   // copying. And we cannot simply sort Blocks as they need to match the
148   // corresponding Values.
149   SmallSetVector<BasicBlock *, 4> &getActiveBlocks() { return ActiveBlocks; }
150 
151   void restrictToBlocks(SmallSetVector<BasicBlock *, 4> &Blocks) {
152     for (auto II = Insts.begin(); II != Insts.end();) {
153       if (!Blocks.contains((*II)->getParent())) {
154         ActiveBlocks.remove((*II)->getParent());
155         II = Insts.erase(II);
156       } else {
157         ++II;
158       }
159     }
160   }
161 
162   void operator--() {
163     if (Fail)
164       return;
165     SmallVector<Instruction *, 4> NewInsts;
166     for (auto *Inst : Insts) {
167       if (Inst == &Inst->getParent()->front())
168         ActiveBlocks.remove(Inst->getParent());
169       else
170         NewInsts.push_back(Inst->getPrevNonDebugInstruction());
171     }
172     if (NewInsts.empty()) {
173       Fail = true;
174       return;
175     }
176     Insts = NewInsts;
177   }
178 };
179 
180 //===----------------------------------------------------------------------===//
181 
182 /// Candidate solution for sinking. There may be different ways to
183 /// sink instructions, differing in the number of instructions sunk,
184 /// the number of predecessors sunk from and the number of PHIs
185 /// required.
186 struct SinkingInstructionCandidate {
187   unsigned NumBlocks;
188   unsigned NumInstructions;
189   unsigned NumPHIs;
190   unsigned NumMemoryInsts;
191   int Cost = -1;
192   SmallVector<BasicBlock *, 4> Blocks;
193 
194   void calculateCost(unsigned NumOrigPHIs, unsigned NumOrigBlocks) {
195     unsigned NumExtraPHIs = NumPHIs - NumOrigPHIs;
196     unsigned SplitEdgeCost = (NumOrigBlocks > NumBlocks) ? 2 : 0;
197     Cost = (NumInstructions * (NumBlocks - 1)) -
198            (NumExtraPHIs *
199             NumExtraPHIs) // PHIs are expensive, so make sure they're worth it.
200            - SplitEdgeCost;
201   }
202 
203   bool operator>(const SinkingInstructionCandidate &Other) const {
204     return Cost > Other.Cost;
205   }
206 };
207 
208 #ifndef NDEBUG
209 raw_ostream &operator<<(raw_ostream &OS, const SinkingInstructionCandidate &C) {
210   OS << "<Candidate Cost=" << C.Cost << " #Blocks=" << C.NumBlocks
211      << " #Insts=" << C.NumInstructions << " #PHIs=" << C.NumPHIs << ">";
212   return OS;
213 }
214 #endif
215 
216 //===----------------------------------------------------------------------===//
217 
218 /// Describes a PHI node that may or may not exist. These track the PHIs
219 /// that must be created if we sunk a sequence of instructions. It provides
220 /// a hash function for efficient equality comparisons.
221 class ModelledPHI {
222   SmallVector<Value *, 4> Values;
223   SmallVector<BasicBlock *, 4> Blocks;
224 
225 public:
226   ModelledPHI() = default;
227 
228   ModelledPHI(const PHINode *PN,
229               const DenseMap<const BasicBlock *, unsigned> &BlockOrder) {
230     // BasicBlock comes first so we sort by basic block pointer order,
231     // then by value pointer order. No need to call `verifyModelledPHI`
232     // As the Values and Blocks are populated in a deterministic order.
233     using OpsType = std::pair<BasicBlock *, Value *>;
234     SmallVector<OpsType, 4> Ops;
235     for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I)
236       Ops.push_back({PN->getIncomingBlock(I), PN->getIncomingValue(I)});
237 
238     auto ComesBefore = [BlockOrder](OpsType O1, OpsType O2) {
239       return BlockOrder.lookup(O1.first) < BlockOrder.lookup(O2.first);
240     };
241     // Sort in a deterministic order.
242     llvm::sort(Ops, ComesBefore);
243 
244     for (auto &P : Ops) {
245       Blocks.push_back(P.first);
246       Values.push_back(P.second);
247     }
248   }
249 
250   /// Create a dummy ModelledPHI that will compare unequal to any other ModelledPHI
251   /// without the same ID.
252   /// \note This is specifically for DenseMapInfo - do not use this!
253   static ModelledPHI createDummy(size_t ID) {
254     ModelledPHI M;
255     M.Values.push_back(reinterpret_cast<Value*>(ID));
256     return M;
257   }
258 
259   void
260   verifyModelledPHI(const DenseMap<const BasicBlock *, unsigned> &BlockOrder) {
261     assert(Values.size() > 1 && Blocks.size() > 1 &&
262            "Modelling PHI with less than 2 values");
263     auto ComesBefore = [BlockOrder](const BasicBlock *BB1,
264                                     const BasicBlock *BB2) {
265       return BlockOrder.lookup(BB1) < BlockOrder.lookup(BB2);
266     };
267     assert(llvm::is_sorted(Blocks, ComesBefore));
268     int C = 0;
269     for (const Value *V : Values) {
270       if (!isa<UndefValue>(V)) {
271         assert(cast<Instruction>(V)->getParent() == Blocks[C]);
272         (void)C;
273       }
274       C++;
275     }
276   }
277   /// Create a PHI from an array of incoming values and incoming blocks.
278   ModelledPHI(SmallVectorImpl<Instruction *> &V,
279               SmallSetVector<BasicBlock *, 4> &B,
280               const DenseMap<const BasicBlock *, unsigned> &BlockOrder) {
281     // The order of Values and Blocks are already ordered by the caller.
282     llvm::copy(V, std::back_inserter(Values));
283     llvm::copy(B, std::back_inserter(Blocks));
284     verifyModelledPHI(BlockOrder);
285   }
286 
287   /// Create a PHI from [I[OpNum] for I in Insts].
288   /// TODO: Figure out a way to verifyModelledPHI in this constructor.
289   ModelledPHI(ArrayRef<Instruction *> Insts, unsigned OpNum,
290               SmallSetVector<BasicBlock *, 4> &B) {
291     llvm::copy(B, std::back_inserter(Blocks));
292     for (auto *I : Insts)
293       Values.push_back(I->getOperand(OpNum));
294   }
295 
296   /// Restrict the PHI's contents down to only \c NewBlocks.
297   /// \c NewBlocks must be a subset of \c this->Blocks.
298   void restrictToBlocks(const SmallSetVector<BasicBlock *, 4> &NewBlocks) {
299     auto BI = Blocks.begin();
300     auto VI = Values.begin();
301     while (BI != Blocks.end()) {
302       assert(VI != Values.end());
303       if (!NewBlocks.contains(*BI)) {
304         BI = Blocks.erase(BI);
305         VI = Values.erase(VI);
306       } else {
307         ++BI;
308         ++VI;
309       }
310     }
311     assert(Blocks.size() == NewBlocks.size());
312   }
313 
314   ArrayRef<Value *> getValues() const { return Values; }
315 
316   bool areAllIncomingValuesSame() const {
317     return llvm::all_equal(Values);
318   }
319 
320   bool areAllIncomingValuesSameType() const {
321     return llvm::all_of(
322         Values, [&](Value *V) { return V->getType() == Values[0]->getType(); });
323   }
324 
325   bool areAnyIncomingValuesConstant() const {
326     return llvm::any_of(Values, [&](Value *V) { return isa<Constant>(V); });
327   }
328 
329   // Hash functor
330   unsigned hash() const {
331     // Is deterministic because Values are saved in a specific order.
332     return (unsigned)hash_combine_range(Values.begin(), Values.end());
333   }
334 
335   bool operator==(const ModelledPHI &Other) const {
336     return Values == Other.Values && Blocks == Other.Blocks;
337   }
338 };
339 
340 template <typename ModelledPHI> struct DenseMapInfo {
341   static inline ModelledPHI &getEmptyKey() {
342     static ModelledPHI Dummy = ModelledPHI::createDummy(0);
343     return Dummy;
344   }
345 
346   static inline ModelledPHI &getTombstoneKey() {
347     static ModelledPHI Dummy = ModelledPHI::createDummy(1);
348     return Dummy;
349   }
350 
351   static unsigned getHashValue(const ModelledPHI &V) { return V.hash(); }
352 
353   static bool isEqual(const ModelledPHI &LHS, const ModelledPHI &RHS) {
354     return LHS == RHS;
355   }
356 };
357 
358 using ModelledPHISet = DenseSet<ModelledPHI, DenseMapInfo<ModelledPHI>>;
359 
360 //===----------------------------------------------------------------------===//
361 //                             ValueTable
362 //===----------------------------------------------------------------------===//
363 // This is a value number table where the value number is a function of the
364 // *uses* of a value, rather than its operands. Thus, if VN(A) == VN(B) we know
365 // that the program would be equivalent if we replaced A with PHI(A, B).
366 //===----------------------------------------------------------------------===//
367 
368 /// A GVN expression describing how an instruction is used. The operands
369 /// field of BasicExpression is used to store uses, not operands.
370 ///
371 /// This class also contains fields for discriminators used when determining
372 /// equivalence of instructions with sideeffects.
373 class InstructionUseExpr : public GVNExpression::BasicExpression {
374   unsigned MemoryUseOrder = -1;
375   bool Volatile = false;
376   ArrayRef<int> ShuffleMask;
377 
378 public:
379   InstructionUseExpr(Instruction *I, ArrayRecycler<Value *> &R,
380                      BumpPtrAllocator &A)
381       : GVNExpression::BasicExpression(I->getNumUses()) {
382     allocateOperands(R, A);
383     setOpcode(I->getOpcode());
384     setType(I->getType());
385 
386     if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(I))
387       ShuffleMask = SVI->getShuffleMask().copy(A);
388 
389     for (auto &U : I->uses())
390       op_push_back(U.getUser());
391     llvm::sort(op_begin(), op_end());
392   }
393 
394   void setMemoryUseOrder(unsigned MUO) { MemoryUseOrder = MUO; }
395   void setVolatile(bool V) { Volatile = V; }
396 
397   hash_code getHashValue() const override {
398     return hash_combine(GVNExpression::BasicExpression::getHashValue(),
399                         MemoryUseOrder, Volatile, ShuffleMask);
400   }
401 
402   template <typename Function> hash_code getHashValue(Function MapFn) {
403     hash_code H = hash_combine(getOpcode(), getType(), MemoryUseOrder, Volatile,
404                                ShuffleMask);
405     for (auto *V : operands())
406       H = hash_combine(H, MapFn(V));
407     return H;
408   }
409 };
410 
411 using BasicBlocksSet = SmallPtrSet<const BasicBlock *, 32>;
412 
413 class ValueTable {
414   DenseMap<Value *, uint32_t> ValueNumbering;
415   DenseMap<GVNExpression::Expression *, uint32_t> ExpressionNumbering;
416   DenseMap<size_t, uint32_t> HashNumbering;
417   BumpPtrAllocator Allocator;
418   ArrayRecycler<Value *> Recycler;
419   uint32_t nextValueNumber = 1;
420   BasicBlocksSet ReachableBBs;
421 
422   /// Create an expression for I based on its opcode and its uses. If I
423   /// touches or reads memory, the expression is also based upon its memory
424   /// order - see \c getMemoryUseOrder().
425   InstructionUseExpr *createExpr(Instruction *I) {
426     InstructionUseExpr *E =
427         new (Allocator) InstructionUseExpr(I, Recycler, Allocator);
428     if (isMemoryInst(I))
429       E->setMemoryUseOrder(getMemoryUseOrder(I));
430 
431     if (CmpInst *C = dyn_cast<CmpInst>(I)) {
432       CmpInst::Predicate Predicate = C->getPredicate();
433       E->setOpcode((C->getOpcode() << 8) | Predicate);
434     }
435     return E;
436   }
437 
438   /// Helper to compute the value number for a memory instruction
439   /// (LoadInst/StoreInst), including checking the memory ordering and
440   /// volatility.
441   template <class Inst> InstructionUseExpr *createMemoryExpr(Inst *I) {
442     if (isStrongerThanUnordered(I->getOrdering()) || I->isAtomic())
443       return nullptr;
444     InstructionUseExpr *E = createExpr(I);
445     E->setVolatile(I->isVolatile());
446     return E;
447   }
448 
449 public:
450   ValueTable() = default;
451 
452   /// Set basic blocks reachable from entry block.
453   void setReachableBBs(const BasicBlocksSet &ReachableBBs) {
454     this->ReachableBBs = ReachableBBs;
455   }
456 
457   /// Returns the value number for the specified value, assigning
458   /// it a new number if it did not have one before.
459   uint32_t lookupOrAdd(Value *V) {
460     auto VI = ValueNumbering.find(V);
461     if (VI != ValueNumbering.end())
462       return VI->second;
463 
464     if (!isa<Instruction>(V)) {
465       ValueNumbering[V] = nextValueNumber;
466       return nextValueNumber++;
467     }
468 
469     Instruction *I = cast<Instruction>(V);
470     if (!ReachableBBs.contains(I->getParent()))
471       return ~0U;
472 
473     InstructionUseExpr *exp = nullptr;
474     switch (I->getOpcode()) {
475     case Instruction::Load:
476       exp = createMemoryExpr(cast<LoadInst>(I));
477       break;
478     case Instruction::Store:
479       exp = createMemoryExpr(cast<StoreInst>(I));
480       break;
481     case Instruction::Call:
482     case Instruction::Invoke:
483     case Instruction::FNeg:
484     case Instruction::Add:
485     case Instruction::FAdd:
486     case Instruction::Sub:
487     case Instruction::FSub:
488     case Instruction::Mul:
489     case Instruction::FMul:
490     case Instruction::UDiv:
491     case Instruction::SDiv:
492     case Instruction::FDiv:
493     case Instruction::URem:
494     case Instruction::SRem:
495     case Instruction::FRem:
496     case Instruction::Shl:
497     case Instruction::LShr:
498     case Instruction::AShr:
499     case Instruction::And:
500     case Instruction::Or:
501     case Instruction::Xor:
502     case Instruction::ICmp:
503     case Instruction::FCmp:
504     case Instruction::Trunc:
505     case Instruction::ZExt:
506     case Instruction::SExt:
507     case Instruction::FPToUI:
508     case Instruction::FPToSI:
509     case Instruction::UIToFP:
510     case Instruction::SIToFP:
511     case Instruction::FPTrunc:
512     case Instruction::FPExt:
513     case Instruction::PtrToInt:
514     case Instruction::IntToPtr:
515     case Instruction::BitCast:
516     case Instruction::AddrSpaceCast:
517     case Instruction::Select:
518     case Instruction::ExtractElement:
519     case Instruction::InsertElement:
520     case Instruction::ShuffleVector:
521     case Instruction::InsertValue:
522     case Instruction::GetElementPtr:
523       exp = createExpr(I);
524       break;
525     default:
526       break;
527     }
528 
529     if (!exp) {
530       ValueNumbering[V] = nextValueNumber;
531       return nextValueNumber++;
532     }
533 
534     uint32_t e = ExpressionNumbering[exp];
535     if (!e) {
536       hash_code H = exp->getHashValue([=](Value *V) { return lookupOrAdd(V); });
537       auto [I, Inserted] = HashNumbering.try_emplace(H, nextValueNumber);
538       e = I->second;
539       if (Inserted)
540         ExpressionNumbering[exp] = nextValueNumber++;
541     }
542     ValueNumbering[V] = e;
543     return e;
544   }
545 
546   /// Returns the value number of the specified value. Fails if the value has
547   /// not yet been numbered.
548   uint32_t lookup(Value *V) const {
549     auto VI = ValueNumbering.find(V);
550     assert(VI != ValueNumbering.end() && "Value not numbered?");
551     return VI->second;
552   }
553 
554   /// Removes all value numberings and resets the value table.
555   void clear() {
556     ValueNumbering.clear();
557     ExpressionNumbering.clear();
558     HashNumbering.clear();
559     Recycler.clear(Allocator);
560     nextValueNumber = 1;
561   }
562 
563   /// \c Inst uses or touches memory. Return an ID describing the memory state
564   /// at \c Inst such that if getMemoryUseOrder(I1) == getMemoryUseOrder(I2),
565   /// the exact same memory operations happen after I1 and I2.
566   ///
567   /// This is a very hard problem in general, so we use domain-specific
568   /// knowledge that we only ever check for equivalence between blocks sharing a
569   /// single immediate successor that is common, and when determining if I1 ==
570   /// I2 we will have already determined that next(I1) == next(I2). This
571   /// inductive property allows us to simply return the value number of the next
572   /// instruction that defines memory.
573   uint32_t getMemoryUseOrder(Instruction *Inst) {
574     auto *BB = Inst->getParent();
575     for (auto I = std::next(Inst->getIterator()), E = BB->end();
576          I != E && !I->isTerminator(); ++I) {
577       if (!isMemoryInst(&*I))
578         continue;
579       if (isa<LoadInst>(&*I))
580         continue;
581       CallInst *CI = dyn_cast<CallInst>(&*I);
582       if (CI && CI->onlyReadsMemory())
583         continue;
584       InvokeInst *II = dyn_cast<InvokeInst>(&*I);
585       if (II && II->onlyReadsMemory())
586         continue;
587       return lookupOrAdd(&*I);
588     }
589     return 0;
590   }
591 };
592 
593 //===----------------------------------------------------------------------===//
594 
595 class GVNSink {
596 public:
597   GVNSink() {}
598 
599   bool run(Function &F) {
600     LLVM_DEBUG(dbgs() << "GVNSink: running on function @" << F.getName()
601                       << "\n");
602 
603     unsigned NumSunk = 0;
604     ReversePostOrderTraversal<Function*> RPOT(&F);
605     VN.setReachableBBs(BasicBlocksSet(RPOT.begin(), RPOT.end()));
606     // Populate reverse post-order to order basic blocks in deterministic
607     // order. Any arbitrary ordering will work in this case as long as they are
608     // deterministic. The node ordering of newly created basic blocks
609     // are irrelevant because RPOT(for computing sinkable candidates) is also
610     // obtained ahead of time and only their order are relevant for this pass.
611     unsigned NodeOrdering = 0;
612     RPOTOrder[*RPOT.begin()] = ++NodeOrdering;
613     for (auto *BB : RPOT)
614       if (!pred_empty(BB))
615         RPOTOrder[BB] = ++NodeOrdering;
616     for (auto *N : RPOT)
617       NumSunk += sinkBB(N);
618 
619     return NumSunk > 0;
620   }
621 
622 private:
623   ValueTable VN;
624   DenseMap<const BasicBlock *, unsigned> RPOTOrder;
625 
626   bool shouldAvoidSinkingInstruction(Instruction *I) {
627     // These instructions may change or break semantics if moved.
628     if (isa<PHINode>(I) || I->isEHPad() || isa<AllocaInst>(I) ||
629         I->getType()->isTokenTy())
630       return true;
631     return false;
632   }
633 
634   /// The main heuristic function. Analyze the set of instructions pointed to by
635   /// LRI and return a candidate solution if these instructions can be sunk, or
636   /// std::nullopt otherwise.
637   std::optional<SinkingInstructionCandidate> analyzeInstructionForSinking(
638       LockstepReverseIterator &LRI, unsigned &InstNum, unsigned &MemoryInstNum,
639       ModelledPHISet &NeededPHIs, SmallPtrSetImpl<Value *> &PHIContents);
640 
641   /// Create a ModelledPHI for each PHI in BB, adding to PHIs.
642   void analyzeInitialPHIs(BasicBlock *BB, ModelledPHISet &PHIs,
643                           SmallPtrSetImpl<Value *> &PHIContents) {
644     for (PHINode &PN : BB->phis()) {
645       auto MPHI = ModelledPHI(&PN, RPOTOrder);
646       PHIs.insert(MPHI);
647       for (auto *V : MPHI.getValues())
648         PHIContents.insert(V);
649     }
650   }
651 
652   /// The main instruction sinking driver. Set up state and try and sink
653   /// instructions into BBEnd from its predecessors.
654   unsigned sinkBB(BasicBlock *BBEnd);
655 
656   /// Perform the actual mechanics of sinking an instruction from Blocks into
657   /// BBEnd, which is their only successor.
658   void sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, BasicBlock *BBEnd);
659 
660   /// Remove PHIs that all have the same incoming value.
661   void foldPointlessPHINodes(BasicBlock *BB) {
662     auto I = BB->begin();
663     while (PHINode *PN = dyn_cast<PHINode>(I++)) {
664       if (!llvm::all_of(PN->incoming_values(), [&](const Value *V) {
665             return V == PN->getIncomingValue(0);
666           }))
667         continue;
668       if (PN->getIncomingValue(0) != PN)
669         PN->replaceAllUsesWith(PN->getIncomingValue(0));
670       else
671         PN->replaceAllUsesWith(PoisonValue::get(PN->getType()));
672       PN->eraseFromParent();
673     }
674   }
675 };
676 
677 std::optional<SinkingInstructionCandidate>
678 GVNSink::analyzeInstructionForSinking(LockstepReverseIterator &LRI,
679                                       unsigned &InstNum,
680                                       unsigned &MemoryInstNum,
681                                       ModelledPHISet &NeededPHIs,
682                                       SmallPtrSetImpl<Value *> &PHIContents) {
683   auto Insts = *LRI;
684   LLVM_DEBUG(dbgs() << " -- Analyzing instruction set: [\n"; for (auto *I
685                                                                   : Insts) {
686     I->dump();
687   } dbgs() << " ]\n";);
688 
689   DenseMap<uint32_t, unsigned> VNums;
690   for (auto *I : Insts) {
691     uint32_t N = VN.lookupOrAdd(I);
692     LLVM_DEBUG(dbgs() << " VN=" << Twine::utohexstr(N) << " for" << *I << "\n");
693     if (N == ~0U)
694       return std::nullopt;
695     VNums[N]++;
696   }
697   unsigned VNumToSink = llvm::max_element(VNums, llvm::less_second())->first;
698 
699   if (VNums[VNumToSink] == 1)
700     // Can't sink anything!
701     return std::nullopt;
702 
703   // Now restrict the number of incoming blocks down to only those with
704   // VNumToSink.
705   auto &ActivePreds = LRI.getActiveBlocks();
706   unsigned InitialActivePredSize = ActivePreds.size();
707   SmallVector<Instruction *, 4> NewInsts;
708   for (auto *I : Insts) {
709     if (VN.lookup(I) != VNumToSink)
710       ActivePreds.remove(I->getParent());
711     else
712       NewInsts.push_back(I);
713   }
714   for (auto *I : NewInsts)
715     if (shouldAvoidSinkingInstruction(I))
716       return std::nullopt;
717 
718   // If we've restricted the incoming blocks, restrict all needed PHIs also
719   // to that set.
720   bool RecomputePHIContents = false;
721   if (ActivePreds.size() != InitialActivePredSize) {
722     ModelledPHISet NewNeededPHIs;
723     for (auto P : NeededPHIs) {
724       P.restrictToBlocks(ActivePreds);
725       NewNeededPHIs.insert(P);
726     }
727     NeededPHIs = NewNeededPHIs;
728     LRI.restrictToBlocks(ActivePreds);
729     RecomputePHIContents = true;
730   }
731 
732   // The sunk instruction's results.
733   ModelledPHI NewPHI(NewInsts, ActivePreds, RPOTOrder);
734 
735   // Does sinking this instruction render previous PHIs redundant?
736   if (NeededPHIs.erase(NewPHI))
737     RecomputePHIContents = true;
738 
739   if (RecomputePHIContents) {
740     // The needed PHIs have changed, so recompute the set of all needed
741     // values.
742     PHIContents.clear();
743     for (auto &PHI : NeededPHIs)
744       PHIContents.insert(PHI.getValues().begin(), PHI.getValues().end());
745   }
746 
747   // Is this instruction required by a later PHI that doesn't match this PHI?
748   // if so, we can't sink this instruction.
749   for (auto *V : NewPHI.getValues())
750     if (PHIContents.count(V))
751       // V exists in this PHI, but the whole PHI is different to NewPHI
752       // (else it would have been removed earlier). We cannot continue
753       // because this isn't representable.
754       return std::nullopt;
755 
756   // Which operands need PHIs?
757   // FIXME: If any of these fail, we should partition up the candidates to
758   // try and continue making progress.
759   Instruction *I0 = NewInsts[0];
760 
761   auto isNotSameOperation = [&I0](Instruction *I) {
762     return !I0->isSameOperationAs(I);
763   };
764 
765   if (any_of(NewInsts, isNotSameOperation))
766     return std::nullopt;
767 
768   for (unsigned OpNum = 0, E = I0->getNumOperands(); OpNum != E; ++OpNum) {
769     ModelledPHI PHI(NewInsts, OpNum, ActivePreds);
770     if (PHI.areAllIncomingValuesSame())
771       continue;
772     if (!canReplaceOperandWithVariable(I0, OpNum))
773       // We can 't create a PHI from this instruction!
774       return std::nullopt;
775     if (NeededPHIs.count(PHI))
776       continue;
777     if (!PHI.areAllIncomingValuesSameType())
778       return std::nullopt;
779     // Don't create indirect calls! The called value is the final operand.
780     if ((isa<CallInst>(I0) || isa<InvokeInst>(I0)) && OpNum == E - 1 &&
781         PHI.areAnyIncomingValuesConstant())
782       return std::nullopt;
783 
784     NeededPHIs.reserve(NeededPHIs.size());
785     NeededPHIs.insert(PHI);
786     PHIContents.insert(PHI.getValues().begin(), PHI.getValues().end());
787   }
788 
789   if (isMemoryInst(NewInsts[0]))
790     ++MemoryInstNum;
791 
792   SinkingInstructionCandidate Cand;
793   Cand.NumInstructions = ++InstNum;
794   Cand.NumMemoryInsts = MemoryInstNum;
795   Cand.NumBlocks = ActivePreds.size();
796   Cand.NumPHIs = NeededPHIs.size();
797   append_range(Cand.Blocks, ActivePreds);
798 
799   return Cand;
800 }
801 
802 unsigned GVNSink::sinkBB(BasicBlock *BBEnd) {
803   LLVM_DEBUG(dbgs() << "GVNSink: running on basic block ";
804              BBEnd->printAsOperand(dbgs()); dbgs() << "\n");
805   SmallVector<BasicBlock *, 4> Preds;
806   for (auto *B : predecessors(BBEnd)) {
807     // Bailout on basic blocks without predecessor(PR42346).
808     if (!RPOTOrder.count(B))
809       return 0;
810     auto *T = B->getTerminator();
811     if (isa<BranchInst>(T) || isa<SwitchInst>(T))
812       Preds.push_back(B);
813     else
814       return 0;
815   }
816   if (Preds.size() < 2)
817     return 0;
818   auto ComesBefore = [this](const BasicBlock *BB1, const BasicBlock *BB2) {
819     return RPOTOrder.lookup(BB1) < RPOTOrder.lookup(BB2);
820   };
821   // Sort in a deterministic order.
822   llvm::sort(Preds, ComesBefore);
823 
824   unsigned NumOrigPreds = Preds.size();
825   // We can only sink instructions through unconditional branches.
826   llvm::erase_if(Preds, [](BasicBlock *BB) {
827     return BB->getTerminator()->getNumSuccessors() != 1;
828   });
829 
830   LockstepReverseIterator LRI(Preds);
831   SmallVector<SinkingInstructionCandidate, 4> Candidates;
832   unsigned InstNum = 0, MemoryInstNum = 0;
833   ModelledPHISet NeededPHIs;
834   SmallPtrSet<Value *, 4> PHIContents;
835   analyzeInitialPHIs(BBEnd, NeededPHIs, PHIContents);
836   unsigned NumOrigPHIs = NeededPHIs.size();
837 
838   while (LRI.isValid()) {
839     auto Cand = analyzeInstructionForSinking(LRI, InstNum, MemoryInstNum,
840                                              NeededPHIs, PHIContents);
841     if (!Cand)
842       break;
843     Cand->calculateCost(NumOrigPHIs, Preds.size());
844     Candidates.emplace_back(*Cand);
845     --LRI;
846   }
847 
848   llvm::stable_sort(Candidates, std::greater<SinkingInstructionCandidate>());
849   LLVM_DEBUG(dbgs() << " -- Sinking candidates:\n"; for (auto &C
850                                                          : Candidates) dbgs()
851                                                     << "  " << C << "\n";);
852 
853   // Pick the top candidate, as long it is positive!
854   if (Candidates.empty() || Candidates.front().Cost <= 0)
855     return 0;
856   auto C = Candidates.front();
857 
858   LLVM_DEBUG(dbgs() << " -- Sinking: " << C << "\n");
859   BasicBlock *InsertBB = BBEnd;
860   if (C.Blocks.size() < NumOrigPreds) {
861     LLVM_DEBUG(dbgs() << " -- Splitting edge to ";
862                BBEnd->printAsOperand(dbgs()); dbgs() << "\n");
863     InsertBB = SplitBlockPredecessors(BBEnd, C.Blocks, ".gvnsink.split");
864     if (!InsertBB) {
865       LLVM_DEBUG(dbgs() << " -- FAILED to split edge!\n");
866       // Edge couldn't be split.
867       return 0;
868     }
869   }
870 
871   for (unsigned I = 0; I < C.NumInstructions; ++I)
872     sinkLastInstruction(C.Blocks, InsertBB);
873 
874   return C.NumInstructions;
875 }
876 
877 void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks,
878                                   BasicBlock *BBEnd) {
879   SmallVector<Instruction *, 4> Insts;
880   for (BasicBlock *BB : Blocks)
881     Insts.push_back(BB->getTerminator()->getPrevNonDebugInstruction());
882   Instruction *I0 = Insts.front();
883 
884   SmallVector<Value *, 4> NewOperands;
885   for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O) {
886     bool NeedPHI = llvm::any_of(Insts, [&I0, O](const Instruction *I) {
887       return I->getOperand(O) != I0->getOperand(O);
888     });
889     if (!NeedPHI) {
890       NewOperands.push_back(I0->getOperand(O));
891       continue;
892     }
893 
894     // Create a new PHI in the successor block and populate it.
895     auto *Op = I0->getOperand(O);
896     assert(!Op->getType()->isTokenTy() && "Can't PHI tokens!");
897     auto *PN =
898         PHINode::Create(Op->getType(), Insts.size(), Op->getName() + ".sink");
899     PN->insertBefore(BBEnd->begin());
900     for (auto *I : Insts)
901       PN->addIncoming(I->getOperand(O), I->getParent());
902     NewOperands.push_back(PN);
903   }
904 
905   // Arbitrarily use I0 as the new "common" instruction; remap its operands
906   // and move it to the start of the successor block.
907   for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O)
908     I0->getOperandUse(O).set(NewOperands[O]);
909   I0->moveBefore(BBEnd->getFirstInsertionPt());
910 
911   // Update metadata and IR flags.
912   for (auto *I : Insts)
913     if (I != I0) {
914       combineMetadataForCSE(I0, I, true);
915       I0->andIRFlags(I);
916     }
917 
918   for (auto *I : Insts)
919     if (I != I0) {
920       I->replaceAllUsesWith(I0);
921       I0->applyMergedLocation(I0->getDebugLoc(), I->getDebugLoc());
922     }
923   foldPointlessPHINodes(BBEnd);
924 
925   // Finally nuke all instructions apart from the common instruction.
926   for (auto *I : Insts)
927     if (I != I0)
928       I->eraseFromParent();
929 
930   NumRemoved += Insts.size() - 1;
931 }
932 
933 } // end anonymous namespace
934 
935 PreservedAnalyses GVNSinkPass::run(Function &F, FunctionAnalysisManager &AM) {
936   GVNSink G;
937   if (!G.run(F))
938     return PreservedAnalyses::all();
939 
940   return PreservedAnalyses::none();
941 }
942