xref: /llvm-project/mlir/test/Dialect/MemRef/canonicalize.mlir (revision 84aa02d3fa1f1f614c4f3c144ec118b2f05ae6b0)
1// RUN: mlir-opt %s -canonicalize="test-convergence" --split-input-file -allow-unregistered-dialect | FileCheck %s
2
3
4// CHECK-LABEL: collapse_shape_identity_fold
5// CHECK-NEXT: return
6func.func @collapse_shape_identity_fold(%arg0 : memref<5xi8>) -> memref<5xi8> {
7  %0 = memref.collapse_shape %arg0 [[0]] : memref<5xi8> into memref<5xi8>
8  return %0 : memref<5xi8>
9}
10
11// -----
12
13// CHECK-LABEL: expand_shape_identity_fold
14// CHECK-NEXT: return
15func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8> {
16  %0 = memref.expand_shape %arg0 [[0], [1]] output_shape [5, 4] : memref<5x4xi8> into memref<5x4xi8>
17  return %0 : memref<5x4xi8>
18}
19
20// -----
21
22// CHECK-LABEL: collapse_expand_rank0_cancel
23// CHECK-NEXT: return
24func.func @collapse_expand_rank0_cancel(%arg0 : memref<1x1xi8>) -> memref<1x1xi8> {
25  %0 = memref.collapse_shape %arg0 [] : memref<1x1xi8> into memref<i8>
26  %1 = memref.expand_shape %0 [] output_shape [1, 1] : memref<i8> into memref<1x1xi8>
27  return %1 : memref<1x1xi8>
28}
29
30// -----
31
32// CHECK-LABEL: func @subview_of_size_memcast
33//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
34//       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, strided{{.*}}>
35//       CHECK:   return %[[S]] : memref<16x32xi8, strided{{.*}}>
36func.func @subview_of_size_memcast(%arg : memref<4x6x16x32xi8>) ->
37  memref<16x32xi8, strided<[32, 1], offset: 512>>{
38  %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
39  %1 = memref.subview %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] :
40    memref<?x?x16x32xi8> to
41    memref<16x32xi8, strided<[32, 1], offset: 512>>
42  return %1 : memref<16x32xi8, strided<[32, 1], offset: 512>>
43}
44
45// -----
46
47//       CHECK: func @subview_of_strides_memcast
48//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<1x1x?xf32, strided{{.*}}>
49//       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][0, 0, 0] [1, 1, 4]
50//  CHECK-SAME:                    to memref<1x4xf32, strided<[7, 1], offset: ?>>
51//       CHECK:   %[[M:.+]] = memref.cast %[[S]]
52//  CHECK-SAME:                    to memref<1x4xf32, strided<[?, ?], offset: ?>>
53//       CHECK:   return %[[M]]
54func.func @subview_of_strides_memcast(%arg : memref<1x1x?xf32, strided<[35, 7, 1], offset: ?>>) -> memref<1x4xf32, strided<[?, ?], offset: ?>> {
55  %0 = memref.cast %arg : memref<1x1x?xf32, strided<[35, 7, 1], offset: ?>> to memref<1x1x?xf32, strided<[?, ?, ?], offset: ?>>
56  %1 = memref.subview %0[0, 0, 0] [1, 1, 4] [1, 1, 1] : memref<1x1x?xf32, strided<[?, ?, ?], offset: ?>> to memref<1x4xf32, strided<[?, ?], offset: ?>>
57  return %1 : memref<1x4xf32, strided<[?, ?], offset: ?>>
58}
59
60// -----
61
62// CHECK-LABEL: func @subview_of_static_full_size
63// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8>
64// CHECK-NOT: memref.subview
65// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8>
66func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> {
67  %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8>
68  return %0 : memref<4x6x16x32xi8>
69}
70
71// -----
72
73// CHECK-LABEL: func @negative_subview_of_static_full_size
74//  CHECK-SAME:   %[[ARG0:.+]]: memref<16x4xf32,  strided<[4, 1], offset: ?>>
75//  CHECK-SAME:   %[[IDX:.+]]: index
76//       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][%[[IDX]], 0] [16, 4] [1, 1]
77//  CHECK-SAME:                    to memref<16x4xf32,  strided<[4, 1], offset: ?>>
78//       CHECK:    return %[[S]] : memref<16x4xf32,  strided<[4, 1], offset: ?>>
79func.func @negative_subview_of_static_full_size(%arg0:  memref<16x4xf32,  strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32,  strided<[4, 1], offset: ?>> {
80  %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32,  strided<[4, 1], offset: ?>> to memref<16x4xf32,  strided<[4, 1], offset: ?>>
81  return %0 : memref<16x4xf32,  strided<[4, 1], offset: ?>>
82}
83
84// -----
85
86func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
87    %arg2 : index) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
88{
89  %c0 = arith.constant 0 : index
90  %c1 = arith.constant 1 : index
91  %c4 = arith.constant 4 : index
92  %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
93  return %0 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
94}
95// CHECK-LABEL: func @subview_canonicalize
96//  CHECK-SAME:   %[[ARG0:.+]]: memref<?x?x?xf32>
97//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
98//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
99//  CHECK-SAME:      : memref<?x?x?xf32> to memref<4x1x?xf32
100//       CHECK:   %[[RESULT:.+]] = memref.cast %[[SUBVIEW]]
101//       CHECK:   return %[[RESULT]]
102
103// -----
104
105func.func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
106  %arg2 : index) -> memref<?x?xf32, strided<[?, ?], offset: ?>>
107{
108  %c0 = arith.constant 0 : index
109  %c1 = arith.constant 1 : index
110  %c4 = arith.constant 4 : index
111  %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
112  return %0 : memref<?x?xf32, strided<[?, ?], offset: ?>>
113}
114// CHECK-LABEL: func @rank_reducing_subview_canonicalize
115//  CHECK-SAME:   %[[ARG0:.+]]: memref<?x?x?xf32>
116//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
117//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
118//  CHECK-SAME:      : memref<?x?x?xf32> to memref<4x?xf32
119//       CHECK:   %[[RESULT:.+]] = memref.cast %[[SUBVIEW]]
120//       CHECK:   return %[[RESULT]]
121
122// -----
123
124func.func @multiple_reducing_dims(%arg0 : memref<1x384x384xf32>,
125    %arg1 : index, %arg2 : index, %arg3 : index) -> memref<?xf32, strided<[1], offset: ?>>
126{
127  %c1 = arith.constant 1 : index
128  %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref<1x384x384xf32> to memref<?x?xf32, strided<[384, 1], offset: ?>>
129  %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<?x?xf32, strided<[384, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
130  return %1 : memref<?xf32, strided<[1], offset: ?>>
131}
132//       CHECK: func @multiple_reducing_dims
133//       CHECK:   %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1]
134//  CHECK-SAME:       : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>>
135//       CHECK:   %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1]
136//  CHECK-SAME:       : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
137
138// -----
139
140func.func @multiple_reducing_dims_dynamic(%arg0 : memref<?x?x?xf32>,
141    %arg1 : index, %arg2 : index, %arg3 : index) -> memref<?xf32, strided<[1], offset: ?>>
142{
143  %c1 = arith.constant 1 : index
144  %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
145  %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
146  return %1 : memref<?xf32, strided<[1], offset: ?>>
147}
148//       CHECK: func @multiple_reducing_dims_dynamic
149//       CHECK:   %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1]
150//  CHECK-SAME:       : memref<?x?x?xf32> to memref<1x?xf32, strided<[?, 1], offset: ?>>
151//       CHECK:   %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1]
152//  CHECK-SAME:       : memref<1x?xf32, strided<[?, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
153
154// -----
155
156func.func @multiple_reducing_dims_all_dynamic(%arg0 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>,
157    %arg1 : index, %arg2 : index, %arg3 : index) -> memref<?xf32, strided<[?], offset: ?>>
158{
159  %c1 = arith.constant 1 : index
160  %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1]
161      : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
162  %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
163  return %1 : memref<?xf32, strided<[?], offset: ?>>
164}
165//       CHECK: func @multiple_reducing_dims_all_dynamic
166//       CHECK:   %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1]
167//  CHECK-SAME:       : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to memref<1x?xf32, strided<[?, ?], offset: ?>>
168//       CHECK:   %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1]
169//  CHECK-SAME:       : memref<1x?xf32, strided<[?, ?], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
170
171// -----
172
173func.func @subview_negative_stride1(%arg0 : memref<?xf32>) -> memref<?xf32, strided<[?], offset: ?>>
174{
175  %c0 = arith.constant 0 : index
176  %c1 = arith.constant -1 : index
177  %1 = memref.dim %arg0, %c0 : memref<?xf32>
178  %2 = arith.addi %1, %c1 : index
179  %3 = memref.subview %arg0[%2] [%1] [%c1] : memref<?xf32> to memref<?xf32, strided<[?], offset: ?>>
180  return %3 : memref<?xf32, strided<[?], offset: ?>>
181}
182//       CHECK: func @subview_negative_stride1
183//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>)
184//       CHECK:   %[[C1:.*]] = arith.constant 0
185//       CHECK:   %[[C2:.*]] = arith.constant -1
186//       CHECK:   %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref<?xf32>
187//       CHECK:   %[[DIM2:.*]] = arith.addi %[[DIM1]], %[[C2]] : index
188//       CHECK:   %[[RES1:.*]] = memref.subview %[[ARG0]][%[[DIM2]]] [%[[DIM1]]] [-1] : memref<?xf32> to memref<?xf32, strided<[-1], offset: ?>>
189//       CHECK:   %[[RES2:.*]] = memref.cast %[[RES1]] : memref<?xf32, strided<[-1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
190//       CHECK:   return %[[RES2]] : memref<?xf32, strided<[?], offset: ?>>
191
192// -----
193
194func.func @subview_negative_stride2(%arg0 : memref<7xf32>) -> memref<?xf32, strided<[?], offset: ?>>
195{
196  %c0 = arith.constant 0 : index
197  %c1 = arith.constant -1 : index
198  %1 = memref.dim %arg0, %c0 : memref<7xf32>
199  %2 = arith.addi %1, %c1 : index
200  %3 = memref.subview %arg0[%2] [%1] [%c1] : memref<7xf32> to memref<?xf32, strided<[?], offset: ?>>
201  return %3 : memref<?xf32, strided<[?], offset: ?>>
202}
203//       CHECK: func @subview_negative_stride2
204//  CHECK-SAME:   (%[[ARG0:.*]]: memref<7xf32>)
205//       CHECK:   %[[RES1:.*]] = memref.subview %[[ARG0]][6] [7] [-1] : memref<7xf32> to memref<7xf32, strided<[-1], offset: 6>>
206//       CHECK:   %[[RES2:.*]] = memref.cast %[[RES1]] : memref<7xf32, strided<[-1], offset: 6>> to memref<?xf32, strided<[?], offset: ?>>
207//       CHECK:   return %[[RES2]] : memref<?xf32, strided<[?], offset: ?>>
208
209// -----
210
211// CHECK-LABEL: func @dim_of_sized_view
212//  CHECK-SAME:   %{{[a-z0-9A-Z_]+}}: memref<?xi8>
213//  CHECK-SAME:   %[[SIZE:.[a-z0-9A-Z_]+]]: index
214//       CHECK:   return %[[SIZE]] : index
215func.func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index {
216  %c0 = arith.constant 0 : index
217  %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [1] : memref<?xi8> to memref<?xi8>
218  %1 = memref.dim %0, %c0 : memref<?xi8>
219  return %1 : index
220}
221
222// -----
223
224// CHECK-LABEL: func @no_fold_subview_negative_size
225//  CHECK:        %[[SUBVIEW:.+]] = memref.subview
226//  CHECK:        return %[[SUBVIEW]]
227func.func @no_fold_subview_negative_size(%input: memref<4x1024xf32>) -> memref<?x256xf32, strided<[1024, 1], offset: 2304>> {
228  %cst = arith.constant -13 : index
229  %0 = memref.subview %input[2, 256] [%cst, 256] [1, 1] : memref<4x1024xf32> to memref<?x256xf32, strided<[1024, 1], offset: 2304>>
230  return %0 : memref<?x256xf32, strided<[1024, 1], offset: 2304>>
231}
232
233// -----
234
235// CHECK-LABEL: func @no_fold_subview_zero_stride
236//  CHECK:        %[[SUBVIEW:.+]] = memref.subview
237//  CHECK:        return %[[SUBVIEW]]
238func.func @no_fold_subview_zero_stride(%arg0 : memref<10xf32>) -> memref<1xf32, strided<[?], offset: 1>> {
239  %c0 = arith.constant 0 : index
240  %c1 = arith.constant 1 : index
241  %1 = memref.subview %arg0[1] [1] [%c0] : memref<10xf32> to memref<1xf32, strided<[?], offset: 1>>
242  return %1 : memref<1xf32, strided<[?], offset: 1>>
243}
244
245// -----
246
247// CHECK-LABEL: func @no_fold_of_store
248//  CHECK:   %[[cst:.+]] = memref.cast %arg
249//  CHECK:   memref.store %[[cst]]
250func.func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
251  %0 = memref.cast %arg : memref<32xi8> to memref<?xi8>
252  memref.store %0, %holder[] : memref<memref<?xi8>>
253  return
254}
255
256// -----
257
258// Test case: Folding of memref.dim(memref.alloca(%size), %idx) -> %size
259// CHECK-LABEL: func @dim_of_alloca(
260//  CHECK-SAME:     %[[SIZE:[0-9a-z]+]]: index
261//  CHECK-NEXT:   return %[[SIZE]] : index
262func.func @dim_of_alloca(%size: index) -> index {
263  %0 = memref.alloca(%size) : memref<?xindex>
264  %c0 = arith.constant 0 : index
265  %1 = memref.dim %0, %c0 : memref<?xindex>
266  return %1 : index
267}
268
269// -----
270
271// Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v)
272// CHECK-LABEL: func @dim_of_alloca_with_dynamic_size(
273//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>
274//  CHECK-NEXT:   %[[RANK:.*]] = memref.rank %[[MEM]] : memref<*xf32>
275//  CHECK-NEXT:   return %[[RANK]] : index
276func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
277  %0 = memref.rank %arg0 : memref<*xf32>
278  %1 = memref.alloca(%0) : memref<?xindex>
279  %c0 = arith.constant 0 : index
280  %2 = memref.dim %1, %c0 : memref<?xindex>
281  return %2 : index
282}
283
284// -----
285
286// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
287// CHECK-LABEL: func @dim_of_memref_reshape(
288//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
289//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xindex>
290//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
291//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
292//  CHECK-NEXT:   memref.store
293//   CHECK-NOT:   memref.dim
294//       CHECK:   return %[[DIM]] : index
295func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
296    -> index {
297  %c3 = arith.constant 3 : index
298  %0 = memref.reshape %arg0(%arg1)
299      : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
300  // Update the shape to test that he load ends up in the right place.
301  memref.store %c3, %arg1[%c3] : memref<?xindex>
302  %1 = memref.dim %0, %c3 : memref<*xf32>
303  return %1 : index
304}
305
306// -----
307
308// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
309// CHECK-LABEL: func @dim_of_memref_reshape_i32(
310//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
311//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xi32>
312//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
313//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
314//  CHECK-NEXT:   %[[CAST:.*]] = arith.index_cast %[[DIM]]
315//   CHECK-NOT:   memref.dim
316//       CHECK:   return %[[CAST]] : index
317func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
318    -> index {
319  %c3 = arith.constant 3 : index
320  %0 = memref.reshape %arg0(%arg1)
321      : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
322  %1 = memref.dim %0, %c3 : memref<*xf32>
323  return %1 : index
324}
325
326// -----
327
328// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
329// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
330//  CHECK-SAME:   %[[MEM:[0-9a-z]+]]: memref<*xf32>,
331//  CHECK-SAME:   %[[SHP:[0-9a-z]+]]: memref<?xindex>,
332//  CHECK-SAME:   %[[IDX:[0-9a-z]+]]: index
333//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
334//   CHECK-NOT:   memref.dim
335//       CHECK:   return %[[DIM]] : index
336func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
337  %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
338  %dim = memref.dim %reshape, %arg2 : memref<*xf32>
339  return %dim : index
340}
341
342// -----
343
344// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
345// CHECK-LABEL: func @dim_of_memref_reshape_for(
346//       CHECK: memref.reshape
347//       CHECK: memref.dim
348//   CHECK-NOT: memref.load
349func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref<?xindex>) -> index {
350    %c0 = arith.constant 0 : index
351    %c1 = arith.constant 1 : index
352    %c4 = arith.constant 4 : index
353
354    %0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
355
356    %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
357      %2 = memref.dim %0, %arg2 : memref<*xf32>
358      %3 = arith.muli %arg3, %2 : index
359      scf.yield %3 : index
360    }
361    return %1 : index
362}
363
364// -----
365
366// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
367// CHECK-LABEL: func @dim_of_memref_reshape_undominated(
368//       CHECK: memref.reshape
369//       CHECK: memref.dim
370//   CHECK-NOT: memref.load
371func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
372    %c4 = arith.constant 4 : index
373    %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
374    %0 = arith.muli %arg2, %c4 : index
375    %dim = memref.dim %reshape, %0 : memref<*xf32>
376    return %dim : index
377  }
378
379// -----
380
381// CHECK-LABEL: func @alloc_const_fold
382func.func @alloc_const_fold() -> memref<?xf32> {
383  // CHECK-NEXT: memref.alloc() : memref<4xf32>
384  %c4 = arith.constant 4 : index
385  %a = memref.alloc(%c4) : memref<?xf32>
386
387  // CHECK-NEXT: memref.cast %{{.*}} : memref<4xf32> to memref<?xf32>
388  // CHECK-NEXT: return %{{.*}} : memref<?xf32>
389  return %a : memref<?xf32>
390}
391
392// -----
393
394// CHECK-LABEL: func @alloc_alignment_const_fold
395func.func @alloc_alignment_const_fold() -> memref<?xf32> {
396  // CHECK-NEXT: memref.alloc() {alignment = 4096 : i64} : memref<4xf32>
397  %c4 = arith.constant 4 : index
398  %a = memref.alloc(%c4) {alignment = 4096 : i64} : memref<?xf32>
399
400  // CHECK-NEXT: memref.cast %{{.*}} : memref<4xf32> to memref<?xf32>
401  // CHECK-NEXT: return %{{.*}} : memref<?xf32>
402  return %a : memref<?xf32>
403}
404
405// -----
406
407// CHECK-LABEL: func @alloc_const_fold_with_symbols1(
408//  CHECK: %[[c1:.+]] = arith.constant 1 : index
409//  CHECK: %[[mem1:.+]] = memref.alloc({{.*}})[%[[c1]], %[[c1]]] : memref<?xi32, strided{{.*}}>
410//  CHECK: return %[[mem1]] : memref<?xi32, strided{{.*}}>
411func.func @alloc_const_fold_with_symbols1(%arg0 : index) -> memref<?xi32, strided<[?], offset: ?>> {
412  %c1 = arith.constant 1 : index
413  %0 = memref.alloc(%arg0)[%c1, %c1] : memref<?xi32, strided<[?], offset: ?>>
414  return %0 : memref<?xi32, strided<[?], offset: ?>>
415}
416
417// -----
418
419// CHECK-LABEL: func @alloc_const_fold_with_symbols2(
420//  CHECK: %[[c1:.+]] = arith.constant 1 : index
421//  CHECK: %[[mem1:.+]] = memref.alloc()[%[[c1]], %[[c1]]] : memref<1xi32, strided{{.*}}>
422//  CHECK: %[[mem2:.+]] = memref.cast %[[mem1]] : memref<1xi32, strided{{.*}}> to memref<?xi32, strided{{.*}}>
423//  CHECK: return %[[mem2]] : memref<?xi32, strided{{.*}}>
424func.func @alloc_const_fold_with_symbols2() -> memref<?xi32, strided<[?], offset: ?>> {
425  %c1 = arith.constant 1 : index
426  %0 = memref.alloc(%c1)[%c1, %c1] : memref<?xi32, strided<[?], offset: ?>>
427  return %0 : memref<?xi32, strided<[?], offset: ?>>
428}
429
430// -----
431// CHECK-LABEL: func @allocator
432// CHECK:   %[[alloc:.+]] = memref.alloc
433// CHECK:   memref.store %[[alloc:.+]], %arg0
434func.func @allocator(%arg0 : memref<memref<?xi32>>, %arg1 : index)  {
435  %0 = memref.alloc(%arg1) : memref<?xi32>
436  memref.store %0, %arg0[] : memref<memref<?xi32>>
437  return
438}
439
440// -----
441
442func.func @compose_collapse_of_collapse_zero_dim(%arg0 : memref<1x1x1xf32>)
443    -> memref<f32> {
444  %0 = memref.collapse_shape %arg0 [[0, 1, 2]]
445      : memref<1x1x1xf32> into memref<1xf32>
446  %1 = memref.collapse_shape %0 [] : memref<1xf32> into memref<f32>
447  return %1 : memref<f32>
448}
449// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim
450//       CHECK:   memref.collapse_shape %{{.*}} []
451//  CHECK-SAME:     memref<1x1x1xf32> into memref<f32>
452
453// -----
454
455func.func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
456    -> memref<?x?xf32> {
457  %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
458      : memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
459  %1 = memref.collapse_shape %0 [[0, 1], [2]]
460      : memref<?x?x?xf32> into memref<?x?xf32>
461  return %1 : memref<?x?xf32>
462}
463// CHECK-LABEL: func @compose_collapse_of_collapse
464//       CHECK:   memref.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
465//   CHECK-NOT:   memref.collapse_shape
466
467// -----
468
469func.func @do_not_compose_collapse_of_expand_non_identity_layout(
470    %arg0: memref<?x?xf32, strided<[?, 1], offset: 0>>, %sz0: index, %sz1: index)
471    -> memref<?xf32, strided<[?], offset: 0>> {
472  %1 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1] :
473    memref<?x?xf32, strided<[?, 1], offset: 0>> into
474    memref<?x4x?xf32, strided<[?, ?, 1], offset: 0>>
475  %2 = memref.collapse_shape %1 [[0, 1, 2]] :
476    memref<?x4x?xf32, strided<[?, ?, 1], offset: 0>> into
477    memref<?xf32, strided<[?], offset: 0>>
478  return %2 : memref<?xf32, strided<[?], offset: 0>>
479}
480// CHECK-LABEL: func @do_not_compose_collapse_of_expand_non_identity_layout
481// CHECK: expand
482// CHECK: collapse
483
484// -----
485
486func.func @compose_expand_of_expand(%arg0 : memref<?x?xf32>, %sz0: index, %sz1: index, %sz2: index, %sz3: index)
487    -> memref<?x6x4x5x?xf32> {
488  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1]
489      : memref<?x?xf32> into memref<?x4x?xf32>
490  %1 = memref.expand_shape %0 [[0, 1], [2], [3, 4]] output_shape [%sz2, 6, 4, 5, %sz3] : memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
491  return %1 : memref<?x6x4x5x?xf32>
492}
493// CHECK-LABEL: func @compose_expand_of_expand
494//       CHECK:   memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] output_shape [%{{.*}}, 6, 4, 5, %{{.*}}]
495//   CHECK-NOT:   memref.expand_shape
496
497// -----
498
499func.func @compose_expand_of_expand_of_zero_dim(%arg0 : memref<f32>)
500    -> memref<1x1x1xf32> {
501  %0 = memref.expand_shape %arg0 [] output_shape [1] : memref<f32> into memref<1xf32>
502  %1 = memref.expand_shape %0 [[0, 1, 2]] output_shape [1, 1, 1]
503      : memref<1xf32> into memref<1x1x1xf32>
504  return %1 : memref<1x1x1xf32>
505}
506// CHECK-LABEL: func @compose_expand_of_expand_of_zero_dim
507//       CHECK:   memref.expand_shape %{{.*}} [] output_shape [1, 1, 1]
508//  CHECK-SAME:     memref<f32> into memref<1x1x1xf32>
509
510// -----
511
512func.func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> {
513  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 4, 4]
514      : memref<12x4xf32> into memref<3x4x4xf32>
515  %1 = memref.collapse_shape %0 [[0, 1], [2]]
516      : memref<3x4x4xf32> into memref<12x4xf32>
517  return %1 : memref<12x4xf32>
518}
519// CHECK-LABEL: func @fold_collapse_of_expand
520//   CHECK-NOT:   linalg.{{.*}}_shape
521
522// -----
523
524func.func @fold_collapse_collapse_of_expand(%arg0 : memref<?x?xf32>, %sz0: index, %sz1: index)
525    -> memref<?x?xf32> {
526  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1]
527      : memref<?x?xf32> into memref<?x4x?xf32>
528  %1 = memref.collapse_shape %0 [[0, 1], [2]]
529      : memref<?x4x?xf32> into memref<?x?xf32>
530  return %1 : memref<?x?xf32>
531}
532// CHECK-LABEL: @fold_collapse_collapse_of_expand
533//   CHECK-NOT:   linalg.{{.*}}_shape
534
535// -----
536
537func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32> {
538  %0 = memref.cast %arg0 : memref<?x?xf32> to memref<8x4xf32>
539  %1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [2, 4, 4]
540      : memref<8x4xf32> into memref<2x4x4xf32>
541  return %1 : memref<2x4x4xf32>
542}
543
544// CHECK-LABEL: @fold_memref_expand_cast
545// CHECK: memref.expand_shape
546
547// -----
548
549// CHECK-LABEL:   func @collapse_after_memref_cast_type_change(
550// CHECK-SAME:      %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
551// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
552// CHECK-SAME:         {{\[\[}}0], [1, 2, 3]] : memref<?x512x1x1xf32> into memref<?x512xf32>
553// CHECK:           %[[DYNAMIC:.*]] = memref.cast %[[COLLAPSED]] :
554// CHECK-SAME:         memref<?x512xf32> to memref<?x?xf32>
555// CHECK:           return %[[DYNAMIC]] : memref<?x?xf32>
556// CHECK:         }
557func.func @collapse_after_memref_cast_type_change(%arg0 : memref<?x512x1x1xf32>) -> memref<?x?xf32> {
558  %dynamic = memref.cast %arg0: memref<?x512x1x1xf32> to memref<?x?x?x?xf32>
559  %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
560  return %collapsed : memref<?x?xf32>
561}
562
563// -----
564
565// CHECK-LABEL:   func @collapse_after_memref_cast(
566// CHECK-SAME:      %[[INPUT:.*]]: memref<?x512x1x?xf32>) -> memref<?x?xf32> {
567// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
568// CHECK-SAME:        {{\[\[}}0], [1, 2, 3]] : memref<?x512x1x?xf32> into memref<?x?xf32>
569// CHECK:           return %[[COLLAPSED]] : memref<?x?xf32>
570func.func @collapse_after_memref_cast(%arg0 : memref<?x512x1x?xf32>) -> memref<?x?xf32> {
571  %dynamic = memref.cast %arg0: memref<?x512x1x?xf32> to memref<?x?x?x?xf32>
572  %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
573  return %collapsed : memref<?x?xf32>
574}
575
576// -----
577
578// CHECK-LABEL:   func @collapse_after_memref_cast_type_change_dynamic(
579// CHECK-SAME:      %[[INPUT:.*]]: memref<1x1x1x?xi64>) -> memref<?x?xi64> {
580// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
581// CHECK-SAME:        {{\[\[}}0, 1, 2], [3]] : memref<1x1x1x?xi64> into memref<1x?xi64>
582// CHECK:           %[[DYNAMIC:.*]] = memref.cast %[[COLLAPSED]] :
583// CHECK-SAME:         memref<1x?xi64> to memref<?x?xi64>
584// CHECK:           return %[[DYNAMIC]] : memref<?x?xi64>
585func.func @collapse_after_memref_cast_type_change_dynamic(%arg0: memref<1x1x1x?xi64>) -> memref<?x?xi64> {
586  %casted = memref.cast %arg0 : memref<1x1x1x?xi64> to memref<1x1x?x?xi64>
587  %collapsed = memref.collapse_shape %casted [[0, 1, 2], [3]] : memref<1x1x?x?xi64> into memref<?x?xi64>
588  return %collapsed : memref<?x?xi64>
589}
590
591// -----
592
593func.func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index)
594    -> memref<1x4x1xf32, strided<[35, 7, 1], offset: ?>> {
595  %c0 = arith.constant 0 : index
596  %c5 = arith.constant 5 : index
597  %c4 = arith.constant 4 : index
598  %c2 = arith.constant 2 : index
599  %c1 = arith.constant 1 : index
600  %0 = memref.subview %arg0[%arg1, %arg1, %arg1, 0] [%c1, %c4, %c1, 1] [1, 1, 1, 1]
601      : memref<2x5x7x1xf32> to memref<?x?x?xf32, strided<[35, 7, 1], offset: ?>>
602  %1 = memref.cast %0
603      : memref<?x?x?xf32, strided<[35, 7, 1], offset: ?>> to
604        memref<1x4x1xf32, strided<[35, 7, 1], offset: ?>>
605  return %1 : memref<1x4x1xf32, strided<[35, 7, 1], offset: ?>>
606}
607
608// CHECK-LABEL: func @reduced_memref
609//       CHECK:   %[[RESULT:.+]] = memref.subview
610//  CHECK-SAME:       memref<2x5x7x1xf32> to memref<1x4x1xf32, strided{{.+}}>
611//       CHECK:   return %[[RESULT]]
612
613// -----
614
615// CHECK-LABEL: func @fold_rank_memref
616func.func @fold_rank_memref(%arg0 : memref<?x?xf32>) -> (index) {
617  // Fold a rank into a constant
618  // CHECK-NEXT: [[C2:%.+]] = arith.constant 2 : index
619  %rank_0 = memref.rank %arg0 : memref<?x?xf32>
620
621  // CHECK-NEXT: return [[C2]]
622  return %rank_0 : index
623}
624
625// -----
626
627func.func @fold_no_op_subview(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, strided<[42, 1]>> {
628  %0 = memref.subview %arg0[0, 0] [20, 42] [1, 1] : memref<20x42xf32> to memref<20x42xf32, strided<[42, 1]>>
629  return %0 : memref<20x42xf32, strided<[42, 1]>>
630}
631// CHECK-LABEL: func @fold_no_op_subview(
632//       CHECK:   %[[ARG0:.+]]: memref<20x42xf32>)
633//       CHECK:   %[[CAST:.+]] = memref.cast %[[ARG0]]
634//       CHECK:   return %[[CAST]]
635
636// -----
637
638func.func @no_fold_subview_with_non_zero_offset(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, strided<[42, 1], offset: 1>> {
639  %0 = memref.subview %arg0[0, 1] [20, 42] [1, 1] : memref<20x42xf32> to memref<20x42xf32, strided<[42, 1], offset: 1>>
640  return %0 : memref<20x42xf32, strided<[42, 1], offset: 1>>
641}
642// CHECK-LABEL: func @no_fold_subview_with_non_zero_offset(
643//       CHECK:   %[[SUBVIEW:.+]] = memref.subview
644//       CHECK:    return %[[SUBVIEW]]
645
646// -----
647
648func.func @no_fold_subview_with_non_unit_stride(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, strided<[42, 2]>> {
649  %0 = memref.subview %arg0[0, 0] [20, 42] [1, 2] : memref<20x42xf32> to memref<20x42xf32, strided<[42, 2]>>
650  return %0 : memref<20x42xf32, strided<[42, 2]>>
651}
652// CHECK-LABEL: func @no_fold_subview_with_non_unit_stride(
653//       CHECK:   %[[SUBVIEW:.+]] = memref.subview
654//       CHECK:    return %[[SUBVIEW]]
655
656// -----
657
658func.func @no_fold_dynamic_no_op_subview(%arg0 : memref<?x?xf32>) -> memref<?x?xf32, strided<[?, 1]>> {
659  %c0 = arith.constant 0 : index
660  %c1 = arith.constant 1 : index
661  %0 = memref.dim %arg0, %c0 : memref<?x?xf32>
662  %1 = memref.dim %arg0, %c1 : memref<?x?xf32>
663  %2 = memref.subview %arg0[0, 0] [%0, %1] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
664  return %2 : memref<?x?xf32, strided<[?, 1]>>
665}
666// CHECK-LABEL: func @no_fold_dynamic_no_op_subview(
667//       CHECK:   %[[SUBVIEW:.+]] = memref.subview
668//       CHECK:    return %[[SUBVIEW]]
669
670// -----
671
672func.func @atomicrmw_cast_fold(%arg0 : f32, %arg1 : memref<4xf32>, %c : index) {
673  %v = memref.cast %arg1 : memref<4xf32> to memref<?xf32>
674  %a = memref.atomic_rmw addf %arg0, %v[%c] : (f32, memref<?xf32>) -> f32
675  return
676}
677
678// CHECK-LABEL: func @atomicrmw_cast_fold
679// CHECK-NEXT: memref.atomic_rmw addf %arg0, %arg1[%arg2] : (f32, memref<4xf32>) -> f32
680
681// -----
682
683func.func @copy_of_cast(%m1: memref<?xf32>, %m2: memref<*xf32>) {
684  %casted1 = memref.cast %m1 : memref<?xf32> to memref<?xf32, strided<[?], offset: ?>>
685  %casted2 = memref.cast %m2 : memref<*xf32> to memref<?xf32, strided<[?], offset: ?>>
686  memref.copy %casted1, %casted2 : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
687  return
688}
689
690// CHECK-LABEL: func @copy_of_cast(
691//  CHECK-SAME:     %[[m1:.*]]: memref<?xf32>, %[[m2:.*]]: memref<*xf32>
692//       CHECK:   %[[casted2:.*]] = memref.cast %[[m2]]
693//       CHECK:   memref.copy %[[m1]], %[[casted2]]
694
695// -----
696
697func.func @self_copy(%m1: memref<?xf32>) {
698  memref.copy %m1, %m1 : memref<?xf32> to memref<?xf32>
699  return
700}
701
702// CHECK-LABEL: func @self_copy
703//  CHECK-NEXT:   return
704
705// -----
706
707// CHECK-LABEL: func @empty_copy
708//  CHECK-NEXT:   return
709func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref<?x10xf32>) {
710  memref.copy %m1, %m2 : memref<0x10xf32> to memref<?x10xf32>
711  memref.copy %m2, %m1 : memref<?x10xf32> to memref<0x10xf32>
712  return
713}
714
715// -----
716
717func.func @scopeMerge() {
718  memref.alloca_scope {
719    %cnt = "test.count"() : () -> index
720    %a = memref.alloca(%cnt) : memref<?xi64>
721    "test.use"(%a) : (memref<?xi64>) -> ()
722  }
723  return
724}
725// CHECK:   func @scopeMerge() {
726// CHECK-NOT: alloca_scope
727// CHECK:     %[[cnt:.+]] = "test.count"() : () -> index
728// CHECK:     %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref<?xi64>
729// CHECK:     "test.use"(%[[alloc]]) : (memref<?xi64>) -> ()
730// CHECK:     return
731
732func.func @scopeMerge2() {
733  "test.region"() ({
734    memref.alloca_scope {
735      %cnt = "test.count"() : () -> index
736      %a = memref.alloca(%cnt) : memref<?xi64>
737      "test.use"(%a) : (memref<?xi64>) -> ()
738    }
739    "test.terminator"() : () -> ()
740  }) : () -> ()
741  return
742}
743
744// CHECK:   func @scopeMerge2() {
745// CHECK:     "test.region"() ({
746// CHECK:       memref.alloca_scope {
747// CHECK:         %[[cnt:.+]] = "test.count"() : () -> index
748// CHECK:         %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref<?xi64>
749// CHECK:         "test.use"(%[[alloc]]) : (memref<?xi64>) -> ()
750// CHECK:       }
751// CHECK:       "test.terminator"() : () -> ()
752// CHECK:     }) : () -> ()
753// CHECK:     return
754// CHECK:   }
755
756func.func @scopeMerge3() {
757  %cnt = "test.count"() : () -> index
758  "test.region"() ({
759    memref.alloca_scope {
760      %a = memref.alloca(%cnt) : memref<?xi64>
761      "test.use"(%a) : (memref<?xi64>) -> ()
762    }
763    "test.terminator"() : () -> ()
764  }) : () -> ()
765  return
766}
767
768// CHECK:   func @scopeMerge3() {
769// CHECK:     %[[cnt:.+]] = "test.count"() : () -> index
770// CHECK:     %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref<?xi64>
771// CHECK:     "test.region"() ({
772// CHECK:       memref.alloca_scope {
773// CHECK:         "test.use"(%[[alloc]]) : (memref<?xi64>) -> ()
774// CHECK:       }
775// CHECK:       "test.terminator"() : () -> ()
776// CHECK:     }) : () -> ()
777// CHECK:     return
778// CHECK:   }
779
780func.func @scopeMerge4() {
781  %cnt = "test.count"() : () -> index
782  "test.region"() ({
783    memref.alloca_scope {
784      %a = memref.alloca(%cnt) : memref<?xi64>
785      "test.use"(%a) : (memref<?xi64>) -> ()
786    }
787    "test.op"() : () -> ()
788    "test.terminator"() : () -> ()
789  }) : () -> ()
790  return
791}
792
793// CHECK:   func @scopeMerge4() {
794// CHECK:     %[[cnt:.+]] = "test.count"() : () -> index
795// CHECK:     "test.region"() ({
796// CHECK:       memref.alloca_scope {
797// CHECK:         %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref<?xi64>
798// CHECK:         "test.use"(%[[alloc]]) : (memref<?xi64>) -> ()
799// CHECK:       }
800// CHECK:       "test.op"() : () -> ()
801// CHECK:       "test.terminator"() : () -> ()
802// CHECK:     }) : () -> ()
803// CHECK:     return
804// CHECK:   }
805
806func.func @scopeMerge5() {
807  "test.region"() ({
808    memref.alloca_scope {
809      affine.parallel (%arg) = (0) to (64) {
810        %a = memref.alloca(%arg) : memref<?xi64>
811        "test.use"(%a) : (memref<?xi64>) -> ()
812      }
813    }
814    "test.op"() : () -> ()
815    "test.terminator"() : () -> ()
816  }) : () -> ()
817  return
818}
819
820// CHECK:   func @scopeMerge5() {
821// CHECK:     "test.region"() ({
822// CHECK:       affine.parallel (%[[cnt:.+]]) = (0) to (64) {
823// CHECK:         %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref<?xi64>
824// CHECK:         "test.use"(%[[alloc]]) : (memref<?xi64>) -> ()
825// CHECK:       }
826// CHECK:       "test.op"() : () -> ()
827// CHECK:       "test.terminator"() : () -> ()
828// CHECK:     }) : () -> ()
829// CHECK:     return
830// CHECK:   }
831
832func.func @scopeInline(%arg : memref<index>) {
833  %cnt = "test.count"() : () -> index
834  "test.region"() ({
835    memref.alloca_scope {
836      memref.store %cnt, %arg[] : memref<index>
837    }
838    "test.terminator"() : () -> ()
839  }) : () -> ()
840  return
841}
842
843// CHECK:   func @scopeInline
844// CHECK-NOT:  memref.alloca_scope
845
846// -----
847
848// CHECK-LABEL: func @reinterpret_noop
849//  CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
850//  CHECK-NEXT: return %[[ARG]]
851func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
852  %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [2, 3, 4], strides: [12, 4, 1] : memref<2x3x4xf32> to memref<2x3x4xf32>
853  return %0 : memref<2x3x4xf32>
854}
855
856// -----
857
858// CHECK-LABEL: func @reinterpret_of_reinterpret
859//  CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
860//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
861//       CHECK: return %[[RES]]
862func.func @reinterpret_of_reinterpret(%arg : memref<?xi8>, %size1: index, %size2: index) -> memref<?xi8> {
863  %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size1], strides: [1] : memref<?xi8> to memref<?xi8>
864  %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref<?xi8> to memref<?xi8>
865  return %1 : memref<?xi8>
866}
867
868// -----
869
870// CHECK-LABEL: func @reinterpret_of_cast
871//  CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE:.*]]: index)
872//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE]]], strides: [1]
873//       CHECK: return %[[RES]]
874func.func @reinterpret_of_cast(%arg : memref<?xi8>, %size: index) -> memref<?xi8> {
875  %0 = memref.cast %arg : memref<?xi8> to memref<5xi8>
876  %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size], strides: [1] : memref<5xi8> to memref<?xi8>
877  return %1 : memref<?xi8>
878}
879
880// -----
881
882// CHECK-LABEL: func @reinterpret_of_subview
883//  CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
884//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
885//       CHECK: return %[[RES]]
886func.func @reinterpret_of_subview(%arg : memref<?xi8>, %size1: index, %size2: index) -> memref<?xi8> {
887  %0 = memref.subview %arg[0] [%size1] [1] : memref<?xi8> to memref<?xi8>
888  %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [%size2], strides: [1] : memref<?xi8> to memref<?xi8>
889  return %1 : memref<?xi8>
890}
891
892// -----
893
894// Check that a reinterpret cast of an equivalent extract strided metadata
895// is canonicalized to a plain cast when the destination type is different
896// than the type of the original memref.
897// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_type_mistach
898//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
899//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
900//       CHECK: return %[[CAST]]
901func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
902  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
903  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
904  return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
905}
906
907// -----
908
909// Similar to reinterpret_of_extract_strided_metadata_w_type_mistach except that
910// we check that the match happen when the static information has been folded.
911// E.g., in this case, we know that size of dim 0 is 8 and size of dim 1 is 2.
912// So even if we don't use the values sizes#0, sizes#1, as long as they have the
913// same constant value, the match is valid.
914// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_constants
915//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
916//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
917//       CHECK: return %[[CAST]]
918func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
919  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
920  %c8 = arith.constant 8: index
921  %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
922  return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
923}
924// -----
925
926// Check that a reinterpret cast of an equivalent extract strided metadata
927// is completely removed when the original memref has the same type.
928// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_same_type
929//  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32
930//       CHECK: return %[[ARG]]
931func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?xf32, strided<[?,?], offset: ?>>) -> memref<?x?xf32, strided<[?,?], offset: ?>> {
932  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<?x?xf32, strided<[?,?], offset: ?>> -> memref<f32>, index, index, index, index, index
933  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?,?], offset:?>>
934  return %m2 : memref<?x?xf32, strided<[?,?], offset:?>>
935}
936
937// -----
938
939// Check that we don't simplify reinterpret cast of extract strided metadata
940// when the strides don't match.
941// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
942//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
943//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
944//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
945//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
946//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
947//       CHECK: return %[[RES]]
948func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
949  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
950  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
951  return %m2 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
952}
953// -----
954
955// Check that we don't simplify reinterpret cast of extract strided metadata
956// when the offset doesn't match.
957// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
958//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
959//   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
960//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
961//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
962//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
963//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
964//       CHECK: return %[[RES]]
965func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
966  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
967  %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
968  return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
969}
970
971// -----
972
973func.func @canonicalize_rank_reduced_subview(%arg0 : memref<8x?xf32>,
974    %arg1 : index) -> memref<?xf32, strided<[?], offset: ?>> {
975  %c0 = arith.constant 0 : index
976  %c1 = arith.constant 1 : index
977  %0 = memref.subview %arg0[%c0, %c0] [1, %arg1] [%c1, %c1] : memref<8x?xf32> to memref<?xf32, strided<[?], offset: ?>>
978  return %0 :  memref<?xf32, strided<[?], offset: ?>>
979}
980//      CHECK: func @canonicalize_rank_reduced_subview
981// CHECK-SAME:     %[[ARG0:.+]]: memref<8x?xf32>
982// CHECK-SAME:     %[[ARG1:.+]]: index
983//      CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, 0] [1, %[[ARG1]]] [1, 1]
984// CHECK-SAME:       memref<8x?xf32> to memref<?xf32, strided<[1]>>
985
986// -----
987
988// CHECK-LABEL: func @memref_realloc_dead
989// CHECK-SAME: %[[SRC:[0-9a-z]+]]: memref<2xf32>
990// CHECK-NOT: memref.realloc
991// CHECK: return %[[SRC]]
992func.func @memref_realloc_dead(%src : memref<2xf32>, %v : f32) -> memref<2xf32>{
993  %0 = memref.realloc %src : memref<2xf32> to memref<4xf32>
994  %i2 = arith.constant 2 : index
995  memref.store %v, %0[%i2] : memref<4xf32>
996  return %src : memref<2xf32>
997}
998
999// -----
1000
1001// CHECK-LABEL: func @collapse_expand_fold_to_cast(
1002//  CHECK-SAME:     %[[m:.*]]: memref<?xf32, strided<[1]>, 3>
1003//       CHECK:   %[[casted:.*]] = memref.cast %[[m]] : memref<?xf32, strided<[1]>, 3> to memref<?xf32, 3
1004//       CHECK:   return %[[casted]]
1005func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>, %sz0: index)
1006    -> (memref<?xf32, 3>)
1007{
1008  %0 = memref.expand_shape %m [[0, 1]] output_shape [1, %sz0]
1009      : memref<?xf32, strided<[1]>, 3> into memref<1x?xf32, 3>
1010  %1 = memref.collapse_shape %0 [[0, 1]]
1011      : memref<1x?xf32, 3> into memref<?xf32, 3>
1012  return %1 : memref<?xf32, 3>
1013}
1014
1015// -----
1016
1017// CHECK-LABEL: func @fold_trivial_subviews(
1018//  CHECK-SAME:     %[[m:.*]]: memref<?xf32, strided<[?], offset: ?>>
1019//       CHECK:   %[[subview:.*]] = memref.subview %[[m]][5]
1020//       CHECK:   return %[[subview]]
1021func.func @fold_trivial_subviews(%m: memref<?xf32, strided<[?], offset: ?>>,
1022                                 %sz: index)
1023    -> memref<?xf32, strided<[?], offset: ?>>
1024{
1025  %0 = memref.subview %m[5] [%sz] [1]
1026      : memref<?xf32, strided<[?], offset: ?>>
1027        to memref<?xf32, strided<[?], offset: ?>>
1028  %1 = memref.subview %0[0] [%sz] [1]
1029      : memref<?xf32, strided<[?], offset: ?>>
1030        to memref<?xf32, strided<[?], offset: ?>>
1031  return %1 : memref<?xf32, strided<[?], offset: ?>>
1032}
1033
1034// -----
1035
1036// CHECK-LABEL: func @load_store_nontemporal(
1037func.func @load_store_nontemporal(%input : memref<32xf32, affine_map<(d0) -> (d0)>>, %output : memref<32xf32, affine_map<(d0) -> (d0)>>) {
1038  %1 = arith.constant 7 : index
1039  // CHECK: memref.load %{{.*}}[%{{.*}}] {nontemporal = true} : memref<32xf32>
1040  %2 = memref.load %input[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>>
1041  // CHECK: memref.store %{{.*}}, %{{.*}}[%{{.*}}] {nontemporal = true} : memref<32xf32>
1042  memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>>
1043  func.return
1044}
1045
1046// -----
1047
1048// CHECK-LABEL: func @fold_trivial_memory_space_cast(
1049//  CHECK-SAME:     %[[arg:.*]]: memref<?xf32>
1050//       CHECK:   return %[[arg]]
1051func.func @fold_trivial_memory_space_cast(%arg : memref<?xf32>) -> memref<?xf32> {
1052  %0 = memref.memory_space_cast %arg : memref<?xf32> to memref<?xf32>
1053  return %0 : memref<?xf32>
1054}
1055
1056// -----
1057
1058// CHECK-LABEL: func @fold_multiple_memory_space_cast(
1059//  CHECK-SAME:     %[[arg:.*]]: memref<?xf32>
1060//       CHECK:   %[[res:.*]] = memref.memory_space_cast %[[arg]] : memref<?xf32> to memref<?xf32, 2>
1061//       CHECK:   return %[[res]]
1062func.func @fold_multiple_memory_space_cast(%arg : memref<?xf32>) -> memref<?xf32, 2> {
1063  %0 = memref.memory_space_cast %arg : memref<?xf32> to memref<?xf32, 1>
1064  %1 = memref.memory_space_cast %0 : memref<?xf32, 1> to memref<?xf32, 2>
1065  return %1 : memref<?xf32, 2>
1066}
1067
1068// -----
1069
1070// CHECK-LABEL: func private @ub_negative_alloc_size
1071func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
1072  %idx1 = index.constant 1
1073  %c-2 = arith.constant -2 : index
1074  %c15 = arith.constant 15 : index
1075// CHECK:   %[[ALLOC:.*]] = memref.alloc(%c-2) : memref<15x?x1xi1>
1076  %alloc = memref.alloc(%c15, %c-2, %idx1) : memref<?x?x?xi1>
1077  return %alloc : memref<?x?x?xi1>
1078}
1079
1080// -----
1081
1082// CHECK-LABEL: func @subview_rank_reduction(
1083//  CHECK-SAME:     %[[arg0:.*]]: memref<1x384x384xf32>, %[[arg1:.*]]: index
1084func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
1085    -> memref<?x?xf32, strided<[384, 1], offset: ?>> {
1086  %c1 = arith.constant 1 : index
1087  // CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, %[[arg1]], %[[arg1]]] [1, 1, %[[arg1]]] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>>
1088  // CHECK: %[[cast:.*]] = memref.cast %[[subview]] : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref<?x?xf32, strided<[384, 1], offset: ?>>
1089  %0 = memref.subview %arg0[0, %idx, %idx] [1, %c1, %idx] [1, 1, 1]
1090      : memref<1x384x384xf32> to memref<?x?xf32, strided<[384, 1], offset: ?>>
1091  // CHECK: return %[[cast]]
1092  return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
1093}
1094
1095// -----
1096
1097// CHECK-LABEL: func @fold_double_transpose(
1098//  CHECK-SAME:     %[[arg0:.*]]: memref<1x2x3x4x5xf32>
1099func.func @fold_double_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
1100  // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
1101  %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
1102  %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
1103  // CHECK: return %[[ONETRANSPOSE]]
1104  return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
1105}
1106
1107// -----
1108
1109// CHECK-LABEL: func @fold_double_transpose2(
1110//  CHECK-SAME:     %[[arg0:.*]]: memref<1x2x3x4x5xf32>
1111func.func @fold_double_transpose2(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
1112  // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
1113  %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d0, d1, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>>
1114  %1 = memref.transpose %0 (d0, d1, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
1115  // CHECK: return %[[ONETRANSPOSE]]
1116  return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
1117}
1118
1119// -----
1120
1121// CHECK-LABEL: func @fold_identity_transpose(
1122//  CHECK-SAME:     %[[arg0:.*]]: memref<1x2x3x4x5xf32>
1123func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3x4x5xf32> {
1124  %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
1125  %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d0, d1, d2, d3, d4) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<1x2x3x4x5xf32>
1126  // CHECK: return %[[arg0]]
1127  return %1 : memref<1x2x3x4x5xf32>
1128}
1129
1130// -----
1131
1132#transpose_map = affine_map<(d0, d1)[s0] -> (d0 + d1 * s0)>
1133
1134// CHECK-LABEL: func @cannot_fold_transpose_cast(
1135//  CHECK-SAME:     %[[arg0:.*]]: memref<?x4xf32>
1136func.func @cannot_fold_transpose_cast(%arg0: memref<?x4xf32>) -> memref<?x?xf32, #transpose_map> {
1137    // CHECK: %[[CAST:.*]] = memref.cast %[[arg0]] : memref<?x4xf32> to memref<?x?xf32>
1138    %cast = memref.cast %arg0 : memref<?x4xf32> to memref<?x?xf32>
1139    // CHECK: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32, #{{.*}}>
1140    %transpose = memref.transpose %cast (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32, #transpose_map>
1141    // CHECK: return %[[TRANSPOSE]]
1142    return %transpose : memref<?x?xf32, #transpose_map>
1143}
1144