1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file is a part of the ORC runtime support library. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H 14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H 15 16 #include "error.h" 17 #include "executor_address.h" 18 #include "orc_rt/c_api.h" 19 #include "simple_packed_serialization.h" 20 #include <type_traits> 21 22 namespace orc_rt { 23 24 /// C++ wrapper function result: Same as CWrapperFunctionResult but 25 /// auto-releases memory. 26 class WrapperFunctionResult { 27 public: 28 /// Create a default WrapperFunctionResult. 29 WrapperFunctionResult() { orc_rt_CWrapperFunctionResultInit(&R); } 30 31 /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This 32 /// instance takes ownership of the result object and will automatically 33 /// call dispose on the result upon destruction. 34 WrapperFunctionResult(orc_rt_CWrapperFunctionResult R) : R(R) {} 35 36 WrapperFunctionResult(const WrapperFunctionResult &) = delete; 37 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; 38 39 WrapperFunctionResult(WrapperFunctionResult &&Other) { 40 orc_rt_CWrapperFunctionResultInit(&R); 41 std::swap(R, Other.R); 42 } 43 44 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { 45 orc_rt_CWrapperFunctionResult Tmp; 46 orc_rt_CWrapperFunctionResultInit(&Tmp); 47 std::swap(Tmp, Other.R); 48 std::swap(R, Tmp); 49 return *this; 50 } 51 52 ~WrapperFunctionResult() { orc_rt_DisposeCWrapperFunctionResult(&R); } 53 54 /// Relinquish ownership of and return the 55 /// orc_rt_CWrapperFunctionResult. 56 orc_rt_CWrapperFunctionResult release() { 57 orc_rt_CWrapperFunctionResult Tmp; 58 orc_rt_CWrapperFunctionResultInit(&Tmp); 59 std::swap(R, Tmp); 60 return Tmp; 61 } 62 63 /// Get a pointer to the data contained in this instance. 64 char *data() { return orc_rt_CWrapperFunctionResultData(&R); } 65 66 /// Returns the size of the data contained in this instance. 67 size_t size() const { return orc_rt_CWrapperFunctionResultSize(&R); } 68 69 /// Returns true if this value is equivalent to a default-constructed 70 /// WrapperFunctionResult. 71 bool empty() const { return orc_rt_CWrapperFunctionResultEmpty(&R); } 72 73 /// Create a WrapperFunctionResult with the given size and return a pointer 74 /// to the underlying memory. 75 static WrapperFunctionResult allocate(size_t Size) { 76 WrapperFunctionResult R; 77 R.R = orc_rt_CWrapperFunctionResultAllocate(Size); 78 return R; 79 } 80 81 /// Copy from the given char range. 82 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { 83 return orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size); 84 } 85 86 /// Copy from the given null-terminated string (includes the null-terminator). 87 static WrapperFunctionResult copyFrom(const char *Source) { 88 return orc_rt_CreateCWrapperFunctionResultFromString(Source); 89 } 90 91 /// Copy from the given std::string (includes the null terminator). 92 static WrapperFunctionResult copyFrom(const std::string &Source) { 93 return copyFrom(Source.c_str()); 94 } 95 96 /// Create an out-of-band error by copying the given string. 97 static WrapperFunctionResult createOutOfBandError(const char *Msg) { 98 return orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg); 99 } 100 101 /// Create an out-of-band error by copying the given string. 102 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { 103 return createOutOfBandError(Msg.c_str()); 104 } 105 106 template <typename SPSArgListT, typename... ArgTs> 107 static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) { 108 auto Result = allocate(SPSArgListT::size(Args...)); 109 SPSOutputBuffer OB(Result.data(), Result.size()); 110 if (!SPSArgListT::serialize(OB, Args...)) 111 return createOutOfBandError( 112 "Error serializing arguments to blob in call"); 113 return Result; 114 } 115 116 /// If this value is an out-of-band error then this returns the error message, 117 /// otherwise returns nullptr. 118 const char *getOutOfBandError() const { 119 return orc_rt_CWrapperFunctionResultGetOutOfBandError(&R); 120 } 121 122 private: 123 orc_rt_CWrapperFunctionResult R; 124 }; 125 126 namespace detail { 127 128 template <typename RetT> class WrapperFunctionHandlerCaller { 129 public: 130 template <typename HandlerT, typename ArgTupleT, std::size_t... I> 131 static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, 132 std::index_sequence<I...>) { 133 return std::forward<HandlerT>(H)(std::get<I>(Args)...); 134 } 135 }; 136 137 template <> class WrapperFunctionHandlerCaller<void> { 138 public: 139 template <typename HandlerT, typename ArgTupleT, std::size_t... I> 140 static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, 141 std::index_sequence<I...>) { 142 std::forward<HandlerT>(H)(std::get<I>(Args)...); 143 return SPSEmpty(); 144 } 145 }; 146 147 template <typename WrapperFunctionImplT, 148 template <typename> class ResultSerializer, typename... SPSTagTs> 149 class WrapperFunctionHandlerHelper 150 : public WrapperFunctionHandlerHelper< 151 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), 152 ResultSerializer, SPSTagTs...> {}; 153 154 template <typename RetT, typename... ArgTs, 155 template <typename> class ResultSerializer, typename... SPSTagTs> 156 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 157 SPSTagTs...> { 158 public: 159 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; 160 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; 161 162 template <typename HandlerT> 163 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, 164 size_t ArgSize) { 165 ArgTuple Args; 166 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) 167 return WrapperFunctionResult::createOutOfBandError( 168 "Could not deserialize arguments for wrapper function call"); 169 170 auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call( 171 std::forward<HandlerT>(H), Args, ArgIndices{}); 172 173 return ResultSerializer<decltype(HandlerResult)>::serialize( 174 std::move(HandlerResult)); 175 } 176 177 private: 178 template <std::size_t... I> 179 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, 180 std::index_sequence<I...>) { 181 SPSInputBuffer IB(ArgData, ArgSize); 182 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); 183 } 184 }; 185 186 // Map function pointers to function types. 187 template <typename RetT, typename... ArgTs, 188 template <typename> class ResultSerializer, typename... SPSTagTs> 189 class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer, 190 SPSTagTs...> 191 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 192 SPSTagTs...> {}; 193 194 // Map non-const member function types to function types. 195 template <typename ClassT, typename RetT, typename... ArgTs, 196 template <typename> class ResultSerializer, typename... SPSTagTs> 197 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer, 198 SPSTagTs...> 199 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 200 SPSTagTs...> {}; 201 202 // Map const member function types to function types. 203 template <typename ClassT, typename RetT, typename... ArgTs, 204 template <typename> class ResultSerializer, typename... SPSTagTs> 205 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const, 206 ResultSerializer, SPSTagTs...> 207 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 208 SPSTagTs...> {}; 209 210 template <typename SPSRetTagT, typename RetT> class ResultSerializer { 211 public: 212 static WrapperFunctionResult serialize(RetT Result) { 213 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result); 214 } 215 }; 216 217 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> { 218 public: 219 static WrapperFunctionResult serialize(Error Err) { 220 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>( 221 toSPSSerializable(std::move(Err))); 222 } 223 }; 224 225 template <typename SPSRetTagT, typename T> 226 class ResultSerializer<SPSRetTagT, Expected<T>> { 227 public: 228 static WrapperFunctionResult serialize(Expected<T> E) { 229 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>( 230 toSPSSerializable(std::move(E))); 231 } 232 }; 233 234 template <typename SPSRetTagT, typename RetT> class ResultDeserializer { 235 public: 236 static void makeSafe(RetT &Result) {} 237 238 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { 239 SPSInputBuffer IB(ArgData, ArgSize); 240 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result)) 241 return make_error<StringError>( 242 "Error deserializing return value from blob in call"); 243 return Error::success(); 244 } 245 }; 246 247 template <> class ResultDeserializer<SPSError, Error> { 248 public: 249 static void makeSafe(Error &Err) { cantFail(std::move(Err)); } 250 251 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { 252 SPSInputBuffer IB(ArgData, ArgSize); 253 SPSSerializableError BSE; 254 if (!SPSArgList<SPSError>::deserialize(IB, BSE)) 255 return make_error<StringError>( 256 "Error deserializing return value from blob in call"); 257 Err = fromSPSSerializable(std::move(BSE)); 258 return Error::success(); 259 } 260 }; 261 262 template <typename SPSTagT, typename T> 263 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> { 264 public: 265 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); } 266 267 static Error deserialize(Expected<T> &E, const char *ArgData, 268 size_t ArgSize) { 269 SPSInputBuffer IB(ArgData, ArgSize); 270 SPSSerializableExpected<T> BSE; 271 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE)) 272 return make_error<StringError>( 273 "Error deserializing return value from blob in call"); 274 E = fromSPSSerializable(std::move(BSE)); 275 return Error::success(); 276 } 277 }; 278 279 } // end namespace detail 280 281 template <typename SPSSignature> class WrapperFunction; 282 283 template <typename SPSRetTagT, typename... SPSTagTs> 284 class WrapperFunction<SPSRetTagT(SPSTagTs...)> { 285 private: 286 template <typename RetT> 287 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>; 288 289 public: 290 template <typename DispatchFn, typename RetT, typename... ArgTs> 291 static Error call(DispatchFn &&Dispatch, RetT &Result, const ArgTs &...Args) { 292 293 // RetT might be an Error or Expected value. Set the checked flag now: 294 // we don't want the user to have to check the unused result if this 295 // operation fails. 296 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result); 297 298 auto ArgBuffer = 299 WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...); 300 if (const char *ErrMsg = ArgBuffer.getOutOfBandError()) 301 return make_error<StringError>(ErrMsg); 302 303 WrapperFunctionResult ResultBuffer = 304 Dispatch(ArgBuffer.data(), ArgBuffer.size()); 305 306 if (auto ErrMsg = ResultBuffer.getOutOfBandError()) 307 return make_error<StringError>(ErrMsg); 308 309 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( 310 Result, ResultBuffer.data(), ResultBuffer.size()); 311 } 312 313 template <typename HandlerT> 314 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, 315 HandlerT &&Handler) { 316 using WFHH = 317 detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>, 318 ResultSerializer, SPSTagTs...>; 319 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize); 320 } 321 322 private: 323 template <typename T> static const T &makeSerializable(const T &Value) { 324 return Value; 325 } 326 327 static detail::SPSSerializableError makeSerializable(Error Err) { 328 return detail::toSPSSerializable(std::move(Err)); 329 } 330 331 template <typename T> 332 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) { 333 return detail::toSPSSerializable(std::move(E)); 334 } 335 }; 336 337 template <typename... SPSTagTs> 338 class WrapperFunction<void(SPSTagTs...)> 339 : private WrapperFunction<SPSEmpty(SPSTagTs...)> { 340 public: 341 template <typename DispatchFn, typename... ArgTs> 342 static Error call(DispatchFn &&Dispatch, const ArgTs &...Args) { 343 SPSEmpty BE; 344 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call( 345 std::forward<DispatchFn>(Dispatch), BE, Args...); 346 } 347 348 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle; 349 }; 350 351 /// A function object that takes an ExecutorAddr as its first argument, 352 /// casts that address to a ClassT*, then calls the given method on that 353 /// pointer passing in the remaining function arguments. This utility 354 /// removes some of the boilerplate from writing wrappers for method calls. 355 /// 356 /// @code{.cpp} 357 /// class MyClass { 358 /// public: 359 /// void myMethod(uint32_t, bool) { ... } 360 /// }; 361 /// 362 /// // SPS Method signature -- note MyClass object address as first argument. 363 /// using SPSMyMethodWrapperSignature = 364 /// SPSTuple<SPSExecutorAddr, uint32_t, bool>; 365 /// 366 /// WrapperFunctionResult 367 /// myMethodCallWrapper(const char *ArgData, size_t ArgSize) { 368 /// return WrapperFunction<SPSMyMethodWrapperSignature>::handle( 369 /// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod)); 370 /// } 371 /// @endcode 372 /// 373 template <typename RetT, typename ClassT, typename... ArgTs> 374 class MethodWrapperHandler { 375 public: 376 using MethodT = RetT (ClassT::*)(ArgTs...); 377 MethodWrapperHandler(MethodT M) : M(M) {} 378 RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) { 379 return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...); 380 } 381 382 private: 383 MethodT M; 384 }; 385 386 /// Create a MethodWrapperHandler object from the given method pointer. 387 template <typename RetT, typename ClassT, typename... ArgTs> 388 MethodWrapperHandler<RetT, ClassT, ArgTs...> 389 makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) { 390 return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method); 391 } 392 393 /// Represents a call to a wrapper function. 394 class WrapperFunctionCall { 395 public: 396 // FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a 397 // smallvector. 398 using ArgDataBufferType = std::vector<char>; 399 400 /// Create a WrapperFunctionCall using the given SPS serializer to serialize 401 /// the arguments. 402 template <typename SPSSerializer, typename... ArgTs> 403 static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr, 404 const ArgTs &...Args) { 405 ArgDataBufferType ArgData; 406 ArgData.resize(SPSSerializer::size(Args...)); 407 SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(), 408 ArgData.size()); 409 if (SPSSerializer::serialize(OB, Args...)) 410 return WrapperFunctionCall(FnAddr, std::move(ArgData)); 411 return make_error<StringError>("Cannot serialize arguments for " 412 "AllocActionCall"); 413 } 414 415 WrapperFunctionCall() = default; 416 417 /// Create a WrapperFunctionCall from a target function and arg buffer. 418 WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData) 419 : FnAddr(FnAddr), ArgData(std::move(ArgData)) {} 420 421 /// Returns the address to be called. 422 const ExecutorAddr &getCallee() const { return FnAddr; } 423 424 /// Returns the argument data. 425 const ArgDataBufferType &getArgData() const { return ArgData; } 426 427 /// WrapperFunctionCalls convert to true if the callee is non-null. 428 explicit operator bool() const { return !!FnAddr; } 429 430 /// Run call returning raw WrapperFunctionResult. 431 WrapperFunctionResult run() const { 432 using FnTy = 433 orc_rt_CWrapperFunctionResult(const char *ArgData, size_t ArgSize); 434 return WrapperFunctionResult( 435 FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size())); 436 } 437 438 /// Run call and deserialize result using SPS. 439 template <typename SPSRetT, typename RetT> 440 std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error> 441 runWithSPSRet(RetT &RetVal) const { 442 auto WFR = run(); 443 if (const char *ErrMsg = WFR.getOutOfBandError()) 444 return make_error<StringError>(ErrMsg); 445 SPSInputBuffer IB(WFR.data(), WFR.size()); 446 if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal)) 447 return make_error<StringError>("Could not deserialize result from " 448 "serialized wrapper function call"); 449 return Error::success(); 450 } 451 452 /// Overload for SPS functions returning void. 453 template <typename SPSRetT> 454 std::enable_if_t<std::is_same<SPSRetT, void>::value, Error> 455 runWithSPSRet() const { 456 SPSEmpty E; 457 return runWithSPSRet<SPSEmpty>(E); 458 } 459 460 /// Run call and deserialize an SPSError result. SPSError returns and 461 /// deserialization failures are merged into the returned error. 462 Error runWithSPSRetErrorMerged() const { 463 detail::SPSSerializableError RetErr; 464 if (auto Err = runWithSPSRet<SPSError>(RetErr)) 465 return Err; 466 return detail::fromSPSSerializable(std::move(RetErr)); 467 } 468 469 private: 470 ExecutorAddr FnAddr; 471 std::vector<char> ArgData; 472 }; 473 474 using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>; 475 476 template <> 477 class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> { 478 public: 479 static size_t size(const WrapperFunctionCall &WFC) { 480 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size( 481 WFC.getCallee(), WFC.getArgData()); 482 } 483 484 static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) { 485 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize( 486 OB, WFC.getCallee(), WFC.getArgData()); 487 } 488 489 static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) { 490 ExecutorAddr FnAddr; 491 WrapperFunctionCall::ArgDataBufferType ArgData; 492 if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData)) 493 return false; 494 WFC = WrapperFunctionCall(FnAddr, std::move(ArgData)); 495 return true; 496 } 497 }; 498 499 } // namespace orc_rt 500 501 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H 502