xref: /llvm-project/mlir/test/python/dialects/memref.py (revision 59eadcd28f787a98a2fd5f057beb3df7950654ee)
1# RUN: %PYTHON %s | FileCheck %s
2
3import mlir.dialects.arith as arith
4import mlir.dialects.memref as memref
5import mlir.extras.types as T
6from mlir.dialects.memref import _infer_memref_subview_result_type
7from mlir.ir import *
8
9
10def run(f):
11    print("\nTEST:", f.__name__)
12    f()
13    return f
14
15
16# CHECK-LABEL: TEST: testSubViewAccessors
17@run
18def testSubViewAccessors():
19    ctx = Context()
20    module = Module.parse(
21        r"""
22    func.func @f1(%arg0: memref<?x?xf32>) {
23      %0 = arith.constant 0 : index
24      %1 = arith.constant 1 : index
25      %2 = arith.constant 2 : index
26      %3 = arith.constant 3 : index
27      %4 = arith.constant 4 : index
28      %5 = arith.constant 5 : index
29      memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
30      return
31    }
32  """,
33        ctx,
34    )
35    func_body = module.body.operations[0].regions[0].blocks[0]
36    subview = func_body.operations[6]
37
38    assert subview.source == subview.operands[0]
39    assert len(subview.offsets) == 2
40    assert len(subview.sizes) == 2
41    assert len(subview.strides) == 2
42    assert subview.result == subview.results[0]
43
44    # CHECK: SubViewOp
45    print(type(subview).__name__)
46
47    # CHECK: constant 0
48    print(subview.offsets[0])
49    # CHECK: constant 1
50    print(subview.offsets[1])
51    # CHECK: constant 2
52    print(subview.sizes[0])
53    # CHECK: constant 3
54    print(subview.sizes[1])
55    # CHECK: constant 4
56    print(subview.strides[0])
57    # CHECK: constant 5
58    print(subview.strides[1])
59
60
61# CHECK-LABEL: TEST: testCustomBuidlers
62@run
63def testCustomBuidlers():
64    with Context() as ctx, Location.unknown(ctx):
65        module = Module.parse(
66            r"""
67      func.func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
68        return
69      }
70    """
71        )
72        f = module.body.operations[0]
73        func_body = f.regions[0].blocks[0]
74        with InsertionPoint.at_block_terminator(func_body):
75            memref.LoadOp(f.arguments[0], f.arguments[1:])
76
77        # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
78        # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
79        print(module)
80        assert module.operation.verify()
81
82
83# CHECK-LABEL: TEST: testMemRefAttr
84@run
85def testMemRefAttr():
86    with Context() as ctx, Location.unknown(ctx):
87        module = Module.create()
88        with InsertionPoint(module.body):
89            memref.global_("objFifo_in0", T.memref(16, T.i32()))
90        # CHECK: memref.global @objFifo_in0 : memref<16xi32>
91        print(module)
92
93
94# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeSemantics
95@run
96def testSubViewOpInferReturnTypeSemantics():
97    with Context() as ctx, Location.unknown(ctx):
98        module = Module.create()
99        with InsertionPoint(module.body):
100            x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
101            # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
102            print(x.owner)
103
104            y = memref.subview(x, [1, 1], [3, 3], [1, 1])
105            assert y.owner.verify()
106            # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
107            print(y.owner)
108
109            z = memref.subview(
110                x,
111                [arith.constant(T.index(), 1), 1],
112                [3, 3],
113                [1, 1],
114            )
115            # CHECK: %{{.*}} =  memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
116            print(z.owner)
117
118            z = memref.subview(
119                x,
120                [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
121                [3, 3],
122                [1, 1],
123            )
124            # CHECK: %{{.*}} =  memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>>
125            print(z.owner)
126
127            s = arith.addi(arith.constant(T.index(), 3), arith.constant(T.index(), 4))
128            z = memref.subview(
129                x,
130                [s, 0],
131                [3, 3],
132                [1, 1],
133            )
134            # CHECK: {{.*}} = memref.subview %[[ALLOC]][%0, 0] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: ?>>
135            print(z)
136
137            try:
138                _infer_memref_subview_result_type(
139                    x.type,
140                    [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
141                    [ShapedType.get_dynamic_size(), 3],
142                    [1, 1],
143                )
144            except ValueError as e:
145                # CHECK: Only inferring from python or mlir integer constant is supported
146                print(e)
147
148            try:
149                memref.subview(
150                    x,
151                    [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
152                    [ShapedType.get_dynamic_size(), 3],
153                    [1, 1],
154                )
155            except ValueError as e:
156                # CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type
157                print(e)
158
159            layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1])
160            x = memref.alloc(
161                T.memref(
162                    10,
163                    10,
164                    T.i32(),
165                    layout=layout,
166                ),
167                [],
168                [arith.constant(T.index(), 42)],
169            )
170            # CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>>
171            print(x.owner)
172            y = memref.subview(
173                x,
174                [1, 1],
175                [3, 3],
176                [1, 1],
177                result_type=T.memref(3, 3, T.i32(), layout=layout),
178            )
179            # CHECK: %{{.*}} = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>>
180            print(y.owner)
181
182
183# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeExtensiveSlicing
184@run
185def testSubViewOpInferReturnTypeExtensiveSlicing():
186    def check_strides_offset(memref, np_view):
187        layout = memref.type.layout
188        dtype_size_in_bytes = np_view.dtype.itemsize
189        golden_strides = (np.array(np_view.strides) // dtype_size_in_bytes).tolist()
190        golden_offset = (
191            np_view.ctypes.data - np_view.base.ctypes.data
192        ) // dtype_size_in_bytes
193
194        assert (layout.strides, layout.offset) == (golden_strides, golden_offset)
195
196    with Context() as ctx, Location.unknown(ctx):
197        module = Module.create()
198        with InsertionPoint(module.body):
199            shape = (10, 22, 3, 44)
200            golden_mem = np.zeros(shape, dtype=np.int32)
201            mem1 = memref.alloc(T.memref(*shape, T.i32()), [], [])
202
203            # fmt: off
204            check_strides_offset(memref.subview(mem1, (1, 0, 0, 0), (1, 22, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, ...])
205            check_strides_offset(memref.subview(mem1, (0, 1, 0, 0), (10, 1, 3, 44), (1, 1, 1, 1)), golden_mem[:, 1:2])
206            check_strides_offset(memref.subview(mem1, (0, 0, 1, 0), (10, 22, 1, 44), (1, 1, 1, 1)), golden_mem[:, :, 1:2])
207            check_strides_offset(memref.subview(mem1, (0, 0, 0, 1), (10, 22, 3, 1), (1, 1, 1, 1)), golden_mem[:, :, :, 1:2])
208            check_strides_offset(memref.subview(mem1, (0, 1, 0, 1), (10, 1, 3, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, :, 1:2])
209            check_strides_offset(memref.subview(mem1, (1, 0, 0, 1), (1, 22, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, :, :, 1:2])
210            check_strides_offset(memref.subview(mem1, (1, 1, 0, 0), (1, 1, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, :])
211            check_strides_offset(memref.subview(mem1, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem[:, :, 1:2, 1:2])
212            check_strides_offset(memref.subview(mem1, (0, 1, 1, 0), (10, 1, 1, 44), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, :])
213            check_strides_offset(memref.subview(mem1, (1, 0, 1, 0), (1, 22, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, :])
214            check_strides_offset(memref.subview(mem1, (1, 1, 0, 1), (1, 1, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, 1:2])
215            check_strides_offset(memref.subview(mem1, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, 1:2])
216            check_strides_offset(memref.subview(mem1, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, 1:2])
217            check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :])
218            # fmt: on
219
220            # default strides and offset means no stridedlayout attribute means affinemap layout
221            assert memref.subview(
222                mem1, (0, 0, 0, 0), (10, 22, 3, 44), (1, 1, 1, 1)
223            ).type.layout == AffineMapAttr.get(
224                AffineMap.get(
225                    4,
226                    0,
227                    [
228                        AffineDimExpr.get(0),
229                        AffineDimExpr.get(1),
230                        AffineDimExpr.get(2),
231                        AffineDimExpr.get(3),
232                    ],
233                )
234            )
235
236            shape = (7, 22, 30, 44)
237            golden_mem = np.zeros(shape, dtype=np.int32)
238            mem2 = memref.alloc(T.memref(*shape, T.i32()), [], [])
239            # fmt: off
240            check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 3, 44), (1, 2, 1, 1)), golden_mem[:, 0:22:2])
241            check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 44), (1, 2, 30, 1)), golden_mem[:, 0:22:2, 0:330:30])
242            check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem[:, 0:22:2, 0:330:30, 0:4400:400])
243            # fmt: on
244
245            shape = (8, 8)
246            golden_mem = np.zeros(shape, dtype=np.int32)
247            # fmt: off
248            mem3 = memref.alloc(T.memref(*shape, T.i32()), [], [])
249            check_strides_offset(memref.subview(mem3, (0, 0), (4, 4), (1, 1)), golden_mem[0:4, 0:4])
250            check_strides_offset(memref.subview(mem3, (4, 4), (4, 4), (1, 1)), golden_mem[4:8, 4:8])
251            # fmt: on
252