1d26b43ffSAlexander Shaposhnikov //===- JumpTableToSwitch.cpp ----------------------------------------------===// 2d26b43ffSAlexander Shaposhnikov // 3d26b43ffSAlexander Shaposhnikov // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4d26b43ffSAlexander Shaposhnikov // See https://llvm.org/LICENSE.txt for license information. 5d26b43ffSAlexander Shaposhnikov // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6d26b43ffSAlexander Shaposhnikov // 7d26b43ffSAlexander Shaposhnikov //===----------------------------------------------------------------------===// 8d26b43ffSAlexander Shaposhnikov 9d26b43ffSAlexander Shaposhnikov #include "llvm/Transforms/Scalar/JumpTableToSwitch.h" 10d26b43ffSAlexander Shaposhnikov #include "llvm/ADT/SmallVector.h" 11d26b43ffSAlexander Shaposhnikov #include "llvm/Analysis/ConstantFolding.h" 12d26b43ffSAlexander Shaposhnikov #include "llvm/Analysis/DomTreeUpdater.h" 13d26b43ffSAlexander Shaposhnikov #include "llvm/Analysis/OptimizationRemarkEmitter.h" 14d26b43ffSAlexander Shaposhnikov #include "llvm/Analysis/PostDominators.h" 15d26b43ffSAlexander Shaposhnikov #include "llvm/IR/IRBuilder.h" 16d26b43ffSAlexander Shaposhnikov #include "llvm/Support/CommandLine.h" 17d26b43ffSAlexander Shaposhnikov #include "llvm/Transforms/Utils/BasicBlockUtils.h" 18d26b43ffSAlexander Shaposhnikov 19d26b43ffSAlexander Shaposhnikov using namespace llvm; 20d26b43ffSAlexander Shaposhnikov 21d26b43ffSAlexander Shaposhnikov static cl::opt<unsigned> 22d26b43ffSAlexander Shaposhnikov JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden, 23d26b43ffSAlexander Shaposhnikov cl::desc("Only split jump tables with size less or " 24d26b43ffSAlexander Shaposhnikov "equal than JumpTableSizeThreshold."), 25d26b43ffSAlexander Shaposhnikov cl::init(10)); 26d26b43ffSAlexander Shaposhnikov 27d26b43ffSAlexander Shaposhnikov // TODO: Consider adding a cost model for profitability analysis of this 28d26b43ffSAlexander Shaposhnikov // transformation. Currently we replace a jump table with a switch if all the 29d26b43ffSAlexander Shaposhnikov // functions in the jump table are smaller than the provided threshold. 30d26b43ffSAlexander Shaposhnikov static cl::opt<unsigned> FunctionSizeThreshold( 31d26b43ffSAlexander Shaposhnikov "jump-table-to-switch-function-size-threshold", cl::Hidden, 32d26b43ffSAlexander Shaposhnikov cl::desc("Only split jump tables containing functions whose sizes are less " 33d26b43ffSAlexander Shaposhnikov "or equal than this threshold."), 34d26b43ffSAlexander Shaposhnikov cl::init(50)); 35d26b43ffSAlexander Shaposhnikov 36d26b43ffSAlexander Shaposhnikov #define DEBUG_TYPE "jump-table-to-switch" 37d26b43ffSAlexander Shaposhnikov 38d26b43ffSAlexander Shaposhnikov namespace { 39d26b43ffSAlexander Shaposhnikov struct JumpTableTy { 40d26b43ffSAlexander Shaposhnikov Value *Index; 41d26b43ffSAlexander Shaposhnikov SmallVector<Function *, 10> Funcs; 42d26b43ffSAlexander Shaposhnikov }; 43d26b43ffSAlexander Shaposhnikov } // anonymous namespace 44d26b43ffSAlexander Shaposhnikov 45d26b43ffSAlexander Shaposhnikov static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP, 46d26b43ffSAlexander Shaposhnikov PointerType *PtrTy) { 47d26b43ffSAlexander Shaposhnikov Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand()); 48d26b43ffSAlexander Shaposhnikov if (!Ptr) 49d26b43ffSAlexander Shaposhnikov return std::nullopt; 50d26b43ffSAlexander Shaposhnikov 51d26b43ffSAlexander Shaposhnikov GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr); 52d26b43ffSAlexander Shaposhnikov if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) 53d26b43ffSAlexander Shaposhnikov return std::nullopt; 54d26b43ffSAlexander Shaposhnikov 55d26b43ffSAlexander Shaposhnikov Function &F = *GEP->getParent()->getParent(); 569df71d76SNikita Popov const DataLayout &DL = F.getDataLayout(); 57d26b43ffSAlexander Shaposhnikov const unsigned BitWidth = 58d26b43ffSAlexander Shaposhnikov DL.getIndexSizeInBits(GEP->getPointerAddressSpace()); 59*96f37ae4SJeremy Morse SmallMapVector<Value *, APInt, 4> VariableOffsets; 60d26b43ffSAlexander Shaposhnikov APInt ConstantOffset(BitWidth, 0); 61d26b43ffSAlexander Shaposhnikov if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) 62d26b43ffSAlexander Shaposhnikov return std::nullopt; 63d26b43ffSAlexander Shaposhnikov if (VariableOffsets.size() != 1) 64d26b43ffSAlexander Shaposhnikov return std::nullopt; 65d26b43ffSAlexander Shaposhnikov // TODO: consider supporting more general patterns 66d26b43ffSAlexander Shaposhnikov if (!ConstantOffset.isZero()) 67d26b43ffSAlexander Shaposhnikov return std::nullopt; 68d26b43ffSAlexander Shaposhnikov APInt StrideBytes = VariableOffsets.front().second; 69d26b43ffSAlexander Shaposhnikov const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(GV->getValueType()); 70d26b43ffSAlexander Shaposhnikov if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0) 71d26b43ffSAlexander Shaposhnikov return std::nullopt; 72d26b43ffSAlexander Shaposhnikov const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue(); 73d26b43ffSAlexander Shaposhnikov if (N > JumpTableSizeThreshold) 74d26b43ffSAlexander Shaposhnikov return std::nullopt; 75d26b43ffSAlexander Shaposhnikov 76d26b43ffSAlexander Shaposhnikov JumpTableTy JumpTable; 77d26b43ffSAlexander Shaposhnikov JumpTable.Index = VariableOffsets.front().first; 78d26b43ffSAlexander Shaposhnikov JumpTable.Funcs.reserve(N); 79d26b43ffSAlexander Shaposhnikov for (uint64_t Index = 0; Index < N; ++Index) { 80d26b43ffSAlexander Shaposhnikov // ConstantOffset is zero. 81d26b43ffSAlexander Shaposhnikov APInt Offset = Index * StrideBytes; 82d26b43ffSAlexander Shaposhnikov Constant *C = 83d26b43ffSAlexander Shaposhnikov ConstantFoldLoadFromConst(GV->getInitializer(), PtrTy, Offset, DL); 84d26b43ffSAlexander Shaposhnikov auto *Func = dyn_cast_or_null<Function>(C); 85d26b43ffSAlexander Shaposhnikov if (!Func || Func->isDeclaration() || 86d26b43ffSAlexander Shaposhnikov Func->getInstructionCount() > FunctionSizeThreshold) 87d26b43ffSAlexander Shaposhnikov return std::nullopt; 88d26b43ffSAlexander Shaposhnikov JumpTable.Funcs.push_back(Func); 89d26b43ffSAlexander Shaposhnikov } 90d26b43ffSAlexander Shaposhnikov return JumpTable; 91d26b43ffSAlexander Shaposhnikov } 92d26b43ffSAlexander Shaposhnikov 93d26b43ffSAlexander Shaposhnikov static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT, 94d26b43ffSAlexander Shaposhnikov DomTreeUpdater &DTU, 95d26b43ffSAlexander Shaposhnikov OptimizationRemarkEmitter &ORE) { 96d26b43ffSAlexander Shaposhnikov const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext()); 97d26b43ffSAlexander Shaposhnikov 98d26b43ffSAlexander Shaposhnikov SmallVector<DominatorTree::UpdateType, 8> DTUpdates; 99d26b43ffSAlexander Shaposhnikov BasicBlock *BB = CB->getParent(); 100d26b43ffSAlexander Shaposhnikov BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr, 101d26b43ffSAlexander Shaposhnikov BB->getName() + Twine(".tail")); 102d26b43ffSAlexander Shaposhnikov DTUpdates.push_back({DominatorTree::Delete, BB, Tail}); 103d26b43ffSAlexander Shaposhnikov BB->getTerminator()->eraseFromParent(); 104d26b43ffSAlexander Shaposhnikov 105d26b43ffSAlexander Shaposhnikov Function &F = *BB->getParent(); 106d26b43ffSAlexander Shaposhnikov BasicBlock *BBUnreachable = BasicBlock::Create( 107d26b43ffSAlexander Shaposhnikov F.getContext(), "default.switch.case.unreachable", &F, Tail); 108d26b43ffSAlexander Shaposhnikov IRBuilder<> BuilderUnreachable(BBUnreachable); 109d26b43ffSAlexander Shaposhnikov BuilderUnreachable.CreateUnreachable(); 110d26b43ffSAlexander Shaposhnikov 111d26b43ffSAlexander Shaposhnikov IRBuilder<> Builder(BB); 112d26b43ffSAlexander Shaposhnikov SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable); 113d26b43ffSAlexander Shaposhnikov DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable}); 114d26b43ffSAlexander Shaposhnikov 115d26b43ffSAlexander Shaposhnikov IRBuilder<> BuilderTail(CB); 116d26b43ffSAlexander Shaposhnikov PHINode *PHI = 117d26b43ffSAlexander Shaposhnikov IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size()); 118d26b43ffSAlexander Shaposhnikov 119d26b43ffSAlexander Shaposhnikov for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) { 120d26b43ffSAlexander Shaposhnikov BasicBlock *B = BasicBlock::Create(Func->getContext(), 121d26b43ffSAlexander Shaposhnikov "call." + Twine(Index), &F, Tail); 122d26b43ffSAlexander Shaposhnikov DTUpdates.push_back({DominatorTree::Insert, BB, B}); 123d26b43ffSAlexander Shaposhnikov DTUpdates.push_back({DominatorTree::Insert, B, Tail}); 124d26b43ffSAlexander Shaposhnikov 125d26b43ffSAlexander Shaposhnikov CallBase *Call = cast<CallBase>(CB->clone()); 126d26b43ffSAlexander Shaposhnikov Call->setCalledFunction(Func); 127d26b43ffSAlexander Shaposhnikov Call->insertInto(B, B->end()); 128d26b43ffSAlexander Shaposhnikov Switch->addCase( 129d26b43ffSAlexander Shaposhnikov cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B); 130d26b43ffSAlexander Shaposhnikov BranchInst::Create(Tail, B); 131d26b43ffSAlexander Shaposhnikov if (PHI) 132d26b43ffSAlexander Shaposhnikov PHI->addIncoming(Call, B); 133d26b43ffSAlexander Shaposhnikov } 134d26b43ffSAlexander Shaposhnikov DTU.applyUpdates(DTUpdates); 135d26b43ffSAlexander Shaposhnikov ORE.emit([&]() { 136d26b43ffSAlexander Shaposhnikov return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB) 137d26b43ffSAlexander Shaposhnikov << "expanded indirect call into switch"; 138d26b43ffSAlexander Shaposhnikov }); 139d26b43ffSAlexander Shaposhnikov if (PHI) 140d26b43ffSAlexander Shaposhnikov CB->replaceAllUsesWith(PHI); 141d26b43ffSAlexander Shaposhnikov CB->eraseFromParent(); 142d26b43ffSAlexander Shaposhnikov return Tail; 143d26b43ffSAlexander Shaposhnikov } 144d26b43ffSAlexander Shaposhnikov 145d26b43ffSAlexander Shaposhnikov PreservedAnalyses JumpTableToSwitchPass::run(Function &F, 146d26b43ffSAlexander Shaposhnikov FunctionAnalysisManager &AM) { 147d26b43ffSAlexander Shaposhnikov OptimizationRemarkEmitter &ORE = 148d26b43ffSAlexander Shaposhnikov AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 149d26b43ffSAlexander Shaposhnikov DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); 150d26b43ffSAlexander Shaposhnikov PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F); 151d26b43ffSAlexander Shaposhnikov DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy); 152d26b43ffSAlexander Shaposhnikov bool Changed = false; 153d26b43ffSAlexander Shaposhnikov for (BasicBlock &BB : make_early_inc_range(F)) { 154d26b43ffSAlexander Shaposhnikov BasicBlock *CurrentBB = &BB; 155d26b43ffSAlexander Shaposhnikov while (CurrentBB) { 156d26b43ffSAlexander Shaposhnikov BasicBlock *SplittedOutTail = nullptr; 157d26b43ffSAlexander Shaposhnikov for (Instruction &I : make_early_inc_range(*CurrentBB)) { 158d26b43ffSAlexander Shaposhnikov auto *Call = dyn_cast<CallInst>(&I); 159d26b43ffSAlexander Shaposhnikov if (!Call || Call->getCalledFunction() || Call->isMustTailCall()) 160d26b43ffSAlexander Shaposhnikov continue; 161d26b43ffSAlexander Shaposhnikov auto *L = dyn_cast<LoadInst>(Call->getCalledOperand()); 162d26b43ffSAlexander Shaposhnikov // Skip atomic or volatile loads. 163d26b43ffSAlexander Shaposhnikov if (!L || !L->isSimple()) 164d26b43ffSAlexander Shaposhnikov continue; 165d26b43ffSAlexander Shaposhnikov auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand()); 166d26b43ffSAlexander Shaposhnikov if (!GEP) 167d26b43ffSAlexander Shaposhnikov continue; 168d26b43ffSAlexander Shaposhnikov auto *PtrTy = dyn_cast<PointerType>(L->getType()); 169d26b43ffSAlexander Shaposhnikov assert(PtrTy && "call operand must be a pointer"); 170d26b43ffSAlexander Shaposhnikov std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy); 171d26b43ffSAlexander Shaposhnikov if (!JumpTable) 172d26b43ffSAlexander Shaposhnikov continue; 173d26b43ffSAlexander Shaposhnikov SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE); 174d26b43ffSAlexander Shaposhnikov Changed = true; 175d26b43ffSAlexander Shaposhnikov break; 176d26b43ffSAlexander Shaposhnikov } 177d26b43ffSAlexander Shaposhnikov CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr; 178d26b43ffSAlexander Shaposhnikov } 179d26b43ffSAlexander Shaposhnikov } 180d26b43ffSAlexander Shaposhnikov 181d26b43ffSAlexander Shaposhnikov if (!Changed) 182d26b43ffSAlexander Shaposhnikov return PreservedAnalyses::all(); 183d26b43ffSAlexander Shaposhnikov 184d26b43ffSAlexander Shaposhnikov PreservedAnalyses PA; 185d26b43ffSAlexander Shaposhnikov if (DT) 186d26b43ffSAlexander Shaposhnikov PA.preserve<DominatorTreeAnalysis>(); 187d26b43ffSAlexander Shaposhnikov if (PDT) 188d26b43ffSAlexander Shaposhnikov PA.preserve<PostDominatorTreeAnalysis>(); 189d26b43ffSAlexander Shaposhnikov return PA; 190d26b43ffSAlexander Shaposhnikov } 191