1 //===--------------- LLJITWithCustomObjectLinkingLayer.cpp ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file shows how to switch LLJIT to use a custom object linking layer (we
10 // use ObjectLinkingLayer, which is backed by JITLink, as an example).
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/StringMap.h"
15 #include "llvm/ExecutionEngine/JITLink/JITLink.h"
16 #include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h"
17 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
18 #include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
19 #include "llvm/Support/InitLLVM.h"
20 #include "llvm/Support/TargetSelect.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 #include "../ExampleModules.h"
24 
25 using namespace llvm;
26 using namespace llvm::orc;
27 
28 ExitOnError ExitOnErr;
29 
30 const llvm::StringRef TestMod =
31     R"(
32   define i32 @callee() {
33   entry:
34     ret i32 7
35   }
36 
37   define i32 @entry() {
38   entry:
39     %0 = call i32 @callee()
40     ret i32 %0
41   }
42 )";
43 
44 class MyPlugin : public ObjectLinkingLayer::Plugin {
45 public:
46   // The modifyPassConfig callback gives us a chance to inspect the
47   // MaterializationResponsibility and target triple for the object being
48   // linked, then add any JITLink passes that we would like to run on the
49   // link graph. A pass is just a function object that is callable as
50   // Error(jitlink::LinkGraph&). In this case we will add two passes
51   // defined as lambdas that call the printLinkerGraph method on our
52   // plugin: One to run before the linker applies fixups and another to
53   // run afterwards.
54   void modifyPassConfig(MaterializationResponsibility &MR, const Triple &TT,
55                         jitlink::PassConfiguration &Config) override {
56     Config.PostPrunePasses.push_back([this](jitlink::LinkGraph &G) -> Error {
57       printLinkGraph(G, "Before fixup:");
58       return Error::success();
59     });
60     Config.PostFixupPasses.push_back([this](jitlink::LinkGraph &G) -> Error {
61       printLinkGraph(G, "After fixup:");
62       return Error::success();
63     });
64   }
65 
66   void notifyLoaded(MaterializationResponsibility &MR) override {
67     dbgs() << "Loading object defining " << MR.getSymbols() << "\n";
68   }
69 
70   Error notifyEmitted(MaterializationResponsibility &MR) override {
71     dbgs() << "Emitted object defining " << MR.getSymbols() << "\n";
72     return Error::success();
73   }
74 
75   Error notifyFailed(MaterializationResponsibility &MR) override {
76     return Error::success();
77   }
78 
79   Error notifyRemovingResources(ResourceKey K) override {
80     return Error::success();
81   }
82 
83   void notifyTransferringResources(ResourceKey DstKey,
84                                    ResourceKey SrcKey) override {}
85 
86 private:
87   void printLinkGraph(jitlink::LinkGraph &G, StringRef Title) {
88     constexpr JITTargetAddress LineWidth = 16;
89 
90     dbgs() << "--- " << Title << "---\n";
91     for (auto &S : G.sections()) {
92       dbgs() << "  section: " << S.getName() << "\n";
93       for (auto *B : S.blocks()) {
94         dbgs() << "    block@" << formatv("{0:x16}", B->getAddress()) << ":\n";
95 
96         if (B->isZeroFill())
97           continue;
98 
99         JITTargetAddress InitAddr = B->getAddress() & ~(LineWidth - 1);
100         JITTargetAddress StartAddr = B->getAddress();
101         JITTargetAddress EndAddr = B->getAddress() + B->getSize();
102         auto *Data = reinterpret_cast<const uint8_t *>(B->getContent().data());
103 
104         for (JITTargetAddress CurAddr = InitAddr; CurAddr != EndAddr;
105              ++CurAddr) {
106           if (CurAddr % LineWidth == 0)
107             dbgs() << "    " << formatv("{0:x16}", CurAddr) << ": ";
108           if (CurAddr < StartAddr)
109             dbgs() << "   ";
110           else
111             dbgs() << formatv("{0:x-2}", Data[CurAddr - StartAddr]) << " ";
112           if (CurAddr % LineWidth == LineWidth - 1)
113             dbgs() << "\n";
114         }
115         if (EndAddr % LineWidth != 0)
116           dbgs() << "\n";
117         dbgs() << "\n";
118       }
119     }
120   }
121 };
122 
123 int main(int argc, char *argv[]) {
124   // Initialize LLVM.
125   InitLLVM X(argc, argv);
126 
127   InitializeNativeTarget();
128   InitializeNativeTargetAsmPrinter();
129 
130   cl::ParseCommandLineOptions(argc, argv, "LLJITWithObjectLinkingLayerPlugin");
131   ExitOnErr.setBanner(std::string(argv[0]) + ": ");
132 
133   // Detect the host and set code model to small.
134   auto JTMB = ExitOnErr(JITTargetMachineBuilder::detectHost());
135   JTMB.setCodeModel(CodeModel::Small);
136 
137   // Create an LLJIT instance with an ObjectLinkingLayer as the base layer.
138   // We attach our plugin in to the newly created ObjectLinkingLayer before
139   // returning it.
140   auto J = ExitOnErr(
141       LLJITBuilder()
142           .setJITTargetMachineBuilder(std::move(JTMB))
143           .setObjectLinkingLayerCreator(
144               [&](ExecutionSession &ES, const Triple &TT) {
145                 // Create ObjectLinkingLayer.
146                 auto ObjLinkingLayer = std::make_unique<ObjectLinkingLayer>(
147                     ES, std::make_unique<jitlink::InProcessMemoryManager>());
148                 // Add an instance of our plugin.
149                 ObjLinkingLayer->addPlugin(std::make_unique<MyPlugin>());
150                 return ObjLinkingLayer;
151               })
152           .create());
153 
154   auto M = ExitOnErr(parseExampleModule(TestMod, "test-module"));
155 
156   ExitOnErr(J->addIRModule(std::move(M)));
157 
158   // Look up the JIT'd function, cast it to a function pointer, then call it.
159   auto EntrySym = ExitOnErr(J->lookup("entry"));
160   auto *Entry = (int (*)())EntrySym.getAddress();
161 
162   int Result = Entry();
163   outs() << "---Result---\n"
164          << "entry() = " << Result << "\n";
165 
166   return 0;
167 }
168