12487db1fSLang Hames //===- ExecutionSessionWrapperFunctionCallsTest.cpp -- Test wrapper calls -===// 22487db1fSLang Hames // 32487db1fSLang Hames // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42487db1fSLang Hames // See https://llvm.org/LICENSE.txt for license information. 52487db1fSLang Hames // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62487db1fSLang Hames // 72487db1fSLang Hames //===----------------------------------------------------------------------===// 82487db1fSLang Hames 9*dc11c060SLang Hames #include "llvm/ExecutionEngine/Orc/AbsoluteSymbols.h" 102487db1fSLang Hames #include "llvm/ExecutionEngine/Orc/Core.h" 112487db1fSLang Hames #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" 122487db1fSLang Hames #include "llvm/Support/MSVCErrorWorkarounds.h" 132487db1fSLang Hames #include "llvm/Testing/Support/Error.h" 142487db1fSLang Hames #include "gtest/gtest.h" 152487db1fSLang Hames 162487db1fSLang Hames #include <future> 172487db1fSLang Hames 182487db1fSLang Hames using namespace llvm; 192487db1fSLang Hames using namespace llvm::orc; 202487db1fSLang Hames using namespace llvm::orc::shared; 212487db1fSLang Hames 22213666f8SLang Hames static llvm::orc::shared::CWrapperFunctionResult addWrapper(const char *ArgData, 23213666f8SLang Hames size_t ArgSize) { 242487db1fSLang Hames return WrapperFunction<int32_t(int32_t, int32_t)>::handle( 252487db1fSLang Hames ArgData, ArgSize, [](int32_t X, int32_t Y) { return X + Y; }) 262487db1fSLang Hames .release(); 272487db1fSLang Hames } 282487db1fSLang Hames 292487db1fSLang Hames static void addAsyncWrapper(unique_function<void(int32_t)> SendResult, 302487db1fSLang Hames int32_t X, int32_t Y) { 312487db1fSLang Hames SendResult(X + Y); 322487db1fSLang Hames } 332487db1fSLang Hames 34213666f8SLang Hames static llvm::orc::shared::CWrapperFunctionResult 358a367502SLang Hames voidWrapper(const char *ArgData, size_t ArgSize) { 368a367502SLang Hames return WrapperFunction<void()>::handle(ArgData, ArgSize, []() {}).release(); 378a367502SLang Hames } 388a367502SLang Hames 392487db1fSLang Hames TEST(ExecutionSessionWrapperFunctionCalls, RunWrapperTemplate) { 402487db1fSLang Hames ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create())); 412487db1fSLang Hames 422487db1fSLang Hames int32_t Result; 432487db1fSLang Hames EXPECT_THAT_ERROR(ES.callSPSWrapper<int32_t(int32_t, int32_t)>( 4421a06254SLang Hames ExecutorAddr::fromPtr(addWrapper), Result, 2, 3), 452487db1fSLang Hames Succeeded()); 462487db1fSLang Hames EXPECT_EQ(Result, 5); 4719b4e3cfSLang Hames cantFail(ES.endSession()); 482487db1fSLang Hames } 492487db1fSLang Hames 508a367502SLang Hames TEST(ExecutionSessionWrapperFunctionCalls, RunVoidWrapperAsyncTemplate) { 518a367502SLang Hames ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create())); 528a367502SLang Hames 538a367502SLang Hames std::promise<MSVCPError> RP; 54da7f993aSLang Hames ES.callSPSWrapperAsync<void()>(ExecutorAddr::fromPtr(voidWrapper), 558a367502SLang Hames [&](Error SerializationErr) { 568a367502SLang Hames RP.set_value(std::move(SerializationErr)); 57da7f993aSLang Hames }); 588a367502SLang Hames Error Err = RP.get_future().get(); 598a367502SLang Hames EXPECT_THAT_ERROR(std::move(Err), Succeeded()); 6019b4e3cfSLang Hames cantFail(ES.endSession()); 618a367502SLang Hames } 628a367502SLang Hames 638a367502SLang Hames TEST(ExecutionSessionWrapperFunctionCalls, RunNonVoidWrapperAsyncTemplate) { 642487db1fSLang Hames ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create())); 652487db1fSLang Hames 662487db1fSLang Hames std::promise<MSVCPExpected<int32_t>> RP; 678a367502SLang Hames ES.callSPSWrapperAsync<int32_t(int32_t, int32_t)>( 68da7f993aSLang Hames ExecutorAddr::fromPtr(addWrapper), 692487db1fSLang Hames [&](Error SerializationErr, int32_t R) { 702487db1fSLang Hames if (SerializationErr) 712487db1fSLang Hames RP.set_value(std::move(SerializationErr)); 722487db1fSLang Hames RP.set_value(std::move(R)); 732487db1fSLang Hames }, 74da7f993aSLang Hames 2, 3); 752487db1fSLang Hames Expected<int32_t> Result = RP.get_future().get(); 762487db1fSLang Hames EXPECT_THAT_EXPECTED(Result, HasValue(5)); 7719b4e3cfSLang Hames cantFail(ES.endSession()); 782487db1fSLang Hames } 792487db1fSLang Hames 802487db1fSLang Hames TEST(ExecutionSessionWrapperFunctionCalls, RegisterAsyncHandlerAndRun) { 812487db1fSLang Hames 828b1771bdSLang Hames constexpr ExecutorAddr AddAsyncTagAddr(0x01); 832487db1fSLang Hames 842487db1fSLang Hames ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create())); 852487db1fSLang Hames auto &JD = ES.createBareJITDylib("JD"); 862487db1fSLang Hames 872487db1fSLang Hames auto AddAsyncTag = ES.intern("addAsync_tag"); 882487db1fSLang Hames cantFail(JD.define(absoluteSymbols( 898b1771bdSLang Hames {{AddAsyncTag, {AddAsyncTagAddr, JITSymbolFlags::Exported}}}))); 902487db1fSLang Hames 912487db1fSLang Hames ExecutionSession::JITDispatchHandlerAssociationMap Associations; 922487db1fSLang Hames 932487db1fSLang Hames Associations[AddAsyncTag] = 942487db1fSLang Hames ES.wrapAsyncWithSPS<int32_t(int32_t, int32_t)>(addAsyncWrapper); 952487db1fSLang Hames 962487db1fSLang Hames cantFail(ES.registerJITDispatchHandlers(JD, std::move(Associations))); 972487db1fSLang Hames 982487db1fSLang Hames std::promise<int32_t> RP; 992487db1fSLang Hames auto RF = RP.get_future(); 1002487db1fSLang Hames 1012487db1fSLang Hames using ArgSerialization = SPSArgList<int32_t, int32_t>; 1022487db1fSLang Hames size_t ArgBufferSize = ArgSerialization::size(1, 2); 1038b117830SLang Hames auto ArgBuffer = WrapperFunctionResult::allocate(ArgBufferSize); 1048b117830SLang Hames SPSOutputBuffer OB(ArgBuffer.data(), ArgBuffer.size()); 1052487db1fSLang Hames EXPECT_TRUE(ArgSerialization::serialize(OB, 1, 2)); 1062487db1fSLang Hames 1072487db1fSLang Hames ES.runJITDispatchHandler( 1082487db1fSLang Hames [&](WrapperFunctionResult ResultBuffer) { 1092487db1fSLang Hames int32_t Result; 1102487db1fSLang Hames SPSInputBuffer IB(ResultBuffer.data(), ResultBuffer.size()); 1112487db1fSLang Hames EXPECT_TRUE(SPSArgList<int32_t>::deserialize(IB, Result)); 1122487db1fSLang Hames RP.set_value(Result); 1132487db1fSLang Hames }, 1142487db1fSLang Hames AddAsyncTagAddr, ArrayRef<char>(ArgBuffer.data(), ArgBuffer.size())); 1152487db1fSLang Hames 1162487db1fSLang Hames EXPECT_EQ(RF.get(), (int32_t)3); 1172487db1fSLang Hames 1182487db1fSLang Hames cantFail(ES.endSession()); 1192487db1fSLang Hames } 120