xref: /llvm-project/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp (revision 96f37ae45310885e09195be09d9c05e1c1dff86b)
1d26b43ffSAlexander Shaposhnikov //===- JumpTableToSwitch.cpp ----------------------------------------------===//
2d26b43ffSAlexander Shaposhnikov //
3d26b43ffSAlexander Shaposhnikov // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d26b43ffSAlexander Shaposhnikov // See https://llvm.org/LICENSE.txt for license information.
5d26b43ffSAlexander Shaposhnikov // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d26b43ffSAlexander Shaposhnikov //
7d26b43ffSAlexander Shaposhnikov //===----------------------------------------------------------------------===//
8d26b43ffSAlexander Shaposhnikov 
9d26b43ffSAlexander Shaposhnikov #include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
10d26b43ffSAlexander Shaposhnikov #include "llvm/ADT/SmallVector.h"
11d26b43ffSAlexander Shaposhnikov #include "llvm/Analysis/ConstantFolding.h"
12d26b43ffSAlexander Shaposhnikov #include "llvm/Analysis/DomTreeUpdater.h"
13d26b43ffSAlexander Shaposhnikov #include "llvm/Analysis/OptimizationRemarkEmitter.h"
14d26b43ffSAlexander Shaposhnikov #include "llvm/Analysis/PostDominators.h"
15d26b43ffSAlexander Shaposhnikov #include "llvm/IR/IRBuilder.h"
16d26b43ffSAlexander Shaposhnikov #include "llvm/Support/CommandLine.h"
17d26b43ffSAlexander Shaposhnikov #include "llvm/Transforms/Utils/BasicBlockUtils.h"
18d26b43ffSAlexander Shaposhnikov 
19d26b43ffSAlexander Shaposhnikov using namespace llvm;
20d26b43ffSAlexander Shaposhnikov 
21d26b43ffSAlexander Shaposhnikov static cl::opt<unsigned>
22d26b43ffSAlexander Shaposhnikov     JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,
23d26b43ffSAlexander Shaposhnikov                            cl::desc("Only split jump tables with size less or "
24d26b43ffSAlexander Shaposhnikov                                     "equal than JumpTableSizeThreshold."),
25d26b43ffSAlexander Shaposhnikov                            cl::init(10));
26d26b43ffSAlexander Shaposhnikov 
27d26b43ffSAlexander Shaposhnikov // TODO: Consider adding a cost model for profitability analysis of this
28d26b43ffSAlexander Shaposhnikov // transformation. Currently we replace a jump table with a switch if all the
29d26b43ffSAlexander Shaposhnikov // functions in the jump table are smaller than the provided threshold.
30d26b43ffSAlexander Shaposhnikov static cl::opt<unsigned> FunctionSizeThreshold(
31d26b43ffSAlexander Shaposhnikov     "jump-table-to-switch-function-size-threshold", cl::Hidden,
32d26b43ffSAlexander Shaposhnikov     cl::desc("Only split jump tables containing functions whose sizes are less "
33d26b43ffSAlexander Shaposhnikov              "or equal than this threshold."),
34d26b43ffSAlexander Shaposhnikov     cl::init(50));
35d26b43ffSAlexander Shaposhnikov 
36d26b43ffSAlexander Shaposhnikov #define DEBUG_TYPE "jump-table-to-switch"
37d26b43ffSAlexander Shaposhnikov 
38d26b43ffSAlexander Shaposhnikov namespace {
39d26b43ffSAlexander Shaposhnikov struct JumpTableTy {
40d26b43ffSAlexander Shaposhnikov   Value *Index;
41d26b43ffSAlexander Shaposhnikov   SmallVector<Function *, 10> Funcs;
42d26b43ffSAlexander Shaposhnikov };
43d26b43ffSAlexander Shaposhnikov } // anonymous namespace
44d26b43ffSAlexander Shaposhnikov 
45d26b43ffSAlexander Shaposhnikov static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
46d26b43ffSAlexander Shaposhnikov                                                  PointerType *PtrTy) {
47d26b43ffSAlexander Shaposhnikov   Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand());
48d26b43ffSAlexander Shaposhnikov   if (!Ptr)
49d26b43ffSAlexander Shaposhnikov     return std::nullopt;
50d26b43ffSAlexander Shaposhnikov 
51d26b43ffSAlexander Shaposhnikov   GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr);
52d26b43ffSAlexander Shaposhnikov   if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
53d26b43ffSAlexander Shaposhnikov     return std::nullopt;
54d26b43ffSAlexander Shaposhnikov 
55d26b43ffSAlexander Shaposhnikov   Function &F = *GEP->getParent()->getParent();
569df71d76SNikita Popov   const DataLayout &DL = F.getDataLayout();
57d26b43ffSAlexander Shaposhnikov   const unsigned BitWidth =
58d26b43ffSAlexander Shaposhnikov       DL.getIndexSizeInBits(GEP->getPointerAddressSpace());
59*96f37ae4SJeremy Morse   SmallMapVector<Value *, APInt, 4> VariableOffsets;
60d26b43ffSAlexander Shaposhnikov   APInt ConstantOffset(BitWidth, 0);
61d26b43ffSAlexander Shaposhnikov   if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
62d26b43ffSAlexander Shaposhnikov     return std::nullopt;
63d26b43ffSAlexander Shaposhnikov   if (VariableOffsets.size() != 1)
64d26b43ffSAlexander Shaposhnikov     return std::nullopt;
65d26b43ffSAlexander Shaposhnikov   // TODO: consider supporting more general patterns
66d26b43ffSAlexander Shaposhnikov   if (!ConstantOffset.isZero())
67d26b43ffSAlexander Shaposhnikov     return std::nullopt;
68d26b43ffSAlexander Shaposhnikov   APInt StrideBytes = VariableOffsets.front().second;
69d26b43ffSAlexander Shaposhnikov   const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(GV->getValueType());
70d26b43ffSAlexander Shaposhnikov   if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0)
71d26b43ffSAlexander Shaposhnikov     return std::nullopt;
72d26b43ffSAlexander Shaposhnikov   const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue();
73d26b43ffSAlexander Shaposhnikov   if (N > JumpTableSizeThreshold)
74d26b43ffSAlexander Shaposhnikov     return std::nullopt;
75d26b43ffSAlexander Shaposhnikov 
76d26b43ffSAlexander Shaposhnikov   JumpTableTy JumpTable;
77d26b43ffSAlexander Shaposhnikov   JumpTable.Index = VariableOffsets.front().first;
78d26b43ffSAlexander Shaposhnikov   JumpTable.Funcs.reserve(N);
79d26b43ffSAlexander Shaposhnikov   for (uint64_t Index = 0; Index < N; ++Index) {
80d26b43ffSAlexander Shaposhnikov     // ConstantOffset is zero.
81d26b43ffSAlexander Shaposhnikov     APInt Offset = Index * StrideBytes;
82d26b43ffSAlexander Shaposhnikov     Constant *C =
83d26b43ffSAlexander Shaposhnikov         ConstantFoldLoadFromConst(GV->getInitializer(), PtrTy, Offset, DL);
84d26b43ffSAlexander Shaposhnikov     auto *Func = dyn_cast_or_null<Function>(C);
85d26b43ffSAlexander Shaposhnikov     if (!Func || Func->isDeclaration() ||
86d26b43ffSAlexander Shaposhnikov         Func->getInstructionCount() > FunctionSizeThreshold)
87d26b43ffSAlexander Shaposhnikov       return std::nullopt;
88d26b43ffSAlexander Shaposhnikov     JumpTable.Funcs.push_back(Func);
89d26b43ffSAlexander Shaposhnikov   }
90d26b43ffSAlexander Shaposhnikov   return JumpTable;
91d26b43ffSAlexander Shaposhnikov }
92d26b43ffSAlexander Shaposhnikov 
93d26b43ffSAlexander Shaposhnikov static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
94d26b43ffSAlexander Shaposhnikov                                   DomTreeUpdater &DTU,
95d26b43ffSAlexander Shaposhnikov                                   OptimizationRemarkEmitter &ORE) {
96d26b43ffSAlexander Shaposhnikov   const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
97d26b43ffSAlexander Shaposhnikov 
98d26b43ffSAlexander Shaposhnikov   SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
99d26b43ffSAlexander Shaposhnikov   BasicBlock *BB = CB->getParent();
100d26b43ffSAlexander Shaposhnikov   BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr,
101d26b43ffSAlexander Shaposhnikov                                 BB->getName() + Twine(".tail"));
102d26b43ffSAlexander Shaposhnikov   DTUpdates.push_back({DominatorTree::Delete, BB, Tail});
103d26b43ffSAlexander Shaposhnikov   BB->getTerminator()->eraseFromParent();
104d26b43ffSAlexander Shaposhnikov 
105d26b43ffSAlexander Shaposhnikov   Function &F = *BB->getParent();
106d26b43ffSAlexander Shaposhnikov   BasicBlock *BBUnreachable = BasicBlock::Create(
107d26b43ffSAlexander Shaposhnikov       F.getContext(), "default.switch.case.unreachable", &F, Tail);
108d26b43ffSAlexander Shaposhnikov   IRBuilder<> BuilderUnreachable(BBUnreachable);
109d26b43ffSAlexander Shaposhnikov   BuilderUnreachable.CreateUnreachable();
110d26b43ffSAlexander Shaposhnikov 
111d26b43ffSAlexander Shaposhnikov   IRBuilder<> Builder(BB);
112d26b43ffSAlexander Shaposhnikov   SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable);
113d26b43ffSAlexander Shaposhnikov   DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable});
114d26b43ffSAlexander Shaposhnikov 
115d26b43ffSAlexander Shaposhnikov   IRBuilder<> BuilderTail(CB);
116d26b43ffSAlexander Shaposhnikov   PHINode *PHI =
117d26b43ffSAlexander Shaposhnikov       IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
118d26b43ffSAlexander Shaposhnikov 
119d26b43ffSAlexander Shaposhnikov   for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
120d26b43ffSAlexander Shaposhnikov     BasicBlock *B = BasicBlock::Create(Func->getContext(),
121d26b43ffSAlexander Shaposhnikov                                        "call." + Twine(Index), &F, Tail);
122d26b43ffSAlexander Shaposhnikov     DTUpdates.push_back({DominatorTree::Insert, BB, B});
123d26b43ffSAlexander Shaposhnikov     DTUpdates.push_back({DominatorTree::Insert, B, Tail});
124d26b43ffSAlexander Shaposhnikov 
125d26b43ffSAlexander Shaposhnikov     CallBase *Call = cast<CallBase>(CB->clone());
126d26b43ffSAlexander Shaposhnikov     Call->setCalledFunction(Func);
127d26b43ffSAlexander Shaposhnikov     Call->insertInto(B, B->end());
128d26b43ffSAlexander Shaposhnikov     Switch->addCase(
129d26b43ffSAlexander Shaposhnikov         cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
130d26b43ffSAlexander Shaposhnikov     BranchInst::Create(Tail, B);
131d26b43ffSAlexander Shaposhnikov     if (PHI)
132d26b43ffSAlexander Shaposhnikov       PHI->addIncoming(Call, B);
133d26b43ffSAlexander Shaposhnikov   }
134d26b43ffSAlexander Shaposhnikov   DTU.applyUpdates(DTUpdates);
135d26b43ffSAlexander Shaposhnikov   ORE.emit([&]() {
136d26b43ffSAlexander Shaposhnikov     return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
137d26b43ffSAlexander Shaposhnikov            << "expanded indirect call into switch";
138d26b43ffSAlexander Shaposhnikov   });
139d26b43ffSAlexander Shaposhnikov   if (PHI)
140d26b43ffSAlexander Shaposhnikov     CB->replaceAllUsesWith(PHI);
141d26b43ffSAlexander Shaposhnikov   CB->eraseFromParent();
142d26b43ffSAlexander Shaposhnikov   return Tail;
143d26b43ffSAlexander Shaposhnikov }
144d26b43ffSAlexander Shaposhnikov 
145d26b43ffSAlexander Shaposhnikov PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
146d26b43ffSAlexander Shaposhnikov                                              FunctionAnalysisManager &AM) {
147d26b43ffSAlexander Shaposhnikov   OptimizationRemarkEmitter &ORE =
148d26b43ffSAlexander Shaposhnikov       AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
149d26b43ffSAlexander Shaposhnikov   DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
150d26b43ffSAlexander Shaposhnikov   PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
151d26b43ffSAlexander Shaposhnikov   DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
152d26b43ffSAlexander Shaposhnikov   bool Changed = false;
153d26b43ffSAlexander Shaposhnikov   for (BasicBlock &BB : make_early_inc_range(F)) {
154d26b43ffSAlexander Shaposhnikov     BasicBlock *CurrentBB = &BB;
155d26b43ffSAlexander Shaposhnikov     while (CurrentBB) {
156d26b43ffSAlexander Shaposhnikov       BasicBlock *SplittedOutTail = nullptr;
157d26b43ffSAlexander Shaposhnikov       for (Instruction &I : make_early_inc_range(*CurrentBB)) {
158d26b43ffSAlexander Shaposhnikov         auto *Call = dyn_cast<CallInst>(&I);
159d26b43ffSAlexander Shaposhnikov         if (!Call || Call->getCalledFunction() || Call->isMustTailCall())
160d26b43ffSAlexander Shaposhnikov           continue;
161d26b43ffSAlexander Shaposhnikov         auto *L = dyn_cast<LoadInst>(Call->getCalledOperand());
162d26b43ffSAlexander Shaposhnikov         // Skip atomic or volatile loads.
163d26b43ffSAlexander Shaposhnikov         if (!L || !L->isSimple())
164d26b43ffSAlexander Shaposhnikov           continue;
165d26b43ffSAlexander Shaposhnikov         auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand());
166d26b43ffSAlexander Shaposhnikov         if (!GEP)
167d26b43ffSAlexander Shaposhnikov           continue;
168d26b43ffSAlexander Shaposhnikov         auto *PtrTy = dyn_cast<PointerType>(L->getType());
169d26b43ffSAlexander Shaposhnikov         assert(PtrTy && "call operand must be a pointer");
170d26b43ffSAlexander Shaposhnikov         std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
171d26b43ffSAlexander Shaposhnikov         if (!JumpTable)
172d26b43ffSAlexander Shaposhnikov           continue;
173d26b43ffSAlexander Shaposhnikov         SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE);
174d26b43ffSAlexander Shaposhnikov         Changed = true;
175d26b43ffSAlexander Shaposhnikov         break;
176d26b43ffSAlexander Shaposhnikov       }
177d26b43ffSAlexander Shaposhnikov       CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr;
178d26b43ffSAlexander Shaposhnikov     }
179d26b43ffSAlexander Shaposhnikov   }
180d26b43ffSAlexander Shaposhnikov 
181d26b43ffSAlexander Shaposhnikov   if (!Changed)
182d26b43ffSAlexander Shaposhnikov     return PreservedAnalyses::all();
183d26b43ffSAlexander Shaposhnikov 
184d26b43ffSAlexander Shaposhnikov   PreservedAnalyses PA;
185d26b43ffSAlexander Shaposhnikov   if (DT)
186d26b43ffSAlexander Shaposhnikov     PA.preserve<DominatorTreeAnalysis>();
187d26b43ffSAlexander Shaposhnikov   if (PDT)
188d26b43ffSAlexander Shaposhnikov     PA.preserve<PostDominatorTreeAnalysis>();
189d26b43ffSAlexander Shaposhnikov   return PA;
190d26b43ffSAlexander Shaposhnikov }
191