1 //===--------------- LLJITWithCustomObjectLinkingLayer.cpp ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file shows how to switch LLJIT to use a custom object linking layer (we
10 // use ObjectLinkingLayer, which is backed by JITLink, as an example).
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/StringMap.h"
15 #include "llvm/ExecutionEngine/JITLink/JITLink.h"
16 #include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h"
17 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
18 #include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
19 #include "llvm/Support/InitLLVM.h"
20 #include "llvm/Support/TargetSelect.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 #include "../ExampleModules.h"
24 
25 using namespace llvm;
26 using namespace llvm::orc;
27 
28 ExitOnError ExitOnErr;
29 
30 const llvm::StringRef TestMod =
31     R"(
32   define i32 @callee() {
33   entry:
34     ret i32 7
35   }
36 
37   define i32 @entry() {
38   entry:
39     %0 = call i32 @callee()
40     ret i32 %0
41   }
42 )";
43 
44 class MyPlugin : public ObjectLinkingLayer::Plugin {
45 public:
46   // The modifyPassConfig callback gives us a chance to inspect the
47   // MaterializationResponsibility and target triple for the object being
48   // linked, then add any JITLink passes that we would like to run on the
49   // link graph. A pass is just a function object that is callable as
50   // Error(jitlink::LinkGraph&). In this case we will add two passes
51   // defined as lambdas that call the printLinkerGraph method on our
52   // plugin: One to run before the linker applies fixups and another to
53   // run afterwards.
54   void modifyPassConfig(MaterializationResponsibility &MR, const Triple &TT,
55                         jitlink::PassConfiguration &Config) override {
56     Config.PostPrunePasses.push_back([this](jitlink::LinkGraph &G) -> Error {
57       printLinkGraph(G, "Before fixup:");
58       return Error::success();
59     });
60     Config.PostFixupPasses.push_back([this](jitlink::LinkGraph &G) -> Error {
61       printLinkGraph(G, "After fixup:");
62       return Error::success();
63     });
64   }
65 
66   void notifyLoaded(MaterializationResponsibility &MR) override {
67     dbgs() << "Loading object defining " << MR.getSymbols() << "\n";
68   }
69 
70   Error notifyEmitted(MaterializationResponsibility &MR) override {
71     dbgs() << "Emitted object defining " << MR.getSymbols() << "\n";
72     return Error::success();
73   }
74 
75 private:
76   void printLinkGraph(jitlink::LinkGraph &G, StringRef Title) {
77     constexpr JITTargetAddress LineWidth = 16;
78 
79     dbgs() << "--- " << Title << "---\n";
80     for (auto &S : G.sections()) {
81       dbgs() << "  section: " << S.getName() << "\n";
82       for (auto *B : S.blocks()) {
83         dbgs() << "    block@" << formatv("{0:x16}", B->getAddress()) << ":\n";
84 
85         if (B->isZeroFill())
86           continue;
87 
88         JITTargetAddress InitAddr = B->getAddress() & ~(LineWidth - 1);
89         JITTargetAddress StartAddr = B->getAddress();
90         JITTargetAddress EndAddr = B->getAddress() + B->getSize();
91         auto *Data = reinterpret_cast<const uint8_t *>(B->getContent().data());
92 
93         for (JITTargetAddress CurAddr = InitAddr; CurAddr != EndAddr;
94              ++CurAddr) {
95           if (CurAddr % LineWidth == 0)
96             dbgs() << "    " << formatv("{0:x16}", CurAddr) << ": ";
97           if (CurAddr < StartAddr)
98             dbgs() << "   ";
99           else
100             dbgs() << formatv("{0:x-2}", Data[CurAddr - StartAddr]) << " ";
101           if (CurAddr % LineWidth == LineWidth - 1)
102             dbgs() << "\n";
103         }
104         if (EndAddr % LineWidth != 0)
105           dbgs() << "\n";
106         dbgs() << "\n";
107       }
108     }
109   }
110 };
111 
112 int main(int argc, char *argv[]) {
113   // Initialize LLVM.
114   InitLLVM X(argc, argv);
115 
116   InitializeNativeTarget();
117   InitializeNativeTargetAsmPrinter();
118 
119   cl::ParseCommandLineOptions(argc, argv, "LLJITWithObjectLinkingLayerPlugin");
120   ExitOnErr.setBanner(std::string(argv[0]) + ": ");
121 
122   // Detect the host and set code model to small.
123   auto JTMB = ExitOnErr(JITTargetMachineBuilder::detectHost());
124   JTMB.setCodeModel(CodeModel::Small);
125 
126   // Create an LLJIT instance with an ObjectLinkingLayer as the base layer.
127   // We attach our plugin in to the newly created ObjectLinkingLayer before
128   // returning it.
129   auto J = ExitOnErr(
130       LLJITBuilder()
131           .setJITTargetMachineBuilder(std::move(JTMB))
132           .setObjectLinkingLayerCreator(
133               [&](ExecutionSession &ES, const Triple &TT) {
134                 // Create ObjectLinkingLayer.
135                 auto ObjLinkingLayer = std::make_unique<ObjectLinkingLayer>(
136                     ES, std::make_unique<jitlink::InProcessMemoryManager>());
137                 // Add an instance of our plugin.
138                 ObjLinkingLayer->addPlugin(std::make_unique<MyPlugin>());
139                 return ObjLinkingLayer;
140               })
141           .create());
142 
143   auto M = ExitOnErr(parseExampleModule(TestMod, "test-module"));
144 
145   ExitOnErr(J->addIRModule(std::move(M)));
146 
147   // Look up the JIT'd function, cast it to a function pointer, then call it.
148   auto EntrySym = ExitOnErr(J->lookup("entry"));
149   auto *Entry = (int (*)())EntrySym.getAddress();
150 
151   int Result = Entry();
152   outs() << "---Result---\n"
153          << "entry() = " << Result << "\n";
154 
155   return 0;
156 }
157