//===- ExecutionSessionWrapperFunctionCallsTest.cpp -- Test wrapper calls -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "llvm/ExecutionEngine/Orc/AbsoluteSymbols.h" #include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" #include "llvm/Support/MSVCErrorWorkarounds.h" #include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" #include using namespace llvm; using namespace llvm::orc; using namespace llvm::orc::shared; static llvm::orc::shared::CWrapperFunctionResult addWrapper(const char *ArgData, size_t ArgSize) { return WrapperFunction::handle( ArgData, ArgSize, [](int32_t X, int32_t Y) { return X + Y; }) .release(); } static void addAsyncWrapper(unique_function SendResult, int32_t X, int32_t Y) { SendResult(X + Y); } static llvm::orc::shared::CWrapperFunctionResult voidWrapper(const char *ArgData, size_t ArgSize) { return WrapperFunction::handle(ArgData, ArgSize, []() {}).release(); } TEST(ExecutionSessionWrapperFunctionCalls, RunWrapperTemplate) { ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create())); int32_t Result; EXPECT_THAT_ERROR(ES.callSPSWrapper( ExecutorAddr::fromPtr(addWrapper), Result, 2, 3), Succeeded()); EXPECT_EQ(Result, 5); cantFail(ES.endSession()); } TEST(ExecutionSessionWrapperFunctionCalls, RunVoidWrapperAsyncTemplate) { ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create())); std::promise RP; ES.callSPSWrapperAsync(ExecutorAddr::fromPtr(voidWrapper), [&](Error SerializationErr) { RP.set_value(std::move(SerializationErr)); }); Error Err = RP.get_future().get(); EXPECT_THAT_ERROR(std::move(Err), Succeeded()); cantFail(ES.endSession()); } TEST(ExecutionSessionWrapperFunctionCalls, RunNonVoidWrapperAsyncTemplate) { ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create())); std::promise> RP; ES.callSPSWrapperAsync( ExecutorAddr::fromPtr(addWrapper), [&](Error SerializationErr, int32_t R) { if (SerializationErr) RP.set_value(std::move(SerializationErr)); RP.set_value(std::move(R)); }, 2, 3); Expected Result = RP.get_future().get(); EXPECT_THAT_EXPECTED(Result, HasValue(5)); cantFail(ES.endSession()); } TEST(ExecutionSessionWrapperFunctionCalls, RegisterAsyncHandlerAndRun) { constexpr ExecutorAddr AddAsyncTagAddr(0x01); ExecutionSession ES(cantFail(SelfExecutorProcessControl::Create())); auto &JD = ES.createBareJITDylib("JD"); auto AddAsyncTag = ES.intern("addAsync_tag"); cantFail(JD.define(absoluteSymbols( {{AddAsyncTag, {AddAsyncTagAddr, JITSymbolFlags::Exported}}}))); ExecutionSession::JITDispatchHandlerAssociationMap Associations; Associations[AddAsyncTag] = ES.wrapAsyncWithSPS(addAsyncWrapper); cantFail(ES.registerJITDispatchHandlers(JD, std::move(Associations))); std::promise RP; auto RF = RP.get_future(); using ArgSerialization = SPSArgList; size_t ArgBufferSize = ArgSerialization::size(1, 2); auto ArgBuffer = WrapperFunctionResult::allocate(ArgBufferSize); SPSOutputBuffer OB(ArgBuffer.data(), ArgBuffer.size()); EXPECT_TRUE(ArgSerialization::serialize(OB, 1, 2)); ES.runJITDispatchHandler( [&](WrapperFunctionResult ResultBuffer) { int32_t Result; SPSInputBuffer IB(ResultBuffer.data(), ResultBuffer.size()); EXPECT_TRUE(SPSArgList::deserialize(IB, Result)); RP.set_value(Result); }, AddAsyncTagAddr, ArrayRef(ArgBuffer.data(), ArgBuffer.size())); EXPECT_EQ(RF.get(), (int32_t)3); cantFail(ES.endSession()); }