xref: /llvm-project/mlir/unittests/ExecutionEngine/Invoke.cpp (revision 0d9dc421143a0acd414a23f343b555c965a471f1)
1 //===- Invoke.cpp ------------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
10 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
11 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
12 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
13 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
14 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/ExecutionEngine/CRunnerUtils.h"
18 #include "mlir/ExecutionEngine/ExecutionEngine.h"
19 #include "mlir/ExecutionEngine/MemRefUtils.h"
20 #include "mlir/ExecutionEngine/RunnerUtils.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/InitAllDialects.h"
23 #include "mlir/Parser/Parser.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
26 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
27 #include "mlir/Target/LLVMIR/Export.h"
28 #include "llvm/Support/TargetSelect.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 #include "gmock/gmock.h"
32 
33 // SPARC currently lacks JIT support.
34 #ifdef __sparc__
35 #define SKIP_WITHOUT_JIT(x) DISABLED_##x
36 #else
37 #define SKIP_WITHOUT_JIT(x) x
38 #endif
39 
40 using namespace mlir;
41 
42 // The JIT isn't supported on Windows at that time
43 #ifndef _WIN32
44 
45 static struct LLVMInitializer {
46   LLVMInitializer() {
47     llvm::InitializeNativeTarget();
48     llvm::InitializeNativeTargetAsmPrinter();
49   }
50 } initializer;
51 
52 /// Simple conversion pipeline for the purpose of testing sources written in
53 /// dialects lowering to LLVM Dialect.
54 static LogicalResult lowerToLLVMDialect(ModuleOp module) {
55   PassManager pm(module->getName());
56   pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass());
57   pm.addNestedPass<func::FuncOp>(mlir::createArithToLLVMConversionPass());
58   pm.addPass(mlir::createConvertFuncToLLVMPass());
59   pm.addPass(mlir::createReconcileUnrealizedCastsPass());
60   return pm.run(module);
61 }
62 
63 TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(AddInteger)) {
64 #ifdef __s390__
65   std::string moduleStr = R"mlir(
66   func.func @foo(%arg0 : i32 {llvm.signext}) -> (i32 {llvm.signext}) attributes { llvm.emit_c_interface } {
67     %res = arith.addi %arg0, %arg0 : i32
68     return %res : i32
69   }
70   )mlir";
71 #else
72   std::string moduleStr = R"mlir(
73   func.func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
74     %res = arith.addi %arg0, %arg0 : i32
75     return %res : i32
76   }
77   )mlir";
78 #endif
79   DialectRegistry registry;
80   registerAllDialects(registry);
81   registerBuiltinDialectTranslation(registry);
82   registerLLVMDialectTranslation(registry);
83   MLIRContext context(registry);
84   OwningOpRef<ModuleOp> module =
85       parseSourceString<ModuleOp>(moduleStr, &context);
86   ASSERT_TRUE(!!module);
87   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
88   auto jitOrError = ExecutionEngine::create(*module);
89   ASSERT_TRUE(!!jitOrError);
90   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
91   // The result of the function must be passed as output argument.
92   int result = 0;
93   llvm::Error error =
94       jit->invoke("foo", 42, ExecutionEngine::Result<int>(result));
95   ASSERT_TRUE(!error);
96   ASSERT_EQ(result, 42 + 42);
97 }
98 
99 TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(SubtractFloat)) {
100   std::string moduleStr = R"mlir(
101   func.func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
102     %res = arith.subf %arg0, %arg1 : f32
103     return %res : f32
104   }
105   )mlir";
106   DialectRegistry registry;
107   registerAllDialects(registry);
108   registerBuiltinDialectTranslation(registry);
109   registerLLVMDialectTranslation(registry);
110   MLIRContext context(registry);
111   OwningOpRef<ModuleOp> module =
112       parseSourceString<ModuleOp>(moduleStr, &context);
113   ASSERT_TRUE(!!module);
114   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
115   auto jitOrError = ExecutionEngine::create(*module);
116   ASSERT_TRUE(!!jitOrError);
117   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
118   // The result of the function must be passed as output argument.
119   float result = -1;
120   llvm::Error error =
121       jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result));
122   ASSERT_TRUE(!error);
123   ASSERT_EQ(result, 42.f);
124 }
125 
126 TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(ZeroRankMemref)) {
127   OwningMemRef<float, 0> a({});
128   a[{}] = 42.;
129   ASSERT_EQ(*a->data, 42);
130   a[{}] = 0;
131   std::string moduleStr = R"mlir(
132   func.func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
133     %cst42 = arith.constant 42.0 : f32
134     memref.store %cst42, %arg0[] : memref<f32>
135     return
136   }
137   )mlir";
138   DialectRegistry registry;
139   registerAllDialects(registry);
140   registerBuiltinDialectTranslation(registry);
141   registerLLVMDialectTranslation(registry);
142   MLIRContext context(registry);
143   auto module = parseSourceString<ModuleOp>(moduleStr, &context);
144   ASSERT_TRUE(!!module);
145   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
146   auto jitOrError = ExecutionEngine::create(*module);
147   ASSERT_TRUE(!!jitOrError);
148   auto jit = std::move(jitOrError.get());
149 
150   llvm::Error error = jit->invoke("zero_ranked", &*a);
151   ASSERT_TRUE(!error);
152   EXPECT_EQ((a[{}]), 42.);
153   for (float &elt : *a)
154     EXPECT_EQ(&elt, &(a[{}]));
155 }
156 
157 TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(RankOneMemref)) {
158   int64_t shape[] = {9};
159   OwningMemRef<float, 1> a(shape);
160   int count = 1;
161   for (float &elt : *a) {
162     EXPECT_EQ(&elt, &(a[{count - 1}]));
163     elt = count++;
164   }
165 
166   std::string moduleStr = R"mlir(
167   func.func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
168     %cst42 = arith.constant 42.0 : f32
169     %cst5 = arith.constant 5 : index
170     memref.store %cst42, %arg0[%cst5] : memref<?xf32>
171     return
172   }
173   )mlir";
174   DialectRegistry registry;
175   registerAllDialects(registry);
176   registerBuiltinDialectTranslation(registry);
177   registerLLVMDialectTranslation(registry);
178   MLIRContext context(registry);
179   auto module = parseSourceString<ModuleOp>(moduleStr, &context);
180   ASSERT_TRUE(!!module);
181   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
182   auto jitOrError = ExecutionEngine::create(*module);
183   ASSERT_TRUE(!!jitOrError);
184   auto jit = std::move(jitOrError.get());
185 
186   llvm::Error error = jit->invoke("one_ranked", &*a);
187   ASSERT_TRUE(!error);
188   count = 1;
189   for (float &elt : *a) {
190     if (count == 6)
191       EXPECT_EQ(elt, 42.);
192     else
193       EXPECT_EQ(elt, count);
194     count++;
195   }
196 }
197 
198 TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) {
199   constexpr int k = 3;
200   constexpr int m = 7;
201   // Prepare arguments beforehand.
202   auto init = [=](float &elt, ArrayRef<int64_t> indices) {
203     assert(indices.size() == 2);
204     elt = m * indices[0] + indices[1];
205   };
206   int64_t shape[] = {k, m};
207   int64_t shapeAlloc[] = {k + 1, m + 1};
208   OwningMemRef<float, 2> a(shape, shapeAlloc, init);
209   ASSERT_EQ(a->sizes[0], k);
210   ASSERT_EQ(a->sizes[1], m);
211   ASSERT_EQ(a->strides[0], m + 1);
212   ASSERT_EQ(a->strides[1], 1);
213   for (int i = 0; i < k; ++i) {
214     for (int j = 0; j < m; ++j) {
215       EXPECT_EQ((a[{i, j}]), i * m + j);
216       EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
217     }
218   }
219   std::string moduleStr = R"mlir(
220   func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
221     %x = arith.constant 2 : index
222     %y = arith.constant 1 : index
223     %cst42 = arith.constant 42.0 : f32
224     memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
225     memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
226     return
227   }
228   )mlir";
229   DialectRegistry registry;
230   registerAllDialects(registry);
231   registerBuiltinDialectTranslation(registry);
232   registerLLVMDialectTranslation(registry);
233   MLIRContext context(registry);
234   OwningOpRef<ModuleOp> module =
235       parseSourceString<ModuleOp>(moduleStr, &context);
236   ASSERT_TRUE(!!module);
237   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
238   auto jitOrError = ExecutionEngine::create(*module);
239   ASSERT_TRUE(!!jitOrError);
240   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
241 
242   llvm::Error error = jit->invoke("rank2_memref", &*a, &*a);
243   ASSERT_TRUE(!error);
244   EXPECT_EQ(((*a)[1][2]), 42.);
245   EXPECT_EQ((a[{2, 1}]), 42.);
246 }
247 
248 // A helper function that will be called from the JIT
249 static void memrefMultiply(::StridedMemRefType<float, 2> *memref,
250                            int32_t coefficient) {
251   for (float &elt : *memref)
252     elt *= coefficient;
253 }
254 
255 // MSAN does not work with JIT.
256 #if __has_feature(memory_sanitizer)
257 #define MAYBE_JITCallback DISABLED_JITCallback
258 #else
259 #define MAYBE_JITCallback SKIP_WITHOUT_JIT(JITCallback)
260 #endif
261 TEST(NativeMemRefJit, MAYBE_JITCallback) {
262   constexpr int k = 2;
263   constexpr int m = 2;
264   int64_t shape[] = {k, m};
265   int64_t shapeAlloc[] = {k + 1, m + 1};
266   OwningMemRef<float, 2> a(shape, shapeAlloc);
267   int count = 1;
268   for (float &elt : *a)
269     elt = count++;
270 
271 #ifdef __s390__
272   std::string moduleStr = R"mlir(
273   func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32 {llvm.signext})  attributes { llvm.emit_c_interface }
274   func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32 {llvm.signext}) attributes { llvm.emit_c_interface } {
275     %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
276     call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
277     return
278   }
279   )mlir";
280 #else
281   std::string moduleStr = R"mlir(
282   func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32)  attributes { llvm.emit_c_interface }
283   func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
284     %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
285     call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
286     return
287   }
288   )mlir";
289 #endif
290 
291   DialectRegistry registry;
292   registerAllDialects(registry);
293   registerBuiltinDialectTranslation(registry);
294   registerLLVMDialectTranslation(registry);
295   MLIRContext context(registry);
296   auto module = parseSourceString<ModuleOp>(moduleStr, &context);
297   ASSERT_TRUE(!!module);
298   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
299   auto jitOrError = ExecutionEngine::create(*module);
300   ASSERT_TRUE(!!jitOrError);
301   auto jit = std::move(jitOrError.get());
302   // Define any extra symbols so they're available at runtime.
303   jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
304     llvm::orc::SymbolMap symbolMap;
305     symbolMap[interner("_mlir_ciface_callback")] = {
306         llvm::orc::ExecutorAddr::fromPtr(memrefMultiply),
307         llvm::JITSymbolFlags::Exported};
308     return symbolMap;
309   });
310 
311   int32_t coefficient = 3.;
312   llvm::Error error = jit->invoke("caller_for_callback", &*a, coefficient);
313   ASSERT_TRUE(!error);
314   count = 1;
315   for (float elt : *a)
316     ASSERT_EQ(elt, coefficient * count++);
317 }
318 
319 #endif // _WIN32
320