xref: /llvm-project/mlir/test/Dialect/MemRef/ops.mlir (revision 00a1f1ab71302d190f8059d86a53ec62485fbce9)
1// RUN: mlir-opt %s | mlir-opt | FileCheck %s
2// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
3
4// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (d0 + s0, d1)>
5
6// CHECK-LABEL: func @alloc() {
7func.func @alloc() {
8^bb0:
9  // Test simple alloc.
10  // CHECK: %{{.*}} = memref.alloc() : memref<1024x64xf32, 1>
11  %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
12
13  %c0 = "arith.constant"() {value = 0: index} : () -> index
14  %c1 = "arith.constant"() {value = 1: index} : () -> index
15
16  // Test alloc with dynamic dimensions.
17  // CHECK: %{{.*}} = memref.alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32, 1>
18  %1 = memref.alloc(%c0, %c1) : memref<?x?xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
19
20  // Test alloc with no dynamic dimensions and one symbol.
21  // CHECK: %{{.*}} = memref.alloc()[%{{.*}}] : memref<2x4xf32, #[[$MAP]], 1>
22  %2 = memref.alloc()[%c0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1>
23
24  // Test alloc with dynamic dimensions and one symbol.
25  // CHECK: %{{.*}} = memref.alloc(%{{.*}})[%{{.*}}] : memref<2x?xf32, #[[$MAP]], 1>
26  %3 = memref.alloc(%c1)[%c0] : memref<2x?xf32, affine_map<(d0, d1)[s0] -> (d0 + s0, d1)>, 1>
27
28  // Alloc with no mappings.
29  // b/116054838 Parser crash while parsing ill-formed AllocOp
30  // CHECK: %{{.*}} = memref.alloc() : memref<2xi32>
31  %4 = memref.alloc() : memref<2 x i32>
32
33  // CHECK:   return
34  return
35}
36
37// CHECK-LABEL: func @alloca() {
38func.func @alloca() {
39^bb0:
40  // Test simple alloc.
41  // CHECK: %{{.*}} = memref.alloca() : memref<1024x64xf32, 1>
42  %0 = memref.alloca() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
43
44  %c0 = "arith.constant"() {value = 0: index} : () -> index
45  %c1 = "arith.constant"() {value = 1: index} : () -> index
46
47  // Test alloca with dynamic dimensions.
48  // CHECK: %{{.*}} = memref.alloca(%{{.*}}, %{{.*}}) : memref<?x?xf32, 1>
49  %1 = memref.alloca(%c0, %c1) : memref<?x?xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
50
51  // Test alloca with no dynamic dimensions and one symbol.
52  // CHECK: %{{.*}} = memref.alloca()[%{{.*}}] : memref<2x4xf32, #[[$MAP]], 1>
53  %2 = memref.alloca()[%c0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1>
54
55  // Test alloca with dynamic dimensions and one symbol.
56  // CHECK: %{{.*}} = memref.alloca(%{{.*}})[%{{.*}}] : memref<2x?xf32, #[[$MAP]], 1>
57  %3 = memref.alloca(%c1)[%c0] : memref<2x?xf32, affine_map<(d0, d1)[s0] -> (d0 + s0, d1)>, 1>
58
59  // Alloca with no mappings, but with alignment.
60  // CHECK: %{{.*}} = memref.alloca() {alignment = 64 : i64} : memref<2xi32>
61  %4 = memref.alloca() {alignment = 64} : memref<2 x i32>
62
63  return
64}
65
66// CHECK-LABEL: func @dealloc() {
67func.func @dealloc() {
68^bb0:
69  // CHECK: %{{.*}} = memref.alloc() : memref<1024x64xf32>
70  %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 0>
71
72  // CHECK: memref.dealloc %{{.*}} : memref<1024x64xf32>
73  memref.dealloc %0 : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 0>
74  return
75}
76
77// CHECK-LABEL: func @load_store
78func.func @load_store() {
79^bb0:
80  // CHECK: %{{.*}} = memref.alloc() : memref<1024x64xf32, 1>
81  %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
82
83  %1 = arith.constant 0 : index
84  %2 = arith.constant 1 : index
85
86  // CHECK: %{{.*}} = memref.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x64xf32, 1>
87  %3 = memref.load %0[%1, %2] : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
88
89  // CHECK: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x64xf32, 1>
90  memref.store %3, %0[%1, %2] : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
91
92  return
93}
94
95// CHECK-LABEL: func @dma_ops()
96func.func @dma_ops() {
97  %c0 = arith.constant 0 : index
98  %stride = arith.constant 32 : index
99  %elt_per_stride = arith.constant 16 : index
100
101  %A = memref.alloc() : memref<256 x f32, affine_map<(d0) -> (d0)>, 0>
102  %Ah = memref.alloc() : memref<256 x f32, affine_map<(d0) -> (d0)>, 1>
103  %tag = memref.alloc() : memref<1 x f32>
104
105  %num_elements = arith.constant 256 : index
106
107  memref.dma_start %A[%c0], %Ah[%c0], %num_elements, %tag[%c0] : memref<256 x f32>, memref<256 x f32, 1>, memref<1 x f32>
108  memref.dma_wait %tag[%c0], %num_elements : memref<1 x f32>
109  // CHECK: dma_start %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}}[%{{.*}}] : memref<256xf32>, memref<256xf32, 1>, memref<1xf32>
110  // CHECK-NEXT:  dma_wait %{{.*}}[%{{.*}}], %{{.*}} : memref<1xf32>
111
112  // DMA with strides
113  memref.dma_start %A[%c0], %Ah[%c0], %num_elements, %tag[%c0], %stride, %elt_per_stride : memref<256 x f32>, memref<256 x f32, 1>, memref<1 x f32>
114  memref.dma_wait %tag[%c0], %num_elements : memref<1 x f32>
115  // CHECK-NEXT:  dma_start %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref<256xf32>, memref<256xf32, 1>, memref<1xf32>
116  // CHECK-NEXT:  dma_wait %{{.*}}[%{{.*}}], %{{.*}} : memref<1xf32>
117
118  return
119}
120
121// CHECK-LABEL: func @memref_reinterpret_cast
122func.func @memref_reinterpret_cast(%in: memref<?xf32>)
123    -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
124  %c0 = arith.constant 0 : index
125  %c10 = arith.constant 10 : index
126  %out = memref.reinterpret_cast %in to
127           offset: [%c0], sizes: [10, %c10], strides: [%c10, 1]
128           : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
129  return %out : memref<10x?xf32, strided<[?, 1], offset: ?>>
130}
131
132// CHECK-LABEL: func @memref_reinterpret_cast_static_to_dynamic_sizes
133func.func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref<?xf32>)
134    -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
135  %out = memref.reinterpret_cast %in to
136           offset: [1], sizes: [10, 10], strides: [1, 1]
137           : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
138  return %out : memref<10x?xf32, strided<[?, 1], offset: ?>>
139}
140
141// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset
142func.func @memref_reinterpret_cast_dynamic_offset(%in: memref<?xf32>, %offset: index)
143    -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
144  %out = memref.reinterpret_cast %in to
145           offset: [%offset], sizes: [10, 10], strides: [1, 1]
146           : memref<?xf32> to memref<10x?xf32, strided<[?, 1], offset: ?>>
147  return %out : memref<10x?xf32, strided<[?, 1], offset: ?>>
148}
149
150// CHECK-LABEL: func @memref_reshape(
151func.func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
152         %shape2: memref<2xi32>, %shape3: memref<?xi32>) -> memref<*xf32> {
153  %dyn_vec = memref.reshape %unranked(%shape1)
154               : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
155  %dyn_mat = memref.reshape %dyn_vec(%shape2)
156               : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
157  %new_unranked = memref.reshape %dyn_mat(%shape3)
158               : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
159  return %new_unranked : memref<*xf32>
160}
161
162// CHECK-LABEL: memref.global @memref0 : memref<2xf32>
163memref.global @memref0 : memref<2xf32>
164
165// CHECK-LABEL: memref.global constant @memref1 : memref<2xf32> = dense<[0.000000e+00, 1.000000e+00]>
166memref.global constant @memref1 : memref<2xf32> = dense<[0.0, 1.0]>
167
168// CHECK-LABEL: memref.global @memref2 : memref<2xf32> = uninitialized
169memref.global @memref2 : memref<2xf32>  = uninitialized
170
171// CHECK-LABEL: memref.global "private" @memref3 : memref<2xf32> = uninitialized
172memref.global "private" @memref3 : memref<2xf32>  = uninitialized
173
174// CHECK-LABEL: memref.global "private" constant @memref4 : memref<2xf32> = uninitialized
175memref.global "private" constant @memref4 : memref<2xf32>  = uninitialized
176
177// CHECK-LABEL: func @read_global_memref
178func.func @read_global_memref() {
179  %0 = memref.get_global @memref0 : memref<2xf32>
180  return
181}
182
183// CHECK-LABEL: func @memref_copy
184func.func @memref_copy() {
185  %0 = memref.alloc() : memref<2xf32>
186  %1 = memref.cast %0 : memref<2xf32> to memref<*xf32>
187  %2 = memref.alloc() : memref<2xf32>
188  %3 = memref.cast %0 : memref<2xf32> to memref<*xf32>
189  memref.copy %1, %3 : memref<*xf32> to memref<*xf32>
190  return
191}
192
193// CHECK-LABEL: func @memref_dealloc
194func.func @memref_dealloc() {
195  %0 = memref.alloc() : memref<2xf32>
196  %1 = memref.cast %0 : memref<2xf32> to memref<*xf32>
197  memref.dealloc %1 : memref<*xf32>
198  return
199}
200
201
202// CHECK-LABEL: func @memref_alloca_scope
203func.func @memref_alloca_scope() {
204  memref.alloca_scope {
205    memref.alloca_scope.return
206  }
207  return
208}
209
210// CHECK-LABEL: func @memref_cast(%arg0
211func.func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64x16x4xf32, strided<[64, 4, 1], offset: 0>>) {
212  // CHECK: memref.cast %{{.*}} : memref<4xf32> to memref<?xf32>
213  %0 = memref.cast %arg0 : memref<4xf32> to memref<?xf32>
214
215  // CHECK: memref.cast %{{.*}} : memref<?xf32> to memref<4xf32>
216  %1 = memref.cast %arg1 : memref<?xf32> to memref<4xf32>
217
218  // CHECK: memref.cast %{{.*}} : memref<64x16x4xf32, strided<[64, 4, 1]>> to memref<64x16x4xf32, strided<[?, ?, ?], offset: ?>>
219  %2 = memref.cast %arg2 : memref<64x16x4xf32, strided<[64, 4, 1], offset: 0>> to memref<64x16x4xf32, strided<[?, ?, ?], offset: ?>>
220
221  // CHECK: memref.cast {{%.*}} : memref<64x16x4xf32, strided<[?, ?, ?], offset: ?>> to memref<64x16x4xf32, strided<[64, 4, 1]>>
222  %3 = memref.cast %2 : memref<64x16x4xf32, strided<[?, ?, ?], offset: ?>> to memref<64x16x4xf32, strided<[64, 4, 1], offset: 0>>
223
224  // CHECK: memref.cast %{{.*}} : memref<4xf32> to memref<*xf32>
225  %4 = memref.cast %1 : memref<4xf32> to memref<*xf32>
226
227  // CHECK: memref.cast %{{.*}} : memref<*xf32> to memref<4xf32>
228  %5 = memref.cast %4 : memref<*xf32> to memref<4xf32>
229  return
230}
231
232// Check that unranked memrefs with non-default memory space roundtrip
233// properly.
234// CHECK-LABEL: @unranked_memref_roundtrip(memref<*xf32, 4>)
235func.func private @unranked_memref_roundtrip(memref<*xf32, 4>)
236
237// CHECK-LABEL: func @load_store_prefetch
238func.func @load_store_prefetch(memref<4x4xi32>, index) {
239^bb0(%0: memref<4x4xi32>, %1: index):
240  // CHECK: %0 = memref.load %arg0[%arg1, %arg1] : memref<4x4xi32>
241  %2 = "memref.load"(%0, %1, %1) : (memref<4x4xi32>, index, index)->i32
242
243  // CHECK: %{{.*}} = memref.load %arg0[%arg1, %arg1] : memref<4x4xi32>
244  %3 = memref.load %0[%1, %1] : memref<4x4xi32>
245
246  // CHECK: memref.prefetch %arg0[%arg1, %arg1], write, locality<1>, data : memref<4x4xi32>
247  memref.prefetch %0[%1, %1], write, locality<1>, data : memref<4x4xi32>
248
249  // CHECK: memref.prefetch %arg0[%arg1, %arg1], read, locality<3>, instr : memref<4x4xi32>
250  memref.prefetch %0[%1, %1], read, locality<3>, instr : memref<4x4xi32>
251
252  return
253}
254
255// Test with zero-dimensional operands using no index in load/store.
256// CHECK-LABEL: func @zero_dim_no_idx
257func.func @zero_dim_no_idx(%arg0 : memref<i32>, %arg1 : memref<i32>, %arg2 : memref<i32>) {
258  %0 = memref.load %arg0[] : memref<i32>
259  memref.store %0, %arg1[] : memref<i32>
260  return
261  // CHECK: %0 = memref.load %{{.*}}[] : memref<i32>
262  // CHECK: memref.store %{{.*}}, %{{.*}}[] : memref<i32>
263}
264
265// CHECK-LABEL: func @memref_view(%arg0
266func.func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
267  %0 = memref.alloc() : memref<2048xi8>
268  // Test two dynamic sizes and dynamic offset.
269  // CHECK: memref.view {{.*}} : memref<2048xi8> to memref<?x?xf32>
270  %1 = memref.view %0[%arg2][%arg0, %arg1] : memref<2048xi8> to memref<?x?xf32>
271
272  // Test one dynamic size and dynamic offset.
273  // CHECK: memref.view {{.*}} : memref<2048xi8> to memref<4x?xf32>
274  %3 = memref.view %0[%arg2][%arg1] : memref<2048xi8> to memref<4x?xf32>
275
276  // Test static sizes and static offset.
277  // CHECK: memref.view {{.*}} : memref<2048xi8> to memref<64x4xf32>
278  %c0 = arith.constant 0: index
279  %5 = memref.view %0[%c0][] : memref<2048xi8> to memref<64x4xf32>
280  return
281}
282
283// CHECK-LABEL: func @assume_alignment
284// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
285func.func @assume_alignment(%0: memref<4x4xf16>) {
286  // CHECK: memref.assume_alignment %[[MEMREF]], 16 : memref<4x4xf16>
287  memref.assume_alignment %0, 16 : memref<4x4xf16>
288  return
289}
290
291// CHECK-LABEL: func @expand_collapse_shape_static
292func.func @expand_collapse_shape_static(
293    %arg0: memref<3x4x5xf32>,
294    %arg1: tensor<3x4x5xf32>,
295    %arg2: tensor<3x?x5xf32>,
296    %arg3: memref<30x20xf32, strided<[4000, 2], offset: 100>>,
297    %arg4: memref<1x5xf32, strided<[5, 1], offset: ?>>,
298    %arg5: memref<f32>,
299    %arg6: memref<3x4x5xf32, strided<[240, 60, 10], offset: 0>>,
300    %arg7: memref<1x2049xi64, strided<[?, ?], offset: ?>>,
301    %arg8: memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>>,
302    %arg9: memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>>) {
303  // Reshapes that collapse and expand back a contiguous buffer.
304//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
305//  CHECK-SAME:     memref<3x4x5xf32> into memref<12x5xf32>
306  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
307    memref<3x4x5xf32> into memref<12x5xf32>
308
309//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [3, 4, 5]
310//  CHECK-SAME:     memref<12x5xf32> into memref<3x4x5xf32>
311  %r0 = memref.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 5] :
312    memref<12x5xf32> into memref<3x4x5xf32>
313
314//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0], [1, 2]]
315//  CHECK-SAME:     memref<3x4x5xf32> into memref<3x20xf32>
316  %1 = memref.collapse_shape %arg0 [[0], [1, 2]] :
317    memref<3x4x5xf32> into memref<3x20xf32>
318
319//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [3, 4, 5]
320//  CHECK-SAME:     memref<3x20xf32> into memref<3x4x5xf32>
321  %r1 = memref.expand_shape %1 [[0], [1, 2]] output_shape [3, 4, 5] :
322    memref<3x20xf32> into memref<3x4x5xf32>
323
324//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
325//  CHECK-SAME:     memref<3x4x5xf32> into memref<60xf32>
326  %2 = memref.collapse_shape %arg0 [[0, 1, 2]] :
327    memref<3x4x5xf32> into memref<60xf32>
328
329//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1, 2]] output_shape [3, 4, 5]
330//  CHECK-SAME:     memref<60xf32> into memref<3x4x5xf32>
331  %r2 = memref.expand_shape %2 [[0, 1, 2]] output_shape [3, 4, 5] :
332      memref<60xf32> into memref<3x4x5xf32>
333
334//       CHECK:   memref.expand_shape {{.*}} [] output_shape [1, 1]
335//  CHECK-SAME:     memref<f32> into memref<1x1xf32>
336  %r5 = memref.expand_shape %arg5 [] output_shape [1, 1] :
337      memref<f32> into memref<1x1xf32>
338
339// Reshapes with a custom layout map.
340//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [30, 4, 5]
341  %l0 = memref.expand_shape %arg3 [[0], [1, 2]] output_shape [30, 4, 5] :
342      memref<30x20xf32, strided<[4000, 2], offset: 100>>
343      into memref<30x4x5xf32, strided<[4000, 10, 2], offset: 100>>
344
345//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [2, 15, 20]
346  %l1 = memref.expand_shape %arg3 [[0, 1], [2]] output_shape [2, 15, 20] :
347      memref<30x20xf32, strided<[4000, 2], offset: 100>>
348      into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
349
350//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [1, 1, 5]
351  %r4 = memref.expand_shape %arg4 [[0], [1, 2]] output_shape [1, 1, 5] :
352      memref<1x5xf32, strided<[5, 1], offset: ?>> into
353      memref<1x1x5xf32, strided<[5, 5, 1], offset: ?>>
354
355  // Note: Only the collapsed two shapes are contiguous in the follow test case.
356//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
357  %r6 = memref.collapse_shape %arg6 [[0, 1], [2]] :
358      memref<3x4x5xf32, strided<[240, 60, 10], offset: 0>> into
359      memref<12x5xf32, strided<[60, 10], offset: 0>>
360
361//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1]]
362  %r7 = memref.collapse_shape %arg7 [[0, 1]] :
363      memref<1x2049xi64, strided<[?, ?], offset: ?>> into
364      memref<2049xi64, strided<[?], offset: ?>>
365
366    // %arg8: memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>>,
367    // %arg9: memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>>) {
368
369//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
370  %r8 = memref.collapse_shape %arg8 [[0, 1, 2]] :
371      memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>> into
372      memref<1024xi8, strided<[1], offset: 0>>
373
374//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0], [1, 2, 3]]
375  %r9 = memref.collapse_shape %arg9 [[0], [1, 2, 3]] :
376      memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>> into
377      memref<24x1024xi8, strided<[40960, 1], offset: 0>>
378
379  // Reshapes that expand and collapse back a contiguous buffer with some 1's.
380//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
381//  CHECK-SAME:     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
382  %3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]:
383    memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
384
385//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
386//  CHECK-SAME:     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
387  %r3 = memref.collapse_shape %3 [[0, 1], [2], [3, 4]] :
388    memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
389
390  // Reshapes on tensors.
391//       CHECK:   tensor.expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
392  %t0 = tensor.expand_shape %arg1 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5] :
393    tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
394
395//       CHECK:   tensor.collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
396  %rt0 = tensor.collapse_shape %t0 [[0, 1], [2], [3, 4]] :
397    tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
398
399//       CHECK:   tensor.dim %arg2, {{.*}} : tensor<3x?x5xf32>
400//       CHECK:   tensor.expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
401  %c1 = arith.constant 1 : index
402  %sz1 = tensor.dim %arg2, %c1 : tensor<3x?x5xf32>
403  %t1 = tensor.expand_shape %arg2 [[0, 1], [2], [3, 4]] output_shape [1, 3, %sz1, 1, 5] :
404    tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
405
406//       CHECK:   tensor.collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
407  %rt1 = tensor.collapse_shape %t1 [[0], [1, 2], [3, 4]] :
408    tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
409  return
410}
411
412// CHECK-LABEL: func @expand_collapse_shape_dynamic
413func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
414         %arg1: memref<?x?x?xf32, strided<[?, ?, 1], offset: 0>>,
415         %arg2: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>,
416         %arg3: memref<?x42xf32, strided<[42, 1], offset: 0>>,
417         %arg4: index,
418         %arg5: index,
419         %arg6: index,
420         %arg7: memref<4x?x4xf32>) {
421//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
422//  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
423  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
424    memref<?x?x?xf32> into memref<?x?xf32>
425
426//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5]
427//  CHECK-SAME:     memref<?x?xf32> into memref<?x4x?xf32>
428  %r0 = memref.expand_shape %0 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] :
429    memref<?x?xf32> into memref<?x4x?xf32>
430
431//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
432//  CHECK-SAME:     memref<?x?x?xf32, strided<[?, ?, 1]>> into memref<?x?xf32, strided<[?, 1]>>
433  %1 = memref.collapse_shape %arg1 [[0, 1], [2]] :
434    memref<?x?x?xf32, strided<[?, ?, 1], offset: 0>> into
435    memref<?x?xf32, strided<[?, 1], offset: 0>>
436
437//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5]
438//  CHECK-SAME:     memref<?x?xf32, strided<[?, 1]>> into memref<?x4x?xf32, strided<[?, ?, 1]>>
439  %r1 = memref.expand_shape %1 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] :
440    memref<?x?xf32, strided<[?, 1], offset: 0>> into
441    memref<?x4x?xf32, strided<[?, ?, 1], offset: 0>>
442
443//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
444//  CHECK-SAME:     memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> into memref<?x?xf32, strided<[?, 1], offset: ?>>
445  %2 = memref.collapse_shape %arg2 [[0, 1], [2]] :
446    memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> into
447    memref<?x?xf32, strided<[?, 1], offset: ?>>
448
449//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5]
450//  CHECK-SAME:     memref<?x?xf32, strided<[?, 1], offset: ?>> into memref<?x4x?xf32, strided<[?, ?, 1], offset: ?>>
451  %r2 = memref.expand_shape %2 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] :
452    memref<?x?xf32, strided<[?, 1], offset: ?>> into
453    memref<?x4x?xf32, strided<[?, ?, 1], offset: ?>>
454
455//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1]]
456//  CHECK-SAME:     memref<?x42xf32, strided<[42, 1]>> into memref<?xf32, strided<[1]>>
457  %3 = memref.collapse_shape %arg3 [[0, 1]] :
458    memref<?x42xf32, strided<[42, 1], offset: 0>> into
459    memref<?xf32, strided<[1]>>
460
461//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1]] output_shape [%arg6, 42]
462//  CHECK-SAME:     memref<?xf32, strided<[1]>> into memref<?x42xf32>
463  %r3 = memref.expand_shape %3 [[0, 1]] output_shape [%arg6, 42] :
464    memref<?xf32, strided<[1]>> into memref<?x42xf32>
465
466//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
467  %4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
468        : memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
469  return
470}
471
472func.func @expand_collapse_shape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>)
473    -> (memref<f32>, memref<1x1xf32>) {
474  %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
475  %1 = memref.expand_shape %0 [] output_shape [1, 1] : memref<f32> into memref<1x1xf32>
476  return %0, %1 : memref<f32>, memref<1x1xf32>
477}
478// CHECK-LABEL: func @expand_collapse_shape_zero_dim
479//       CHECK:   memref.collapse_shape %{{.*}} [] : memref<1x1xf32> into memref<f32>
480//       CHECK:   memref.expand_shape %{{.*}} [] output_shape [1, 1] : memref<f32> into memref<1x1xf32>
481
482func.func @collapse_shape_to_dynamic
483  (%arg0: memref<?x?x?x4x?xf32>) -> memref<?x?x?xf32> {
484  %0 = memref.collapse_shape %arg0 [[0], [1], [2, 3, 4]] :
485    memref<?x?x?x4x?xf32> into memref<?x?x?xf32>
486  return %0 : memref<?x?x?xf32>
487}
488//      CHECK: func @collapse_shape_to_dynamic
489//      CHECK:   memref.collapse_shape
490// CHECK-SAME:    [0], [1], [2, 3, 4]
491
492// -----
493
494// CHECK-LABEL: func @expand_collapse_shape_transposed_layout
495func.func @expand_collapse_shape_transposed_layout(
496    %m0: memref<?x?xf32, strided<[1, 10], offset: 0>>,
497    %m1: memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>>,
498    %sz0: index,
499    %sz1: index) {
500
501  %r0 = memref.expand_shape %m0 [[0], [1, 2]] output_shape [%sz0, %sz1, 5] :
502    memref<?x?xf32, strided<[1, 10], offset: 0>> into
503    memref<?x?x5xf32, strided<[1, 50, 10], offset: 0>>
504  %rr0 = memref.collapse_shape %r0 [[0], [1, 2]] :
505    memref<?x?x5xf32, strided<[1, 50, 10], offset: 0>> into
506    memref<?x?xf32, strided<[1, 10], offset: 0>>
507
508  %r1 = memref.expand_shape %m1 [[0, 1], [2], [3, 4]] output_shape [2, 2, 5, 2, 3] :
509    memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>> into
510    memref<2x2x5x2x3xf32, strided<[2, 1, ?, 3000, 1000], offset: 0>>
511  %rr1 = memref.collapse_shape %r1 [[0, 1], [2], [3, 4]] :
512    memref<2x2x5x2x3xf32, strided<[2, 1, ?, 3000, 1000], offset: 0>> into
513    memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>>
514  return
515}
516
517// -----
518
519func.func @rank(%t : memref<4x4x?xf32>) {
520  // CHECK: %{{.*}} = memref.rank %{{.*}} : memref<4x4x?xf32>
521  %0 = "memref.rank"(%t) : (memref<4x4x?xf32>) -> index
522
523  // CHECK: %{{.*}} = memref.rank %{{.*}} : memref<4x4x?xf32>
524  %1 = memref.rank %t : memref<4x4x?xf32>
525  return
526}
527
528// ------
529
530// CHECK-LABEL: func @atomic_rmw
531// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
532func.func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
533  %x = memref.atomic_rmw addf %val, %I[%i] : (f32, memref<10xf32>) -> f32
534  // CHECK: memref.atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]]
535  return
536}
537
538// CHECK-LABEL: func @generic_atomic_rmw
539// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index)
540func.func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {
541  %x = memref.generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> {
542  // CHECK-NEXT: memref.generic_atomic_rmw [[BUF]]{{\[}}[[I]], [[J]]] : memref
543    ^bb0(%old_value : f32):
544      %c1 = arith.constant 1.0 : f32
545      %out = arith.addf %c1, %old_value : f32
546      memref.atomic_yield %out : f32
547  // CHECK: index_attr = 8 : index
548  } { index_attr = 8 : index }
549  return
550}
551
552// -----
553
554func.func @extract_strided_metadata(%memref : memref<10x?xf32>)
555  -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
556
557  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %memref
558    : memref<10x?xf32> -> memref<f32>, index, index, index, index, index
559
560  %m2 = memref.reinterpret_cast %base to
561      offset: [%offset],
562      sizes: [%sizes#0, %sizes#1],
563      strides: [%strides#0, %strides#1]
564    : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
565
566  return %m2: memref<?x?xf32, strided<[?, ?], offset: ?>>
567}
568
569// -----
570
571// CHECK-LABEL: func @memref_realloc_ss
572func.func @memref_realloc_ss(%src : memref<2xf32>) -> memref<4xf32>{
573  %0 = memref.realloc %src : memref<2xf32> to memref<4xf32>
574  return %0 : memref<4xf32>
575}
576
577// CHECK-LABEL: func @memref_realloc_sd
578func.func @memref_realloc_sd(%src : memref<2xf32>, %d : index) -> memref<?xf32>{
579  %0 = memref.realloc %src(%d) : memref<2xf32> to memref<?xf32>
580  return %0 : memref<?xf32>
581}
582
583// CHECK-LABEL: func @memref_realloc_ds
584func.func @memref_realloc_ds(%src : memref<?xf32>) -> memref<4xf32>{
585  %0 = memref.realloc %src: memref<?xf32> to memref<4xf32>
586  return %0 : memref<4xf32>
587}
588
589// CHECK-LABEL: func @memref_realloc_dd
590func.func @memref_realloc_dd(%src : memref<?xf32>, %d: index)
591  -> memref<?xf32>{
592  %0 = memref.realloc %src(%d) : memref<?xf32> to memref<?xf32>
593  return %0 : memref<?xf32>
594}
595
596// CHECK-LABEL: func @memref_extract_aligned_pointer
597func.func @memref_extract_aligned_pointer(%src : memref<?xf32>) -> index {
598  %0 = memref.extract_aligned_pointer_as_index %src : memref<?xf32> -> index
599  return %0 : index
600}
601
602// CHECK-LABEL: func @memref_memory_space_cast
603func.func @memref_memory_space_cast(%src : memref<?xf32>) -> memref<?xf32, 1> {
604  %dst = memref.memory_space_cast %src : memref<?xf32> to memref<?xf32, 1>
605  return %dst : memref<?xf32, 1>
606}
607
608// CHECK-LABEL: func @memref_transpose_map
609func.func @memref_transpose_map(%src : memref<?x?xf32>) -> memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>> {
610  %dst = memref.transpose %src (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
611  return %dst : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
612}
613