109467b48Spatrick //===- BlockExtractor.cpp - Extracts blocks into their own functions ------===//
209467b48Spatrick //
309467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409467b48Spatrick // See https://llvm.org/LICENSE.txt for license information.
509467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609467b48Spatrick //
709467b48Spatrick //===----------------------------------------------------------------------===//
809467b48Spatrick //
909467b48Spatrick // This pass extracts the specified basic blocks from the module into their
1009467b48Spatrick // own functions.
1109467b48Spatrick //
1209467b48Spatrick //===----------------------------------------------------------------------===//
1309467b48Spatrick
1473471bf0Spatrick #include "llvm/Transforms/IPO/BlockExtractor.h"
1509467b48Spatrick #include "llvm/ADT/STLExtras.h"
1609467b48Spatrick #include "llvm/ADT/Statistic.h"
1709467b48Spatrick #include "llvm/IR/Instructions.h"
1809467b48Spatrick #include "llvm/IR/Module.h"
1973471bf0Spatrick #include "llvm/IR/PassManager.h"
2009467b48Spatrick #include "llvm/InitializePasses.h"
2109467b48Spatrick #include "llvm/Pass.h"
2209467b48Spatrick #include "llvm/Support/CommandLine.h"
2309467b48Spatrick #include "llvm/Support/Debug.h"
2409467b48Spatrick #include "llvm/Support/MemoryBuffer.h"
2509467b48Spatrick #include "llvm/Transforms/IPO.h"
2609467b48Spatrick #include "llvm/Transforms/Utils/BasicBlockUtils.h"
2709467b48Spatrick #include "llvm/Transforms/Utils/CodeExtractor.h"
2809467b48Spatrick
2909467b48Spatrick using namespace llvm;
3009467b48Spatrick
3109467b48Spatrick #define DEBUG_TYPE "block-extractor"
3209467b48Spatrick
3309467b48Spatrick STATISTIC(NumExtracted, "Number of basic blocks extracted");
3409467b48Spatrick
3509467b48Spatrick static cl::opt<std::string> BlockExtractorFile(
3609467b48Spatrick "extract-blocks-file", cl::value_desc("filename"),
3709467b48Spatrick cl::desc("A file containing list of basic blocks to extract"), cl::Hidden);
3809467b48Spatrick
3973471bf0Spatrick static cl::opt<bool>
4073471bf0Spatrick BlockExtractorEraseFuncs("extract-blocks-erase-funcs",
4109467b48Spatrick cl::desc("Erase the existing functions"),
4209467b48Spatrick cl::Hidden);
4309467b48Spatrick namespace {
4473471bf0Spatrick class BlockExtractor {
4573471bf0Spatrick public:
BlockExtractor(bool EraseFunctions)4673471bf0Spatrick BlockExtractor(bool EraseFunctions) : EraseFunctions(EraseFunctions) {}
4773471bf0Spatrick bool runOnModule(Module &M);
48*d415bd75Srobert void
init(const std::vector<std::vector<BasicBlock * >> & GroupsOfBlocksToExtract)49*d415bd75Srobert init(const std::vector<std::vector<BasicBlock *>> &GroupsOfBlocksToExtract) {
50*d415bd75Srobert GroupsOfBlocks = GroupsOfBlocksToExtract;
5109467b48Spatrick if (!BlockExtractorFile.empty())
5209467b48Spatrick loadFile();
5309467b48Spatrick }
5409467b48Spatrick
5573471bf0Spatrick private:
56*d415bd75Srobert std::vector<std::vector<BasicBlock *>> GroupsOfBlocks;
5773471bf0Spatrick bool EraseFunctions;
5873471bf0Spatrick /// Map a function name to groups of blocks.
5973471bf0Spatrick SmallVector<std::pair<std::string, SmallVector<std::string, 4>>, 4>
6073471bf0Spatrick BlocksByName;
6173471bf0Spatrick
6273471bf0Spatrick void loadFile();
6373471bf0Spatrick void splitLandingPadPreds(Function &F);
6473471bf0Spatrick };
6573471bf0Spatrick
6609467b48Spatrick } // end anonymous namespace
6709467b48Spatrick
6809467b48Spatrick /// Gets all of the blocks specified in the input file.
loadFile()6909467b48Spatrick void BlockExtractor::loadFile() {
7009467b48Spatrick auto ErrOrBuf = MemoryBuffer::getFile(BlockExtractorFile);
7109467b48Spatrick if (ErrOrBuf.getError())
7209467b48Spatrick report_fatal_error("BlockExtractor couldn't load the file.");
7309467b48Spatrick // Read the file.
7409467b48Spatrick auto &Buf = *ErrOrBuf;
7509467b48Spatrick SmallVector<StringRef, 16> Lines;
7609467b48Spatrick Buf->getBuffer().split(Lines, '\n', /*MaxSplit=*/-1,
7709467b48Spatrick /*KeepEmpty=*/false);
7809467b48Spatrick for (const auto &Line : Lines) {
7909467b48Spatrick SmallVector<StringRef, 4> LineSplit;
8009467b48Spatrick Line.split(LineSplit, ' ', /*MaxSplit=*/-1,
8109467b48Spatrick /*KeepEmpty=*/false);
8209467b48Spatrick if (LineSplit.empty())
8309467b48Spatrick continue;
8409467b48Spatrick if (LineSplit.size()!=2)
85*d415bd75Srobert report_fatal_error("Invalid line format, expecting lines like: 'funcname bb1[;bb2..]'",
86*d415bd75Srobert /*GenCrashDiag=*/false);
8709467b48Spatrick SmallVector<StringRef, 4> BBNames;
8809467b48Spatrick LineSplit[1].split(BBNames, ';', /*MaxSplit=*/-1,
8909467b48Spatrick /*KeepEmpty=*/false);
9009467b48Spatrick if (BBNames.empty())
9109467b48Spatrick report_fatal_error("Missing bbs name");
92097a140dSpatrick BlocksByName.push_back(
93097a140dSpatrick {std::string(LineSplit[0]), {BBNames.begin(), BBNames.end()}});
9409467b48Spatrick }
9509467b48Spatrick }
9609467b48Spatrick
9709467b48Spatrick /// Extracts the landing pads to make sure all of them have only one
9809467b48Spatrick /// predecessor.
splitLandingPadPreds(Function & F)9909467b48Spatrick void BlockExtractor::splitLandingPadPreds(Function &F) {
10009467b48Spatrick for (BasicBlock &BB : F) {
10109467b48Spatrick for (Instruction &I : BB) {
10209467b48Spatrick if (!isa<InvokeInst>(&I))
10309467b48Spatrick continue;
10409467b48Spatrick InvokeInst *II = cast<InvokeInst>(&I);
10509467b48Spatrick BasicBlock *Parent = II->getParent();
10609467b48Spatrick BasicBlock *LPad = II->getUnwindDest();
10709467b48Spatrick
10809467b48Spatrick // Look through the landing pad's predecessors. If one of them ends in an
10909467b48Spatrick // 'invoke', then we want to split the landing pad.
11009467b48Spatrick bool Split = false;
111*d415bd75Srobert for (auto *PredBB : predecessors(LPad)) {
11209467b48Spatrick if (PredBB->isLandingPad() && PredBB != Parent &&
11309467b48Spatrick isa<InvokeInst>(Parent->getTerminator())) {
11409467b48Spatrick Split = true;
11509467b48Spatrick break;
11609467b48Spatrick }
11709467b48Spatrick }
11809467b48Spatrick
11909467b48Spatrick if (!Split)
12009467b48Spatrick continue;
12109467b48Spatrick
12209467b48Spatrick SmallVector<BasicBlock *, 2> NewBBs;
12309467b48Spatrick SplitLandingPadPredecessors(LPad, Parent, ".1", ".2", NewBBs);
12409467b48Spatrick }
12509467b48Spatrick }
12609467b48Spatrick }
12709467b48Spatrick
runOnModule(Module & M)12809467b48Spatrick bool BlockExtractor::runOnModule(Module &M) {
12909467b48Spatrick bool Changed = false;
13009467b48Spatrick
13109467b48Spatrick // Get all the functions.
13209467b48Spatrick SmallVector<Function *, 4> Functions;
13309467b48Spatrick for (Function &F : M) {
13409467b48Spatrick splitLandingPadPreds(F);
13509467b48Spatrick Functions.push_back(&F);
13609467b48Spatrick }
13709467b48Spatrick
13809467b48Spatrick // Get all the blocks specified in the input file.
13909467b48Spatrick unsigned NextGroupIdx = GroupsOfBlocks.size();
14009467b48Spatrick GroupsOfBlocks.resize(NextGroupIdx + BlocksByName.size());
14109467b48Spatrick for (const auto &BInfo : BlocksByName) {
14209467b48Spatrick Function *F = M.getFunction(BInfo.first);
14309467b48Spatrick if (!F)
144*d415bd75Srobert report_fatal_error("Invalid function name specified in the input file",
145*d415bd75Srobert /*GenCrashDiag=*/false);
14609467b48Spatrick for (const auto &BBInfo : BInfo.second) {
14709467b48Spatrick auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) {
14809467b48Spatrick return BB.getName().equals(BBInfo);
14909467b48Spatrick });
15009467b48Spatrick if (Res == F->end())
151*d415bd75Srobert report_fatal_error("Invalid block name specified in the input file",
152*d415bd75Srobert /*GenCrashDiag=*/false);
15309467b48Spatrick GroupsOfBlocks[NextGroupIdx].push_back(&*Res);
15409467b48Spatrick }
15509467b48Spatrick ++NextGroupIdx;
15609467b48Spatrick }
15709467b48Spatrick
15809467b48Spatrick // Extract each group of basic blocks.
15909467b48Spatrick for (auto &BBs : GroupsOfBlocks) {
16009467b48Spatrick SmallVector<BasicBlock *, 32> BlocksToExtractVec;
16109467b48Spatrick for (BasicBlock *BB : BBs) {
16209467b48Spatrick // Check if the module contains BB.
16309467b48Spatrick if (BB->getParent()->getParent() != &M)
164*d415bd75Srobert report_fatal_error("Invalid basic block", /*GenCrashDiag=*/false);
16509467b48Spatrick LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting "
16609467b48Spatrick << BB->getParent()->getName() << ":" << BB->getName()
16709467b48Spatrick << "\n");
16809467b48Spatrick BlocksToExtractVec.push_back(BB);
16909467b48Spatrick if (const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator()))
17009467b48Spatrick BlocksToExtractVec.push_back(II->getUnwindDest());
17109467b48Spatrick ++NumExtracted;
17209467b48Spatrick Changed = true;
17309467b48Spatrick }
17409467b48Spatrick CodeExtractorAnalysisCache CEAC(*BBs[0]->getParent());
17509467b48Spatrick Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion(CEAC);
17609467b48Spatrick if (F)
17709467b48Spatrick LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs.begin())->getName()
17809467b48Spatrick << "' in: " << F->getName() << '\n');
17909467b48Spatrick else
18009467b48Spatrick LLVM_DEBUG(dbgs() << "Failed to extract for group '"
18109467b48Spatrick << (*BBs.begin())->getName() << "'\n");
18209467b48Spatrick }
18309467b48Spatrick
18409467b48Spatrick // Erase the functions.
18509467b48Spatrick if (EraseFunctions || BlockExtractorEraseFuncs) {
18609467b48Spatrick for (Function *F : Functions) {
18709467b48Spatrick LLVM_DEBUG(dbgs() << "BlockExtractor: Trying to delete " << F->getName()
18809467b48Spatrick << "\n");
18909467b48Spatrick F->deleteBody();
19009467b48Spatrick }
19109467b48Spatrick // Set linkage as ExternalLinkage to avoid erasing unreachable functions.
19209467b48Spatrick for (Function &F : M)
19309467b48Spatrick F.setLinkage(GlobalValue::ExternalLinkage);
19409467b48Spatrick Changed = true;
19509467b48Spatrick }
19609467b48Spatrick
19709467b48Spatrick return Changed;
19809467b48Spatrick }
19973471bf0Spatrick
BlockExtractorPass(std::vector<std::vector<BasicBlock * >> && GroupsOfBlocks,bool EraseFunctions)200*d415bd75Srobert BlockExtractorPass::BlockExtractorPass(
201*d415bd75Srobert std::vector<std::vector<BasicBlock *>> &&GroupsOfBlocks,
202*d415bd75Srobert bool EraseFunctions)
203*d415bd75Srobert : GroupsOfBlocks(GroupsOfBlocks), EraseFunctions(EraseFunctions) {}
20473471bf0Spatrick
run(Module & M,ModuleAnalysisManager & AM)20573471bf0Spatrick PreservedAnalyses BlockExtractorPass::run(Module &M,
20673471bf0Spatrick ModuleAnalysisManager &AM) {
207*d415bd75Srobert BlockExtractor BE(EraseFunctions);
208*d415bd75Srobert BE.init(GroupsOfBlocks);
20973471bf0Spatrick return BE.runOnModule(M) ? PreservedAnalyses::none()
21073471bf0Spatrick : PreservedAnalyses::all();
21173471bf0Spatrick }
212