xref: /llvm-project/llvm/lib/ExecutionEngine/Orc/TargetProcess/SimpleRemoteEPCServer.cpp (revision dfd74db9813b0c7c64038c303726ba43f335e07a)
1 //===------- SimpleEPCServer.cpp - EPC over simple abstract channel -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ExecutionEngine/Orc/TargetProcess/SimpleRemoteEPCServer.h"
10 
11 #include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h"
12 #include "llvm/Support/FormatVariadic.h"
13 #include "llvm/Support/Host.h"
14 #include "llvm/Support/Process.h"
15 
16 #include "OrcRTBootstrap.h"
17 
18 #define DEBUG_TYPE "orc"
19 
20 using namespace llvm::orc::shared;
21 
22 namespace llvm {
23 namespace orc {
24 
25 ExecutorBootstrapService::~ExecutorBootstrapService() {}
26 
27 StringMap<ExecutorAddr> SimpleRemoteEPCServer::defaultBootstrapSymbols() {
28   StringMap<ExecutorAddr> DBS;
29   rt_bootstrap::addTo(DBS);
30   return DBS;
31 }
32 
33 Expected<SimpleRemoteEPCTransportClient::HandleMessageAction>
34 SimpleRemoteEPCServer::handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
35                                      ExecutorAddr TagAddr,
36                                      SimpleRemoteEPCArgBytesVector ArgBytes) {
37 
38   LLVM_DEBUG({
39     dbgs() << "SimpleRemoteEPCServer::handleMessage: opc = ";
40     switch (OpC) {
41     case SimpleRemoteEPCOpcode::Setup:
42       dbgs() << "Setup";
43       assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");
44       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Setup?");
45       break;
46     case SimpleRemoteEPCOpcode::Hangup:
47       dbgs() << "Hangup";
48       assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
49       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
50       break;
51     case SimpleRemoteEPCOpcode::Result:
52       dbgs() << "Result";
53       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
54       break;
55     case SimpleRemoteEPCOpcode::CallWrapper:
56       dbgs() << "CallWrapper";
57       break;
58     }
59     dbgs() << ", seqno = " << SeqNo
60            << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
61            << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
62            << " bytes\n";
63   });
64 
65   using UT = std::underlying_type_t<SimpleRemoteEPCOpcode>;
66   if (static_cast<UT>(OpC) > static_cast<UT>(SimpleRemoteEPCOpcode::LastOpC))
67     return make_error<StringError>("Unexpected opcode",
68                                    inconvertibleErrorCode());
69 
70   // TODO: Clean detach message?
71   switch (OpC) {
72   case SimpleRemoteEPCOpcode::Setup:
73     return make_error<StringError>("Unexpected Setup opcode",
74                                    inconvertibleErrorCode());
75   case SimpleRemoteEPCOpcode::Hangup:
76     return SimpleRemoteEPCTransportClient::EndSession;
77   case SimpleRemoteEPCOpcode::Result:
78     if (auto Err = handleResult(SeqNo, TagAddr, std::move(ArgBytes)))
79       return std::move(Err);
80     break;
81   case SimpleRemoteEPCOpcode::CallWrapper:
82     handleCallWrapper(SeqNo, TagAddr, std::move(ArgBytes));
83     break;
84   }
85   return ContinueSession;
86 }
87 
88 Error SimpleRemoteEPCServer::waitForDisconnect() {
89   std::unique_lock<std::mutex> Lock(ServerStateMutex);
90   ShutdownCV.wait(Lock, [this]() { return RunState == ServerShutDown; });
91   return std::move(ShutdownErr);
92 }
93 
94 void SimpleRemoteEPCServer::handleDisconnect(Error Err) {
95   PendingJITDispatchResultsMap TmpPending;
96 
97   {
98     std::lock_guard<std::mutex> Lock(ServerStateMutex);
99     std::swap(TmpPending, PendingJITDispatchResults);
100     RunState = ServerShuttingDown;
101   }
102 
103   // Send out-of-band errors to any waiting threads.
104   for (auto &KV : TmpPending)
105     KV.second->set_value(
106         shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
107 
108   // TODO: Free attached resources.
109   // 1. Close libraries in DylibHandles.
110 
111   // Wait for dispatcher to clear.
112   D->shutdown();
113 
114   // Shut down services.
115   while (!Services.empty()) {
116     ShutdownErr =
117       joinErrors(std::move(ShutdownErr), Services.back()->shutdown());
118     Services.pop_back();
119   }
120 
121   std::lock_guard<std::mutex> Lock(ServerStateMutex);
122   ShutdownErr = joinErrors(std::move(ShutdownErr), std::move(Err));
123   RunState = ServerShutDown;
124   ShutdownCV.notify_all();
125 }
126 
127 Error SimpleRemoteEPCServer::sendMessage(SimpleRemoteEPCOpcode OpC,
128                                          uint64_t SeqNo, ExecutorAddr TagAddr,
129                                          ArrayRef<char> ArgBytes) {
130 
131   LLVM_DEBUG({
132     dbgs() << "SimpleRemoteEPCServer::sendMessage: opc = ";
133     switch (OpC) {
134     case SimpleRemoteEPCOpcode::Setup:
135       dbgs() << "Setup";
136       assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");
137       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Setup?");
138       break;
139     case SimpleRemoteEPCOpcode::Hangup:
140       dbgs() << "Hangup";
141       assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
142       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
143       break;
144     case SimpleRemoteEPCOpcode::Result:
145       dbgs() << "Result";
146       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
147       break;
148     case SimpleRemoteEPCOpcode::CallWrapper:
149       dbgs() << "CallWrapper";
150       break;
151     }
152     dbgs() << ", seqno = " << SeqNo
153            << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
154            << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
155            << " bytes\n";
156   });
157   auto Err = T->sendMessage(OpC, SeqNo, TagAddr, ArgBytes);
158   LLVM_DEBUG({
159     if (Err)
160       dbgs() << "  \\--> SimpleRemoteEPC::sendMessage failed\n";
161   });
162   return Err;
163 }
164 
165 Error SimpleRemoteEPCServer::sendSetupMessage(
166     StringMap<ExecutorAddr> BootstrapSymbols) {
167 
168   using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
169 
170   std::vector<char> SetupPacket;
171   SimpleRemoteEPCExecutorInfo EI;
172   EI.TargetTriple = sys::getProcessTriple();
173   if (auto PageSize = sys::Process::getPageSize())
174     EI.PageSize = *PageSize;
175   else
176     return PageSize.takeError();
177   EI.BootstrapSymbols = std::move(BootstrapSymbols);
178 
179   assert(!EI.BootstrapSymbols.count(ExecutorSessionObjectName) &&
180          "Dispatch context name should not be set");
181   assert(!EI.BootstrapSymbols.count(DispatchFnName) &&
182          "Dispatch function name should not be set");
183   EI.BootstrapSymbols[ExecutorSessionObjectName] = ExecutorAddr::fromPtr(this);
184   EI.BootstrapSymbols[DispatchFnName] = ExecutorAddr::fromPtr(jitDispatchEntry);
185 
186   using SPSSerialize =
187       shared::SPSArgList<shared::SPSSimpleRemoteEPCExecutorInfo>;
188   auto SetupPacketBytes =
189       shared::WrapperFunctionResult::allocate(SPSSerialize::size(EI));
190   shared::SPSOutputBuffer OB(SetupPacketBytes.data(), SetupPacketBytes.size());
191   if (!SPSSerialize::serialize(OB, EI))
192     return make_error<StringError>("Could not send setup packet",
193                                    inconvertibleErrorCode());
194 
195   return sendMessage(SimpleRemoteEPCOpcode::Setup, 0, ExecutorAddr(),
196                      {SetupPacketBytes.data(), SetupPacketBytes.size()});
197 }
198 
199 Error SimpleRemoteEPCServer::handleResult(
200     uint64_t SeqNo, ExecutorAddr TagAddr,
201     SimpleRemoteEPCArgBytesVector ArgBytes) {
202   std::promise<shared::WrapperFunctionResult> *P = nullptr;
203   {
204     std::lock_guard<std::mutex> Lock(ServerStateMutex);
205     auto I = PendingJITDispatchResults.find(SeqNo);
206     if (I == PendingJITDispatchResults.end())
207       return make_error<StringError>("No call for sequence number " +
208                                          Twine(SeqNo),
209                                      inconvertibleErrorCode());
210     P = I->second;
211     PendingJITDispatchResults.erase(I);
212     releaseSeqNo(SeqNo);
213   }
214   auto R = shared::WrapperFunctionResult::allocate(ArgBytes.size());
215   memcpy(R.data(), ArgBytes.data(), ArgBytes.size());
216   P->set_value(std::move(R));
217   return Error::success();
218 }
219 
220 void SimpleRemoteEPCServer::handleCallWrapper(
221     uint64_t RemoteSeqNo, ExecutorAddr TagAddr,
222     SimpleRemoteEPCArgBytesVector ArgBytes) {
223   D->dispatch([this, RemoteSeqNo, TagAddr, ArgBytes = std::move(ArgBytes)]() {
224     using WrapperFnTy =
225         shared::detail::CWrapperFunctionResult (*)(const char *, size_t);
226     auto *Fn = TagAddr.toPtr<WrapperFnTy>();
227     shared::WrapperFunctionResult ResultBytes(
228         Fn(ArgBytes.data(), ArgBytes.size()));
229     if (auto Err = sendMessage(SimpleRemoteEPCOpcode::Result, RemoteSeqNo,
230                                ExecutorAddr(),
231                                {ResultBytes.data(), ResultBytes.size()}))
232       ReportError(std::move(Err));
233   });
234 }
235 
236 shared::WrapperFunctionResult
237 SimpleRemoteEPCServer::doJITDispatch(const void *FnTag, const char *ArgData,
238                                      size_t ArgSize) {
239   uint64_t SeqNo;
240   std::promise<shared::WrapperFunctionResult> ResultP;
241   auto ResultF = ResultP.get_future();
242   {
243     std::lock_guard<std::mutex> Lock(ServerStateMutex);
244     if (RunState != ServerRunning)
245       return shared::WrapperFunctionResult::createOutOfBandError(
246           "jit_dispatch not available (EPC server shut down)");
247 
248     SeqNo = getNextSeqNo();
249     assert(!PendingJITDispatchResults.count(SeqNo) && "SeqNo already in use");
250     PendingJITDispatchResults[SeqNo] = &ResultP;
251   }
252 
253   if (auto Err = sendMessage(SimpleRemoteEPCOpcode::CallWrapper, SeqNo,
254                              ExecutorAddr::fromPtr(FnTag), {ArgData, ArgSize}))
255     ReportError(std::move(Err));
256 
257   return ResultF.get();
258 }
259 
260 shared::detail::CWrapperFunctionResult
261 SimpleRemoteEPCServer::jitDispatchEntry(void *DispatchCtx, const void *FnTag,
262                                         const char *ArgData, size_t ArgSize) {
263   return reinterpret_cast<SimpleRemoteEPCServer *>(DispatchCtx)
264       ->doJITDispatch(FnTag, ArgData, ArgSize)
265       .release();
266 }
267 
268 } // end namespace orc
269 } // end namespace llvm
270