xref: /llvm-project/llvm/unittests/ExecutionEngine/Orc/ExecutionSessionWrapperFunctionCallsTest.cpp (revision dc11c0601577afb8f67513d041ee25dabe3555b9)
12487db1fSLang Hames //===- ExecutionSessionWrapperFunctionCallsTest.cpp -- Test wrapper calls -===//
22487db1fSLang Hames //
32487db1fSLang Hames // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42487db1fSLang Hames // See https://llvm.org/LICENSE.txt for license information.
52487db1fSLang Hames // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62487db1fSLang Hames //
72487db1fSLang Hames //===----------------------------------------------------------------------===//
82487db1fSLang Hames 
9*dc11c060SLang Hames #include "llvm/ExecutionEngine/Orc/AbsoluteSymbols.h"
102487db1fSLang Hames #include "llvm/ExecutionEngine/Orc/Core.h"
112487db1fSLang Hames #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
122487db1fSLang Hames #include "llvm/Support/MSVCErrorWorkarounds.h"
132487db1fSLang Hames #include "llvm/Testing/Support/Error.h"
142487db1fSLang Hames #include "gtest/gtest.h"
152487db1fSLang Hames 
162487db1fSLang Hames #include <future>
172487db1fSLang Hames 
182487db1fSLang Hames using namespace llvm;
192487db1fSLang Hames using namespace llvm::orc;
202487db1fSLang Hames using namespace llvm::orc::shared;
212487db1fSLang Hames 
22213666f8SLang Hames static llvm::orc::shared::CWrapperFunctionResult addWrapper(const char *ArgData,
23213666f8SLang Hames                                                             size_t ArgSize) {
242487db1fSLang Hames   return WrapperFunction<int32_t(int32_t, int32_t)>::handle(
252487db1fSLang Hames              ArgData, ArgSize, [](int32_t X, int32_t Y) { return X + Y; })
262487db1fSLang Hames       .release();
272487db1fSLang Hames }
282487db1fSLang Hames 
292487db1fSLang Hames static void addAsyncWrapper(unique_function<void(int32_t)> SendResult,
302487db1fSLang Hames                             int32_t X, int32_t Y) {
312487db1fSLang Hames   SendResult(X + Y);
322487db1fSLang Hames }
332487db1fSLang Hames 
34213666f8SLang Hames static llvm::orc::shared::CWrapperFunctionResult
358a367502SLang Hames voidWrapper(const char *ArgData, size_t ArgSize) {
368a367502SLang Hames   return WrapperFunction<void()>::handle(ArgData, ArgSize, []() {}).release();
378a367502SLang Hames }
388a367502SLang Hames 
392487db1fSLang Hames TEST(ExecutionSessionWrapperFunctionCalls, RunWrapperTemplate) {
402487db1fSLang Hames   ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create()));
412487db1fSLang Hames 
422487db1fSLang Hames   int32_t Result;
432487db1fSLang Hames   EXPECT_THAT_ERROR(ES.callSPSWrapper<int32_t(int32_t, int32_t)>(
4421a06254SLang Hames                         ExecutorAddr::fromPtr(addWrapper), Result, 2, 3),
452487db1fSLang Hames                     Succeeded());
462487db1fSLang Hames   EXPECT_EQ(Result, 5);
4719b4e3cfSLang Hames   cantFail(ES.endSession());
482487db1fSLang Hames }
492487db1fSLang Hames 
508a367502SLang Hames TEST(ExecutionSessionWrapperFunctionCalls, RunVoidWrapperAsyncTemplate) {
518a367502SLang Hames   ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create()));
528a367502SLang Hames 
538a367502SLang Hames   std::promise<MSVCPError> RP;
54da7f993aSLang Hames   ES.callSPSWrapperAsync<void()>(ExecutorAddr::fromPtr(voidWrapper),
558a367502SLang Hames                                  [&](Error SerializationErr) {
568a367502SLang Hames                                    RP.set_value(std::move(SerializationErr));
57da7f993aSLang Hames                                  });
588a367502SLang Hames   Error Err = RP.get_future().get();
598a367502SLang Hames   EXPECT_THAT_ERROR(std::move(Err), Succeeded());
6019b4e3cfSLang Hames   cantFail(ES.endSession());
618a367502SLang Hames }
628a367502SLang Hames 
638a367502SLang Hames TEST(ExecutionSessionWrapperFunctionCalls, RunNonVoidWrapperAsyncTemplate) {
642487db1fSLang Hames   ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create()));
652487db1fSLang Hames 
662487db1fSLang Hames   std::promise<MSVCPExpected<int32_t>> RP;
678a367502SLang Hames   ES.callSPSWrapperAsync<int32_t(int32_t, int32_t)>(
68da7f993aSLang Hames       ExecutorAddr::fromPtr(addWrapper),
692487db1fSLang Hames       [&](Error SerializationErr, int32_t R) {
702487db1fSLang Hames         if (SerializationErr)
712487db1fSLang Hames           RP.set_value(std::move(SerializationErr));
722487db1fSLang Hames         RP.set_value(std::move(R));
732487db1fSLang Hames       },
74da7f993aSLang Hames       2, 3);
752487db1fSLang Hames   Expected<int32_t> Result = RP.get_future().get();
762487db1fSLang Hames   EXPECT_THAT_EXPECTED(Result, HasValue(5));
7719b4e3cfSLang Hames   cantFail(ES.endSession());
782487db1fSLang Hames }
792487db1fSLang Hames 
802487db1fSLang Hames TEST(ExecutionSessionWrapperFunctionCalls, RegisterAsyncHandlerAndRun) {
812487db1fSLang Hames 
828b1771bdSLang Hames   constexpr ExecutorAddr AddAsyncTagAddr(0x01);
832487db1fSLang Hames 
842487db1fSLang Hames   ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create()));
852487db1fSLang Hames   auto &JD = ES.createBareJITDylib("JD");
862487db1fSLang Hames 
872487db1fSLang Hames   auto AddAsyncTag = ES.intern("addAsync_tag");
882487db1fSLang Hames   cantFail(JD.define(absoluteSymbols(
898b1771bdSLang Hames       {{AddAsyncTag, {AddAsyncTagAddr, JITSymbolFlags::Exported}}})));
902487db1fSLang Hames 
912487db1fSLang Hames   ExecutionSession::JITDispatchHandlerAssociationMap Associations;
922487db1fSLang Hames 
932487db1fSLang Hames   Associations[AddAsyncTag] =
942487db1fSLang Hames       ES.wrapAsyncWithSPS<int32_t(int32_t, int32_t)>(addAsyncWrapper);
952487db1fSLang Hames 
962487db1fSLang Hames   cantFail(ES.registerJITDispatchHandlers(JD, std::move(Associations)));
972487db1fSLang Hames 
982487db1fSLang Hames   std::promise<int32_t> RP;
992487db1fSLang Hames   auto RF = RP.get_future();
1002487db1fSLang Hames 
1012487db1fSLang Hames   using ArgSerialization = SPSArgList<int32_t, int32_t>;
1022487db1fSLang Hames   size_t ArgBufferSize = ArgSerialization::size(1, 2);
1038b117830SLang Hames   auto ArgBuffer = WrapperFunctionResult::allocate(ArgBufferSize);
1048b117830SLang Hames   SPSOutputBuffer OB(ArgBuffer.data(), ArgBuffer.size());
1052487db1fSLang Hames   EXPECT_TRUE(ArgSerialization::serialize(OB, 1, 2));
1062487db1fSLang Hames 
1072487db1fSLang Hames   ES.runJITDispatchHandler(
1082487db1fSLang Hames       [&](WrapperFunctionResult ResultBuffer) {
1092487db1fSLang Hames         int32_t Result;
1102487db1fSLang Hames         SPSInputBuffer IB(ResultBuffer.data(), ResultBuffer.size());
1112487db1fSLang Hames         EXPECT_TRUE(SPSArgList<int32_t>::deserialize(IB, Result));
1122487db1fSLang Hames         RP.set_value(Result);
1132487db1fSLang Hames       },
1142487db1fSLang Hames       AddAsyncTagAddr, ArrayRef<char>(ArgBuffer.data(), ArgBuffer.size()));
1152487db1fSLang Hames 
1162487db1fSLang Hames   EXPECT_EQ(RF.get(), (int32_t)3);
1172487db1fSLang Hames 
1182487db1fSLang Hames   cantFail(ES.endSession());
1192487db1fSLang Hames }
120