xref: /llvm-project/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp (revision 2c8e784915887f72f13ee49cd513efb446eb23be)
1 //===------- SimpleRemoteEPC.cpp -- Simple remote executor control --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h"
10 #include "llvm/ExecutionEngine/Orc/EPCGenericJITLinkMemoryManager.h"
11 #include "llvm/ExecutionEngine/Orc/EPCGenericMemoryAccess.h"
12 #include "llvm/ExecutionEngine/Orc/Shared/OrcRTBridge.h"
13 #include "llvm/Support/FormatVariadic.h"
14 
15 #define DEBUG_TYPE "orc"
16 
17 namespace llvm {
18 namespace orc {
19 namespace shared {
20 
21 template <>
22 class SPSSerializationTraits<SPSRemoteSymbolLookupSetElement,
23                              SymbolLookupSet::value_type> {
24 public:
25   static size_t size(const SymbolLookupSet::value_type &V) {
26     return SPSArgList<SPSString, bool>::size(
27         *V.first, V.second == SymbolLookupFlags::RequiredSymbol);
28   }
29 
30   static bool serialize(SPSOutputBuffer &OB,
31                         const SymbolLookupSet::value_type &V) {
32     return SPSArgList<SPSString, bool>::serialize(
33         OB, *V.first, V.second == SymbolLookupFlags::RequiredSymbol);
34   }
35 };
36 
37 template <>
38 class TrivialSPSSequenceSerialization<SPSRemoteSymbolLookupSetElement,
39                                       SymbolLookupSet> {
40 public:
41   static constexpr bool available = true;
42 };
43 
44 template <>
45 class SPSSerializationTraits<SPSRemoteSymbolLookup,
46                              ExecutorProcessControl::LookupRequest> {
47   using MemberSerialization =
48       SPSArgList<SPSExecutorAddress, SPSRemoteSymbolLookupSet>;
49 
50 public:
51   static size_t size(const ExecutorProcessControl::LookupRequest &LR) {
52     return MemberSerialization::size(ExecutorAddress(LR.Handle), LR.Symbols);
53   }
54 
55   static bool serialize(SPSOutputBuffer &OB,
56                         const ExecutorProcessControl::LookupRequest &LR) {
57     return MemberSerialization::serialize(OB, ExecutorAddress(LR.Handle),
58                                           LR.Symbols);
59   }
60 };
61 
62 } // end namespace shared
63 
64 SimpleRemoteEPC::~SimpleRemoteEPC() {
65   assert(Disconnected && "Destroyed without disconnection");
66 }
67 
68 Expected<tpctypes::DylibHandle>
69 SimpleRemoteEPC::loadDylib(const char *DylibPath) {
70   Expected<tpctypes::DylibHandle> H((tpctypes::DylibHandle()));
71   if (auto Err = callSPSWrapper<shared::SPSLoadDylibSignature>(
72           LoadDylibAddr.getValue(), H, JDI.JITDispatchContextAddress,
73           StringRef(DylibPath), (uint64_t)0))
74     return std::move(Err);
75   return H;
76 }
77 
78 Expected<std::vector<tpctypes::LookupResult>>
79 SimpleRemoteEPC::lookupSymbols(ArrayRef<LookupRequest> Request) {
80   Expected<std::vector<tpctypes::LookupResult>> R(
81       (std::vector<tpctypes::LookupResult>()));
82 
83   if (auto Err = callSPSWrapper<shared::SPSLookupSymbolsSignature>(
84           LookupSymbolsAddr.getValue(), R, JDI.JITDispatchContextAddress,
85           Request))
86     return std::move(Err);
87   return R;
88 }
89 
90 Expected<int32_t> SimpleRemoteEPC::runAsMain(JITTargetAddress MainFnAddr,
91                                              ArrayRef<std::string> Args) {
92   int64_t Result = 0;
93   if (auto Err = callSPSWrapper<rt::SPSRunAsMainSignature>(
94           RunAsMainAddr.getValue(), Result, ExecutorAddress(MainFnAddr), Args))
95     return std::move(Err);
96   return Result;
97 }
98 
99 void SimpleRemoteEPC::callWrapperAsync(SendResultFunction OnComplete,
100                                        JITTargetAddress WrapperFnAddr,
101                                        ArrayRef<char> ArgBuffer) {
102   uint64_t SeqNo;
103   {
104     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
105     SeqNo = getNextSeqNo();
106     assert(!PendingCallWrapperResults.count(SeqNo) && "SeqNo already in use");
107     PendingCallWrapperResults[SeqNo] = std::move(OnComplete);
108   }
109 
110   if (auto Err = T->sendMessage(SimpleRemoteEPCOpcode::CallWrapper, SeqNo,
111                                 ExecutorAddress(WrapperFnAddr), ArgBuffer)) {
112     getExecutionSession().reportError(std::move(Err));
113   }
114 }
115 
116 Error SimpleRemoteEPC::disconnect() {
117   Disconnected = true;
118   T->disconnect();
119   return Error::success();
120 }
121 
122 Expected<SimpleRemoteEPCTransportClient::HandleMessageAction>
123 SimpleRemoteEPC::handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
124                                ExecutorAddress TagAddr,
125                                SimpleRemoteEPCArgBytesVector ArgBytes) {
126   using UT = std::underlying_type_t<SimpleRemoteEPCOpcode>;
127   if (static_cast<UT>(OpC) > static_cast<UT>(SimpleRemoteEPCOpcode::LastOpC))
128     return make_error<StringError>("Unexpected opcode",
129                                    inconvertibleErrorCode());
130 
131   switch (OpC) {
132   case SimpleRemoteEPCOpcode::Setup:
133     if (auto Err = handleSetup(SeqNo, TagAddr, std::move(ArgBytes)))
134       return std::move(Err);
135     break;
136   case SimpleRemoteEPCOpcode::Hangup:
137     // FIXME: Put EPC into 'detached' state.
138     return SimpleRemoteEPCTransportClient::EndSession;
139   case SimpleRemoteEPCOpcode::Result:
140     if (auto Err = handleResult(SeqNo, TagAddr, std::move(ArgBytes)))
141       return std::move(Err);
142     break;
143   case SimpleRemoteEPCOpcode::CallWrapper:
144     handleCallWrapper(SeqNo, TagAddr, std::move(ArgBytes));
145     break;
146   }
147   return ContinueSession;
148 }
149 
150 void SimpleRemoteEPC::handleDisconnect(Error Err) {
151   PendingCallWrapperResultsMap TmpPending;
152 
153   {
154     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
155     std::swap(TmpPending, PendingCallWrapperResults);
156   }
157 
158   for (auto &KV : TmpPending)
159     KV.second(
160         shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
161 
162   if (Err) {
163     // FIXME: Move ReportError to EPC.
164     if (ES)
165       ES->reportError(std::move(Err));
166     else
167       logAllUnhandledErrors(std::move(Err), errs(), "SimpleRemoteEPC: ");
168   }
169 }
170 
171 Expected<std::unique_ptr<jitlink::JITLinkMemoryManager>>
172 SimpleRemoteEPC::createMemoryManager() {
173   EPCGenericJITLinkMemoryManager::FuncAddrs FAs;
174   if (auto Err = getBootstrapSymbols(
175           {{FAs.Reserve, rt::MemoryReserveWrapperName},
176            {FAs.Finalize, rt::MemoryFinalizeWrapperName},
177            {FAs.Deallocate, rt::MemoryDeallocateWrapperName}}))
178     return std::move(Err);
179 
180   return std::make_unique<EPCGenericJITLinkMemoryManager>(*this, FAs);
181 }
182 
183 Expected<std::unique_ptr<ExecutorProcessControl::MemoryAccess>>
184 SimpleRemoteEPC::createMemoryAccess() {
185 
186   return nullptr;
187 }
188 
189 Error SimpleRemoteEPC::handleSetup(uint64_t SeqNo, ExecutorAddress TagAddr,
190                                    SimpleRemoteEPCArgBytesVector ArgBytes) {
191   if (SeqNo != 0)
192     return make_error<StringError>("Setup packet SeqNo not zero",
193                                    inconvertibleErrorCode());
194 
195   if (TagAddr)
196     return make_error<StringError>("Setup packet TagAddr not zero",
197                                    inconvertibleErrorCode());
198 
199   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
200   auto I = PendingCallWrapperResults.find(0);
201   assert(PendingCallWrapperResults.size() == 1 &&
202          I != PendingCallWrapperResults.end() &&
203          "Setup message handler not connectly set up");
204   auto SetupMsgHandler = std::move(I->second);
205   PendingCallWrapperResults.erase(I);
206 
207   auto WFR =
208       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
209   SetupMsgHandler(std::move(WFR));
210   return Error::success();
211 }
212 
213 void SimpleRemoteEPC::prepareToReceiveSetupMessage(
214     std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> &ExecInfoP) {
215   PendingCallWrapperResults[0] =
216       [&](shared::WrapperFunctionResult SetupMsgBytes) {
217         if (const char *ErrMsg = SetupMsgBytes.getOutOfBandError()) {
218           ExecInfoP.set_value(
219               make_error<StringError>(ErrMsg, inconvertibleErrorCode()));
220           return;
221         }
222         using SPSSerialize =
223             shared::SPSArgList<shared::SPSSimpleRemoteEPCExecutorInfo>;
224         shared::SPSInputBuffer IB(SetupMsgBytes.data(), SetupMsgBytes.size());
225         SimpleRemoteEPCExecutorInfo EI;
226         if (SPSSerialize::deserialize(IB, EI))
227           ExecInfoP.set_value(EI);
228         else
229           ExecInfoP.set_value(make_error<StringError>(
230               "Could not deserialize setup message", inconvertibleErrorCode()));
231       };
232 }
233 
234 Error SimpleRemoteEPC::setup(std::unique_ptr<SimpleRemoteEPCTransport> T,
235                              SimpleRemoteEPCExecutorInfo EI) {
236   using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
237   LLVM_DEBUG({
238     dbgs() << "SimpleRemoteEPC received setup message:\n"
239            << "  Triple: " << EI.TargetTriple << "\n"
240            << "  Page size: " << EI.PageSize << "\n"
241            << "  Bootstrap symbols:\n";
242     for (const auto &KV : EI.BootstrapSymbols)
243       dbgs() << "    " << KV.first() << ": "
244              << formatv("{0:x16}", KV.second.getValue()) << "\n";
245   });
246   this->T = std::move(T);
247   TargetTriple = Triple(EI.TargetTriple);
248   PageSize = EI.PageSize;
249   BootstrapSymbols = std::move(EI.BootstrapSymbols);
250 
251   if (auto Err = getBootstrapSymbols(
252           {{JDI.JITDispatchContextAddress, ExecutorSessionObjectName},
253            {JDI.JITDispatchFunctionAddress, DispatchFnName},
254            {LoadDylibAddr, "__llvm_orc_load_dylib"},
255            {LookupSymbolsAddr, "__llvm_orc_lookup_symbols"},
256            {RunAsMainAddr, rt::RunAsMainWrapperName}}))
257     return Err;
258 
259   if (auto MemMgr = createMemoryManager()) {
260     OwnedMemMgr = std::move(*MemMgr);
261     this->MemMgr = OwnedMemMgr.get();
262   } else
263     return MemMgr.takeError();
264 
265   if (auto MemAccess = createMemoryAccess()) {
266     OwnedMemAccess = std::move(*MemAccess);
267     this->MemAccess = OwnedMemAccess.get();
268   } else
269     return MemAccess.takeError();
270 
271   return Error::success();
272 }
273 
274 Error SimpleRemoteEPC::handleResult(uint64_t SeqNo, ExecutorAddress TagAddr,
275                                     SimpleRemoteEPCArgBytesVector ArgBytes) {
276   SendResultFunction SendResult;
277 
278   if (TagAddr)
279     return make_error<StringError>("Unexpected TagAddr in result message",
280                                    inconvertibleErrorCode());
281 
282   {
283     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
284     auto I = PendingCallWrapperResults.find(SeqNo);
285     if (I == PendingCallWrapperResults.end())
286       return make_error<StringError>("No call for sequence number " +
287                                          Twine(SeqNo),
288                                      inconvertibleErrorCode());
289     SendResult = std::move(I->second);
290     PendingCallWrapperResults.erase(I);
291     releaseSeqNo(SeqNo);
292   }
293 
294   auto WFR =
295       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
296   SendResult(std::move(WFR));
297   return Error::success();
298 }
299 
300 void SimpleRemoteEPC::handleCallWrapper(
301     uint64_t RemoteSeqNo, ExecutorAddress TagAddr,
302     SimpleRemoteEPCArgBytesVector ArgBytes) {
303   assert(ES && "No ExecutionSession attached");
304   ES->runJITDispatchHandler(
305       [this, RemoteSeqNo](shared::WrapperFunctionResult WFR) {
306         if (auto Err =
307                 T->sendMessage(SimpleRemoteEPCOpcode::Result, RemoteSeqNo,
308                                ExecutorAddress(), {WFR.data(), WFR.size()}))
309           getExecutionSession().reportError(std::move(Err));
310       },
311       TagAddr.getValue(), ArgBytes);
312 }
313 
314 } // end namespace orc
315 } // end namespace llvm
316