xref: /netbsd-src/external/apache2/llvm/dist/llvm/include/llvm/ExecutionEngine/Orc/Shared/RPCUtils.h (revision 82d56013d7b633d116a93943de88e08335357a7c)
1 //===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utilities to support construction of simple RPC APIs.
10 //
11 // The RPC utilities aim for ease of use (minimal conceptual overhead) for C++
12 // programmers, high performance, low memory overhead, and efficient use of the
13 // communications channel.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H
18 #define LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H
19 
20 #include <map>
21 #include <thread>
22 #include <vector>
23 
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ExecutionEngine/Orc/Shared/OrcError.h"
26 #include "llvm/ExecutionEngine/Orc/Shared/Serialization.h"
27 #include "llvm/Support/MSVCErrorWorkarounds.h"
28 
29 #include <future>
30 
31 namespace llvm {
32 namespace orc {
33 namespace shared {
34 
35 /// Base class of all fatal RPC errors (those that necessarily result in the
36 /// termination of the RPC session).
37 class RPCFatalError : public ErrorInfo<RPCFatalError> {
38 public:
39   static char ID;
40 };
41 
42 /// RPCConnectionClosed is returned from RPC operations if the RPC connection
43 /// has already been closed due to either an error or graceful disconnection.
44 class ConnectionClosed : public ErrorInfo<ConnectionClosed> {
45 public:
46   static char ID;
47   std::error_code convertToErrorCode() const override;
48   void log(raw_ostream &OS) const override;
49 };
50 
51 /// BadFunctionCall is returned from handleOne when the remote makes a call with
52 /// an unrecognized function id.
53 ///
54 /// This error is fatal because Orc RPC needs to know how to parse a function
55 /// call to know where the next call starts, and if it doesn't recognize the
56 /// function id it cannot parse the call.
57 template <typename FnIdT, typename SeqNoT>
58 class BadFunctionCall
59     : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
60 public:
61   static char ID;
62 
BadFunctionCall(FnIdT FnId,SeqNoT SeqNo)63   BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
64       : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
65 
convertToErrorCode()66   std::error_code convertToErrorCode() const override {
67     return orcError(OrcErrorCode::UnexpectedRPCCall);
68   }
69 
log(raw_ostream & OS)70   void log(raw_ostream &OS) const override {
71     OS << "Call to invalid RPC function id '" << FnId
72        << "' with "
73           "sequence number "
74        << SeqNo;
75   }
76 
77 private:
78   FnIdT FnId;
79   SeqNoT SeqNo;
80 };
81 
82 template <typename FnIdT, typename SeqNoT>
83 char BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
84 
85 /// InvalidSequenceNumberForResponse is returned from handleOne when a response
86 /// call arrives with a sequence number that doesn't correspond to any in-flight
87 /// function call.
88 ///
89 /// This error is fatal because Orc RPC needs to know how to parse the rest of
90 /// the response call to know where the next call starts, and if it doesn't have
91 /// a result parser for this sequence number it can't do that.
92 template <typename SeqNoT>
93 class InvalidSequenceNumberForResponse
94     : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>,
95                        RPCFatalError> {
96 public:
97   static char ID;
98 
InvalidSequenceNumberForResponse(SeqNoT SeqNo)99   InvalidSequenceNumberForResponse(SeqNoT SeqNo) : SeqNo(std::move(SeqNo)) {}
100 
convertToErrorCode()101   std::error_code convertToErrorCode() const override {
102     return orcError(OrcErrorCode::UnexpectedRPCCall);
103   };
104 
log(raw_ostream & OS)105   void log(raw_ostream &OS) const override {
106     OS << "Response has unknown sequence number " << SeqNo;
107   }
108 
109 private:
110   SeqNoT SeqNo;
111 };
112 
113 template <typename SeqNoT>
114 char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
115 
116 /// This non-fatal error will be passed to asynchronous result handlers in place
117 /// of a result if the connection goes down before a result returns, or if the
118 /// function to be called cannot be negotiated with the remote.
119 class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
120 public:
121   static char ID;
122 
123   std::error_code convertToErrorCode() const override;
124   void log(raw_ostream &OS) const override;
125 };
126 
127 /// This error is returned if the remote does not have a handler installed for
128 /// the given RPC function.
129 class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
130 public:
131   static char ID;
132 
133   CouldNotNegotiate(std::string Signature);
134   std::error_code convertToErrorCode() const override;
135   void log(raw_ostream &OS) const override;
getSignature()136   const std::string &getSignature() const { return Signature; }
137 
138 private:
139   std::string Signature;
140 };
141 
142 template <typename DerivedFunc, typename FnT> class RPCFunction;
143 
144 // RPC Function class.
145 // DerivedFunc should be a user defined class with a static 'getName()' method
146 // returning a const char* representing the function's name.
147 template <typename DerivedFunc, typename RetT, typename... ArgTs>
148 class RPCFunction<DerivedFunc, RetT(ArgTs...)> {
149 public:
150   /// User defined function type.
151   using Type = RetT(ArgTs...);
152 
153   /// Return type.
154   using ReturnType = RetT;
155 
156   /// Returns the full function prototype as a string.
getPrototype()157   static const char *getPrototype() {
158     static std::string Name = [] {
159       std::string Name;
160       raw_string_ostream(Name)
161           << SerializationTypeName<RetT>::getName() << " "
162           << DerivedFunc::getName() << "("
163           << SerializationTypeNameSequence<ArgTs...>() << ")";
164       return Name;
165     }();
166     return Name.data();
167   }
168 };
169 
170 /// Allocates RPC function ids during autonegotiation.
171 /// Specializations of this class must provide four members:
172 ///
173 /// static T getInvalidId():
174 ///   Should return a reserved id that will be used to represent missing
175 /// functions during autonegotiation.
176 ///
177 /// static T getResponseId():
178 ///   Should return a reserved id that will be used to send function responses
179 /// (return values).
180 ///
181 /// static T getNegotiateId():
182 ///   Should return a reserved id for the negotiate function, which will be used
183 /// to negotiate ids for user defined functions.
184 ///
185 /// template <typename Func> T allocate():
186 ///   Allocate a unique id for function Func.
187 template <typename T, typename = void> class RPCFunctionIdAllocator;
188 
189 /// This specialization of RPCFunctionIdAllocator provides a default
190 /// implementation for integral types.
191 template <typename T>
192 class RPCFunctionIdAllocator<T, std::enable_if_t<std::is_integral<T>::value>> {
193 public:
getInvalidId()194   static T getInvalidId() { return T(0); }
getResponseId()195   static T getResponseId() { return T(1); }
getNegotiateId()196   static T getNegotiateId() { return T(2); }
197 
allocate()198   template <typename Func> T allocate() { return NextId++; }
199 
200 private:
201   T NextId = 3;
202 };
203 
204 namespace detail {
205 
206 /// Provides a typedef for a tuple containing the decayed argument types.
207 template <typename T> class RPCFunctionArgsTuple;
208 
209 template <typename RetT, typename... ArgTs>
210 class RPCFunctionArgsTuple<RetT(ArgTs...)> {
211 public:
212   using Type = std::tuple<std::decay_t<std::remove_reference_t<ArgTs>>...>;
213 };
214 
215 // ResultTraits provides typedefs and utilities specific to the return type
216 // of functions.
217 template <typename RetT> class ResultTraits {
218 public:
219   // The return type wrapped in llvm::Expected.
220   using ErrorReturnType = Expected<RetT>;
221 
222 #ifdef _MSC_VER
223   // The ErrorReturnType wrapped in a std::promise.
224   using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>;
225 
226   // The ErrorReturnType wrapped in a std::future.
227   using ReturnFutureType = std::future<MSVCPExpected<RetT>>;
228 #else
229   // The ErrorReturnType wrapped in a std::promise.
230   using ReturnPromiseType = std::promise<ErrorReturnType>;
231 
232   // The ErrorReturnType wrapped in a std::future.
233   using ReturnFutureType = std::future<ErrorReturnType>;
234 #endif
235 
236   // Create a 'blank' value of the ErrorReturnType, ready and safe to
237   // overwrite.
createBlankErrorReturnValue()238   static ErrorReturnType createBlankErrorReturnValue() {
239     return ErrorReturnType(RetT());
240   }
241 
242   // Consume an abandoned ErrorReturnType.
consumeAbandoned(ErrorReturnType RetOrErr)243   static void consumeAbandoned(ErrorReturnType RetOrErr) {
244     consumeError(RetOrErr.takeError());
245   }
246 
returnError(Error Err)247   static ErrorReturnType returnError(Error Err) { return std::move(Err); }
248 };
249 
250 // ResultTraits specialization for void functions.
251 template <> class ResultTraits<void> {
252 public:
253   // For void functions, ErrorReturnType is llvm::Error.
254   using ErrorReturnType = Error;
255 
256 #ifdef _MSC_VER
257   // The ErrorReturnType wrapped in a std::promise.
258   using ReturnPromiseType = std::promise<MSVCPError>;
259 
260   // The ErrorReturnType wrapped in a std::future.
261   using ReturnFutureType = std::future<MSVCPError>;
262 #else
263   // The ErrorReturnType wrapped in a std::promise.
264   using ReturnPromiseType = std::promise<ErrorReturnType>;
265 
266   // The ErrorReturnType wrapped in a std::future.
267   using ReturnFutureType = std::future<ErrorReturnType>;
268 #endif
269 
270   // Create a 'blank' value of the ErrorReturnType, ready and safe to
271   // overwrite.
createBlankErrorReturnValue()272   static ErrorReturnType createBlankErrorReturnValue() {
273     return ErrorReturnType::success();
274   }
275 
276   // Consume an abandoned ErrorReturnType.
consumeAbandoned(ErrorReturnType Err)277   static void consumeAbandoned(ErrorReturnType Err) {
278     consumeError(std::move(Err));
279   }
280 
returnError(Error Err)281   static ErrorReturnType returnError(Error Err) { return Err; }
282 };
283 
284 // ResultTraits<Error> is equivalent to ResultTraits<void>. This allows
285 // handlers for void RPC functions to return either void (in which case they
286 // implicitly succeed) or Error (in which case their error return is
287 // propagated). See usage in HandlerTraits::runHandlerHelper.
288 template <> class ResultTraits<Error> : public ResultTraits<void> {};
289 
290 // ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows
291 // handlers for RPC functions returning a T to return either a T (in which
292 // case they implicitly succeed) or Expected<T> (in which case their error
293 // return is propagated). See usage in HandlerTraits::runHandlerHelper.
294 template <typename RetT>
295 class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
296 
297 // Determines whether an RPC function's defined error return type supports
298 // error return value.
299 template <typename T> class SupportsErrorReturn {
300 public:
301   static const bool value = false;
302 };
303 
304 template <> class SupportsErrorReturn<Error> {
305 public:
306   static const bool value = true;
307 };
308 
309 template <typename T> class SupportsErrorReturn<Expected<T>> {
310 public:
311   static const bool value = true;
312 };
313 
314 // RespondHelper packages return values based on whether or not the declared
315 // RPC function return type supports error returns.
316 template <bool FuncSupportsErrorReturn> class RespondHelper;
317 
318 // RespondHelper specialization for functions that support error returns.
319 template <> class RespondHelper<true> {
320 public:
321   // Send Expected<T>.
322   template <typename WireRetT, typename HandlerRetT, typename ChannelT,
323             typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)324   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
325                           SequenceNumberT SeqNo,
326                           Expected<HandlerRetT> ResultOrErr) {
327     if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
328       return ResultOrErr.takeError();
329 
330     // Open the response message.
331     if (auto Err = C.startSendMessage(ResponseId, SeqNo))
332       return Err;
333 
334     // Serialize the result.
335     if (auto Err =
336             SerializationTraits<ChannelT, WireRetT, Expected<HandlerRetT>>::
337                 serialize(C, std::move(ResultOrErr)))
338       return Err;
339 
340     // Close the response message.
341     if (auto Err = C.endSendMessage())
342       return Err;
343     return C.send();
344   }
345 
346   template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)347   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
348                           SequenceNumberT SeqNo, Error Err) {
349     if (Err && Err.isA<RPCFatalError>())
350       return Err;
351     if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
352       return Err2;
353     if (auto Err2 = serializeSeq(C, std::move(Err)))
354       return Err2;
355     if (auto Err2 = C.endSendMessage())
356       return Err2;
357     return C.send();
358   }
359 };
360 
361 // RespondHelper specialization for functions that do not support error returns.
362 template <> class RespondHelper<false> {
363 public:
364   template <typename WireRetT, typename HandlerRetT, typename ChannelT,
365             typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)366   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
367                           SequenceNumberT SeqNo,
368                           Expected<HandlerRetT> ResultOrErr) {
369     if (auto Err = ResultOrErr.takeError())
370       return Err;
371 
372     // Open the response message.
373     if (auto Err = C.startSendMessage(ResponseId, SeqNo))
374       return Err;
375 
376     // Serialize the result.
377     if (auto Err =
378             SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
379                 C, *ResultOrErr))
380       return Err;
381 
382     // End the response message.
383     if (auto Err = C.endSendMessage())
384       return Err;
385 
386     return C.send();
387   }
388 
389   template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)390   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
391                           SequenceNumberT SeqNo, Error Err) {
392     if (Err)
393       return Err;
394     if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
395       return Err2;
396     if (auto Err2 = C.endSendMessage())
397       return Err2;
398     return C.send();
399   }
400 };
401 
402 // Send a response of the given wire return type (WireRetT) over the
403 // channel, with the given sequence number.
404 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
405           typename FunctionIdT, typename SequenceNumberT>
respond(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)406 Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
407               Expected<HandlerRetT> ResultOrErr) {
408   return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
409       template sendResult<WireRetT>(C, ResponseId, SeqNo,
410                                     std::move(ResultOrErr));
411 }
412 
413 // Send an empty response message on the given channel to indicate that
414 // the handler ran.
415 template <typename WireRetT, typename ChannelT, typename FunctionIdT,
416           typename SequenceNumberT>
respond(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)417 Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
418               Error Err) {
419   return RespondHelper<SupportsErrorReturn<WireRetT>::value>::sendResult(
420       C, ResponseId, SeqNo, std::move(Err));
421 }
422 
423 // Converts a given type to the equivalent error return type.
424 template <typename T> class WrappedHandlerReturn {
425 public:
426   using Type = Expected<T>;
427 };
428 
429 template <typename T> class WrappedHandlerReturn<Expected<T>> {
430 public:
431   using Type = Expected<T>;
432 };
433 
434 template <> class WrappedHandlerReturn<void> {
435 public:
436   using Type = Error;
437 };
438 
439 template <> class WrappedHandlerReturn<Error> {
440 public:
441   using Type = Error;
442 };
443 
444 template <> class WrappedHandlerReturn<ErrorSuccess> {
445 public:
446   using Type = Error;
447 };
448 
449 // Traits class that strips the response function from the list of handler
450 // arguments.
451 template <typename FnT> class AsyncHandlerTraits;
452 
453 template <typename ResultT, typename... ArgTs>
454 class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>,
455                                ArgTs...)> {
456 public:
457   using Type = Error(ArgTs...);
458   using ResultType = Expected<ResultT>;
459 };
460 
461 template <typename... ArgTs>
462 class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
463 public:
464   using Type = Error(ArgTs...);
465   using ResultType = Error;
466 };
467 
468 template <typename... ArgTs>
469 class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> {
470 public:
471   using Type = Error(ArgTs...);
472   using ResultType = Error;
473 };
474 
475 template <typename... ArgTs>
476 class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> {
477 public:
478   using Type = Error(ArgTs...);
479   using ResultType = Error;
480 };
481 
482 template <typename ResponseHandlerT, typename... ArgTs>
483 class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)>
484     : public AsyncHandlerTraits<Error(std::decay_t<ResponseHandlerT>,
485                                       ArgTs...)> {};
486 
487 // This template class provides utilities related to RPC function handlers.
488 // The base case applies to non-function types (the template class is
489 // specialized for function types) and inherits from the appropriate
490 // speciilization for the given non-function type's call operator.
491 template <typename HandlerT>
492 class HandlerTraits
493     : public HandlerTraits<
494           decltype(&std::remove_reference<HandlerT>::type::operator())> {};
495 
496 // Traits for handlers with a given function type.
497 template <typename RetT, typename... ArgTs>
498 class HandlerTraits<RetT(ArgTs...)> {
499 public:
500   // Function type of the handler.
501   using Type = RetT(ArgTs...);
502 
503   // Return type of the handler.
504   using ReturnType = RetT;
505 
506   // Call the given handler with the given arguments.
507   template <typename HandlerT, typename... TArgTs>
508   static typename WrappedHandlerReturn<RetT>::Type
unpackAndRun(HandlerT & Handler,std::tuple<TArgTs...> & Args)509   unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
510     return unpackAndRunHelper(Handler, Args,
511                               std::index_sequence_for<TArgTs...>());
512   }
513 
514   // Call the given handler with the given arguments.
515   template <typename HandlerT, typename ResponderT, typename... TArgTs>
unpackAndRunAsync(HandlerT & Handler,ResponderT & Responder,std::tuple<TArgTs...> & Args)516   static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
517                                  std::tuple<TArgTs...> &Args) {
518     return unpackAndRunAsyncHelper(Handler, Responder, Args,
519                                    std::index_sequence_for<TArgTs...>());
520   }
521 
522   // Call the given handler with the given arguments.
523   template <typename HandlerT>
524   static std::enable_if_t<
525       std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, Error>
run(HandlerT & Handler,ArgTs &&...Args)526   run(HandlerT &Handler, ArgTs &&...Args) {
527     Handler(std::move(Args)...);
528     return Error::success();
529   }
530 
531   template <typename HandlerT, typename... TArgTs>
532   static std::enable_if_t<
533       !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
534       typename HandlerTraits<HandlerT>::ReturnType>
run(HandlerT & Handler,TArgTs...Args)535   run(HandlerT &Handler, TArgTs... Args) {
536     return Handler(std::move(Args)...);
537   }
538 
539   // Serialize arguments to the channel.
540   template <typename ChannelT, typename... CArgTs>
serializeArgs(ChannelT & C,const CArgTs...CArgs)541   static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) {
542     return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
543   }
544 
545   // Deserialize arguments from the channel.
546   template <typename ChannelT, typename... CArgTs>
deserializeArgs(ChannelT & C,std::tuple<CArgTs...> & Args)547   static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) {
548     return deserializeArgsHelper(C, Args, std::index_sequence_for<CArgTs...>());
549   }
550 
551 private:
552   template <typename ChannelT, typename... CArgTs, size_t... Indexes>
deserializeArgsHelper(ChannelT & C,std::tuple<CArgTs...> & Args,std::index_sequence<Indexes...> _)553   static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args,
554                                      std::index_sequence<Indexes...> _) {
555     return SequenceSerialization<ChannelT, ArgTs...>::deserialize(
556         C, std::get<Indexes>(Args)...);
557   }
558 
559   template <typename HandlerT, typename ArgTuple, size_t... Indexes>
560   static typename WrappedHandlerReturn<
561       typename HandlerTraits<HandlerT>::ReturnType>::Type
unpackAndRunHelper(HandlerT & Handler,ArgTuple & Args,std::index_sequence<Indexes...>)562   unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
563                      std::index_sequence<Indexes...>) {
564     return run(Handler, std::move(std::get<Indexes>(Args))...);
565   }
566 
567   template <typename HandlerT, typename ResponderT, typename ArgTuple,
568             size_t... Indexes>
569   static typename WrappedHandlerReturn<
570       typename HandlerTraits<HandlerT>::ReturnType>::Type
unpackAndRunAsyncHelper(HandlerT & Handler,ResponderT & Responder,ArgTuple & Args,std::index_sequence<Indexes...>)571   unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
572                           ArgTuple &Args, std::index_sequence<Indexes...>) {
573     return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
574   }
575 };
576 
577 // Handler traits for free functions.
578 template <typename RetT, typename... ArgTs>
579 class HandlerTraits<RetT (*)(ArgTs...)> : public HandlerTraits<RetT(ArgTs...)> {
580 };
581 
582 // Handler traits for class methods (especially call operators for lambdas).
583 template <typename Class, typename RetT, typename... ArgTs>
584 class HandlerTraits<RetT (Class::*)(ArgTs...)>
585     : public HandlerTraits<RetT(ArgTs...)> {};
586 
587 // Handler traits for const class methods (especially call operators for
588 // lambdas).
589 template <typename Class, typename RetT, typename... ArgTs>
590 class HandlerTraits<RetT (Class::*)(ArgTs...) const>
591     : public HandlerTraits<RetT(ArgTs...)> {};
592 
593 // Utility to peel the Expected wrapper off a response handler error type.
594 template <typename HandlerT> class ResponseHandlerArg;
595 
596 template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> {
597 public:
598   using ArgType = Expected<ArgT>;
599   using UnwrappedArgType = ArgT;
600 };
601 
602 template <typename ArgT>
603 class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
604 public:
605   using ArgType = Expected<ArgT>;
606   using UnwrappedArgType = ArgT;
607 };
608 
609 template <> class ResponseHandlerArg<Error(Error)> {
610 public:
611   using ArgType = Error;
612 };
613 
614 template <> class ResponseHandlerArg<ErrorSuccess(Error)> {
615 public:
616   using ArgType = Error;
617 };
618 
619 // ResponseHandler represents a handler for a not-yet-received function call
620 // result.
621 template <typename ChannelT> class ResponseHandler {
622 public:
~ResponseHandler()623   virtual ~ResponseHandler() {}
624 
625   // Reads the function result off the wire and acts on it. The meaning of
626   // "act" will depend on how this method is implemented in any given
627   // ResponseHandler subclass but could, for example, mean running a
628   // user-specified handler or setting a promise value.
629   virtual Error handleResponse(ChannelT &C) = 0;
630 
631   // Abandons this outstanding result.
632   virtual void abandon() = 0;
633 
634   // Create an error instance representing an abandoned response.
createAbandonedResponseError()635   static Error createAbandonedResponseError() {
636     return make_error<ResponseAbandoned>();
637   }
638 };
639 
640 // ResponseHandler subclass for RPC functions with non-void returns.
641 template <typename ChannelT, typename FuncRetT, typename HandlerT>
642 class ResponseHandlerImpl : public ResponseHandler<ChannelT> {
643 public:
ResponseHandlerImpl(HandlerT Handler)644   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
645 
646   // Handle the result by deserializing it from the channel then passing it
647   // to the user defined handler.
handleResponse(ChannelT & C)648   Error handleResponse(ChannelT &C) override {
649     using UnwrappedArgType = typename ResponseHandlerArg<
650         typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType;
651     UnwrappedArgType Result;
652     if (auto Err =
653             SerializationTraits<ChannelT, FuncRetT,
654                                 UnwrappedArgType>::deserialize(C, Result))
655       return Err;
656     if (auto Err = C.endReceiveMessage())
657       return Err;
658     return Handler(std::move(Result));
659   }
660 
661   // Abandon this response by calling the handler with an 'abandoned response'
662   // error.
abandon()663   void abandon() override {
664     if (auto Err = Handler(this->createAbandonedResponseError())) {
665       // Handlers should not fail when passed an abandoned response error.
666       report_fatal_error(std::move(Err));
667     }
668   }
669 
670 private:
671   HandlerT Handler;
672 };
673 
674 // ResponseHandler subclass for RPC functions with void returns.
675 template <typename ChannelT, typename HandlerT>
676 class ResponseHandlerImpl<ChannelT, void, HandlerT>
677     : public ResponseHandler<ChannelT> {
678 public:
ResponseHandlerImpl(HandlerT Handler)679   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
680 
681   // Handle the result (no actual value, just a notification that the function
682   // has completed on the remote end) by calling the user-defined handler with
683   // Error::success().
handleResponse(ChannelT & C)684   Error handleResponse(ChannelT &C) override {
685     if (auto Err = C.endReceiveMessage())
686       return Err;
687     return Handler(Error::success());
688   }
689 
690   // Abandon this response by calling the handler with an 'abandoned response'
691   // error.
abandon()692   void abandon() override {
693     if (auto Err = Handler(this->createAbandonedResponseError())) {
694       // Handlers should not fail when passed an abandoned response error.
695       report_fatal_error(std::move(Err));
696     }
697   }
698 
699 private:
700   HandlerT Handler;
701 };
702 
703 template <typename ChannelT, typename FuncRetT, typename HandlerT>
704 class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
705     : public ResponseHandler<ChannelT> {
706 public:
ResponseHandlerImpl(HandlerT Handler)707   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
708 
709   // Handle the result by deserializing it from the channel then passing it
710   // to the user defined handler.
handleResponse(ChannelT & C)711   Error handleResponse(ChannelT &C) override {
712     using HandlerArgType = typename ResponseHandlerArg<
713         typename HandlerTraits<HandlerT>::Type>::ArgType;
714     HandlerArgType Result((typename HandlerArgType::value_type()));
715 
716     if (auto Err = SerializationTraits<ChannelT, Expected<FuncRetT>,
717                                        HandlerArgType>::deserialize(C, Result))
718       return Err;
719     if (auto Err = C.endReceiveMessage())
720       return Err;
721     return Handler(std::move(Result));
722   }
723 
724   // Abandon this response by calling the handler with an 'abandoned response'
725   // error.
abandon()726   void abandon() override {
727     if (auto Err = Handler(this->createAbandonedResponseError())) {
728       // Handlers should not fail when passed an abandoned response error.
729       report_fatal_error(std::move(Err));
730     }
731   }
732 
733 private:
734   HandlerT Handler;
735 };
736 
737 template <typename ChannelT, typename HandlerT>
738 class ResponseHandlerImpl<ChannelT, Error, HandlerT>
739     : public ResponseHandler<ChannelT> {
740 public:
ResponseHandlerImpl(HandlerT Handler)741   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
742 
743   // Handle the result by deserializing it from the channel then passing it
744   // to the user defined handler.
handleResponse(ChannelT & C)745   Error handleResponse(ChannelT &C) override {
746     Error Result = Error::success();
747     if (auto Err = SerializationTraits<ChannelT, Error, Error>::deserialize(
748             C, Result)) {
749       consumeError(std::move(Result));
750       return Err;
751     }
752     if (auto Err = C.endReceiveMessage()) {
753       consumeError(std::move(Result));
754       return Err;
755     }
756     return Handler(std::move(Result));
757   }
758 
759   // Abandon this response by calling the handler with an 'abandoned response'
760   // error.
abandon()761   void abandon() override {
762     if (auto Err = Handler(this->createAbandonedResponseError())) {
763       // Handlers should not fail when passed an abandoned response error.
764       report_fatal_error(std::move(Err));
765     }
766   }
767 
768 private:
769   HandlerT Handler;
770 };
771 
772 // Create a ResponseHandler from a given user handler.
773 template <typename ChannelT, typename FuncRetT, typename HandlerT>
createResponseHandler(HandlerT H)774 std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
775   return std::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>(
776       std::move(H));
777 }
778 
779 // Helper for wrapping member functions up as functors. This is useful for
780 // installing methods as result handlers.
781 template <typename ClassT, typename RetT, typename... ArgTs>
782 class MemberFnWrapper {
783 public:
784   using MethodT = RetT (ClassT::*)(ArgTs...);
MemberFnWrapper(ClassT & Instance,MethodT Method)785   MemberFnWrapper(ClassT &Instance, MethodT Method)
786       : Instance(Instance), Method(Method) {}
operator()787   RetT operator()(ArgTs &&...Args) {
788     return (Instance.*Method)(std::move(Args)...);
789   }
790 
791 private:
792   ClassT &Instance;
793   MethodT Method;
794 };
795 
796 // Helper that provides a Functor for deserializing arguments.
797 template <typename... ArgTs> class ReadArgs {
798 public:
operator()799   Error operator()() { return Error::success(); }
800 };
801 
802 template <typename ArgT, typename... ArgTs>
803 class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> {
804 public:
ReadArgs(ArgT & Arg,ArgTs &...Args)805   ReadArgs(ArgT &Arg, ArgTs &...Args) : ReadArgs<ArgTs...>(Args...), Arg(Arg) {}
806 
operator()807   Error operator()(ArgT &ArgVal, ArgTs &...ArgVals) {
808     this->Arg = std::move(ArgVal);
809     return ReadArgs<ArgTs...>::operator()(ArgVals...);
810   }
811 
812 private:
813   ArgT &Arg;
814 };
815 
816 // Manage sequence numbers.
817 template <typename SequenceNumberT> class SequenceNumberManager {
818 public:
819   // Reset, making all sequence numbers available.
reset()820   void reset() {
821     std::lock_guard<std::mutex> Lock(SeqNoLock);
822     NextSequenceNumber = 0;
823     FreeSequenceNumbers.clear();
824   }
825 
826   // Get the next available sequence number. Will re-use numbers that have
827   // been released.
getSequenceNumber()828   SequenceNumberT getSequenceNumber() {
829     std::lock_guard<std::mutex> Lock(SeqNoLock);
830     if (FreeSequenceNumbers.empty())
831       return NextSequenceNumber++;
832     auto SequenceNumber = FreeSequenceNumbers.back();
833     FreeSequenceNumbers.pop_back();
834     return SequenceNumber;
835   }
836 
837   // Release a sequence number, making it available for re-use.
releaseSequenceNumber(SequenceNumberT SequenceNumber)838   void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
839     std::lock_guard<std::mutex> Lock(SeqNoLock);
840     FreeSequenceNumbers.push_back(SequenceNumber);
841   }
842 
843 private:
844   std::mutex SeqNoLock;
845   SequenceNumberT NextSequenceNumber = 0;
846   std::vector<SequenceNumberT> FreeSequenceNumbers;
847 };
848 
849 // Checks that predicate P holds for each corresponding pair of type arguments
850 // from T1 and T2 tuple.
851 template <template <class, class> class P, typename T1Tuple, typename T2Tuple>
852 class RPCArgTypeCheckHelper;
853 
854 template <template <class, class> class P>
855 class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
856 public:
857   static const bool value = true;
858 };
859 
860 template <template <class, class> class P, typename T, typename... Ts,
861           typename U, typename... Us>
862 class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
863 public:
864   static const bool value =
865       P<T, U>::value &&
866       RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
867 };
868 
869 template <template <class, class> class P, typename T1Sig, typename T2Sig>
870 class RPCArgTypeCheck {
871 public:
872   using T1Tuple = typename RPCFunctionArgsTuple<T1Sig>::Type;
873   using T2Tuple = typename RPCFunctionArgsTuple<T2Sig>::Type;
874 
875   static_assert(std::tuple_size<T1Tuple>::value >=
876                     std::tuple_size<T2Tuple>::value,
877                 "Too many arguments to RPC call");
878   static_assert(std::tuple_size<T1Tuple>::value <=
879                     std::tuple_size<T2Tuple>::value,
880                 "Too few arguments to RPC call");
881 
882   static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
883 };
884 
885 template <typename ChannelT, typename WireT, typename ConcreteT>
886 class CanSerialize {
887 private:
888   using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
889 
890   template <typename T>
891   static std::true_type check(
892       std::enable_if_t<std::is_same<decltype(T::serialize(
893                                         std::declval<ChannelT &>(),
894                                         std::declval<const ConcreteT &>())),
895                                     Error>::value,
896                        void *>);
897 
898   template <typename> static std::false_type check(...);
899 
900 public:
901   static const bool value = decltype(check<S>(0))::value;
902 };
903 
904 template <typename ChannelT, typename WireT, typename ConcreteT>
905 class CanDeserialize {
906 private:
907   using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
908 
909   template <typename T>
910   static std::true_type
911       check(std::enable_if_t<
912             std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(),
913                                                  std::declval<ConcreteT &>())),
914                          Error>::value,
915             void *>);
916 
917   template <typename> static std::false_type check(...);
918 
919 public:
920   static const bool value = decltype(check<S>(0))::value;
921 };
922 
923 /// Contains primitive utilities for defining, calling and handling calls to
924 /// remote procedures. ChannelT is a bidirectional stream conforming to the
925 /// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
926 /// identifier type that must be serializable on ChannelT, and SequenceNumberT
927 /// is an integral type that will be used to number in-flight function calls.
928 ///
929 /// These utilities support the construction of very primitive RPC utilities.
930 /// Their intent is to ensure correct serialization and deserialization of
931 /// procedure arguments, and to keep the client and server's view of the API in
932 /// sync.
933 template <typename ImplT, typename ChannelT, typename FunctionIdT,
934           typename SequenceNumberT>
935 class RPCEndpointBase {
936 protected:
937   class OrcRPCInvalid : public RPCFunction<OrcRPCInvalid, void()> {
938   public:
getName()939     static const char *getName() { return "__orc_rpc$invalid"; }
940   };
941 
942   class OrcRPCResponse : public RPCFunction<OrcRPCResponse, void()> {
943   public:
getName()944     static const char *getName() { return "__orc_rpc$response"; }
945   };
946 
947   class OrcRPCNegotiate
948       : public RPCFunction<OrcRPCNegotiate, FunctionIdT(std::string)> {
949   public:
getName()950     static const char *getName() { return "__orc_rpc$negotiate"; }
951   };
952 
953   // Helper predicate for testing for the presence of SerializeTraits
954   // serializers.
955   template <typename WireT, typename ConcreteT>
956   class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
957   public:
958     using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
959 
960     static_assert(value, "Missing serializer for argument (Can't serialize the "
961                          "first template type argument of CanSerializeCheck "
962                          "from the second)");
963   };
964 
965   // Helper predicate for testing for the presence of SerializeTraits
966   // deserializers.
967   template <typename WireT, typename ConcreteT>
968   class CanDeserializeCheck
969       : detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
970   public:
971     using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
972 
973     static_assert(value, "Missing deserializer for argument (Can't deserialize "
974                          "the second template type argument of "
975                          "CanDeserializeCheck from the first)");
976   };
977 
978 public:
979   /// Construct an RPC instance on a channel.
RPCEndpointBase(ChannelT & C,bool LazyAutoNegotiation)980   RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation)
981       : C(C), LazyAutoNegotiation(LazyAutoNegotiation) {
982     // Hold ResponseId in a special variable, since we expect Response to be
983     // called relatively frequently, and want to avoid the map lookup.
984     ResponseId = FnIdAllocator.getResponseId();
985     RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId;
986 
987     // Register the negotiate function id and handler.
988     auto NegotiateId = FnIdAllocator.getNegotiateId();
989     RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
990     Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
991         [this](const std::string &Name) { return handleNegotiate(Name); });
992   }
993 
994   /// Negotiate a function id for Func with the other end of the channel.
995   template <typename Func> Error negotiateFunction(bool Retry = false) {
996     return getRemoteFunctionId<Func>(true, Retry).takeError();
997   }
998 
999   /// Append a call Func, does not call send on the channel.
1000   /// The first argument specifies a user-defined handler to be run when the
1001   /// function returns. The handler should take an Expected<Func::ReturnType>,
1002   /// or an Error (if Func::ReturnType is void). The handler will be called
1003   /// with an error if the return value is abandoned due to a channel error.
1004   template <typename Func, typename HandlerT, typename... ArgTs>
appendCallAsync(HandlerT Handler,const ArgTs &...Args)1005   Error appendCallAsync(HandlerT Handler, const ArgTs &...Args) {
1006 
1007     static_assert(
1008         detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
1009                                 void(ArgTs...)>::value,
1010         "");
1011 
1012     // Look up the function ID.
1013     FunctionIdT FnId;
1014     if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
1015       FnId = *FnIdOrErr;
1016     else {
1017       // Negotiation failed. Notify the handler then return the negotiate-failed
1018       // error.
1019       cantFail(Handler(make_error<ResponseAbandoned>()));
1020       return FnIdOrErr.takeError();
1021     }
1022 
1023     SequenceNumberT SeqNo; // initialized in locked scope below.
1024     {
1025       // Lock the pending responses map and sequence number manager.
1026       std::lock_guard<std::mutex> Lock(ResponsesMutex);
1027 
1028       // Allocate a sequence number.
1029       SeqNo = SequenceNumberMgr.getSequenceNumber();
1030       assert(!PendingResponses.count(SeqNo) &&
1031              "Sequence number already allocated");
1032 
1033       // Install the user handler.
1034       PendingResponses[SeqNo] =
1035           detail::createResponseHandler<ChannelT, typename Func::ReturnType>(
1036               std::move(Handler));
1037     }
1038 
1039     // Open the function call message.
1040     if (auto Err = C.startSendMessage(FnId, SeqNo)) {
1041       abandonPendingResponses();
1042       return Err;
1043     }
1044 
1045     // Serialize the call arguments.
1046     if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
1047             C, Args...)) {
1048       abandonPendingResponses();
1049       return Err;
1050     }
1051 
1052     // Close the function call messagee.
1053     if (auto Err = C.endSendMessage()) {
1054       abandonPendingResponses();
1055       return Err;
1056     }
1057 
1058     return Error::success();
1059   }
1060 
sendAppendedCalls()1061   Error sendAppendedCalls() { return C.send(); };
1062 
1063   template <typename Func, typename HandlerT, typename... ArgTs>
callAsync(HandlerT Handler,const ArgTs &...Args)1064   Error callAsync(HandlerT Handler, const ArgTs &...Args) {
1065     if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...))
1066       return Err;
1067     return C.send();
1068   }
1069 
1070   /// Handle one incoming call.
handleOne()1071   Error handleOne() {
1072     FunctionIdT FnId;
1073     SequenceNumberT SeqNo;
1074     if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
1075       abandonPendingResponses();
1076       return Err;
1077     }
1078     if (FnId == ResponseId)
1079       return handleResponse(SeqNo);
1080     auto I = Handlers.find(FnId);
1081     if (I != Handlers.end())
1082       return I->second(C, SeqNo);
1083 
1084     // else: No handler found. Report error to client?
1085     return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
1086                                                                      SeqNo);
1087   }
1088 
1089   /// Helper for handling setter procedures - this method returns a functor that
1090   /// sets the variables referred to by Args... to values deserialized from the
1091   /// channel.
1092   /// E.g.
1093   ///
1094   ///   typedef Function<0, bool, int> Func1;
1095   ///
1096   ///   ...
1097   ///   bool B;
1098   ///   int I;
1099   ///   if (auto Err = expect<Func1>(Channel, readArgs(B, I)))
1100   ///     /* Handle Args */ ;
1101   ///
1102   template <typename... ArgTs>
readArgs(ArgTs &...Args)1103   static detail::ReadArgs<ArgTs...> readArgs(ArgTs &...Args) {
1104     return detail::ReadArgs<ArgTs...>(Args...);
1105   }
1106 
1107   /// Abandon all outstanding result handlers.
1108   ///
1109   /// This will call all currently registered result handlers to receive an
1110   /// "abandoned" error as their argument. This is used internally by the RPC
1111   /// in error situations, but can also be called directly by clients who are
1112   /// disconnecting from the remote and don't or can't expect responses to their
1113   /// outstanding calls. (Especially for outstanding blocking calls, calling
1114   /// this function may be necessary to avoid dead threads).
abandonPendingResponses()1115   void abandonPendingResponses() {
1116     // Lock the pending responses map and sequence number manager.
1117     std::lock_guard<std::mutex> Lock(ResponsesMutex);
1118 
1119     for (auto &KV : PendingResponses)
1120       KV.second->abandon();
1121     PendingResponses.clear();
1122     SequenceNumberMgr.reset();
1123   }
1124 
1125   /// Remove the handler for the given function.
1126   /// A handler must currently be registered for this function.
removeHandler()1127   template <typename Func> void removeHandler() {
1128     auto IdItr = LocalFunctionIds.find(Func::getPrototype());
1129     assert(IdItr != LocalFunctionIds.end() &&
1130            "Function does not have a registered handler");
1131     auto HandlerItr = Handlers.find(IdItr->second);
1132     assert(HandlerItr != Handlers.end() &&
1133            "Function does not have a registered handler");
1134     Handlers.erase(HandlerItr);
1135   }
1136 
1137   /// Clear all handlers.
clearHandlers()1138   void clearHandlers() { Handlers.clear(); }
1139 
1140 protected:
getInvalidFunctionId()1141   FunctionIdT getInvalidFunctionId() const {
1142     return FnIdAllocator.getInvalidId();
1143   }
1144 
1145   /// Add the given handler to the handler map and make it available for
1146   /// autonegotiation and execution.
1147   template <typename Func, typename HandlerT>
addHandlerImpl(HandlerT Handler)1148   void addHandlerImpl(HandlerT Handler) {
1149 
1150     static_assert(detail::RPCArgTypeCheck<
1151                       CanDeserializeCheck, typename Func::Type,
1152                       typename detail::HandlerTraits<HandlerT>::Type>::value,
1153                   "");
1154 
1155     FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1156     LocalFunctionIds[Func::getPrototype()] = NewFnId;
1157     Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
1158   }
1159 
1160   template <typename Func, typename HandlerT>
addAsyncHandlerImpl(HandlerT Handler)1161   void addAsyncHandlerImpl(HandlerT Handler) {
1162 
1163     static_assert(
1164         detail::RPCArgTypeCheck<
1165             CanDeserializeCheck, typename Func::Type,
1166             typename detail::AsyncHandlerTraits<
1167                 typename detail::HandlerTraits<HandlerT>::Type>::Type>::value,
1168         "");
1169 
1170     FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1171     LocalFunctionIds[Func::getPrototype()] = NewFnId;
1172     Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
1173   }
1174 
handleResponse(SequenceNumberT SeqNo)1175   Error handleResponse(SequenceNumberT SeqNo) {
1176     using Handler = typename decltype(PendingResponses)::mapped_type;
1177     Handler PRHandler;
1178 
1179     {
1180       // Lock the pending responses map and sequence number manager.
1181       std::unique_lock<std::mutex> Lock(ResponsesMutex);
1182       auto I = PendingResponses.find(SeqNo);
1183 
1184       if (I != PendingResponses.end()) {
1185         PRHandler = std::move(I->second);
1186         PendingResponses.erase(I);
1187         SequenceNumberMgr.releaseSequenceNumber(SeqNo);
1188       } else {
1189         // Unlock the pending results map to prevent recursive lock.
1190         Lock.unlock();
1191         abandonPendingResponses();
1192         return make_error<InvalidSequenceNumberForResponse<SequenceNumberT>>(
1193             SeqNo);
1194       }
1195     }
1196 
1197     assert(PRHandler &&
1198            "If we didn't find a response handler we should have bailed out");
1199 
1200     if (auto Err = PRHandler->handleResponse(C)) {
1201       abandonPendingResponses();
1202       return Err;
1203     }
1204 
1205     return Error::success();
1206   }
1207 
handleNegotiate(const std::string & Name)1208   FunctionIdT handleNegotiate(const std::string &Name) {
1209     auto I = LocalFunctionIds.find(Name);
1210     if (I == LocalFunctionIds.end())
1211       return getInvalidFunctionId();
1212     return I->second;
1213   }
1214 
1215   // Find the remote FunctionId for the given function.
1216   template <typename Func>
getRemoteFunctionId(bool NegotiateIfNotInMap,bool NegotiateIfInvalid)1217   Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
1218                                             bool NegotiateIfInvalid) {
1219     bool DoNegotiate;
1220 
1221     // Check if we already have a function id...
1222     auto I = RemoteFunctionIds.find(Func::getPrototype());
1223     if (I != RemoteFunctionIds.end()) {
1224       // If it's valid there's nothing left to do.
1225       if (I->second != getInvalidFunctionId())
1226         return I->second;
1227       DoNegotiate = NegotiateIfInvalid;
1228     } else
1229       DoNegotiate = NegotiateIfNotInMap;
1230 
1231     // We don't have a function id for Func yet, but we're allowed to try to
1232     // negotiate one.
1233     if (DoNegotiate) {
1234       auto &Impl = static_cast<ImplT &>(*this);
1235       if (auto RemoteIdOrErr =
1236               Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
1237         RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
1238         if (*RemoteIdOrErr == getInvalidFunctionId())
1239           return make_error<CouldNotNegotiate>(Func::getPrototype());
1240         return *RemoteIdOrErr;
1241       } else
1242         return RemoteIdOrErr.takeError();
1243     }
1244 
1245     // No key was available in the map and we weren't allowed to try to
1246     // negotiate one, so return an unknown function error.
1247     return make_error<CouldNotNegotiate>(Func::getPrototype());
1248   }
1249 
1250   using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
1251 
1252   // Wrap the given user handler in the necessary argument-deserialization code,
1253   // result-serialization code, and call to the launch policy (if present).
1254   template <typename Func, typename HandlerT>
wrapHandler(HandlerT Handler)1255   WrappedHandlerFn wrapHandler(HandlerT Handler) {
1256     return [this, Handler](ChannelT &Channel,
1257                            SequenceNumberT SeqNo) mutable -> Error {
1258       // Start by deserializing the arguments.
1259       using ArgsTuple = typename detail::RPCFunctionArgsTuple<
1260           typename detail::HandlerTraits<HandlerT>::Type>::Type;
1261       auto Args = std::make_shared<ArgsTuple>();
1262 
1263       if (auto Err =
1264               detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1265                   Channel, *Args))
1266         return Err;
1267 
1268       // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1269       // for RPCArgs. Void cast RPCArgs to work around this for now.
1270       // FIXME: Remove this workaround once we can assume a working GCC version.
1271       (void)Args;
1272 
1273       // End receieve message, unlocking the channel for reading.
1274       if (auto Err = Channel.endReceiveMessage())
1275         return Err;
1276 
1277       using HTraits = detail::HandlerTraits<HandlerT>;
1278       using FuncReturn = typename Func::ReturnType;
1279       return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
1280                                          HTraits::unpackAndRun(Handler, *Args));
1281     };
1282   }
1283 
1284   // Wrap the given user handler in the necessary argument-deserialization code,
1285   // result-serialization code, and call to the launch policy (if present).
1286   template <typename Func, typename HandlerT>
wrapAsyncHandler(HandlerT Handler)1287   WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
1288     return [this, Handler](ChannelT &Channel,
1289                            SequenceNumberT SeqNo) mutable -> Error {
1290       // Start by deserializing the arguments.
1291       using AHTraits = detail::AsyncHandlerTraits<
1292           typename detail::HandlerTraits<HandlerT>::Type>;
1293       using ArgsTuple =
1294           typename detail::RPCFunctionArgsTuple<typename AHTraits::Type>::Type;
1295       auto Args = std::make_shared<ArgsTuple>();
1296 
1297       if (auto Err =
1298               detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1299                   Channel, *Args))
1300         return Err;
1301 
1302       // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1303       // for RPCArgs. Void cast RPCArgs to work around this for now.
1304       // FIXME: Remove this workaround once we can assume a working GCC version.
1305       (void)Args;
1306 
1307       // End receieve message, unlocking the channel for reading.
1308       if (auto Err = Channel.endReceiveMessage())
1309         return Err;
1310 
1311       using HTraits = detail::HandlerTraits<HandlerT>;
1312       using FuncReturn = typename Func::ReturnType;
1313       auto Responder = [this,
1314                         SeqNo](typename AHTraits::ResultType RetVal) -> Error {
1315         return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
1316                                            std::move(RetVal));
1317       };
1318 
1319       return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
1320     };
1321   }
1322 
1323   ChannelT &C;
1324 
1325   bool LazyAutoNegotiation;
1326 
1327   RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator;
1328 
1329   FunctionIdT ResponseId;
1330   std::map<std::string, FunctionIdT> LocalFunctionIds;
1331   std::map<const char *, FunctionIdT> RemoteFunctionIds;
1332 
1333   std::map<FunctionIdT, WrappedHandlerFn> Handlers;
1334 
1335   std::mutex ResponsesMutex;
1336   detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr;
1337   std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>>
1338       PendingResponses;
1339 };
1340 
1341 } // end namespace detail
1342 
1343 template <typename ChannelT, typename FunctionIdT = uint32_t,
1344           typename SequenceNumberT = uint32_t>
1345 class MultiThreadedRPCEndpoint
1346     : public detail::RPCEndpointBase<
1347           MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1348           ChannelT, FunctionIdT, SequenceNumberT> {
1349 private:
1350   using BaseClass = detail::RPCEndpointBase<
1351       MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1352       ChannelT, FunctionIdT, SequenceNumberT>;
1353 
1354 public:
MultiThreadedRPCEndpoint(ChannelT & C,bool LazyAutoNegotiation)1355   MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1356       : BaseClass(C, LazyAutoNegotiation) {}
1357 
1358   /// Add a handler for the given RPC function.
1359   /// This installs the given handler functor for the given RPCFunction, and
1360   /// makes the RPC function available for negotiation/calling from the remote.
1361   template <typename Func, typename HandlerT>
addHandler(HandlerT Handler)1362   void addHandler(HandlerT Handler) {
1363     return this->template addHandlerImpl<Func>(std::move(Handler));
1364   }
1365 
1366   /// Add a class-method as a handler.
1367   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1368   void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1369     addHandler<Func>(
1370         detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1371   }
1372 
1373   template <typename Func, typename HandlerT>
addAsyncHandler(HandlerT Handler)1374   void addAsyncHandler(HandlerT Handler) {
1375     return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1376   }
1377 
1378   /// Add a class-method as a handler.
1379   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addAsyncHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1380   void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1381     addAsyncHandler<Func>(
1382         detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1383   }
1384 
1385   /// Return type for non-blocking call primitives.
1386   template <typename Func>
1387   using NonBlockingCallResult = typename detail::ResultTraits<
1388       typename Func::ReturnType>::ReturnFutureType;
1389 
1390   /// Call Func on Channel C. Does not block, does not call send. Returns a pair
1391   /// of a future result and the sequence number assigned to the result.
1392   ///
1393   /// This utility function is primarily used for single-threaded mode support,
1394   /// where the sequence number can be used to wait for the corresponding
1395   /// result. In multi-threaded mode the appendCallNB method, which does not
1396   /// return the sequence numeber, should be preferred.
1397   template <typename Func, typename... ArgTs>
appendCallNB(const ArgTs &...Args)1398   Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &...Args) {
1399     using RTraits = detail::ResultTraits<typename Func::ReturnType>;
1400     using ErrorReturn = typename RTraits::ErrorReturnType;
1401     using ErrorReturnPromise = typename RTraits::ReturnPromiseType;
1402 
1403     ErrorReturnPromise Promise;
1404     auto FutureResult = Promise.get_future();
1405 
1406     if (auto Err = this->template appendCallAsync<Func>(
1407             [Promise = std::move(Promise)](ErrorReturn RetOrErr) mutable {
1408               Promise.set_value(std::move(RetOrErr));
1409               return Error::success();
1410             },
1411             Args...)) {
1412       RTraits::consumeAbandoned(FutureResult.get());
1413       return std::move(Err);
1414     }
1415     return std::move(FutureResult);
1416   }
1417 
1418   /// The same as appendCallNBWithSeq, except that it calls C.send() to
1419   /// flush the channel after serializing the call.
1420   template <typename Func, typename... ArgTs>
callNB(const ArgTs &...Args)1421   Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &...Args) {
1422     auto Result = appendCallNB<Func>(Args...);
1423     if (!Result)
1424       return Result;
1425     if (auto Err = this->C.send()) {
1426       this->abandonPendingResponses();
1427       detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1428           std::move(Result->get()));
1429       return std::move(Err);
1430     }
1431     return Result;
1432   }
1433 
1434   /// Call Func on Channel C. Blocks waiting for a result. Returns an Error
1435   /// for void functions or an Expected<T> for functions returning a T.
1436   ///
1437   /// This function is for use in threaded code where another thread is
1438   /// handling responses and incoming calls.
1439   template <typename Func, typename... ArgTs,
1440             typename AltRetT = typename Func::ReturnType>
1441   typename detail::ResultTraits<AltRetT>::ErrorReturnType
callB(const ArgTs &...Args)1442   callB(const ArgTs &...Args) {
1443     if (auto FutureResOrErr = callNB<Func>(Args...))
1444       return FutureResOrErr->get();
1445     else
1446       return FutureResOrErr.takeError();
1447   }
1448 
1449   /// Handle incoming RPC calls.
handlerLoop()1450   Error handlerLoop() {
1451     while (true)
1452       if (auto Err = this->handleOne())
1453         return Err;
1454     return Error::success();
1455   }
1456 };
1457 
1458 template <typename ChannelT, typename FunctionIdT = uint32_t,
1459           typename SequenceNumberT = uint32_t>
1460 class SingleThreadedRPCEndpoint
1461     : public detail::RPCEndpointBase<
1462           SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1463           ChannelT, FunctionIdT, SequenceNumberT> {
1464 private:
1465   using BaseClass = detail::RPCEndpointBase<
1466       SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1467       ChannelT, FunctionIdT, SequenceNumberT>;
1468 
1469 public:
SingleThreadedRPCEndpoint(ChannelT & C,bool LazyAutoNegotiation)1470   SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1471       : BaseClass(C, LazyAutoNegotiation) {}
1472 
1473   template <typename Func, typename HandlerT>
addHandler(HandlerT Handler)1474   void addHandler(HandlerT Handler) {
1475     return this->template addHandlerImpl<Func>(std::move(Handler));
1476   }
1477 
1478   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1479   void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1480     addHandler<Func>(
1481         detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1482   }
1483 
1484   template <typename Func, typename HandlerT>
addAsyncHandler(HandlerT Handler)1485   void addAsyncHandler(HandlerT Handler) {
1486     return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1487   }
1488 
1489   /// Add a class-method as a handler.
1490   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addAsyncHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1491   void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1492     addAsyncHandler<Func>(
1493         detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1494   }
1495 
1496   template <typename Func, typename... ArgTs,
1497             typename AltRetT = typename Func::ReturnType>
1498   typename detail::ResultTraits<AltRetT>::ErrorReturnType
callB(const ArgTs &...Args)1499   callB(const ArgTs &...Args) {
1500     bool ReceivedResponse = false;
1501     using AltRetTraits = detail::ResultTraits<AltRetT>;
1502     using ResultType = typename AltRetTraits::ErrorReturnType;
1503     ResultType Result = AltRetTraits::createBlankErrorReturnValue();
1504 
1505     // We have to 'Check' result (which we know is in a success state at this
1506     // point) so that it can be overwritten in the async handler.
1507     (void)!!Result;
1508 
1509     if (Error Err = this->template appendCallAsync<Func>(
1510             [&](ResultType R) {
1511               Result = std::move(R);
1512               ReceivedResponse = true;
1513               return Error::success();
1514             },
1515             Args...)) {
1516       AltRetTraits::consumeAbandoned(std::move(Result));
1517       return AltRetTraits::returnError(std::move(Err));
1518     }
1519 
1520     if (Error Err = this->C.send()) {
1521       AltRetTraits::consumeAbandoned(std::move(Result));
1522       return AltRetTraits::returnError(std::move(Err));
1523     }
1524 
1525     while (!ReceivedResponse) {
1526       if (Error Err = this->handleOne()) {
1527         AltRetTraits::consumeAbandoned(std::move(Result));
1528         return AltRetTraits::returnError(std::move(Err));
1529       }
1530     }
1531 
1532     return Result;
1533   }
1534 };
1535 
1536 /// Asynchronous dispatch for a function on an RPC endpoint.
1537 template <typename RPCClass, typename Func> class RPCAsyncDispatch {
1538 public:
RPCAsyncDispatch(RPCClass & Endpoint)1539   RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
1540 
1541   template <typename HandlerT, typename... ArgTs>
operator()1542   Error operator()(HandlerT Handler, const ArgTs &...Args) const {
1543     return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
1544   }
1545 
1546 private:
1547   RPCClass &Endpoint;
1548 };
1549 
1550 /// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
1551 template <typename Func, typename RPCEndpointT>
rpcAsyncDispatch(RPCEndpointT & Endpoint)1552 RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
1553   return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
1554 }
1555 
1556 /// Allows a set of asynchrounous calls to be dispatched, and then
1557 ///        waited on as a group.
1558 class ParallelCallGroup {
1559 public:
1560   ParallelCallGroup() = default;
1561   ParallelCallGroup(const ParallelCallGroup &) = delete;
1562   ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
1563 
1564   /// Make as asynchronous call.
1565   template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
call(const AsyncDispatcher & AsyncDispatch,HandlerT Handler,const ArgTs &...Args)1566   Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
1567              const ArgTs &...Args) {
1568     // Increment the count of outstanding calls. This has to happen before
1569     // we invoke the call, as the handler may (depending on scheduling)
1570     // be run immediately on another thread, and we don't want the decrement
1571     // in the wrapped handler below to run before the increment.
1572     {
1573       std::unique_lock<std::mutex> Lock(M);
1574       ++NumOutstandingCalls;
1575     }
1576 
1577     // Wrap the user handler in a lambda that will decrement the
1578     // outstanding calls count, then poke the condition variable.
1579     using ArgType = typename detail::ResponseHandlerArg<
1580         typename detail::HandlerTraits<HandlerT>::Type>::ArgType;
1581     auto WrappedHandler = [this, Handler = std::move(Handler)](ArgType Arg) {
1582       auto Err = Handler(std::move(Arg));
1583       std::unique_lock<std::mutex> Lock(M);
1584       --NumOutstandingCalls;
1585       CV.notify_all();
1586       return Err;
1587     };
1588 
1589     return AsyncDispatch(std::move(WrappedHandler), Args...);
1590   }
1591 
1592   /// Blocks until all calls have been completed and their return value
1593   ///        handlers run.
wait()1594   void wait() {
1595     std::unique_lock<std::mutex> Lock(M);
1596     while (NumOutstandingCalls > 0)
1597       CV.wait(Lock);
1598   }
1599 
1600 private:
1601   std::mutex M;
1602   std::condition_variable CV;
1603   uint32_t NumOutstandingCalls = 0;
1604 };
1605 
1606 /// Convenience class for grouping RPCFunctions into APIs that can be
1607 ///        negotiated as a block.
1608 ///
1609 template <typename... Funcs> class APICalls {
1610 public:
1611   /// Test whether this API contains Function F.
1612   template <typename F> class Contains {
1613   public:
1614     static const bool value = false;
1615   };
1616 
1617   /// Negotiate all functions in this API.
negotiate(RPCEndpoint & R)1618   template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) {
1619     return Error::success();
1620   }
1621 };
1622 
1623 template <typename Func, typename... Funcs> class APICalls<Func, Funcs...> {
1624 public:
1625   template <typename F> class Contains {
1626   public:
1627     static const bool value = std::is_same<F, Func>::value |
1628                               APICalls<Funcs...>::template Contains<F>::value;
1629   };
1630 
negotiate(RPCEndpoint & R)1631   template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) {
1632     if (auto Err = R.template negotiateFunction<Func>())
1633       return Err;
1634     return APICalls<Funcs...>::negotiate(R);
1635   }
1636 };
1637 
1638 template <typename... InnerFuncs, typename... Funcs>
1639 class APICalls<APICalls<InnerFuncs...>, Funcs...> {
1640 public:
1641   template <typename F> class Contains {
1642   public:
1643     static const bool value =
1644         APICalls<InnerFuncs...>::template Contains<F>::value |
1645         APICalls<Funcs...>::template Contains<F>::value;
1646   };
1647 
negotiate(RPCEndpoint & R)1648   template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) {
1649     if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
1650       return Err;
1651     return APICalls<Funcs...>::negotiate(R);
1652   }
1653 };
1654 
1655 } // end namespace shared
1656 } // end namespace orc
1657 } // end namespace llvm
1658 
1659 #endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H
1660