xref: /llvm-project/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp (revision 8201ae2aa662a1bcba80751f3ef162f228f626f7)
1 //===------- SimpleRemoteEPC.cpp -- Simple remote executor control --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h"
10 #include "llvm/ExecutionEngine/Orc/EPCGenericJITLinkMemoryManager.h"
11 #include "llvm/ExecutionEngine/Orc/Shared/OrcRTBridge.h"
12 #include "llvm/Support/FormatVariadic.h"
13 
14 #define DEBUG_TYPE "orc"
15 
16 namespace llvm {
17 namespace orc {
18 
19 SimpleRemoteEPC::~SimpleRemoteEPC() {
20 #ifndef NDEBUG
21   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
22   assert(Disconnected && "Destroyed without disconnection");
23 #endif // NDEBUG
24 }
25 
26 Expected<tpctypes::DylibHandle>
27 SimpleRemoteEPC::loadDylib(const char *DylibPath) {
28   return EPCDylibMgr->open(DylibPath, 0);
29 }
30 
31 /// Async helper to chain together calls to DylibMgr::lookupAsync to fulfill all
32 /// all the requests.
33 /// FIXME: The dylib manager should support multiple LookupRequests natively.
34 static void
35 lookupSymbolsAsyncHelper(EPCGenericDylibManager &DylibMgr,
36                          ArrayRef<DylibManager::LookupRequest> Request,
37                          std::vector<tpctypes::LookupResult> Result,
38                          DylibManager::SymbolLookupCompleteFn Complete) {
39   if (Request.empty())
40     return Complete(std::move(Result));
41 
42   auto &Element = Request.front();
43   DylibMgr.lookupAsync(Element.Handle, Element.Symbols,
44                        [&DylibMgr, Request, Complete = std::move(Complete),
45                         Result = std::move(Result)](auto R) mutable {
46                          if (!R)
47                            return Complete(R.takeError());
48                          Result.push_back({});
49                          Result.back().reserve(R->size());
50                          for (auto Addr : *R)
51                            Result.back().push_back(Addr);
52 
53                          lookupSymbolsAsyncHelper(
54                              DylibMgr, Request.drop_front(), std::move(Result),
55                              std::move(Complete));
56                        });
57 }
58 
59 void SimpleRemoteEPC::lookupSymbolsAsync(ArrayRef<LookupRequest> Request,
60                                          SymbolLookupCompleteFn Complete) {
61   lookupSymbolsAsyncHelper(*EPCDylibMgr, Request, {}, std::move(Complete));
62 }
63 
64 Expected<int32_t> SimpleRemoteEPC::runAsMain(ExecutorAddr MainFnAddr,
65                                              ArrayRef<std::string> Args) {
66   int64_t Result = 0;
67   if (auto Err = callSPSWrapper<rt::SPSRunAsMainSignature>(
68           RunAsMainAddr, Result, MainFnAddr, Args))
69     return std::move(Err);
70   return Result;
71 }
72 
73 Expected<int32_t> SimpleRemoteEPC::runAsVoidFunction(ExecutorAddr VoidFnAddr) {
74   int32_t Result = 0;
75   if (auto Err = callSPSWrapper<rt::SPSRunAsVoidFunctionSignature>(
76           RunAsVoidFunctionAddr, Result, VoidFnAddr))
77     return std::move(Err);
78   return Result;
79 }
80 
81 Expected<int32_t> SimpleRemoteEPC::runAsIntFunction(ExecutorAddr IntFnAddr,
82                                                     int Arg) {
83   int32_t Result = 0;
84   if (auto Err = callSPSWrapper<rt::SPSRunAsIntFunctionSignature>(
85           RunAsIntFunctionAddr, Result, IntFnAddr, Arg))
86     return std::move(Err);
87   return Result;
88 }
89 
90 void SimpleRemoteEPC::callWrapperAsync(ExecutorAddr WrapperFnAddr,
91                                        IncomingWFRHandler OnComplete,
92                                        ArrayRef<char> ArgBuffer) {
93   uint64_t SeqNo;
94   {
95     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
96     SeqNo = getNextSeqNo();
97     assert(!PendingCallWrapperResults.count(SeqNo) && "SeqNo already in use");
98     PendingCallWrapperResults[SeqNo] = std::move(OnComplete);
99   }
100 
101   if (auto Err = sendMessage(SimpleRemoteEPCOpcode::CallWrapper, SeqNo,
102                              WrapperFnAddr, ArgBuffer)) {
103     IncomingWFRHandler H;
104 
105     // We just registered OnComplete, but there may be a race between this
106     // thread returning from sendMessage and handleDisconnect being called from
107     // the transport's listener thread. If handleDisconnect gets there first
108     // then it will have failed 'H' for us. If we get there first (or if
109     // handleDisconnect already ran) then we need to take care of it.
110     {
111       std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
112       auto I = PendingCallWrapperResults.find(SeqNo);
113       if (I != PendingCallWrapperResults.end()) {
114         H = std::move(I->second);
115         PendingCallWrapperResults.erase(I);
116       }
117     }
118 
119     if (H)
120       H(shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
121 
122     getExecutionSession().reportError(std::move(Err));
123   }
124 }
125 
126 Error SimpleRemoteEPC::disconnect() {
127   T->disconnect();
128   D->shutdown();
129   std::unique_lock<std::mutex> Lock(SimpleRemoteEPCMutex);
130   DisconnectCV.wait(Lock, [this] { return Disconnected; });
131   return std::move(DisconnectErr);
132 }
133 
134 Expected<SimpleRemoteEPCTransportClient::HandleMessageAction>
135 SimpleRemoteEPC::handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
136                                ExecutorAddr TagAddr,
137                                SimpleRemoteEPCArgBytesVector ArgBytes) {
138 
139   LLVM_DEBUG({
140     dbgs() << "SimpleRemoteEPC::handleMessage: opc = ";
141     switch (OpC) {
142     case SimpleRemoteEPCOpcode::Setup:
143       dbgs() << "Setup";
144       assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");
145       assert(!TagAddr && "Non-zero TagAddr for Setup?");
146       break;
147     case SimpleRemoteEPCOpcode::Hangup:
148       dbgs() << "Hangup";
149       assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
150       assert(!TagAddr && "Non-zero TagAddr for Hangup?");
151       break;
152     case SimpleRemoteEPCOpcode::Result:
153       dbgs() << "Result";
154       assert(!TagAddr && "Non-zero TagAddr for Result?");
155       break;
156     case SimpleRemoteEPCOpcode::CallWrapper:
157       dbgs() << "CallWrapper";
158       break;
159     }
160     dbgs() << ", seqno = " << SeqNo << ", tag-addr = " << TagAddr
161            << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
162            << " bytes\n";
163   });
164 
165   using UT = std::underlying_type_t<SimpleRemoteEPCOpcode>;
166   if (static_cast<UT>(OpC) > static_cast<UT>(SimpleRemoteEPCOpcode::LastOpC))
167     return make_error<StringError>("Unexpected opcode",
168                                    inconvertibleErrorCode());
169 
170   switch (OpC) {
171   case SimpleRemoteEPCOpcode::Setup:
172     if (auto Err = handleSetup(SeqNo, TagAddr, std::move(ArgBytes)))
173       return std::move(Err);
174     break;
175   case SimpleRemoteEPCOpcode::Hangup:
176     T->disconnect();
177     if (auto Err = handleHangup(std::move(ArgBytes)))
178       return std::move(Err);
179     return EndSession;
180   case SimpleRemoteEPCOpcode::Result:
181     if (auto Err = handleResult(SeqNo, TagAddr, std::move(ArgBytes)))
182       return std::move(Err);
183     break;
184   case SimpleRemoteEPCOpcode::CallWrapper:
185     handleCallWrapper(SeqNo, TagAddr, std::move(ArgBytes));
186     break;
187   }
188   return ContinueSession;
189 }
190 
191 void SimpleRemoteEPC::handleDisconnect(Error Err) {
192   LLVM_DEBUG({
193     dbgs() << "SimpleRemoteEPC::handleDisconnect: "
194            << (Err ? "failure" : "success") << "\n";
195   });
196 
197   PendingCallWrapperResultsMap TmpPending;
198 
199   {
200     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
201     std::swap(TmpPending, PendingCallWrapperResults);
202   }
203 
204   for (auto &KV : TmpPending)
205     KV.second(
206         shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
207 
208   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
209   DisconnectErr = joinErrors(std::move(DisconnectErr), std::move(Err));
210   Disconnected = true;
211   DisconnectCV.notify_all();
212 }
213 
214 Expected<std::unique_ptr<jitlink::JITLinkMemoryManager>>
215 SimpleRemoteEPC::createDefaultMemoryManager(SimpleRemoteEPC &SREPC) {
216   EPCGenericJITLinkMemoryManager::SymbolAddrs SAs;
217   if (auto Err = SREPC.getBootstrapSymbols(
218           {{SAs.Allocator, rt::SimpleExecutorMemoryManagerInstanceName},
219            {SAs.Reserve, rt::SimpleExecutorMemoryManagerReserveWrapperName},
220            {SAs.Finalize, rt::SimpleExecutorMemoryManagerFinalizeWrapperName},
221            {SAs.Deallocate,
222             rt::SimpleExecutorMemoryManagerDeallocateWrapperName}}))
223     return std::move(Err);
224 
225   return std::make_unique<EPCGenericJITLinkMemoryManager>(SREPC, SAs);
226 }
227 
228 Expected<std::unique_ptr<ExecutorProcessControl::MemoryAccess>>
229 SimpleRemoteEPC::createDefaultMemoryAccess(SimpleRemoteEPC &SREPC) {
230   EPCGenericMemoryAccess::FuncAddrs FAs;
231   if (auto Err = SREPC.getBootstrapSymbols(
232           {{FAs.WriteUInt8s, rt::MemoryWriteUInt8sWrapperName},
233            {FAs.WriteUInt16s, rt::MemoryWriteUInt16sWrapperName},
234            {FAs.WriteUInt32s, rt::MemoryWriteUInt32sWrapperName},
235            {FAs.WriteUInt64s, rt::MemoryWriteUInt64sWrapperName},
236            {FAs.WriteBuffers, rt::MemoryWriteBuffersWrapperName},
237            {FAs.WritePointers, rt::MemoryWritePointersWrapperName}}))
238     return std::move(Err);
239 
240   return std::make_unique<EPCGenericMemoryAccess>(SREPC, FAs);
241 }
242 
243 Error SimpleRemoteEPC::sendMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
244                                    ExecutorAddr TagAddr,
245                                    ArrayRef<char> ArgBytes) {
246   assert(OpC != SimpleRemoteEPCOpcode::Setup &&
247          "SimpleRemoteEPC sending Setup message? That's the wrong direction.");
248 
249   LLVM_DEBUG({
250     dbgs() << "SimpleRemoteEPC::sendMessage: opc = ";
251     switch (OpC) {
252     case SimpleRemoteEPCOpcode::Hangup:
253       dbgs() << "Hangup";
254       assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
255       assert(!TagAddr && "Non-zero TagAddr for Hangup?");
256       break;
257     case SimpleRemoteEPCOpcode::Result:
258       dbgs() << "Result";
259       assert(!TagAddr && "Non-zero TagAddr for Result?");
260       break;
261     case SimpleRemoteEPCOpcode::CallWrapper:
262       dbgs() << "CallWrapper";
263       break;
264     default:
265       llvm_unreachable("Invalid opcode");
266     }
267     dbgs() << ", seqno = " << SeqNo << ", tag-addr = " << TagAddr
268            << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
269            << " bytes\n";
270   });
271   auto Err = T->sendMessage(OpC, SeqNo, TagAddr, ArgBytes);
272   LLVM_DEBUG({
273     if (Err)
274       dbgs() << "  \\--> SimpleRemoteEPC::sendMessage failed\n";
275   });
276   return Err;
277 }
278 
279 Error SimpleRemoteEPC::handleSetup(uint64_t SeqNo, ExecutorAddr TagAddr,
280                                    SimpleRemoteEPCArgBytesVector ArgBytes) {
281   if (SeqNo != 0)
282     return make_error<StringError>("Setup packet SeqNo not zero",
283                                    inconvertibleErrorCode());
284 
285   if (TagAddr)
286     return make_error<StringError>("Setup packet TagAddr not zero",
287                                    inconvertibleErrorCode());
288 
289   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
290   auto I = PendingCallWrapperResults.find(0);
291   assert(PendingCallWrapperResults.size() == 1 &&
292          I != PendingCallWrapperResults.end() &&
293          "Setup message handler not connectly set up");
294   auto SetupMsgHandler = std::move(I->second);
295   PendingCallWrapperResults.erase(I);
296 
297   auto WFR =
298       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
299   SetupMsgHandler(std::move(WFR));
300   return Error::success();
301 }
302 
303 Error SimpleRemoteEPC::setup(Setup S) {
304   using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
305 
306   std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> EIP;
307   auto EIF = EIP.get_future();
308 
309   // Prepare a handler for the setup packet.
310   PendingCallWrapperResults[0] =
311     RunInPlace()(
312       [&](shared::WrapperFunctionResult SetupMsgBytes) {
313         if (const char *ErrMsg = SetupMsgBytes.getOutOfBandError()) {
314           EIP.set_value(
315               make_error<StringError>(ErrMsg, inconvertibleErrorCode()));
316           return;
317         }
318         using SPSSerialize =
319             shared::SPSArgList<shared::SPSSimpleRemoteEPCExecutorInfo>;
320         shared::SPSInputBuffer IB(SetupMsgBytes.data(), SetupMsgBytes.size());
321         SimpleRemoteEPCExecutorInfo EI;
322         if (SPSSerialize::deserialize(IB, EI))
323           EIP.set_value(EI);
324         else
325           EIP.set_value(make_error<StringError>(
326               "Could not deserialize setup message", inconvertibleErrorCode()));
327       });
328 
329   // Start the transport.
330   if (auto Err = T->start())
331     return Err;
332 
333   // Wait for setup packet to arrive.
334   auto EI = EIF.get();
335   if (!EI) {
336     T->disconnect();
337     return EI.takeError();
338   }
339 
340   LLVM_DEBUG({
341     dbgs() << "SimpleRemoteEPC received setup message:\n"
342            << "  Triple: " << EI->TargetTriple << "\n"
343            << "  Page size: " << EI->PageSize << "\n"
344            << "  Bootstrap map" << (EI->BootstrapMap.empty() ? " empty" : ":")
345            << "\n";
346     for (const auto &KV : EI->BootstrapMap)
347       dbgs() << "    " << KV.first() << ": " << KV.second.size()
348              << "-byte SPS encoded buffer\n";
349     dbgs() << "  Bootstrap symbols"
350            << (EI->BootstrapSymbols.empty() ? " empty" : ":") << "\n";
351     for (const auto &KV : EI->BootstrapSymbols)
352       dbgs() << "    " << KV.first() << ": " << KV.second << "\n";
353   });
354   TargetTriple = Triple(EI->TargetTriple);
355   PageSize = EI->PageSize;
356   BootstrapMap = std::move(EI->BootstrapMap);
357   BootstrapSymbols = std::move(EI->BootstrapSymbols);
358 
359   if (auto Err = getBootstrapSymbols(
360           {{JDI.JITDispatchContext, ExecutorSessionObjectName},
361            {JDI.JITDispatchFunction, DispatchFnName},
362            {RunAsMainAddr, rt::RunAsMainWrapperName},
363            {RunAsVoidFunctionAddr, rt::RunAsVoidFunctionWrapperName},
364            {RunAsIntFunctionAddr, rt::RunAsIntFunctionWrapperName}}))
365     return Err;
366 
367   if (auto DM =
368           EPCGenericDylibManager::CreateWithDefaultBootstrapSymbols(*this))
369     EPCDylibMgr = std::make_unique<EPCGenericDylibManager>(std::move(*DM));
370   else
371     return DM.takeError();
372 
373   // Set a default CreateMemoryManager if none is specified.
374   if (!S.CreateMemoryManager)
375     S.CreateMemoryManager = createDefaultMemoryManager;
376 
377   if (auto MemMgr = S.CreateMemoryManager(*this)) {
378     OwnedMemMgr = std::move(*MemMgr);
379     this->MemMgr = OwnedMemMgr.get();
380   } else
381     return MemMgr.takeError();
382 
383   // Set a default CreateMemoryAccess if none is specified.
384   if (!S.CreateMemoryAccess)
385     S.CreateMemoryAccess = createDefaultMemoryAccess;
386 
387   if (auto MemAccess = S.CreateMemoryAccess(*this)) {
388     OwnedMemAccess = std::move(*MemAccess);
389     this->MemAccess = OwnedMemAccess.get();
390   } else
391     return MemAccess.takeError();
392 
393   return Error::success();
394 }
395 
396 Error SimpleRemoteEPC::handleResult(uint64_t SeqNo, ExecutorAddr TagAddr,
397                                     SimpleRemoteEPCArgBytesVector ArgBytes) {
398   IncomingWFRHandler SendResult;
399 
400   if (TagAddr)
401     return make_error<StringError>("Unexpected TagAddr in result message",
402                                    inconvertibleErrorCode());
403 
404   {
405     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
406     auto I = PendingCallWrapperResults.find(SeqNo);
407     if (I == PendingCallWrapperResults.end())
408       return make_error<StringError>("No call for sequence number " +
409                                          Twine(SeqNo),
410                                      inconvertibleErrorCode());
411     SendResult = std::move(I->second);
412     PendingCallWrapperResults.erase(I);
413     releaseSeqNo(SeqNo);
414   }
415 
416   auto WFR =
417       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
418   SendResult(std::move(WFR));
419   return Error::success();
420 }
421 
422 void SimpleRemoteEPC::handleCallWrapper(
423     uint64_t RemoteSeqNo, ExecutorAddr TagAddr,
424     SimpleRemoteEPCArgBytesVector ArgBytes) {
425   assert(ES && "No ExecutionSession attached");
426   D->dispatch(makeGenericNamedTask(
427       [this, RemoteSeqNo, TagAddr, ArgBytes = std::move(ArgBytes)]() {
428         ES->runJITDispatchHandler(
429             [this, RemoteSeqNo](shared::WrapperFunctionResult WFR) {
430               if (auto Err =
431                       sendMessage(SimpleRemoteEPCOpcode::Result, RemoteSeqNo,
432                                   ExecutorAddr(), {WFR.data(), WFR.size()}))
433                 getExecutionSession().reportError(std::move(Err));
434             },
435             TagAddr, ArgBytes);
436       },
437       "callWrapper task"));
438 }
439 
440 Error SimpleRemoteEPC::handleHangup(SimpleRemoteEPCArgBytesVector ArgBytes) {
441   using namespace llvm::orc::shared;
442   auto WFR = WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
443   if (const char *ErrMsg = WFR.getOutOfBandError())
444     return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
445 
446   detail::SPSSerializableError Info;
447   SPSInputBuffer IB(WFR.data(), WFR.size());
448   if (!SPSArgList<SPSError>::deserialize(IB, Info))
449     return make_error<StringError>("Could not deserialize hangup info",
450                                    inconvertibleErrorCode());
451   return fromSPSSerializable(std::move(Info));
452 }
453 
454 } // end namespace orc
455 } // end namespace llvm
456