1 //===- llvm-extract.cpp - LLVM function extraction utility ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This utility changes the input module to only contain a single function, 10 // which is primarily used for debugging transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/SetVector.h" 15 #include "llvm/ADT/SmallPtrSet.h" 16 #include "llvm/Bitcode/BitcodeWriterPass.h" 17 #include "llvm/IR/DataLayout.h" 18 #include "llvm/IR/IRPrintingPasses.h" 19 #include "llvm/IR/Instructions.h" 20 #include "llvm/IR/LLVMContext.h" 21 #include "llvm/IR/LegacyPassManager.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/IRReader/IRReader.h" 24 #include "llvm/Pass.h" 25 #include "llvm/Support/CommandLine.h" 26 #include "llvm/Support/Error.h" 27 #include "llvm/Support/FileSystem.h" 28 #include "llvm/Support/InitLLVM.h" 29 #include "llvm/Support/Regex.h" 30 #include "llvm/Support/SourceMgr.h" 31 #include "llvm/Support/SystemUtils.h" 32 #include "llvm/Support/ToolOutputFile.h" 33 #include "llvm/Transforms/IPO.h" 34 #include <memory> 35 #include <utility> 36 using namespace llvm; 37 38 cl::OptionCategory ExtractCat("llvm-extract Options"); 39 40 // InputFilename - The filename to read from. 41 static cl::opt<std::string> InputFilename(cl::Positional, 42 cl::desc("<input bitcode file>"), 43 cl::init("-"), 44 cl::value_desc("filename")); 45 46 static cl::opt<std::string> OutputFilename("o", 47 cl::desc("Specify output filename"), 48 cl::value_desc("filename"), 49 cl::init("-"), cl::cat(ExtractCat)); 50 51 static cl::opt<bool> Force("f", cl::desc("Enable binary output on terminals"), 52 cl::cat(ExtractCat)); 53 54 static cl::opt<bool> DeleteFn("delete", 55 cl::desc("Delete specified Globals from Module"), 56 cl::cat(ExtractCat)); 57 58 static cl::opt<bool> KeepConstInit("keep-const-init", 59 cl::desc("Keep initializers of constants"), 60 cl::cat(ExtractCat)); 61 62 static cl::opt<bool> 63 Recursive("recursive", cl::desc("Recursively extract all called functions"), 64 cl::cat(ExtractCat)); 65 66 // ExtractFuncs - The functions to extract from the module. 67 static cl::list<std::string> 68 ExtractFuncs("func", cl::desc("Specify function to extract"), 69 cl::value_desc("function"), cl::cat(ExtractCat)); 70 71 // ExtractRegExpFuncs - The functions, matched via regular expression, to 72 // extract from the module. 73 static cl::list<std::string> 74 ExtractRegExpFuncs("rfunc", 75 cl::desc("Specify function(s) to extract using a " 76 "regular expression"), 77 cl::value_desc("rfunction"), cl::cat(ExtractCat)); 78 79 // ExtractBlocks - The blocks to extract from the module. 80 static cl::list<std::string> ExtractBlocks( 81 "bb", 82 cl::desc( 83 "Specify <function, basic block1[;basic block2...]> pairs to extract.\n" 84 "Each pair will create a function.\n" 85 "If multiple basic blocks are specified in one pair,\n" 86 "the first block in the sequence should dominate the rest.\n" 87 "eg:\n" 88 " --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n" 89 " --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one " 90 "with bb2."), 91 cl::ZeroOrMore, cl::value_desc("function:bb1[;bb2...]"), 92 cl::cat(ExtractCat)); 93 94 // ExtractAlias - The alias to extract from the module. 95 static cl::list<std::string> 96 ExtractAliases("alias", cl::desc("Specify alias to extract"), 97 cl::value_desc("alias"), cl::cat(ExtractCat)); 98 99 // ExtractRegExpAliases - The aliases, matched via regular expression, to 100 // extract from the module. 101 static cl::list<std::string> 102 ExtractRegExpAliases("ralias", 103 cl::desc("Specify alias(es) to extract using a " 104 "regular expression"), 105 cl::value_desc("ralias"), cl::cat(ExtractCat)); 106 107 // ExtractGlobals - The globals to extract from the module. 108 static cl::list<std::string> 109 ExtractGlobals("glob", cl::desc("Specify global to extract"), 110 cl::value_desc("global"), cl::cat(ExtractCat)); 111 112 // ExtractRegExpGlobals - The globals, matched via regular expression, to 113 // extract from the module... 114 static cl::list<std::string> 115 ExtractRegExpGlobals("rglob", 116 cl::desc("Specify global(s) to extract using a " 117 "regular expression"), 118 cl::value_desc("rglobal"), cl::cat(ExtractCat)); 119 120 static cl::opt<bool> OutputAssembly("S", 121 cl::desc("Write output as LLVM assembly"), 122 cl::Hidden, cl::cat(ExtractCat)); 123 124 static cl::opt<bool> PreserveBitcodeUseListOrder( 125 "preserve-bc-uselistorder", 126 cl::desc("Preserve use-list order when writing LLVM bitcode."), 127 cl::init(true), cl::Hidden, cl::cat(ExtractCat)); 128 129 static cl::opt<bool> PreserveAssemblyUseListOrder( 130 "preserve-ll-uselistorder", 131 cl::desc("Preserve use-list order when writing LLVM assembly."), 132 cl::init(false), cl::Hidden, cl::cat(ExtractCat)); 133 134 int main(int argc, char **argv) { 135 InitLLVM X(argc, argv); 136 137 LLVMContext Context; 138 cl::HideUnrelatedOptions(ExtractCat); 139 cl::ParseCommandLineOptions(argc, argv, "llvm extractor\n"); 140 141 // Use lazy loading, since we only care about selected global values. 142 SMDiagnostic Err; 143 std::unique_ptr<Module> M = getLazyIRFileModule(InputFilename, Err, Context); 144 145 if (!M.get()) { 146 Err.print(argv[0], errs()); 147 return 1; 148 } 149 150 // Use SetVector to avoid duplicates. 151 SetVector<GlobalValue *> GVs; 152 153 // Figure out which aliases we should extract. 154 for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) { 155 GlobalAlias *GA = M->getNamedAlias(ExtractAliases[i]); 156 if (!GA) { 157 errs() << argv[0] << ": program doesn't contain alias named '" 158 << ExtractAliases[i] << "'!\n"; 159 return 1; 160 } 161 GVs.insert(GA); 162 } 163 164 // Extract aliases via regular expression matching. 165 for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) { 166 std::string Error; 167 Regex RegEx(ExtractRegExpAliases[i]); 168 if (!RegEx.isValid(Error)) { 169 errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' " 170 "invalid regex: " << Error; 171 } 172 bool match = false; 173 for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end(); 174 GA != E; GA++) { 175 if (RegEx.match(GA->getName())) { 176 GVs.insert(&*GA); 177 match = true; 178 } 179 } 180 if (!match) { 181 errs() << argv[0] << ": program doesn't contain global named '" 182 << ExtractRegExpAliases[i] << "'!\n"; 183 return 1; 184 } 185 } 186 187 // Figure out which globals we should extract. 188 for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) { 189 GlobalValue *GV = M->getNamedGlobal(ExtractGlobals[i]); 190 if (!GV) { 191 errs() << argv[0] << ": program doesn't contain global named '" 192 << ExtractGlobals[i] << "'!\n"; 193 return 1; 194 } 195 GVs.insert(GV); 196 } 197 198 // Extract globals via regular expression matching. 199 for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) { 200 std::string Error; 201 Regex RegEx(ExtractRegExpGlobals[i]); 202 if (!RegEx.isValid(Error)) { 203 errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' " 204 "invalid regex: " << Error; 205 } 206 bool match = false; 207 for (auto &GV : M->globals()) { 208 if (RegEx.match(GV.getName())) { 209 GVs.insert(&GV); 210 match = true; 211 } 212 } 213 if (!match) { 214 errs() << argv[0] << ": program doesn't contain global named '" 215 << ExtractRegExpGlobals[i] << "'!\n"; 216 return 1; 217 } 218 } 219 220 // Figure out which functions we should extract. 221 for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) { 222 GlobalValue *GV = M->getFunction(ExtractFuncs[i]); 223 if (!GV) { 224 errs() << argv[0] << ": program doesn't contain function named '" 225 << ExtractFuncs[i] << "'!\n"; 226 return 1; 227 } 228 GVs.insert(GV); 229 } 230 // Extract functions via regular expression matching. 231 for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) { 232 std::string Error; 233 StringRef RegExStr = ExtractRegExpFuncs[i]; 234 Regex RegEx(RegExStr); 235 if (!RegEx.isValid(Error)) { 236 errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' " 237 "invalid regex: " << Error; 238 } 239 bool match = false; 240 for (Module::iterator F = M->begin(), E = M->end(); F != E; 241 F++) { 242 if (RegEx.match(F->getName())) { 243 GVs.insert(&*F); 244 match = true; 245 } 246 } 247 if (!match) { 248 errs() << argv[0] << ": program doesn't contain global named '" 249 << ExtractRegExpFuncs[i] << "'!\n"; 250 return 1; 251 } 252 } 253 254 // Figure out which BasicBlocks we should extract. 255 SmallVector<std::pair<Function *, SmallVector<StringRef, 16>>, 2> BBMap; 256 for (StringRef StrPair : ExtractBlocks) { 257 SmallVector<StringRef, 16> BBNames; 258 auto BBInfo = StrPair.split(':'); 259 // Get the function. 260 Function *F = M->getFunction(BBInfo.first); 261 if (!F) { 262 errs() << argv[0] << ": program doesn't contain a function named '" 263 << BBInfo.first << "'!\n"; 264 return 1; 265 } 266 // Add the function to the materialize list, and store the basic block names 267 // to check after materialization. 268 GVs.insert(F); 269 BBInfo.second.split(BBNames, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false); 270 BBMap.push_back({F, std::move(BBNames)}); 271 } 272 273 // Use *argv instead of argv[0] to work around a wrong GCC warning. 274 ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: "); 275 276 if (Recursive) { 277 std::vector<llvm::Function *> Workqueue; 278 for (GlobalValue *GV : GVs) { 279 if (auto *F = dyn_cast<Function>(GV)) { 280 Workqueue.push_back(F); 281 } 282 } 283 while (!Workqueue.empty()) { 284 Function *F = &*Workqueue.back(); 285 Workqueue.pop_back(); 286 ExitOnErr(F->materialize()); 287 for (auto &BB : *F) { 288 for (auto &I : BB) { 289 CallBase *CB = dyn_cast<CallBase>(&I); 290 if (!CB) 291 continue; 292 Function *CF = CB->getCalledFunction(); 293 if (!CF) 294 continue; 295 if (CF->isDeclaration() || GVs.count(CF)) 296 continue; 297 GVs.insert(CF); 298 Workqueue.push_back(CF); 299 } 300 } 301 } 302 } 303 304 auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); }; 305 306 // Materialize requisite global values. 307 if (!DeleteFn) { 308 for (size_t i = 0, e = GVs.size(); i != e; ++i) 309 Materialize(*GVs[i]); 310 } else { 311 // Deleting. Materialize every GV that's *not* in GVs. 312 SmallPtrSet<GlobalValue *, 8> GVSet(GVs.begin(), GVs.end()); 313 for (auto &F : *M) { 314 if (!GVSet.count(&F)) 315 Materialize(F); 316 } 317 } 318 319 { 320 std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end()); 321 legacy::PassManager Extract; 322 Extract.add(createGVExtractionPass(Gvs, DeleteFn, KeepConstInit)); 323 Extract.run(*M); 324 325 // Now that we have all the GVs we want, mark the module as fully 326 // materialized. 327 // FIXME: should the GVExtractionPass handle this? 328 ExitOnErr(M->materializeAll()); 329 } 330 331 // Extract the specified basic blocks from the module and erase the existing 332 // functions. 333 if (!ExtractBlocks.empty()) { 334 // Figure out which BasicBlocks we should extract. 335 SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupOfBBs; 336 for (auto &P : BBMap) { 337 SmallVector<BasicBlock *, 16> BBs; 338 for (StringRef BBName : P.second) { 339 // The function has been materialized, so add its matching basic blocks 340 // to the block extractor list, or fail if a name is not found. 341 auto Res = llvm::find_if(*P.first, [&](const BasicBlock &BB) { 342 return BB.getName().equals(BBName); 343 }); 344 if (Res == P.first->end()) { 345 errs() << argv[0] << ": function " << P.first->getName() 346 << " doesn't contain a basic block named '" << BBName 347 << "'!\n"; 348 return 1; 349 } 350 BBs.push_back(&*Res); 351 } 352 GroupOfBBs.push_back(BBs); 353 } 354 355 legacy::PassManager PM; 356 PM.add(createBlockExtractorPass(GroupOfBBs, true)); 357 PM.run(*M); 358 } 359 360 // In addition to deleting all other functions, we also want to spiff it 361 // up a little bit. Do this now. 362 legacy::PassManager Passes; 363 364 if (!DeleteFn) 365 Passes.add(createGlobalDCEPass()); // Delete unreachable globals 366 Passes.add(createStripDeadDebugInfoPass()); // Remove dead debug info 367 Passes.add(createStripDeadPrototypesPass()); // Remove dead func decls 368 369 std::error_code EC; 370 ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None); 371 if (EC) { 372 errs() << EC.message() << '\n'; 373 return 1; 374 } 375 376 if (OutputAssembly) 377 Passes.add( 378 createPrintModulePass(Out.os(), "", PreserveAssemblyUseListOrder)); 379 else if (Force || !CheckBitcodeOutputToConsole(Out.os())) 380 Passes.add(createBitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder)); 381 382 Passes.run(*M.get()); 383 384 // Declare success. 385 Out.keep(); 386 387 return 0; 388 } 389