xref: /llvm-project/llvm/lib/Transforms/Scalar/GVNSink.cpp (revision efd94c56badf696ed7193f4a83c7a59f7dfbfc6e)
1 //===- GVNSink.cpp - sink expressions into successors ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file GVNSink.cpp
10 /// This pass attempts to sink instructions into successors, reducing static
11 /// instruction count and enabling if-conversion.
12 ///
13 /// We use a variant of global value numbering to decide what can be sunk.
14 /// Consider:
15 ///
16 /// [ %a1 = add i32 %b, 1  ]   [ %c1 = add i32 %d, 1  ]
17 /// [ %a2 = xor i32 %a1, 1 ]   [ %c2 = xor i32 %c1, 1 ]
18 ///                  \           /
19 ///            [ %e = phi i32 %a2, %c2 ]
20 ///            [ add i32 %e, 4         ]
21 ///
22 ///
23 /// GVN would number %a1 and %c1 differently because they compute different
24 /// results - the VN of an instruction is a function of its opcode and the
25 /// transitive closure of its operands. This is the key property for hoisting
26 /// and CSE.
27 ///
28 /// What we want when sinking however is for a numbering that is a function of
29 /// the *uses* of an instruction, which allows us to answer the question "if I
30 /// replace %a1 with %c1, will it contribute in an equivalent way to all
31 /// successive instructions?". The PostValueTable class in GVN provides this
32 /// mapping.
33 //
34 //===----------------------------------------------------------------------===//
35 
36 #include "llvm/ADT/ArrayRef.h"
37 #include "llvm/ADT/DenseMap.h"
38 #include "llvm/ADT/DenseMapInfo.h"
39 #include "llvm/ADT/DenseSet.h"
40 #include "llvm/ADT/Hashing.h"
41 #include "llvm/ADT/None.h"
42 #include "llvm/ADT/Optional.h"
43 #include "llvm/ADT/PostOrderIterator.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/SmallPtrSet.h"
46 #include "llvm/ADT/SmallVector.h"
47 #include "llvm/ADT/Statistic.h"
48 #include "llvm/ADT/StringExtras.h"
49 #include "llvm/Analysis/GlobalsModRef.h"
50 #include "llvm/Transforms/Utils/Local.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/CFG.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/Function.h"
55 #include "llvm/IR/InstrTypes.h"
56 #include "llvm/IR/Instruction.h"
57 #include "llvm/IR/Instructions.h"
58 #include "llvm/IR/PassManager.h"
59 #include "llvm/IR/Type.h"
60 #include "llvm/IR/Use.h"
61 #include "llvm/IR/Value.h"
62 #include "llvm/Pass.h"
63 #include "llvm/Support/Allocator.h"
64 #include "llvm/Support/ArrayRecycler.h"
65 #include "llvm/Support/AtomicOrdering.h"
66 #include "llvm/Support/Casting.h"
67 #include "llvm/Support/Compiler.h"
68 #include "llvm/Support/Debug.h"
69 #include "llvm/Support/raw_ostream.h"
70 #include "llvm/Transforms/Scalar.h"
71 #include "llvm/Transforms/Scalar/GVN.h"
72 #include "llvm/Transforms/Scalar/GVNExpression.h"
73 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
74 #include <algorithm>
75 #include <cassert>
76 #include <cstddef>
77 #include <cstdint>
78 #include <iterator>
79 #include <utility>
80 
81 using namespace llvm;
82 
83 #define DEBUG_TYPE "gvn-sink"
84 
85 STATISTIC(NumRemoved, "Number of instructions removed");
86 
87 namespace llvm {
88 namespace GVNExpression {
89 
90 LLVM_DUMP_METHOD void Expression::dump() const {
91   print(dbgs());
92   dbgs() << "\n";
93 }
94 
95 } // end namespace GVNExpression
96 } // end namespace llvm
97 
98 namespace {
99 
100 static bool isMemoryInst(const Instruction *I) {
101   return isa<LoadInst>(I) || isa<StoreInst>(I) ||
102          (isa<InvokeInst>(I) && !cast<InvokeInst>(I)->doesNotAccessMemory()) ||
103          (isa<CallInst>(I) && !cast<CallInst>(I)->doesNotAccessMemory());
104 }
105 
106 /// Iterates through instructions in a set of blocks in reverse order from the
107 /// first non-terminator. For example (assume all blocks have size n):
108 ///   LockstepReverseIterator I([B1, B2, B3]);
109 ///   *I-- = [B1[n], B2[n], B3[n]];
110 ///   *I-- = [B1[n-1], B2[n-1], B3[n-1]];
111 ///   *I-- = [B1[n-2], B2[n-2], B3[n-2]];
112 ///   ...
113 ///
114 /// It continues until all blocks have been exhausted. Use \c getActiveBlocks()
115 /// to
116 /// determine which blocks are still going and the order they appear in the
117 /// list returned by operator*.
118 class LockstepReverseIterator {
119   ArrayRef<BasicBlock *> Blocks;
120   SmallSetVector<BasicBlock *, 4> ActiveBlocks;
121   SmallVector<Instruction *, 4> Insts;
122   bool Fail;
123 
124 public:
125   LockstepReverseIterator(ArrayRef<BasicBlock *> Blocks) : Blocks(Blocks) {
126     reset();
127   }
128 
129   void reset() {
130     Fail = false;
131     ActiveBlocks.clear();
132     for (BasicBlock *BB : Blocks)
133       ActiveBlocks.insert(BB);
134     Insts.clear();
135     for (BasicBlock *BB : Blocks) {
136       if (BB->size() <= 1) {
137         // Block wasn't big enough - only contained a terminator.
138         ActiveBlocks.remove(BB);
139         continue;
140       }
141       Insts.push_back(BB->getTerminator()->getPrevNode());
142     }
143     if (Insts.empty())
144       Fail = true;
145   }
146 
147   bool isValid() const { return !Fail; }
148   ArrayRef<Instruction *> operator*() const { return Insts; }
149 
150   // Note: This needs to return a SmallSetVector as the elements of
151   // ActiveBlocks will be later copied to Blocks using std::copy. The
152   // resultant order of elements in Blocks needs to be deterministic.
153   // Using SmallPtrSet instead causes non-deterministic order while
154   // copying. And we cannot simply sort Blocks as they need to match the
155   // corresponding Values.
156   SmallSetVector<BasicBlock *, 4> &getActiveBlocks() { return ActiveBlocks; }
157 
158   void restrictToBlocks(SmallSetVector<BasicBlock *, 4> &Blocks) {
159     for (auto II = Insts.begin(); II != Insts.end();) {
160       if (std::find(Blocks.begin(), Blocks.end(), (*II)->getParent()) ==
161           Blocks.end()) {
162         ActiveBlocks.remove((*II)->getParent());
163         II = Insts.erase(II);
164       } else {
165         ++II;
166       }
167     }
168   }
169 
170   void operator--() {
171     if (Fail)
172       return;
173     SmallVector<Instruction *, 4> NewInsts;
174     for (auto *Inst : Insts) {
175       if (Inst == &Inst->getParent()->front())
176         ActiveBlocks.remove(Inst->getParent());
177       else
178         NewInsts.push_back(Inst->getPrevNode());
179     }
180     if (NewInsts.empty()) {
181       Fail = true;
182       return;
183     }
184     Insts = NewInsts;
185   }
186 };
187 
188 //===----------------------------------------------------------------------===//
189 
190 /// Candidate solution for sinking. There may be different ways to
191 /// sink instructions, differing in the number of instructions sunk,
192 /// the number of predecessors sunk from and the number of PHIs
193 /// required.
194 struct SinkingInstructionCandidate {
195   unsigned NumBlocks;
196   unsigned NumInstructions;
197   unsigned NumPHIs;
198   unsigned NumMemoryInsts;
199   int Cost = -1;
200   SmallVector<BasicBlock *, 4> Blocks;
201 
202   void calculateCost(unsigned NumOrigPHIs, unsigned NumOrigBlocks) {
203     unsigned NumExtraPHIs = NumPHIs - NumOrigPHIs;
204     unsigned SplitEdgeCost = (NumOrigBlocks > NumBlocks) ? 2 : 0;
205     Cost = (NumInstructions * (NumBlocks - 1)) -
206            (NumExtraPHIs *
207             NumExtraPHIs) // PHIs are expensive, so make sure they're worth it.
208            - SplitEdgeCost;
209   }
210 
211   bool operator>(const SinkingInstructionCandidate &Other) const {
212     return Cost > Other.Cost;
213   }
214 };
215 
216 #ifndef NDEBUG
217 raw_ostream &operator<<(raw_ostream &OS, const SinkingInstructionCandidate &C) {
218   OS << "<Candidate Cost=" << C.Cost << " #Blocks=" << C.NumBlocks
219      << " #Insts=" << C.NumInstructions << " #PHIs=" << C.NumPHIs << ">";
220   return OS;
221 }
222 #endif
223 
224 //===----------------------------------------------------------------------===//
225 
226 /// Describes a PHI node that may or may not exist. These track the PHIs
227 /// that must be created if we sunk a sequence of instructions. It provides
228 /// a hash function for efficient equality comparisons.
229 class ModelledPHI {
230   SmallVector<Value *, 4> Values;
231   SmallVector<BasicBlock *, 4> Blocks;
232 
233 public:
234   ModelledPHI() = default;
235 
236   ModelledPHI(const PHINode *PN) {
237     // BasicBlock comes first so we sort by basic block pointer order, then by value pointer order.
238     SmallVector<std::pair<BasicBlock *, Value *>, 4> Ops;
239     for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I)
240       Ops.push_back({PN->getIncomingBlock(I), PN->getIncomingValue(I)});
241     llvm::sort(Ops);
242     for (auto &P : Ops) {
243       Blocks.push_back(P.first);
244       Values.push_back(P.second);
245     }
246   }
247 
248   /// Create a dummy ModelledPHI that will compare unequal to any other ModelledPHI
249   /// without the same ID.
250   /// \note This is specifically for DenseMapInfo - do not use this!
251   static ModelledPHI createDummy(size_t ID) {
252     ModelledPHI M;
253     M.Values.push_back(reinterpret_cast<Value*>(ID));
254     return M;
255   }
256 
257   /// Create a PHI from an array of incoming values and incoming blocks.
258   template <typename VArray, typename BArray>
259   ModelledPHI(const VArray &V, const BArray &B) {
260     llvm::copy(V, std::back_inserter(Values));
261     llvm::copy(B, std::back_inserter(Blocks));
262   }
263 
264   /// Create a PHI from [I[OpNum] for I in Insts].
265   template <typename BArray>
266   ModelledPHI(ArrayRef<Instruction *> Insts, unsigned OpNum, const BArray &B) {
267     llvm::copy(B, std::back_inserter(Blocks));
268     for (auto *I : Insts)
269       Values.push_back(I->getOperand(OpNum));
270   }
271 
272   /// Restrict the PHI's contents down to only \c NewBlocks.
273   /// \c NewBlocks must be a subset of \c this->Blocks.
274   void restrictToBlocks(const SmallSetVector<BasicBlock *, 4> &NewBlocks) {
275     auto BI = Blocks.begin();
276     auto VI = Values.begin();
277     while (BI != Blocks.end()) {
278       assert(VI != Values.end());
279       if (std::find(NewBlocks.begin(), NewBlocks.end(), *BI) ==
280           NewBlocks.end()) {
281         BI = Blocks.erase(BI);
282         VI = Values.erase(VI);
283       } else {
284         ++BI;
285         ++VI;
286       }
287     }
288     assert(Blocks.size() == NewBlocks.size());
289   }
290 
291   ArrayRef<Value *> getValues() const { return Values; }
292 
293   bool areAllIncomingValuesSame() const {
294     return llvm::all_of(Values, [&](Value *V) { return V == Values[0]; });
295   }
296 
297   bool areAllIncomingValuesSameType() const {
298     return llvm::all_of(
299         Values, [&](Value *V) { return V->getType() == Values[0]->getType(); });
300   }
301 
302   bool areAnyIncomingValuesConstant() const {
303     return llvm::any_of(Values, [&](Value *V) { return isa<Constant>(V); });
304   }
305 
306   // Hash functor
307   unsigned hash() const {
308       return (unsigned)hash_combine_range(Values.begin(), Values.end());
309   }
310 
311   bool operator==(const ModelledPHI &Other) const {
312     return Values == Other.Values && Blocks == Other.Blocks;
313   }
314 };
315 
316 template <typename ModelledPHI> struct DenseMapInfo {
317   static inline ModelledPHI &getEmptyKey() {
318     static ModelledPHI Dummy = ModelledPHI::createDummy(0);
319     return Dummy;
320   }
321 
322   static inline ModelledPHI &getTombstoneKey() {
323     static ModelledPHI Dummy = ModelledPHI::createDummy(1);
324     return Dummy;
325   }
326 
327   static unsigned getHashValue(const ModelledPHI &V) { return V.hash(); }
328 
329   static bool isEqual(const ModelledPHI &LHS, const ModelledPHI &RHS) {
330     return LHS == RHS;
331   }
332 };
333 
334 using ModelledPHISet = DenseSet<ModelledPHI, DenseMapInfo<ModelledPHI>>;
335 
336 //===----------------------------------------------------------------------===//
337 //                             ValueTable
338 //===----------------------------------------------------------------------===//
339 // This is a value number table where the value number is a function of the
340 // *uses* of a value, rather than its operands. Thus, if VN(A) == VN(B) we know
341 // that the program would be equivalent if we replaced A with PHI(A, B).
342 //===----------------------------------------------------------------------===//
343 
344 /// A GVN expression describing how an instruction is used. The operands
345 /// field of BasicExpression is used to store uses, not operands.
346 ///
347 /// This class also contains fields for discriminators used when determining
348 /// equivalence of instructions with sideeffects.
349 class InstructionUseExpr : public GVNExpression::BasicExpression {
350   unsigned MemoryUseOrder = -1;
351   bool Volatile = false;
352 
353 public:
354   InstructionUseExpr(Instruction *I, ArrayRecycler<Value *> &R,
355                      BumpPtrAllocator &A)
356       : GVNExpression::BasicExpression(I->getNumUses()) {
357     allocateOperands(R, A);
358     setOpcode(I->getOpcode());
359     setType(I->getType());
360 
361     for (auto &U : I->uses())
362       op_push_back(U.getUser());
363     llvm::sort(op_begin(), op_end());
364   }
365 
366   void setMemoryUseOrder(unsigned MUO) { MemoryUseOrder = MUO; }
367   void setVolatile(bool V) { Volatile = V; }
368 
369   hash_code getHashValue() const override {
370     return hash_combine(GVNExpression::BasicExpression::getHashValue(),
371                         MemoryUseOrder, Volatile);
372   }
373 
374   template <typename Function> hash_code getHashValue(Function MapFn) {
375     hash_code H =
376         hash_combine(getOpcode(), getType(), MemoryUseOrder, Volatile);
377     for (auto *V : operands())
378       H = hash_combine(H, MapFn(V));
379     return H;
380   }
381 };
382 
383 class ValueTable {
384   DenseMap<Value *, uint32_t> ValueNumbering;
385   DenseMap<GVNExpression::Expression *, uint32_t> ExpressionNumbering;
386   DenseMap<size_t, uint32_t> HashNumbering;
387   BumpPtrAllocator Allocator;
388   ArrayRecycler<Value *> Recycler;
389   uint32_t nextValueNumber = 1;
390 
391   /// Create an expression for I based on its opcode and its uses. If I
392   /// touches or reads memory, the expression is also based upon its memory
393   /// order - see \c getMemoryUseOrder().
394   InstructionUseExpr *createExpr(Instruction *I) {
395     InstructionUseExpr *E =
396         new (Allocator) InstructionUseExpr(I, Recycler, Allocator);
397     if (isMemoryInst(I))
398       E->setMemoryUseOrder(getMemoryUseOrder(I));
399 
400     if (CmpInst *C = dyn_cast<CmpInst>(I)) {
401       CmpInst::Predicate Predicate = C->getPredicate();
402       E->setOpcode((C->getOpcode() << 8) | Predicate);
403     }
404     return E;
405   }
406 
407   /// Helper to compute the value number for a memory instruction
408   /// (LoadInst/StoreInst), including checking the memory ordering and
409   /// volatility.
410   template <class Inst> InstructionUseExpr *createMemoryExpr(Inst *I) {
411     if (isStrongerThanUnordered(I->getOrdering()) || I->isAtomic())
412       return nullptr;
413     InstructionUseExpr *E = createExpr(I);
414     E->setVolatile(I->isVolatile());
415     return E;
416   }
417 
418 public:
419   ValueTable() = default;
420 
421   /// Returns the value number for the specified value, assigning
422   /// it a new number if it did not have one before.
423   uint32_t lookupOrAdd(Value *V) {
424     auto VI = ValueNumbering.find(V);
425     if (VI != ValueNumbering.end())
426       return VI->second;
427 
428     if (!isa<Instruction>(V)) {
429       ValueNumbering[V] = nextValueNumber;
430       return nextValueNumber++;
431     }
432 
433     Instruction *I = cast<Instruction>(V);
434     InstructionUseExpr *exp = nullptr;
435     switch (I->getOpcode()) {
436     case Instruction::Load:
437       exp = createMemoryExpr(cast<LoadInst>(I));
438       break;
439     case Instruction::Store:
440       exp = createMemoryExpr(cast<StoreInst>(I));
441       break;
442     case Instruction::Call:
443     case Instruction::Invoke:
444     case Instruction::Add:
445     case Instruction::FAdd:
446     case Instruction::Sub:
447     case Instruction::FSub:
448     case Instruction::Mul:
449     case Instruction::FMul:
450     case Instruction::UDiv:
451     case Instruction::SDiv:
452     case Instruction::FDiv:
453     case Instruction::URem:
454     case Instruction::SRem:
455     case Instruction::FRem:
456     case Instruction::Shl:
457     case Instruction::LShr:
458     case Instruction::AShr:
459     case Instruction::And:
460     case Instruction::Or:
461     case Instruction::Xor:
462     case Instruction::ICmp:
463     case Instruction::FCmp:
464     case Instruction::Trunc:
465     case Instruction::ZExt:
466     case Instruction::SExt:
467     case Instruction::FPToUI:
468     case Instruction::FPToSI:
469     case Instruction::UIToFP:
470     case Instruction::SIToFP:
471     case Instruction::FPTrunc:
472     case Instruction::FPExt:
473     case Instruction::PtrToInt:
474     case Instruction::IntToPtr:
475     case Instruction::BitCast:
476     case Instruction::Select:
477     case Instruction::ExtractElement:
478     case Instruction::InsertElement:
479     case Instruction::ShuffleVector:
480     case Instruction::InsertValue:
481     case Instruction::GetElementPtr:
482       exp = createExpr(I);
483       break;
484     default:
485       break;
486     }
487 
488     if (!exp) {
489       ValueNumbering[V] = nextValueNumber;
490       return nextValueNumber++;
491     }
492 
493     uint32_t e = ExpressionNumbering[exp];
494     if (!e) {
495       hash_code H = exp->getHashValue([=](Value *V) { return lookupOrAdd(V); });
496       auto I = HashNumbering.find(H);
497       if (I != HashNumbering.end()) {
498         e = I->second;
499       } else {
500         e = nextValueNumber++;
501         HashNumbering[H] = e;
502         ExpressionNumbering[exp] = e;
503       }
504     }
505     ValueNumbering[V] = e;
506     return e;
507   }
508 
509   /// Returns the value number of the specified value. Fails if the value has
510   /// not yet been numbered.
511   uint32_t lookup(Value *V) const {
512     auto VI = ValueNumbering.find(V);
513     assert(VI != ValueNumbering.end() && "Value not numbered?");
514     return VI->second;
515   }
516 
517   /// Removes all value numberings and resets the value table.
518   void clear() {
519     ValueNumbering.clear();
520     ExpressionNumbering.clear();
521     HashNumbering.clear();
522     Recycler.clear(Allocator);
523     nextValueNumber = 1;
524   }
525 
526   /// \c Inst uses or touches memory. Return an ID describing the memory state
527   /// at \c Inst such that if getMemoryUseOrder(I1) == getMemoryUseOrder(I2),
528   /// the exact same memory operations happen after I1 and I2.
529   ///
530   /// This is a very hard problem in general, so we use domain-specific
531   /// knowledge that we only ever check for equivalence between blocks sharing a
532   /// single immediate successor that is common, and when determining if I1 ==
533   /// I2 we will have already determined that next(I1) == next(I2). This
534   /// inductive property allows us to simply return the value number of the next
535   /// instruction that defines memory.
536   uint32_t getMemoryUseOrder(Instruction *Inst) {
537     auto *BB = Inst->getParent();
538     for (auto I = std::next(Inst->getIterator()), E = BB->end();
539          I != E && !I->isTerminator(); ++I) {
540       if (!isMemoryInst(&*I))
541         continue;
542       if (isa<LoadInst>(&*I))
543         continue;
544       CallInst *CI = dyn_cast<CallInst>(&*I);
545       if (CI && CI->onlyReadsMemory())
546         continue;
547       InvokeInst *II = dyn_cast<InvokeInst>(&*I);
548       if (II && II->onlyReadsMemory())
549         continue;
550       return lookupOrAdd(&*I);
551     }
552     return 0;
553   }
554 };
555 
556 //===----------------------------------------------------------------------===//
557 
558 class GVNSink {
559 public:
560   GVNSink() = default;
561 
562   bool run(Function &F) {
563     LLVM_DEBUG(dbgs() << "GVNSink: running on function @" << F.getName()
564                       << "\n");
565 
566     unsigned NumSunk = 0;
567     ReversePostOrderTraversal<Function*> RPOT(&F);
568     for (auto *N : RPOT)
569       NumSunk += sinkBB(N);
570 
571     return NumSunk > 0;
572   }
573 
574 private:
575   ValueTable VN;
576 
577   bool isInstructionBlacklisted(Instruction *I) {
578     // These instructions may change or break semantics if moved.
579     if (isa<PHINode>(I) || I->isEHPad() || isa<AllocaInst>(I) ||
580         I->getType()->isTokenTy())
581       return true;
582     return false;
583   }
584 
585   /// The main heuristic function. Analyze the set of instructions pointed to by
586   /// LRI and return a candidate solution if these instructions can be sunk, or
587   /// None otherwise.
588   Optional<SinkingInstructionCandidate> analyzeInstructionForSinking(
589       LockstepReverseIterator &LRI, unsigned &InstNum, unsigned &MemoryInstNum,
590       ModelledPHISet &NeededPHIs, SmallPtrSetImpl<Value *> &PHIContents);
591 
592   /// Create a ModelledPHI for each PHI in BB, adding to PHIs.
593   void analyzeInitialPHIs(BasicBlock *BB, ModelledPHISet &PHIs,
594                           SmallPtrSetImpl<Value *> &PHIContents) {
595     for (PHINode &PN : BB->phis()) {
596       auto MPHI = ModelledPHI(&PN);
597       PHIs.insert(MPHI);
598       for (auto *V : MPHI.getValues())
599         PHIContents.insert(V);
600     }
601   }
602 
603   /// The main instruction sinking driver. Set up state and try and sink
604   /// instructions into BBEnd from its predecessors.
605   unsigned sinkBB(BasicBlock *BBEnd);
606 
607   /// Perform the actual mechanics of sinking an instruction from Blocks into
608   /// BBEnd, which is their only successor.
609   void sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, BasicBlock *BBEnd);
610 
611   /// Remove PHIs that all have the same incoming value.
612   void foldPointlessPHINodes(BasicBlock *BB) {
613     auto I = BB->begin();
614     while (PHINode *PN = dyn_cast<PHINode>(I++)) {
615       if (!llvm::all_of(PN->incoming_values(), [&](const Value *V) {
616             return V == PN->getIncomingValue(0);
617           }))
618         continue;
619       if (PN->getIncomingValue(0) != PN)
620         PN->replaceAllUsesWith(PN->getIncomingValue(0));
621       else
622         PN->replaceAllUsesWith(UndefValue::get(PN->getType()));
623       PN->eraseFromParent();
624     }
625   }
626 };
627 
628 Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking(
629   LockstepReverseIterator &LRI, unsigned &InstNum, unsigned &MemoryInstNum,
630   ModelledPHISet &NeededPHIs, SmallPtrSetImpl<Value *> &PHIContents) {
631   auto Insts = *LRI;
632   LLVM_DEBUG(dbgs() << " -- Analyzing instruction set: [\n"; for (auto *I
633                                                                   : Insts) {
634     I->dump();
635   } dbgs() << " ]\n";);
636 
637   DenseMap<uint32_t, unsigned> VNums;
638   for (auto *I : Insts) {
639     uint32_t N = VN.lookupOrAdd(I);
640     LLVM_DEBUG(dbgs() << " VN=" << Twine::utohexstr(N) << " for" << *I << "\n");
641     if (N == ~0U)
642       return None;
643     VNums[N]++;
644   }
645   unsigned VNumToSink =
646       std::max_element(VNums.begin(), VNums.end(),
647                        [](const std::pair<uint32_t, unsigned> &I,
648                           const std::pair<uint32_t, unsigned> &J) {
649                          return I.second < J.second;
650                        })
651           ->first;
652 
653   if (VNums[VNumToSink] == 1)
654     // Can't sink anything!
655     return None;
656 
657   // Now restrict the number of incoming blocks down to only those with
658   // VNumToSink.
659   auto &ActivePreds = LRI.getActiveBlocks();
660   unsigned InitialActivePredSize = ActivePreds.size();
661   SmallVector<Instruction *, 4> NewInsts;
662   for (auto *I : Insts) {
663     if (VN.lookup(I) != VNumToSink)
664       ActivePreds.remove(I->getParent());
665     else
666       NewInsts.push_back(I);
667   }
668   for (auto *I : NewInsts)
669     if (isInstructionBlacklisted(I))
670       return None;
671 
672   // If we've restricted the incoming blocks, restrict all needed PHIs also
673   // to that set.
674   bool RecomputePHIContents = false;
675   if (ActivePreds.size() != InitialActivePredSize) {
676     ModelledPHISet NewNeededPHIs;
677     for (auto P : NeededPHIs) {
678       P.restrictToBlocks(ActivePreds);
679       NewNeededPHIs.insert(P);
680     }
681     NeededPHIs = NewNeededPHIs;
682     LRI.restrictToBlocks(ActivePreds);
683     RecomputePHIContents = true;
684   }
685 
686   // The sunk instruction's results.
687   ModelledPHI NewPHI(NewInsts, ActivePreds);
688 
689   // Does sinking this instruction render previous PHIs redundant?
690   if (NeededPHIs.find(NewPHI) != NeededPHIs.end()) {
691     NeededPHIs.erase(NewPHI);
692     RecomputePHIContents = true;
693   }
694 
695   if (RecomputePHIContents) {
696     // The needed PHIs have changed, so recompute the set of all needed
697     // values.
698     PHIContents.clear();
699     for (auto &PHI : NeededPHIs)
700       PHIContents.insert(PHI.getValues().begin(), PHI.getValues().end());
701   }
702 
703   // Is this instruction required by a later PHI that doesn't match this PHI?
704   // if so, we can't sink this instruction.
705   for (auto *V : NewPHI.getValues())
706     if (PHIContents.count(V))
707       // V exists in this PHI, but the whole PHI is different to NewPHI
708       // (else it would have been removed earlier). We cannot continue
709       // because this isn't representable.
710       return None;
711 
712   // Which operands need PHIs?
713   // FIXME: If any of these fail, we should partition up the candidates to
714   // try and continue making progress.
715   Instruction *I0 = NewInsts[0];
716   for (unsigned OpNum = 0, E = I0->getNumOperands(); OpNum != E; ++OpNum) {
717     ModelledPHI PHI(NewInsts, OpNum, ActivePreds);
718     if (PHI.areAllIncomingValuesSame())
719       continue;
720     if (!canReplaceOperandWithVariable(I0, OpNum))
721       // We can 't create a PHI from this instruction!
722       return None;
723     if (NeededPHIs.count(PHI))
724       continue;
725     if (!PHI.areAllIncomingValuesSameType())
726       return None;
727     // Don't create indirect calls! The called value is the final operand.
728     if ((isa<CallInst>(I0) || isa<InvokeInst>(I0)) && OpNum == E - 1 &&
729         PHI.areAnyIncomingValuesConstant())
730       return None;
731 
732     NeededPHIs.reserve(NeededPHIs.size());
733     NeededPHIs.insert(PHI);
734     PHIContents.insert(PHI.getValues().begin(), PHI.getValues().end());
735   }
736 
737   if (isMemoryInst(NewInsts[0]))
738     ++MemoryInstNum;
739 
740   SinkingInstructionCandidate Cand;
741   Cand.NumInstructions = ++InstNum;
742   Cand.NumMemoryInsts = MemoryInstNum;
743   Cand.NumBlocks = ActivePreds.size();
744   Cand.NumPHIs = NeededPHIs.size();
745   for (auto *C : ActivePreds)
746     Cand.Blocks.push_back(C);
747 
748   return Cand;
749 }
750 
751 unsigned GVNSink::sinkBB(BasicBlock *BBEnd) {
752   LLVM_DEBUG(dbgs() << "GVNSink: running on basic block ";
753              BBEnd->printAsOperand(dbgs()); dbgs() << "\n");
754   SmallVector<BasicBlock *, 4> Preds;
755   for (auto *B : predecessors(BBEnd)) {
756     auto *T = B->getTerminator();
757     if (isa<BranchInst>(T) || isa<SwitchInst>(T))
758       Preds.push_back(B);
759     else
760       return 0;
761   }
762   if (Preds.size() < 2)
763     return 0;
764   llvm::sort(Preds);
765 
766   unsigned NumOrigPreds = Preds.size();
767   // We can only sink instructions through unconditional branches.
768   for (auto I = Preds.begin(); I != Preds.end();) {
769     if ((*I)->getTerminator()->getNumSuccessors() != 1)
770       I = Preds.erase(I);
771     else
772       ++I;
773   }
774 
775   LockstepReverseIterator LRI(Preds);
776   SmallVector<SinkingInstructionCandidate, 4> Candidates;
777   unsigned InstNum = 0, MemoryInstNum = 0;
778   ModelledPHISet NeededPHIs;
779   SmallPtrSet<Value *, 4> PHIContents;
780   analyzeInitialPHIs(BBEnd, NeededPHIs, PHIContents);
781   unsigned NumOrigPHIs = NeededPHIs.size();
782 
783   while (LRI.isValid()) {
784     auto Cand = analyzeInstructionForSinking(LRI, InstNum, MemoryInstNum,
785                                              NeededPHIs, PHIContents);
786     if (!Cand)
787       break;
788     Cand->calculateCost(NumOrigPHIs, Preds.size());
789     Candidates.emplace_back(*Cand);
790     --LRI;
791   }
792 
793   llvm::stable_sort(Candidates, std::greater<SinkingInstructionCandidate>());
794   LLVM_DEBUG(dbgs() << " -- Sinking candidates:\n"; for (auto &C
795                                                          : Candidates) dbgs()
796                                                     << "  " << C << "\n";);
797 
798   // Pick the top candidate, as long it is positive!
799   if (Candidates.empty() || Candidates.front().Cost <= 0)
800     return 0;
801   auto C = Candidates.front();
802 
803   LLVM_DEBUG(dbgs() << " -- Sinking: " << C << "\n");
804   BasicBlock *InsertBB = BBEnd;
805   if (C.Blocks.size() < NumOrigPreds) {
806     LLVM_DEBUG(dbgs() << " -- Splitting edge to ";
807                BBEnd->printAsOperand(dbgs()); dbgs() << "\n");
808     InsertBB = SplitBlockPredecessors(BBEnd, C.Blocks, ".gvnsink.split");
809     if (!InsertBB) {
810       LLVM_DEBUG(dbgs() << " -- FAILED to split edge!\n");
811       // Edge couldn't be split.
812       return 0;
813     }
814   }
815 
816   for (unsigned I = 0; I < C.NumInstructions; ++I)
817     sinkLastInstruction(C.Blocks, InsertBB);
818 
819   return C.NumInstructions;
820 }
821 
822 void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks,
823                                   BasicBlock *BBEnd) {
824   SmallVector<Instruction *, 4> Insts;
825   for (BasicBlock *BB : Blocks)
826     Insts.push_back(BB->getTerminator()->getPrevNode());
827   Instruction *I0 = Insts.front();
828 
829   SmallVector<Value *, 4> NewOperands;
830   for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O) {
831     bool NeedPHI = llvm::any_of(Insts, [&I0, O](const Instruction *I) {
832       return I->getOperand(O) != I0->getOperand(O);
833     });
834     if (!NeedPHI) {
835       NewOperands.push_back(I0->getOperand(O));
836       continue;
837     }
838 
839     // Create a new PHI in the successor block and populate it.
840     auto *Op = I0->getOperand(O);
841     assert(!Op->getType()->isTokenTy() && "Can't PHI tokens!");
842     auto *PN = PHINode::Create(Op->getType(), Insts.size(),
843                                Op->getName() + ".sink", &BBEnd->front());
844     for (auto *I : Insts)
845       PN->addIncoming(I->getOperand(O), I->getParent());
846     NewOperands.push_back(PN);
847   }
848 
849   // Arbitrarily use I0 as the new "common" instruction; remap its operands
850   // and move it to the start of the successor block.
851   for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O)
852     I0->getOperandUse(O).set(NewOperands[O]);
853   I0->moveBefore(&*BBEnd->getFirstInsertionPt());
854 
855   // Update metadata and IR flags.
856   for (auto *I : Insts)
857     if (I != I0) {
858       combineMetadataForCSE(I0, I, true);
859       I0->andIRFlags(I);
860     }
861 
862   for (auto *I : Insts)
863     if (I != I0)
864       I->replaceAllUsesWith(I0);
865   foldPointlessPHINodes(BBEnd);
866 
867   // Finally nuke all instructions apart from the common instruction.
868   for (auto *I : Insts)
869     if (I != I0)
870       I->eraseFromParent();
871 
872   NumRemoved += Insts.size() - 1;
873 }
874 
875 ////////////////////////////////////////////////////////////////////////////////
876 // Pass machinery / boilerplate
877 
878 class GVNSinkLegacyPass : public FunctionPass {
879 public:
880   static char ID;
881 
882   GVNSinkLegacyPass() : FunctionPass(ID) {
883     initializeGVNSinkLegacyPassPass(*PassRegistry::getPassRegistry());
884   }
885 
886   bool runOnFunction(Function &F) override {
887     if (skipFunction(F))
888       return false;
889     GVNSink G;
890     return G.run(F);
891   }
892 
893   void getAnalysisUsage(AnalysisUsage &AU) const override {
894     AU.addPreserved<GlobalsAAWrapperPass>();
895   }
896 };
897 
898 } // end anonymous namespace
899 
900 PreservedAnalyses GVNSinkPass::run(Function &F, FunctionAnalysisManager &AM) {
901   GVNSink G;
902   if (!G.run(F))
903     return PreservedAnalyses::all();
904 
905   PreservedAnalyses PA;
906   PA.preserve<GlobalsAA>();
907   return PA;
908 }
909 
910 char GVNSinkLegacyPass::ID = 0;
911 
912 INITIALIZE_PASS_BEGIN(GVNSinkLegacyPass, "gvn-sink",
913                       "Early GVN sinking of Expressions", false, false)
914 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
915 INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
916 INITIALIZE_PASS_END(GVNSinkLegacyPass, "gvn-sink",
917                     "Early GVN sinking of Expressions", false, false)
918 
919 FunctionPass *llvm::createGVNSinkPass() { return new GVNSinkLegacyPass(); }
920