xref: /llvm-project/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp (revision 78b083dbb725e1ec568d1b8ee523f5f025d25798)
1 //===------- SimpleRemoteEPC.cpp -- Simple remote executor control --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h"
10 #include "llvm/ExecutionEngine/Orc/EPCGenericJITLinkMemoryManager.h"
11 #include "llvm/ExecutionEngine/Orc/EPCGenericMemoryAccess.h"
12 #include "llvm/ExecutionEngine/Orc/Shared/OrcRTBridge.h"
13 #include "llvm/Support/FormatVariadic.h"
14 
15 #define DEBUG_TYPE "orc"
16 
17 namespace llvm {
18 namespace orc {
19 namespace shared {
20 
21 template <>
22 class SPSSerializationTraits<SPSRemoteSymbolLookupSetElement,
23                              SymbolLookupSet::value_type> {
24 public:
25   static size_t size(const SymbolLookupSet::value_type &V) {
26     return SPSArgList<SPSString, bool>::size(
27         *V.first, V.second == SymbolLookupFlags::RequiredSymbol);
28   }
29 
30   static bool serialize(SPSOutputBuffer &OB,
31                         const SymbolLookupSet::value_type &V) {
32     return SPSArgList<SPSString, bool>::serialize(
33         OB, *V.first, V.second == SymbolLookupFlags::RequiredSymbol);
34   }
35 };
36 
37 template <>
38 class TrivialSPSSequenceSerialization<SPSRemoteSymbolLookupSetElement,
39                                       SymbolLookupSet> {
40 public:
41   static constexpr bool available = true;
42 };
43 
44 template <>
45 class SPSSerializationTraits<SPSRemoteSymbolLookup,
46                              ExecutorProcessControl::LookupRequest> {
47   using MemberSerialization =
48       SPSArgList<SPSExecutorAddress, SPSRemoteSymbolLookupSet>;
49 
50 public:
51   static size_t size(const ExecutorProcessControl::LookupRequest &LR) {
52     return MemberSerialization::size(ExecutorAddress(LR.Handle), LR.Symbols);
53   }
54 
55   static bool serialize(SPSOutputBuffer &OB,
56                         const ExecutorProcessControl::LookupRequest &LR) {
57     return MemberSerialization::serialize(OB, ExecutorAddress(LR.Handle),
58                                           LR.Symbols);
59   }
60 };
61 
62 } // end namespace shared
63 
64 SimpleRemoteEPC::~SimpleRemoteEPC() {
65   assert(Disconnected && "Destroyed without disconnection");
66 }
67 
68 Expected<tpctypes::DylibHandle>
69 SimpleRemoteEPC::loadDylib(const char *DylibPath) {
70   Expected<tpctypes::DylibHandle> H((tpctypes::DylibHandle()));
71   if (auto Err = callSPSWrapper<shared::SPSLoadDylibSignature>(
72           LoadDylibAddr.getValue(), H, JDI.JITDispatchContextAddress,
73           StringRef(DylibPath), (uint64_t)0))
74     return std::move(Err);
75   return H;
76 }
77 
78 Expected<std::vector<tpctypes::LookupResult>>
79 SimpleRemoteEPC::lookupSymbols(ArrayRef<LookupRequest> Request) {
80   Expected<std::vector<tpctypes::LookupResult>> R(
81       (std::vector<tpctypes::LookupResult>()));
82 
83   if (auto Err = callSPSWrapper<shared::SPSLookupSymbolsSignature>(
84           LookupSymbolsAddr.getValue(), R, JDI.JITDispatchContextAddress,
85           Request))
86     return std::move(Err);
87   return R;
88 }
89 
90 Expected<int32_t> SimpleRemoteEPC::runAsMain(JITTargetAddress MainFnAddr,
91                                              ArrayRef<std::string> Args) {
92   int64_t Result = 0;
93   if (auto Err = callSPSWrapper<rt::SPSRunAsMainSignature>(
94           RunAsMainAddr.getValue(), Result, ExecutorAddress(MainFnAddr), Args))
95     return std::move(Err);
96   return Result;
97 }
98 
99 void SimpleRemoteEPC::callWrapperAsync(SendResultFunction OnComplete,
100                                        JITTargetAddress WrapperFnAddr,
101                                        ArrayRef<char> ArgBuffer) {
102   uint64_t SeqNo;
103   {
104     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
105     SeqNo = getNextSeqNo();
106     assert(!PendingCallWrapperResults.count(SeqNo) && "SeqNo already in use");
107     PendingCallWrapperResults[SeqNo] = std::move(OnComplete);
108   }
109 
110   if (auto Err = T->sendMessage(SimpleRemoteEPCOpcode::CallWrapper, SeqNo,
111                                 ExecutorAddress(WrapperFnAddr), ArgBuffer)) {
112     getExecutionSession().reportError(std::move(Err));
113   }
114 }
115 
116 Error SimpleRemoteEPC::disconnect() {
117   Disconnected = true;
118   T->disconnect();
119   return Error::success();
120 }
121 
122 Expected<SimpleRemoteEPCTransportClient::HandleMessageAction>
123 SimpleRemoteEPC::handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
124                                ExecutorAddress TagAddr,
125                                SimpleRemoteEPCArgBytesVector ArgBytes) {
126   using UT = std::underlying_type_t<SimpleRemoteEPCOpcode>;
127   if (static_cast<UT>(OpC) > static_cast<UT>(SimpleRemoteEPCOpcode::LastOpC))
128     return make_error<StringError>("Unexpected opcode",
129                                    inconvertibleErrorCode());
130 
131   switch (OpC) {
132   case SimpleRemoteEPCOpcode::Setup:
133     if (auto Err = handleSetup(SeqNo, TagAddr, std::move(ArgBytes)))
134       return std::move(Err);
135     break;
136   case SimpleRemoteEPCOpcode::Hangup:
137     // FIXME: Put EPC into 'detached' state.
138     return SimpleRemoteEPCTransportClient::EndSession;
139   case SimpleRemoteEPCOpcode::Result:
140     if (auto Err = handleResult(SeqNo, TagAddr, std::move(ArgBytes)))
141       return std::move(Err);
142     break;
143   case SimpleRemoteEPCOpcode::CallWrapper:
144     handleCallWrapper(SeqNo, TagAddr, std::move(ArgBytes));
145     break;
146   }
147   return ContinueSession;
148 }
149 
150 void SimpleRemoteEPC::handleDisconnect(Error Err) {
151   PendingCallWrapperResultsMap TmpPending;
152 
153   {
154     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
155     std::swap(TmpPending, PendingCallWrapperResults);
156   }
157 
158   for (auto &KV : TmpPending)
159     KV.second(
160         shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
161 
162   if (Err) {
163     // FIXME: Move ReportError to EPC.
164     if (ES)
165       ES->reportError(std::move(Err));
166     else
167       logAllUnhandledErrors(std::move(Err), errs(), "SimpleRemoteEPC: ");
168   }
169 }
170 
171 Expected<std::unique_ptr<jitlink::JITLinkMemoryManager>>
172 SimpleRemoteEPC::createMemoryManager() {
173   EPCGenericJITLinkMemoryManager::SymbolAddrs SAs;
174   if (auto Err = getBootstrapSymbols(
175           {{SAs.Allocator, rt::SimpleExecutorMemoryManagerInstanceName},
176            {SAs.Reserve, rt::SimpleExecutorMemoryManagerReserveWrapperName},
177            {SAs.Finalize, rt::SimpleExecutorMemoryManagerFinalizeWrapperName},
178            {SAs.Deallocate,
179             rt::SimpleExecutorMemoryManagerDeallocateWrapperName}}))
180     return std::move(Err);
181 
182   return std::make_unique<EPCGenericJITLinkMemoryManager>(*this, SAs);
183 }
184 
185 Expected<std::unique_ptr<ExecutorProcessControl::MemoryAccess>>
186 SimpleRemoteEPC::createMemoryAccess() {
187 
188   return nullptr;
189 }
190 
191 Error SimpleRemoteEPC::handleSetup(uint64_t SeqNo, ExecutorAddress TagAddr,
192                                    SimpleRemoteEPCArgBytesVector ArgBytes) {
193   if (SeqNo != 0)
194     return make_error<StringError>("Setup packet SeqNo not zero",
195                                    inconvertibleErrorCode());
196 
197   if (TagAddr)
198     return make_error<StringError>("Setup packet TagAddr not zero",
199                                    inconvertibleErrorCode());
200 
201   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
202   auto I = PendingCallWrapperResults.find(0);
203   assert(PendingCallWrapperResults.size() == 1 &&
204          I != PendingCallWrapperResults.end() &&
205          "Setup message handler not connectly set up");
206   auto SetupMsgHandler = std::move(I->second);
207   PendingCallWrapperResults.erase(I);
208 
209   auto WFR =
210       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
211   SetupMsgHandler(std::move(WFR));
212   return Error::success();
213 }
214 
215 void SimpleRemoteEPC::prepareToReceiveSetupMessage(
216     std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> &ExecInfoP) {
217   PendingCallWrapperResults[0] =
218       [&](shared::WrapperFunctionResult SetupMsgBytes) {
219         if (const char *ErrMsg = SetupMsgBytes.getOutOfBandError()) {
220           ExecInfoP.set_value(
221               make_error<StringError>(ErrMsg, inconvertibleErrorCode()));
222           return;
223         }
224         using SPSSerialize =
225             shared::SPSArgList<shared::SPSSimpleRemoteEPCExecutorInfo>;
226         shared::SPSInputBuffer IB(SetupMsgBytes.data(), SetupMsgBytes.size());
227         SimpleRemoteEPCExecutorInfo EI;
228         if (SPSSerialize::deserialize(IB, EI))
229           ExecInfoP.set_value(EI);
230         else
231           ExecInfoP.set_value(make_error<StringError>(
232               "Could not deserialize setup message", inconvertibleErrorCode()));
233       };
234 }
235 
236 Error SimpleRemoteEPC::setup(std::unique_ptr<SimpleRemoteEPCTransport> T,
237                              SimpleRemoteEPCExecutorInfo EI) {
238   using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
239   LLVM_DEBUG({
240     dbgs() << "SimpleRemoteEPC received setup message:\n"
241            << "  Triple: " << EI.TargetTriple << "\n"
242            << "  Page size: " << EI.PageSize << "\n"
243            << "  Bootstrap symbols:\n";
244     for (const auto &KV : EI.BootstrapSymbols)
245       dbgs() << "    " << KV.first() << ": "
246              << formatv("{0:x16}", KV.second.getValue()) << "\n";
247   });
248   this->T = std::move(T);
249   TargetTriple = Triple(EI.TargetTriple);
250   PageSize = EI.PageSize;
251   BootstrapSymbols = std::move(EI.BootstrapSymbols);
252 
253   if (auto Err = getBootstrapSymbols(
254           {{JDI.JITDispatchContextAddress, ExecutorSessionObjectName},
255            {JDI.JITDispatchFunctionAddress, DispatchFnName},
256            {LoadDylibAddr, "__llvm_orc_load_dylib"},
257            {LookupSymbolsAddr, "__llvm_orc_lookup_symbols"},
258            {RunAsMainAddr, rt::RunAsMainWrapperName}}))
259     return Err;
260 
261   if (auto MemMgr = createMemoryManager()) {
262     OwnedMemMgr = std::move(*MemMgr);
263     this->MemMgr = OwnedMemMgr.get();
264   } else
265     return MemMgr.takeError();
266 
267   if (auto MemAccess = createMemoryAccess()) {
268     OwnedMemAccess = std::move(*MemAccess);
269     this->MemAccess = OwnedMemAccess.get();
270   } else
271     return MemAccess.takeError();
272 
273   return Error::success();
274 }
275 
276 Error SimpleRemoteEPC::handleResult(uint64_t SeqNo, ExecutorAddress TagAddr,
277                                     SimpleRemoteEPCArgBytesVector ArgBytes) {
278   SendResultFunction SendResult;
279 
280   if (TagAddr)
281     return make_error<StringError>("Unexpected TagAddr in result message",
282                                    inconvertibleErrorCode());
283 
284   {
285     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
286     auto I = PendingCallWrapperResults.find(SeqNo);
287     if (I == PendingCallWrapperResults.end())
288       return make_error<StringError>("No call for sequence number " +
289                                          Twine(SeqNo),
290                                      inconvertibleErrorCode());
291     SendResult = std::move(I->second);
292     PendingCallWrapperResults.erase(I);
293     releaseSeqNo(SeqNo);
294   }
295 
296   auto WFR =
297       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
298   SendResult(std::move(WFR));
299   return Error::success();
300 }
301 
302 void SimpleRemoteEPC::handleCallWrapper(
303     uint64_t RemoteSeqNo, ExecutorAddress TagAddr,
304     SimpleRemoteEPCArgBytesVector ArgBytes) {
305   assert(ES && "No ExecutionSession attached");
306   ES->runJITDispatchHandler(
307       [this, RemoteSeqNo](shared::WrapperFunctionResult WFR) {
308         if (auto Err =
309                 T->sendMessage(SimpleRemoteEPCOpcode::Result, RemoteSeqNo,
310                                ExecutorAddress(), {WFR.data(), WFR.size()}))
311           getExecutionSession().reportError(std::move(Err));
312       },
313       TagAddr.getValue(), ArgBytes);
314 }
315 
316 } // end namespace orc
317 } // end namespace llvm
318