xref: /llvm-project/mlir/test/python/dialects/memref.py (revision 59eadcd28f787a98a2fd5f057beb3df7950654ee)
19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s
29f3f6d7bSStella Laurenzo
3404af14fSMaksim Leventalimport mlir.dialects.arith as arith
49f3f6d7bSStella Laurenzoimport mlir.dialects.memref as memref
583be8a74SMaksim Leventalimport mlir.extras.types as T
6404af14fSMaksim Leventalfrom mlir.dialects.memref import _infer_memref_subview_result_type
7404af14fSMaksim Leventalfrom mlir.ir import *
89f3f6d7bSStella Laurenzo
9a54f4eaeSMogball
109f3f6d7bSStella Laurenzodef run(f):
119f3f6d7bSStella Laurenzo    print("\nTEST:", f.__name__)
129f3f6d7bSStella Laurenzo    f()
137fd6f40dSAlex Zinenko    return f
149f3f6d7bSStella Laurenzo
15a54f4eaeSMogball
169f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testSubViewAccessors
177fd6f40dSAlex Zinenko@run
189f3f6d7bSStella Laurenzodef testSubViewAccessors():
199f3f6d7bSStella Laurenzo    ctx = Context()
20a54f4eaeSMogball    module = Module.parse(
21a54f4eaeSMogball        r"""
222310ced8SRiver Riddle    func.func @f1(%arg0: memref<?x?xf32>) {
23a54f4eaeSMogball      %0 = arith.constant 0 : index
24a54f4eaeSMogball      %1 = arith.constant 1 : index
25a54f4eaeSMogball      %2 = arith.constant 2 : index
26a54f4eaeSMogball      %3 = arith.constant 3 : index
27a54f4eaeSMogball      %4 = arith.constant 4 : index
28a54f4eaeSMogball      %5 = arith.constant 5 : index
29519847feSAlex Zinenko      memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
309f3f6d7bSStella Laurenzo      return
319f3f6d7bSStella Laurenzo    }
32f9008e63STobias Hieta  """,
33f9008e63STobias Hieta        ctx,
34f9008e63STobias Hieta    )
359f3f6d7bSStella Laurenzo    func_body = module.body.operations[0].regions[0].blocks[0]
369f3f6d7bSStella Laurenzo    subview = func_body.operations[6]
379f3f6d7bSStella Laurenzo
389f3f6d7bSStella Laurenzo    assert subview.source == subview.operands[0]
399f3f6d7bSStella Laurenzo    assert len(subview.offsets) == 2
409f3f6d7bSStella Laurenzo    assert len(subview.sizes) == 2
419f3f6d7bSStella Laurenzo    assert len(subview.strides) == 2
429f3f6d7bSStella Laurenzo    assert subview.result == subview.results[0]
439f3f6d7bSStella Laurenzo
449f3f6d7bSStella Laurenzo    # CHECK: SubViewOp
459f3f6d7bSStella Laurenzo    print(type(subview).__name__)
469f3f6d7bSStella Laurenzo
479f3f6d7bSStella Laurenzo    # CHECK: constant 0
489f3f6d7bSStella Laurenzo    print(subview.offsets[0])
499f3f6d7bSStella Laurenzo    # CHECK: constant 1
509f3f6d7bSStella Laurenzo    print(subview.offsets[1])
519f3f6d7bSStella Laurenzo    # CHECK: constant 2
529f3f6d7bSStella Laurenzo    print(subview.sizes[0])
539f3f6d7bSStella Laurenzo    # CHECK: constant 3
549f3f6d7bSStella Laurenzo    print(subview.sizes[1])
559f3f6d7bSStella Laurenzo    # CHECK: constant 4
569f3f6d7bSStella Laurenzo    print(subview.strides[0])
579f3f6d7bSStella Laurenzo    # CHECK: constant 5
589f3f6d7bSStella Laurenzo    print(subview.strides[1])
599f3f6d7bSStella Laurenzo
609f3f6d7bSStella Laurenzo
617fd6f40dSAlex Zinenko# CHECK-LABEL: TEST: testCustomBuidlers
627fd6f40dSAlex Zinenko@run
637fd6f40dSAlex Zinenkodef testCustomBuidlers():
647fd6f40dSAlex Zinenko    with Context() as ctx, Location.unknown(ctx):
65f9008e63STobias Hieta        module = Module.parse(
66f9008e63STobias Hieta            r"""
672310ced8SRiver Riddle      func.func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
687fd6f40dSAlex Zinenko        return
697fd6f40dSAlex Zinenko      }
70f9008e63STobias Hieta    """
71f9008e63STobias Hieta        )
7223aa5a74SRiver Riddle        f = module.body.operations[0]
7323aa5a74SRiver Riddle        func_body = f.regions[0].blocks[0]
747fd6f40dSAlex Zinenko        with InsertionPoint.at_block_terminator(func_body):
7523aa5a74SRiver Riddle            memref.LoadOp(f.arguments[0], f.arguments[1:])
767fd6f40dSAlex Zinenko
777fd6f40dSAlex Zinenko        # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
787fd6f40dSAlex Zinenko        # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
797fd6f40dSAlex Zinenko        print(module)
80a04c0b7eSAlex Zinenko        assert module.operation.verify()
8183be8a74SMaksim Levental
8283be8a74SMaksim Levental
8383be8a74SMaksim Levental# CHECK-LABEL: TEST: testMemRefAttr
8483be8a74SMaksim Levental@run
8583be8a74SMaksim Leventaldef testMemRefAttr():
8683be8a74SMaksim Levental    with Context() as ctx, Location.unknown(ctx):
8783be8a74SMaksim Levental        module = Module.create()
8883be8a74SMaksim Levental        with InsertionPoint(module.body):
8983be8a74SMaksim Levental            memref.global_("objFifo_in0", T.memref(16, T.i32()))
9083be8a74SMaksim Levental        # CHECK: memref.global @objFifo_in0 : memref<16xi32>
9183be8a74SMaksim Levental        print(module)
92404af14fSMaksim Levental
93404af14fSMaksim Levental
94404af14fSMaksim Levental# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeSemantics
95404af14fSMaksim Levental@run
96404af14fSMaksim Leventaldef testSubViewOpInferReturnTypeSemantics():
97404af14fSMaksim Levental    with Context() as ctx, Location.unknown(ctx):
98404af14fSMaksim Levental        module = Module.create()
99404af14fSMaksim Levental        with InsertionPoint(module.body):
100404af14fSMaksim Levental            x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
101404af14fSMaksim Levental            # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
102404af14fSMaksim Levental            print(x.owner)
103404af14fSMaksim Levental
104404af14fSMaksim Levental            y = memref.subview(x, [1, 1], [3, 3], [1, 1])
105404af14fSMaksim Levental            assert y.owner.verify()
106404af14fSMaksim Levental            # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
107404af14fSMaksim Levental            print(y.owner)
108404af14fSMaksim Levental
109404af14fSMaksim Levental            z = memref.subview(
110404af14fSMaksim Levental                x,
111404af14fSMaksim Levental                [arith.constant(T.index(), 1), 1],
112404af14fSMaksim Levental                [3, 3],
113404af14fSMaksim Levental                [1, 1],
114404af14fSMaksim Levental            )
115404af14fSMaksim Levental            # CHECK: %{{.*}} =  memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
116404af14fSMaksim Levental            print(z.owner)
117404af14fSMaksim Levental
118404af14fSMaksim Levental            z = memref.subview(
119404af14fSMaksim Levental                x,
120404af14fSMaksim Levental                [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
121404af14fSMaksim Levental                [3, 3],
122404af14fSMaksim Levental                [1, 1],
123404af14fSMaksim Levental            )
124404af14fSMaksim Levental            # CHECK: %{{.*}} =  memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>>
125404af14fSMaksim Levental            print(z.owner)
126404af14fSMaksim Levental
127404af14fSMaksim Levental            s = arith.addi(arith.constant(T.index(), 3), arith.constant(T.index(), 4))
128404af14fSMaksim Levental            z = memref.subview(
129404af14fSMaksim Levental                x,
130404af14fSMaksim Levental                [s, 0],
131404af14fSMaksim Levental                [3, 3],
132404af14fSMaksim Levental                [1, 1],
133404af14fSMaksim Levental            )
134404af14fSMaksim Levental            # CHECK: {{.*}} = memref.subview %[[ALLOC]][%0, 0] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: ?>>
135404af14fSMaksim Levental            print(z)
136404af14fSMaksim Levental
137404af14fSMaksim Levental            try:
138404af14fSMaksim Levental                _infer_memref_subview_result_type(
139404af14fSMaksim Levental                    x.type,
140404af14fSMaksim Levental                    [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
141404af14fSMaksim Levental                    [ShapedType.get_dynamic_size(), 3],
142404af14fSMaksim Levental                    [1, 1],
143404af14fSMaksim Levental                )
144404af14fSMaksim Levental            except ValueError as e:
145404af14fSMaksim Levental                # CHECK: Only inferring from python or mlir integer constant is supported
146404af14fSMaksim Levental                print(e)
147404af14fSMaksim Levental
148404af14fSMaksim Levental            try:
149404af14fSMaksim Levental                memref.subview(
150404af14fSMaksim Levental                    x,
151404af14fSMaksim Levental                    [arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
152404af14fSMaksim Levental                    [ShapedType.get_dynamic_size(), 3],
153404af14fSMaksim Levental                    [1, 1],
154404af14fSMaksim Levental                )
155404af14fSMaksim Levental            except ValueError as e:
156404af14fSMaksim Levental                # CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type
157404af14fSMaksim Levental                print(e)
158404af14fSMaksim Levental
159404af14fSMaksim Levental            layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1])
160404af14fSMaksim Levental            x = memref.alloc(
161404af14fSMaksim Levental                T.memref(
162404af14fSMaksim Levental                    10,
163404af14fSMaksim Levental                    10,
164404af14fSMaksim Levental                    T.i32(),
165404af14fSMaksim Levental                    layout=layout,
166404af14fSMaksim Levental                ),
167404af14fSMaksim Levental                [],
168404af14fSMaksim Levental                [arith.constant(T.index(), 42)],
169404af14fSMaksim Levental            )
170404af14fSMaksim Levental            # CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>>
171404af14fSMaksim Levental            print(x.owner)
172404af14fSMaksim Levental            y = memref.subview(
173404af14fSMaksim Levental                x,
174404af14fSMaksim Levental                [1, 1],
175404af14fSMaksim Levental                [3, 3],
176404af14fSMaksim Levental                [1, 1],
177404af14fSMaksim Levental                result_type=T.memref(3, 3, T.i32(), layout=layout),
178404af14fSMaksim Levental            )
179404af14fSMaksim Levental            # CHECK: %{{.*}} = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>>
180404af14fSMaksim Levental            print(y.owner)
181404af14fSMaksim Levental
182404af14fSMaksim Levental
183404af14fSMaksim Levental# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeExtensiveSlicing
184404af14fSMaksim Levental@run
185404af14fSMaksim Leventaldef testSubViewOpInferReturnTypeExtensiveSlicing():
186404af14fSMaksim Levental    def check_strides_offset(memref, np_view):
187404af14fSMaksim Levental        layout = memref.type.layout
188404af14fSMaksim Levental        dtype_size_in_bytes = np_view.dtype.itemsize
189404af14fSMaksim Levental        golden_strides = (np.array(np_view.strides) // dtype_size_in_bytes).tolist()
190404af14fSMaksim Levental        golden_offset = (
191404af14fSMaksim Levental            np_view.ctypes.data - np_view.base.ctypes.data
192404af14fSMaksim Levental        ) // dtype_size_in_bytes
193404af14fSMaksim Levental
194404af14fSMaksim Levental        assert (layout.strides, layout.offset) == (golden_strides, golden_offset)
195404af14fSMaksim Levental
196404af14fSMaksim Levental    with Context() as ctx, Location.unknown(ctx):
197404af14fSMaksim Levental        module = Module.create()
198404af14fSMaksim Levental        with InsertionPoint(module.body):
199*59eadcd2SJacques Pienaar            shape = (10, 22, 3, 44)
200404af14fSMaksim Levental            golden_mem = np.zeros(shape, dtype=np.int32)
201404af14fSMaksim Levental            mem1 = memref.alloc(T.memref(*shape, T.i32()), [], [])
202404af14fSMaksim Levental
203404af14fSMaksim Levental            # fmt: off
204*59eadcd2SJacques Pienaar            check_strides_offset(memref.subview(mem1, (1, 0, 0, 0), (1, 22, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, ...])
205*59eadcd2SJacques Pienaar            check_strides_offset(memref.subview(mem1, (0, 1, 0, 0), (10, 1, 3, 44), (1, 1, 1, 1)), golden_mem[:, 1:2])
206*59eadcd2SJacques Pienaar            check_strides_offset(memref.subview(mem1, (0, 0, 1, 0), (10, 22, 1, 44), (1, 1, 1, 1)), golden_mem[:, :, 1:2])
207*59eadcd2SJacques Pienaar            check_strides_offset(memref.subview(mem1, (0, 0, 0, 1), (10, 22, 3, 1), (1, 1, 1, 1)), golden_mem[:, :, :, 1:2])
208*59eadcd2SJacques Pienaar            check_strides_offset(memref.subview(mem1, (0, 1, 0, 1), (10, 1, 3, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, :, 1:2])
209*59eadcd2SJacques Pienaar            check_strides_offset(memref.subview(mem1, (1, 0, 0, 1), (1, 22, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, :, :, 1:2])
210*59eadcd2SJacques Pienaar            check_strides_offset(memref.subview(mem1, (1, 1, 0, 0), (1, 1, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, :])
211404af14fSMaksim Levental            check_strides_offset(memref.subview(mem1, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem[:, :, 1:2, 1:2])
212*59eadcd2SJacques Pienaar            check_strides_offset(memref.subview(mem1, (0, 1, 1, 0), (10, 1, 1, 44), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, :])
213*59eadcd2SJacques Pienaar            check_strides_offset(memref.subview(mem1, (1, 0, 1, 0), (1, 22, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, :])
214*59eadcd2SJacques Pienaar            check_strides_offset(memref.subview(mem1, (1, 1, 0, 1), (1, 1, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, 1:2])
215404af14fSMaksim Levental            check_strides_offset(memref.subview(mem1, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, 1:2])
216404af14fSMaksim Levental            check_strides_offset(memref.subview(mem1, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, 1:2])
217*59eadcd2SJacques Pienaar            check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :])
218404af14fSMaksim Levental            # fmt: on
219404af14fSMaksim Levental
220404af14fSMaksim Levental            # default strides and offset means no stridedlayout attribute means affinemap layout
221404af14fSMaksim Levental            assert memref.subview(
222*59eadcd2SJacques Pienaar                mem1, (0, 0, 0, 0), (10, 22, 3, 44), (1, 1, 1, 1)
223404af14fSMaksim Levental            ).type.layout == AffineMapAttr.get(
224404af14fSMaksim Levental                AffineMap.get(
225404af14fSMaksim Levental                    4,
226404af14fSMaksim Levental                    0,
227404af14fSMaksim Levental                    [
228404af14fSMaksim Levental                        AffineDimExpr.get(0),
229404af14fSMaksim Levental                        AffineDimExpr.get(1),
230404af14fSMaksim Levental                        AffineDimExpr.get(2),
231404af14fSMaksim Levental                        AffineDimExpr.get(3),
232404af14fSMaksim Levental                    ],
233404af14fSMaksim Levental                )
234404af14fSMaksim Levental            )
235404af14fSMaksim Levental
236*59eadcd2SJacques Pienaar            shape = (7, 22, 30, 44)
237404af14fSMaksim Levental            golden_mem = np.zeros(shape, dtype=np.int32)
238404af14fSMaksim Levental            mem2 = memref.alloc(T.memref(*shape, T.i32()), [], [])
239404af14fSMaksim Levental            # fmt: off
240*59eadcd2SJacques Pienaar            check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 3, 44), (1, 2, 1, 1)), golden_mem[:, 0:22:2])
241*59eadcd2SJacques Pienaar            check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 44), (1, 2, 30, 1)), golden_mem[:, 0:22:2, 0:330:30])
242404af14fSMaksim Levental            check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem[:, 0:22:2, 0:330:30, 0:4400:400])
243404af14fSMaksim Levental            # fmt: on
244404af14fSMaksim Levental
245404af14fSMaksim Levental            shape = (8, 8)
246404af14fSMaksim Levental            golden_mem = np.zeros(shape, dtype=np.int32)
247404af14fSMaksim Levental            # fmt: off
248404af14fSMaksim Levental            mem3 = memref.alloc(T.memref(*shape, T.i32()), [], [])
249404af14fSMaksim Levental            check_strides_offset(memref.subview(mem3, (0, 0), (4, 4), (1, 1)), golden_mem[0:4, 0:4])
250404af14fSMaksim Levental            check_strides_offset(memref.subview(mem3, (4, 4), (4, 4), (1, 1)), golden_mem[4:8, 4:8])
251404af14fSMaksim Levental            # fmt: on
252