xref: /llvm-project/compiler-rt/lib/orc/tests/unit/wrapper_function_utils_test.cpp (revision 69f8923efa61034b57805a8d6d859e9c1ca976eb)
1 //===-- wrapper_function_utils_test.cpp -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file is a part of the ORC runtime.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "common.h"
14 #include "jit_dispatch.h"
15 #include "wrapper_function_utils.h"
16 #include "gtest/gtest.h"
17 
18 using namespace orc_rt;
19 
20 namespace {
21 constexpr const char *TestString = "test string";
22 } // end anonymous namespace
23 
24 TEST(WrapperFunctionUtilsTest, DefaultWrapperFunctionResult) {
25   WrapperFunctionResult R;
26   EXPECT_TRUE(R.empty());
27   EXPECT_EQ(R.size(), 0U);
28   EXPECT_EQ(R.getOutOfBandError(), nullptr);
29 }
30 
31 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromCStruct) {
32   orc_rt_CWrapperFunctionResult CR =
33       orc_rt_CreateCWrapperFunctionResultFromString(TestString);
34   WrapperFunctionResult R(CR);
35   EXPECT_EQ(R.size(), strlen(TestString) + 1);
36   EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
37   EXPECT_FALSE(R.empty());
38   EXPECT_EQ(R.getOutOfBandError(), nullptr);
39 }
40 
41 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromRange) {
42   auto R = WrapperFunctionResult::copyFrom(TestString, strlen(TestString) + 1);
43   EXPECT_EQ(R.size(), strlen(TestString) + 1);
44   EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
45   EXPECT_FALSE(R.empty());
46   EXPECT_EQ(R.getOutOfBandError(), nullptr);
47 }
48 
49 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromCString) {
50   auto R = WrapperFunctionResult::copyFrom(TestString);
51   EXPECT_EQ(R.size(), strlen(TestString) + 1);
52   EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
53   EXPECT_FALSE(R.empty());
54   EXPECT_EQ(R.getOutOfBandError(), nullptr);
55 }
56 
57 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromStdString) {
58   auto R = WrapperFunctionResult::copyFrom(std::string(TestString));
59   EXPECT_EQ(R.size(), strlen(TestString) + 1);
60   EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
61   EXPECT_FALSE(R.empty());
62   EXPECT_EQ(R.getOutOfBandError(), nullptr);
63 }
64 
65 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromOutOfBandError) {
66   auto R = WrapperFunctionResult::createOutOfBandError(TestString);
67   EXPECT_FALSE(R.empty());
68   EXPECT_TRUE(strcmp(R.getOutOfBandError(), TestString) == 0);
69 }
70 
71 TEST(WrapperFunctionUtilsTest, WrapperFunctionCCallCreateEmpty) {
72   EXPECT_TRUE(!!WrapperFunctionCall::Create<SPSArgList<>>(ExecutorAddr()));
73 }
74 
75 static void voidNoop() {}
76 
77 static orc_rt_CWrapperFunctionResult voidNoopWrapper(const char *ArgData,
78                                                      size_t ArgSize) {
79   return WrapperFunction<void()>::handle(ArgData, ArgSize, voidNoop).release();
80 }
81 
82 static orc_rt_CWrapperFunctionResult addWrapper(const char *ArgData,
83                                                 size_t ArgSize) {
84   return WrapperFunction<int32_t(int32_t, int32_t)>::handle(
85              ArgData, ArgSize,
86              [](int32_t X, int32_t Y) -> int32_t { return X + Y; })
87       .release();
88 }
89 
90 extern "C" __orc_rt_Opaque __orc_rt_jit_dispatch_ctx{};
91 
92 extern "C" orc_rt_CWrapperFunctionResult
93 __orc_rt_jit_dispatch(__orc_rt_Opaque *Ctx, const void *FnTag,
94                       const char *ArgData, size_t ArgSize) {
95   using WrapperFunctionType =
96       orc_rt_CWrapperFunctionResult (*)(const char *, size_t);
97 
98   return reinterpret_cast<WrapperFunctionType>(const_cast<void *>(FnTag))(
99       ArgData, ArgSize);
100 }
101 
102 TEST(WrapperFunctionUtilsTest, WrapperFunctionCallVoidNoopAndHandle) {
103   EXPECT_FALSE(
104       !!WrapperFunction<void()>::call(JITDispatch((void *)&voidNoopWrapper)));
105 }
106 
107 TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAddWrapperAndHandle) {
108   int32_t Result;
109   EXPECT_FALSE(!!WrapperFunction<int32_t(int32_t, int32_t)>::call(
110       JITDispatch((void *)&addWrapper), Result, 1, 2));
111   EXPECT_EQ(Result, (int32_t)3);
112 }
113 
114 class AddClass {
115 public:
116   AddClass(int32_t X) : X(X) {}
117   int32_t addMethod(int32_t Y) { return X + Y; }
118 
119 private:
120   int32_t X;
121 };
122 
123 static orc_rt_CWrapperFunctionResult addMethodWrapper(const char *ArgData,
124                                                       size_t ArgSize) {
125   return WrapperFunction<int32_t(SPSExecutorAddr, int32_t)>::handle(
126              ArgData, ArgSize, makeMethodWrapperHandler(&AddClass::addMethod))
127       .release();
128 }
129 
130 TEST(WrapperFunctionUtilsTest, WrapperFunctionMethodCallAndHandleRet) {
131   int32_t Result;
132   AddClass AddObj(1);
133   EXPECT_FALSE(!!WrapperFunction<int32_t(SPSExecutorAddr, int32_t)>::call(
134       JITDispatch((void *)&addMethodWrapper), Result,
135       ExecutorAddr::fromPtr(&AddObj), 2));
136   EXPECT_EQ(Result, (int32_t)3);
137 }
138 
139 static orc_rt_CWrapperFunctionResult sumArrayWrapper(const char *ArgData,
140                                                      size_t ArgSize) {
141   return WrapperFunction<int8_t(SPSExecutorAddrRange)>::handle(
142              ArgData, ArgSize,
143              [](ExecutorAddrRange R) {
144                int8_t Sum = 0;
145                for (char C : R.toSpan<char>())
146                  Sum += C;
147                return Sum;
148              })
149       .release();
150 }
151 
152 TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) {
153   {
154     // Check wrapper function calls.
155     char A[] = {1, 2, 3, 4};
156 
157     auto WFC =
158         cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>(
159             ExecutorAddr::fromPtr(sumArrayWrapper),
160             ExecutorAddrRange(ExecutorAddr::fromPtr(A),
161                               ExecutorAddrDiff(sizeof(A)))));
162 
163     WrapperFunctionResult WFR(WFC.run());
164     EXPECT_EQ(WFR.size(), 1U);
165     EXPECT_EQ(WFR.data()[0], 10);
166   }
167 
168   {
169     // Check calls to void functions.
170     auto WFC =
171         cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>(
172             ExecutorAddr::fromPtr(voidNoopWrapper), ExecutorAddrRange()));
173     auto Err = WFC.runWithSPSRet<void>();
174     EXPECT_FALSE(!!Err);
175   }
176 
177   {
178     // Check calls with arguments and return values.
179     auto WFC =
180         cantFail(WrapperFunctionCall::Create<SPSArgList<int32_t, int32_t>>(
181             ExecutorAddr::fromPtr(addWrapper), 2, 4));
182 
183     int32_t Result = 0;
184     auto Err = WFC.runWithSPSRet<int32_t>(Result);
185     EXPECT_FALSE(!!Err);
186     EXPECT_EQ(Result, 6);
187   }
188 }
189