xref: /llvm-project/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp (revision 5fb57131b744c52f74919f9487f4a9fa69f455fb)
1 //===- DFAJumpThreading.cpp - Threads a switch statement inside a loop ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transform each threading path to effectively jump thread the DFA. For
10 // example, the CFG below could be transformed as follows, where the cloned
11 // blocks unconditionally branch to the next correct case based on what is
12 // identified in the analysis.
13 //
14 //          sw.bb                        sw.bb
15 //        /   |   \                    /   |   \
16 //   case1  case2  case3          case1  case2  case3
17 //        \   |   /                 |      |      |
18 //       determinator            det.2   det.3  det.1
19 //        br sw.bb                /        |        \
20 //                          sw.bb.2     sw.bb.3     sw.bb.1
21 //                           br case2    br case3    br case1§
22 //
23 // Definitions and Terminology:
24 //
25 // * Threading path:
26 //   a list of basic blocks, the exit state, and the block that determines
27 //   the next state, for which the following notation will be used:
28 //   < path of BBs that form a cycle > [ state, determinator ]
29 //
30 // * Predictable switch:
31 //   The switch variable is always a known constant so that all conditional
32 //   jumps based on switch variable can be converted to unconditional jump.
33 //
34 // * Determinator:
35 //   The basic block that determines the next state of the DFA.
36 //
37 // Representing the optimization in C-like pseudocode: the code pattern on the
38 // left could functionally be transformed to the right pattern if the switch
39 // condition is predictable.
40 //
41 //  X = A                       goto A
42 //  for (...)                   A:
43 //    switch (X)                  ...
44 //      case A                    goto B
45 //        X = B                 B:
46 //      case B                    ...
47 //        X = C                   goto C
48 //
49 // The pass first checks that switch variable X is decided by the control flow
50 // path taken in the loop; for example, in case B, the next value of X is
51 // decided to be C. It then enumerates through all paths in the loop and labels
52 // the basic blocks where the next state is decided.
53 //
54 // Using this information it creates new paths that unconditionally branch to
55 // the next case. This involves cloning code, so it only gets triggered if the
56 // amount of code duplicated is below a threshold.
57 //
58 //===----------------------------------------------------------------------===//
59 
60 #include "llvm/Transforms/Scalar/DFAJumpThreading.h"
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/ADT/DenseMap.h"
63 #include "llvm/ADT/SmallSet.h"
64 #include "llvm/ADT/Statistic.h"
65 #include "llvm/Analysis/AssumptionCache.h"
66 #include "llvm/Analysis/CodeMetrics.h"
67 #include "llvm/Analysis/DomTreeUpdater.h"
68 #include "llvm/Analysis/LoopInfo.h"
69 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
70 #include "llvm/Analysis/TargetTransformInfo.h"
71 #include "llvm/IR/CFG.h"
72 #include "llvm/IR/Constants.h"
73 #include "llvm/IR/IntrinsicInst.h"
74 #include "llvm/Support/CommandLine.h"
75 #include "llvm/Support/Debug.h"
76 #include "llvm/Transforms/Utils/Cloning.h"
77 #include "llvm/Transforms/Utils/SSAUpdaterBulk.h"
78 #include "llvm/Transforms/Utils/ValueMapper.h"
79 #include <deque>
80 
81 #ifdef EXPENSIVE_CHECKS
82 #include "llvm/IR/Verifier.h"
83 #endif
84 
85 using namespace llvm;
86 
87 #define DEBUG_TYPE "dfa-jump-threading"
88 
89 STATISTIC(NumTransforms, "Number of transformations done");
90 STATISTIC(NumCloned, "Number of blocks cloned");
91 STATISTIC(NumPaths, "Number of individual paths threaded");
92 
93 static cl::opt<bool>
94     ClViewCfgBefore("dfa-jump-view-cfg-before",
95                     cl::desc("View the CFG before DFA Jump Threading"),
96                     cl::Hidden, cl::init(false));
97 
98 static cl::opt<bool> EarlyExitHeuristic(
99     "dfa-early-exit-heuristic",
100     cl::desc("Exit early if an unpredictable value come from the same loop"),
101     cl::Hidden, cl::init(true));
102 
103 static cl::opt<unsigned> MaxPathLength(
104     "dfa-max-path-length",
105     cl::desc("Max number of blocks searched to find a threading path"),
106     cl::Hidden, cl::init(20));
107 
108 static cl::opt<unsigned> MaxNumVisitiedPaths(
109     "dfa-max-num-visited-paths",
110     cl::desc(
111         "Max number of blocks visited while enumerating paths around a switch"),
112     cl::Hidden, cl::init(2500));
113 
114 static cl::opt<unsigned>
115     MaxNumPaths("dfa-max-num-paths",
116                 cl::desc("Max number of paths enumerated around a switch"),
117                 cl::Hidden, cl::init(200));
118 
119 static cl::opt<unsigned>
120     CostThreshold("dfa-cost-threshold",
121                   cl::desc("Maximum cost accepted for the transformation"),
122                   cl::Hidden, cl::init(50));
123 
124 namespace {
125 
126 class SelectInstToUnfold {
127   SelectInst *SI;
128   PHINode *SIUse;
129 
130 public:
131   SelectInstToUnfold(SelectInst *SI, PHINode *SIUse) : SI(SI), SIUse(SIUse) {}
132 
133   SelectInst *getInst() { return SI; }
134   PHINode *getUse() { return SIUse; }
135 
136   explicit operator bool() const { return SI && SIUse; }
137 };
138 
139 void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold,
140             std::vector<SelectInstToUnfold> *NewSIsToUnfold,
141             std::vector<BasicBlock *> *NewBBs);
142 
143 class DFAJumpThreading {
144 public:
145   DFAJumpThreading(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI,
146                    TargetTransformInfo *TTI, OptimizationRemarkEmitter *ORE)
147       : AC(AC), DT(DT), LI(LI), TTI(TTI), ORE(ORE) {}
148 
149   bool run(Function &F);
150   bool LoopInfoBroken;
151 
152 private:
153   void
154   unfoldSelectInstrs(DominatorTree *DT,
155                      const SmallVector<SelectInstToUnfold, 4> &SelectInsts) {
156     DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
157     SmallVector<SelectInstToUnfold, 4> Stack(SelectInsts);
158 
159     while (!Stack.empty()) {
160       SelectInstToUnfold SIToUnfold = Stack.pop_back_val();
161 
162       std::vector<SelectInstToUnfold> NewSIsToUnfold;
163       std::vector<BasicBlock *> NewBBs;
164       unfold(&DTU, LI, SIToUnfold, &NewSIsToUnfold, &NewBBs);
165 
166       // Put newly discovered select instructions into the work list.
167       for (const SelectInstToUnfold &NewSIToUnfold : NewSIsToUnfold)
168         Stack.push_back(NewSIToUnfold);
169     }
170   }
171 
172   AssumptionCache *AC;
173   DominatorTree *DT;
174   LoopInfo *LI;
175   TargetTransformInfo *TTI;
176   OptimizationRemarkEmitter *ORE;
177 };
178 
179 } // end anonymous namespace
180 
181 namespace {
182 
183 /// Unfold the select instruction held in \p SIToUnfold by replacing it with
184 /// control flow.
185 ///
186 /// Put newly discovered select instructions into \p NewSIsToUnfold. Put newly
187 /// created basic blocks into \p NewBBs.
188 ///
189 /// TODO: merge it with CodeGenPrepare::optimizeSelectInst() if possible.
190 void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold,
191             std::vector<SelectInstToUnfold> *NewSIsToUnfold,
192             std::vector<BasicBlock *> *NewBBs) {
193   SelectInst *SI = SIToUnfold.getInst();
194   PHINode *SIUse = SIToUnfold.getUse();
195   BasicBlock *StartBlock = SI->getParent();
196   BranchInst *StartBlockTerm =
197       dyn_cast<BranchInst>(StartBlock->getTerminator());
198 
199   assert(StartBlockTerm);
200   assert(SI->hasOneUse());
201 
202   if (StartBlockTerm->isUnconditional()) {
203     BasicBlock *EndBlock = StartBlock->getUniqueSuccessor();
204     // Arbitrarily choose the 'false' side for a new input value to the PHI.
205     BasicBlock *NewBlock = BasicBlock::Create(
206         SI->getContext(), Twine(SI->getName(), ".si.unfold.false"),
207         EndBlock->getParent(), EndBlock);
208     NewBBs->push_back(NewBlock);
209     BranchInst::Create(EndBlock, NewBlock);
210     DTU->applyUpdates({{DominatorTree::Insert, NewBlock, EndBlock}});
211 
212     // StartBlock
213     //   |  \
214     //   |  NewBlock
215     //   |  /
216     // EndBlock
217     Value *SIOp1 = SI->getTrueValue();
218     Value *SIOp2 = SI->getFalseValue();
219 
220     PHINode *NewPhi = PHINode::Create(SIUse->getType(), 1,
221                                       Twine(SIOp2->getName(), ".si.unfold.phi"),
222                                       NewBlock->getFirstInsertionPt());
223     NewPhi->addIncoming(SIOp2, StartBlock);
224 
225     // Update any other PHI nodes in EndBlock.
226     for (PHINode &Phi : EndBlock->phis()) {
227       if (SIUse == &Phi)
228         continue;
229       Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), NewBlock);
230     }
231 
232     // Update the phi node of SI, which is its only use.
233     if (EndBlock == SIUse->getParent()) {
234       SIUse->addIncoming(NewPhi, NewBlock);
235       SIUse->replaceUsesOfWith(SI, SIOp1);
236     } else {
237       PHINode *EndPhi = PHINode::Create(SIUse->getType(), pred_size(EndBlock),
238                                         Twine(SI->getName(), ".si.unfold.phi"),
239                                         EndBlock->getFirstInsertionPt());
240       for (BasicBlock *Pred : predecessors(EndBlock)) {
241         if (Pred != StartBlock && Pred != NewBlock)
242           EndPhi->addIncoming(EndPhi, Pred);
243       }
244 
245       EndPhi->addIncoming(SIOp1, StartBlock);
246       EndPhi->addIncoming(NewPhi, NewBlock);
247       SIUse->replaceUsesOfWith(SI, EndPhi);
248       SIUse = EndPhi;
249     }
250 
251     if (auto *OpSi = dyn_cast<SelectInst>(SIOp1))
252       NewSIsToUnfold->push_back(SelectInstToUnfold(OpSi, SIUse));
253     if (auto *OpSi = dyn_cast<SelectInst>(SIOp2))
254       NewSIsToUnfold->push_back(SelectInstToUnfold(OpSi, NewPhi));
255 
256     // Insert the real conditional branch based on the original condition.
257     StartBlockTerm->eraseFromParent();
258     BranchInst::Create(EndBlock, NewBlock, SI->getCondition(), StartBlock);
259     DTU->applyUpdates({{DominatorTree::Insert, StartBlock, EndBlock},
260                        {DominatorTree::Insert, StartBlock, NewBlock}});
261   } else {
262     BasicBlock *EndBlock = SIUse->getParent();
263     BasicBlock *NewBlockT = BasicBlock::Create(
264         SI->getContext(), Twine(SI->getName(), ".si.unfold.true"),
265         EndBlock->getParent(), EndBlock);
266     BasicBlock *NewBlockF = BasicBlock::Create(
267         SI->getContext(), Twine(SI->getName(), ".si.unfold.false"),
268         EndBlock->getParent(), EndBlock);
269 
270     NewBBs->push_back(NewBlockT);
271     NewBBs->push_back(NewBlockF);
272 
273     // Def only has one use in EndBlock.
274     // Before transformation:
275     // StartBlock(Def)
276     //   |      \
277     // EndBlock  OtherBlock
278     //  (Use)
279     //
280     // After transformation:
281     // StartBlock(Def)
282     //   |      \
283     //   |       OtherBlock
284     // NewBlockT
285     //   |     \
286     //   |   NewBlockF
287     //   |      /
288     //   |     /
289     // EndBlock
290     //  (Use)
291     BranchInst::Create(EndBlock, NewBlockF);
292     // Insert the real conditional branch based on the original condition.
293     BranchInst::Create(EndBlock, NewBlockF, SI->getCondition(), NewBlockT);
294     DTU->applyUpdates({{DominatorTree::Insert, NewBlockT, NewBlockF},
295                        {DominatorTree::Insert, NewBlockT, EndBlock},
296                        {DominatorTree::Insert, NewBlockF, EndBlock}});
297 
298     Value *TrueVal = SI->getTrueValue();
299     Value *FalseVal = SI->getFalseValue();
300 
301     PHINode *NewPhiT = PHINode::Create(
302         SIUse->getType(), 1, Twine(TrueVal->getName(), ".si.unfold.phi"),
303         NewBlockT->getFirstInsertionPt());
304     PHINode *NewPhiF = PHINode::Create(
305         SIUse->getType(), 1, Twine(FalseVal->getName(), ".si.unfold.phi"),
306         NewBlockF->getFirstInsertionPt());
307     NewPhiT->addIncoming(TrueVal, StartBlock);
308     NewPhiF->addIncoming(FalseVal, NewBlockT);
309 
310     if (auto *TrueSI = dyn_cast<SelectInst>(TrueVal))
311       NewSIsToUnfold->push_back(SelectInstToUnfold(TrueSI, NewPhiT));
312     if (auto *FalseSi = dyn_cast<SelectInst>(FalseVal))
313       NewSIsToUnfold->push_back(SelectInstToUnfold(FalseSi, NewPhiF));
314 
315     SIUse->addIncoming(NewPhiT, NewBlockT);
316     SIUse->addIncoming(NewPhiF, NewBlockF);
317     SIUse->removeIncomingValue(StartBlock);
318 
319     // Update any other PHI nodes in EndBlock.
320     for (PHINode &Phi : EndBlock->phis()) {
321       if (SIUse == &Phi)
322         continue;
323       Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), NewBlockT);
324       Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), NewBlockF);
325       Phi.removeIncomingValue(StartBlock);
326     }
327 
328     // Update the appropriate successor of the start block to point to the new
329     // unfolded block.
330     unsigned SuccNum = StartBlockTerm->getSuccessor(1) == EndBlock ? 1 : 0;
331     StartBlockTerm->setSuccessor(SuccNum, NewBlockT);
332     DTU->applyUpdates({{DominatorTree::Delete, StartBlock, EndBlock},
333                        {DominatorTree::Insert, StartBlock, NewBlockT}});
334   }
335 
336   // Preserve loop info
337   if (Loop *L = LI->getLoopFor(SI->getParent())) {
338     for (BasicBlock *NewBB : *NewBBs)
339       L->addBasicBlockToLoop(NewBB, *LI);
340   }
341 
342   // The select is now dead.
343   assert(SI->use_empty() && "Select must be dead now");
344   SI->eraseFromParent();
345 }
346 
347 struct ClonedBlock {
348   BasicBlock *BB;
349   APInt State; ///< \p State corresponds to the next value of a switch stmnt.
350 };
351 
352 typedef std::deque<BasicBlock *> PathType;
353 typedef std::vector<PathType> PathsType;
354 typedef SmallPtrSet<const BasicBlock *, 8> VisitedBlocks;
355 typedef std::vector<ClonedBlock> CloneList;
356 
357 // This data structure keeps track of all blocks that have been cloned.  If two
358 // different ThreadingPaths clone the same block for a certain state it should
359 // be reused, and it can be looked up in this map.
360 typedef DenseMap<BasicBlock *, CloneList> DuplicateBlockMap;
361 
362 // This map keeps track of all the new definitions for an instruction. This
363 // information is needed when restoring SSA form after cloning blocks.
364 typedef MapVector<Instruction *, std::vector<Instruction *>> DefMap;
365 
366 inline raw_ostream &operator<<(raw_ostream &OS, const PathType &Path) {
367   OS << "< ";
368   for (const BasicBlock *BB : Path) {
369     std::string BBName;
370     if (BB->hasName())
371       raw_string_ostream(BBName) << BB->getName();
372     else
373       raw_string_ostream(BBName) << BB;
374     OS << BBName << " ";
375   }
376   OS << ">";
377   return OS;
378 }
379 
380 /// ThreadingPath is a path in the control flow of a loop that can be threaded
381 /// by cloning necessary basic blocks and replacing conditional branches with
382 /// unconditional ones. A threading path includes a list of basic blocks, the
383 /// exit state, and the block that determines the next state.
384 struct ThreadingPath {
385   /// Exit value is DFA's exit state for the given path.
386   APInt getExitValue() const { return ExitVal; }
387   void setExitValue(const ConstantInt *V) {
388     ExitVal = V->getValue();
389     IsExitValSet = true;
390   }
391   bool isExitValueSet() const { return IsExitValSet; }
392 
393   /// Determinator is the basic block that determines the next state of the DFA.
394   const BasicBlock *getDeterminatorBB() const { return DBB; }
395   void setDeterminator(const BasicBlock *BB) { DBB = BB; }
396 
397   /// Path is a list of basic blocks.
398   const PathType &getPath() const { return Path; }
399   void setPath(const PathType &NewPath) { Path = NewPath; }
400   void push_back(BasicBlock *BB) { Path.push_back(BB); }
401   void push_front(BasicBlock *BB) { Path.push_front(BB); }
402   void appendExcludingFirst(const PathType &OtherPath) {
403     Path.insert(Path.end(), OtherPath.begin() + 1, OtherPath.end());
404   }
405 
406   void print(raw_ostream &OS) const {
407     OS << Path << " [ " << ExitVal << ", " << DBB->getName() << " ]";
408   }
409 
410 private:
411   PathType Path;
412   APInt ExitVal;
413   const BasicBlock *DBB = nullptr;
414   bool IsExitValSet = false;
415 };
416 
417 #ifndef NDEBUG
418 inline raw_ostream &operator<<(raw_ostream &OS, const ThreadingPath &TPath) {
419   TPath.print(OS);
420   return OS;
421 }
422 #endif
423 
424 struct MainSwitch {
425   MainSwitch(SwitchInst *SI, LoopInfo *LI, OptimizationRemarkEmitter *ORE)
426       : LI(LI) {
427     if (isCandidate(SI)) {
428       Instr = SI;
429     } else {
430       ORE->emit([&]() {
431         return OptimizationRemarkMissed(DEBUG_TYPE, "SwitchNotPredictable", SI)
432                << "Switch instruction is not predictable.";
433       });
434     }
435   }
436 
437   virtual ~MainSwitch() = default;
438 
439   SwitchInst *getInstr() const { return Instr; }
440   const SmallVector<SelectInstToUnfold, 4> getSelectInsts() {
441     return SelectInsts;
442   }
443 
444 private:
445   /// Do a use-def chain traversal starting from the switch condition to see if
446   /// \p SI is a potential condidate.
447   ///
448   /// Also, collect select instructions to unfold.
449   bool isCandidate(const SwitchInst *SI) {
450     std::deque<std::pair<Value *, BasicBlock *>> Q;
451     SmallSet<Value *, 16> SeenValues;
452     SelectInsts.clear();
453 
454     Value *SICond = SI->getCondition();
455     LLVM_DEBUG(dbgs() << "\tSICond: " << *SICond << "\n");
456     if (!isa<PHINode>(SICond))
457       return false;
458 
459     // The switch must be in a loop.
460     const Loop *L = LI->getLoopFor(SI->getParent());
461     if (!L)
462       return false;
463 
464     addToQueue(SICond, nullptr, Q, SeenValues);
465 
466     while (!Q.empty()) {
467       Value *Current = Q.front().first;
468       BasicBlock *CurrentIncomingBB = Q.front().second;
469       Q.pop_front();
470 
471       if (auto *Phi = dyn_cast<PHINode>(Current)) {
472         for (BasicBlock *IncomingBB : Phi->blocks()) {
473           Value *Incoming = Phi->getIncomingValueForBlock(IncomingBB);
474           addToQueue(Incoming, IncomingBB, Q, SeenValues);
475         }
476         LLVM_DEBUG(dbgs() << "\tphi: " << *Phi << "\n");
477       } else if (SelectInst *SelI = dyn_cast<SelectInst>(Current)) {
478         if (!isValidSelectInst(SelI))
479           return false;
480         addToQueue(SelI->getTrueValue(), CurrentIncomingBB, Q, SeenValues);
481         addToQueue(SelI->getFalseValue(), CurrentIncomingBB, Q, SeenValues);
482         LLVM_DEBUG(dbgs() << "\tselect: " << *SelI << "\n");
483         if (auto *SelIUse = dyn_cast<PHINode>(SelI->user_back()))
484           SelectInsts.push_back(SelectInstToUnfold(SelI, SelIUse));
485       } else if (isa<Constant>(Current)) {
486         LLVM_DEBUG(dbgs() << "\tconst: " << *Current << "\n");
487         continue;
488       } else {
489         LLVM_DEBUG(dbgs() << "\tother: " << *Current << "\n");
490         // Allow unpredictable values. The hope is that those will be the
491         // initial switch values that can be ignored (they will hit the
492         // unthreaded switch) but this assumption will get checked later after
493         // paths have been enumerated (in function getStateDefMap).
494 
495         // If the unpredictable value comes from the same inner loop it is
496         // likely that it will also be on the enumerated paths, causing us to
497         // exit after we have enumerated all the paths. This heuristic save
498         // compile time because a search for all the paths can become expensive.
499         if (EarlyExitHeuristic &&
500             L->contains(LI->getLoopFor(CurrentIncomingBB))) {
501           LLVM_DEBUG(dbgs()
502                      << "\tExiting early due to unpredictability heuristic.\n");
503           return false;
504         }
505 
506         continue;
507       }
508     }
509 
510     return true;
511   }
512 
513   void addToQueue(Value *Val, BasicBlock *BB,
514                   std::deque<std::pair<Value *, BasicBlock *>> &Q,
515                   SmallSet<Value *, 16> &SeenValues) {
516     if (SeenValues.insert(Val).second)
517       Q.push_back({Val, BB});
518   }
519 
520   bool isValidSelectInst(SelectInst *SI) {
521     if (!SI->hasOneUse())
522       return false;
523 
524     Instruction *SIUse = dyn_cast<Instruction>(SI->user_back());
525     // The use of the select inst should be either a phi or another select.
526     if (!SIUse && !(isa<PHINode>(SIUse) || isa<SelectInst>(SIUse)))
527       return false;
528 
529     BasicBlock *SIBB = SI->getParent();
530 
531     // Currently, we can only expand select instructions in basic blocks with
532     // one successor.
533     BranchInst *SITerm = dyn_cast<BranchInst>(SIBB->getTerminator());
534     if (!SITerm || !SITerm->isUnconditional())
535       return false;
536 
537     // Only fold the select coming from directly where it is defined.
538     PHINode *PHIUser = dyn_cast<PHINode>(SIUse);
539     if (PHIUser && PHIUser->getIncomingBlock(*SI->use_begin()) != SIBB)
540       return false;
541 
542     // If select will not be sunk during unfolding, and it is in the same basic
543     // block as another state defining select, then cannot unfold both.
544     for (SelectInstToUnfold SIToUnfold : SelectInsts) {
545       SelectInst *PrevSI = SIToUnfold.getInst();
546       if (PrevSI->getTrueValue() != SI && PrevSI->getFalseValue() != SI &&
547           PrevSI->getParent() == SI->getParent())
548         return false;
549     }
550 
551     return true;
552   }
553 
554   LoopInfo *LI;
555   SwitchInst *Instr = nullptr;
556   SmallVector<SelectInstToUnfold, 4> SelectInsts;
557 };
558 
559 struct AllSwitchPaths {
560   AllSwitchPaths(const MainSwitch *MSwitch, OptimizationRemarkEmitter *ORE,
561                  LoopInfo *LI, Loop *L)
562       : Switch(MSwitch->getInstr()), SwitchBlock(Switch->getParent()), ORE(ORE),
563         LI(LI), SwitchOuterLoop(L) {}
564 
565   std::vector<ThreadingPath> &getThreadingPaths() { return TPaths; }
566   unsigned getNumThreadingPaths() { return TPaths.size(); }
567   SwitchInst *getSwitchInst() { return Switch; }
568   BasicBlock *getSwitchBlock() { return SwitchBlock; }
569 
570   void run() {
571     StateDefMap StateDef = getStateDefMap();
572     if (StateDef.empty()) {
573       ORE->emit([&]() {
574         return OptimizationRemarkMissed(DEBUG_TYPE, "SwitchNotPredictable",
575                                         Switch)
576                << "Switch instruction is not predictable.";
577       });
578       return;
579     }
580 
581     auto *SwitchPhi = cast<PHINode>(Switch->getOperand(0));
582     auto *SwitchPhiDefBB = SwitchPhi->getParent();
583     VisitedBlocks VB;
584     // Get paths from the determinator BBs to SwitchPhiDefBB
585     std::vector<ThreadingPath> PathsToPhiDef =
586         getPathsFromStateDefMap(StateDef, SwitchPhi, VB);
587     if (SwitchPhiDefBB == SwitchBlock) {
588       TPaths = std::move(PathsToPhiDef);
589       return;
590     }
591 
592     // Find and append paths from SwitchPhiDefBB to SwitchBlock.
593     PathsType PathsToSwitchBB =
594         paths(SwitchPhiDefBB, SwitchBlock, VB, /* PathDepth = */ 1);
595     if (PathsToSwitchBB.empty())
596       return;
597 
598     std::vector<ThreadingPath> TempList;
599     for (const ThreadingPath &Path : PathsToPhiDef) {
600       for (const PathType &PathToSw : PathsToSwitchBB) {
601         ThreadingPath PathCopy(Path);
602         PathCopy.appendExcludingFirst(PathToSw);
603         TempList.push_back(PathCopy);
604       }
605     }
606     TPaths = std::move(TempList);
607   }
608 
609 private:
610   // Value: an instruction that defines a switch state;
611   // Key: the parent basic block of that instruction.
612   typedef DenseMap<const BasicBlock *, const PHINode *> StateDefMap;
613   std::vector<ThreadingPath> getPathsFromStateDefMap(StateDefMap &StateDef,
614                                                      PHINode *Phi,
615                                                      VisitedBlocks &VB) {
616     std::vector<ThreadingPath> Res;
617     auto *PhiBB = Phi->getParent();
618     VB.insert(PhiBB);
619 
620     VisitedBlocks UniqueBlocks;
621     for (auto *IncomingBB : Phi->blocks()) {
622       if (!UniqueBlocks.insert(IncomingBB).second)
623         continue;
624       if (!SwitchOuterLoop->contains(IncomingBB))
625         continue;
626 
627       Value *IncomingValue = Phi->getIncomingValueForBlock(IncomingBB);
628       // We found the determinator. This is the start of our path.
629       if (auto *C = dyn_cast<ConstantInt>(IncomingValue)) {
630         // SwitchBlock is the determinator, unsupported unless its also the def.
631         if (PhiBB == SwitchBlock &&
632             SwitchBlock != cast<PHINode>(Switch->getOperand(0))->getParent())
633           continue;
634         ThreadingPath NewPath;
635         NewPath.setDeterminator(PhiBB);
636         NewPath.setExitValue(C);
637         // Don't add SwitchBlock at the start, this is handled later.
638         if (IncomingBB != SwitchBlock)
639           NewPath.push_back(IncomingBB);
640         NewPath.push_back(PhiBB);
641         Res.push_back(NewPath);
642         continue;
643       }
644       // Don't get into a cycle.
645       if (VB.contains(IncomingBB) || IncomingBB == SwitchBlock)
646         continue;
647       // Recurse up the PHI chain.
648       auto *IncomingPhi = dyn_cast<PHINode>(IncomingValue);
649       if (!IncomingPhi)
650         continue;
651       auto *IncomingPhiDefBB = IncomingPhi->getParent();
652       if (!StateDef.contains(IncomingPhiDefBB))
653         continue;
654 
655       // Direct predecessor, just add to the path.
656       if (IncomingPhiDefBB == IncomingBB) {
657         std::vector<ThreadingPath> PredPaths =
658             getPathsFromStateDefMap(StateDef, IncomingPhi, VB);
659         for (ThreadingPath &Path : PredPaths) {
660           Path.push_back(PhiBB);
661           Res.push_back(std::move(Path));
662         }
663         continue;
664       }
665       // Not a direct predecessor, find intermediate paths to append to the
666       // existing path.
667       if (VB.contains(IncomingPhiDefBB))
668         continue;
669 
670       PathsType IntermediatePaths;
671       IntermediatePaths =
672           paths(IncomingPhiDefBB, IncomingBB, VB, /* PathDepth = */ 1);
673       if (IntermediatePaths.empty())
674         continue;
675 
676       std::vector<ThreadingPath> PredPaths =
677           getPathsFromStateDefMap(StateDef, IncomingPhi, VB);
678       for (const ThreadingPath &Path : PredPaths) {
679         for (const PathType &IPath : IntermediatePaths) {
680           ThreadingPath NewPath(Path);
681           NewPath.appendExcludingFirst(IPath);
682           NewPath.push_back(PhiBB);
683           Res.push_back(NewPath);
684         }
685       }
686     }
687     VB.erase(PhiBB);
688     return Res;
689   }
690 
691   PathsType paths(BasicBlock *BB, BasicBlock *ToBB, VisitedBlocks &Visited,
692                   unsigned PathDepth) {
693     PathsType Res;
694 
695     // Stop exploring paths after visiting MaxPathLength blocks
696     if (PathDepth > MaxPathLength) {
697       ORE->emit([&]() {
698         return OptimizationRemarkAnalysis(DEBUG_TYPE, "MaxPathLengthReached",
699                                           Switch)
700                << "Exploration stopped after visiting MaxPathLength="
701                << ore::NV("MaxPathLength", MaxPathLength) << " blocks.";
702       });
703       return Res;
704     }
705 
706     Visited.insert(BB);
707     if (++NumVisited > MaxNumVisitiedPaths)
708       return Res;
709 
710     // Stop if we have reached the BB out of loop, since its successors have no
711     // impact on the DFA.
712     if (!SwitchOuterLoop->contains(BB))
713       return Res;
714 
715     // Some blocks have multiple edges to the same successor, and this set
716     // is used to prevent a duplicate path from being generated
717     SmallSet<BasicBlock *, 4> Successors;
718     for (BasicBlock *Succ : successors(BB)) {
719       if (!Successors.insert(Succ).second)
720         continue;
721 
722       // Found a cycle through the final block.
723       if (Succ == ToBB) {
724         Res.push_back({BB, ToBB});
725         continue;
726       }
727 
728       // We have encountered a cycle, do not get caught in it
729       if (Visited.contains(Succ))
730         continue;
731 
732       auto *CurrLoop = LI->getLoopFor(BB);
733       // Unlikely to be beneficial.
734       if (Succ == CurrLoop->getHeader())
735         continue;
736       // Skip for now, revisit this condition later to see the impact on
737       // coverage and compile time.
738       if (LI->getLoopFor(Succ) != CurrLoop)
739         continue;
740 
741       PathsType SuccPaths = paths(Succ, ToBB, Visited, PathDepth + 1);
742       for (PathType &Path : SuccPaths) {
743         Path.push_front(BB);
744         Res.push_back(Path);
745         if (Res.size() >= MaxNumPaths) {
746           return Res;
747         }
748       }
749     }
750     // This block could now be visited again from a different predecessor. Note
751     // that this will result in exponential runtime. Subpaths could possibly be
752     // cached but it takes a lot of memory to store them.
753     Visited.erase(BB);
754     return Res;
755   }
756 
757   /// Walk the use-def chain and collect all the state-defining blocks and the
758   /// PHI nodes in those blocks that define the state.
759   StateDefMap getStateDefMap() const {
760     StateDefMap Res;
761     PHINode *FirstDef = dyn_cast<PHINode>(Switch->getOperand(0));
762     assert(FirstDef && "The first definition must be a phi.");
763 
764     SmallVector<PHINode *, 8> Stack;
765     Stack.push_back(FirstDef);
766     SmallSet<Value *, 16> SeenValues;
767 
768     while (!Stack.empty()) {
769       PHINode *CurPhi = Stack.pop_back_val();
770 
771       Res[CurPhi->getParent()] = CurPhi;
772       SeenValues.insert(CurPhi);
773 
774       for (BasicBlock *IncomingBB : CurPhi->blocks()) {
775         PHINode *IncomingPhi =
776             dyn_cast<PHINode>(CurPhi->getIncomingValueForBlock(IncomingBB));
777         if (!IncomingPhi)
778           continue;
779         bool IsOutsideLoops = !SwitchOuterLoop->contains(IncomingBB);
780         if (SeenValues.contains(IncomingPhi) || IsOutsideLoops)
781           continue;
782 
783         Stack.push_back(IncomingPhi);
784       }
785     }
786 
787     return Res;
788   }
789 
790   unsigned NumVisited = 0;
791   SwitchInst *Switch;
792   BasicBlock *SwitchBlock;
793   OptimizationRemarkEmitter *ORE;
794   std::vector<ThreadingPath> TPaths;
795   LoopInfo *LI;
796   Loop *SwitchOuterLoop;
797 };
798 
799 struct TransformDFA {
800   TransformDFA(AllSwitchPaths *SwitchPaths, DominatorTree *DT,
801                AssumptionCache *AC, TargetTransformInfo *TTI,
802                OptimizationRemarkEmitter *ORE,
803                SmallPtrSet<const Value *, 32> EphValues)
804       : SwitchPaths(SwitchPaths), DT(DT), AC(AC), TTI(TTI), ORE(ORE),
805         EphValues(EphValues) {}
806 
807   void run() {
808     if (isLegalAndProfitableToTransform()) {
809       createAllExitPaths();
810       NumTransforms++;
811     }
812   }
813 
814 private:
815   /// This function performs both a legality check and profitability check at
816   /// the same time since it is convenient to do so. It iterates through all
817   /// blocks that will be cloned, and keeps track of the duplication cost. It
818   /// also returns false if it is illegal to clone some required block.
819   bool isLegalAndProfitableToTransform() {
820     CodeMetrics Metrics;
821     SwitchInst *Switch = SwitchPaths->getSwitchInst();
822 
823     // Don't thread switch without multiple successors.
824     if (Switch->getNumSuccessors() <= 1)
825       return false;
826 
827     // Note that DuplicateBlockMap is not being used as intended here. It is
828     // just being used to ensure (BB, State) pairs are only counted once.
829     DuplicateBlockMap DuplicateMap;
830 
831     for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
832       PathType PathBBs = TPath.getPath();
833       APInt NextState = TPath.getExitValue();
834       const BasicBlock *Determinator = TPath.getDeterminatorBB();
835 
836       // Update Metrics for the Switch block, this is always cloned
837       BasicBlock *BB = SwitchPaths->getSwitchBlock();
838       BasicBlock *VisitedBB = getClonedBB(BB, NextState, DuplicateMap);
839       if (!VisitedBB) {
840         Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
841         DuplicateMap[BB].push_back({BB, NextState});
842       }
843 
844       // If the Switch block is the Determinator, then we can continue since
845       // this is the only block that is cloned and we already counted for it.
846       if (PathBBs.front() == Determinator)
847         continue;
848 
849       // Otherwise update Metrics for all blocks that will be cloned. If any
850       // block is already cloned and would be reused, don't double count it.
851       auto DetIt = llvm::find(PathBBs, Determinator);
852       for (auto BBIt = DetIt; BBIt != PathBBs.end(); BBIt++) {
853         BB = *BBIt;
854         VisitedBB = getClonedBB(BB, NextState, DuplicateMap);
855         if (VisitedBB)
856           continue;
857         Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
858         DuplicateMap[BB].push_back({BB, NextState});
859       }
860 
861       if (Metrics.notDuplicatable) {
862         LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, contains "
863                           << "non-duplicatable instructions.\n");
864         ORE->emit([&]() {
865           return OptimizationRemarkMissed(DEBUG_TYPE, "NonDuplicatableInst",
866                                           Switch)
867                  << "Contains non-duplicatable instructions.";
868         });
869         return false;
870       }
871 
872       // FIXME: Allow jump threading with controlled convergence.
873       if (Metrics.Convergence != ConvergenceKind::None) {
874         LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, contains "
875                           << "convergent instructions.\n");
876         ORE->emit([&]() {
877           return OptimizationRemarkMissed(DEBUG_TYPE, "ConvergentInst", Switch)
878                  << "Contains convergent instructions.";
879         });
880         return false;
881       }
882 
883       if (!Metrics.NumInsts.isValid()) {
884         LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, contains "
885                           << "instructions with invalid cost.\n");
886         ORE->emit([&]() {
887           return OptimizationRemarkMissed(DEBUG_TYPE, "ConvergentInst", Switch)
888                  << "Contains instructions with invalid cost.";
889         });
890         return false;
891       }
892     }
893 
894     InstructionCost DuplicationCost = 0;
895 
896     unsigned JumpTableSize = 0;
897     TTI->getEstimatedNumberOfCaseClusters(*Switch, JumpTableSize, nullptr,
898                                           nullptr);
899     if (JumpTableSize == 0) {
900       // Factor in the number of conditional branches reduced from jump
901       // threading. Assume that lowering the switch block is implemented by
902       // using binary search, hence the LogBase2().
903       unsigned CondBranches =
904           APInt(32, Switch->getNumSuccessors()).ceilLogBase2();
905       assert(CondBranches > 0 &&
906              "The threaded switch must have multiple branches");
907       DuplicationCost = Metrics.NumInsts / CondBranches;
908     } else {
909       // Compared with jump tables, the DFA optimizer removes an indirect branch
910       // on each loop iteration, thus making branch prediction more precise. The
911       // more branch targets there are, the more likely it is for the branch
912       // predictor to make a mistake, and the more benefit there is in the DFA
913       // optimizer. Thus, the more branch targets there are, the lower is the
914       // cost of the DFA opt.
915       DuplicationCost = Metrics.NumInsts / JumpTableSize;
916     }
917 
918     LLVM_DEBUG(dbgs() << "\nDFA Jump Threading: Cost to jump thread block "
919                       << SwitchPaths->getSwitchBlock()->getName()
920                       << " is: " << DuplicationCost << "\n\n");
921 
922     if (DuplicationCost > CostThreshold) {
923       LLVM_DEBUG(dbgs() << "Not jump threading, duplication cost exceeds the "
924                         << "cost threshold.\n");
925       ORE->emit([&]() {
926         return OptimizationRemarkMissed(DEBUG_TYPE, "NotProfitable", Switch)
927                << "Duplication cost exceeds the cost threshold (cost="
928                << ore::NV("Cost", DuplicationCost)
929                << ", threshold=" << ore::NV("Threshold", CostThreshold) << ").";
930       });
931       return false;
932     }
933 
934     ORE->emit([&]() {
935       return OptimizationRemark(DEBUG_TYPE, "JumpThreaded", Switch)
936              << "Switch statement jump-threaded.";
937     });
938 
939     return true;
940   }
941 
942   /// Transform each threading path to effectively jump thread the DFA.
943   void createAllExitPaths() {
944     DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Eager);
945 
946     // Move the switch block to the end of the path, since it will be duplicated
947     BasicBlock *SwitchBlock = SwitchPaths->getSwitchBlock();
948     for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
949       LLVM_DEBUG(dbgs() << TPath << "\n");
950       // TODO: Fix exit path creation logic so that we dont need this
951       // placeholder.
952       TPath.push_front(SwitchBlock);
953     }
954 
955     // Transform the ThreadingPaths and keep track of the cloned values
956     DuplicateBlockMap DuplicateMap;
957     DefMap NewDefs;
958 
959     SmallSet<BasicBlock *, 16> BlocksToClean;
960     for (BasicBlock *BB : successors(SwitchBlock))
961       BlocksToClean.insert(BB);
962 
963     for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
964       createExitPath(NewDefs, TPath, DuplicateMap, BlocksToClean, &DTU);
965       NumPaths++;
966     }
967 
968     // After all paths are cloned, now update the last successor of the cloned
969     // path so it skips over the switch statement
970     for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths())
971       updateLastSuccessor(TPath, DuplicateMap, &DTU);
972 
973     // For each instruction that was cloned and used outside, update its uses
974     updateSSA(NewDefs);
975 
976     // Clean PHI Nodes for the newly created blocks
977     for (BasicBlock *BB : BlocksToClean)
978       cleanPhiNodes(BB);
979   }
980 
981   /// For a specific ThreadingPath \p Path, create an exit path starting from
982   /// the determinator block.
983   ///
984   /// To remember the correct destination, we have to duplicate blocks
985   /// corresponding to each state. Also update the terminating instruction of
986   /// the predecessors, and phis in the successor blocks.
987   void createExitPath(DefMap &NewDefs, ThreadingPath &Path,
988                       DuplicateBlockMap &DuplicateMap,
989                       SmallSet<BasicBlock *, 16> &BlocksToClean,
990                       DomTreeUpdater *DTU) {
991     APInt NextState = Path.getExitValue();
992     const BasicBlock *Determinator = Path.getDeterminatorBB();
993     PathType PathBBs = Path.getPath();
994 
995     // Don't select the placeholder block in front
996     if (PathBBs.front() == Determinator)
997       PathBBs.pop_front();
998 
999     auto DetIt = llvm::find(PathBBs, Determinator);
1000     // When there is only one BB in PathBBs, the determinator takes itself as a
1001     // direct predecessor.
1002     BasicBlock *PrevBB = PathBBs.size() == 1 ? *DetIt : *std::prev(DetIt);
1003     for (auto BBIt = DetIt; BBIt != PathBBs.end(); BBIt++) {
1004       BasicBlock *BB = *BBIt;
1005       BlocksToClean.insert(BB);
1006 
1007       // We already cloned BB for this NextState, now just update the branch
1008       // and continue.
1009       BasicBlock *NextBB = getClonedBB(BB, NextState, DuplicateMap);
1010       if (NextBB) {
1011         updatePredecessor(PrevBB, BB, NextBB, DTU);
1012         PrevBB = NextBB;
1013         continue;
1014       }
1015 
1016       // Clone the BB and update the successor of Prev to jump to the new block
1017       BasicBlock *NewBB = cloneBlockAndUpdatePredecessor(
1018           BB, PrevBB, NextState, DuplicateMap, NewDefs, DTU);
1019       DuplicateMap[BB].push_back({NewBB, NextState});
1020       BlocksToClean.insert(NewBB);
1021       PrevBB = NewBB;
1022     }
1023   }
1024 
1025   /// Restore SSA form after cloning blocks.
1026   ///
1027   /// Each cloned block creates new defs for a variable, and the uses need to be
1028   /// updated to reflect this. The uses may be replaced with a cloned value, or
1029   /// some derived phi instruction. Note that all uses of a value defined in the
1030   /// same block were already remapped when cloning the block.
1031   void updateSSA(DefMap &NewDefs) {
1032     SSAUpdaterBulk SSAUpdate;
1033     SmallVector<Use *, 16> UsesToRename;
1034 
1035     for (const auto &KV : NewDefs) {
1036       Instruction *I = KV.first;
1037       BasicBlock *BB = I->getParent();
1038       std::vector<Instruction *> Cloned = KV.second;
1039 
1040       // Scan all uses of this instruction to see if it is used outside of its
1041       // block, and if so, record them in UsesToRename.
1042       for (Use &U : I->uses()) {
1043         Instruction *User = cast<Instruction>(U.getUser());
1044         if (PHINode *UserPN = dyn_cast<PHINode>(User)) {
1045           if (UserPN->getIncomingBlock(U) == BB)
1046             continue;
1047         } else if (User->getParent() == BB) {
1048           continue;
1049         }
1050 
1051         UsesToRename.push_back(&U);
1052       }
1053 
1054       // If there are no uses outside the block, we're done with this
1055       // instruction.
1056       if (UsesToRename.empty())
1057         continue;
1058       LLVM_DEBUG(dbgs() << "DFA-JT: Renaming non-local uses of: " << *I
1059                         << "\n");
1060 
1061       // We found a use of I outside of BB.  Rename all uses of I that are
1062       // outside its block to be uses of the appropriate PHI node etc.  See
1063       // ValuesInBlocks with the values we know.
1064       unsigned VarNum = SSAUpdate.AddVariable(I->getName(), I->getType());
1065       SSAUpdate.AddAvailableValue(VarNum, BB, I);
1066       for (Instruction *New : Cloned)
1067         SSAUpdate.AddAvailableValue(VarNum, New->getParent(), New);
1068 
1069       while (!UsesToRename.empty())
1070         SSAUpdate.AddUse(VarNum, UsesToRename.pop_back_val());
1071 
1072       LLVM_DEBUG(dbgs() << "\n");
1073     }
1074     // SSAUpdater handles phi placement and renaming uses with the appropriate
1075     // value.
1076     SSAUpdate.RewriteAllUses(DT);
1077   }
1078 
1079   /// Clones a basic block, and adds it to the CFG.
1080   ///
1081   /// This function also includes updating phi nodes in the successors of the
1082   /// BB, and remapping uses that were defined locally in the cloned BB.
1083   BasicBlock *cloneBlockAndUpdatePredecessor(BasicBlock *BB, BasicBlock *PrevBB,
1084                                              const APInt &NextState,
1085                                              DuplicateBlockMap &DuplicateMap,
1086                                              DefMap &NewDefs,
1087                                              DomTreeUpdater *DTU) {
1088     ValueToValueMapTy VMap;
1089     BasicBlock *NewBB = CloneBasicBlock(
1090         BB, VMap, ".jt" + std::to_string(NextState.getLimitedValue()),
1091         BB->getParent());
1092     NewBB->moveAfter(BB);
1093     NumCloned++;
1094 
1095     for (Instruction &I : *NewBB) {
1096       // Do not remap operands of PHINode in case a definition in BB is an
1097       // incoming value to a phi in the same block. This incoming value will
1098       // be renamed later while restoring SSA.
1099       if (isa<PHINode>(&I))
1100         continue;
1101       RemapInstruction(&I, VMap,
1102                        RF_IgnoreMissingLocals | RF_NoModuleLevelChanges);
1103       if (AssumeInst *II = dyn_cast<AssumeInst>(&I))
1104         AC->registerAssumption(II);
1105     }
1106 
1107     updateSuccessorPhis(BB, NewBB, NextState, VMap, DuplicateMap);
1108     updatePredecessor(PrevBB, BB, NewBB, DTU);
1109     updateDefMap(NewDefs, VMap);
1110 
1111     // Add all successors to the DominatorTree
1112     SmallPtrSet<BasicBlock *, 4> SuccSet;
1113     for (auto *SuccBB : successors(NewBB)) {
1114       if (SuccSet.insert(SuccBB).second)
1115         DTU->applyUpdates({{DominatorTree::Insert, NewBB, SuccBB}});
1116     }
1117     SuccSet.clear();
1118     return NewBB;
1119   }
1120 
1121   /// Update the phi nodes in BB's successors.
1122   ///
1123   /// This means creating a new incoming value from NewBB with the new
1124   /// instruction wherever there is an incoming value from BB.
1125   void updateSuccessorPhis(BasicBlock *BB, BasicBlock *ClonedBB,
1126                            const APInt &NextState, ValueToValueMapTy &VMap,
1127                            DuplicateBlockMap &DuplicateMap) {
1128     std::vector<BasicBlock *> BlocksToUpdate;
1129 
1130     // If BB is the last block in the path, we can simply update the one case
1131     // successor that will be reached.
1132     if (BB == SwitchPaths->getSwitchBlock()) {
1133       SwitchInst *Switch = SwitchPaths->getSwitchInst();
1134       BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
1135       BlocksToUpdate.push_back(NextCase);
1136       BasicBlock *ClonedSucc = getClonedBB(NextCase, NextState, DuplicateMap);
1137       if (ClonedSucc)
1138         BlocksToUpdate.push_back(ClonedSucc);
1139     }
1140     // Otherwise update phis in all successors.
1141     else {
1142       for (BasicBlock *Succ : successors(BB)) {
1143         BlocksToUpdate.push_back(Succ);
1144 
1145         // Check if a successor has already been cloned for the particular exit
1146         // value. In this case if a successor was already cloned, the phi nodes
1147         // in the cloned block should be updated directly.
1148         BasicBlock *ClonedSucc = getClonedBB(Succ, NextState, DuplicateMap);
1149         if (ClonedSucc)
1150           BlocksToUpdate.push_back(ClonedSucc);
1151       }
1152     }
1153 
1154     // If there is a phi with an incoming value from BB, create a new incoming
1155     // value for the new predecessor ClonedBB. The value will either be the same
1156     // value from BB or a cloned value.
1157     for (BasicBlock *Succ : BlocksToUpdate) {
1158       for (auto II = Succ->begin(); PHINode *Phi = dyn_cast<PHINode>(II);
1159            ++II) {
1160         Value *Incoming = Phi->getIncomingValueForBlock(BB);
1161         if (Incoming) {
1162           if (isa<Constant>(Incoming)) {
1163             Phi->addIncoming(Incoming, ClonedBB);
1164             continue;
1165           }
1166           Value *ClonedVal = VMap[Incoming];
1167           if (ClonedVal)
1168             Phi->addIncoming(ClonedVal, ClonedBB);
1169           else
1170             Phi->addIncoming(Incoming, ClonedBB);
1171         }
1172       }
1173     }
1174   }
1175 
1176   /// Sets the successor of PrevBB to be NewBB instead of OldBB. Note that all
1177   /// other successors are kept as well.
1178   void updatePredecessor(BasicBlock *PrevBB, BasicBlock *OldBB,
1179                          BasicBlock *NewBB, DomTreeUpdater *DTU) {
1180     // When a path is reused, there is a chance that predecessors were already
1181     // updated before. Check if the predecessor needs to be updated first.
1182     if (!isPredecessor(OldBB, PrevBB))
1183       return;
1184 
1185     Instruction *PrevTerm = PrevBB->getTerminator();
1186     for (unsigned Idx = 0; Idx < PrevTerm->getNumSuccessors(); Idx++) {
1187       if (PrevTerm->getSuccessor(Idx) == OldBB) {
1188         OldBB->removePredecessor(PrevBB, /* KeepOneInputPHIs = */ true);
1189         PrevTerm->setSuccessor(Idx, NewBB);
1190       }
1191     }
1192     DTU->applyUpdates({{DominatorTree::Delete, PrevBB, OldBB},
1193                        {DominatorTree::Insert, PrevBB, NewBB}});
1194   }
1195 
1196   /// Add new value mappings to the DefMap to keep track of all new definitions
1197   /// for a particular instruction. These will be used while updating SSA form.
1198   void updateDefMap(DefMap &NewDefs, ValueToValueMapTy &VMap) {
1199     SmallVector<std::pair<Instruction *, Instruction *>> NewDefsVector;
1200     NewDefsVector.reserve(VMap.size());
1201 
1202     for (auto Entry : VMap) {
1203       Instruction *Inst =
1204           dyn_cast<Instruction>(const_cast<Value *>(Entry.first));
1205       if (!Inst || !Entry.second || isa<BranchInst>(Inst) ||
1206           isa<SwitchInst>(Inst)) {
1207         continue;
1208       }
1209 
1210       Instruction *Cloned = dyn_cast<Instruction>(Entry.second);
1211       if (!Cloned)
1212         continue;
1213 
1214       NewDefsVector.push_back({Inst, Cloned});
1215     }
1216 
1217     // Sort the defs to get deterministic insertion order into NewDefs.
1218     sort(NewDefsVector, [](const auto &LHS, const auto &RHS) {
1219       if (LHS.first == RHS.first)
1220         return LHS.second->comesBefore(RHS.second);
1221       return LHS.first->comesBefore(RHS.first);
1222     });
1223 
1224     for (const auto &KV : NewDefsVector)
1225       NewDefs[KV.first].push_back(KV.second);
1226   }
1227 
1228   /// Update the last branch of a particular cloned path to point to the correct
1229   /// case successor.
1230   ///
1231   /// Note that this is an optional step and would have been done in later
1232   /// optimizations, but it makes the CFG significantly easier to work with.
1233   void updateLastSuccessor(ThreadingPath &TPath,
1234                            DuplicateBlockMap &DuplicateMap,
1235                            DomTreeUpdater *DTU) {
1236     APInt NextState = TPath.getExitValue();
1237     BasicBlock *BB = TPath.getPath().back();
1238     BasicBlock *LastBlock = getClonedBB(BB, NextState, DuplicateMap);
1239 
1240     // Note multiple paths can end at the same block so check that it is not
1241     // updated yet
1242     if (!isa<SwitchInst>(LastBlock->getTerminator()))
1243       return;
1244     SwitchInst *Switch = cast<SwitchInst>(LastBlock->getTerminator());
1245     BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
1246 
1247     std::vector<DominatorTree::UpdateType> DTUpdates;
1248     SmallPtrSet<BasicBlock *, 4> SuccSet;
1249     for (BasicBlock *Succ : successors(LastBlock)) {
1250       if (Succ != NextCase && SuccSet.insert(Succ).second)
1251         DTUpdates.push_back({DominatorTree::Delete, LastBlock, Succ});
1252     }
1253 
1254     Switch->eraseFromParent();
1255     BranchInst::Create(NextCase, LastBlock);
1256 
1257     DTU->applyUpdates(DTUpdates);
1258   }
1259 
1260   /// After cloning blocks, some of the phi nodes have extra incoming values
1261   /// that are no longer used. This function removes them.
1262   void cleanPhiNodes(BasicBlock *BB) {
1263     // If BB is no longer reachable, remove any remaining phi nodes
1264     if (pred_empty(BB)) {
1265       std::vector<PHINode *> PhiToRemove;
1266       for (auto II = BB->begin(); PHINode *Phi = dyn_cast<PHINode>(II); ++II) {
1267         PhiToRemove.push_back(Phi);
1268       }
1269       for (PHINode *PN : PhiToRemove) {
1270         PN->replaceAllUsesWith(PoisonValue::get(PN->getType()));
1271         PN->eraseFromParent();
1272       }
1273       return;
1274     }
1275 
1276     // Remove any incoming values that come from an invalid predecessor
1277     for (auto II = BB->begin(); PHINode *Phi = dyn_cast<PHINode>(II); ++II) {
1278       std::vector<BasicBlock *> BlocksToRemove;
1279       for (BasicBlock *IncomingBB : Phi->blocks()) {
1280         if (!isPredecessor(BB, IncomingBB))
1281           BlocksToRemove.push_back(IncomingBB);
1282       }
1283       for (BasicBlock *BB : BlocksToRemove)
1284         Phi->removeIncomingValue(BB);
1285     }
1286   }
1287 
1288   /// Checks if BB was already cloned for a particular next state value. If it
1289   /// was then it returns this cloned block, and otherwise null.
1290   BasicBlock *getClonedBB(BasicBlock *BB, const APInt &NextState,
1291                           DuplicateBlockMap &DuplicateMap) {
1292     CloneList ClonedBBs = DuplicateMap[BB];
1293 
1294     // Find an entry in the CloneList with this NextState. If it exists then
1295     // return the corresponding BB
1296     auto It = llvm::find_if(ClonedBBs, [NextState](const ClonedBlock &C) {
1297       return C.State == NextState;
1298     });
1299     return It != ClonedBBs.end() ? (*It).BB : nullptr;
1300   }
1301 
1302   /// Helper to get the successor corresponding to a particular case value for
1303   /// a switch statement.
1304   BasicBlock *getNextCaseSuccessor(SwitchInst *Switch, const APInt &NextState) {
1305     BasicBlock *NextCase = nullptr;
1306     for (auto Case : Switch->cases()) {
1307       if (Case.getCaseValue()->getValue() == NextState) {
1308         NextCase = Case.getCaseSuccessor();
1309         break;
1310       }
1311     }
1312     if (!NextCase)
1313       NextCase = Switch->getDefaultDest();
1314     return NextCase;
1315   }
1316 
1317   /// Returns true if IncomingBB is a predecessor of BB.
1318   bool isPredecessor(BasicBlock *BB, BasicBlock *IncomingBB) {
1319     return llvm::is_contained(predecessors(BB), IncomingBB);
1320   }
1321 
1322   AllSwitchPaths *SwitchPaths;
1323   DominatorTree *DT;
1324   AssumptionCache *AC;
1325   TargetTransformInfo *TTI;
1326   OptimizationRemarkEmitter *ORE;
1327   SmallPtrSet<const Value *, 32> EphValues;
1328   std::vector<ThreadingPath> TPaths;
1329 };
1330 
1331 bool DFAJumpThreading::run(Function &F) {
1332   LLVM_DEBUG(dbgs() << "\nDFA Jump threading: " << F.getName() << "\n");
1333 
1334   if (F.hasOptSize()) {
1335     LLVM_DEBUG(dbgs() << "Skipping due to the 'minsize' attribute\n");
1336     return false;
1337   }
1338 
1339   if (ClViewCfgBefore)
1340     F.viewCFG();
1341 
1342   SmallVector<AllSwitchPaths, 2> ThreadableLoops;
1343   bool MadeChanges = false;
1344   LoopInfoBroken = false;
1345 
1346   for (BasicBlock &BB : F) {
1347     auto *SI = dyn_cast<SwitchInst>(BB.getTerminator());
1348     if (!SI)
1349       continue;
1350 
1351     LLVM_DEBUG(dbgs() << "\nCheck if SwitchInst in BB " << BB.getName()
1352                       << " is a candidate\n");
1353     MainSwitch Switch(SI, LI, ORE);
1354 
1355     if (!Switch.getInstr()) {
1356       LLVM_DEBUG(dbgs() << "\nSwitchInst in BB " << BB.getName() << " is not a "
1357                         << "candidate for jump threading\n");
1358       continue;
1359     }
1360 
1361     LLVM_DEBUG(dbgs() << "\nSwitchInst in BB " << BB.getName() << " is a "
1362                       << "candidate for jump threading\n");
1363     LLVM_DEBUG(SI->dump());
1364 
1365     unfoldSelectInstrs(DT, Switch.getSelectInsts());
1366     if (!Switch.getSelectInsts().empty())
1367       MadeChanges = true;
1368 
1369     AllSwitchPaths SwitchPaths(&Switch, ORE, LI,
1370                                LI->getLoopFor(&BB)->getOutermostLoop());
1371     SwitchPaths.run();
1372 
1373     if (SwitchPaths.getNumThreadingPaths() > 0) {
1374       ThreadableLoops.push_back(SwitchPaths);
1375 
1376       // For the time being limit this optimization to occurring once in a
1377       // function since it can change the CFG significantly. This is not a
1378       // strict requirement but it can cause buggy behavior if there is an
1379       // overlap of blocks in different opportunities. There is a lot of room to
1380       // experiment with catching more opportunities here.
1381       // NOTE: To release this contraint, we must handle LoopInfo invalidation
1382       break;
1383     }
1384   }
1385 
1386 #ifdef NDEBUG
1387   LI->verify(*DT);
1388 #endif
1389 
1390   SmallPtrSet<const Value *, 32> EphValues;
1391   if (ThreadableLoops.size() > 0)
1392     CodeMetrics::collectEphemeralValues(&F, AC, EphValues);
1393 
1394   for (AllSwitchPaths SwitchPaths : ThreadableLoops) {
1395     TransformDFA Transform(&SwitchPaths, DT, AC, TTI, ORE, EphValues);
1396     Transform.run();
1397     MadeChanges = true;
1398     LoopInfoBroken = true;
1399   }
1400 
1401 #ifdef EXPENSIVE_CHECKS
1402   assert(DT->verify(DominatorTree::VerificationLevel::Full));
1403   verifyFunction(F, &dbgs());
1404 #endif
1405 
1406   return MadeChanges;
1407 }
1408 
1409 } // end anonymous namespace
1410 
1411 /// Integrate with the new Pass Manager
1412 PreservedAnalyses DFAJumpThreadingPass::run(Function &F,
1413                                             FunctionAnalysisManager &AM) {
1414   AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
1415   DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
1416   LoopInfo &LI = AM.getResult<LoopAnalysis>(F);
1417   TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
1418   OptimizationRemarkEmitter ORE(&F);
1419   DFAJumpThreading ThreadImpl(&AC, &DT, &LI, &TTI, &ORE);
1420   if (!ThreadImpl.run(F))
1421     return PreservedAnalyses::all();
1422 
1423   PreservedAnalyses PA;
1424   PA.preserve<DominatorTreeAnalysis>();
1425   if (!ThreadImpl.LoopInfoBroken)
1426     PA.preserve<LoopAnalysis>();
1427   return PA;
1428 }
1429