xref: /llvm-project/llvm/lib/Transforms/Scalar/GVNSink.cpp (revision cf49d077fd75278abc405c8c125f40a975c830b4)
1 //===- GVNSink.cpp - sink expressions into successors ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file GVNSink.cpp
10 /// This pass attempts to sink instructions into successors, reducing static
11 /// instruction count and enabling if-conversion.
12 ///
13 /// We use a variant of global value numbering to decide what can be sunk.
14 /// Consider:
15 ///
16 /// [ %a1 = add i32 %b, 1  ]   [ %c1 = add i32 %d, 1  ]
17 /// [ %a2 = xor i32 %a1, 1 ]   [ %c2 = xor i32 %c1, 1 ]
18 ///                  \           /
19 ///            [ %e = phi i32 %a2, %c2 ]
20 ///            [ add i32 %e, 4         ]
21 ///
22 ///
23 /// GVN would number %a1 and %c1 differently because they compute different
24 /// results - the VN of an instruction is a function of its opcode and the
25 /// transitive closure of its operands. This is the key property for hoisting
26 /// and CSE.
27 ///
28 /// What we want when sinking however is for a numbering that is a function of
29 /// the *uses* of an instruction, which allows us to answer the question "if I
30 /// replace %a1 with %c1, will it contribute in an equivalent way to all
31 /// successive instructions?". The PostValueTable class in GVN provides this
32 /// mapping.
33 //
34 //===----------------------------------------------------------------------===//
35 
36 #include "llvm/ADT/ArrayRef.h"
37 #include "llvm/ADT/DenseMap.h"
38 #include "llvm/ADT/DenseSet.h"
39 #include "llvm/ADT/Hashing.h"
40 #include "llvm/ADT/PostOrderIterator.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallPtrSet.h"
43 #include "llvm/ADT/SmallVector.h"
44 #include "llvm/ADT/Statistic.h"
45 #include "llvm/Analysis/GlobalsModRef.h"
46 #include "llvm/IR/BasicBlock.h"
47 #include "llvm/IR/CFG.h"
48 #include "llvm/IR/Constants.h"
49 #include "llvm/IR/Function.h"
50 #include "llvm/IR/InstrTypes.h"
51 #include "llvm/IR/Instruction.h"
52 #include "llvm/IR/Instructions.h"
53 #include "llvm/IR/PassManager.h"
54 #include "llvm/IR/Type.h"
55 #include "llvm/IR/Use.h"
56 #include "llvm/IR/Value.h"
57 #include "llvm/Support/Allocator.h"
58 #include "llvm/Support/ArrayRecycler.h"
59 #include "llvm/Support/AtomicOrdering.h"
60 #include "llvm/Support/Casting.h"
61 #include "llvm/Support/Compiler.h"
62 #include "llvm/Support/Debug.h"
63 #include "llvm/Support/raw_ostream.h"
64 #include "llvm/Transforms/Scalar/GVN.h"
65 #include "llvm/Transforms/Scalar/GVNExpression.h"
66 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
67 #include "llvm/Transforms/Utils/Local.h"
68 #include <algorithm>
69 #include <cassert>
70 #include <cstddef>
71 #include <cstdint>
72 #include <iterator>
73 #include <utility>
74 
75 using namespace llvm;
76 
77 #define DEBUG_TYPE "gvn-sink"
78 
79 STATISTIC(NumRemoved, "Number of instructions removed");
80 
81 namespace llvm {
82 namespace GVNExpression {
83 
84 LLVM_DUMP_METHOD void Expression::dump() const {
85   print(dbgs());
86   dbgs() << "\n";
87 }
88 
89 } // end namespace GVNExpression
90 } // end namespace llvm
91 
92 namespace {
93 
94 static bool isMemoryInst(const Instruction *I) {
95   return isa<LoadInst>(I) || isa<StoreInst>(I) ||
96          (isa<InvokeInst>(I) && !cast<InvokeInst>(I)->doesNotAccessMemory()) ||
97          (isa<CallInst>(I) && !cast<CallInst>(I)->doesNotAccessMemory());
98 }
99 
100 /// Iterates through instructions in a set of blocks in reverse order from the
101 /// first non-terminator. For example (assume all blocks have size n):
102 ///   LockstepReverseIterator I([B1, B2, B3]);
103 ///   *I-- = [B1[n], B2[n], B3[n]];
104 ///   *I-- = [B1[n-1], B2[n-1], B3[n-1]];
105 ///   *I-- = [B1[n-2], B2[n-2], B3[n-2]];
106 ///   ...
107 ///
108 /// It continues until all blocks have been exhausted. Use \c getActiveBlocks()
109 /// to
110 /// determine which blocks are still going and the order they appear in the
111 /// list returned by operator*.
112 class LockstepReverseIterator {
113   ArrayRef<BasicBlock *> Blocks;
114   SmallSetVector<BasicBlock *, 4> ActiveBlocks;
115   SmallVector<Instruction *, 4> Insts;
116   bool Fail;
117 
118 public:
119   LockstepReverseIterator(ArrayRef<BasicBlock *> Blocks) : Blocks(Blocks) {
120     reset();
121   }
122 
123   void reset() {
124     Fail = false;
125     ActiveBlocks.clear();
126     for (BasicBlock *BB : Blocks)
127       ActiveBlocks.insert(BB);
128     Insts.clear();
129     for (BasicBlock *BB : Blocks) {
130       if (BB->size() <= 1) {
131         // Block wasn't big enough - only contained a terminator.
132         ActiveBlocks.remove(BB);
133         continue;
134       }
135       Insts.push_back(BB->getTerminator()->getPrevNode());
136     }
137     if (Insts.empty())
138       Fail = true;
139   }
140 
141   bool isValid() const { return !Fail; }
142   ArrayRef<Instruction *> operator*() const { return Insts; }
143 
144   // Note: This needs to return a SmallSetVector as the elements of
145   // ActiveBlocks will be later copied to Blocks using std::copy. The
146   // resultant order of elements in Blocks needs to be deterministic.
147   // Using SmallPtrSet instead causes non-deterministic order while
148   // copying. And we cannot simply sort Blocks as they need to match the
149   // corresponding Values.
150   SmallSetVector<BasicBlock *, 4> &getActiveBlocks() { return ActiveBlocks; }
151 
152   void restrictToBlocks(SmallSetVector<BasicBlock *, 4> &Blocks) {
153     for (auto II = Insts.begin(); II != Insts.end();) {
154       if (!Blocks.contains((*II)->getParent())) {
155         ActiveBlocks.remove((*II)->getParent());
156         II = Insts.erase(II);
157       } else {
158         ++II;
159       }
160     }
161   }
162 
163   void operator--() {
164     if (Fail)
165       return;
166     SmallVector<Instruction *, 4> NewInsts;
167     for (auto *Inst : Insts) {
168       if (Inst == &Inst->getParent()->front())
169         ActiveBlocks.remove(Inst->getParent());
170       else
171         NewInsts.push_back(Inst->getPrevNode());
172     }
173     if (NewInsts.empty()) {
174       Fail = true;
175       return;
176     }
177     Insts = NewInsts;
178   }
179 };
180 
181 //===----------------------------------------------------------------------===//
182 
183 /// Candidate solution for sinking. There may be different ways to
184 /// sink instructions, differing in the number of instructions sunk,
185 /// the number of predecessors sunk from and the number of PHIs
186 /// required.
187 struct SinkingInstructionCandidate {
188   unsigned NumBlocks;
189   unsigned NumInstructions;
190   unsigned NumPHIs;
191   unsigned NumMemoryInsts;
192   int Cost = -1;
193   SmallVector<BasicBlock *, 4> Blocks;
194 
195   void calculateCost(unsigned NumOrigPHIs, unsigned NumOrigBlocks) {
196     unsigned NumExtraPHIs = NumPHIs - NumOrigPHIs;
197     unsigned SplitEdgeCost = (NumOrigBlocks > NumBlocks) ? 2 : 0;
198     Cost = (NumInstructions * (NumBlocks - 1)) -
199            (NumExtraPHIs *
200             NumExtraPHIs) // PHIs are expensive, so make sure they're worth it.
201            - SplitEdgeCost;
202   }
203 
204   bool operator>(const SinkingInstructionCandidate &Other) const {
205     return Cost > Other.Cost;
206   }
207 };
208 
209 #ifndef NDEBUG
210 raw_ostream &operator<<(raw_ostream &OS, const SinkingInstructionCandidate &C) {
211   OS << "<Candidate Cost=" << C.Cost << " #Blocks=" << C.NumBlocks
212      << " #Insts=" << C.NumInstructions << " #PHIs=" << C.NumPHIs << ">";
213   return OS;
214 }
215 #endif
216 
217 //===----------------------------------------------------------------------===//
218 
219 /// Describes a PHI node that may or may not exist. These track the PHIs
220 /// that must be created if we sunk a sequence of instructions. It provides
221 /// a hash function for efficient equality comparisons.
222 class ModelledPHI {
223   SmallVector<Value *, 4> Values;
224   SmallVector<BasicBlock *, 4> Blocks;
225 
226 public:
227   ModelledPHI() = default;
228 
229   ModelledPHI(const PHINode *PN) {
230     // BasicBlock comes first so we sort by basic block pointer order, then by value pointer order.
231     SmallVector<std::pair<BasicBlock *, Value *>, 4> Ops;
232     for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I)
233       Ops.push_back({PN->getIncomingBlock(I), PN->getIncomingValue(I)});
234     llvm::sort(Ops);
235     for (auto &P : Ops) {
236       Blocks.push_back(P.first);
237       Values.push_back(P.second);
238     }
239   }
240 
241   /// Create a dummy ModelledPHI that will compare unequal to any other ModelledPHI
242   /// without the same ID.
243   /// \note This is specifically for DenseMapInfo - do not use this!
244   static ModelledPHI createDummy(size_t ID) {
245     ModelledPHI M;
246     M.Values.push_back(reinterpret_cast<Value*>(ID));
247     return M;
248   }
249 
250   /// Create a PHI from an array of incoming values and incoming blocks.
251   template <typename VArray, typename BArray>
252   ModelledPHI(const VArray &V, const BArray &B) {
253     llvm::copy(V, std::back_inserter(Values));
254     llvm::copy(B, std::back_inserter(Blocks));
255   }
256 
257   /// Create a PHI from [I[OpNum] for I in Insts].
258   template <typename BArray>
259   ModelledPHI(ArrayRef<Instruction *> Insts, unsigned OpNum, const BArray &B) {
260     llvm::copy(B, std::back_inserter(Blocks));
261     for (auto *I : Insts)
262       Values.push_back(I->getOperand(OpNum));
263   }
264 
265   /// Restrict the PHI's contents down to only \c NewBlocks.
266   /// \c NewBlocks must be a subset of \c this->Blocks.
267   void restrictToBlocks(const SmallSetVector<BasicBlock *, 4> &NewBlocks) {
268     auto BI = Blocks.begin();
269     auto VI = Values.begin();
270     while (BI != Blocks.end()) {
271       assert(VI != Values.end());
272       if (!NewBlocks.contains(*BI)) {
273         BI = Blocks.erase(BI);
274         VI = Values.erase(VI);
275       } else {
276         ++BI;
277         ++VI;
278       }
279     }
280     assert(Blocks.size() == NewBlocks.size());
281   }
282 
283   ArrayRef<Value *> getValues() const { return Values; }
284 
285   bool areAllIncomingValuesSame() const {
286     return llvm::all_equal(Values);
287   }
288 
289   bool areAllIncomingValuesSameType() const {
290     return llvm::all_of(
291         Values, [&](Value *V) { return V->getType() == Values[0]->getType(); });
292   }
293 
294   bool areAnyIncomingValuesConstant() const {
295     return llvm::any_of(Values, [&](Value *V) { return isa<Constant>(V); });
296   }
297 
298   // Hash functor
299   unsigned hash() const {
300       return (unsigned)hash_combine_range(Values.begin(), Values.end());
301   }
302 
303   bool operator==(const ModelledPHI &Other) const {
304     return Values == Other.Values && Blocks == Other.Blocks;
305   }
306 };
307 
308 template <typename ModelledPHI> struct DenseMapInfo {
309   static inline ModelledPHI &getEmptyKey() {
310     static ModelledPHI Dummy = ModelledPHI::createDummy(0);
311     return Dummy;
312   }
313 
314   static inline ModelledPHI &getTombstoneKey() {
315     static ModelledPHI Dummy = ModelledPHI::createDummy(1);
316     return Dummy;
317   }
318 
319   static unsigned getHashValue(const ModelledPHI &V) { return V.hash(); }
320 
321   static bool isEqual(const ModelledPHI &LHS, const ModelledPHI &RHS) {
322     return LHS == RHS;
323   }
324 };
325 
326 using ModelledPHISet = DenseSet<ModelledPHI, DenseMapInfo<ModelledPHI>>;
327 
328 //===----------------------------------------------------------------------===//
329 //                             ValueTable
330 //===----------------------------------------------------------------------===//
331 // This is a value number table where the value number is a function of the
332 // *uses* of a value, rather than its operands. Thus, if VN(A) == VN(B) we know
333 // that the program would be equivalent if we replaced A with PHI(A, B).
334 //===----------------------------------------------------------------------===//
335 
336 /// A GVN expression describing how an instruction is used. The operands
337 /// field of BasicExpression is used to store uses, not operands.
338 ///
339 /// This class also contains fields for discriminators used when determining
340 /// equivalence of instructions with sideeffects.
341 class InstructionUseExpr : public GVNExpression::BasicExpression {
342   unsigned MemoryUseOrder = -1;
343   bool Volatile = false;
344   ArrayRef<int> ShuffleMask;
345 
346 public:
347   InstructionUseExpr(Instruction *I, ArrayRecycler<Value *> &R,
348                      BumpPtrAllocator &A)
349       : GVNExpression::BasicExpression(I->getNumUses()) {
350     allocateOperands(R, A);
351     setOpcode(I->getOpcode());
352     setType(I->getType());
353 
354     if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(I))
355       ShuffleMask = SVI->getShuffleMask().copy(A);
356 
357     for (auto &U : I->uses())
358       op_push_back(U.getUser());
359     llvm::sort(op_begin(), op_end());
360   }
361 
362   void setMemoryUseOrder(unsigned MUO) { MemoryUseOrder = MUO; }
363   void setVolatile(bool V) { Volatile = V; }
364 
365   hash_code getHashValue() const override {
366     return hash_combine(GVNExpression::BasicExpression::getHashValue(),
367                         MemoryUseOrder, Volatile, ShuffleMask);
368   }
369 
370   template <typename Function> hash_code getHashValue(Function MapFn) {
371     hash_code H = hash_combine(getOpcode(), getType(), MemoryUseOrder, Volatile,
372                                ShuffleMask);
373     for (auto *V : operands())
374       H = hash_combine(H, MapFn(V));
375     return H;
376   }
377 };
378 
379 using BasicBlocksSet = SmallPtrSet<const BasicBlock *, 32>;
380 
381 class ValueTable {
382   DenseMap<Value *, uint32_t> ValueNumbering;
383   DenseMap<GVNExpression::Expression *, uint32_t> ExpressionNumbering;
384   DenseMap<size_t, uint32_t> HashNumbering;
385   BumpPtrAllocator Allocator;
386   ArrayRecycler<Value *> Recycler;
387   uint32_t nextValueNumber = 1;
388   BasicBlocksSet ReachableBBs;
389 
390   /// Create an expression for I based on its opcode and its uses. If I
391   /// touches or reads memory, the expression is also based upon its memory
392   /// order - see \c getMemoryUseOrder().
393   InstructionUseExpr *createExpr(Instruction *I) {
394     InstructionUseExpr *E =
395         new (Allocator) InstructionUseExpr(I, Recycler, Allocator);
396     if (isMemoryInst(I))
397       E->setMemoryUseOrder(getMemoryUseOrder(I));
398 
399     if (CmpInst *C = dyn_cast<CmpInst>(I)) {
400       CmpInst::Predicate Predicate = C->getPredicate();
401       E->setOpcode((C->getOpcode() << 8) | Predicate);
402     }
403     return E;
404   }
405 
406   /// Helper to compute the value number for a memory instruction
407   /// (LoadInst/StoreInst), including checking the memory ordering and
408   /// volatility.
409   template <class Inst> InstructionUseExpr *createMemoryExpr(Inst *I) {
410     if (isStrongerThanUnordered(I->getOrdering()) || I->isAtomic())
411       return nullptr;
412     InstructionUseExpr *E = createExpr(I);
413     E->setVolatile(I->isVolatile());
414     return E;
415   }
416 
417 public:
418   ValueTable() = default;
419 
420   /// Set basic blocks reachable from entry block.
421   void setReachableBBs(const BasicBlocksSet &ReachableBBs) {
422     this->ReachableBBs = ReachableBBs;
423   }
424 
425   /// Returns the value number for the specified value, assigning
426   /// it a new number if it did not have one before.
427   uint32_t lookupOrAdd(Value *V) {
428     auto VI = ValueNumbering.find(V);
429     if (VI != ValueNumbering.end())
430       return VI->second;
431 
432     if (!isa<Instruction>(V)) {
433       ValueNumbering[V] = nextValueNumber;
434       return nextValueNumber++;
435     }
436 
437     Instruction *I = cast<Instruction>(V);
438     if (!ReachableBBs.contains(I->getParent()))
439       return ~0U;
440 
441     InstructionUseExpr *exp = nullptr;
442     switch (I->getOpcode()) {
443     case Instruction::Load:
444       exp = createMemoryExpr(cast<LoadInst>(I));
445       break;
446     case Instruction::Store:
447       exp = createMemoryExpr(cast<StoreInst>(I));
448       break;
449     case Instruction::Call:
450     case Instruction::Invoke:
451     case Instruction::FNeg:
452     case Instruction::Add:
453     case Instruction::FAdd:
454     case Instruction::Sub:
455     case Instruction::FSub:
456     case Instruction::Mul:
457     case Instruction::FMul:
458     case Instruction::UDiv:
459     case Instruction::SDiv:
460     case Instruction::FDiv:
461     case Instruction::URem:
462     case Instruction::SRem:
463     case Instruction::FRem:
464     case Instruction::Shl:
465     case Instruction::LShr:
466     case Instruction::AShr:
467     case Instruction::And:
468     case Instruction::Or:
469     case Instruction::Xor:
470     case Instruction::ICmp:
471     case Instruction::FCmp:
472     case Instruction::Trunc:
473     case Instruction::ZExt:
474     case Instruction::SExt:
475     case Instruction::FPToUI:
476     case Instruction::FPToSI:
477     case Instruction::UIToFP:
478     case Instruction::SIToFP:
479     case Instruction::FPTrunc:
480     case Instruction::FPExt:
481     case Instruction::PtrToInt:
482     case Instruction::IntToPtr:
483     case Instruction::BitCast:
484     case Instruction::AddrSpaceCast:
485     case Instruction::Select:
486     case Instruction::ExtractElement:
487     case Instruction::InsertElement:
488     case Instruction::ShuffleVector:
489     case Instruction::InsertValue:
490     case Instruction::GetElementPtr:
491       exp = createExpr(I);
492       break;
493     default:
494       break;
495     }
496 
497     if (!exp) {
498       ValueNumbering[V] = nextValueNumber;
499       return nextValueNumber++;
500     }
501 
502     uint32_t e = ExpressionNumbering[exp];
503     if (!e) {
504       hash_code H = exp->getHashValue([=](Value *V) { return lookupOrAdd(V); });
505       auto I = HashNumbering.find(H);
506       if (I != HashNumbering.end()) {
507         e = I->second;
508       } else {
509         e = nextValueNumber++;
510         HashNumbering[H] = e;
511         ExpressionNumbering[exp] = e;
512       }
513     }
514     ValueNumbering[V] = e;
515     return e;
516   }
517 
518   /// Returns the value number of the specified value. Fails if the value has
519   /// not yet been numbered.
520   uint32_t lookup(Value *V) const {
521     auto VI = ValueNumbering.find(V);
522     assert(VI != ValueNumbering.end() && "Value not numbered?");
523     return VI->second;
524   }
525 
526   /// Removes all value numberings and resets the value table.
527   void clear() {
528     ValueNumbering.clear();
529     ExpressionNumbering.clear();
530     HashNumbering.clear();
531     Recycler.clear(Allocator);
532     nextValueNumber = 1;
533   }
534 
535   /// \c Inst uses or touches memory. Return an ID describing the memory state
536   /// at \c Inst such that if getMemoryUseOrder(I1) == getMemoryUseOrder(I2),
537   /// the exact same memory operations happen after I1 and I2.
538   ///
539   /// This is a very hard problem in general, so we use domain-specific
540   /// knowledge that we only ever check for equivalence between blocks sharing a
541   /// single immediate successor that is common, and when determining if I1 ==
542   /// I2 we will have already determined that next(I1) == next(I2). This
543   /// inductive property allows us to simply return the value number of the next
544   /// instruction that defines memory.
545   uint32_t getMemoryUseOrder(Instruction *Inst) {
546     auto *BB = Inst->getParent();
547     for (auto I = std::next(Inst->getIterator()), E = BB->end();
548          I != E && !I->isTerminator(); ++I) {
549       if (!isMemoryInst(&*I))
550         continue;
551       if (isa<LoadInst>(&*I))
552         continue;
553       CallInst *CI = dyn_cast<CallInst>(&*I);
554       if (CI && CI->onlyReadsMemory())
555         continue;
556       InvokeInst *II = dyn_cast<InvokeInst>(&*I);
557       if (II && II->onlyReadsMemory())
558         continue;
559       return lookupOrAdd(&*I);
560     }
561     return 0;
562   }
563 };
564 
565 //===----------------------------------------------------------------------===//
566 
567 class GVNSink {
568 public:
569   GVNSink() = default;
570 
571   bool run(Function &F) {
572     LLVM_DEBUG(dbgs() << "GVNSink: running on function @" << F.getName()
573                       << "\n");
574 
575     unsigned NumSunk = 0;
576     ReversePostOrderTraversal<Function*> RPOT(&F);
577     VN.setReachableBBs(BasicBlocksSet(RPOT.begin(), RPOT.end()));
578     for (auto *N : RPOT)
579       NumSunk += sinkBB(N);
580 
581     return NumSunk > 0;
582   }
583 
584 private:
585   ValueTable VN;
586 
587   bool shouldAvoidSinkingInstruction(Instruction *I) {
588     // These instructions may change or break semantics if moved.
589     if (isa<PHINode>(I) || I->isEHPad() || isa<AllocaInst>(I) ||
590         I->getType()->isTokenTy())
591       return true;
592     return false;
593   }
594 
595   /// The main heuristic function. Analyze the set of instructions pointed to by
596   /// LRI and return a candidate solution if these instructions can be sunk, or
597   /// std::nullopt otherwise.
598   std::optional<SinkingInstructionCandidate> analyzeInstructionForSinking(
599       LockstepReverseIterator &LRI, unsigned &InstNum, unsigned &MemoryInstNum,
600       ModelledPHISet &NeededPHIs, SmallPtrSetImpl<Value *> &PHIContents);
601 
602   /// Create a ModelledPHI for each PHI in BB, adding to PHIs.
603   void analyzeInitialPHIs(BasicBlock *BB, ModelledPHISet &PHIs,
604                           SmallPtrSetImpl<Value *> &PHIContents) {
605     for (PHINode &PN : BB->phis()) {
606       auto MPHI = ModelledPHI(&PN);
607       PHIs.insert(MPHI);
608       for (auto *V : MPHI.getValues())
609         PHIContents.insert(V);
610     }
611   }
612 
613   /// The main instruction sinking driver. Set up state and try and sink
614   /// instructions into BBEnd from its predecessors.
615   unsigned sinkBB(BasicBlock *BBEnd);
616 
617   /// Perform the actual mechanics of sinking an instruction from Blocks into
618   /// BBEnd, which is their only successor.
619   void sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, BasicBlock *BBEnd);
620 
621   /// Remove PHIs that all have the same incoming value.
622   void foldPointlessPHINodes(BasicBlock *BB) {
623     auto I = BB->begin();
624     while (PHINode *PN = dyn_cast<PHINode>(I++)) {
625       if (!llvm::all_of(PN->incoming_values(), [&](const Value *V) {
626             return V == PN->getIncomingValue(0);
627           }))
628         continue;
629       if (PN->getIncomingValue(0) != PN)
630         PN->replaceAllUsesWith(PN->getIncomingValue(0));
631       else
632         PN->replaceAllUsesWith(PoisonValue::get(PN->getType()));
633       PN->eraseFromParent();
634     }
635   }
636 };
637 
638 std::optional<SinkingInstructionCandidate>
639 GVNSink::analyzeInstructionForSinking(LockstepReverseIterator &LRI,
640                                       unsigned &InstNum,
641                                       unsigned &MemoryInstNum,
642                                       ModelledPHISet &NeededPHIs,
643                                       SmallPtrSetImpl<Value *> &PHIContents) {
644   auto Insts = *LRI;
645   LLVM_DEBUG(dbgs() << " -- Analyzing instruction set: [\n"; for (auto *I
646                                                                   : Insts) {
647     I->dump();
648   } dbgs() << " ]\n";);
649 
650   DenseMap<uint32_t, unsigned> VNums;
651   for (auto *I : Insts) {
652     uint32_t N = VN.lookupOrAdd(I);
653     LLVM_DEBUG(dbgs() << " VN=" << Twine::utohexstr(N) << " for" << *I << "\n");
654     if (N == ~0U)
655       return std::nullopt;
656     VNums[N]++;
657   }
658   unsigned VNumToSink = llvm::max_element(VNums, llvm::less_second())->first;
659 
660   if (VNums[VNumToSink] == 1)
661     // Can't sink anything!
662     return std::nullopt;
663 
664   // Now restrict the number of incoming blocks down to only those with
665   // VNumToSink.
666   auto &ActivePreds = LRI.getActiveBlocks();
667   unsigned InitialActivePredSize = ActivePreds.size();
668   SmallVector<Instruction *, 4> NewInsts;
669   for (auto *I : Insts) {
670     if (VN.lookup(I) != VNumToSink)
671       ActivePreds.remove(I->getParent());
672     else
673       NewInsts.push_back(I);
674   }
675   for (auto *I : NewInsts)
676     if (shouldAvoidSinkingInstruction(I))
677       return std::nullopt;
678 
679   // If we've restricted the incoming blocks, restrict all needed PHIs also
680   // to that set.
681   bool RecomputePHIContents = false;
682   if (ActivePreds.size() != InitialActivePredSize) {
683     ModelledPHISet NewNeededPHIs;
684     for (auto P : NeededPHIs) {
685       P.restrictToBlocks(ActivePreds);
686       NewNeededPHIs.insert(P);
687     }
688     NeededPHIs = NewNeededPHIs;
689     LRI.restrictToBlocks(ActivePreds);
690     RecomputePHIContents = true;
691   }
692 
693   // The sunk instruction's results.
694   ModelledPHI NewPHI(NewInsts, ActivePreds);
695 
696   // Does sinking this instruction render previous PHIs redundant?
697   if (NeededPHIs.erase(NewPHI))
698     RecomputePHIContents = true;
699 
700   if (RecomputePHIContents) {
701     // The needed PHIs have changed, so recompute the set of all needed
702     // values.
703     PHIContents.clear();
704     for (auto &PHI : NeededPHIs)
705       PHIContents.insert(PHI.getValues().begin(), PHI.getValues().end());
706   }
707 
708   // Is this instruction required by a later PHI that doesn't match this PHI?
709   // if so, we can't sink this instruction.
710   for (auto *V : NewPHI.getValues())
711     if (PHIContents.count(V))
712       // V exists in this PHI, but the whole PHI is different to NewPHI
713       // (else it would have been removed earlier). We cannot continue
714       // because this isn't representable.
715       return std::nullopt;
716 
717   // Which operands need PHIs?
718   // FIXME: If any of these fail, we should partition up the candidates to
719   // try and continue making progress.
720   Instruction *I0 = NewInsts[0];
721 
722   // If all instructions that are going to participate don't have the same
723   // number of operands, we can't do any useful PHI analysis for all operands.
724   auto hasDifferentNumOperands = [&I0](Instruction *I) {
725     return I->getNumOperands() != I0->getNumOperands();
726   };
727   if (any_of(NewInsts, hasDifferentNumOperands))
728     return std::nullopt;
729 
730   for (unsigned OpNum = 0, E = I0->getNumOperands(); OpNum != E; ++OpNum) {
731     ModelledPHI PHI(NewInsts, OpNum, ActivePreds);
732     if (PHI.areAllIncomingValuesSame())
733       continue;
734     if (!canReplaceOperandWithVariable(I0, OpNum))
735       // We can 't create a PHI from this instruction!
736       return std::nullopt;
737     if (NeededPHIs.count(PHI))
738       continue;
739     if (!PHI.areAllIncomingValuesSameType())
740       return std::nullopt;
741     // Don't create indirect calls! The called value is the final operand.
742     if ((isa<CallInst>(I0) || isa<InvokeInst>(I0)) && OpNum == E - 1 &&
743         PHI.areAnyIncomingValuesConstant())
744       return std::nullopt;
745 
746     NeededPHIs.reserve(NeededPHIs.size());
747     NeededPHIs.insert(PHI);
748     PHIContents.insert(PHI.getValues().begin(), PHI.getValues().end());
749   }
750 
751   if (isMemoryInst(NewInsts[0]))
752     ++MemoryInstNum;
753 
754   SinkingInstructionCandidate Cand;
755   Cand.NumInstructions = ++InstNum;
756   Cand.NumMemoryInsts = MemoryInstNum;
757   Cand.NumBlocks = ActivePreds.size();
758   Cand.NumPHIs = NeededPHIs.size();
759   append_range(Cand.Blocks, ActivePreds);
760 
761   return Cand;
762 }
763 
764 unsigned GVNSink::sinkBB(BasicBlock *BBEnd) {
765   LLVM_DEBUG(dbgs() << "GVNSink: running on basic block ";
766              BBEnd->printAsOperand(dbgs()); dbgs() << "\n");
767   SmallVector<BasicBlock *, 4> Preds;
768   for (auto *B : predecessors(BBEnd)) {
769     auto *T = B->getTerminator();
770     if (isa<BranchInst>(T) || isa<SwitchInst>(T))
771       Preds.push_back(B);
772     else
773       return 0;
774   }
775   if (Preds.size() < 2)
776     return 0;
777   llvm::sort(Preds);
778 
779   unsigned NumOrigPreds = Preds.size();
780   // We can only sink instructions through unconditional branches.
781   llvm::erase_if(Preds, [](BasicBlock *BB) {
782     return BB->getTerminator()->getNumSuccessors() != 1;
783   });
784 
785   LockstepReverseIterator LRI(Preds);
786   SmallVector<SinkingInstructionCandidate, 4> Candidates;
787   unsigned InstNum = 0, MemoryInstNum = 0;
788   ModelledPHISet NeededPHIs;
789   SmallPtrSet<Value *, 4> PHIContents;
790   analyzeInitialPHIs(BBEnd, NeededPHIs, PHIContents);
791   unsigned NumOrigPHIs = NeededPHIs.size();
792 
793   while (LRI.isValid()) {
794     auto Cand = analyzeInstructionForSinking(LRI, InstNum, MemoryInstNum,
795                                              NeededPHIs, PHIContents);
796     if (!Cand)
797       break;
798     Cand->calculateCost(NumOrigPHIs, Preds.size());
799     Candidates.emplace_back(*Cand);
800     --LRI;
801   }
802 
803   llvm::stable_sort(Candidates, std::greater<SinkingInstructionCandidate>());
804   LLVM_DEBUG(dbgs() << " -- Sinking candidates:\n"; for (auto &C
805                                                          : Candidates) dbgs()
806                                                     << "  " << C << "\n";);
807 
808   // Pick the top candidate, as long it is positive!
809   if (Candidates.empty() || Candidates.front().Cost <= 0)
810     return 0;
811   auto C = Candidates.front();
812 
813   LLVM_DEBUG(dbgs() << " -- Sinking: " << C << "\n");
814   BasicBlock *InsertBB = BBEnd;
815   if (C.Blocks.size() < NumOrigPreds) {
816     LLVM_DEBUG(dbgs() << " -- Splitting edge to ";
817                BBEnd->printAsOperand(dbgs()); dbgs() << "\n");
818     InsertBB = SplitBlockPredecessors(BBEnd, C.Blocks, ".gvnsink.split");
819     if (!InsertBB) {
820       LLVM_DEBUG(dbgs() << " -- FAILED to split edge!\n");
821       // Edge couldn't be split.
822       return 0;
823     }
824   }
825 
826   for (unsigned I = 0; I < C.NumInstructions; ++I)
827     sinkLastInstruction(C.Blocks, InsertBB);
828 
829   return C.NumInstructions;
830 }
831 
832 void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks,
833                                   BasicBlock *BBEnd) {
834   SmallVector<Instruction *, 4> Insts;
835   for (BasicBlock *BB : Blocks)
836     Insts.push_back(BB->getTerminator()->getPrevNode());
837   Instruction *I0 = Insts.front();
838 
839   SmallVector<Value *, 4> NewOperands;
840   for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O) {
841     bool NeedPHI = llvm::any_of(Insts, [&I0, O](const Instruction *I) {
842       return I->getOperand(O) != I0->getOperand(O);
843     });
844     if (!NeedPHI) {
845       NewOperands.push_back(I0->getOperand(O));
846       continue;
847     }
848 
849     // Create a new PHI in the successor block and populate it.
850     auto *Op = I0->getOperand(O);
851     assert(!Op->getType()->isTokenTy() && "Can't PHI tokens!");
852     auto *PN =
853         PHINode::Create(Op->getType(), Insts.size(), Op->getName() + ".sink");
854     PN->insertBefore(BBEnd->begin());
855     for (auto *I : Insts)
856       PN->addIncoming(I->getOperand(O), I->getParent());
857     NewOperands.push_back(PN);
858   }
859 
860   // Arbitrarily use I0 as the new "common" instruction; remap its operands
861   // and move it to the start of the successor block.
862   for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O)
863     I0->getOperandUse(O).set(NewOperands[O]);
864   I0->moveBefore(&*BBEnd->getFirstInsertionPt());
865 
866   // Update metadata and IR flags.
867   for (auto *I : Insts)
868     if (I != I0) {
869       combineMetadataForCSE(I0, I, true);
870       I0->andIRFlags(I);
871     }
872 
873   for (auto *I : Insts)
874     if (I != I0)
875       I->replaceAllUsesWith(I0);
876   foldPointlessPHINodes(BBEnd);
877 
878   // Finally nuke all instructions apart from the common instruction.
879   for (auto *I : Insts)
880     if (I != I0)
881       I->eraseFromParent();
882 
883   NumRemoved += Insts.size() - 1;
884 }
885 
886 } // end anonymous namespace
887 
888 PreservedAnalyses GVNSinkPass::run(Function &F, FunctionAnalysisManager &AM) {
889   GVNSink G;
890   if (!G.run(F))
891     return PreservedAnalyses::all();
892   return PreservedAnalyses::none();
893 }
894