xref: /llvm-project/mlir/test/python/integration/dialects/linalg/opsrun.py (revision eb6c4197d5263ed2e086925b2b2f032a19442d2b)
1# RUN: %PYTHON %s 2>&1 | FileCheck %s
2
3import ctypes
4import sys
5from mlir.ir import *
6from mlir.dialects import builtin
7from mlir.dialects import func
8from mlir.dialects import linalg
9from mlir.passmanager import *
10from mlir.execution_engine import *
11
12from mlir.dialects.linalg.opdsl.lang import *
13
14
15# Log everything to stderr and flush so that we have a unified stream to match
16# errors/info emitted by MLIR to stderr.
17def log(*args):
18    print(*args, file=sys.stderr)
19    sys.stderr.flush()
20
21
22elemwise_boiler = """
23func.func @main() -> f32 attributes {llvm.emit_c_interface} {
24  %v0 = arith.constant 0.0 : f32
25  %v1 = arith.constant 1.0 : f32
26  %v2 = arith.constant 2.0 : f32
27
28  %lhs = memref.alloc() : memref<f32>
29  %rhs = memref.alloc() : memref<4x8xf32>
30  %O0 = memref.alloc() : memref<4x8xf32>
31  %O1 = memref.alloc() : memref<4x8xf32>
32  linalg.fill ins(%v1 : f32) outs(%lhs : memref<f32>)
33  linalg.fill ins(%v2 : f32) outs(%rhs : memref<4x8xf32>)
34  linalg.fill ins(%v0 : f32) outs(%O0 : memref<4x8xf32>)
35  linalg.fill ins(%v0 : f32) outs(%O1 : memref<4x8xf32>)
36
37  call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) :
38    (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
39  call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) :
40    (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
41
42  %c0 = arith.constant 0 : index
43  %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32>
44  %res1 = memref.load %O1[%c0, %c0] : memref<4x8xf32>
45
46  %0 = arith.addf %res0, %res1 : f32
47
48  // TODO: FFI-based solution to allow testing and printing with python code.
49  return %0 : f32
50}
51"""
52
53fill_boiler = """
54func.func @main() -> i32 attributes {llvm.emit_c_interface} {
55  %O0 = memref.alloc() : memref<i32>
56  %O1 = memref.alloc() : memref<16xi32>
57  %O2 = memref.alloc() : memref<4x16xi32>
58
59  %val0 = arith.constant 1.0 : f32
60  %val1 = arith.constant 2.0 : f32
61  %val2 = arith.constant 3.0 : f32
62
63  call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> ()
64  call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> ()
65  call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> ()
66
67  %c0 = arith.constant 0 : index
68  %res0 = memref.load %O0[] : memref<i32>
69  %c8 = arith.constant 8 : index
70  %res1 = memref.load %O1[%c8] : memref<16xi32>
71  %c2 = arith.constant 2 : index
72  %res2 = memref.load %O2[%c2, %c8] : memref<4x16xi32>
73
74  %0 = arith.addi %res0, %res1 : i32
75  %1 = arith.addi %0, %res2 : i32
76
77  // TODO: FFI-based solution to allow testing and printing with python code.
78  return %1 : i32
79}
80"""
81
82fill_rng_boiler = """
83func.func @main() -> i32 attributes {llvm.emit_c_interface} {
84  %O = memref.alloc() : memref<4x16xi32>
85  %min = arith.constant -1000.0 : f64
86  %max = arith.constant 1000.0 : f64
87  %seed = arith.constant 42 : i32
88
89  call @fill_rng_on_buffers(%min, %max, %seed, %O) :
90    (f64, f64, i32, memref<4x16xi32>) -> ()
91
92  %c0 = arith.constant 0 : index
93  %0 = memref.load %O[%c0, %c0] : memref<4x16xi32>
94
95  // TODO: FFI-based solution to allow testing and printing with python code.
96  return %0 : i32
97}
98"""
99
100conv_boiler = """
101func.func @main() -> i32 attributes {llvm.emit_c_interface} {
102  %v0 = arith.constant 0 : i32
103  %v1 = arith.constant 1.0 : f64
104  %v2 = arith.constant 2.0 : f64
105
106  %input = memref.alloc() : memref<1x4x16x1xf64>
107  %filter = memref.alloc() : memref<2x2x1xf64>
108  %output = memref.alloc() : memref<1x2x4x1xi32>
109  linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>)
110  linalg.fill ins(%v2 : f64) outs(%filter : memref<2x2x1xf64>)
111  linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>)
112
113  call @conv_on_buffers(%input, %filter, %output) :
114    (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> ()
115
116  %c0 = arith.constant 0 : index
117  %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
118
119  // TODO: FFI-based solution to allow testing and printing with python code.
120  return %0 : i32
121}
122"""
123
124pooling_boiler = """
125func.func @main() -> i32 attributes {llvm.emit_c_interface} {
126  %v0 = arith.constant 0 : i32
127  %v42 = arith.constant 42.0 : f64
128  %v77 = arith.constant 77.0 : f64
129  %v-13 = arith.constant -13.0 : f64
130  %v1 = arith.constant 1.0 : f64
131
132  %input = memref.alloc() : memref<1x4x16x1xf64>
133  %shape = memref.alloc() : memref<2x2xf64>
134  %output = memref.alloc() : memref<1x2x4x1xi32>
135  linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>)
136  linalg.fill ins(%v1 : f64) outs(%shape : memref<2x2xf64>)
137  linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>)
138
139  %c0 = arith.constant 0 : index
140  %c1 = arith.constant 1 : index
141  %c2 = arith.constant 2 : index
142  memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64>
143  memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64>
144  memref.store %v-13, %input[%c0, %c1, %c0, %c0] : memref<1x4x16x1xf64>
145
146  call @pooling_on_buffers(%input, %shape, %output) :
147    (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
148
149  %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
150
151  // TODO: FFI-based solution to allow testing and printing with python code.
152  return %0 : i32
153}
154"""
155
156
157def transform(module, boilerplate):
158    # TODO: Allow cloning functions from one module to another.
159    # Atm we have to resort to string concatenation.
160    ops = module.operation.regions[0].blocks[0].operations
161    mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate)
162
163    pm = PassManager("builtin.module")
164    pm.add("func.func(convert-linalg-to-loops)")
165    pm.add("func.func(lower-affine)")
166    pm.add("func.func(convert-math-to-llvm)")
167    pm.add("func.func(convert-scf-to-cf)")
168    pm.add("func.func(arith-expand)")
169    pm.add("func.func(memref-expand)")
170    pm.add("convert-vector-to-llvm")
171    pm.add("finalize-memref-to-llvm")
172    pm.add("convert-func-to-llvm")
173    pm.add("convert-arith-to-llvm")
174    pm.add("convert-cf-to-llvm")
175    pm.add("reconcile-unrealized-casts")
176    pm.run(mod.operation)
177    return mod
178
179
180def test_elemwise_builtin():
181    with Context() as ctx, Location.unknown():
182        module = Module.create()
183        f32 = F32Type.get()
184        i8 = IntegerType.get_signless(8)
185        with InsertionPoint(module.body):
186
187            @func.FuncOp.from_py_func(
188                MemRefType.get((), f32),
189                MemRefType.get((4, 8), f32),
190                MemRefType.get((4, 8), f32),
191            )
192            def elemwise_exp_add_on_buffers(lhs, rhs, out):
193                linalg.elemwise_unary(lhs, outs=[out])
194                linalg.elemwise_binary(out, rhs, outs=[out])
195
196            @func.FuncOp.from_py_func(
197                MemRefType.get((), f32),
198                MemRefType.get((4, 8), f32),
199                MemRefType.get((4, 8), f32),
200            )
201            def elemwise_log_mul_on_buffers(lhs, rhs, out):
202                linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
203                linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul)
204
205        execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
206
207        # TODO: FFI-based solution to allow testing and printing with python code.
208        # Prepare arguments: one result f32.
209        # Arguments must be passed as pointers.
210        c_float_p = ctypes.c_float * 1
211        res = c_float_p(-1.0)
212        execution_engine.invoke("main", res)
213
214        log("RESULT: ", res[0])
215        # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
216        # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
217        # CHECK: RESULT: 4.71828
218
219
220test_elemwise_builtin()
221
222
223def test_elemwise_generic():
224    with Context() as ctx, Location.unknown():
225        module = Module.create()
226        f32 = F32Type.get()
227        i8 = IntegerType.get_signless(8)
228        with InsertionPoint(module.body):
229
230            @func.FuncOp.from_py_func(
231                MemRefType.get((), f32),
232                MemRefType.get((4, 8), f32),
233                MemRefType.get((4, 8), f32),
234            )
235            def elemwise_exp_add_on_buffers(lhs, rhs, out):
236                linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
237                linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
238
239            @func.FuncOp.from_py_func(
240                MemRefType.get((), f32),
241                MemRefType.get((4, 8), f32),
242                MemRefType.get((4, 8), f32),
243            )
244            def elemwise_log_mul_on_buffers(lhs, rhs, out):
245                linalg.elemwise_unary(
246                    lhs, outs=[out], fun=UnaryFn.log, emit_generic=True
247                )
248                linalg.elemwise_binary(
249                    out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True
250                )
251
252        execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
253
254        # TODO: FFI-based solution to allow testing and printing with python code.
255        # Prepare arguments: one result f32.
256        # Arguments must be passed as pointers.
257        c_float_p = ctypes.c_float * 1
258        res = c_float_p(-1.0)
259        execution_engine.invoke("main", res)
260
261        log("RESULT: ", res[0])
262        # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
263        # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
264        # CHECK: RESULT: 4.71828
265
266
267test_elemwise_generic()
268
269
270def test_fill_builtin():
271    with Context() as ctx, Location.unknown():
272        module = Module.create()
273        f32 = F32Type.get()
274        i32 = IntegerType.get_signless(32)
275        with InsertionPoint(module.body):
276
277            @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
278            def fill_0d_on_buffers(value, out):
279                linalg.fill(value, outs=[out])
280
281            @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
282            def fill_1d_on_buffers(value, out):
283                linalg.fill(value, outs=[out])
284
285            @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
286            def fill_2d_on_buffers(value, out):
287                linalg.fill(value, outs=[out])
288
289        execution_engine = ExecutionEngine(transform(module, fill_boiler))
290
291        # TODO: FFI-based solution to allow testing and printing with python code.
292        # Prepare arguments: one result i32.
293        # Arguments must be passed as pointers.
294        c_int_p = ctypes.c_int * 1
295        res = c_int_p(-1)
296        execution_engine.invoke("main", res)
297
298        log("RESULT: ", res[0])
299        # CHECK: RESULT: 6
300
301
302test_fill_builtin()
303
304
305def test_fill_generic():
306    with Context() as ctx, Location.unknown():
307        module = Module.create()
308        f32 = F32Type.get()
309        i32 = IntegerType.get_signless(32)
310        with InsertionPoint(module.body):
311
312            @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
313            def fill_0d_on_buffers(value, out):
314                linalg.fill(value, outs=[out], emit_generic=True)
315
316            @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
317            def fill_1d_on_buffers(value, out):
318                linalg.fill(value, outs=[out], emit_generic=True)
319
320            @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
321            def fill_2d_on_buffers(value, out):
322                linalg.fill(value, outs=[out], emit_generic=True)
323
324        execution_engine = ExecutionEngine(transform(module, fill_boiler))
325
326        # TODO: FFI-based solution to allow testing and printing with python code.
327        # Prepare arguments: one result i32.
328        # Arguments must be passed as pointers.
329        c_int_p = ctypes.c_int * 1
330        res = c_int_p(-1)
331        execution_engine.invoke("main", res)
332
333        log("RESULT: ", res[0])
334        # CHECK: RESULT: 6
335
336
337test_fill_generic()
338
339
340def test_fill_rng_builtin():
341    with Context() as ctx, Location.unknown():
342        module = Module.create()
343        f64 = F64Type.get()
344        i32 = IntegerType.get_signless(32)
345        with InsertionPoint(module.body):
346
347            @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
348            def fill_rng_on_buffers(min, max, seed, out):
349                linalg.fill_rng_2d(min, max, seed, outs=[out])
350
351        execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
352
353        # TODO: FFI-based solution to allow testing and printing with python code.
354        # Prepare arguments: one result i32.
355        # Arguments must be passed as pointers.
356        c_int_p = ctypes.c_int * 1
357        res = c_int_p(-1)
358        execution_engine.invoke("main", res)
359
360        log("RESULT: ", res[0])
361        # CHECK: RESULT: -480
362
363
364test_fill_rng_builtin()
365
366
367def test_fill_rng_generic():
368    with Context() as ctx, Location.unknown():
369        module = Module.create()
370        f64 = F64Type.get()
371        i32 = IntegerType.get_signless(32)
372        with InsertionPoint(module.body):
373
374            @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
375            def fill_rng_on_buffers(min, max, seed, out):
376                linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True)
377
378        execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
379
380        # TODO: FFI-based solution to allow testing and printing with python code.
381        # Prepare arguments: one result i32.
382        # Arguments must be passed as pointers.
383        c_int_p = ctypes.c_int * 1
384        res = c_int_p(-1)
385        execution_engine.invoke("main", res)
386
387        log("RESULT: ", res[0])
388        # CHECK: RESULT: -480
389
390
391test_fill_rng_generic()
392
393
394def test_max_pooling_builtin():
395    with Context() as ctx, Location.unknown():
396        module = Module.create()
397        f64 = F64Type.get()
398        i32 = IntegerType.get_signless(32)
399        with InsertionPoint(module.body):
400
401            @func.FuncOp.from_py_func(
402                MemRefType.get((1, 4, 16, 1), f64),
403                MemRefType.get((2, 2), f64),
404                MemRefType.get((1, 2, 4, 1), i32),
405            )
406            def pooling_on_buffers(input, shape, output):
407                linalg.pooling_nhwc_max(
408                    input, shape, outs=[output], strides=[2, 4], dilations=[1, 2]
409                )
410
411        execution_engine = ExecutionEngine(transform(module, pooling_boiler))
412
413        # TODO: FFI-based solution to allow testing and printing with python code.
414        # Prepare arguments: one result i32.
415        # Arguments must be passed as pointers.
416        c_int_p = ctypes.c_int * 1
417        res = c_int_p(-1)
418        execution_engine.invoke("main", res)
419
420        log("RESULT: ", res[0])
421        # 77 is not selected due to the dilation 2 in the second dimension.
422        # CHECK: RESULT: 42
423
424
425test_max_pooling_builtin()
426
427
428def test_max_pooling_generic():
429    with Context() as ctx, Location.unknown():
430        module = Module.create()
431        f64 = F64Type.get()
432        i32 = IntegerType.get_signless(32)
433        with InsertionPoint(module.body):
434
435            @func.FuncOp.from_py_func(
436                MemRefType.get((1, 4, 16, 1), f64),
437                MemRefType.get((2, 2), f64),
438                MemRefType.get((1, 2, 4, 1), i32),
439            )
440            def pooling_on_buffers(input, shape, output):
441                linalg.pooling_nhwc_max(
442                    input,
443                    shape,
444                    outs=[output],
445                    strides=[2, 4],
446                    dilations=[1, 2],
447                    emit_generic=True,
448                )
449
450        execution_engine = ExecutionEngine(transform(module, pooling_boiler))
451
452        # TODO: FFI-based solution to allow testing and printing with python code.
453        # Prepare arguments: one result i32.
454        # Arguments must be passed as pointers.
455        c_int_p = ctypes.c_int * 1
456        res = c_int_p(-1)
457        execution_engine.invoke("main", res)
458
459        log("RESULT: ", res[0])
460        # 77 is not selected due to the dilation 2 in the second dimension.
461        # CHECK: RESULT: 42
462
463
464test_max_pooling_generic()
465
466
467def test_min_pooling_builtin():
468    with Context() as ctx, Location.unknown():
469        module = Module.create()
470        f64 = F64Type.get()
471        i32 = IntegerType.get_signless(32)
472        with InsertionPoint(module.body):
473
474            @func.FuncOp.from_py_func(
475                MemRefType.get((1, 4, 16, 1), f64),
476                MemRefType.get((2, 2), f64),
477                MemRefType.get((1, 2, 4, 1), i32),
478            )
479            # Set the strides and use the default dilations.
480            def pooling_on_buffers(input, shape, output):
481                linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4])
482
483        execution_engine = ExecutionEngine(transform(module, pooling_boiler))
484
485        # TODO: FFI-based solution to allow testing and printing with python code.
486        # Prepare arguments: one result i32.
487        # Arguments must be passed as pointers.
488        c_int_p = ctypes.c_int * 1
489        res = c_int_p(-1)
490        execution_engine.invoke("main", res)
491
492        log("RESULT: ", res[0])
493        # CHECK: RESULT: -13
494
495
496test_min_pooling_builtin()
497
498
499def test_min_pooling_generic():
500    with Context() as ctx, Location.unknown():
501        module = Module.create()
502        f64 = F64Type.get()
503        i32 = IntegerType.get_signless(32)
504        with InsertionPoint(module.body):
505
506            @func.FuncOp.from_py_func(
507                MemRefType.get((1, 4, 16, 1), f64),
508                MemRefType.get((2, 2), f64),
509                MemRefType.get((1, 2, 4, 1), i32),
510            )
511            # Set the strides and use the default dilations.
512            def pooling_on_buffers(input, shape, output):
513                linalg.pooling_nhwc_min(
514                    input, shape, outs=[output], strides=[2, 4], emit_generic=True
515                )
516
517        execution_engine = ExecutionEngine(transform(module, pooling_boiler))
518
519        # TODO: FFI-based solution to allow testing and printing with python code.
520        # Prepare arguments: one result i32.
521        # Arguments must be passed as pointers.
522        c_int_p = ctypes.c_int * 1
523        res = c_int_p(-1)
524        execution_engine.invoke("main", res)
525
526        log("RESULT: ", res[0])
527        # CHECK: RESULT: -13
528
529
530test_min_pooling_generic()
531