xref: /llvm-project/compiler-rt/lib/orc/wrapper_function_utils.h (revision 69f8923efa61034b57805a8d6d859e9c1ca976eb)
1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file is a part of the ORC runtime support library.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H
14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H
15 
16 #include "error.h"
17 #include "executor_address.h"
18 #include "orc_rt/c_api.h"
19 #include "simple_packed_serialization.h"
20 #include <type_traits>
21 
22 namespace orc_rt {
23 
24 /// C++ wrapper function result: Same as CWrapperFunctionResult but
25 /// auto-releases memory.
26 class WrapperFunctionResult {
27 public:
28   /// Create a default WrapperFunctionResult.
29   WrapperFunctionResult() { orc_rt_CWrapperFunctionResultInit(&R); }
30 
31   /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This
32   /// instance takes ownership of the result object and will automatically
33   /// call dispose on the result upon destruction.
34   WrapperFunctionResult(orc_rt_CWrapperFunctionResult R) : R(R) {}
35 
36   WrapperFunctionResult(const WrapperFunctionResult &) = delete;
37   WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
38 
39   WrapperFunctionResult(WrapperFunctionResult &&Other) {
40     orc_rt_CWrapperFunctionResultInit(&R);
41     std::swap(R, Other.R);
42   }
43 
44   WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
45     orc_rt_CWrapperFunctionResult Tmp;
46     orc_rt_CWrapperFunctionResultInit(&Tmp);
47     std::swap(Tmp, Other.R);
48     std::swap(R, Tmp);
49     return *this;
50   }
51 
52   ~WrapperFunctionResult() { orc_rt_DisposeCWrapperFunctionResult(&R); }
53 
54   /// Relinquish ownership of and return the
55   /// orc_rt_CWrapperFunctionResult.
56   orc_rt_CWrapperFunctionResult release() {
57     orc_rt_CWrapperFunctionResult Tmp;
58     orc_rt_CWrapperFunctionResultInit(&Tmp);
59     std::swap(R, Tmp);
60     return Tmp;
61   }
62 
63   /// Get a pointer to the data contained in this instance.
64   char *data() { return orc_rt_CWrapperFunctionResultData(&R); }
65 
66   /// Returns the size of the data contained in this instance.
67   size_t size() const { return orc_rt_CWrapperFunctionResultSize(&R); }
68 
69   /// Returns true if this value is equivalent to a default-constructed
70   /// WrapperFunctionResult.
71   bool empty() const { return orc_rt_CWrapperFunctionResultEmpty(&R); }
72 
73   /// Create a WrapperFunctionResult with the given size and return a pointer
74   /// to the underlying memory.
75   static WrapperFunctionResult allocate(size_t Size) {
76     WrapperFunctionResult R;
77     R.R = orc_rt_CWrapperFunctionResultAllocate(Size);
78     return R;
79   }
80 
81   /// Copy from the given char range.
82   static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
83     return orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size);
84   }
85 
86   /// Copy from the given null-terminated string (includes the null-terminator).
87   static WrapperFunctionResult copyFrom(const char *Source) {
88     return orc_rt_CreateCWrapperFunctionResultFromString(Source);
89   }
90 
91   /// Copy from the given std::string (includes the null terminator).
92   static WrapperFunctionResult copyFrom(const std::string &Source) {
93     return copyFrom(Source.c_str());
94   }
95 
96   /// Create an out-of-band error by copying the given string.
97   static WrapperFunctionResult createOutOfBandError(const char *Msg) {
98     return orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg);
99   }
100 
101   /// Create an out-of-band error by copying the given string.
102   static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
103     return createOutOfBandError(Msg.c_str());
104   }
105 
106   template <typename SPSArgListT, typename... ArgTs>
107   static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) {
108     auto Result = allocate(SPSArgListT::size(Args...));
109     SPSOutputBuffer OB(Result.data(), Result.size());
110     if (!SPSArgListT::serialize(OB, Args...))
111       return createOutOfBandError(
112           "Error serializing arguments to blob in call");
113     return Result;
114   }
115 
116   /// If this value is an out-of-band error then this returns the error message,
117   /// otherwise returns nullptr.
118   const char *getOutOfBandError() const {
119     return orc_rt_CWrapperFunctionResultGetOutOfBandError(&R);
120   }
121 
122 private:
123   orc_rt_CWrapperFunctionResult R;
124 };
125 
126 namespace detail {
127 
128 template <typename RetT> class WrapperFunctionHandlerCaller {
129 public:
130   template <typename HandlerT, typename ArgTupleT, std::size_t... I>
131   static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
132                              std::index_sequence<I...>) {
133     return std::forward<HandlerT>(H)(std::get<I>(Args)...);
134   }
135 };
136 
137 template <> class WrapperFunctionHandlerCaller<void> {
138 public:
139   template <typename HandlerT, typename ArgTupleT, std::size_t... I>
140   static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
141                        std::index_sequence<I...>) {
142     std::forward<HandlerT>(H)(std::get<I>(Args)...);
143     return SPSEmpty();
144   }
145 };
146 
147 template <typename WrapperFunctionImplT,
148           template <typename> class ResultSerializer, typename... SPSTagTs>
149 class WrapperFunctionHandlerHelper
150     : public WrapperFunctionHandlerHelper<
151           decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
152           ResultSerializer, SPSTagTs...> {};
153 
154 template <typename RetT, typename... ArgTs,
155           template <typename> class ResultSerializer, typename... SPSTagTs>
156 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
157                                    SPSTagTs...> {
158 public:
159   using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
160   using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
161 
162   template <typename HandlerT>
163   static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
164                                      size_t ArgSize) {
165     ArgTuple Args;
166     if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
167       return WrapperFunctionResult::createOutOfBandError(
168           "Could not deserialize arguments for wrapper function call");
169 
170     auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
171         std::forward<HandlerT>(H), Args, ArgIndices{});
172 
173     return ResultSerializer<decltype(HandlerResult)>::serialize(
174         std::move(HandlerResult));
175   }
176 
177 private:
178   template <std::size_t... I>
179   static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
180                           std::index_sequence<I...>) {
181     SPSInputBuffer IB(ArgData, ArgSize);
182     return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
183   }
184 };
185 
186 // Map function pointers to function types.
187 template <typename RetT, typename... ArgTs,
188           template <typename> class ResultSerializer, typename... SPSTagTs>
189 class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
190                                    SPSTagTs...>
191     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
192                                           SPSTagTs...> {};
193 
194 // Map non-const member function types to function types.
195 template <typename ClassT, typename RetT, typename... ArgTs,
196           template <typename> class ResultSerializer, typename... SPSTagTs>
197 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
198                                    SPSTagTs...>
199     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
200                                           SPSTagTs...> {};
201 
202 // Map const member function types to function types.
203 template <typename ClassT, typename RetT, typename... ArgTs,
204           template <typename> class ResultSerializer, typename... SPSTagTs>
205 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
206                                    ResultSerializer, SPSTagTs...>
207     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
208                                           SPSTagTs...> {};
209 
210 template <typename SPSRetTagT, typename RetT> class ResultSerializer {
211 public:
212   static WrapperFunctionResult serialize(RetT Result) {
213     return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result);
214   }
215 };
216 
217 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
218 public:
219   static WrapperFunctionResult serialize(Error Err) {
220     return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
221         toSPSSerializable(std::move(Err)));
222   }
223 };
224 
225 template <typename SPSRetTagT, typename T>
226 class ResultSerializer<SPSRetTagT, Expected<T>> {
227 public:
228   static WrapperFunctionResult serialize(Expected<T> E) {
229     return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
230         toSPSSerializable(std::move(E)));
231   }
232 };
233 
234 template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
235 public:
236   static void makeSafe(RetT &Result) {}
237 
238   static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
239     SPSInputBuffer IB(ArgData, ArgSize);
240     if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
241       return make_error<StringError>(
242           "Error deserializing return value from blob in call");
243     return Error::success();
244   }
245 };
246 
247 template <> class ResultDeserializer<SPSError, Error> {
248 public:
249   static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
250 
251   static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
252     SPSInputBuffer IB(ArgData, ArgSize);
253     SPSSerializableError BSE;
254     if (!SPSArgList<SPSError>::deserialize(IB, BSE))
255       return make_error<StringError>(
256           "Error deserializing return value from blob in call");
257     Err = fromSPSSerializable(std::move(BSE));
258     return Error::success();
259   }
260 };
261 
262 template <typename SPSTagT, typename T>
263 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
264 public:
265   static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
266 
267   static Error deserialize(Expected<T> &E, const char *ArgData,
268                            size_t ArgSize) {
269     SPSInputBuffer IB(ArgData, ArgSize);
270     SPSSerializableExpected<T> BSE;
271     if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
272       return make_error<StringError>(
273           "Error deserializing return value from blob in call");
274     E = fromSPSSerializable(std::move(BSE));
275     return Error::success();
276   }
277 };
278 
279 } // end namespace detail
280 
281 template <typename SPSSignature> class WrapperFunction;
282 
283 template <typename SPSRetTagT, typename... SPSTagTs>
284 class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
285 private:
286   template <typename RetT>
287   using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
288 
289 public:
290   template <typename DispatchFn, typename RetT, typename... ArgTs>
291   static Error call(DispatchFn &&Dispatch, RetT &Result, const ArgTs &...Args) {
292 
293     // RetT might be an Error or Expected value. Set the checked flag now:
294     // we don't want the user to have to check the unused result if this
295     // operation fails.
296     detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
297 
298     auto ArgBuffer =
299         WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...);
300     if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
301       return make_error<StringError>(ErrMsg);
302 
303     WrapperFunctionResult ResultBuffer =
304         Dispatch(ArgBuffer.data(), ArgBuffer.size());
305 
306     if (auto ErrMsg = ResultBuffer.getOutOfBandError())
307       return make_error<StringError>(ErrMsg);
308 
309     return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
310         Result, ResultBuffer.data(), ResultBuffer.size());
311   }
312 
313   template <typename HandlerT>
314   static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
315                                       HandlerT &&Handler) {
316     using WFHH =
317         detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
318                                              ResultSerializer, SPSTagTs...>;
319     return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
320   }
321 
322 private:
323   template <typename T> static const T &makeSerializable(const T &Value) {
324     return Value;
325   }
326 
327   static detail::SPSSerializableError makeSerializable(Error Err) {
328     return detail::toSPSSerializable(std::move(Err));
329   }
330 
331   template <typename T>
332   static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
333     return detail::toSPSSerializable(std::move(E));
334   }
335 };
336 
337 template <typename... SPSTagTs>
338 class WrapperFunction<void(SPSTagTs...)>
339     : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
340 public:
341   template <typename DispatchFn, typename... ArgTs>
342   static Error call(DispatchFn &&Dispatch, const ArgTs &...Args) {
343     SPSEmpty BE;
344     return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(
345         std::forward<DispatchFn>(Dispatch), BE, Args...);
346   }
347 
348   using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
349 };
350 
351 /// A function object that takes an ExecutorAddr as its first argument,
352 /// casts that address to a ClassT*, then calls the given method on that
353 /// pointer passing in the remaining function arguments. This utility
354 /// removes some of the boilerplate from writing wrappers for method calls.
355 ///
356 ///   @code{.cpp}
357 ///   class MyClass {
358 ///   public:
359 ///     void myMethod(uint32_t, bool) { ... }
360 ///   };
361 ///
362 ///   // SPS Method signature -- note MyClass object address as first argument.
363 ///   using SPSMyMethodWrapperSignature =
364 ///     SPSTuple<SPSExecutorAddr, uint32_t, bool>;
365 ///
366 ///   WrapperFunctionResult
367 ///   myMethodCallWrapper(const char *ArgData, size_t ArgSize) {
368 ///     return WrapperFunction<SPSMyMethodWrapperSignature>::handle(
369 ///        ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod));
370 ///   }
371 ///   @endcode
372 ///
373 template <typename RetT, typename ClassT, typename... ArgTs>
374 class MethodWrapperHandler {
375 public:
376   using MethodT = RetT (ClassT::*)(ArgTs...);
377   MethodWrapperHandler(MethodT M) : M(M) {}
378   RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) {
379     return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...);
380   }
381 
382 private:
383   MethodT M;
384 };
385 
386 /// Create a MethodWrapperHandler object from the given method pointer.
387 template <typename RetT, typename ClassT, typename... ArgTs>
388 MethodWrapperHandler<RetT, ClassT, ArgTs...>
389 makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
390   return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
391 }
392 
393 /// Represents a call to a wrapper function.
394 class WrapperFunctionCall {
395 public:
396   // FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a
397   // smallvector.
398   using ArgDataBufferType = std::vector<char>;
399 
400   /// Create a WrapperFunctionCall using the given SPS serializer to serialize
401   /// the arguments.
402   template <typename SPSSerializer, typename... ArgTs>
403   static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr,
404                                               const ArgTs &...Args) {
405     ArgDataBufferType ArgData;
406     ArgData.resize(SPSSerializer::size(Args...));
407     SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(),
408                        ArgData.size());
409     if (SPSSerializer::serialize(OB, Args...))
410       return WrapperFunctionCall(FnAddr, std::move(ArgData));
411     return make_error<StringError>("Cannot serialize arguments for "
412                                    "AllocActionCall");
413   }
414 
415   WrapperFunctionCall() = default;
416 
417   /// Create a WrapperFunctionCall from a target function and arg buffer.
418   WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData)
419       : FnAddr(FnAddr), ArgData(std::move(ArgData)) {}
420 
421   /// Returns the address to be called.
422   const ExecutorAddr &getCallee() const { return FnAddr; }
423 
424   /// Returns the argument data.
425   const ArgDataBufferType &getArgData() const { return ArgData; }
426 
427   /// WrapperFunctionCalls convert to true if the callee is non-null.
428   explicit operator bool() const { return !!FnAddr; }
429 
430   /// Run call returning raw WrapperFunctionResult.
431   WrapperFunctionResult run() const {
432     using FnTy =
433         orc_rt_CWrapperFunctionResult(const char *ArgData, size_t ArgSize);
434     return WrapperFunctionResult(
435         FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size()));
436   }
437 
438   /// Run call and deserialize result using SPS.
439   template <typename SPSRetT, typename RetT>
440   std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error>
441   runWithSPSRet(RetT &RetVal) const {
442     auto WFR = run();
443     if (const char *ErrMsg = WFR.getOutOfBandError())
444       return make_error<StringError>(ErrMsg);
445     SPSInputBuffer IB(WFR.data(), WFR.size());
446     if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
447       return make_error<StringError>("Could not deserialize result from "
448                                      "serialized wrapper function call");
449     return Error::success();
450   }
451 
452   /// Overload for SPS functions returning void.
453   template <typename SPSRetT>
454   std::enable_if_t<std::is_same<SPSRetT, void>::value, Error>
455   runWithSPSRet() const {
456     SPSEmpty E;
457     return runWithSPSRet<SPSEmpty>(E);
458   }
459 
460   /// Run call and deserialize an SPSError result. SPSError returns and
461   /// deserialization failures are merged into the returned error.
462   Error runWithSPSRetErrorMerged() const {
463     detail::SPSSerializableError RetErr;
464     if (auto Err = runWithSPSRet<SPSError>(RetErr))
465       return Err;
466     return detail::fromSPSSerializable(std::move(RetErr));
467   }
468 
469 private:
470   ExecutorAddr FnAddr;
471   std::vector<char> ArgData;
472 };
473 
474 using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>;
475 
476 template <>
477 class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
478 public:
479   static size_t size(const WrapperFunctionCall &WFC) {
480     return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size(
481         WFC.getCallee(), WFC.getArgData());
482   }
483 
484   static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
485     return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize(
486         OB, WFC.getCallee(), WFC.getArgData());
487   }
488 
489   static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
490     ExecutorAddr FnAddr;
491     WrapperFunctionCall::ArgDataBufferType ArgData;
492     if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData))
493       return false;
494     WFC = WrapperFunctionCall(FnAddr, std::move(ArgData));
495     return true;
496   }
497 };
498 
499 } // namespace orc_rt
500 
501 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H
502