1*0fca6ea1SDimitry Andric //===- JumpTableToSwitch.cpp ----------------------------------------------===// 2*0fca6ea1SDimitry Andric // 3*0fca6ea1SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*0fca6ea1SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*0fca6ea1SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*0fca6ea1SDimitry Andric // 7*0fca6ea1SDimitry Andric //===----------------------------------------------------------------------===// 8*0fca6ea1SDimitry Andric 9*0fca6ea1SDimitry Andric #include "llvm/Transforms/Scalar/JumpTableToSwitch.h" 10*0fca6ea1SDimitry Andric #include "llvm/ADT/SmallVector.h" 11*0fca6ea1SDimitry Andric #include "llvm/Analysis/ConstantFolding.h" 12*0fca6ea1SDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h" 13*0fca6ea1SDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h" 14*0fca6ea1SDimitry Andric #include "llvm/Analysis/PostDominators.h" 15*0fca6ea1SDimitry Andric #include "llvm/IR/IRBuilder.h" 16*0fca6ea1SDimitry Andric #include "llvm/Support/CommandLine.h" 17*0fca6ea1SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h" 18*0fca6ea1SDimitry Andric 19*0fca6ea1SDimitry Andric using namespace llvm; 20*0fca6ea1SDimitry Andric 21*0fca6ea1SDimitry Andric static cl::opt<unsigned> 22*0fca6ea1SDimitry Andric JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden, 23*0fca6ea1SDimitry Andric cl::desc("Only split jump tables with size less or " 24*0fca6ea1SDimitry Andric "equal than JumpTableSizeThreshold."), 25*0fca6ea1SDimitry Andric cl::init(10)); 26*0fca6ea1SDimitry Andric 27*0fca6ea1SDimitry Andric // TODO: Consider adding a cost model for profitability analysis of this 28*0fca6ea1SDimitry Andric // transformation. Currently we replace a jump table with a switch if all the 29*0fca6ea1SDimitry Andric // functions in the jump table are smaller than the provided threshold. 30*0fca6ea1SDimitry Andric static cl::opt<unsigned> FunctionSizeThreshold( 31*0fca6ea1SDimitry Andric "jump-table-to-switch-function-size-threshold", cl::Hidden, 32*0fca6ea1SDimitry Andric cl::desc("Only split jump tables containing functions whose sizes are less " 33*0fca6ea1SDimitry Andric "or equal than this threshold."), 34*0fca6ea1SDimitry Andric cl::init(50)); 35*0fca6ea1SDimitry Andric 36*0fca6ea1SDimitry Andric #define DEBUG_TYPE "jump-table-to-switch" 37*0fca6ea1SDimitry Andric 38*0fca6ea1SDimitry Andric namespace { 39*0fca6ea1SDimitry Andric struct JumpTableTy { 40*0fca6ea1SDimitry Andric Value *Index; 41*0fca6ea1SDimitry Andric SmallVector<Function *, 10> Funcs; 42*0fca6ea1SDimitry Andric }; 43*0fca6ea1SDimitry Andric } // anonymous namespace 44*0fca6ea1SDimitry Andric 45*0fca6ea1SDimitry Andric static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP, 46*0fca6ea1SDimitry Andric PointerType *PtrTy) { 47*0fca6ea1SDimitry Andric Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand()); 48*0fca6ea1SDimitry Andric if (!Ptr) 49*0fca6ea1SDimitry Andric return std::nullopt; 50*0fca6ea1SDimitry Andric 51*0fca6ea1SDimitry Andric GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr); 52*0fca6ea1SDimitry Andric if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) 53*0fca6ea1SDimitry Andric return std::nullopt; 54*0fca6ea1SDimitry Andric 55*0fca6ea1SDimitry Andric Function &F = *GEP->getParent()->getParent(); 56*0fca6ea1SDimitry Andric const DataLayout &DL = F.getDataLayout(); 57*0fca6ea1SDimitry Andric const unsigned BitWidth = 58*0fca6ea1SDimitry Andric DL.getIndexSizeInBits(GEP->getPointerAddressSpace()); 59*0fca6ea1SDimitry Andric MapVector<Value *, APInt> VariableOffsets; 60*0fca6ea1SDimitry Andric APInt ConstantOffset(BitWidth, 0); 61*0fca6ea1SDimitry Andric if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) 62*0fca6ea1SDimitry Andric return std::nullopt; 63*0fca6ea1SDimitry Andric if (VariableOffsets.size() != 1) 64*0fca6ea1SDimitry Andric return std::nullopt; 65*0fca6ea1SDimitry Andric // TODO: consider supporting more general patterns 66*0fca6ea1SDimitry Andric if (!ConstantOffset.isZero()) 67*0fca6ea1SDimitry Andric return std::nullopt; 68*0fca6ea1SDimitry Andric APInt StrideBytes = VariableOffsets.front().second; 69*0fca6ea1SDimitry Andric const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(GV->getValueType()); 70*0fca6ea1SDimitry Andric if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0) 71*0fca6ea1SDimitry Andric return std::nullopt; 72*0fca6ea1SDimitry Andric const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue(); 73*0fca6ea1SDimitry Andric if (N > JumpTableSizeThreshold) 74*0fca6ea1SDimitry Andric return std::nullopt; 75*0fca6ea1SDimitry Andric 76*0fca6ea1SDimitry Andric JumpTableTy JumpTable; 77*0fca6ea1SDimitry Andric JumpTable.Index = VariableOffsets.front().first; 78*0fca6ea1SDimitry Andric JumpTable.Funcs.reserve(N); 79*0fca6ea1SDimitry Andric for (uint64_t Index = 0; Index < N; ++Index) { 80*0fca6ea1SDimitry Andric // ConstantOffset is zero. 81*0fca6ea1SDimitry Andric APInt Offset = Index * StrideBytes; 82*0fca6ea1SDimitry Andric Constant *C = 83*0fca6ea1SDimitry Andric ConstantFoldLoadFromConst(GV->getInitializer(), PtrTy, Offset, DL); 84*0fca6ea1SDimitry Andric auto *Func = dyn_cast_or_null<Function>(C); 85*0fca6ea1SDimitry Andric if (!Func || Func->isDeclaration() || 86*0fca6ea1SDimitry Andric Func->getInstructionCount() > FunctionSizeThreshold) 87*0fca6ea1SDimitry Andric return std::nullopt; 88*0fca6ea1SDimitry Andric JumpTable.Funcs.push_back(Func); 89*0fca6ea1SDimitry Andric } 90*0fca6ea1SDimitry Andric return JumpTable; 91*0fca6ea1SDimitry Andric } 92*0fca6ea1SDimitry Andric 93*0fca6ea1SDimitry Andric static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT, 94*0fca6ea1SDimitry Andric DomTreeUpdater &DTU, 95*0fca6ea1SDimitry Andric OptimizationRemarkEmitter &ORE) { 96*0fca6ea1SDimitry Andric const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext()); 97*0fca6ea1SDimitry Andric 98*0fca6ea1SDimitry Andric SmallVector<DominatorTree::UpdateType, 8> DTUpdates; 99*0fca6ea1SDimitry Andric BasicBlock *BB = CB->getParent(); 100*0fca6ea1SDimitry Andric BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr, 101*0fca6ea1SDimitry Andric BB->getName() + Twine(".tail")); 102*0fca6ea1SDimitry Andric DTUpdates.push_back({DominatorTree::Delete, BB, Tail}); 103*0fca6ea1SDimitry Andric BB->getTerminator()->eraseFromParent(); 104*0fca6ea1SDimitry Andric 105*0fca6ea1SDimitry Andric Function &F = *BB->getParent(); 106*0fca6ea1SDimitry Andric BasicBlock *BBUnreachable = BasicBlock::Create( 107*0fca6ea1SDimitry Andric F.getContext(), "default.switch.case.unreachable", &F, Tail); 108*0fca6ea1SDimitry Andric IRBuilder<> BuilderUnreachable(BBUnreachable); 109*0fca6ea1SDimitry Andric BuilderUnreachable.CreateUnreachable(); 110*0fca6ea1SDimitry Andric 111*0fca6ea1SDimitry Andric IRBuilder<> Builder(BB); 112*0fca6ea1SDimitry Andric SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable); 113*0fca6ea1SDimitry Andric DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable}); 114*0fca6ea1SDimitry Andric 115*0fca6ea1SDimitry Andric IRBuilder<> BuilderTail(CB); 116*0fca6ea1SDimitry Andric PHINode *PHI = 117*0fca6ea1SDimitry Andric IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size()); 118*0fca6ea1SDimitry Andric 119*0fca6ea1SDimitry Andric for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) { 120*0fca6ea1SDimitry Andric BasicBlock *B = BasicBlock::Create(Func->getContext(), 121*0fca6ea1SDimitry Andric "call." + Twine(Index), &F, Tail); 122*0fca6ea1SDimitry Andric DTUpdates.push_back({DominatorTree::Insert, BB, B}); 123*0fca6ea1SDimitry Andric DTUpdates.push_back({DominatorTree::Insert, B, Tail}); 124*0fca6ea1SDimitry Andric 125*0fca6ea1SDimitry Andric CallBase *Call = cast<CallBase>(CB->clone()); 126*0fca6ea1SDimitry Andric Call->setCalledFunction(Func); 127*0fca6ea1SDimitry Andric Call->insertInto(B, B->end()); 128*0fca6ea1SDimitry Andric Switch->addCase( 129*0fca6ea1SDimitry Andric cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B); 130*0fca6ea1SDimitry Andric BranchInst::Create(Tail, B); 131*0fca6ea1SDimitry Andric if (PHI) 132*0fca6ea1SDimitry Andric PHI->addIncoming(Call, B); 133*0fca6ea1SDimitry Andric } 134*0fca6ea1SDimitry Andric DTU.applyUpdates(DTUpdates); 135*0fca6ea1SDimitry Andric ORE.emit([&]() { 136*0fca6ea1SDimitry Andric return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB) 137*0fca6ea1SDimitry Andric << "expanded indirect call into switch"; 138*0fca6ea1SDimitry Andric }); 139*0fca6ea1SDimitry Andric if (PHI) 140*0fca6ea1SDimitry Andric CB->replaceAllUsesWith(PHI); 141*0fca6ea1SDimitry Andric CB->eraseFromParent(); 142*0fca6ea1SDimitry Andric return Tail; 143*0fca6ea1SDimitry Andric } 144*0fca6ea1SDimitry Andric 145*0fca6ea1SDimitry Andric PreservedAnalyses JumpTableToSwitchPass::run(Function &F, 146*0fca6ea1SDimitry Andric FunctionAnalysisManager &AM) { 147*0fca6ea1SDimitry Andric OptimizationRemarkEmitter &ORE = 148*0fca6ea1SDimitry Andric AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 149*0fca6ea1SDimitry Andric DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); 150*0fca6ea1SDimitry Andric PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F); 151*0fca6ea1SDimitry Andric DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy); 152*0fca6ea1SDimitry Andric bool Changed = false; 153*0fca6ea1SDimitry Andric for (BasicBlock &BB : make_early_inc_range(F)) { 154*0fca6ea1SDimitry Andric BasicBlock *CurrentBB = &BB; 155*0fca6ea1SDimitry Andric while (CurrentBB) { 156*0fca6ea1SDimitry Andric BasicBlock *SplittedOutTail = nullptr; 157*0fca6ea1SDimitry Andric for (Instruction &I : make_early_inc_range(*CurrentBB)) { 158*0fca6ea1SDimitry Andric auto *Call = dyn_cast<CallInst>(&I); 159*0fca6ea1SDimitry Andric if (!Call || Call->getCalledFunction() || Call->isMustTailCall()) 160*0fca6ea1SDimitry Andric continue; 161*0fca6ea1SDimitry Andric auto *L = dyn_cast<LoadInst>(Call->getCalledOperand()); 162*0fca6ea1SDimitry Andric // Skip atomic or volatile loads. 163*0fca6ea1SDimitry Andric if (!L || !L->isSimple()) 164*0fca6ea1SDimitry Andric continue; 165*0fca6ea1SDimitry Andric auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand()); 166*0fca6ea1SDimitry Andric if (!GEP) 167*0fca6ea1SDimitry Andric continue; 168*0fca6ea1SDimitry Andric auto *PtrTy = dyn_cast<PointerType>(L->getType()); 169*0fca6ea1SDimitry Andric assert(PtrTy && "call operand must be a pointer"); 170*0fca6ea1SDimitry Andric std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy); 171*0fca6ea1SDimitry Andric if (!JumpTable) 172*0fca6ea1SDimitry Andric continue; 173*0fca6ea1SDimitry Andric SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE); 174*0fca6ea1SDimitry Andric Changed = true; 175*0fca6ea1SDimitry Andric break; 176*0fca6ea1SDimitry Andric } 177*0fca6ea1SDimitry Andric CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr; 178*0fca6ea1SDimitry Andric } 179*0fca6ea1SDimitry Andric } 180*0fca6ea1SDimitry Andric 181*0fca6ea1SDimitry Andric if (!Changed) 182*0fca6ea1SDimitry Andric return PreservedAnalyses::all(); 183*0fca6ea1SDimitry Andric 184*0fca6ea1SDimitry Andric PreservedAnalyses PA; 185*0fca6ea1SDimitry Andric if (DT) 186*0fca6ea1SDimitry Andric PA.preserve<DominatorTreeAnalysis>(); 187*0fca6ea1SDimitry Andric if (PDT) 188*0fca6ea1SDimitry Andric PA.preserve<PostDominatorTreeAnalysis>(); 189*0fca6ea1SDimitry Andric return PA; 190*0fca6ea1SDimitry Andric } 191