1 //===-- LLJITWithOptimizingIRTransform.cpp -- LLJIT with IR optimization --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // In this example we will use an IR transform to optimize a module as it 10 // passes through LLJIT's IRTransformLayer. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ExecutionEngine/Orc/LLJIT.h" 15 #include "llvm/IR/LegacyPassManager.h" 16 #include "llvm/Pass.h" 17 #include "llvm/Support/InitLLVM.h" 18 #include "llvm/Support/TargetSelect.h" 19 #include "llvm/Support/raw_ostream.h" 20 #include "llvm/Transforms/IPO.h" 21 #include "llvm/Transforms/Scalar.h" 22 23 #include "../ExampleModules.h" 24 25 using namespace llvm; 26 using namespace llvm::orc; 27 28 ExitOnError ExitOnErr; 29 30 // Example IR module. 31 // 32 // This IR contains a recursive definition of the factorial function: 33 // 34 // fac(n) | n == 0 = 1 35 // | otherwise = n * fac(n - 1) 36 // 37 // It also contains an entry function which calls the factorial function with 38 // an input value of 5. 39 // 40 // We expect the IR optimization transform that we build below to transform 41 // this into a non-recursive factorial function and an entry function that 42 // returns a constant value of 5!, or 120. 43 44 const llvm::StringRef MainMod = 45 R"( 46 47 define i32 @fac(i32 %n) { 48 entry: 49 %tobool = icmp eq i32 %n, 0 50 br i1 %tobool, label %return, label %if.then 51 52 if.then: ; preds = %entry 53 %arg = add nsw i32 %n, -1 54 %call_result = call i32 @fac(i32 %arg) 55 %result = mul nsw i32 %n, %call_result 56 br label %return 57 58 return: ; preds = %entry, %if.then 59 %final_result = phi i32 [ %result, %if.then ], [ 1, %entry ] 60 ret i32 %final_result 61 } 62 63 define i32 @entry() { 64 entry: 65 %result = call i32 @fac(i32 5) 66 ret i32 %result 67 } 68 69 )"; 70 71 // A function object that creates a simple pass pipeline to apply to each 72 // module as it passes through the IRTransformLayer. 73 class MyOptimizationTransform { 74 public: 75 MyOptimizationTransform() : PM(std::make_unique<legacy::PassManager>()) { 76 PM->add(createTailCallEliminationPass()); 77 PM->add(createFunctionInliningPass()); 78 PM->add(createIndVarSimplifyPass()); 79 PM->add(createCFGSimplificationPass()); 80 } 81 82 Expected<ThreadSafeModule> operator()(ThreadSafeModule TSM, 83 MaterializationResponsibility &R) { 84 TSM.withModuleDo([this](Module &M) { 85 dbgs() << "--- BEFORE OPTIMIZATION ---\n" << M << "\n"; 86 PM->run(M); 87 dbgs() << "--- AFTER OPTIMIZATION ---\n" << M << "\n"; 88 }); 89 return std::move(TSM); 90 } 91 92 private: 93 std::unique_ptr<legacy::PassManager> PM; 94 }; 95 96 int main(int argc, char *argv[]) { 97 // Initialize LLVM. 98 InitLLVM X(argc, argv); 99 100 InitializeNativeTarget(); 101 InitializeNativeTargetAsmPrinter(); 102 103 ExitOnErr.setBanner(std::string(argv[0]) + ": "); 104 105 // (1) Create LLJIT instance. 106 auto J = ExitOnErr(LLJITBuilder().create()); 107 108 // (2) Install transform to optimize modules when they're materialized. 109 J->getIRTransformLayer().setTransform(MyOptimizationTransform()); 110 111 // (3) Add modules. 112 ExitOnErr(J->addIRModule(ExitOnErr(parseExampleModule(MainMod, "MainMod")))); 113 114 // (4) Look up the JIT'd function and call it. 115 auto EntrySym = ExitOnErr(J->lookup("entry")); 116 auto *Entry = (int (*)())EntrySym.getAddress(); 117 118 int Result = Entry(); 119 outs() << "--- Result ---\n" 120 << "entry() = " << Result << "\n"; 121 122 return 0; 123 } 124