1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file is a part of the ORC runtime support library.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H
14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H
15
16 #include "orc_rt/c_api.h"
17 #include "common.h"
18 #include "error.h"
19 #include "executor_address.h"
20 #include "simple_packed_serialization.h"
21 #include <type_traits>
22
23 namespace __orc_rt {
24
25 /// C++ wrapper function result: Same as CWrapperFunctionResult but
26 /// auto-releases memory.
27 class WrapperFunctionResult {
28 public:
29 /// Create a default WrapperFunctionResult.
WrapperFunctionResult()30 WrapperFunctionResult() { orc_rt_CWrapperFunctionResultInit(&R); }
31
32 /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This
33 /// instance takes ownership of the result object and will automatically
34 /// call dispose on the result upon destruction.
WrapperFunctionResult(orc_rt_CWrapperFunctionResult R)35 WrapperFunctionResult(orc_rt_CWrapperFunctionResult R) : R(R) {}
36
37 WrapperFunctionResult(const WrapperFunctionResult &) = delete;
38 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
39
WrapperFunctionResult(WrapperFunctionResult && Other)40 WrapperFunctionResult(WrapperFunctionResult &&Other) {
41 orc_rt_CWrapperFunctionResultInit(&R);
42 std::swap(R, Other.R);
43 }
44
45 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
46 orc_rt_CWrapperFunctionResult Tmp;
47 orc_rt_CWrapperFunctionResultInit(&Tmp);
48 std::swap(Tmp, Other.R);
49 std::swap(R, Tmp);
50 return *this;
51 }
52
~WrapperFunctionResult()53 ~WrapperFunctionResult() { orc_rt_DisposeCWrapperFunctionResult(&R); }
54
55 /// Relinquish ownership of and return the
56 /// orc_rt_CWrapperFunctionResult.
release()57 orc_rt_CWrapperFunctionResult release() {
58 orc_rt_CWrapperFunctionResult Tmp;
59 orc_rt_CWrapperFunctionResultInit(&Tmp);
60 std::swap(R, Tmp);
61 return Tmp;
62 }
63
64 /// Get a pointer to the data contained in this instance.
data()65 char *data() { return orc_rt_CWrapperFunctionResultData(&R); }
66
67 /// Returns the size of the data contained in this instance.
size()68 size_t size() const { return orc_rt_CWrapperFunctionResultSize(&R); }
69
70 /// Returns true if this value is equivalent to a default-constructed
71 /// WrapperFunctionResult.
empty()72 bool empty() const { return orc_rt_CWrapperFunctionResultEmpty(&R); }
73
74 /// Create a WrapperFunctionResult with the given size and return a pointer
75 /// to the underlying memory.
allocate(size_t Size)76 static WrapperFunctionResult allocate(size_t Size) {
77 WrapperFunctionResult R;
78 R.R = orc_rt_CWrapperFunctionResultAllocate(Size);
79 return R;
80 }
81
82 /// Copy from the given char range.
copyFrom(const char * Source,size_t Size)83 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
84 return orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size);
85 }
86
87 /// Copy from the given null-terminated string (includes the null-terminator).
copyFrom(const char * Source)88 static WrapperFunctionResult copyFrom(const char *Source) {
89 return orc_rt_CreateCWrapperFunctionResultFromString(Source);
90 }
91
92 /// Copy from the given std::string (includes the null terminator).
copyFrom(const std::string & Source)93 static WrapperFunctionResult copyFrom(const std::string &Source) {
94 return copyFrom(Source.c_str());
95 }
96
97 /// Create an out-of-band error by copying the given string.
createOutOfBandError(const char * Msg)98 static WrapperFunctionResult createOutOfBandError(const char *Msg) {
99 return orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg);
100 }
101
102 /// Create an out-of-band error by copying the given string.
createOutOfBandError(const std::string & Msg)103 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
104 return createOutOfBandError(Msg.c_str());
105 }
106
107 template <typename SPSArgListT, typename... ArgTs>
fromSPSArgs(const ArgTs &...Args)108 static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) {
109 auto Result = allocate(SPSArgListT::size(Args...));
110 SPSOutputBuffer OB(Result.data(), Result.size());
111 if (!SPSArgListT::serialize(OB, Args...))
112 return createOutOfBandError(
113 "Error serializing arguments to blob in call");
114 return Result;
115 }
116
117 /// If this value is an out-of-band error then this returns the error message,
118 /// otherwise returns nullptr.
getOutOfBandError()119 const char *getOutOfBandError() const {
120 return orc_rt_CWrapperFunctionResultGetOutOfBandError(&R);
121 }
122
123 private:
124 orc_rt_CWrapperFunctionResult R;
125 };
126
127 namespace detail {
128
129 template <typename RetT> class WrapperFunctionHandlerCaller {
130 public:
131 template <typename HandlerT, typename ArgTupleT, std::size_t... I>
decltype(auto)132 static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
133 std::index_sequence<I...>) {
134 return std::forward<HandlerT>(H)(std::get<I>(Args)...);
135 }
136 };
137
138 template <> class WrapperFunctionHandlerCaller<void> {
139 public:
140 template <typename HandlerT, typename ArgTupleT, std::size_t... I>
call(HandlerT && H,ArgTupleT & Args,std::index_sequence<I...>)141 static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
142 std::index_sequence<I...>) {
143 std::forward<HandlerT>(H)(std::get<I>(Args)...);
144 return SPSEmpty();
145 }
146 };
147
148 template <typename WrapperFunctionImplT,
149 template <typename> class ResultSerializer, typename... SPSTagTs>
150 class WrapperFunctionHandlerHelper
151 : public WrapperFunctionHandlerHelper<
152 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
153 ResultSerializer, SPSTagTs...> {};
154
155 template <typename RetT, typename... ArgTs,
156 template <typename> class ResultSerializer, typename... SPSTagTs>
157 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
158 SPSTagTs...> {
159 public:
160 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
161 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
162
163 template <typename HandlerT>
apply(HandlerT && H,const char * ArgData,size_t ArgSize)164 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
165 size_t ArgSize) {
166 ArgTuple Args;
167 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
168 return WrapperFunctionResult::createOutOfBandError(
169 "Could not deserialize arguments for wrapper function call");
170
171 auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
172 std::forward<HandlerT>(H), Args, ArgIndices{});
173
174 return ResultSerializer<decltype(HandlerResult)>::serialize(
175 std::move(HandlerResult));
176 }
177
178 private:
179 template <std::size_t... I>
deserialize(const char * ArgData,size_t ArgSize,ArgTuple & Args,std::index_sequence<I...>)180 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
181 std::index_sequence<I...>) {
182 SPSInputBuffer IB(ArgData, ArgSize);
183 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
184 }
185 };
186
187 // Map function pointers to function types.
188 template <typename RetT, typename... ArgTs,
189 template <typename> class ResultSerializer, typename... SPSTagTs>
190 class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
191 SPSTagTs...>
192 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
193 SPSTagTs...> {};
194
195 // Map non-const member function types to function types.
196 template <typename ClassT, typename RetT, typename... ArgTs,
197 template <typename> class ResultSerializer, typename... SPSTagTs>
198 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
199 SPSTagTs...>
200 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
201 SPSTagTs...> {};
202
203 // Map const member function types to function types.
204 template <typename ClassT, typename RetT, typename... ArgTs,
205 template <typename> class ResultSerializer, typename... SPSTagTs>
206 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
207 ResultSerializer, SPSTagTs...>
208 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
209 SPSTagTs...> {};
210
211 template <typename SPSRetTagT, typename RetT> class ResultSerializer {
212 public:
serialize(RetT Result)213 static WrapperFunctionResult serialize(RetT Result) {
214 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result);
215 }
216 };
217
218 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
219 public:
serialize(Error Err)220 static WrapperFunctionResult serialize(Error Err) {
221 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
222 toSPSSerializable(std::move(Err)));
223 }
224 };
225
226 template <typename SPSRetTagT, typename T>
227 class ResultSerializer<SPSRetTagT, Expected<T>> {
228 public:
serialize(Expected<T> E)229 static WrapperFunctionResult serialize(Expected<T> E) {
230 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
231 toSPSSerializable(std::move(E)));
232 }
233 };
234
235 template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
236 public:
makeSafe(RetT & Result)237 static void makeSafe(RetT &Result) {}
238
deserialize(RetT & Result,const char * ArgData,size_t ArgSize)239 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
240 SPSInputBuffer IB(ArgData, ArgSize);
241 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
242 return make_error<StringError>(
243 "Error deserializing return value from blob in call");
244 return Error::success();
245 }
246 };
247
248 template <> class ResultDeserializer<SPSError, Error> {
249 public:
makeSafe(Error & Err)250 static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
251
deserialize(Error & Err,const char * ArgData,size_t ArgSize)252 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
253 SPSInputBuffer IB(ArgData, ArgSize);
254 SPSSerializableError BSE;
255 if (!SPSArgList<SPSError>::deserialize(IB, BSE))
256 return make_error<StringError>(
257 "Error deserializing return value from blob in call");
258 Err = fromSPSSerializable(std::move(BSE));
259 return Error::success();
260 }
261 };
262
263 template <typename SPSTagT, typename T>
264 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
265 public:
makeSafe(Expected<T> & E)266 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
267
deserialize(Expected<T> & E,const char * ArgData,size_t ArgSize)268 static Error deserialize(Expected<T> &E, const char *ArgData,
269 size_t ArgSize) {
270 SPSInputBuffer IB(ArgData, ArgSize);
271 SPSSerializableExpected<T> BSE;
272 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
273 return make_error<StringError>(
274 "Error deserializing return value from blob in call");
275 E = fromSPSSerializable(std::move(BSE));
276 return Error::success();
277 }
278 };
279
280 } // end namespace detail
281
282 template <typename SPSSignature> class WrapperFunction;
283
284 template <typename SPSRetTagT, typename... SPSTagTs>
285 class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
286 private:
287 template <typename RetT>
288 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
289
290 public:
291 template <typename RetT, typename... ArgTs>
call(const void * FnTag,RetT & Result,const ArgTs &...Args)292 static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) {
293
294 // RetT might be an Error or Expected value. Set the checked flag now:
295 // we don't want the user to have to check the unused result if this
296 // operation fails.
297 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
298
299 // Since the functions cannot be zero/unresolved on Windows, the following
300 // reference taking would always be non-zero, thus generating a compiler
301 // warning otherwise.
302 #if !defined(_WIN32)
303 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx))
304 return make_error<StringError>("__orc_rt_jit_dispatch_ctx not set");
305 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch))
306 return make_error<StringError>("__orc_rt_jit_dispatch not set");
307 #endif
308 auto ArgBuffer =
309 WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...);
310 if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
311 return make_error<StringError>(ErrMsg);
312
313 WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch(
314 &__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer.data(), ArgBuffer.size());
315 if (auto ErrMsg = ResultBuffer.getOutOfBandError())
316 return make_error<StringError>(ErrMsg);
317
318 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
319 Result, ResultBuffer.data(), ResultBuffer.size());
320 }
321
322 template <typename HandlerT>
handle(const char * ArgData,size_t ArgSize,HandlerT && Handler)323 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
324 HandlerT &&Handler) {
325 using WFHH =
326 detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
327 ResultSerializer, SPSTagTs...>;
328 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
329 }
330
331 private:
makeSerializable(const T & Value)332 template <typename T> static const T &makeSerializable(const T &Value) {
333 return Value;
334 }
335
makeSerializable(Error Err)336 static detail::SPSSerializableError makeSerializable(Error Err) {
337 return detail::toSPSSerializable(std::move(Err));
338 }
339
340 template <typename T>
makeSerializable(Expected<T> E)341 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
342 return detail::toSPSSerializable(std::move(E));
343 }
344 };
345
346 template <typename... SPSTagTs>
347 class WrapperFunction<void(SPSTagTs...)>
348 : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
349 public:
350 template <typename... ArgTs>
call(const void * FnTag,const ArgTs &...Args)351 static Error call(const void *FnTag, const ArgTs &...Args) {
352 SPSEmpty BE;
353 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...);
354 }
355
356 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
357 };
358
359 /// A function object that takes an ExecutorAddr as its first argument,
360 /// casts that address to a ClassT*, then calls the given method on that
361 /// pointer passing in the remaining function arguments. This utility
362 /// removes some of the boilerplate from writing wrappers for method calls.
363 ///
364 /// @code{.cpp}
365 /// class MyClass {
366 /// public:
367 /// void myMethod(uint32_t, bool) { ... }
368 /// };
369 ///
370 /// // SPS Method signature -- note MyClass object address as first argument.
371 /// using SPSMyMethodWrapperSignature =
372 /// SPSTuple<SPSExecutorAddr, uint32_t, bool>;
373 ///
374 /// WrapperFunctionResult
375 /// myMethodCallWrapper(const char *ArgData, size_t ArgSize) {
376 /// return WrapperFunction<SPSMyMethodWrapperSignature>::handle(
377 /// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod));
378 /// }
379 /// @endcode
380 ///
381 template <typename RetT, typename ClassT, typename... ArgTs>
382 class MethodWrapperHandler {
383 public:
384 using MethodT = RetT (ClassT::*)(ArgTs...);
MethodWrapperHandler(MethodT M)385 MethodWrapperHandler(MethodT M) : M(M) {}
operator()386 RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) {
387 return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...);
388 }
389
390 private:
391 MethodT M;
392 };
393
394 /// Create a MethodWrapperHandler object from the given method pointer.
395 template <typename RetT, typename ClassT, typename... ArgTs>
396 MethodWrapperHandler<RetT, ClassT, ArgTs...>
makeMethodWrapperHandler(RetT (ClassT::* Method)(ArgTs...))397 makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
398 return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
399 }
400
401 /// Represents a call to a wrapper function.
402 class WrapperFunctionCall {
403 public:
404 // FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a
405 // smallvector.
406 using ArgDataBufferType = std::vector<char>;
407
408 /// Create a WrapperFunctionCall using the given SPS serializer to serialize
409 /// the arguments.
410 template <typename SPSSerializer, typename... ArgTs>
Create(ExecutorAddr FnAddr,const ArgTs &...Args)411 static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr,
412 const ArgTs &...Args) {
413 ArgDataBufferType ArgData;
414 ArgData.resize(SPSSerializer::size(Args...));
415 SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(),
416 ArgData.size());
417 if (SPSSerializer::serialize(OB, Args...))
418 return WrapperFunctionCall(FnAddr, std::move(ArgData));
419 return make_error<StringError>("Cannot serialize arguments for "
420 "AllocActionCall");
421 }
422
423 WrapperFunctionCall() = default;
424
425 /// Create a WrapperFunctionCall from a target function and arg buffer.
WrapperFunctionCall(ExecutorAddr FnAddr,ArgDataBufferType ArgData)426 WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData)
427 : FnAddr(FnAddr), ArgData(std::move(ArgData)) {}
428
429 /// Returns the address to be called.
getCallee()430 const ExecutorAddr &getCallee() const { return FnAddr; }
431
432 /// Returns the argument data.
getArgData()433 const ArgDataBufferType &getArgData() const { return ArgData; }
434
435 /// WrapperFunctionCalls convert to true if the callee is non-null.
436 explicit operator bool() const { return !!FnAddr; }
437
438 /// Run call returning raw WrapperFunctionResult.
run()439 WrapperFunctionResult run() const {
440 using FnTy =
441 orc_rt_CWrapperFunctionResult(const char *ArgData, size_t ArgSize);
442 return WrapperFunctionResult(
443 FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size()));
444 }
445
446 /// Run call and deserialize result using SPS.
447 template <typename SPSRetT, typename RetT>
448 std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error>
runWithSPSRet(RetT & RetVal)449 runWithSPSRet(RetT &RetVal) const {
450 auto WFR = run();
451 if (const char *ErrMsg = WFR.getOutOfBandError())
452 return make_error<StringError>(ErrMsg);
453 SPSInputBuffer IB(WFR.data(), WFR.size());
454 if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
455 return make_error<StringError>("Could not deserialize result from "
456 "serialized wrapper function call");
457 return Error::success();
458 }
459
460 /// Overload for SPS functions returning void.
461 template <typename SPSRetT>
462 std::enable_if_t<std::is_same<SPSRetT, void>::value, Error>
runWithSPSRet()463 runWithSPSRet() const {
464 SPSEmpty E;
465 return runWithSPSRet<SPSEmpty>(E);
466 }
467
468 /// Run call and deserialize an SPSError result. SPSError returns and
469 /// deserialization failures are merged into the returned error.
runWithSPSRetErrorMerged()470 Error runWithSPSRetErrorMerged() const {
471 detail::SPSSerializableError RetErr;
472 if (auto Err = runWithSPSRet<SPSError>(RetErr))
473 return Err;
474 return detail::fromSPSSerializable(std::move(RetErr));
475 }
476
477 private:
478 ExecutorAddr FnAddr;
479 std::vector<char> ArgData;
480 };
481
482 using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>;
483
484 template <>
485 class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
486 public:
size(const WrapperFunctionCall & WFC)487 static size_t size(const WrapperFunctionCall &WFC) {
488 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size(
489 WFC.getCallee(), WFC.getArgData());
490 }
491
serialize(SPSOutputBuffer & OB,const WrapperFunctionCall & WFC)492 static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
493 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize(
494 OB, WFC.getCallee(), WFC.getArgData());
495 }
496
deserialize(SPSInputBuffer & IB,WrapperFunctionCall & WFC)497 static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
498 ExecutorAddr FnAddr;
499 WrapperFunctionCall::ArgDataBufferType ArgData;
500 if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData))
501 return false;
502 WFC = WrapperFunctionCall(FnAddr, std::move(ArgData));
503 return true;
504 }
505 };
506
507 } // end namespace __orc_rt
508
509 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H
510