1 //===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utilities to support construction of simple RPC APIs.
10 //
11 // The RPC utilities aim for ease of use (minimal conceptual overhead) for C++
12 // programmers, high performance, low memory overhead, and efficient use of the
13 // communications channel.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H
18 #define LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H
19
20 #include <map>
21 #include <thread>
22 #include <vector>
23
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ExecutionEngine/Orc/Shared/OrcError.h"
26 #include "llvm/ExecutionEngine/Orc/Shared/Serialization.h"
27 #include "llvm/Support/MSVCErrorWorkarounds.h"
28
29 #include <future>
30
31 namespace llvm {
32 namespace orc {
33 namespace shared {
34
35 /// Base class of all fatal RPC errors (those that necessarily result in the
36 /// termination of the RPC session).
37 class RPCFatalError : public ErrorInfo<RPCFatalError> {
38 public:
39 static char ID;
40 };
41
42 /// RPCConnectionClosed is returned from RPC operations if the RPC connection
43 /// has already been closed due to either an error or graceful disconnection.
44 class ConnectionClosed : public ErrorInfo<ConnectionClosed> {
45 public:
46 static char ID;
47 std::error_code convertToErrorCode() const override;
48 void log(raw_ostream &OS) const override;
49 };
50
51 /// BadFunctionCall is returned from handleOne when the remote makes a call with
52 /// an unrecognized function id.
53 ///
54 /// This error is fatal because Orc RPC needs to know how to parse a function
55 /// call to know where the next call starts, and if it doesn't recognize the
56 /// function id it cannot parse the call.
57 template <typename FnIdT, typename SeqNoT>
58 class BadFunctionCall
59 : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
60 public:
61 static char ID;
62
BadFunctionCall(FnIdT FnId,SeqNoT SeqNo)63 BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
64 : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
65
convertToErrorCode()66 std::error_code convertToErrorCode() const override {
67 return orcError(OrcErrorCode::UnexpectedRPCCall);
68 }
69
log(raw_ostream & OS)70 void log(raw_ostream &OS) const override {
71 OS << "Call to invalid RPC function id '" << FnId
72 << "' with "
73 "sequence number "
74 << SeqNo;
75 }
76
77 private:
78 FnIdT FnId;
79 SeqNoT SeqNo;
80 };
81
82 template <typename FnIdT, typename SeqNoT>
83 char BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
84
85 /// InvalidSequenceNumberForResponse is returned from handleOne when a response
86 /// call arrives with a sequence number that doesn't correspond to any in-flight
87 /// function call.
88 ///
89 /// This error is fatal because Orc RPC needs to know how to parse the rest of
90 /// the response call to know where the next call starts, and if it doesn't have
91 /// a result parser for this sequence number it can't do that.
92 template <typename SeqNoT>
93 class InvalidSequenceNumberForResponse
94 : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>,
95 RPCFatalError> {
96 public:
97 static char ID;
98
InvalidSequenceNumberForResponse(SeqNoT SeqNo)99 InvalidSequenceNumberForResponse(SeqNoT SeqNo) : SeqNo(std::move(SeqNo)) {}
100
convertToErrorCode()101 std::error_code convertToErrorCode() const override {
102 return orcError(OrcErrorCode::UnexpectedRPCCall);
103 };
104
log(raw_ostream & OS)105 void log(raw_ostream &OS) const override {
106 OS << "Response has unknown sequence number " << SeqNo;
107 }
108
109 private:
110 SeqNoT SeqNo;
111 };
112
113 template <typename SeqNoT>
114 char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
115
116 /// This non-fatal error will be passed to asynchronous result handlers in place
117 /// of a result if the connection goes down before a result returns, or if the
118 /// function to be called cannot be negotiated with the remote.
119 class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
120 public:
121 static char ID;
122
123 std::error_code convertToErrorCode() const override;
124 void log(raw_ostream &OS) const override;
125 };
126
127 /// This error is returned if the remote does not have a handler installed for
128 /// the given RPC function.
129 class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
130 public:
131 static char ID;
132
133 CouldNotNegotiate(std::string Signature);
134 std::error_code convertToErrorCode() const override;
135 void log(raw_ostream &OS) const override;
getSignature()136 const std::string &getSignature() const { return Signature; }
137
138 private:
139 std::string Signature;
140 };
141
142 template <typename DerivedFunc, typename FnT> class RPCFunction;
143
144 // RPC Function class.
145 // DerivedFunc should be a user defined class with a static 'getName()' method
146 // returning a const char* representing the function's name.
147 template <typename DerivedFunc, typename RetT, typename... ArgTs>
148 class RPCFunction<DerivedFunc, RetT(ArgTs...)> {
149 public:
150 /// User defined function type.
151 using Type = RetT(ArgTs...);
152
153 /// Return type.
154 using ReturnType = RetT;
155
156 /// Returns the full function prototype as a string.
getPrototype()157 static const char *getPrototype() {
158 static std::string Name = [] {
159 std::string Name;
160 raw_string_ostream(Name)
161 << SerializationTypeName<RetT>::getName() << " "
162 << DerivedFunc::getName() << "("
163 << SerializationTypeNameSequence<ArgTs...>() << ")";
164 return Name;
165 }();
166 return Name.data();
167 }
168 };
169
170 /// Allocates RPC function ids during autonegotiation.
171 /// Specializations of this class must provide four members:
172 ///
173 /// static T getInvalidId():
174 /// Should return a reserved id that will be used to represent missing
175 /// functions during autonegotiation.
176 ///
177 /// static T getResponseId():
178 /// Should return a reserved id that will be used to send function responses
179 /// (return values).
180 ///
181 /// static T getNegotiateId():
182 /// Should return a reserved id for the negotiate function, which will be used
183 /// to negotiate ids for user defined functions.
184 ///
185 /// template <typename Func> T allocate():
186 /// Allocate a unique id for function Func.
187 template <typename T, typename = void> class RPCFunctionIdAllocator;
188
189 /// This specialization of RPCFunctionIdAllocator provides a default
190 /// implementation for integral types.
191 template <typename T>
192 class RPCFunctionIdAllocator<T, std::enable_if_t<std::is_integral<T>::value>> {
193 public:
getInvalidId()194 static T getInvalidId() { return T(0); }
getResponseId()195 static T getResponseId() { return T(1); }
getNegotiateId()196 static T getNegotiateId() { return T(2); }
197
allocate()198 template <typename Func> T allocate() { return NextId++; }
199
200 private:
201 T NextId = 3;
202 };
203
204 namespace detail {
205
206 /// Provides a typedef for a tuple containing the decayed argument types.
207 template <typename T> class RPCFunctionArgsTuple;
208
209 template <typename RetT, typename... ArgTs>
210 class RPCFunctionArgsTuple<RetT(ArgTs...)> {
211 public:
212 using Type = std::tuple<std::decay_t<std::remove_reference_t<ArgTs>>...>;
213 };
214
215 // ResultTraits provides typedefs and utilities specific to the return type
216 // of functions.
217 template <typename RetT> class ResultTraits {
218 public:
219 // The return type wrapped in llvm::Expected.
220 using ErrorReturnType = Expected<RetT>;
221
222 #ifdef _MSC_VER
223 // The ErrorReturnType wrapped in a std::promise.
224 using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>;
225
226 // The ErrorReturnType wrapped in a std::future.
227 using ReturnFutureType = std::future<MSVCPExpected<RetT>>;
228 #else
229 // The ErrorReturnType wrapped in a std::promise.
230 using ReturnPromiseType = std::promise<ErrorReturnType>;
231
232 // The ErrorReturnType wrapped in a std::future.
233 using ReturnFutureType = std::future<ErrorReturnType>;
234 #endif
235
236 // Create a 'blank' value of the ErrorReturnType, ready and safe to
237 // overwrite.
createBlankErrorReturnValue()238 static ErrorReturnType createBlankErrorReturnValue() {
239 return ErrorReturnType(RetT());
240 }
241
242 // Consume an abandoned ErrorReturnType.
consumeAbandoned(ErrorReturnType RetOrErr)243 static void consumeAbandoned(ErrorReturnType RetOrErr) {
244 consumeError(RetOrErr.takeError());
245 }
246
returnError(Error Err)247 static ErrorReturnType returnError(Error Err) { return std::move(Err); }
248 };
249
250 // ResultTraits specialization for void functions.
251 template <> class ResultTraits<void> {
252 public:
253 // For void functions, ErrorReturnType is llvm::Error.
254 using ErrorReturnType = Error;
255
256 #ifdef _MSC_VER
257 // The ErrorReturnType wrapped in a std::promise.
258 using ReturnPromiseType = std::promise<MSVCPError>;
259
260 // The ErrorReturnType wrapped in a std::future.
261 using ReturnFutureType = std::future<MSVCPError>;
262 #else
263 // The ErrorReturnType wrapped in a std::promise.
264 using ReturnPromiseType = std::promise<ErrorReturnType>;
265
266 // The ErrorReturnType wrapped in a std::future.
267 using ReturnFutureType = std::future<ErrorReturnType>;
268 #endif
269
270 // Create a 'blank' value of the ErrorReturnType, ready and safe to
271 // overwrite.
createBlankErrorReturnValue()272 static ErrorReturnType createBlankErrorReturnValue() {
273 return ErrorReturnType::success();
274 }
275
276 // Consume an abandoned ErrorReturnType.
consumeAbandoned(ErrorReturnType Err)277 static void consumeAbandoned(ErrorReturnType Err) {
278 consumeError(std::move(Err));
279 }
280
returnError(Error Err)281 static ErrorReturnType returnError(Error Err) { return Err; }
282 };
283
284 // ResultTraits<Error> is equivalent to ResultTraits<void>. This allows
285 // handlers for void RPC functions to return either void (in which case they
286 // implicitly succeed) or Error (in which case their error return is
287 // propagated). See usage in HandlerTraits::runHandlerHelper.
288 template <> class ResultTraits<Error> : public ResultTraits<void> {};
289
290 // ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows
291 // handlers for RPC functions returning a T to return either a T (in which
292 // case they implicitly succeed) or Expected<T> (in which case their error
293 // return is propagated). See usage in HandlerTraits::runHandlerHelper.
294 template <typename RetT>
295 class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
296
297 // Determines whether an RPC function's defined error return type supports
298 // error return value.
299 template <typename T> class SupportsErrorReturn {
300 public:
301 static const bool value = false;
302 };
303
304 template <> class SupportsErrorReturn<Error> {
305 public:
306 static const bool value = true;
307 };
308
309 template <typename T> class SupportsErrorReturn<Expected<T>> {
310 public:
311 static const bool value = true;
312 };
313
314 // RespondHelper packages return values based on whether or not the declared
315 // RPC function return type supports error returns.
316 template <bool FuncSupportsErrorReturn> class RespondHelper;
317
318 // RespondHelper specialization for functions that support error returns.
319 template <> class RespondHelper<true> {
320 public:
321 // Send Expected<T>.
322 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
323 typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)324 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
325 SequenceNumberT SeqNo,
326 Expected<HandlerRetT> ResultOrErr) {
327 if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
328 return ResultOrErr.takeError();
329
330 // Open the response message.
331 if (auto Err = C.startSendMessage(ResponseId, SeqNo))
332 return Err;
333
334 // Serialize the result.
335 if (auto Err =
336 SerializationTraits<ChannelT, WireRetT, Expected<HandlerRetT>>::
337 serialize(C, std::move(ResultOrErr)))
338 return Err;
339
340 // Close the response message.
341 if (auto Err = C.endSendMessage())
342 return Err;
343 return C.send();
344 }
345
346 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)347 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
348 SequenceNumberT SeqNo, Error Err) {
349 if (Err && Err.isA<RPCFatalError>())
350 return Err;
351 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
352 return Err2;
353 if (auto Err2 = serializeSeq(C, std::move(Err)))
354 return Err2;
355 if (auto Err2 = C.endSendMessage())
356 return Err2;
357 return C.send();
358 }
359 };
360
361 // RespondHelper specialization for functions that do not support error returns.
362 template <> class RespondHelper<false> {
363 public:
364 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
365 typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)366 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
367 SequenceNumberT SeqNo,
368 Expected<HandlerRetT> ResultOrErr) {
369 if (auto Err = ResultOrErr.takeError())
370 return Err;
371
372 // Open the response message.
373 if (auto Err = C.startSendMessage(ResponseId, SeqNo))
374 return Err;
375
376 // Serialize the result.
377 if (auto Err =
378 SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
379 C, *ResultOrErr))
380 return Err;
381
382 // End the response message.
383 if (auto Err = C.endSendMessage())
384 return Err;
385
386 return C.send();
387 }
388
389 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)390 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
391 SequenceNumberT SeqNo, Error Err) {
392 if (Err)
393 return Err;
394 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
395 return Err2;
396 if (auto Err2 = C.endSendMessage())
397 return Err2;
398 return C.send();
399 }
400 };
401
402 // Send a response of the given wire return type (WireRetT) over the
403 // channel, with the given sequence number.
404 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
405 typename FunctionIdT, typename SequenceNumberT>
respond(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)406 Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
407 Expected<HandlerRetT> ResultOrErr) {
408 return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
409 template sendResult<WireRetT>(C, ResponseId, SeqNo,
410 std::move(ResultOrErr));
411 }
412
413 // Send an empty response message on the given channel to indicate that
414 // the handler ran.
415 template <typename WireRetT, typename ChannelT, typename FunctionIdT,
416 typename SequenceNumberT>
respond(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)417 Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
418 Error Err) {
419 return RespondHelper<SupportsErrorReturn<WireRetT>::value>::sendResult(
420 C, ResponseId, SeqNo, std::move(Err));
421 }
422
423 // Converts a given type to the equivalent error return type.
424 template <typename T> class WrappedHandlerReturn {
425 public:
426 using Type = Expected<T>;
427 };
428
429 template <typename T> class WrappedHandlerReturn<Expected<T>> {
430 public:
431 using Type = Expected<T>;
432 };
433
434 template <> class WrappedHandlerReturn<void> {
435 public:
436 using Type = Error;
437 };
438
439 template <> class WrappedHandlerReturn<Error> {
440 public:
441 using Type = Error;
442 };
443
444 template <> class WrappedHandlerReturn<ErrorSuccess> {
445 public:
446 using Type = Error;
447 };
448
449 // Traits class that strips the response function from the list of handler
450 // arguments.
451 template <typename FnT> class AsyncHandlerTraits;
452
453 template <typename ResultT, typename... ArgTs>
454 class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>,
455 ArgTs...)> {
456 public:
457 using Type = Error(ArgTs...);
458 using ResultType = Expected<ResultT>;
459 };
460
461 template <typename... ArgTs>
462 class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
463 public:
464 using Type = Error(ArgTs...);
465 using ResultType = Error;
466 };
467
468 template <typename... ArgTs>
469 class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> {
470 public:
471 using Type = Error(ArgTs...);
472 using ResultType = Error;
473 };
474
475 template <typename... ArgTs>
476 class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> {
477 public:
478 using Type = Error(ArgTs...);
479 using ResultType = Error;
480 };
481
482 template <typename ResponseHandlerT, typename... ArgTs>
483 class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)>
484 : public AsyncHandlerTraits<Error(std::decay_t<ResponseHandlerT>,
485 ArgTs...)> {};
486
487 // This template class provides utilities related to RPC function handlers.
488 // The base case applies to non-function types (the template class is
489 // specialized for function types) and inherits from the appropriate
490 // speciilization for the given non-function type's call operator.
491 template <typename HandlerT>
492 class HandlerTraits
493 : public HandlerTraits<
494 decltype(&std::remove_reference<HandlerT>::type::operator())> {};
495
496 // Traits for handlers with a given function type.
497 template <typename RetT, typename... ArgTs>
498 class HandlerTraits<RetT(ArgTs...)> {
499 public:
500 // Function type of the handler.
501 using Type = RetT(ArgTs...);
502
503 // Return type of the handler.
504 using ReturnType = RetT;
505
506 // Call the given handler with the given arguments.
507 template <typename HandlerT, typename... TArgTs>
508 static typename WrappedHandlerReturn<RetT>::Type
unpackAndRun(HandlerT & Handler,std::tuple<TArgTs...> & Args)509 unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
510 return unpackAndRunHelper(Handler, Args,
511 std::index_sequence_for<TArgTs...>());
512 }
513
514 // Call the given handler with the given arguments.
515 template <typename HandlerT, typename ResponderT, typename... TArgTs>
unpackAndRunAsync(HandlerT & Handler,ResponderT & Responder,std::tuple<TArgTs...> & Args)516 static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
517 std::tuple<TArgTs...> &Args) {
518 return unpackAndRunAsyncHelper(Handler, Responder, Args,
519 std::index_sequence_for<TArgTs...>());
520 }
521
522 // Call the given handler with the given arguments.
523 template <typename HandlerT>
524 static std::enable_if_t<
525 std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, Error>
run(HandlerT & Handler,ArgTs &&...Args)526 run(HandlerT &Handler, ArgTs &&...Args) {
527 Handler(std::move(Args)...);
528 return Error::success();
529 }
530
531 template <typename HandlerT, typename... TArgTs>
532 static std::enable_if_t<
533 !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
534 typename HandlerTraits<HandlerT>::ReturnType>
run(HandlerT & Handler,TArgTs...Args)535 run(HandlerT &Handler, TArgTs... Args) {
536 return Handler(std::move(Args)...);
537 }
538
539 // Serialize arguments to the channel.
540 template <typename ChannelT, typename... CArgTs>
serializeArgs(ChannelT & C,const CArgTs...CArgs)541 static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) {
542 return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
543 }
544
545 // Deserialize arguments from the channel.
546 template <typename ChannelT, typename... CArgTs>
deserializeArgs(ChannelT & C,std::tuple<CArgTs...> & Args)547 static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) {
548 return deserializeArgsHelper(C, Args, std::index_sequence_for<CArgTs...>());
549 }
550
551 private:
552 template <typename ChannelT, typename... CArgTs, size_t... Indexes>
deserializeArgsHelper(ChannelT & C,std::tuple<CArgTs...> & Args,std::index_sequence<Indexes...> _)553 static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args,
554 std::index_sequence<Indexes...> _) {
555 return SequenceSerialization<ChannelT, ArgTs...>::deserialize(
556 C, std::get<Indexes>(Args)...);
557 }
558
559 template <typename HandlerT, typename ArgTuple, size_t... Indexes>
560 static typename WrappedHandlerReturn<
561 typename HandlerTraits<HandlerT>::ReturnType>::Type
unpackAndRunHelper(HandlerT & Handler,ArgTuple & Args,std::index_sequence<Indexes...>)562 unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
563 std::index_sequence<Indexes...>) {
564 return run(Handler, std::move(std::get<Indexes>(Args))...);
565 }
566
567 template <typename HandlerT, typename ResponderT, typename ArgTuple,
568 size_t... Indexes>
569 static typename WrappedHandlerReturn<
570 typename HandlerTraits<HandlerT>::ReturnType>::Type
unpackAndRunAsyncHelper(HandlerT & Handler,ResponderT & Responder,ArgTuple & Args,std::index_sequence<Indexes...>)571 unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
572 ArgTuple &Args, std::index_sequence<Indexes...>) {
573 return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
574 }
575 };
576
577 // Handler traits for free functions.
578 template <typename RetT, typename... ArgTs>
579 class HandlerTraits<RetT (*)(ArgTs...)> : public HandlerTraits<RetT(ArgTs...)> {
580 };
581
582 // Handler traits for class methods (especially call operators for lambdas).
583 template <typename Class, typename RetT, typename... ArgTs>
584 class HandlerTraits<RetT (Class::*)(ArgTs...)>
585 : public HandlerTraits<RetT(ArgTs...)> {};
586
587 // Handler traits for const class methods (especially call operators for
588 // lambdas).
589 template <typename Class, typename RetT, typename... ArgTs>
590 class HandlerTraits<RetT (Class::*)(ArgTs...) const>
591 : public HandlerTraits<RetT(ArgTs...)> {};
592
593 // Utility to peel the Expected wrapper off a response handler error type.
594 template <typename HandlerT> class ResponseHandlerArg;
595
596 template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> {
597 public:
598 using ArgType = Expected<ArgT>;
599 using UnwrappedArgType = ArgT;
600 };
601
602 template <typename ArgT>
603 class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
604 public:
605 using ArgType = Expected<ArgT>;
606 using UnwrappedArgType = ArgT;
607 };
608
609 template <> class ResponseHandlerArg<Error(Error)> {
610 public:
611 using ArgType = Error;
612 };
613
614 template <> class ResponseHandlerArg<ErrorSuccess(Error)> {
615 public:
616 using ArgType = Error;
617 };
618
619 // ResponseHandler represents a handler for a not-yet-received function call
620 // result.
621 template <typename ChannelT> class ResponseHandler {
622 public:
~ResponseHandler()623 virtual ~ResponseHandler() {}
624
625 // Reads the function result off the wire and acts on it. The meaning of
626 // "act" will depend on how this method is implemented in any given
627 // ResponseHandler subclass but could, for example, mean running a
628 // user-specified handler or setting a promise value.
629 virtual Error handleResponse(ChannelT &C) = 0;
630
631 // Abandons this outstanding result.
632 virtual void abandon() = 0;
633
634 // Create an error instance representing an abandoned response.
createAbandonedResponseError()635 static Error createAbandonedResponseError() {
636 return make_error<ResponseAbandoned>();
637 }
638 };
639
640 // ResponseHandler subclass for RPC functions with non-void returns.
641 template <typename ChannelT, typename FuncRetT, typename HandlerT>
642 class ResponseHandlerImpl : public ResponseHandler<ChannelT> {
643 public:
ResponseHandlerImpl(HandlerT Handler)644 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
645
646 // Handle the result by deserializing it from the channel then passing it
647 // to the user defined handler.
handleResponse(ChannelT & C)648 Error handleResponse(ChannelT &C) override {
649 using UnwrappedArgType = typename ResponseHandlerArg<
650 typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType;
651 UnwrappedArgType Result;
652 if (auto Err =
653 SerializationTraits<ChannelT, FuncRetT,
654 UnwrappedArgType>::deserialize(C, Result))
655 return Err;
656 if (auto Err = C.endReceiveMessage())
657 return Err;
658 return Handler(std::move(Result));
659 }
660
661 // Abandon this response by calling the handler with an 'abandoned response'
662 // error.
abandon()663 void abandon() override {
664 if (auto Err = Handler(this->createAbandonedResponseError())) {
665 // Handlers should not fail when passed an abandoned response error.
666 report_fatal_error(std::move(Err));
667 }
668 }
669
670 private:
671 HandlerT Handler;
672 };
673
674 // ResponseHandler subclass for RPC functions with void returns.
675 template <typename ChannelT, typename HandlerT>
676 class ResponseHandlerImpl<ChannelT, void, HandlerT>
677 : public ResponseHandler<ChannelT> {
678 public:
ResponseHandlerImpl(HandlerT Handler)679 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
680
681 // Handle the result (no actual value, just a notification that the function
682 // has completed on the remote end) by calling the user-defined handler with
683 // Error::success().
handleResponse(ChannelT & C)684 Error handleResponse(ChannelT &C) override {
685 if (auto Err = C.endReceiveMessage())
686 return Err;
687 return Handler(Error::success());
688 }
689
690 // Abandon this response by calling the handler with an 'abandoned response'
691 // error.
abandon()692 void abandon() override {
693 if (auto Err = Handler(this->createAbandonedResponseError())) {
694 // Handlers should not fail when passed an abandoned response error.
695 report_fatal_error(std::move(Err));
696 }
697 }
698
699 private:
700 HandlerT Handler;
701 };
702
703 template <typename ChannelT, typename FuncRetT, typename HandlerT>
704 class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
705 : public ResponseHandler<ChannelT> {
706 public:
ResponseHandlerImpl(HandlerT Handler)707 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
708
709 // Handle the result by deserializing it from the channel then passing it
710 // to the user defined handler.
handleResponse(ChannelT & C)711 Error handleResponse(ChannelT &C) override {
712 using HandlerArgType = typename ResponseHandlerArg<
713 typename HandlerTraits<HandlerT>::Type>::ArgType;
714 HandlerArgType Result((typename HandlerArgType::value_type()));
715
716 if (auto Err = SerializationTraits<ChannelT, Expected<FuncRetT>,
717 HandlerArgType>::deserialize(C, Result))
718 return Err;
719 if (auto Err = C.endReceiveMessage())
720 return Err;
721 return Handler(std::move(Result));
722 }
723
724 // Abandon this response by calling the handler with an 'abandoned response'
725 // error.
abandon()726 void abandon() override {
727 if (auto Err = Handler(this->createAbandonedResponseError())) {
728 // Handlers should not fail when passed an abandoned response error.
729 report_fatal_error(std::move(Err));
730 }
731 }
732
733 private:
734 HandlerT Handler;
735 };
736
737 template <typename ChannelT, typename HandlerT>
738 class ResponseHandlerImpl<ChannelT, Error, HandlerT>
739 : public ResponseHandler<ChannelT> {
740 public:
ResponseHandlerImpl(HandlerT Handler)741 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
742
743 // Handle the result by deserializing it from the channel then passing it
744 // to the user defined handler.
handleResponse(ChannelT & C)745 Error handleResponse(ChannelT &C) override {
746 Error Result = Error::success();
747 if (auto Err = SerializationTraits<ChannelT, Error, Error>::deserialize(
748 C, Result)) {
749 consumeError(std::move(Result));
750 return Err;
751 }
752 if (auto Err = C.endReceiveMessage()) {
753 consumeError(std::move(Result));
754 return Err;
755 }
756 return Handler(std::move(Result));
757 }
758
759 // Abandon this response by calling the handler with an 'abandoned response'
760 // error.
abandon()761 void abandon() override {
762 if (auto Err = Handler(this->createAbandonedResponseError())) {
763 // Handlers should not fail when passed an abandoned response error.
764 report_fatal_error(std::move(Err));
765 }
766 }
767
768 private:
769 HandlerT Handler;
770 };
771
772 // Create a ResponseHandler from a given user handler.
773 template <typename ChannelT, typename FuncRetT, typename HandlerT>
createResponseHandler(HandlerT H)774 std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
775 return std::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>(
776 std::move(H));
777 }
778
779 // Helper for wrapping member functions up as functors. This is useful for
780 // installing methods as result handlers.
781 template <typename ClassT, typename RetT, typename... ArgTs>
782 class MemberFnWrapper {
783 public:
784 using MethodT = RetT (ClassT::*)(ArgTs...);
MemberFnWrapper(ClassT & Instance,MethodT Method)785 MemberFnWrapper(ClassT &Instance, MethodT Method)
786 : Instance(Instance), Method(Method) {}
operator()787 RetT operator()(ArgTs &&...Args) {
788 return (Instance.*Method)(std::move(Args)...);
789 }
790
791 private:
792 ClassT &Instance;
793 MethodT Method;
794 };
795
796 // Helper that provides a Functor for deserializing arguments.
797 template <typename... ArgTs> class ReadArgs {
798 public:
operator()799 Error operator()() { return Error::success(); }
800 };
801
802 template <typename ArgT, typename... ArgTs>
803 class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> {
804 public:
ReadArgs(ArgT & Arg,ArgTs &...Args)805 ReadArgs(ArgT &Arg, ArgTs &...Args) : ReadArgs<ArgTs...>(Args...), Arg(Arg) {}
806
operator()807 Error operator()(ArgT &ArgVal, ArgTs &...ArgVals) {
808 this->Arg = std::move(ArgVal);
809 return ReadArgs<ArgTs...>::operator()(ArgVals...);
810 }
811
812 private:
813 ArgT &Arg;
814 };
815
816 // Manage sequence numbers.
817 template <typename SequenceNumberT> class SequenceNumberManager {
818 public:
819 // Reset, making all sequence numbers available.
reset()820 void reset() {
821 std::lock_guard<std::mutex> Lock(SeqNoLock);
822 NextSequenceNumber = 0;
823 FreeSequenceNumbers.clear();
824 }
825
826 // Get the next available sequence number. Will re-use numbers that have
827 // been released.
getSequenceNumber()828 SequenceNumberT getSequenceNumber() {
829 std::lock_guard<std::mutex> Lock(SeqNoLock);
830 if (FreeSequenceNumbers.empty())
831 return NextSequenceNumber++;
832 auto SequenceNumber = FreeSequenceNumbers.back();
833 FreeSequenceNumbers.pop_back();
834 return SequenceNumber;
835 }
836
837 // Release a sequence number, making it available for re-use.
releaseSequenceNumber(SequenceNumberT SequenceNumber)838 void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
839 std::lock_guard<std::mutex> Lock(SeqNoLock);
840 FreeSequenceNumbers.push_back(SequenceNumber);
841 }
842
843 private:
844 std::mutex SeqNoLock;
845 SequenceNumberT NextSequenceNumber = 0;
846 std::vector<SequenceNumberT> FreeSequenceNumbers;
847 };
848
849 // Checks that predicate P holds for each corresponding pair of type arguments
850 // from T1 and T2 tuple.
851 template <template <class, class> class P, typename T1Tuple, typename T2Tuple>
852 class RPCArgTypeCheckHelper;
853
854 template <template <class, class> class P>
855 class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
856 public:
857 static const bool value = true;
858 };
859
860 template <template <class, class> class P, typename T, typename... Ts,
861 typename U, typename... Us>
862 class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
863 public:
864 static const bool value =
865 P<T, U>::value &&
866 RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
867 };
868
869 template <template <class, class> class P, typename T1Sig, typename T2Sig>
870 class RPCArgTypeCheck {
871 public:
872 using T1Tuple = typename RPCFunctionArgsTuple<T1Sig>::Type;
873 using T2Tuple = typename RPCFunctionArgsTuple<T2Sig>::Type;
874
875 static_assert(std::tuple_size<T1Tuple>::value >=
876 std::tuple_size<T2Tuple>::value,
877 "Too many arguments to RPC call");
878 static_assert(std::tuple_size<T1Tuple>::value <=
879 std::tuple_size<T2Tuple>::value,
880 "Too few arguments to RPC call");
881
882 static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
883 };
884
885 template <typename ChannelT, typename WireT, typename ConcreteT>
886 class CanSerialize {
887 private:
888 using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
889
890 template <typename T>
891 static std::true_type check(
892 std::enable_if_t<std::is_same<decltype(T::serialize(
893 std::declval<ChannelT &>(),
894 std::declval<const ConcreteT &>())),
895 Error>::value,
896 void *>);
897
898 template <typename> static std::false_type check(...);
899
900 public:
901 static const bool value = decltype(check<S>(0))::value;
902 };
903
904 template <typename ChannelT, typename WireT, typename ConcreteT>
905 class CanDeserialize {
906 private:
907 using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
908
909 template <typename T>
910 static std::true_type
911 check(std::enable_if_t<
912 std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(),
913 std::declval<ConcreteT &>())),
914 Error>::value,
915 void *>);
916
917 template <typename> static std::false_type check(...);
918
919 public:
920 static const bool value = decltype(check<S>(0))::value;
921 };
922
923 /// Contains primitive utilities for defining, calling and handling calls to
924 /// remote procedures. ChannelT is a bidirectional stream conforming to the
925 /// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
926 /// identifier type that must be serializable on ChannelT, and SequenceNumberT
927 /// is an integral type that will be used to number in-flight function calls.
928 ///
929 /// These utilities support the construction of very primitive RPC utilities.
930 /// Their intent is to ensure correct serialization and deserialization of
931 /// procedure arguments, and to keep the client and server's view of the API in
932 /// sync.
933 template <typename ImplT, typename ChannelT, typename FunctionIdT,
934 typename SequenceNumberT>
935 class RPCEndpointBase {
936 protected:
937 class OrcRPCInvalid : public RPCFunction<OrcRPCInvalid, void()> {
938 public:
getName()939 static const char *getName() { return "__orc_rpc$invalid"; }
940 };
941
942 class OrcRPCResponse : public RPCFunction<OrcRPCResponse, void()> {
943 public:
getName()944 static const char *getName() { return "__orc_rpc$response"; }
945 };
946
947 class OrcRPCNegotiate
948 : public RPCFunction<OrcRPCNegotiate, FunctionIdT(std::string)> {
949 public:
getName()950 static const char *getName() { return "__orc_rpc$negotiate"; }
951 };
952
953 // Helper predicate for testing for the presence of SerializeTraits
954 // serializers.
955 template <typename WireT, typename ConcreteT>
956 class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
957 public:
958 using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
959
960 static_assert(value, "Missing serializer for argument (Can't serialize the "
961 "first template type argument of CanSerializeCheck "
962 "from the second)");
963 };
964
965 // Helper predicate for testing for the presence of SerializeTraits
966 // deserializers.
967 template <typename WireT, typename ConcreteT>
968 class CanDeserializeCheck
969 : detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
970 public:
971 using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
972
973 static_assert(value, "Missing deserializer for argument (Can't deserialize "
974 "the second template type argument of "
975 "CanDeserializeCheck from the first)");
976 };
977
978 public:
979 /// Construct an RPC instance on a channel.
RPCEndpointBase(ChannelT & C,bool LazyAutoNegotiation)980 RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation)
981 : C(C), LazyAutoNegotiation(LazyAutoNegotiation) {
982 // Hold ResponseId in a special variable, since we expect Response to be
983 // called relatively frequently, and want to avoid the map lookup.
984 ResponseId = FnIdAllocator.getResponseId();
985 RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId;
986
987 // Register the negotiate function id and handler.
988 auto NegotiateId = FnIdAllocator.getNegotiateId();
989 RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
990 Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
991 [this](const std::string &Name) { return handleNegotiate(Name); });
992 }
993
994 /// Negotiate a function id for Func with the other end of the channel.
995 template <typename Func> Error negotiateFunction(bool Retry = false) {
996 return getRemoteFunctionId<Func>(true, Retry).takeError();
997 }
998
999 /// Append a call Func, does not call send on the channel.
1000 /// The first argument specifies a user-defined handler to be run when the
1001 /// function returns. The handler should take an Expected<Func::ReturnType>,
1002 /// or an Error (if Func::ReturnType is void). The handler will be called
1003 /// with an error if the return value is abandoned due to a channel error.
1004 template <typename Func, typename HandlerT, typename... ArgTs>
appendCallAsync(HandlerT Handler,const ArgTs &...Args)1005 Error appendCallAsync(HandlerT Handler, const ArgTs &...Args) {
1006
1007 static_assert(
1008 detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
1009 void(ArgTs...)>::value,
1010 "");
1011
1012 // Look up the function ID.
1013 FunctionIdT FnId;
1014 if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
1015 FnId = *FnIdOrErr;
1016 else {
1017 // Negotiation failed. Notify the handler then return the negotiate-failed
1018 // error.
1019 cantFail(Handler(make_error<ResponseAbandoned>()));
1020 return FnIdOrErr.takeError();
1021 }
1022
1023 SequenceNumberT SeqNo; // initialized in locked scope below.
1024 {
1025 // Lock the pending responses map and sequence number manager.
1026 std::lock_guard<std::mutex> Lock(ResponsesMutex);
1027
1028 // Allocate a sequence number.
1029 SeqNo = SequenceNumberMgr.getSequenceNumber();
1030 assert(!PendingResponses.count(SeqNo) &&
1031 "Sequence number already allocated");
1032
1033 // Install the user handler.
1034 PendingResponses[SeqNo] =
1035 detail::createResponseHandler<ChannelT, typename Func::ReturnType>(
1036 std::move(Handler));
1037 }
1038
1039 // Open the function call message.
1040 if (auto Err = C.startSendMessage(FnId, SeqNo)) {
1041 abandonPendingResponses();
1042 return Err;
1043 }
1044
1045 // Serialize the call arguments.
1046 if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
1047 C, Args...)) {
1048 abandonPendingResponses();
1049 return Err;
1050 }
1051
1052 // Close the function call messagee.
1053 if (auto Err = C.endSendMessage()) {
1054 abandonPendingResponses();
1055 return Err;
1056 }
1057
1058 return Error::success();
1059 }
1060
sendAppendedCalls()1061 Error sendAppendedCalls() { return C.send(); };
1062
1063 template <typename Func, typename HandlerT, typename... ArgTs>
callAsync(HandlerT Handler,const ArgTs &...Args)1064 Error callAsync(HandlerT Handler, const ArgTs &...Args) {
1065 if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...))
1066 return Err;
1067 return C.send();
1068 }
1069
1070 /// Handle one incoming call.
handleOne()1071 Error handleOne() {
1072 FunctionIdT FnId;
1073 SequenceNumberT SeqNo;
1074 if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
1075 abandonPendingResponses();
1076 return Err;
1077 }
1078 if (FnId == ResponseId)
1079 return handleResponse(SeqNo);
1080 auto I = Handlers.find(FnId);
1081 if (I != Handlers.end())
1082 return I->second(C, SeqNo);
1083
1084 // else: No handler found. Report error to client?
1085 return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
1086 SeqNo);
1087 }
1088
1089 /// Helper for handling setter procedures - this method returns a functor that
1090 /// sets the variables referred to by Args... to values deserialized from the
1091 /// channel.
1092 /// E.g.
1093 ///
1094 /// typedef Function<0, bool, int> Func1;
1095 ///
1096 /// ...
1097 /// bool B;
1098 /// int I;
1099 /// if (auto Err = expect<Func1>(Channel, readArgs(B, I)))
1100 /// /* Handle Args */ ;
1101 ///
1102 template <typename... ArgTs>
readArgs(ArgTs &...Args)1103 static detail::ReadArgs<ArgTs...> readArgs(ArgTs &...Args) {
1104 return detail::ReadArgs<ArgTs...>(Args...);
1105 }
1106
1107 /// Abandon all outstanding result handlers.
1108 ///
1109 /// This will call all currently registered result handlers to receive an
1110 /// "abandoned" error as their argument. This is used internally by the RPC
1111 /// in error situations, but can also be called directly by clients who are
1112 /// disconnecting from the remote and don't or can't expect responses to their
1113 /// outstanding calls. (Especially for outstanding blocking calls, calling
1114 /// this function may be necessary to avoid dead threads).
abandonPendingResponses()1115 void abandonPendingResponses() {
1116 // Lock the pending responses map and sequence number manager.
1117 std::lock_guard<std::mutex> Lock(ResponsesMutex);
1118
1119 for (auto &KV : PendingResponses)
1120 KV.second->abandon();
1121 PendingResponses.clear();
1122 SequenceNumberMgr.reset();
1123 }
1124
1125 /// Remove the handler for the given function.
1126 /// A handler must currently be registered for this function.
removeHandler()1127 template <typename Func> void removeHandler() {
1128 auto IdItr = LocalFunctionIds.find(Func::getPrototype());
1129 assert(IdItr != LocalFunctionIds.end() &&
1130 "Function does not have a registered handler");
1131 auto HandlerItr = Handlers.find(IdItr->second);
1132 assert(HandlerItr != Handlers.end() &&
1133 "Function does not have a registered handler");
1134 Handlers.erase(HandlerItr);
1135 }
1136
1137 /// Clear all handlers.
clearHandlers()1138 void clearHandlers() { Handlers.clear(); }
1139
1140 protected:
getInvalidFunctionId()1141 FunctionIdT getInvalidFunctionId() const {
1142 return FnIdAllocator.getInvalidId();
1143 }
1144
1145 /// Add the given handler to the handler map and make it available for
1146 /// autonegotiation and execution.
1147 template <typename Func, typename HandlerT>
addHandlerImpl(HandlerT Handler)1148 void addHandlerImpl(HandlerT Handler) {
1149
1150 static_assert(detail::RPCArgTypeCheck<
1151 CanDeserializeCheck, typename Func::Type,
1152 typename detail::HandlerTraits<HandlerT>::Type>::value,
1153 "");
1154
1155 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1156 LocalFunctionIds[Func::getPrototype()] = NewFnId;
1157 Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
1158 }
1159
1160 template <typename Func, typename HandlerT>
addAsyncHandlerImpl(HandlerT Handler)1161 void addAsyncHandlerImpl(HandlerT Handler) {
1162
1163 static_assert(
1164 detail::RPCArgTypeCheck<
1165 CanDeserializeCheck, typename Func::Type,
1166 typename detail::AsyncHandlerTraits<
1167 typename detail::HandlerTraits<HandlerT>::Type>::Type>::value,
1168 "");
1169
1170 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1171 LocalFunctionIds[Func::getPrototype()] = NewFnId;
1172 Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
1173 }
1174
handleResponse(SequenceNumberT SeqNo)1175 Error handleResponse(SequenceNumberT SeqNo) {
1176 using Handler = typename decltype(PendingResponses)::mapped_type;
1177 Handler PRHandler;
1178
1179 {
1180 // Lock the pending responses map and sequence number manager.
1181 std::unique_lock<std::mutex> Lock(ResponsesMutex);
1182 auto I = PendingResponses.find(SeqNo);
1183
1184 if (I != PendingResponses.end()) {
1185 PRHandler = std::move(I->second);
1186 PendingResponses.erase(I);
1187 SequenceNumberMgr.releaseSequenceNumber(SeqNo);
1188 } else {
1189 // Unlock the pending results map to prevent recursive lock.
1190 Lock.unlock();
1191 abandonPendingResponses();
1192 return make_error<InvalidSequenceNumberForResponse<SequenceNumberT>>(
1193 SeqNo);
1194 }
1195 }
1196
1197 assert(PRHandler &&
1198 "If we didn't find a response handler we should have bailed out");
1199
1200 if (auto Err = PRHandler->handleResponse(C)) {
1201 abandonPendingResponses();
1202 return Err;
1203 }
1204
1205 return Error::success();
1206 }
1207
handleNegotiate(const std::string & Name)1208 FunctionIdT handleNegotiate(const std::string &Name) {
1209 auto I = LocalFunctionIds.find(Name);
1210 if (I == LocalFunctionIds.end())
1211 return getInvalidFunctionId();
1212 return I->second;
1213 }
1214
1215 // Find the remote FunctionId for the given function.
1216 template <typename Func>
getRemoteFunctionId(bool NegotiateIfNotInMap,bool NegotiateIfInvalid)1217 Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
1218 bool NegotiateIfInvalid) {
1219 bool DoNegotiate;
1220
1221 // Check if we already have a function id...
1222 auto I = RemoteFunctionIds.find(Func::getPrototype());
1223 if (I != RemoteFunctionIds.end()) {
1224 // If it's valid there's nothing left to do.
1225 if (I->second != getInvalidFunctionId())
1226 return I->second;
1227 DoNegotiate = NegotiateIfInvalid;
1228 } else
1229 DoNegotiate = NegotiateIfNotInMap;
1230
1231 // We don't have a function id for Func yet, but we're allowed to try to
1232 // negotiate one.
1233 if (DoNegotiate) {
1234 auto &Impl = static_cast<ImplT &>(*this);
1235 if (auto RemoteIdOrErr =
1236 Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
1237 RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
1238 if (*RemoteIdOrErr == getInvalidFunctionId())
1239 return make_error<CouldNotNegotiate>(Func::getPrototype());
1240 return *RemoteIdOrErr;
1241 } else
1242 return RemoteIdOrErr.takeError();
1243 }
1244
1245 // No key was available in the map and we weren't allowed to try to
1246 // negotiate one, so return an unknown function error.
1247 return make_error<CouldNotNegotiate>(Func::getPrototype());
1248 }
1249
1250 using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
1251
1252 // Wrap the given user handler in the necessary argument-deserialization code,
1253 // result-serialization code, and call to the launch policy (if present).
1254 template <typename Func, typename HandlerT>
wrapHandler(HandlerT Handler)1255 WrappedHandlerFn wrapHandler(HandlerT Handler) {
1256 return [this, Handler](ChannelT &Channel,
1257 SequenceNumberT SeqNo) mutable -> Error {
1258 // Start by deserializing the arguments.
1259 using ArgsTuple = typename detail::RPCFunctionArgsTuple<
1260 typename detail::HandlerTraits<HandlerT>::Type>::Type;
1261 auto Args = std::make_shared<ArgsTuple>();
1262
1263 if (auto Err =
1264 detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1265 Channel, *Args))
1266 return Err;
1267
1268 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1269 // for RPCArgs. Void cast RPCArgs to work around this for now.
1270 // FIXME: Remove this workaround once we can assume a working GCC version.
1271 (void)Args;
1272
1273 // End receieve message, unlocking the channel for reading.
1274 if (auto Err = Channel.endReceiveMessage())
1275 return Err;
1276
1277 using HTraits = detail::HandlerTraits<HandlerT>;
1278 using FuncReturn = typename Func::ReturnType;
1279 return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
1280 HTraits::unpackAndRun(Handler, *Args));
1281 };
1282 }
1283
1284 // Wrap the given user handler in the necessary argument-deserialization code,
1285 // result-serialization code, and call to the launch policy (if present).
1286 template <typename Func, typename HandlerT>
wrapAsyncHandler(HandlerT Handler)1287 WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
1288 return [this, Handler](ChannelT &Channel,
1289 SequenceNumberT SeqNo) mutable -> Error {
1290 // Start by deserializing the arguments.
1291 using AHTraits = detail::AsyncHandlerTraits<
1292 typename detail::HandlerTraits<HandlerT>::Type>;
1293 using ArgsTuple =
1294 typename detail::RPCFunctionArgsTuple<typename AHTraits::Type>::Type;
1295 auto Args = std::make_shared<ArgsTuple>();
1296
1297 if (auto Err =
1298 detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1299 Channel, *Args))
1300 return Err;
1301
1302 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1303 // for RPCArgs. Void cast RPCArgs to work around this for now.
1304 // FIXME: Remove this workaround once we can assume a working GCC version.
1305 (void)Args;
1306
1307 // End receieve message, unlocking the channel for reading.
1308 if (auto Err = Channel.endReceiveMessage())
1309 return Err;
1310
1311 using HTraits = detail::HandlerTraits<HandlerT>;
1312 using FuncReturn = typename Func::ReturnType;
1313 auto Responder = [this,
1314 SeqNo](typename AHTraits::ResultType RetVal) -> Error {
1315 return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
1316 std::move(RetVal));
1317 };
1318
1319 return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
1320 };
1321 }
1322
1323 ChannelT &C;
1324
1325 bool LazyAutoNegotiation;
1326
1327 RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator;
1328
1329 FunctionIdT ResponseId;
1330 std::map<std::string, FunctionIdT> LocalFunctionIds;
1331 std::map<const char *, FunctionIdT> RemoteFunctionIds;
1332
1333 std::map<FunctionIdT, WrappedHandlerFn> Handlers;
1334
1335 std::mutex ResponsesMutex;
1336 detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr;
1337 std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>>
1338 PendingResponses;
1339 };
1340
1341 } // end namespace detail
1342
1343 template <typename ChannelT, typename FunctionIdT = uint32_t,
1344 typename SequenceNumberT = uint32_t>
1345 class MultiThreadedRPCEndpoint
1346 : public detail::RPCEndpointBase<
1347 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1348 ChannelT, FunctionIdT, SequenceNumberT> {
1349 private:
1350 using BaseClass = detail::RPCEndpointBase<
1351 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1352 ChannelT, FunctionIdT, SequenceNumberT>;
1353
1354 public:
MultiThreadedRPCEndpoint(ChannelT & C,bool LazyAutoNegotiation)1355 MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1356 : BaseClass(C, LazyAutoNegotiation) {}
1357
1358 /// Add a handler for the given RPC function.
1359 /// This installs the given handler functor for the given RPCFunction, and
1360 /// makes the RPC function available for negotiation/calling from the remote.
1361 template <typename Func, typename HandlerT>
addHandler(HandlerT Handler)1362 void addHandler(HandlerT Handler) {
1363 return this->template addHandlerImpl<Func>(std::move(Handler));
1364 }
1365
1366 /// Add a class-method as a handler.
1367 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1368 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1369 addHandler<Func>(
1370 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1371 }
1372
1373 template <typename Func, typename HandlerT>
addAsyncHandler(HandlerT Handler)1374 void addAsyncHandler(HandlerT Handler) {
1375 return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1376 }
1377
1378 /// Add a class-method as a handler.
1379 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addAsyncHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1380 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1381 addAsyncHandler<Func>(
1382 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1383 }
1384
1385 /// Return type for non-blocking call primitives.
1386 template <typename Func>
1387 using NonBlockingCallResult = typename detail::ResultTraits<
1388 typename Func::ReturnType>::ReturnFutureType;
1389
1390 /// Call Func on Channel C. Does not block, does not call send. Returns a pair
1391 /// of a future result and the sequence number assigned to the result.
1392 ///
1393 /// This utility function is primarily used for single-threaded mode support,
1394 /// where the sequence number can be used to wait for the corresponding
1395 /// result. In multi-threaded mode the appendCallNB method, which does not
1396 /// return the sequence numeber, should be preferred.
1397 template <typename Func, typename... ArgTs>
appendCallNB(const ArgTs &...Args)1398 Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &...Args) {
1399 using RTraits = detail::ResultTraits<typename Func::ReturnType>;
1400 using ErrorReturn = typename RTraits::ErrorReturnType;
1401 using ErrorReturnPromise = typename RTraits::ReturnPromiseType;
1402
1403 ErrorReturnPromise Promise;
1404 auto FutureResult = Promise.get_future();
1405
1406 if (auto Err = this->template appendCallAsync<Func>(
1407 [Promise = std::move(Promise)](ErrorReturn RetOrErr) mutable {
1408 Promise.set_value(std::move(RetOrErr));
1409 return Error::success();
1410 },
1411 Args...)) {
1412 RTraits::consumeAbandoned(FutureResult.get());
1413 return std::move(Err);
1414 }
1415 return std::move(FutureResult);
1416 }
1417
1418 /// The same as appendCallNBWithSeq, except that it calls C.send() to
1419 /// flush the channel after serializing the call.
1420 template <typename Func, typename... ArgTs>
callNB(const ArgTs &...Args)1421 Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &...Args) {
1422 auto Result = appendCallNB<Func>(Args...);
1423 if (!Result)
1424 return Result;
1425 if (auto Err = this->C.send()) {
1426 this->abandonPendingResponses();
1427 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1428 std::move(Result->get()));
1429 return std::move(Err);
1430 }
1431 return Result;
1432 }
1433
1434 /// Call Func on Channel C. Blocks waiting for a result. Returns an Error
1435 /// for void functions or an Expected<T> for functions returning a T.
1436 ///
1437 /// This function is for use in threaded code where another thread is
1438 /// handling responses and incoming calls.
1439 template <typename Func, typename... ArgTs,
1440 typename AltRetT = typename Func::ReturnType>
1441 typename detail::ResultTraits<AltRetT>::ErrorReturnType
callB(const ArgTs &...Args)1442 callB(const ArgTs &...Args) {
1443 if (auto FutureResOrErr = callNB<Func>(Args...))
1444 return FutureResOrErr->get();
1445 else
1446 return FutureResOrErr.takeError();
1447 }
1448
1449 /// Handle incoming RPC calls.
handlerLoop()1450 Error handlerLoop() {
1451 while (true)
1452 if (auto Err = this->handleOne())
1453 return Err;
1454 return Error::success();
1455 }
1456 };
1457
1458 template <typename ChannelT, typename FunctionIdT = uint32_t,
1459 typename SequenceNumberT = uint32_t>
1460 class SingleThreadedRPCEndpoint
1461 : public detail::RPCEndpointBase<
1462 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1463 ChannelT, FunctionIdT, SequenceNumberT> {
1464 private:
1465 using BaseClass = detail::RPCEndpointBase<
1466 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1467 ChannelT, FunctionIdT, SequenceNumberT>;
1468
1469 public:
SingleThreadedRPCEndpoint(ChannelT & C,bool LazyAutoNegotiation)1470 SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1471 : BaseClass(C, LazyAutoNegotiation) {}
1472
1473 template <typename Func, typename HandlerT>
addHandler(HandlerT Handler)1474 void addHandler(HandlerT Handler) {
1475 return this->template addHandlerImpl<Func>(std::move(Handler));
1476 }
1477
1478 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1479 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1480 addHandler<Func>(
1481 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1482 }
1483
1484 template <typename Func, typename HandlerT>
addAsyncHandler(HandlerT Handler)1485 void addAsyncHandler(HandlerT Handler) {
1486 return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1487 }
1488
1489 /// Add a class-method as a handler.
1490 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addAsyncHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1491 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1492 addAsyncHandler<Func>(
1493 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1494 }
1495
1496 template <typename Func, typename... ArgTs,
1497 typename AltRetT = typename Func::ReturnType>
1498 typename detail::ResultTraits<AltRetT>::ErrorReturnType
callB(const ArgTs &...Args)1499 callB(const ArgTs &...Args) {
1500 bool ReceivedResponse = false;
1501 using AltRetTraits = detail::ResultTraits<AltRetT>;
1502 using ResultType = typename AltRetTraits::ErrorReturnType;
1503 ResultType Result = AltRetTraits::createBlankErrorReturnValue();
1504
1505 // We have to 'Check' result (which we know is in a success state at this
1506 // point) so that it can be overwritten in the async handler.
1507 (void)!!Result;
1508
1509 if (Error Err = this->template appendCallAsync<Func>(
1510 [&](ResultType R) {
1511 Result = std::move(R);
1512 ReceivedResponse = true;
1513 return Error::success();
1514 },
1515 Args...)) {
1516 AltRetTraits::consumeAbandoned(std::move(Result));
1517 return AltRetTraits::returnError(std::move(Err));
1518 }
1519
1520 if (Error Err = this->C.send()) {
1521 AltRetTraits::consumeAbandoned(std::move(Result));
1522 return AltRetTraits::returnError(std::move(Err));
1523 }
1524
1525 while (!ReceivedResponse) {
1526 if (Error Err = this->handleOne()) {
1527 AltRetTraits::consumeAbandoned(std::move(Result));
1528 return AltRetTraits::returnError(std::move(Err));
1529 }
1530 }
1531
1532 return Result;
1533 }
1534 };
1535
1536 /// Asynchronous dispatch for a function on an RPC endpoint.
1537 template <typename RPCClass, typename Func> class RPCAsyncDispatch {
1538 public:
RPCAsyncDispatch(RPCClass & Endpoint)1539 RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
1540
1541 template <typename HandlerT, typename... ArgTs>
operator()1542 Error operator()(HandlerT Handler, const ArgTs &...Args) const {
1543 return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
1544 }
1545
1546 private:
1547 RPCClass &Endpoint;
1548 };
1549
1550 /// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
1551 template <typename Func, typename RPCEndpointT>
rpcAsyncDispatch(RPCEndpointT & Endpoint)1552 RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
1553 return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
1554 }
1555
1556 /// Allows a set of asynchrounous calls to be dispatched, and then
1557 /// waited on as a group.
1558 class ParallelCallGroup {
1559 public:
1560 ParallelCallGroup() = default;
1561 ParallelCallGroup(const ParallelCallGroup &) = delete;
1562 ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
1563
1564 /// Make as asynchronous call.
1565 template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
call(const AsyncDispatcher & AsyncDispatch,HandlerT Handler,const ArgTs &...Args)1566 Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
1567 const ArgTs &...Args) {
1568 // Increment the count of outstanding calls. This has to happen before
1569 // we invoke the call, as the handler may (depending on scheduling)
1570 // be run immediately on another thread, and we don't want the decrement
1571 // in the wrapped handler below to run before the increment.
1572 {
1573 std::unique_lock<std::mutex> Lock(M);
1574 ++NumOutstandingCalls;
1575 }
1576
1577 // Wrap the user handler in a lambda that will decrement the
1578 // outstanding calls count, then poke the condition variable.
1579 using ArgType = typename detail::ResponseHandlerArg<
1580 typename detail::HandlerTraits<HandlerT>::Type>::ArgType;
1581 auto WrappedHandler = [this, Handler = std::move(Handler)](ArgType Arg) {
1582 auto Err = Handler(std::move(Arg));
1583 std::unique_lock<std::mutex> Lock(M);
1584 --NumOutstandingCalls;
1585 CV.notify_all();
1586 return Err;
1587 };
1588
1589 return AsyncDispatch(std::move(WrappedHandler), Args...);
1590 }
1591
1592 /// Blocks until all calls have been completed and their return value
1593 /// handlers run.
wait()1594 void wait() {
1595 std::unique_lock<std::mutex> Lock(M);
1596 while (NumOutstandingCalls > 0)
1597 CV.wait(Lock);
1598 }
1599
1600 private:
1601 std::mutex M;
1602 std::condition_variable CV;
1603 uint32_t NumOutstandingCalls = 0;
1604 };
1605
1606 /// Convenience class for grouping RPCFunctions into APIs that can be
1607 /// negotiated as a block.
1608 ///
1609 template <typename... Funcs> class APICalls {
1610 public:
1611 /// Test whether this API contains Function F.
1612 template <typename F> class Contains {
1613 public:
1614 static const bool value = false;
1615 };
1616
1617 /// Negotiate all functions in this API.
negotiate(RPCEndpoint & R)1618 template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) {
1619 return Error::success();
1620 }
1621 };
1622
1623 template <typename Func, typename... Funcs> class APICalls<Func, Funcs...> {
1624 public:
1625 template <typename F> class Contains {
1626 public:
1627 static const bool value = std::is_same<F, Func>::value |
1628 APICalls<Funcs...>::template Contains<F>::value;
1629 };
1630
negotiate(RPCEndpoint & R)1631 template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) {
1632 if (auto Err = R.template negotiateFunction<Func>())
1633 return Err;
1634 return APICalls<Funcs...>::negotiate(R);
1635 }
1636 };
1637
1638 template <typename... InnerFuncs, typename... Funcs>
1639 class APICalls<APICalls<InnerFuncs...>, Funcs...> {
1640 public:
1641 template <typename F> class Contains {
1642 public:
1643 static const bool value =
1644 APICalls<InnerFuncs...>::template Contains<F>::value |
1645 APICalls<Funcs...>::template Contains<F>::value;
1646 };
1647
negotiate(RPCEndpoint & R)1648 template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) {
1649 if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
1650 return Err;
1651 return APICalls<Funcs...>::negotiate(R);
1652 }
1653 };
1654
1655 } // end namespace shared
1656 } // end namespace orc
1657 } // end namespace llvm
1658
1659 #endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H
1660