xref: /freebsd-src/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1*0fca6ea1SDimitry Andric //===- JumpTableToSwitch.cpp ----------------------------------------------===//
2*0fca6ea1SDimitry Andric //
3*0fca6ea1SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*0fca6ea1SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*0fca6ea1SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*0fca6ea1SDimitry Andric //
7*0fca6ea1SDimitry Andric //===----------------------------------------------------------------------===//
8*0fca6ea1SDimitry Andric 
9*0fca6ea1SDimitry Andric #include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
10*0fca6ea1SDimitry Andric #include "llvm/ADT/SmallVector.h"
11*0fca6ea1SDimitry Andric #include "llvm/Analysis/ConstantFolding.h"
12*0fca6ea1SDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h"
13*0fca6ea1SDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h"
14*0fca6ea1SDimitry Andric #include "llvm/Analysis/PostDominators.h"
15*0fca6ea1SDimitry Andric #include "llvm/IR/IRBuilder.h"
16*0fca6ea1SDimitry Andric #include "llvm/Support/CommandLine.h"
17*0fca6ea1SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h"
18*0fca6ea1SDimitry Andric 
19*0fca6ea1SDimitry Andric using namespace llvm;
20*0fca6ea1SDimitry Andric 
21*0fca6ea1SDimitry Andric static cl::opt<unsigned>
22*0fca6ea1SDimitry Andric     JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,
23*0fca6ea1SDimitry Andric                            cl::desc("Only split jump tables with size less or "
24*0fca6ea1SDimitry Andric                                     "equal than JumpTableSizeThreshold."),
25*0fca6ea1SDimitry Andric                            cl::init(10));
26*0fca6ea1SDimitry Andric 
27*0fca6ea1SDimitry Andric // TODO: Consider adding a cost model for profitability analysis of this
28*0fca6ea1SDimitry Andric // transformation. Currently we replace a jump table with a switch if all the
29*0fca6ea1SDimitry Andric // functions in the jump table are smaller than the provided threshold.
30*0fca6ea1SDimitry Andric static cl::opt<unsigned> FunctionSizeThreshold(
31*0fca6ea1SDimitry Andric     "jump-table-to-switch-function-size-threshold", cl::Hidden,
32*0fca6ea1SDimitry Andric     cl::desc("Only split jump tables containing functions whose sizes are less "
33*0fca6ea1SDimitry Andric              "or equal than this threshold."),
34*0fca6ea1SDimitry Andric     cl::init(50));
35*0fca6ea1SDimitry Andric 
36*0fca6ea1SDimitry Andric #define DEBUG_TYPE "jump-table-to-switch"
37*0fca6ea1SDimitry Andric 
38*0fca6ea1SDimitry Andric namespace {
39*0fca6ea1SDimitry Andric struct JumpTableTy {
40*0fca6ea1SDimitry Andric   Value *Index;
41*0fca6ea1SDimitry Andric   SmallVector<Function *, 10> Funcs;
42*0fca6ea1SDimitry Andric };
43*0fca6ea1SDimitry Andric } // anonymous namespace
44*0fca6ea1SDimitry Andric 
45*0fca6ea1SDimitry Andric static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
46*0fca6ea1SDimitry Andric                                                  PointerType *PtrTy) {
47*0fca6ea1SDimitry Andric   Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand());
48*0fca6ea1SDimitry Andric   if (!Ptr)
49*0fca6ea1SDimitry Andric     return std::nullopt;
50*0fca6ea1SDimitry Andric 
51*0fca6ea1SDimitry Andric   GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr);
52*0fca6ea1SDimitry Andric   if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
53*0fca6ea1SDimitry Andric     return std::nullopt;
54*0fca6ea1SDimitry Andric 
55*0fca6ea1SDimitry Andric   Function &F = *GEP->getParent()->getParent();
56*0fca6ea1SDimitry Andric   const DataLayout &DL = F.getDataLayout();
57*0fca6ea1SDimitry Andric   const unsigned BitWidth =
58*0fca6ea1SDimitry Andric       DL.getIndexSizeInBits(GEP->getPointerAddressSpace());
59*0fca6ea1SDimitry Andric   MapVector<Value *, APInt> VariableOffsets;
60*0fca6ea1SDimitry Andric   APInt ConstantOffset(BitWidth, 0);
61*0fca6ea1SDimitry Andric   if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
62*0fca6ea1SDimitry Andric     return std::nullopt;
63*0fca6ea1SDimitry Andric   if (VariableOffsets.size() != 1)
64*0fca6ea1SDimitry Andric     return std::nullopt;
65*0fca6ea1SDimitry Andric   // TODO: consider supporting more general patterns
66*0fca6ea1SDimitry Andric   if (!ConstantOffset.isZero())
67*0fca6ea1SDimitry Andric     return std::nullopt;
68*0fca6ea1SDimitry Andric   APInt StrideBytes = VariableOffsets.front().second;
69*0fca6ea1SDimitry Andric   const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(GV->getValueType());
70*0fca6ea1SDimitry Andric   if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0)
71*0fca6ea1SDimitry Andric     return std::nullopt;
72*0fca6ea1SDimitry Andric   const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue();
73*0fca6ea1SDimitry Andric   if (N > JumpTableSizeThreshold)
74*0fca6ea1SDimitry Andric     return std::nullopt;
75*0fca6ea1SDimitry Andric 
76*0fca6ea1SDimitry Andric   JumpTableTy JumpTable;
77*0fca6ea1SDimitry Andric   JumpTable.Index = VariableOffsets.front().first;
78*0fca6ea1SDimitry Andric   JumpTable.Funcs.reserve(N);
79*0fca6ea1SDimitry Andric   for (uint64_t Index = 0; Index < N; ++Index) {
80*0fca6ea1SDimitry Andric     // ConstantOffset is zero.
81*0fca6ea1SDimitry Andric     APInt Offset = Index * StrideBytes;
82*0fca6ea1SDimitry Andric     Constant *C =
83*0fca6ea1SDimitry Andric         ConstantFoldLoadFromConst(GV->getInitializer(), PtrTy, Offset, DL);
84*0fca6ea1SDimitry Andric     auto *Func = dyn_cast_or_null<Function>(C);
85*0fca6ea1SDimitry Andric     if (!Func || Func->isDeclaration() ||
86*0fca6ea1SDimitry Andric         Func->getInstructionCount() > FunctionSizeThreshold)
87*0fca6ea1SDimitry Andric       return std::nullopt;
88*0fca6ea1SDimitry Andric     JumpTable.Funcs.push_back(Func);
89*0fca6ea1SDimitry Andric   }
90*0fca6ea1SDimitry Andric   return JumpTable;
91*0fca6ea1SDimitry Andric }
92*0fca6ea1SDimitry Andric 
93*0fca6ea1SDimitry Andric static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
94*0fca6ea1SDimitry Andric                                   DomTreeUpdater &DTU,
95*0fca6ea1SDimitry Andric                                   OptimizationRemarkEmitter &ORE) {
96*0fca6ea1SDimitry Andric   const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
97*0fca6ea1SDimitry Andric 
98*0fca6ea1SDimitry Andric   SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
99*0fca6ea1SDimitry Andric   BasicBlock *BB = CB->getParent();
100*0fca6ea1SDimitry Andric   BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr,
101*0fca6ea1SDimitry Andric                                 BB->getName() + Twine(".tail"));
102*0fca6ea1SDimitry Andric   DTUpdates.push_back({DominatorTree::Delete, BB, Tail});
103*0fca6ea1SDimitry Andric   BB->getTerminator()->eraseFromParent();
104*0fca6ea1SDimitry Andric 
105*0fca6ea1SDimitry Andric   Function &F = *BB->getParent();
106*0fca6ea1SDimitry Andric   BasicBlock *BBUnreachable = BasicBlock::Create(
107*0fca6ea1SDimitry Andric       F.getContext(), "default.switch.case.unreachable", &F, Tail);
108*0fca6ea1SDimitry Andric   IRBuilder<> BuilderUnreachable(BBUnreachable);
109*0fca6ea1SDimitry Andric   BuilderUnreachable.CreateUnreachable();
110*0fca6ea1SDimitry Andric 
111*0fca6ea1SDimitry Andric   IRBuilder<> Builder(BB);
112*0fca6ea1SDimitry Andric   SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable);
113*0fca6ea1SDimitry Andric   DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable});
114*0fca6ea1SDimitry Andric 
115*0fca6ea1SDimitry Andric   IRBuilder<> BuilderTail(CB);
116*0fca6ea1SDimitry Andric   PHINode *PHI =
117*0fca6ea1SDimitry Andric       IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
118*0fca6ea1SDimitry Andric 
119*0fca6ea1SDimitry Andric   for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
120*0fca6ea1SDimitry Andric     BasicBlock *B = BasicBlock::Create(Func->getContext(),
121*0fca6ea1SDimitry Andric                                        "call." + Twine(Index), &F, Tail);
122*0fca6ea1SDimitry Andric     DTUpdates.push_back({DominatorTree::Insert, BB, B});
123*0fca6ea1SDimitry Andric     DTUpdates.push_back({DominatorTree::Insert, B, Tail});
124*0fca6ea1SDimitry Andric 
125*0fca6ea1SDimitry Andric     CallBase *Call = cast<CallBase>(CB->clone());
126*0fca6ea1SDimitry Andric     Call->setCalledFunction(Func);
127*0fca6ea1SDimitry Andric     Call->insertInto(B, B->end());
128*0fca6ea1SDimitry Andric     Switch->addCase(
129*0fca6ea1SDimitry Andric         cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
130*0fca6ea1SDimitry Andric     BranchInst::Create(Tail, B);
131*0fca6ea1SDimitry Andric     if (PHI)
132*0fca6ea1SDimitry Andric       PHI->addIncoming(Call, B);
133*0fca6ea1SDimitry Andric   }
134*0fca6ea1SDimitry Andric   DTU.applyUpdates(DTUpdates);
135*0fca6ea1SDimitry Andric   ORE.emit([&]() {
136*0fca6ea1SDimitry Andric     return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
137*0fca6ea1SDimitry Andric            << "expanded indirect call into switch";
138*0fca6ea1SDimitry Andric   });
139*0fca6ea1SDimitry Andric   if (PHI)
140*0fca6ea1SDimitry Andric     CB->replaceAllUsesWith(PHI);
141*0fca6ea1SDimitry Andric   CB->eraseFromParent();
142*0fca6ea1SDimitry Andric   return Tail;
143*0fca6ea1SDimitry Andric }
144*0fca6ea1SDimitry Andric 
145*0fca6ea1SDimitry Andric PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
146*0fca6ea1SDimitry Andric                                              FunctionAnalysisManager &AM) {
147*0fca6ea1SDimitry Andric   OptimizationRemarkEmitter &ORE =
148*0fca6ea1SDimitry Andric       AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
149*0fca6ea1SDimitry Andric   DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
150*0fca6ea1SDimitry Andric   PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
151*0fca6ea1SDimitry Andric   DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
152*0fca6ea1SDimitry Andric   bool Changed = false;
153*0fca6ea1SDimitry Andric   for (BasicBlock &BB : make_early_inc_range(F)) {
154*0fca6ea1SDimitry Andric     BasicBlock *CurrentBB = &BB;
155*0fca6ea1SDimitry Andric     while (CurrentBB) {
156*0fca6ea1SDimitry Andric       BasicBlock *SplittedOutTail = nullptr;
157*0fca6ea1SDimitry Andric       for (Instruction &I : make_early_inc_range(*CurrentBB)) {
158*0fca6ea1SDimitry Andric         auto *Call = dyn_cast<CallInst>(&I);
159*0fca6ea1SDimitry Andric         if (!Call || Call->getCalledFunction() || Call->isMustTailCall())
160*0fca6ea1SDimitry Andric           continue;
161*0fca6ea1SDimitry Andric         auto *L = dyn_cast<LoadInst>(Call->getCalledOperand());
162*0fca6ea1SDimitry Andric         // Skip atomic or volatile loads.
163*0fca6ea1SDimitry Andric         if (!L || !L->isSimple())
164*0fca6ea1SDimitry Andric           continue;
165*0fca6ea1SDimitry Andric         auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand());
166*0fca6ea1SDimitry Andric         if (!GEP)
167*0fca6ea1SDimitry Andric           continue;
168*0fca6ea1SDimitry Andric         auto *PtrTy = dyn_cast<PointerType>(L->getType());
169*0fca6ea1SDimitry Andric         assert(PtrTy && "call operand must be a pointer");
170*0fca6ea1SDimitry Andric         std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
171*0fca6ea1SDimitry Andric         if (!JumpTable)
172*0fca6ea1SDimitry Andric           continue;
173*0fca6ea1SDimitry Andric         SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE);
174*0fca6ea1SDimitry Andric         Changed = true;
175*0fca6ea1SDimitry Andric         break;
176*0fca6ea1SDimitry Andric       }
177*0fca6ea1SDimitry Andric       CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr;
178*0fca6ea1SDimitry Andric     }
179*0fca6ea1SDimitry Andric   }
180*0fca6ea1SDimitry Andric 
181*0fca6ea1SDimitry Andric   if (!Changed)
182*0fca6ea1SDimitry Andric     return PreservedAnalyses::all();
183*0fca6ea1SDimitry Andric 
184*0fca6ea1SDimitry Andric   PreservedAnalyses PA;
185*0fca6ea1SDimitry Andric   if (DT)
186*0fca6ea1SDimitry Andric     PA.preserve<DominatorTreeAnalysis>();
187*0fca6ea1SDimitry Andric   if (PDT)
188*0fca6ea1SDimitry Andric     PA.preserve<PostDominatorTreeAnalysis>();
189*0fca6ea1SDimitry Andric   return PA;
190*0fca6ea1SDimitry Andric }
191