xref: /freebsd-src/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
10b57cec5SDimitry Andric //===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===//
20b57cec5SDimitry Andric //
30b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
50b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60b57cec5SDimitry Andric //
70b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
80b57cec5SDimitry Andric //
90b57cec5SDimitry Andric // The LowerSwitch transformation rewrites switch instructions with a sequence
100b57cec5SDimitry Andric // of branches, which allows targets to get away with not implementing the
110b57cec5SDimitry Andric // switch instruction until it is convenient.
120b57cec5SDimitry Andric //
130b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
140b57cec5SDimitry Andric 
15e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/LowerSwitch.h"
160b57cec5SDimitry Andric #include "llvm/ADT/DenseMap.h"
170b57cec5SDimitry Andric #include "llvm/ADT/STLExtras.h"
180b57cec5SDimitry Andric #include "llvm/ADT/SmallPtrSet.h"
190b57cec5SDimitry Andric #include "llvm/ADT/SmallVector.h"
200b57cec5SDimitry Andric #include "llvm/Analysis/AssumptionCache.h"
210b57cec5SDimitry Andric #include "llvm/Analysis/LazyValueInfo.h"
220b57cec5SDimitry Andric #include "llvm/Analysis/ValueTracking.h"
230b57cec5SDimitry Andric #include "llvm/IR/BasicBlock.h"
240b57cec5SDimitry Andric #include "llvm/IR/CFG.h"
250b57cec5SDimitry Andric #include "llvm/IR/ConstantRange.h"
260b57cec5SDimitry Andric #include "llvm/IR/Constants.h"
270b57cec5SDimitry Andric #include "llvm/IR/Function.h"
280b57cec5SDimitry Andric #include "llvm/IR/InstrTypes.h"
290b57cec5SDimitry Andric #include "llvm/IR/Instructions.h"
30e8d8bef9SDimitry Andric #include "llvm/IR/PassManager.h"
310b57cec5SDimitry Andric #include "llvm/IR/Value.h"
32480093f4SDimitry Andric #include "llvm/InitializePasses.h"
330b57cec5SDimitry Andric #include "llvm/Pass.h"
340b57cec5SDimitry Andric #include "llvm/Support/Casting.h"
350b57cec5SDimitry Andric #include "llvm/Support/Compiler.h"
360b57cec5SDimitry Andric #include "llvm/Support/Debug.h"
370b57cec5SDimitry Andric #include "llvm/Support/KnownBits.h"
380b57cec5SDimitry Andric #include "llvm/Support/raw_ostream.h"
390b57cec5SDimitry Andric #include "llvm/Transforms/Utils.h"
400b57cec5SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h"
410b57cec5SDimitry Andric #include <algorithm>
420b57cec5SDimitry Andric #include <cassert>
430b57cec5SDimitry Andric #include <cstdint>
440b57cec5SDimitry Andric #include <iterator>
450b57cec5SDimitry Andric #include <vector>
460b57cec5SDimitry Andric 
470b57cec5SDimitry Andric using namespace llvm;
480b57cec5SDimitry Andric 
490b57cec5SDimitry Andric #define DEBUG_TYPE "lower-switch"
500b57cec5SDimitry Andric 
510b57cec5SDimitry Andric namespace {
520b57cec5SDimitry Andric 
530b57cec5SDimitry Andric struct IntRange {
54bdd1243dSDimitry Andric   APInt Low, High;
550b57cec5SDimitry Andric };
560b57cec5SDimitry Andric 
570b57cec5SDimitry Andric } // end anonymous namespace
580b57cec5SDimitry Andric 
59e8d8bef9SDimitry Andric namespace {
600b57cec5SDimitry Andric // Return true iff R is covered by Ranges.
61e8d8bef9SDimitry Andric bool IsInRanges(const IntRange &R, const std::vector<IntRange> &Ranges) {
620b57cec5SDimitry Andric   // Note: Ranges must be sorted, non-overlapping and non-adjacent.
630b57cec5SDimitry Andric 
640b57cec5SDimitry Andric   // Find the first range whose High field is >= R.High,
650b57cec5SDimitry Andric   // then check if the Low field is <= R.Low. If so, we
660b57cec5SDimitry Andric   // have a Range that covers R.
670b57cec5SDimitry Andric   auto I = llvm::lower_bound(
68bdd1243dSDimitry Andric       Ranges, R, [](IntRange A, IntRange B) { return A.High.slt(B.High); });
69bdd1243dSDimitry Andric   return I != Ranges.end() && I->Low.sle(R.Low);
700b57cec5SDimitry Andric }
710b57cec5SDimitry Andric 
720b57cec5SDimitry Andric struct CaseRange {
730b57cec5SDimitry Andric   ConstantInt *Low;
740b57cec5SDimitry Andric   ConstantInt *High;
750b57cec5SDimitry Andric   BasicBlock *BB;
760b57cec5SDimitry Andric 
770b57cec5SDimitry Andric   CaseRange(ConstantInt *low, ConstantInt *high, BasicBlock *bb)
780b57cec5SDimitry Andric       : Low(low), High(high), BB(bb) {}
790b57cec5SDimitry Andric };
800b57cec5SDimitry Andric 
810b57cec5SDimitry Andric using CaseVector = std::vector<CaseRange>;
820b57cec5SDimitry Andric using CaseItr = std::vector<CaseRange>::iterator;
830b57cec5SDimitry Andric 
840b57cec5SDimitry Andric /// The comparison function for sorting the switch case values in the vector.
850b57cec5SDimitry Andric /// WARNING: Case ranges should be disjoint!
860b57cec5SDimitry Andric struct CaseCmp {
87e8d8bef9SDimitry Andric   bool operator()(const CaseRange &C1, const CaseRange &C2) {
880b57cec5SDimitry Andric     const ConstantInt *CI1 = cast<const ConstantInt>(C1.Low);
890b57cec5SDimitry Andric     const ConstantInt *CI2 = cast<const ConstantInt>(C2.High);
900b57cec5SDimitry Andric     return CI1->getValue().slt(CI2->getValue());
910b57cec5SDimitry Andric   }
920b57cec5SDimitry Andric };
930b57cec5SDimitry Andric 
940b57cec5SDimitry Andric /// Used for debugging purposes.
950b57cec5SDimitry Andric LLVM_ATTRIBUTE_USED
96e8d8bef9SDimitry Andric raw_ostream &operator<<(raw_ostream &O, const CaseVector &C) {
970b57cec5SDimitry Andric   O << "[";
980b57cec5SDimitry Andric 
99e8d8bef9SDimitry Andric   for (CaseVector::const_iterator B = C.begin(), E = C.end(); B != E;) {
1000b57cec5SDimitry Andric     O << "[" << B->Low->getValue() << ", " << B->High->getValue() << "]";
1010b57cec5SDimitry Andric     if (++B != E)
1020b57cec5SDimitry Andric       O << ", ";
1030b57cec5SDimitry Andric   }
1040b57cec5SDimitry Andric 
1050b57cec5SDimitry Andric   return O << "]";
1060b57cec5SDimitry Andric }
1070b57cec5SDimitry Andric 
1080b57cec5SDimitry Andric /// Update the first occurrence of the "switch statement" BB in the PHI
1090b57cec5SDimitry Andric /// node with the "new" BB. The other occurrences will:
1100b57cec5SDimitry Andric ///
1110b57cec5SDimitry Andric /// 1) Be updated by subsequent calls to this function.  Switch statements may
1120b57cec5SDimitry Andric /// have more than one outcoming edge into the same BB if they all have the same
1130b57cec5SDimitry Andric /// value. When the switch statement is converted these incoming edges are now
1140b57cec5SDimitry Andric /// coming from multiple BBs.
1150b57cec5SDimitry Andric /// 2) Removed if subsequent incoming values now share the same case, i.e.,
1160b57cec5SDimitry Andric /// multiple outcome edges are condensed into one. This is necessary to keep the
1170b57cec5SDimitry Andric /// number of phi values equal to the number of branches to SuccBB.
118bdd1243dSDimitry Andric void FixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB,
119bdd1243dSDimitry Andric              const APInt &NumMergedCases) {
12081ad6265SDimitry Andric   for (auto &I : SuccBB->phis()) {
12181ad6265SDimitry Andric     PHINode *PN = cast<PHINode>(&I);
1220b57cec5SDimitry Andric 
12381ad6265SDimitry Andric     // Only update the first occurrence if NewBB exists.
1240b57cec5SDimitry Andric     unsigned Idx = 0, E = PN->getNumIncomingValues();
125bdd1243dSDimitry Andric     APInt LocalNumMergedCases = NumMergedCases;
12681ad6265SDimitry Andric     for (; Idx != E && NewBB; ++Idx) {
1270b57cec5SDimitry Andric       if (PN->getIncomingBlock(Idx) == OrigBB) {
1280b57cec5SDimitry Andric         PN->setIncomingBlock(Idx, NewBB);
1290b57cec5SDimitry Andric         break;
1300b57cec5SDimitry Andric       }
1310b57cec5SDimitry Andric     }
1320b57cec5SDimitry Andric 
13381ad6265SDimitry Andric     // Skip the updated incoming block so that it will not be removed.
13481ad6265SDimitry Andric     if (NewBB)
13581ad6265SDimitry Andric       ++Idx;
13681ad6265SDimitry Andric 
1370b57cec5SDimitry Andric     // Remove additional occurrences coming from condensed cases and keep the
1380b57cec5SDimitry Andric     // number of incoming values equal to the number of branches to SuccBB.
1390b57cec5SDimitry Andric     SmallVector<unsigned, 8> Indices;
140bdd1243dSDimitry Andric     for (; LocalNumMergedCases.ugt(0) && Idx < E; ++Idx)
1410b57cec5SDimitry Andric       if (PN->getIncomingBlock(Idx) == OrigBB) {
1420b57cec5SDimitry Andric         Indices.push_back(Idx);
143bdd1243dSDimitry Andric         LocalNumMergedCases -= 1;
1440b57cec5SDimitry Andric       }
1450b57cec5SDimitry Andric     // Remove incoming values in the reverse order to prevent invalidating
1460b57cec5SDimitry Andric     // *successive* index.
1470b57cec5SDimitry Andric     for (unsigned III : llvm::reverse(Indices))
1480b57cec5SDimitry Andric       PN->removeIncomingValue(III);
1490b57cec5SDimitry Andric   }
1500b57cec5SDimitry Andric }
1510b57cec5SDimitry Andric 
152e8d8bef9SDimitry Andric /// Create a new leaf block for the binary lookup tree. It checks if the
153e8d8bef9SDimitry Andric /// switch's value == the case's value. If not, then it jumps to the default
154e8d8bef9SDimitry Andric /// branch. At this point in the tree, the value can't be another valid case
155e8d8bef9SDimitry Andric /// value, so the jump to the "default" branch is warranted.
156e8d8bef9SDimitry Andric BasicBlock *NewLeafBlock(CaseRange &Leaf, Value *Val, ConstantInt *LowerBound,
157e8d8bef9SDimitry Andric                          ConstantInt *UpperBound, BasicBlock *OrigBlock,
158e8d8bef9SDimitry Andric                          BasicBlock *Default) {
159e8d8bef9SDimitry Andric   Function *F = OrigBlock->getParent();
160e8d8bef9SDimitry Andric   BasicBlock *NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock");
161bdd1243dSDimitry Andric   F->insert(++OrigBlock->getIterator(), NewLeaf);
162e8d8bef9SDimitry Andric 
163e8d8bef9SDimitry Andric   // Emit comparison
164e8d8bef9SDimitry Andric   ICmpInst *Comp = nullptr;
165e8d8bef9SDimitry Andric   if (Leaf.Low == Leaf.High) {
166e8d8bef9SDimitry Andric     // Make the seteq instruction...
167e8d8bef9SDimitry Andric     Comp =
168*0fca6ea1SDimitry Andric         new ICmpInst(NewLeaf, ICmpInst::ICMP_EQ, Val, Leaf.Low, "SwitchLeaf");
169e8d8bef9SDimitry Andric   } else {
170e8d8bef9SDimitry Andric     // Make range comparison
171e8d8bef9SDimitry Andric     if (Leaf.Low == LowerBound) {
172e8d8bef9SDimitry Andric       // Val >= Min && Val <= Hi --> Val <= Hi
173*0fca6ea1SDimitry Andric       Comp = new ICmpInst(NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High,
174e8d8bef9SDimitry Andric                           "SwitchLeaf");
175e8d8bef9SDimitry Andric     } else if (Leaf.High == UpperBound) {
176e8d8bef9SDimitry Andric       // Val <= Max && Val >= Lo --> Val >= Lo
177*0fca6ea1SDimitry Andric       Comp = new ICmpInst(NewLeaf, ICmpInst::ICMP_SGE, Val, Leaf.Low,
178e8d8bef9SDimitry Andric                           "SwitchLeaf");
179e8d8bef9SDimitry Andric     } else if (Leaf.Low->isZero()) {
180e8d8bef9SDimitry Andric       // Val >= 0 && Val <= Hi --> Val <=u Hi
181*0fca6ea1SDimitry Andric       Comp = new ICmpInst(NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High,
182e8d8bef9SDimitry Andric                           "SwitchLeaf");
183e8d8bef9SDimitry Andric     } else {
184e8d8bef9SDimitry Andric       // Emit V-Lo <=u Hi-Lo
185e8d8bef9SDimitry Andric       Constant *NegLo = ConstantExpr::getNeg(Leaf.Low);
186e8d8bef9SDimitry Andric       Instruction *Add = BinaryOperator::CreateAdd(
187e8d8bef9SDimitry Andric           Val, NegLo, Val->getName() + ".off", NewLeaf);
188e8d8bef9SDimitry Andric       Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High);
189*0fca6ea1SDimitry Andric       Comp = new ICmpInst(NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound,
190e8d8bef9SDimitry Andric                           "SwitchLeaf");
191e8d8bef9SDimitry Andric     }
192e8d8bef9SDimitry Andric   }
193e8d8bef9SDimitry Andric 
194e8d8bef9SDimitry Andric   // Make the conditional branch...
195e8d8bef9SDimitry Andric   BasicBlock *Succ = Leaf.BB;
196e8d8bef9SDimitry Andric   BranchInst::Create(Succ, Default, Comp, NewLeaf);
197e8d8bef9SDimitry Andric 
19881ad6265SDimitry Andric   // Update the PHI incoming value/block for the default.
19981ad6265SDimitry Andric   for (auto &I : Default->phis()) {
20081ad6265SDimitry Andric     PHINode *PN = cast<PHINode>(&I);
20181ad6265SDimitry Andric     auto *V = PN->getIncomingValueForBlock(OrigBlock);
20281ad6265SDimitry Andric     PN->addIncoming(V, NewLeaf);
20381ad6265SDimitry Andric   }
20481ad6265SDimitry Andric 
205e8d8bef9SDimitry Andric   // If there were any PHI nodes in this successor, rewrite one entry
206e8d8bef9SDimitry Andric   // from OrigBlock to come from NewLeaf.
207e8d8bef9SDimitry Andric   for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) {
208e8d8bef9SDimitry Andric     PHINode *PN = cast<PHINode>(I);
209e8d8bef9SDimitry Andric     // Remove all but one incoming entries from the cluster
210bdd1243dSDimitry Andric     APInt Range = Leaf.High->getValue() - Leaf.Low->getValue();
211*0fca6ea1SDimitry Andric     for (APInt j(Range.getBitWidth(), 0, false); j.ult(Range); ++j) {
212e8d8bef9SDimitry Andric       PN->removeIncomingValue(OrigBlock);
213e8d8bef9SDimitry Andric     }
214e8d8bef9SDimitry Andric 
215e8d8bef9SDimitry Andric     int BlockIdx = PN->getBasicBlockIndex(OrigBlock);
216e8d8bef9SDimitry Andric     assert(BlockIdx != -1 && "Switch didn't go to this successor??");
217e8d8bef9SDimitry Andric     PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf);
218e8d8bef9SDimitry Andric   }
219e8d8bef9SDimitry Andric 
220e8d8bef9SDimitry Andric   return NewLeaf;
221e8d8bef9SDimitry Andric }
222e8d8bef9SDimitry Andric 
2230b57cec5SDimitry Andric /// Convert the switch statement into a binary lookup of the case values.
2240b57cec5SDimitry Andric /// The function recursively builds this tree. LowerBound and UpperBound are
2250b57cec5SDimitry Andric /// used to keep track of the bounds for Val that have already been checked by
2260b57cec5SDimitry Andric /// a block emitted by one of the previous calls to switchConvert in the call
2270b57cec5SDimitry Andric /// stack.
228e8d8bef9SDimitry Andric BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound,
2290b57cec5SDimitry Andric                           ConstantInt *UpperBound, Value *Val,
2300b57cec5SDimitry Andric                           BasicBlock *Predecessor, BasicBlock *OrigBlock,
2310b57cec5SDimitry Andric                           BasicBlock *Default,
2320b57cec5SDimitry Andric                           const std::vector<IntRange> &UnreachableRanges) {
2330b57cec5SDimitry Andric   assert(LowerBound && UpperBound && "Bounds must be initialized");
2340b57cec5SDimitry Andric   unsigned Size = End - Begin;
2350b57cec5SDimitry Andric 
2360b57cec5SDimitry Andric   if (Size == 1) {
2370b57cec5SDimitry Andric     // Check if the Case Range is perfectly squeezed in between
2380b57cec5SDimitry Andric     // already checked Upper and Lower bounds. If it is then we can avoid
2390b57cec5SDimitry Andric     // emitting the code that checks if the value actually falls in the range
2400b57cec5SDimitry Andric     // because the bounds already tell us so.
2410b57cec5SDimitry Andric     if (Begin->Low == LowerBound && Begin->High == UpperBound) {
242bdd1243dSDimitry Andric       APInt NumMergedCases = UpperBound->getValue() - LowerBound->getValue();
243e8d8bef9SDimitry Andric       FixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases);
2440b57cec5SDimitry Andric       return Begin->BB;
2450b57cec5SDimitry Andric     }
246e8d8bef9SDimitry Andric     return NewLeafBlock(*Begin, Val, LowerBound, UpperBound, OrigBlock,
2470b57cec5SDimitry Andric                         Default);
2480b57cec5SDimitry Andric   }
2490b57cec5SDimitry Andric 
2500b57cec5SDimitry Andric   unsigned Mid = Size / 2;
2510b57cec5SDimitry Andric   std::vector<CaseRange> LHS(Begin, Begin + Mid);
2520b57cec5SDimitry Andric   LLVM_DEBUG(dbgs() << "LHS: " << LHS << "\n");
2530b57cec5SDimitry Andric   std::vector<CaseRange> RHS(Begin + Mid, End);
2540b57cec5SDimitry Andric   LLVM_DEBUG(dbgs() << "RHS: " << RHS << "\n");
2550b57cec5SDimitry Andric 
2560b57cec5SDimitry Andric   CaseRange &Pivot = *(Begin + Mid);
2570b57cec5SDimitry Andric   LLVM_DEBUG(dbgs() << "Pivot ==> [" << Pivot.Low->getValue() << ", "
2580b57cec5SDimitry Andric                     << Pivot.High->getValue() << "]\n");
2590b57cec5SDimitry Andric 
2600b57cec5SDimitry Andric   // NewLowerBound here should never be the integer minimal value.
2610b57cec5SDimitry Andric   // This is because it is computed from a case range that is never
2620b57cec5SDimitry Andric   // the smallest, so there is always a case range that has at least
2630b57cec5SDimitry Andric   // a smaller value.
2640b57cec5SDimitry Andric   ConstantInt *NewLowerBound = Pivot.Low;
2650b57cec5SDimitry Andric 
2660b57cec5SDimitry Andric   // Because NewLowerBound is never the smallest representable integer
2670b57cec5SDimitry Andric   // it is safe here to subtract one.
2680b57cec5SDimitry Andric   ConstantInt *NewUpperBound = ConstantInt::get(NewLowerBound->getContext(),
2690b57cec5SDimitry Andric                                                 NewLowerBound->getValue() - 1);
2700b57cec5SDimitry Andric 
2710b57cec5SDimitry Andric   if (!UnreachableRanges.empty()) {
2720b57cec5SDimitry Andric     // Check if the gap between LHS's highest and NewLowerBound is unreachable.
273bdd1243dSDimitry Andric     APInt GapLow = LHS.back().High->getValue() + 1;
274bdd1243dSDimitry Andric     APInt GapHigh = NewLowerBound->getValue() - 1;
2750b57cec5SDimitry Andric     IntRange Gap = {GapLow, GapHigh};
276bdd1243dSDimitry Andric     if (GapHigh.sge(GapLow) && IsInRanges(Gap, UnreachableRanges))
2770b57cec5SDimitry Andric       NewUpperBound = LHS.back().High;
2780b57cec5SDimitry Andric   }
2790b57cec5SDimitry Andric 
280bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getValue() << ", "
281bdd1243dSDimitry Andric                     << NewUpperBound->getValue() << "]\n"
282bdd1243dSDimitry Andric                     << "RHS Bounds ==> [" << NewLowerBound->getValue() << ", "
283bdd1243dSDimitry Andric                     << UpperBound->getValue() << "]\n");
2840b57cec5SDimitry Andric 
2850b57cec5SDimitry Andric   // Create a new node that checks if the value is < pivot. Go to the
2860b57cec5SDimitry Andric   // left branch if it is and right branch if not.
2870b57cec5SDimitry Andric   Function *F = OrigBlock->getParent();
2880b57cec5SDimitry Andric   BasicBlock *NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock");
2890b57cec5SDimitry Andric 
290bdd1243dSDimitry Andric   ICmpInst *Comp = new ICmpInst(ICmpInst::ICMP_SLT, Val, Pivot.Low, "Pivot");
2910b57cec5SDimitry Andric 
292e8d8bef9SDimitry Andric   BasicBlock *LBranch =
293e8d8bef9SDimitry Andric       SwitchConvert(LHS.begin(), LHS.end(), LowerBound, NewUpperBound, Val,
294e8d8bef9SDimitry Andric                     NewNode, OrigBlock, Default, UnreachableRanges);
295e8d8bef9SDimitry Andric   BasicBlock *RBranch =
296e8d8bef9SDimitry Andric       SwitchConvert(RHS.begin(), RHS.end(), NewLowerBound, UpperBound, Val,
297e8d8bef9SDimitry Andric                     NewNode, OrigBlock, Default, UnreachableRanges);
2980b57cec5SDimitry Andric 
299bdd1243dSDimitry Andric   F->insert(++OrigBlock->getIterator(), NewNode);
300bdd1243dSDimitry Andric   Comp->insertInto(NewNode, NewNode->end());
3010b57cec5SDimitry Andric 
3020b57cec5SDimitry Andric   BranchInst::Create(LBranch, RBranch, Comp, NewNode);
3030b57cec5SDimitry Andric   return NewNode;
3040b57cec5SDimitry Andric }
3050b57cec5SDimitry Andric 
3060b57cec5SDimitry Andric /// Transform simple list of \p SI's cases into list of CaseRange's \p Cases.
3070b57cec5SDimitry Andric /// \post \p Cases wouldn't contain references to \p SI's default BB.
3080b57cec5SDimitry Andric /// \returns Number of \p SI's cases that do not reference \p SI's default BB.
309e8d8bef9SDimitry Andric unsigned Clusterify(CaseVector &Cases, SwitchInst *SI) {
3100b57cec5SDimitry Andric   unsigned NumSimpleCases = 0;
3110b57cec5SDimitry Andric 
3120b57cec5SDimitry Andric   // Start with "simple" cases
3130b57cec5SDimitry Andric   for (auto Case : SI->cases()) {
3140b57cec5SDimitry Andric     if (Case.getCaseSuccessor() == SI->getDefaultDest())
3150b57cec5SDimitry Andric       continue;
3160b57cec5SDimitry Andric     Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(),
3170b57cec5SDimitry Andric                               Case.getCaseSuccessor()));
3180b57cec5SDimitry Andric     ++NumSimpleCases;
3190b57cec5SDimitry Andric   }
3200b57cec5SDimitry Andric 
3210b57cec5SDimitry Andric   llvm::sort(Cases, CaseCmp());
3220b57cec5SDimitry Andric 
3230b57cec5SDimitry Andric   // Merge case into clusters
3240b57cec5SDimitry Andric   if (Cases.size() >= 2) {
3250b57cec5SDimitry Andric     CaseItr I = Cases.begin();
3260b57cec5SDimitry Andric     for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) {
327bdd1243dSDimitry Andric       const APInt &nextValue = J->Low->getValue();
328bdd1243dSDimitry Andric       const APInt &currentValue = I->High->getValue();
3290b57cec5SDimitry Andric       BasicBlock *nextBB = J->BB;
3300b57cec5SDimitry Andric       BasicBlock *currentBB = I->BB;
3310b57cec5SDimitry Andric 
3320b57cec5SDimitry Andric       // If the two neighboring cases go to the same destination, merge them
3330b57cec5SDimitry Andric       // into a single case.
334bdd1243dSDimitry Andric       assert(nextValue.sgt(currentValue) &&
335bdd1243dSDimitry Andric              "Cases should be strictly ascending");
3360b57cec5SDimitry Andric       if ((nextValue == currentValue + 1) && (currentBB == nextBB)) {
3370b57cec5SDimitry Andric         I->High = J->High;
3380b57cec5SDimitry Andric         // FIXME: Combine branch weights.
3390b57cec5SDimitry Andric       } else if (++I != J) {
3400b57cec5SDimitry Andric         *I = *J;
3410b57cec5SDimitry Andric       }
3420b57cec5SDimitry Andric     }
3430b57cec5SDimitry Andric     Cases.erase(std::next(I), Cases.end());
3440b57cec5SDimitry Andric   }
3450b57cec5SDimitry Andric 
3460b57cec5SDimitry Andric   return NumSimpleCases;
3470b57cec5SDimitry Andric }
3480b57cec5SDimitry Andric 
3490b57cec5SDimitry Andric /// Replace the specified switch instruction with a sequence of chained if-then
3500b57cec5SDimitry Andric /// insts in a balanced binary search.
351e8d8bef9SDimitry Andric void ProcessSwitchInst(SwitchInst *SI,
3520b57cec5SDimitry Andric                        SmallPtrSetImpl<BasicBlock *> &DeleteList,
3530b57cec5SDimitry Andric                        AssumptionCache *AC, LazyValueInfo *LVI) {
3540b57cec5SDimitry Andric   BasicBlock *OrigBlock = SI->getParent();
3550b57cec5SDimitry Andric   Function *F = OrigBlock->getParent();
3560b57cec5SDimitry Andric   Value *Val = SI->getCondition(); // The value we are switching on...
3570b57cec5SDimitry Andric   BasicBlock *Default = SI->getDefaultDest();
3580b57cec5SDimitry Andric 
3590b57cec5SDimitry Andric   // Don't handle unreachable blocks. If there are successors with phis, this
3600b57cec5SDimitry Andric   // would leave them behind with missing predecessors.
3610b57cec5SDimitry Andric   if ((OrigBlock != &F->getEntryBlock() && pred_empty(OrigBlock)) ||
3620b57cec5SDimitry Andric       OrigBlock->getSinglePredecessor() == OrigBlock) {
3630b57cec5SDimitry Andric     DeleteList.insert(OrigBlock);
3640b57cec5SDimitry Andric     return;
3650b57cec5SDimitry Andric   }
3660b57cec5SDimitry Andric 
3670b57cec5SDimitry Andric   // Prepare cases vector.
3680b57cec5SDimitry Andric   CaseVector Cases;
3690b57cec5SDimitry Andric   const unsigned NumSimpleCases = Clusterify(Cases, SI);
370bdd1243dSDimitry Andric   IntegerType *IT = cast<IntegerType>(SI->getCondition()->getType());
371bdd1243dSDimitry Andric   const unsigned BitWidth = IT->getBitWidth();
372*0fca6ea1SDimitry Andric   // Explicitly use higher precision to prevent unsigned overflow where
373bdd1243dSDimitry Andric   // `UnsignedMax - 0 + 1 == 0`
374bdd1243dSDimitry Andric   APInt UnsignedZero(BitWidth + 1, 0);
375bdd1243dSDimitry Andric   APInt UnsignedMax = APInt::getMaxValue(BitWidth);
3760b57cec5SDimitry Andric   LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size()
3770b57cec5SDimitry Andric                     << ". Total non-default cases: " << NumSimpleCases
3780b57cec5SDimitry Andric                     << "\nCase clusters: " << Cases << "\n");
3790b57cec5SDimitry Andric 
3800b57cec5SDimitry Andric   // If there is only the default destination, just branch.
3810b57cec5SDimitry Andric   if (Cases.empty()) {
3820b57cec5SDimitry Andric     BranchInst::Create(Default, OrigBlock);
3830b57cec5SDimitry Andric     // Remove all the references from Default's PHIs to OrigBlock, but one.
384bdd1243dSDimitry Andric     FixPhis(Default, OrigBlock, OrigBlock, UnsignedMax);
3850b57cec5SDimitry Andric     SI->eraseFromParent();
3860b57cec5SDimitry Andric     return;
3870b57cec5SDimitry Andric   }
3880b57cec5SDimitry Andric 
3890b57cec5SDimitry Andric   ConstantInt *LowerBound = nullptr;
3900b57cec5SDimitry Andric   ConstantInt *UpperBound = nullptr;
3910b57cec5SDimitry Andric   bool DefaultIsUnreachableFromSwitch = false;
3920b57cec5SDimitry Andric 
3930b57cec5SDimitry Andric   if (isa<UnreachableInst>(Default->getFirstNonPHIOrDbg())) {
3940b57cec5SDimitry Andric     // Make the bounds tightly fitted around the case value range, because we
3950b57cec5SDimitry Andric     // know that the value passed to the switch must be exactly one of the case
3960b57cec5SDimitry Andric     // values.
3970b57cec5SDimitry Andric     LowerBound = Cases.front().Low;
3980b57cec5SDimitry Andric     UpperBound = Cases.back().High;
3990b57cec5SDimitry Andric     DefaultIsUnreachableFromSwitch = true;
4000b57cec5SDimitry Andric   } else {
4010b57cec5SDimitry Andric     // Constraining the range of the value being switched over helps eliminating
4020b57cec5SDimitry Andric     // unreachable BBs and minimizing the number of `add` instructions
4030b57cec5SDimitry Andric     // newLeafBlock ends up emitting. Running CorrelatedValuePropagation after
4040b57cec5SDimitry Andric     // LowerSwitch isn't as good, and also much more expensive in terms of
4050b57cec5SDimitry Andric     // compile time for the following reasons:
4060b57cec5SDimitry Andric     // 1. it processes many kinds of instructions, not just switches;
4070b57cec5SDimitry Andric     // 2. even if limited to icmp instructions only, it will have to process
4080b57cec5SDimitry Andric     //    roughly C icmp's per switch, where C is the number of cases in the
4090b57cec5SDimitry Andric     //    switch, while LowerSwitch only needs to call LVI once per switch.
410*0fca6ea1SDimitry Andric     const DataLayout &DL = F->getDataLayout();
4110b57cec5SDimitry Andric     KnownBits Known = computeKnownBits(Val, DL, /*Depth=*/0, AC, SI);
4120b57cec5SDimitry Andric     // TODO Shouldn't this create a signed range?
4130b57cec5SDimitry Andric     ConstantRange KnownBitsRange =
4140b57cec5SDimitry Andric         ConstantRange::fromKnownBits(Known, /*IsSigned=*/false);
4155f757f3fSDimitry Andric     const ConstantRange LVIRange =
4165f757f3fSDimitry Andric         LVI->getConstantRange(Val, SI, /*UndefAllowed*/ false);
4170b57cec5SDimitry Andric     ConstantRange ValRange = KnownBitsRange.intersectWith(LVIRange);
4180b57cec5SDimitry Andric     // We delegate removal of unreachable non-default cases to other passes. In
4190b57cec5SDimitry Andric     // the unlikely event that some of them survived, we just conservatively
4200b57cec5SDimitry Andric     // maintain the invariant that all the cases lie between the bounds. This
4210b57cec5SDimitry Andric     // may, however, still render the default case effectively unreachable.
422bdd1243dSDimitry Andric     const APInt &Low = Cases.front().Low->getValue();
423bdd1243dSDimitry Andric     const APInt &High = Cases.back().High->getValue();
4240b57cec5SDimitry Andric     APInt Min = APIntOps::smin(ValRange.getSignedMin(), Low);
4250b57cec5SDimitry Andric     APInt Max = APIntOps::smax(ValRange.getSignedMax(), High);
4260b57cec5SDimitry Andric 
4270b57cec5SDimitry Andric     LowerBound = ConstantInt::get(SI->getContext(), Min);
4280b57cec5SDimitry Andric     UpperBound = ConstantInt::get(SI->getContext(), Max);
4290b57cec5SDimitry Andric     DefaultIsUnreachableFromSwitch = (Min + (NumSimpleCases - 1) == Max);
4300b57cec5SDimitry Andric   }
4310b57cec5SDimitry Andric 
4320b57cec5SDimitry Andric   std::vector<IntRange> UnreachableRanges;
4330b57cec5SDimitry Andric 
4340b57cec5SDimitry Andric   if (DefaultIsUnreachableFromSwitch) {
435bdd1243dSDimitry Andric     DenseMap<BasicBlock *, APInt> Popularity;
436bdd1243dSDimitry Andric     APInt MaxPop(UnsignedZero);
4370b57cec5SDimitry Andric     BasicBlock *PopSucc = nullptr;
4380b57cec5SDimitry Andric 
439bdd1243dSDimitry Andric     APInt SignedMax = APInt::getSignedMaxValue(BitWidth);
440bdd1243dSDimitry Andric     APInt SignedMin = APInt::getSignedMinValue(BitWidth);
441bdd1243dSDimitry Andric     IntRange R = {SignedMin, SignedMax};
4420b57cec5SDimitry Andric     UnreachableRanges.push_back(R);
4430b57cec5SDimitry Andric     for (const auto &I : Cases) {
444bdd1243dSDimitry Andric       const APInt &Low = I.Low->getValue();
445bdd1243dSDimitry Andric       const APInt &High = I.High->getValue();
4460b57cec5SDimitry Andric 
4470b57cec5SDimitry Andric       IntRange &LastRange = UnreachableRanges.back();
448bdd1243dSDimitry Andric       if (LastRange.Low.eq(Low)) {
4490b57cec5SDimitry Andric         // There is nothing left of the previous range.
4500b57cec5SDimitry Andric         UnreachableRanges.pop_back();
4510b57cec5SDimitry Andric       } else {
4520b57cec5SDimitry Andric         // Terminate the previous range.
453bdd1243dSDimitry Andric         assert(Low.sgt(LastRange.Low));
4540b57cec5SDimitry Andric         LastRange.High = Low - 1;
4550b57cec5SDimitry Andric       }
456bdd1243dSDimitry Andric       if (High.ne(SignedMax)) {
457bdd1243dSDimitry Andric         IntRange R = {High + 1, SignedMax};
4580b57cec5SDimitry Andric         UnreachableRanges.push_back(R);
4590b57cec5SDimitry Andric       }
4600b57cec5SDimitry Andric 
4610b57cec5SDimitry Andric       // Count popularity.
462bdd1243dSDimitry Andric       assert(High.sge(Low) && "Popularity shouldn't be negative.");
463bdd1243dSDimitry Andric       APInt N = High.sext(BitWidth + 1) - Low.sext(BitWidth + 1) + 1;
464bdd1243dSDimitry Andric       // Explict insert to make sure the bitwidth of APInts match
465bdd1243dSDimitry Andric       APInt &Pop = Popularity.insert({I.BB, APInt(UnsignedZero)}).first->second;
466bdd1243dSDimitry Andric       if ((Pop += N).ugt(MaxPop)) {
4670b57cec5SDimitry Andric         MaxPop = Pop;
4680b57cec5SDimitry Andric         PopSucc = I.BB;
4690b57cec5SDimitry Andric       }
4700b57cec5SDimitry Andric     }
4710b57cec5SDimitry Andric #ifndef NDEBUG
4720b57cec5SDimitry Andric     /* UnreachableRanges should be sorted and the ranges non-adjacent. */
4730b57cec5SDimitry Andric     for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end();
4740b57cec5SDimitry Andric          I != E; ++I) {
475bdd1243dSDimitry Andric       assert(I->Low.sle(I->High));
4760b57cec5SDimitry Andric       auto Next = I + 1;
4770b57cec5SDimitry Andric       if (Next != E) {
478bdd1243dSDimitry Andric         assert(Next->Low.sgt(I->High));
4790b57cec5SDimitry Andric       }
4800b57cec5SDimitry Andric     }
4810b57cec5SDimitry Andric #endif
4820b57cec5SDimitry Andric 
4830b57cec5SDimitry Andric     // As the default block in the switch is unreachable, update the PHI nodes
4840b57cec5SDimitry Andric     // (remove all of the references to the default block) to reflect this.
4850b57cec5SDimitry Andric     const unsigned NumDefaultEdges = SI->getNumCases() + 1 - NumSimpleCases;
4860b57cec5SDimitry Andric     for (unsigned I = 0; I < NumDefaultEdges; ++I)
4870b57cec5SDimitry Andric       Default->removePredecessor(OrigBlock);
4880b57cec5SDimitry Andric 
4890b57cec5SDimitry Andric     // Use the most popular block as the new default, reducing the number of
4900b57cec5SDimitry Andric     // cases.
4910b57cec5SDimitry Andric     Default = PopSucc;
492e8d8bef9SDimitry Andric     llvm::erase_if(Cases,
493e8d8bef9SDimitry Andric                    [PopSucc](const CaseRange &R) { return R.BB == PopSucc; });
4940b57cec5SDimitry Andric 
4950b57cec5SDimitry Andric     // If there are no cases left, just branch.
4960b57cec5SDimitry Andric     if (Cases.empty()) {
4970b57cec5SDimitry Andric       BranchInst::Create(Default, OrigBlock);
4980b57cec5SDimitry Andric       SI->eraseFromParent();
4990b57cec5SDimitry Andric       // As all the cases have been replaced with a single branch, only keep
5000b57cec5SDimitry Andric       // one entry in the PHI nodes.
501bdd1243dSDimitry Andric       if (!MaxPop.isZero())
502bdd1243dSDimitry Andric         for (APInt I(UnsignedZero); I.ult(MaxPop - 1); ++I)
5030b57cec5SDimitry Andric           PopSucc->removePredecessor(OrigBlock);
5040b57cec5SDimitry Andric       return;
5050b57cec5SDimitry Andric     }
5060b57cec5SDimitry Andric 
5070b57cec5SDimitry Andric     // If the condition was a PHI node with the switch block as a predecessor
5080b57cec5SDimitry Andric     // removing predecessors may have caused the condition to be erased.
5090b57cec5SDimitry Andric     // Getting the condition value again here protects against that.
5100b57cec5SDimitry Andric     Val = SI->getCondition();
5110b57cec5SDimitry Andric   }
5120b57cec5SDimitry Andric 
5130b57cec5SDimitry Andric   BasicBlock *SwitchBlock =
514e8d8bef9SDimitry Andric       SwitchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val,
51581ad6265SDimitry Andric                     OrigBlock, OrigBlock, Default, UnreachableRanges);
5160b57cec5SDimitry Andric 
51781ad6265SDimitry Andric   // We have added incoming values for newly-created predecessors in
51881ad6265SDimitry Andric   // NewLeafBlock(). The only meaningful work we offload to FixPhis() is to
51981ad6265SDimitry Andric   // remove the incoming values from OrigBlock. There might be a special case
52081ad6265SDimitry Andric   // that SwitchBlock is the same as Default, under which the PHIs in Default
52181ad6265SDimitry Andric   // are fixed inside SwitchConvert().
52281ad6265SDimitry Andric   if (SwitchBlock != Default)
523bdd1243dSDimitry Andric     FixPhis(Default, OrigBlock, nullptr, UnsignedMax);
5240b57cec5SDimitry Andric 
5250b57cec5SDimitry Andric   // Branch to our shiny new if-then stuff...
5260b57cec5SDimitry Andric   BranchInst::Create(SwitchBlock, OrigBlock);
5270b57cec5SDimitry Andric 
5280b57cec5SDimitry Andric   // We are now done with the switch instruction, delete it.
5290b57cec5SDimitry Andric   BasicBlock *OldDefault = SI->getDefaultDest();
530bdd1243dSDimitry Andric   SI->eraseFromParent();
5310b57cec5SDimitry Andric 
5320b57cec5SDimitry Andric   // If the Default block has no more predecessors just add it to DeleteList.
533e8d8bef9SDimitry Andric   if (pred_empty(OldDefault))
5340b57cec5SDimitry Andric     DeleteList.insert(OldDefault);
5350b57cec5SDimitry Andric }
536e8d8bef9SDimitry Andric 
537e8d8bef9SDimitry Andric bool LowerSwitch(Function &F, LazyValueInfo *LVI, AssumptionCache *AC) {
538e8d8bef9SDimitry Andric   bool Changed = false;
539e8d8bef9SDimitry Andric   SmallPtrSet<BasicBlock *, 8> DeleteList;
540e8d8bef9SDimitry Andric 
541349cc55cSDimitry Andric   // We use make_early_inc_range here so that we don't traverse new blocks.
542349cc55cSDimitry Andric   for (BasicBlock &Cur : llvm::make_early_inc_range(F)) {
543e8d8bef9SDimitry Andric     // If the block is a dead Default block that will be deleted later, don't
544e8d8bef9SDimitry Andric     // waste time processing it.
545349cc55cSDimitry Andric     if (DeleteList.count(&Cur))
546e8d8bef9SDimitry Andric       continue;
547e8d8bef9SDimitry Andric 
548349cc55cSDimitry Andric     if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur.getTerminator())) {
549e8d8bef9SDimitry Andric       Changed = true;
550e8d8bef9SDimitry Andric       ProcessSwitchInst(SI, DeleteList, AC, LVI);
551e8d8bef9SDimitry Andric     }
552e8d8bef9SDimitry Andric   }
553e8d8bef9SDimitry Andric 
554e8d8bef9SDimitry Andric   for (BasicBlock *BB : DeleteList) {
555e8d8bef9SDimitry Andric     LVI->eraseBlock(BB);
556e8d8bef9SDimitry Andric     DeleteDeadBlock(BB);
557e8d8bef9SDimitry Andric   }
558e8d8bef9SDimitry Andric 
559e8d8bef9SDimitry Andric   return Changed;
560e8d8bef9SDimitry Andric }
561e8d8bef9SDimitry Andric 
562e8d8bef9SDimitry Andric /// Replace all SwitchInst instructions with chained branch instructions.
563e8d8bef9SDimitry Andric class LowerSwitchLegacyPass : public FunctionPass {
564e8d8bef9SDimitry Andric public:
565e8d8bef9SDimitry Andric   // Pass identification, replacement for typeid
566e8d8bef9SDimitry Andric   static char ID;
567e8d8bef9SDimitry Andric 
568e8d8bef9SDimitry Andric   LowerSwitchLegacyPass() : FunctionPass(ID) {
569e8d8bef9SDimitry Andric     initializeLowerSwitchLegacyPassPass(*PassRegistry::getPassRegistry());
570e8d8bef9SDimitry Andric   }
571e8d8bef9SDimitry Andric 
572e8d8bef9SDimitry Andric   bool runOnFunction(Function &F) override;
573e8d8bef9SDimitry Andric 
574e8d8bef9SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
575e8d8bef9SDimitry Andric     AU.addRequired<LazyValueInfoWrapperPass>();
576e8d8bef9SDimitry Andric   }
577e8d8bef9SDimitry Andric };
578e8d8bef9SDimitry Andric 
579e8d8bef9SDimitry Andric } // end anonymous namespace
580e8d8bef9SDimitry Andric 
581e8d8bef9SDimitry Andric char LowerSwitchLegacyPass::ID = 0;
582e8d8bef9SDimitry Andric 
583e8d8bef9SDimitry Andric // Publicly exposed interface to pass...
584e8d8bef9SDimitry Andric char &llvm::LowerSwitchID = LowerSwitchLegacyPass::ID;
585e8d8bef9SDimitry Andric 
586e8d8bef9SDimitry Andric INITIALIZE_PASS_BEGIN(LowerSwitchLegacyPass, "lowerswitch",
587e8d8bef9SDimitry Andric                       "Lower SwitchInst's to branches", false, false)
588e8d8bef9SDimitry Andric INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
589e8d8bef9SDimitry Andric INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass)
590e8d8bef9SDimitry Andric INITIALIZE_PASS_END(LowerSwitchLegacyPass, "lowerswitch",
591e8d8bef9SDimitry Andric                     "Lower SwitchInst's to branches", false, false)
592e8d8bef9SDimitry Andric 
593e8d8bef9SDimitry Andric // createLowerSwitchPass - Interface to this file...
594e8d8bef9SDimitry Andric FunctionPass *llvm::createLowerSwitchPass() {
595e8d8bef9SDimitry Andric   return new LowerSwitchLegacyPass();
596e8d8bef9SDimitry Andric }
597e8d8bef9SDimitry Andric 
598e8d8bef9SDimitry Andric bool LowerSwitchLegacyPass::runOnFunction(Function &F) {
599e8d8bef9SDimitry Andric   LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI();
600e8d8bef9SDimitry Andric   auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>();
601e8d8bef9SDimitry Andric   AssumptionCache *AC = ACT ? &ACT->getAssumptionCache(F) : nullptr;
602e8d8bef9SDimitry Andric   return LowerSwitch(F, LVI, AC);
603e8d8bef9SDimitry Andric }
604e8d8bef9SDimitry Andric 
605e8d8bef9SDimitry Andric PreservedAnalyses LowerSwitchPass::run(Function &F,
606e8d8bef9SDimitry Andric                                        FunctionAnalysisManager &AM) {
607e8d8bef9SDimitry Andric   LazyValueInfo *LVI = &AM.getResult<LazyValueAnalysis>(F);
608e8d8bef9SDimitry Andric   AssumptionCache *AC = AM.getCachedResult<AssumptionAnalysis>(F);
609e8d8bef9SDimitry Andric   return LowerSwitch(F, LVI, AC) ? PreservedAnalyses::none()
610e8d8bef9SDimitry Andric                                  : PreservedAnalyses::all();
611e8d8bef9SDimitry Andric }
612