xref: /llvm-project/compiler-rt/lib/orc/tests/unit/wrapper_function_utils_test.cpp (revision 69f8923efa61034b57805a8d6d859e9c1ca976eb)
11169586dSLang Hames //===-- wrapper_function_utils_test.cpp -----------------------------------===//
21169586dSLang Hames //
31169586dSLang Hames // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41169586dSLang Hames // See https://llvm.org/LICENSE.txt for license information.
51169586dSLang Hames // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61169586dSLang Hames //
71169586dSLang Hames //===----------------------------------------------------------------------===//
81169586dSLang Hames //
91169586dSLang Hames // This file is a part of the ORC runtime.
101169586dSLang Hames //
111169586dSLang Hames //===----------------------------------------------------------------------===//
121169586dSLang Hames 
13*69f8923eSLang Hames #include "common.h"
14*69f8923eSLang Hames #include "jit_dispatch.h"
151169586dSLang Hames #include "wrapper_function_utils.h"
161169586dSLang Hames #include "gtest/gtest.h"
171169586dSLang Hames 
18dbd81ba2SMikhail Goncharov using namespace orc_rt;
191169586dSLang Hames 
201169586dSLang Hames namespace {
211169586dSLang Hames constexpr const char *TestString = "test string";
221169586dSLang Hames } // end anonymous namespace
231169586dSLang Hames 
241169586dSLang Hames TEST(WrapperFunctionUtilsTest, DefaultWrapperFunctionResult) {
251169586dSLang Hames   WrapperFunctionResult R;
261169586dSLang Hames   EXPECT_TRUE(R.empty());
271169586dSLang Hames   EXPECT_EQ(R.size(), 0U);
281169586dSLang Hames   EXPECT_EQ(R.getOutOfBandError(), nullptr);
291169586dSLang Hames }
301169586dSLang Hames 
311169586dSLang Hames TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromCStruct) {
3234fccfb2SLang Hames   orc_rt_CWrapperFunctionResult CR =
3334fccfb2SLang Hames       orc_rt_CreateCWrapperFunctionResultFromString(TestString);
341169586dSLang Hames   WrapperFunctionResult R(CR);
351169586dSLang Hames   EXPECT_EQ(R.size(), strlen(TestString) + 1);
361169586dSLang Hames   EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
371169586dSLang Hames   EXPECT_FALSE(R.empty());
381169586dSLang Hames   EXPECT_EQ(R.getOutOfBandError(), nullptr);
391169586dSLang Hames }
401169586dSLang Hames 
411169586dSLang Hames TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromRange) {
421169586dSLang Hames   auto R = WrapperFunctionResult::copyFrom(TestString, strlen(TestString) + 1);
431169586dSLang Hames   EXPECT_EQ(R.size(), strlen(TestString) + 1);
441169586dSLang Hames   EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
451169586dSLang Hames   EXPECT_FALSE(R.empty());
461169586dSLang Hames   EXPECT_EQ(R.getOutOfBandError(), nullptr);
471169586dSLang Hames }
481169586dSLang Hames 
491169586dSLang Hames TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromCString) {
501169586dSLang Hames   auto R = WrapperFunctionResult::copyFrom(TestString);
511169586dSLang Hames   EXPECT_EQ(R.size(), strlen(TestString) + 1);
521169586dSLang Hames   EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
531169586dSLang Hames   EXPECT_FALSE(R.empty());
541169586dSLang Hames   EXPECT_EQ(R.getOutOfBandError(), nullptr);
551169586dSLang Hames }
561169586dSLang Hames 
571169586dSLang Hames TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromStdString) {
581169586dSLang Hames   auto R = WrapperFunctionResult::copyFrom(std::string(TestString));
591169586dSLang Hames   EXPECT_EQ(R.size(), strlen(TestString) + 1);
601169586dSLang Hames   EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
611169586dSLang Hames   EXPECT_FALSE(R.empty());
621169586dSLang Hames   EXPECT_EQ(R.getOutOfBandError(), nullptr);
631169586dSLang Hames }
641169586dSLang Hames 
651169586dSLang Hames TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromOutOfBandError) {
661169586dSLang Hames   auto R = WrapperFunctionResult::createOutOfBandError(TestString);
671169586dSLang Hames   EXPECT_FALSE(R.empty());
681169586dSLang Hames   EXPECT_TRUE(strcmp(R.getOutOfBandError(), TestString) == 0);
691169586dSLang Hames }
701169586dSLang Hames 
710e43f3b0SLang Hames TEST(WrapperFunctionUtilsTest, WrapperFunctionCCallCreateEmpty) {
720e43f3b0SLang Hames   EXPECT_TRUE(!!WrapperFunctionCall::Create<SPSArgList<>>(ExecutorAddr()));
730e43f3b0SLang Hames }
740e43f3b0SLang Hames 
751169586dSLang Hames static void voidNoop() {}
761169586dSLang Hames 
7734fccfb2SLang Hames static orc_rt_CWrapperFunctionResult voidNoopWrapper(const char *ArgData,
781169586dSLang Hames                                                      size_t ArgSize) {
791169586dSLang Hames   return WrapperFunction<void()>::handle(ArgData, ArgSize, voidNoop).release();
801169586dSLang Hames }
811169586dSLang Hames 
8234fccfb2SLang Hames static orc_rt_CWrapperFunctionResult addWrapper(const char *ArgData,
831169586dSLang Hames                                                 size_t ArgSize) {
841169586dSLang Hames   return WrapperFunction<int32_t(int32_t, int32_t)>::handle(
851169586dSLang Hames              ArgData, ArgSize,
861169586dSLang Hames              [](int32_t X, int32_t Y) -> int32_t { return X + Y; })
871169586dSLang Hames       .release();
881169586dSLang Hames }
891169586dSLang Hames 
901169586dSLang Hames extern "C" __orc_rt_Opaque __orc_rt_jit_dispatch_ctx{};
911169586dSLang Hames 
9234fccfb2SLang Hames extern "C" orc_rt_CWrapperFunctionResult
931169586dSLang Hames __orc_rt_jit_dispatch(__orc_rt_Opaque *Ctx, const void *FnTag,
941169586dSLang Hames                       const char *ArgData, size_t ArgSize) {
951169586dSLang Hames   using WrapperFunctionType =
9634fccfb2SLang Hames       orc_rt_CWrapperFunctionResult (*)(const char *, size_t);
971169586dSLang Hames 
981169586dSLang Hames   return reinterpret_cast<WrapperFunctionType>(const_cast<void *>(FnTag))(
991169586dSLang Hames       ArgData, ArgSize);
1001169586dSLang Hames }
1011169586dSLang Hames 
1021169586dSLang Hames TEST(WrapperFunctionUtilsTest, WrapperFunctionCallVoidNoopAndHandle) {
103*69f8923eSLang Hames   EXPECT_FALSE(
104*69f8923eSLang Hames       !!WrapperFunction<void()>::call(JITDispatch((void *)&voidNoopWrapper)));
1051169586dSLang Hames }
1061169586dSLang Hames 
1071169586dSLang Hames TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAddWrapperAndHandle) {
1081169586dSLang Hames   int32_t Result;
1091169586dSLang Hames   EXPECT_FALSE(!!WrapperFunction<int32_t(int32_t, int32_t)>::call(
110*69f8923eSLang Hames       JITDispatch((void *)&addWrapper), Result, 1, 2));
1111169586dSLang Hames   EXPECT_EQ(Result, (int32_t)3);
1121169586dSLang Hames }
1131169586dSLang Hames 
1141169586dSLang Hames class AddClass {
1151169586dSLang Hames public:
1161169586dSLang Hames   AddClass(int32_t X) : X(X) {}
1171169586dSLang Hames   int32_t addMethod(int32_t Y) { return X + Y; }
1181169586dSLang Hames 
1191169586dSLang Hames private:
1201169586dSLang Hames   int32_t X;
1211169586dSLang Hames };
1221169586dSLang Hames 
12334fccfb2SLang Hames static orc_rt_CWrapperFunctionResult addMethodWrapper(const char *ArgData,
1241169586dSLang Hames                                                       size_t ArgSize) {
1251169586dSLang Hames   return WrapperFunction<int32_t(SPSExecutorAddr, int32_t)>::handle(
1261169586dSLang Hames              ArgData, ArgSize, makeMethodWrapperHandler(&AddClass::addMethod))
1271169586dSLang Hames       .release();
1281169586dSLang Hames }
1291169586dSLang Hames 
1301169586dSLang Hames TEST(WrapperFunctionUtilsTest, WrapperFunctionMethodCallAndHandleRet) {
1311169586dSLang Hames   int32_t Result;
1321169586dSLang Hames   AddClass AddObj(1);
1331169586dSLang Hames   EXPECT_FALSE(!!WrapperFunction<int32_t(SPSExecutorAddr, int32_t)>::call(
134*69f8923eSLang Hames       JITDispatch((void *)&addMethodWrapper), Result,
135*69f8923eSLang Hames       ExecutorAddr::fromPtr(&AddObj), 2));
1361169586dSLang Hames   EXPECT_EQ(Result, (int32_t)3);
1371169586dSLang Hames }
1381169586dSLang Hames 
13934fccfb2SLang Hames static orc_rt_CWrapperFunctionResult sumArrayWrapper(const char *ArgData,
1401169586dSLang Hames                                                      size_t ArgSize) {
1411169586dSLang Hames   return WrapperFunction<int8_t(SPSExecutorAddrRange)>::handle(
1421169586dSLang Hames              ArgData, ArgSize,
1431169586dSLang Hames              [](ExecutorAddrRange R) {
1441169586dSLang Hames                int8_t Sum = 0;
1451169586dSLang Hames                for (char C : R.toSpan<char>())
1461169586dSLang Hames                  Sum += C;
1471169586dSLang Hames                return Sum;
1481169586dSLang Hames              })
1491169586dSLang Hames       .release();
1501169586dSLang Hames }
1511169586dSLang Hames 
1521169586dSLang Hames TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) {
1531169586dSLang Hames   {
1541169586dSLang Hames     // Check wrapper function calls.
1551169586dSLang Hames     char A[] = {1, 2, 3, 4};
1561169586dSLang Hames 
1571169586dSLang Hames     auto WFC =
1581169586dSLang Hames         cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>(
1591169586dSLang Hames             ExecutorAddr::fromPtr(sumArrayWrapper),
1601169586dSLang Hames             ExecutorAddrRange(ExecutorAddr::fromPtr(A),
1611169586dSLang Hames                               ExecutorAddrDiff(sizeof(A)))));
1621169586dSLang Hames 
1631169586dSLang Hames     WrapperFunctionResult WFR(WFC.run());
1641169586dSLang Hames     EXPECT_EQ(WFR.size(), 1U);
1651169586dSLang Hames     EXPECT_EQ(WFR.data()[0], 10);
1661169586dSLang Hames   }
1671169586dSLang Hames 
1681169586dSLang Hames   {
1691169586dSLang Hames     // Check calls to void functions.
1701169586dSLang Hames     auto WFC =
1711169586dSLang Hames         cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>(
1721169586dSLang Hames             ExecutorAddr::fromPtr(voidNoopWrapper), ExecutorAddrRange()));
1731169586dSLang Hames     auto Err = WFC.runWithSPSRet<void>();
1741169586dSLang Hames     EXPECT_FALSE(!!Err);
1751169586dSLang Hames   }
1761169586dSLang Hames 
1771169586dSLang Hames   {
1781169586dSLang Hames     // Check calls with arguments and return values.
1791169586dSLang Hames     auto WFC =
1801169586dSLang Hames         cantFail(WrapperFunctionCall::Create<SPSArgList<int32_t, int32_t>>(
1811169586dSLang Hames             ExecutorAddr::fromPtr(addWrapper), 2, 4));
1821169586dSLang Hames 
1831169586dSLang Hames     int32_t Result = 0;
1841169586dSLang Hames     auto Err = WFC.runWithSPSRet<int32_t>(Result);
1851169586dSLang Hames     EXPECT_FALSE(!!Err);
1861169586dSLang Hames     EXPECT_EQ(Result, 6);
1871169586dSLang Hames   }
1881169586dSLang Hames }
189