1 //===-- wrapper_function_utils_test.cpp -----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file is a part of the ORC runtime. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "common.h" 14 #include "jit_dispatch.h" 15 #include "wrapper_function_utils.h" 16 #include "gtest/gtest.h" 17 18 using namespace orc_rt; 19 20 namespace { 21 constexpr const char *TestString = "test string"; 22 } // end anonymous namespace 23 24 TEST(WrapperFunctionUtilsTest, DefaultWrapperFunctionResult) { 25 WrapperFunctionResult R; 26 EXPECT_TRUE(R.empty()); 27 EXPECT_EQ(R.size(), 0U); 28 EXPECT_EQ(R.getOutOfBandError(), nullptr); 29 } 30 31 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromCStruct) { 32 orc_rt_CWrapperFunctionResult CR = 33 orc_rt_CreateCWrapperFunctionResultFromString(TestString); 34 WrapperFunctionResult R(CR); 35 EXPECT_EQ(R.size(), strlen(TestString) + 1); 36 EXPECT_TRUE(strcmp(R.data(), TestString) == 0); 37 EXPECT_FALSE(R.empty()); 38 EXPECT_EQ(R.getOutOfBandError(), nullptr); 39 } 40 41 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromRange) { 42 auto R = WrapperFunctionResult::copyFrom(TestString, strlen(TestString) + 1); 43 EXPECT_EQ(R.size(), strlen(TestString) + 1); 44 EXPECT_TRUE(strcmp(R.data(), TestString) == 0); 45 EXPECT_FALSE(R.empty()); 46 EXPECT_EQ(R.getOutOfBandError(), nullptr); 47 } 48 49 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromCString) { 50 auto R = WrapperFunctionResult::copyFrom(TestString); 51 EXPECT_EQ(R.size(), strlen(TestString) + 1); 52 EXPECT_TRUE(strcmp(R.data(), TestString) == 0); 53 EXPECT_FALSE(R.empty()); 54 EXPECT_EQ(R.getOutOfBandError(), nullptr); 55 } 56 57 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromStdString) { 58 auto R = WrapperFunctionResult::copyFrom(std::string(TestString)); 59 EXPECT_EQ(R.size(), strlen(TestString) + 1); 60 EXPECT_TRUE(strcmp(R.data(), TestString) == 0); 61 EXPECT_FALSE(R.empty()); 62 EXPECT_EQ(R.getOutOfBandError(), nullptr); 63 } 64 65 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromOutOfBandError) { 66 auto R = WrapperFunctionResult::createOutOfBandError(TestString); 67 EXPECT_FALSE(R.empty()); 68 EXPECT_TRUE(strcmp(R.getOutOfBandError(), TestString) == 0); 69 } 70 71 TEST(WrapperFunctionUtilsTest, WrapperFunctionCCallCreateEmpty) { 72 EXPECT_TRUE(!!WrapperFunctionCall::Create<SPSArgList<>>(ExecutorAddr())); 73 } 74 75 static void voidNoop() {} 76 77 static orc_rt_CWrapperFunctionResult voidNoopWrapper(const char *ArgData, 78 size_t ArgSize) { 79 return WrapperFunction<void()>::handle(ArgData, ArgSize, voidNoop).release(); 80 } 81 82 static orc_rt_CWrapperFunctionResult addWrapper(const char *ArgData, 83 size_t ArgSize) { 84 return WrapperFunction<int32_t(int32_t, int32_t)>::handle( 85 ArgData, ArgSize, 86 [](int32_t X, int32_t Y) -> int32_t { return X + Y; }) 87 .release(); 88 } 89 90 extern "C" __orc_rt_Opaque __orc_rt_jit_dispatch_ctx{}; 91 92 extern "C" orc_rt_CWrapperFunctionResult 93 __orc_rt_jit_dispatch(__orc_rt_Opaque *Ctx, const void *FnTag, 94 const char *ArgData, size_t ArgSize) { 95 using WrapperFunctionType = 96 orc_rt_CWrapperFunctionResult (*)(const char *, size_t); 97 98 return reinterpret_cast<WrapperFunctionType>(const_cast<void *>(FnTag))( 99 ArgData, ArgSize); 100 } 101 102 TEST(WrapperFunctionUtilsTest, WrapperFunctionCallVoidNoopAndHandle) { 103 EXPECT_FALSE( 104 !!WrapperFunction<void()>::call(JITDispatch((void *)&voidNoopWrapper))); 105 } 106 107 TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAddWrapperAndHandle) { 108 int32_t Result; 109 EXPECT_FALSE(!!WrapperFunction<int32_t(int32_t, int32_t)>::call( 110 JITDispatch((void *)&addWrapper), Result, 1, 2)); 111 EXPECT_EQ(Result, (int32_t)3); 112 } 113 114 class AddClass { 115 public: 116 AddClass(int32_t X) : X(X) {} 117 int32_t addMethod(int32_t Y) { return X + Y; } 118 119 private: 120 int32_t X; 121 }; 122 123 static orc_rt_CWrapperFunctionResult addMethodWrapper(const char *ArgData, 124 size_t ArgSize) { 125 return WrapperFunction<int32_t(SPSExecutorAddr, int32_t)>::handle( 126 ArgData, ArgSize, makeMethodWrapperHandler(&AddClass::addMethod)) 127 .release(); 128 } 129 130 TEST(WrapperFunctionUtilsTest, WrapperFunctionMethodCallAndHandleRet) { 131 int32_t Result; 132 AddClass AddObj(1); 133 EXPECT_FALSE(!!WrapperFunction<int32_t(SPSExecutorAddr, int32_t)>::call( 134 JITDispatch((void *)&addMethodWrapper), Result, 135 ExecutorAddr::fromPtr(&AddObj), 2)); 136 EXPECT_EQ(Result, (int32_t)3); 137 } 138 139 static orc_rt_CWrapperFunctionResult sumArrayWrapper(const char *ArgData, 140 size_t ArgSize) { 141 return WrapperFunction<int8_t(SPSExecutorAddrRange)>::handle( 142 ArgData, ArgSize, 143 [](ExecutorAddrRange R) { 144 int8_t Sum = 0; 145 for (char C : R.toSpan<char>()) 146 Sum += C; 147 return Sum; 148 }) 149 .release(); 150 } 151 152 TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) { 153 { 154 // Check wrapper function calls. 155 char A[] = {1, 2, 3, 4}; 156 157 auto WFC = 158 cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>( 159 ExecutorAddr::fromPtr(sumArrayWrapper), 160 ExecutorAddrRange(ExecutorAddr::fromPtr(A), 161 ExecutorAddrDiff(sizeof(A))))); 162 163 WrapperFunctionResult WFR(WFC.run()); 164 EXPECT_EQ(WFR.size(), 1U); 165 EXPECT_EQ(WFR.data()[0], 10); 166 } 167 168 { 169 // Check calls to void functions. 170 auto WFC = 171 cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>( 172 ExecutorAddr::fromPtr(voidNoopWrapper), ExecutorAddrRange())); 173 auto Err = WFC.runWithSPSRet<void>(); 174 EXPECT_FALSE(!!Err); 175 } 176 177 { 178 // Check calls with arguments and return values. 179 auto WFC = 180 cantFail(WrapperFunctionCall::Create<SPSArgList<int32_t, int32_t>>( 181 ExecutorAddr::fromPtr(addWrapper), 2, 4)); 182 183 int32_t Result = 0; 184 auto Err = WFC.runWithSPSRet<int32_t>(Result); 185 EXPECT_FALSE(!!Err); 186 EXPECT_EQ(Result, 6); 187 } 188 } 189