1 //===- Invoke.cpp ------------------------------------*- C++ -*-===// 2 // 3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" 10 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" 11 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 12 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" 13 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 14 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/ExecutionEngine/CRunnerUtils.h" 18 #include "mlir/ExecutionEngine/ExecutionEngine.h" 19 #include "mlir/ExecutionEngine/MemRefUtils.h" 20 #include "mlir/ExecutionEngine/RunnerUtils.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/InitAllDialects.h" 23 #include "mlir/Parser/Parser.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" 26 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" 27 #include "mlir/Target/LLVMIR/Export.h" 28 #include "llvm/Support/TargetSelect.h" 29 #include "llvm/Support/raw_ostream.h" 30 31 #include "gmock/gmock.h" 32 33 // SPARC currently lacks JIT support. 34 #ifdef __sparc__ 35 #define SKIP_WITHOUT_JIT(x) DISABLED_##x 36 #else 37 #define SKIP_WITHOUT_JIT(x) x 38 #endif 39 40 using namespace mlir; 41 42 // The JIT isn't supported on Windows at that time 43 #ifndef _WIN32 44 45 static struct LLVMInitializer { 46 LLVMInitializer() { 47 llvm::InitializeNativeTarget(); 48 llvm::InitializeNativeTargetAsmPrinter(); 49 } 50 } initializer; 51 52 /// Simple conversion pipeline for the purpose of testing sources written in 53 /// dialects lowering to LLVM Dialect. 54 static LogicalResult lowerToLLVMDialect(ModuleOp module) { 55 PassManager pm(module->getName()); 56 pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); 57 pm.addNestedPass<func::FuncOp>(mlir::createArithToLLVMConversionPass()); 58 pm.addPass(mlir::createConvertFuncToLLVMPass()); 59 pm.addPass(mlir::createReconcileUnrealizedCastsPass()); 60 return pm.run(module); 61 } 62 63 TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(AddInteger)) { 64 #ifdef __s390__ 65 std::string moduleStr = R"mlir( 66 func.func @foo(%arg0 : i32 {llvm.signext}) -> (i32 {llvm.signext}) attributes { llvm.emit_c_interface } { 67 %res = arith.addi %arg0, %arg0 : i32 68 return %res : i32 69 } 70 )mlir"; 71 #else 72 std::string moduleStr = R"mlir( 73 func.func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } { 74 %res = arith.addi %arg0, %arg0 : i32 75 return %res : i32 76 } 77 )mlir"; 78 #endif 79 DialectRegistry registry; 80 registerAllDialects(registry); 81 registerBuiltinDialectTranslation(registry); 82 registerLLVMDialectTranslation(registry); 83 MLIRContext context(registry); 84 OwningOpRef<ModuleOp> module = 85 parseSourceString<ModuleOp>(moduleStr, &context); 86 ASSERT_TRUE(!!module); 87 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 88 auto jitOrError = ExecutionEngine::create(*module); 89 ASSERT_TRUE(!!jitOrError); 90 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 91 // The result of the function must be passed as output argument. 92 int result = 0; 93 llvm::Error error = 94 jit->invoke("foo", 42, ExecutionEngine::Result<int>(result)); 95 ASSERT_TRUE(!error); 96 ASSERT_EQ(result, 42 + 42); 97 } 98 99 TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(SubtractFloat)) { 100 std::string moduleStr = R"mlir( 101 func.func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } { 102 %res = arith.subf %arg0, %arg1 : f32 103 return %res : f32 104 } 105 )mlir"; 106 DialectRegistry registry; 107 registerAllDialects(registry); 108 registerBuiltinDialectTranslation(registry); 109 registerLLVMDialectTranslation(registry); 110 MLIRContext context(registry); 111 OwningOpRef<ModuleOp> module = 112 parseSourceString<ModuleOp>(moduleStr, &context); 113 ASSERT_TRUE(!!module); 114 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 115 auto jitOrError = ExecutionEngine::create(*module); 116 ASSERT_TRUE(!!jitOrError); 117 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 118 // The result of the function must be passed as output argument. 119 float result = -1; 120 llvm::Error error = 121 jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result)); 122 ASSERT_TRUE(!error); 123 ASSERT_EQ(result, 42.f); 124 } 125 126 TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(ZeroRankMemref)) { 127 OwningMemRef<float, 0> a({}); 128 a[{}] = 42.; 129 ASSERT_EQ(*a->data, 42); 130 a[{}] = 0; 131 std::string moduleStr = R"mlir( 132 func.func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } { 133 %cst42 = arith.constant 42.0 : f32 134 memref.store %cst42, %arg0[] : memref<f32> 135 return 136 } 137 )mlir"; 138 DialectRegistry registry; 139 registerAllDialects(registry); 140 registerBuiltinDialectTranslation(registry); 141 registerLLVMDialectTranslation(registry); 142 MLIRContext context(registry); 143 auto module = parseSourceString<ModuleOp>(moduleStr, &context); 144 ASSERT_TRUE(!!module); 145 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 146 auto jitOrError = ExecutionEngine::create(*module); 147 ASSERT_TRUE(!!jitOrError); 148 auto jit = std::move(jitOrError.get()); 149 150 llvm::Error error = jit->invoke("zero_ranked", &*a); 151 ASSERT_TRUE(!error); 152 EXPECT_EQ((a[{}]), 42.); 153 for (float &elt : *a) 154 EXPECT_EQ(&elt, &(a[{}])); 155 } 156 157 TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(RankOneMemref)) { 158 int64_t shape[] = {9}; 159 OwningMemRef<float, 1> a(shape); 160 int count = 1; 161 for (float &elt : *a) { 162 EXPECT_EQ(&elt, &(a[{count - 1}])); 163 elt = count++; 164 } 165 166 std::string moduleStr = R"mlir( 167 func.func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } { 168 %cst42 = arith.constant 42.0 : f32 169 %cst5 = arith.constant 5 : index 170 memref.store %cst42, %arg0[%cst5] : memref<?xf32> 171 return 172 } 173 )mlir"; 174 DialectRegistry registry; 175 registerAllDialects(registry); 176 registerBuiltinDialectTranslation(registry); 177 registerLLVMDialectTranslation(registry); 178 MLIRContext context(registry); 179 auto module = parseSourceString<ModuleOp>(moduleStr, &context); 180 ASSERT_TRUE(!!module); 181 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 182 auto jitOrError = ExecutionEngine::create(*module); 183 ASSERT_TRUE(!!jitOrError); 184 auto jit = std::move(jitOrError.get()); 185 186 llvm::Error error = jit->invoke("one_ranked", &*a); 187 ASSERT_TRUE(!error); 188 count = 1; 189 for (float &elt : *a) { 190 if (count == 6) 191 EXPECT_EQ(elt, 42.); 192 else 193 EXPECT_EQ(elt, count); 194 count++; 195 } 196 } 197 198 TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) { 199 constexpr int k = 3; 200 constexpr int m = 7; 201 // Prepare arguments beforehand. 202 auto init = [=](float &elt, ArrayRef<int64_t> indices) { 203 assert(indices.size() == 2); 204 elt = m * indices[0] + indices[1]; 205 }; 206 int64_t shape[] = {k, m}; 207 int64_t shapeAlloc[] = {k + 1, m + 1}; 208 OwningMemRef<float, 2> a(shape, shapeAlloc, init); 209 ASSERT_EQ(a->sizes[0], k); 210 ASSERT_EQ(a->sizes[1], m); 211 ASSERT_EQ(a->strides[0], m + 1); 212 ASSERT_EQ(a->strides[1], 1); 213 for (int i = 0; i < k; ++i) { 214 for (int j = 0; j < m; ++j) { 215 EXPECT_EQ((a[{i, j}]), i * m + j); 216 EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j])); 217 } 218 } 219 std::string moduleStr = R"mlir( 220 func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } { 221 %x = arith.constant 2 : index 222 %y = arith.constant 1 : index 223 %cst42 = arith.constant 42.0 : f32 224 memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32> 225 memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32> 226 return 227 } 228 )mlir"; 229 DialectRegistry registry; 230 registerAllDialects(registry); 231 registerBuiltinDialectTranslation(registry); 232 registerLLVMDialectTranslation(registry); 233 MLIRContext context(registry); 234 OwningOpRef<ModuleOp> module = 235 parseSourceString<ModuleOp>(moduleStr, &context); 236 ASSERT_TRUE(!!module); 237 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 238 auto jitOrError = ExecutionEngine::create(*module); 239 ASSERT_TRUE(!!jitOrError); 240 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 241 242 llvm::Error error = jit->invoke("rank2_memref", &*a, &*a); 243 ASSERT_TRUE(!error); 244 EXPECT_EQ(((*a)[1][2]), 42.); 245 EXPECT_EQ((a[{2, 1}]), 42.); 246 } 247 248 // A helper function that will be called from the JIT 249 static void memrefMultiply(::StridedMemRefType<float, 2> *memref, 250 int32_t coefficient) { 251 for (float &elt : *memref) 252 elt *= coefficient; 253 } 254 255 // MSAN does not work with JIT. 256 #if __has_feature(memory_sanitizer) 257 #define MAYBE_JITCallback DISABLED_JITCallback 258 #else 259 #define MAYBE_JITCallback SKIP_WITHOUT_JIT(JITCallback) 260 #endif 261 TEST(NativeMemRefJit, MAYBE_JITCallback) { 262 constexpr int k = 2; 263 constexpr int m = 2; 264 int64_t shape[] = {k, m}; 265 int64_t shapeAlloc[] = {k + 1, m + 1}; 266 OwningMemRef<float, 2> a(shape, shapeAlloc); 267 int count = 1; 268 for (float &elt : *a) 269 elt = count++; 270 271 #ifdef __s390__ 272 std::string moduleStr = R"mlir( 273 func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32 {llvm.signext}) attributes { llvm.emit_c_interface } 274 func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32 {llvm.signext}) attributes { llvm.emit_c_interface } { 275 %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32> 276 call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> () 277 return 278 } 279 )mlir"; 280 #else 281 std::string moduleStr = R"mlir( 282 func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } 283 func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } { 284 %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32> 285 call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> () 286 return 287 } 288 )mlir"; 289 #endif 290 291 DialectRegistry registry; 292 registerAllDialects(registry); 293 registerBuiltinDialectTranslation(registry); 294 registerLLVMDialectTranslation(registry); 295 MLIRContext context(registry); 296 auto module = parseSourceString<ModuleOp>(moduleStr, &context); 297 ASSERT_TRUE(!!module); 298 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 299 auto jitOrError = ExecutionEngine::create(*module); 300 ASSERT_TRUE(!!jitOrError); 301 auto jit = std::move(jitOrError.get()); 302 // Define any extra symbols so they're available at runtime. 303 jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) { 304 llvm::orc::SymbolMap symbolMap; 305 symbolMap[interner("_mlir_ciface_callback")] = { 306 llvm::orc::ExecutorAddr::fromPtr(memrefMultiply), 307 llvm::JITSymbolFlags::Exported}; 308 return symbolMap; 309 }); 310 311 int32_t coefficient = 3.; 312 llvm::Error error = jit->invoke("caller_for_callback", &*a, coefficient); 313 ASSERT_TRUE(!error); 314 count = 1; 315 for (float elt : *a) 316 ASSERT_EQ(elt, coefficient * count++); 317 } 318 319 #endif // _WIN32 320