xref: /llvm-project/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (revision 57e4360836f421a2c6131de51e3845620c6aea76)
1// RUN: mlir-opt -fold-memref-alias-ops -split-input-file %s | FileCheck %s
2
3func.func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
4  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
5  %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>>
6  return %1 : f32
7}
8//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)>
9//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
10//      CHECK: func @fold_static_stride_subview_with_load
11// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
12// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
13// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
14// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
15// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
16//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG3]]]
17//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG4]]]
18//      CHECK:   memref.load %[[ARG0]][%[[I1]], %[[I2]]]
19
20// -----
21
22func.func @fold_dynamic_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) -> f32 {
23  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
24    memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>>
25  %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[?, ?], offset: ?>>
26  return %1 : f32
27}
28//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * s2)>
29//      CHECK: func @fold_dynamic_stride_subview_with_load
30// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
31// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
32// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
33// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
34// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
35// CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: index
36// CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: index
37//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[ARG5]]]
38//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[ARG6]]]
39//      CHECK:   memref.load %[[ARG0]][%[[I1]], %[[I2]]]
40
41// -----
42
43func.func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
44  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] :
45    memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
46  memref.store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>>
47  return
48}
49//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)>
50//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
51//      CHECK: func @fold_static_stride_subview_with_store
52// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
53// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
54// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
55// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
56// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
57//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG3]]]
58//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG4]]]
59//      CHECK:   memref.store %{{.+}}, %[[ARG0]][%[[I1]], %[[I2]]]
60
61// -----
62
63func.func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : f32) {
64  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
65    memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>>
66  memref.store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, strided<[?, ?], offset: ?>>
67  return
68}
69//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * s2)>
70//      CHECK: func @fold_dynamic_stride_subview_with_store
71// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
72// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
73// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
74// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
75// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
76// CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: index
77// CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: index
78//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[ARG5]]]
79//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[ARG6]]]
80//      CHECK:   memref.store %{{.+}}, %[[ARG0]][%[[I1]], %[[I2]]]
81
82// -----
83
84func.func @fold_subview_with_transfer_read_0d(
85  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index)
86    -> vector<f32> {
87  %f1 = arith.constant 1.0 : f32
88  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
89  %1 = vector.transfer_read %0[], %f1 : memref<f32, strided<[], offset: ?>>, vector<f32>
90  return %1 : vector<f32>
91}
92//      CHECK: func @fold_subview_with_transfer_read_0d
93// CHECK-SAME:   %[[MEM:[a-zA-Z0-9_]+]]: memref<12x32xf32>
94// CHECK-SAME:   %[[SZ0:[a-zA-Z0-9_]+]]: index
95// CHECK-SAME:   %[[SZ1:[a-zA-Z0-9_]+]]: index
96// CHECK-SAME:   %[[ST1:[a-zA-Z0-9_]+]]: index
97//      CHECK:   vector.transfer_read %[[MEM]][%[[SZ0]], %[[SZ1]]]
98
99// -----
100
101func.func @fold_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) -> vector<4xf32> {
102  %f1 = arith.constant 1.0 : f32
103
104  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>>
105  %1 = vector.transfer_read %0[%arg3, %arg4], %f1 {in_bounds = [true]} : memref<4x4xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
106  return %1 : vector<4xf32>
107}
108//      CHECK: func @fold_subview_with_transfer_read
109// Can't fold this atm since we don't emit the proper vector.extract_strided_slice.
110//   CHECK: memref.subview
111
112// -----
113
114func.func @fold_static_stride_subview_with_transfer_write_0d(
115    %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index,
116    %v : vector<f32>) {
117  %f1 = arith.constant 1.0 : f32
118  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
119  vector.transfer_write %v, %0[] {in_bounds = []} : vector<f32>, memref<f32, strided<[], offset: ?>>
120  return
121}
122//      CHECK: func @fold_static_stride_subview_with_transfer_write_0d
123// CHECK-SAME:   %[[MEM:[a-zA-Z0-9_]+]]: memref<12x32xf32>
124// CHECK-SAME:   %[[SZ0:[a-zA-Z0-9_]+]]: index
125// CHECK-SAME:   %[[SZ1:[a-zA-Z0-9_]+]]: index
126// CHECK-SAME:   %[[ST1:[a-zA-Z0-9_]+]]: index
127// CHECK-SAME:   %[[V:[a-zA-Z0-9_]+]]: vector<f32>
128//      CHECK:   vector.transfer_write %[[V]], %[[MEM]][%[[SZ0]], %[[SZ1]]]
129
130// -----
131
132func.func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5: index, %arg6 : index, %arg7 : vector<4xf32>) {
133  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
134    memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>>
135  vector.transfer_write %arg7, %0[%arg3, %arg4] {in_bounds = [true]} : vector<4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>
136  return
137}
138//      CHECK: func @fold_static_stride_subview_with_transfer_write
139// Can't fold this atm since we don't emit the proper vector.extract_strided_slice.
140//   CHECK: memref.subview
141
142// -----
143
144func.func @fold_rank_reducing_subview_with_load
145    (%arg0 : memref<?x?x?x?x?x?xf32>, %arg1 : index, %arg2 : index,
146     %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
147     %arg7 : index, %arg8 : index, %arg9 : index, %arg10: index,
148     %arg11 : index, %arg12 : index, %arg13 : index, %arg14: index,
149     %arg15 : index, %arg16 : index) -> f32 {
150  %0 = memref.subview %arg0[%arg1, %arg2, %arg3, %arg4, %arg5, %arg6][4, 1, 1, 4, 1, 1][%arg7, %arg8, %arg9, %arg10, %arg11, %arg12] : memref<?x?x?x?x?x?xf32> to memref<4x1x4x1xf32, strided<[?, ?, ?, ?], offset: ?>>
151  %1 = memref.load %0[%arg13, %arg14, %arg15, %arg16] : memref<4x1x4x1xf32, strided<[?, ?, ?, ?], offset: ?>>
152  return %1 : f32
153}
154//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * s2)>
155//      CHECK: func @fold_rank_reducing_subview_with_load
156// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?x?x?x?x?xf32>
157// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
158// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
159// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
160// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
161// CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: index
162// CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: index
163// CHECK-SAME:   %[[ARG7:[a-zA-Z0-9_]+]]: index
164// CHECK-SAME:   %[[ARG8:[a-zA-Z0-9_]+]]: index
165// CHECK-SAME:   %[[ARG9:[a-zA-Z0-9_]+]]: index
166// CHECK-SAME:   %[[ARG10:[a-zA-Z0-9_]+]]: index
167// CHECK-SAME:   %[[ARG11:[a-zA-Z0-9_]+]]: index
168// CHECK-SAME:   %[[ARG12:[a-zA-Z0-9_]+]]: index
169// CHECK-SAME:   %[[ARG13:[a-zA-Z0-9_]+]]: index
170// CHECK-SAME:   %[[ARG14:[a-zA-Z0-9_]+]]: index
171// CHECK-SAME:   %[[ARG15:[a-zA-Z0-9_]+]]: index
172// CHECK-SAME:   %[[ARG16:[a-zA-Z0-9_]+]]: index
173//  CHECK-DAG:   %[[I0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG13]], %[[ARG7]]]
174//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG14]], %[[ARG9]]]
175//  CHECK-DAG:   %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG15]], %[[ARG10]]]
176//  CHECK-DAG:   %[[I4:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG16]], %[[ARG11]]]
177//      CHECK:   memref.load %[[ARG0]][%[[I0]], %[[ARG2]], %[[I2]], %[[I3]], %[[I4]], %[[ARG6]]]
178
179// -----
180
181func.func @fold_vector_transfer_read_with_rank_reduced_subview(
182    %arg0 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>,
183    %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
184    %arg6 : index) -> vector<4xf32> {
185  %cst = arith.constant 0.0 : f32
186  %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %arg3, %arg4] [1, 1, 1]
187      : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to
188        memref<?x?xf32, strided<[?, ?], offset: ?>>
189  %1 = vector.transfer_read %0[%arg5, %arg6], %cst {in_bounds = [true]}
190      : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
191  return %1 : vector<4xf32>
192}
193//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
194//       CHECK: func @fold_vector_transfer_read_with_rank_reduced_subview
195//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
196//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: index
197//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
198//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
199//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
200//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
201//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
202//   CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
203//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]], %[[ARG5]]]
204//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
205//       CHECK:    vector.transfer_read %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]], %{{.*}} : memref<?x?x?xf32
206
207// -----
208
209func.func @fold_vector_transfer_write_with_rank_reduced_subview(
210    %arg0 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>,
211    %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
212    %arg5: index, %arg6 : index, %arg7 : index) {
213  %cst = arith.constant 0.0 : f32
214  %0 = memref.subview %arg0[0, %arg2, %arg3] [1, %arg4, %arg5] [1, 1, 1]
215      : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to
216        memref<?x?xf32, strided<[?, ?], offset: ?>>
217  vector.transfer_write %arg1, %0[%arg6, %arg7] {in_bounds = [true]}
218      : vector<4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
219  return
220}
221//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
222//       CHECK: func @fold_vector_transfer_write_with_rank_reduced_subview
223//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
224//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
225//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
226//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
227//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
228//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
229//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
230//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: index
231//   CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
232//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
233//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
234//   CHECK-DAG:    vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]] {in_bounds = [true]} : vector<4xf32>, memref<?x?x?xf32
235
236// -----
237
238func.func @fold_vector_transfer_write_with_inner_rank_reduced_subview(
239    %arg0 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>,
240    %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
241    %arg5: index, %arg6 : index, %arg7 : index) {
242  %cst = arith.constant 0.0 : f32
243  %0 = memref.subview %arg0[%arg2, %arg3, 0] [%arg4, %arg5, 1] [1, 1, 1]
244      : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to
245        memref<?x?xf32, strided<[?, ?], offset: ?>>
246  vector.transfer_write %arg1, %0[%arg6, %arg7] {in_bounds = [true]}
247      : vector<4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
248  return
249}
250//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
251//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
252//       CHECK: func @fold_vector_transfer_write_with_inner_rank_reduced_subview
253//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
254//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
255//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
256//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
257//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
258//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
259//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
260//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: index
261//   CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
262//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
263//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
264//   CHECK-DAG:    vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]], %[[C0]]]
265//  CHECK-SAME:    {in_bounds = [true], permutation_map = #[[MAP2]]} : vector<4xf32>, memref<?x?x?xf32
266
267// -----
268
269func.func @fold_masked_vector_transfer_read_with_subview(
270    %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
271    %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
272    %arg6 : index, %mask : vector<4xi1>) -> vector<4xf32> {
273  %cst = arith.constant 0.0 : f32
274  %0 = memref.subview %arg0[%arg1, %arg2] [%arg3, %arg4] [1, 1]
275      : memref<?x?xf32, strided<[?, ?], offset: ?>> to
276        memref<?x?xf32, strided<[?, ?], offset: ?>>
277  %1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {in_bounds = [true]}
278      : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
279  return %1 : vector<4xf32>
280}
281//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
282//       CHECK: func @fold_masked_vector_transfer_read_with_subview
283//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
284//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: index
285//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
286//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
287//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
288//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
289//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
290//  CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]: vector<4xi1>
291//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]], %[[ARG5]]]
292//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
293//       CHECK:    vector.transfer_read %[[ARG0]][%[[IDX0]], %[[IDX1]]], %{{.*}}, %[[MASK]] {{.*}} : memref<?x?xf32
294
295// -----
296
297func.func @fold_masked_vector_transfer_read_with_rank_reducing_subview(
298    %arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
299    %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
300    %arg6 : index, %mask : vector<4x3xi1>) -> vector<3x4xf32> {
301  %cst = arith.constant 0.0 : f32
302  %0 = memref.subview %arg0[0, %arg1, 0, %arg2] [1, %arg3, 1, %arg4] [1, 1, 1, 1]
303      : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
304        memref<?x?xf32, strided<[?, ?], offset: ?>>
305  %1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {
306         permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]}
307      : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<3x4xf32>
308  return %1 : vector<3x4xf32>
309}
310//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
311//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d1)>
312//       CHECK: func @fold_masked_vector_transfer_read_with_rank_reducing_subview
313//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
314//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: index
315//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
316//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
317//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
318//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
319//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
320//  CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]: vector<4x3xi1>
321//   CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
322//   CHECK-DAG:    %[[PAD:.+]] = arith.constant 0.000000e+00 : f32
323//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG5]]]
324//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG6]]]
325//       CHECK:    vector.transfer_read %[[ARG0]][%[[C0]], %[[IDX0]], %[[C0]], %[[IDX1]]], %[[PAD]], %[[MASK]] {{.*}} permutation_map = #[[MAP1]]} : memref<?x?x?x?xf32
326
327// -----
328
329func.func @fold_masked_vector_transfer_write_with_subview(
330    %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
331    %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
332    %arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4xi1>) {
333  %cst = arith.constant 0.0 : f32
334  %0 = memref.subview %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1]
335      : memref<?x?xf32, strided<[?, ?], offset: ?>> to
336        memref<?x?xf32, strided<[?, ?], offset: ?>>
337  vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {in_bounds = [true]}
338      : vector<4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
339  return
340}
341//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
342//       CHECK: func @fold_masked_vector_transfer_write_with_subview
343//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
344//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
345//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
346//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
347//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
348//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
349//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
350//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: index
351//  CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]: vector<4xi1>
352//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
353//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
354//   CHECK-DAG:    vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]]], %[[MASK]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32
355
356// -----
357
358func.func @fold_masked_vector_transfer_write_with_rank_reducing_subview(
359    %arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
360    %arg1 : vector<3x4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
361    %arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4x3xi1>) {
362  %cst = arith.constant 0.0 : f32
363  %0 = memref.subview %arg0[0, %arg2, 0, %arg3] [1, %arg4, 1, %arg5] [1, 1, 1, 1]
364      : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
365        memref<?x?xf32, strided<[?, ?], offset: ?>>
366  vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {
367        permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]}
368      : vector<3x4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
369  return
370}
371//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
372//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d1)>
373//       CHECK: func @fold_masked_vector_transfer_write_with_rank_reducing_subview
374//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
375//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: vector<3x4xf32>
376//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
377//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
378//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
379//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
380//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
381//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: index
382//  CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]: vector<4x3xi1>
383//   CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
384//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG6]]]
385//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP0]]()[%[[ARG3]], %[[ARG7]]]
386//   CHECK-DAG:    vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[C0]], %[[IDX1]]], %[[ARG8]] {in_bounds = [true, true], permutation_map = #[[MAP1]]} : vector<3x4xf32>, memref<?x?x?x?xf32
387
388// -----
389
390//  Test with affine.load/store ops. We only do a basic test here since the
391//  logic is identical to that with memref.load/store ops. The same affine.apply
392//  ops would be generated.
393
394// CHECK-LABEL: func @fold_static_stride_subview_with_affine_load_store
395func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
396  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
397  %1 = affine.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>>
398  // CHECK-NEXT: affine.apply
399  // CHECK-NEXT: affine.apply
400  // CHECK-NEXT: affine.load
401  affine.store %1, %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>>
402  // CHECK-NEXT: affine.apply
403  // CHECK-NEXT: affine.apply
404  // CHECK-NEXT: affine.store
405  // CHECK-NEXT: return
406  return %1 : f32
407}
408
409// -----
410
411// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 6 + s1)>
412// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
413// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
414func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
415  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 6, 32] : memref<12x32xf32> into memref<2x6x32xf32>
416  %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
417  return %1 : f32
418}
419// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]]
420// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
421// CHECK-NEXT: return %[[RESULT]] : f32
422
423// -----
424
425// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
426// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
427// CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
428// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
429func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
430  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<2x6x32xf32> into memref<12x32xf32>
431  %1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
432  return %1 : f32
433}
434// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
435// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
436// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
437// CHECK-NEXT: return %[[RESULT]] : f32
438
439// -----
440
441// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
442// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
443// CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
444// CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
445func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
446  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x6x32xf32> into memref<?x32xf32>
447  %1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
448  return %1 : f32
449}
450// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
451// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
452// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<?x6x32xf32>
453// CHECK-NEXT: return %[[RESULT]] : f32
454
455// -----
456
457// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 6 + s1 * 3 + s2)>
458// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
459// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
460func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
461  %0 = memref.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [2, 2, 3, 32] : memref<12x32xf32> into memref<2x2x3x32xf32>
462  %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
463  return %1 : f32
464}
465// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]]]
466// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
467// CHECK-NEXT: return %[[RESULT]] : f32
468
469// -----
470
471// CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
472// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32
473func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
474  %c0 = arith.constant 0 : index
475  %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
476  %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
477  return %0 : f32
478}
479// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
480// CHECK-NEXT: return %[[VAL1]] : f32
481
482// -----
483
484// CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape
485// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
486func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0 : index) {
487  %c0 = arith.constant 0 : index
488  %c1f32 = arith.constant 1.0 : f32
489  %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
490  memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
491  return
492}
493// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
494// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
495// CHECK-NEXT: return
496
497// -----
498
499// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
500// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)>
501// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
502// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
503func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
504  %subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
505  %expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [%sz0, 1, 8, 2] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
506  %dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
507
508  affine.for %arg6 = 0 to %dim step 64 {
509    affine.for %arg7 = 0 to 16 step 16 {
510      %dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
511      affine.store %dummy_load, %subview[%arg6, %arg7] : memref<?x16xf32, strided<[16, 1], offset: ?>>
512    }
513  }
514  return
515}
516// CHECK-NEXT:   memref.subview
517// CHECK-NEXT:   %[[EXPAND_SHAPE:.*]] = memref.expand_shape
518// CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
519// CHECK-NEXT:   affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
520// CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
521// CHECK-NEXT:   %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
522// CHECK-NEXT:   %[[VAL1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]]]
523// CHECK-NEXT:   %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
524// CHECK-NEXT:   %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
525// CHECK-NEXT:   affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
526
527// -----
528
529// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 * 1024 + s1)>
530// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
531// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
532// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
533func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
534  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
535  affine.for %arg3 = 0 to 1 {
536    affine.for %arg4 = 0 to 1024 {
537      affine.for %arg5 = 0 to 1020 {
538        affine.for %arg6 = 0 to 1 {
539          %1 = affine.load %0[%arg3, %arg4, %arg5, %arg6] : memref<1x1024x1024x1xf32>
540          affine.store %1, %arg1[%arg2] : memref<1xf32>
541        }
542      }
543    }
544  }
545  %2 = affine.load %arg1[%arg2] : memref<1xf32>
546  return %2 : f32
547}
548// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
549// CHECK-NEXT:  affine.for %[[ARG4:.*]] = 0 to 1024 {
550// CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to 1020 {
551// CHECK-NEXT:    affine.for %[[ARG6:.*]] = 0 to 1 {
552// CHECK-NEXT:     %[[IDX1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]], %[[ARG4]]]
553// CHECK-NEXT:     %[[IDX2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
554// CHECK-NEXT:     affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
555
556// -----
557
558// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 * 1024)>
559// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
560// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
561// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
562func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
563  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
564  affine.for %arg3 = 0 to 1 {
565    affine.for %arg4 = 0 to 1024 {
566      affine.for %arg5 = 0 to 1020 {
567        affine.for %arg6 = 0 to 1 {
568          %1 = affine.load %0[%arg3, %arg4 + %arg3, %arg5, %arg6] : memref<1x1024x1024x1xf32>
569          affine.store %1, %arg1[%arg2] : memref<1xf32>
570        }
571      }
572    }
573  }
574  %2 = affine.load %arg1[%arg2] : memref<1xf32>
575  return %2 : f32
576}
577// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
578// CHECK-NEXT:  affine.for %[[ARG4:.*]] = 0 to 1024 {
579// CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to 1020 {
580// CHECK-NEXT:    affine.for %[[ARG6:.*]] = 0 to 1 {
581// CHECK-NEXT:      %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])[%[[ARG3]]]
582// CHECK-NEXT:      %[[TMP3:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
583// CHECK-NEXT:      affine.load %[[ARG0]][%[[TMP1]], %[[TMP3]]] : memref<1024x1024xf32>
584
585// -----
586
587// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 1024)>
588// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
589// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
590// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
591func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
592  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
593  %cst = arith.constant 0 : index
594  affine.for %arg3 = 0 to 1 {
595    affine.for %arg4 = 0 to 1024 {
596      affine.for %arg5 = 0 to 1020 {
597        affine.for %arg6 = 0 to 1 {
598          %1 = memref.load %0[%arg3, %cst, %arg5, %arg6] : memref<1x1024x1024x1xf32>
599          memref.store %1, %arg1[%arg2] : memref<1xf32>
600        }
601      }
602    }
603  }
604  %2 = memref.load %arg1[%arg2] : memref<1xf32>
605  return %2 : f32
606}
607// CHECK-NEXT:   affine.for %[[ARG3:.*]] = 0 to 1 {
608// CHECK-NEXT:   affine.for %[[ARG4:.*]] = 0 to 1024 {
609// CHECK-NEXT:    affine.for %[[ARG5:.*]] = 0 to 1020 {
610// CHECK-NEXT:     affine.for %[[ARG6:.*]] = 0 to 1 {
611// CHECK-NEXT:      %[[TMP1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]]]
612// CHECK-NEXT:      %[[TMP2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
613// CHECK-NEXT:      memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32>
614
615// -----
616
617// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result
618// CHECK-SAME: (%[[ARG0:.*]]: memref<1xf32>, %[[ARG1:.*]]: memref<1xf32>)
619func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result(%arg0: memref<1xf32>, %arg1: memref<1xf32>) -> memref<1xf32> {
620  %0 = memref.collapse_shape %arg0 [] : memref<1xf32> into memref<f32>
621  affine.for %arg2 = 0 to 3 {
622    %1 = affine.load %0[] : memref<f32>
623    affine.store %1, %arg1[0] : memref<1xf32>
624  }
625  return %arg1 : memref<1xf32>
626}
627// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index
628// CHECK-NEXT: affine.for %{{.*}} = 0 to 3 {
629// CHECK-NEXT:   affine.load %[[ARG0]][%[[ZERO]]] : memref<1xf32>
630
631// -----
632
633//       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 2)>
634// CHECK-LABEL: func @subview_of_subview(
635//  CHECK-SAME:     %[[m:.*]]: memref<1x1024xf32, 3>, %[[pos:.*]]: index
636//       CHECK:   %[[add:.*]] = affine.apply #[[$map]]()[%arg1]
637//       CHECK:   memref.subview %arg0[4, %[[add]]] [1, 1] [1, 1] : memref<1x1024xf32, 3> to memref<f32, strided<[], offset: ?>, 3>
638func.func @subview_of_subview(%m: memref<1x1024xf32, 3>, %pos: index)
639    -> memref<f32, strided<[], offset: ?>, 3>
640{
641  %0 = memref.subview %m[3, %pos] [1, 2] [1, 1]
642      : memref<1x1024xf32, 3>
643        to memref<1x2xf32, strided<[1024, 1], offset: ?>, 3>
644  %1 = memref.subview %0[1, 2] [1, 1] [1, 1]
645      : memref<1x2xf32, strided<[1024, 1], offset: ?>, 3>
646        to memref<f32, strided<[], offset: ?>, 3>
647  return %1 : memref<f32, strided<[], offset: ?>, 3>
648}
649
650// -----
651
652// CHECK-LABEL: func @subview_of_subview_rank_reducing(
653//  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>
654//       CHECK:   memref.subview %arg0[3, 7, 8] [1, 1, 1] [1, 1, 1] : memref<?x?x?xf32> to memref<f32, strided<[], offset: ?>>
655func.func @subview_of_subview_rank_reducing(%m: memref<?x?x?xf32>,
656                                            %sz: index, %pos: index)
657    -> memref<f32, strided<[], offset: ?>>
658{
659  %0 = memref.subview %m[3, 1, 8] [1, %sz, 1] [1, 1, 1]
660      : memref<?x?x?xf32>
661        to memref<?xf32, strided<[?], offset: ?>>
662  %1 = memref.subview %0[6] [1] [1]
663      : memref<?xf32, strided<[?], offset: ?>>
664        to memref<f32, strided<[], offset: ?>>
665  return %1 : memref<f32, strided<[], offset: ?>>
666}
667
668// -----
669
670// CHECK-LABEL: func @fold_load_keep_nontemporal(
671//      CHECK:   memref.load %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true}
672func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
673  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
674  %1 = memref.load %0[%arg3, %arg4] {nontemporal = true }: memref<4x4xf32, strided<[64, 3], offset: ?>>
675  return %1 : f32
676}
677
678// -----
679
680// CHECK-LABEL: func @fold_store_keep_nontemporal(
681//      CHECK:   memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}]  {nontemporal = true} : memref<12x32xf32>
682func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
683  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] :
684    memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
685  memref.store %arg5, %0[%arg3, %arg4] {nontemporal=true}: memref<4x4xf32, strided<[64, 3], offset: ?>>
686  return
687}
688
689// -----
690
691func.func @fold_gpu_subgroup_mma_load_matrix_1d(%src: memref<?xvector<4xf32>>, %offset: index, %i: index) -> !gpu.mma_matrix<16x16xf16, "COp"> {
692  %subview = memref.subview %src[%offset] [81920] [1] : memref<?xvector<4xf32>> to memref<81920xvector<4xf32>, strided<[1], offset: ?>>
693  %matrix = gpu.subgroup_mma_load_matrix %subview[%i] {leadDimension = 160 : index} : memref<81920xvector<4xf32>, strided<[1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
694  return %matrix: !gpu.mma_matrix<16x16xf16, "COp">
695}
696
697//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
698//      CHECK: func.func @fold_gpu_subgroup_mma_load_matrix_1d
699// CHECK-SAME: (%[[SRC:.+]]: memref<?xvector<4xf32>>, %[[OFFSET:.+]]: index, %[[I:.+]]: index)
700//      CHECK:   %[[APPLY:.+]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[I]]]
701//      CHECK:   %[[LOAD:.+]] = gpu.subgroup_mma_load_matrix %[[SRC]][%[[APPLY]]] {leadDimension = 160 : index} : memref<?xvector<4xf32>> -> !gpu.mma_matrix<16x16xf16, "COp">
702//      CHECK:   return %[[LOAD]]
703
704// -----
705
706func.func @fold_gpu_subgroup_mma_store_matrix_1d(%dst: memref<?xvector<4xf32>>, %offset: index, %i: index, %matrix: !gpu.mma_matrix<16x16xf16, "COp">) {
707  %subview = memref.subview %dst[%offset] [81920] [1] : memref<?xvector<4xf32>> to memref<81920xvector<4xf32>, strided<[1], offset: ?>>
708  gpu.subgroup_mma_store_matrix %matrix, %subview[%i] {leadDimension = 160 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<81920xvector<4xf32>, strided<[1], offset: ?>>
709  return
710}
711
712//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
713//      CHECK: func.func @fold_gpu_subgroup_mma_store_matrix_1d
714// CHECK-SAME: (%[[DST:.+]]: memref<?xvector<4xf32>>, %[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[VAL:.+]]: !gpu.mma_matrix<16x16xf16, "COp">)
715//      CHECK:   %[[APPLY:.+]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[I0]]]
716//      CHECK:   gpu.subgroup_mma_store_matrix %[[VAL]], %[[DST]][%[[APPLY]]] {leadDimension = 160 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<?xvector<4xf32>>
717
718// -----
719
720// CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d
721//  CHECK-SAME: %[[SRC:.+]]: memref<128x128xf32>
722func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> !gpu.mma_matrix<16x16xf16, "COp"> {
723  %subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[256, 1], offset: ?>>
724  // CHECK: gpu.subgroup_mma_load_matrix %[[SRC]][{{.+}}] {leadDimension = 32 : index} : memref<128x128xf32> -> !gpu.mma_matrix<16x16xf16, "COp">
725  %matrix = gpu.subgroup_mma_load_matrix %subview[%arg3, %arg4] {leadDimension = 32 : index} : memref<64x32xf32, strided<[256, 1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
726  return %matrix : !gpu.mma_matrix<16x16xf16, "COp">
727}
728
729// -----
730
731// CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d
732//  CHECK-SAME: %[[DST:.+]]: memref<128x128xf32>
733func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %matrix: !gpu.mma_matrix<16x16xf16, "COp">) {
734  %subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[256, 1], offset: ?>>
735  // CHECK: gpu.subgroup_mma_store_matrix %{{.+}}, %[[DST]][{{.+}}] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf32>
736  gpu.subgroup_mma_store_matrix %matrix, %subview[%arg3, %arg4] {leadDimension = 32 : index} :  !gpu.mma_matrix<16x16xf16, "COp">, memref<64x32xf32, strided<[256, 1], offset: ?>>
737  return
738}
739
740// -----
741
742
743func.func @fold_nvgpu_device_async_copy_zero_sub_idx(%gmem_memref_3d : memref<2x128x768xf16>, %idx_1 : index, %idx_2 : index, %idx_3 : index) {
744
745  %c0 = arith.constant 0 : index
746  %smem_memref_4d = memref.alloc() : memref<5x1x64x64xf16, #gpu.address_space<workgroup>>
747  %gmem_memref_subview_2d = memref.subview %gmem_memref_3d[%idx_1, %idx_2, %idx_3] [1, 1, 8] [1, 1, 1] : memref<2x128x768xf16> to memref<1x8xf16, strided<[98304, 1], offset: ?>>
748  %async_token = nvgpu.device_async_copy %gmem_memref_subview_2d[%c0, %c0], %smem_memref_4d[%c0, %c0, %c0, %c0], 8 {bypassL1} : memref<1x8xf16, strided<[98304, 1], offset: ?>> to memref<5x1x64x64xf16, #gpu.address_space<workgroup>>
749  return
750}
751
752// CHECK-LABEL: func.func @fold_nvgpu_device_async_copy_zero_sub_idx
753//  CHECK-SAME: (%[[GMEM_MEMREF_3d:.+]]: memref<2x128x768xf16>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index, %[[IDX_3:.+]]: index)
754//   CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
755//   CHECK-DAG: %[[SMEM_MEMREF_4d:.+]] = memref.alloc() : memref<5x1x64x64xf16, #gpu.address_space<workgroup>>
756//       CHECK: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[IDX_1]], %[[IDX_2]], %[[IDX_3]]], %[[SMEM_MEMREF_4d]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space<workgroup>>
757
758// -----
759
760
761func.func @fold_src_nvgpu_device_async_copy(%gmem_memref_3d : memref<2x128x768xf16>, %src_idx_0 : index, %src_idx_1 : index, %src_idx_2 : index, %src_sub_idx_0 : index, %src_sub_idx_1 : index) {
762  %c0 = arith.constant 0 : index
763  %smem_memref_4d = memref.alloc() : memref<5x1x64x64xf16, #gpu.address_space<workgroup>>
764  %gmem_memref_subview_2d = memref.subview %gmem_memref_3d[%src_idx_0, %src_idx_1, %src_idx_2] [1, 1, 8] [1, 1, 1] : memref<2x128x768xf16> to memref<1x8xf16, strided<[98304, 1], offset: ?>>
765  %async_token = nvgpu.device_async_copy %gmem_memref_subview_2d[%src_sub_idx_0, %src_sub_idx_1], %smem_memref_4d[%c0, %c0, %c0, %c0], 8 {bypassL1} : memref<1x8xf16, strided<[98304, 1], offset: ?>> to memref<5x1x64x64xf16, #gpu.address_space<workgroup>>
766  return
767}
768
769//   CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
770//       CHECK: func.func @fold_src_nvgpu_device_async_copy
771//  CHECK-SAME: (%[[GMEM_MEMREF_3d:.+]]: memref<2x128x768xf16>, %[[SRC_IDX_0:.+]]: index, %[[SRC_IDX_1:.+]]: index, %[[SRC_IDX_2:.+]]: index, %[[SRC_SUB_IDX_0:.+]]: index, %[[SRC_SUB_IDX_1:.+]]: index)
772//   CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
773//   CHECK-DAG: %[[RESOLVED_SRC_IDX_0:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_0]], %[[SRC_SUB_IDX_0]]]
774//   CHECK-DAG: %[[RESOLVED_SRC_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_2]], %[[SRC_SUB_IDX_1]]]
775//   CHECK-DAG: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[RESOLVED_SRC_IDX_0]], %[[SRC_IDX_1]], %[[RESOLVED_SRC_IDX_1]]], %[[SMEM_MEMREF_4d]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space<workgroup>>
776
777// -----
778
779
780func.func @fold_src_fold_dest_nvgpu_device_async_copy(%gmem_memref_3d : memref<2x128x768xf16>, %src_idx_0 : index, %src_idx_1 : index, %src_idx_2 : index, %src_sub_idx_0 : index, %src_sub_idx_1 : index, %dest_idx_0 : index, %dest_idx_1 : index, %dest_idx_2 : index, %dest_idx_3 : index, %dest_sub_idx_0 : index, %dest_sub_idx_1 : index) {
781  %c0 = arith.constant 0 : index
782  %smem_memref_4d = memref.alloc() : memref<5x1x64x64xf16, #gpu.address_space<workgroup>>
783  %gmem_memref_subview_2d = memref.subview %gmem_memref_3d[%src_idx_0, %src_idx_1, %src_idx_2] [1, 1, 8] [1, 1, 1] : memref<2x128x768xf16> to memref<1x8xf16, strided<[98304, 1], offset: ?>>
784  %smem_memref_2d = memref.subview %smem_memref_4d[%dest_idx_0, %dest_idx_1, %dest_idx_2, %dest_idx_3] [1, 1, 1, 8] [1, 1, 1, 1] : memref<5x1x64x64xf16, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[4096, 1], offset: ?>, #gpu.address_space<workgroup>>
785  %async_token = nvgpu.device_async_copy %gmem_memref_subview_2d[%src_sub_idx_0, %src_sub_idx_1], %smem_memref_2d[%dest_sub_idx_0, %dest_sub_idx_1], 8 {bypassL1} : memref<1x8xf16, strided<[98304, 1], offset: ?>> to memref<1x8xf16, strided<[4096, 1], offset: ?>, #gpu.address_space<workgroup>>
786  return
787}
788
789//   CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
790//       CHECK: func.func @fold_src_fold_dest_nvgpu_device_async_copy
791//  CHECK-SAME: (%[[GMEM_MEMREF_3d:.+]]: memref<2x128x768xf16>, %[[SRC_IDX_0:.+]]: index, %[[SRC_IDX_1:.+]]: index, %[[SRC_IDX_2:.+]]: index, %[[SRC_SUB_IDX_0:.+]]: index, %[[SRC_SUB_IDX_1:.+]]: index, %[[DEST_IDX_0:.+]]: index, %[[DEST_IDX_1:.+]]: index, %[[DEST_IDX_2:.+]]: index, %[[DEST_IDX_3:.+]]: index, %[[DEST_SUB_IDX_0:.+]]: index, %[[DEST_SUB_IDX_1:.+]]: index)
792//   CHECK-DAG: %[[RESOLVED_SRC_IDX_0:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_0]], %[[SRC_SUB_IDX_0]]]
793//   CHECK-DAG: %[[RESOLVED_SRC_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_2]], %[[SRC_SUB_IDX_1]]]
794//   CHECK-DAG: %[[RESOLVED_DST_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_1]], %[[DEST_SUB_IDX_0]]]
795//   CHECK-DAG: %[[RESOLVED_DST_IDX_3:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_3]], %[[DEST_SUB_IDX_1]]]
796//   CHECK-DAG: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[RESOLVED_SRC_IDX_0]], %[[SRC_IDX_1]], %[[RESOLVED_SRC_IDX_1]]], %[[SMEM_MEMREF_4d]][%[[DEST_IDX_0]], %[[RESOLVED_DST_IDX_1]], %[[DEST_IDX_2]], %[[RESOLVED_DST_IDX_3]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space<workgroup>>
797
798// -----
799
800#map = affine_map<()[s0] -> (-s0 + 4)>
801#map1 = affine_map<()[s0] -> (-s0 + 32)>
802
803func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: index, %arg3: index) -> vector<4x2xf16> {
804  %c0 = arith.constant 0 : index
805  %0 = affine.apply #map()[%arg1]
806  %1 = affine.apply #map1()[%arg2]
807  %2 = affine.apply #map1()[%arg3]
808  %subview = memref.subview %arg0[%arg1, %arg2, %arg3] [%0, %1, %2] [1, 1, 1] : memref<4x32x32xf16, 3> to memref<?x?x?xf16, strided<[1024, 32, 1], offset: ?>, 3>
809  %3 = nvgpu.ldmatrix %subview[%c0, %c0, %c0] {numTiles = 4 : i32, transpose = false} : memref<?x?x?xf16, strided<[1024, 32, 1], offset: ?>, 3> -> vector<4x2xf16>
810  return %3 : vector<4x2xf16>
811}
812
813//      CHECK: func @test_ldmatrix
814// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x32x32xf16, 3>
815// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
816// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
817// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
818//      CHECK:   nvgpu.ldmatrix %[[ARG0]][%[[ARG1]], %[[ARG2]], %[[ARG3]]] {numTiles = 4 : i32, transpose = false} : memref<4x32x32xf16, 3> -> vector<4x2xf16>
819
820// -----
821
822func.func @fold_vector_load_subview(
823  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
824  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
825  %1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
826  return %1 : vector<12x32xf32>
827}
828
829//      CHECK: func @fold_vector_load_subview
830// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
831// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
832// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
833//      CHECK:   vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] :  memref<12x32xf32>, vector<12x32xf32>
834
835// -----
836
837func.func @fold_vector_maskedload_subview(
838  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
839  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
840  %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
841  return %1 : vector<32xf32>
842}
843
844//      CHECK: func @fold_vector_maskedload_subview
845// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
846// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
847// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
848// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
849// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
850//      CHECK:   vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>
851
852// -----
853
854func.func @fold_vector_store_subview(
855  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
856  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
857  vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
858  return
859}
860
861//      CHECK: func @fold_vector_store_subview
862// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
863// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
864// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
865// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<2x32xf32>
866//      CHECK:   vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] :  memref<12x32xf32>, vector<2x32xf32>
867//      CHECK:   return
868
869// -----
870
871func.func @fold_vector_maskedstore_subview(
872  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () {
873  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
874  vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32>
875  return
876}
877
878//      CHECK: func @fold_vector_maskedstore_subview
879// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
880// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
881// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
882// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
883// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
884//      CHECK:   vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
885//      CHECK:   return
886
887// -----
888
889func.func @fold_vector_load_expand_shape(
890  %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> {
891  %c0 = arith.constant 0 : index
892  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
893  %1 = vector.load %0[%arg1, %c0] {nontemporal = true} : memref<4x8xf32>, vector<8xf32>
894  return %1 : vector<8xf32>
895}
896
897//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
898// CHECK-LABEL: func @fold_vector_load_expand_shape
899//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
900//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
901//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
902//       CHECK:   vector.load %[[ARG0]][%[[IDX]]] {nontemporal = true}
903
904// -----
905
906func.func @fold_vector_maskedload_expand_shape(
907  %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
908  %c0 = arith.constant 0 : index
909  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
910  %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
911  return %1 : vector<8xf32>
912}
913
914//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
915// CHECK-LABEL: func @fold_vector_maskedload_expand_shape
916//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
917//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
918//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
919//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
920//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
921//       CHECK:   vector.maskedload %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
922
923// -----
924
925func.func @fold_vector_store_expand_shape(
926  %arg0 : memref<32xf32>, %arg1 : index, %val : vector<8xf32>) {
927  %c0 = arith.constant 0 : index
928  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
929  vector.store %val, %0[%arg1, %c0] {nontemporal = true} : memref<4x8xf32>, vector<8xf32>
930  return
931}
932
933//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
934// CHECK-LABEL: func @fold_vector_store_expand_shape
935//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
936//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
937//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
938//       CHECK:   vector.store %{{.*}}, %[[ARG0]][%[[IDX]]] {nontemporal = true}
939
940// -----
941
942func.func @fold_vector_maskedstore_expand_shape(
943  %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
944  %c0 = arith.constant 0 : index
945  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
946  vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
947  return
948}
949
950//   CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
951// CHECK-LABEL: func @fold_vector_maskedstore_expand_shape
952//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
953//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
954//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
955//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
956//       CHECK:   %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
957//       CHECK:   vector.maskedstore %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
958
959// -----
960
961func.func @fold_vector_load_collapse_shape(
962  %arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> {
963  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
964  %1 = vector.load %0[%arg1] {nontemporal = true} : memref<32xf32>, vector<8xf32>
965  return %1 : vector<8xf32>
966}
967
968//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
969//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
970// CHECK-LABEL: func @fold_vector_load_collapse_shape
971//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
972//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
973//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
974//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
975//       CHECK:   vector.load %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true}
976
977// -----
978
979func.func @fold_vector_maskedload_collapse_shape(
980  %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
981  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
982  %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
983  return %1 : vector<8xf32>
984}
985
986//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
987//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
988// CHECK-LABEL: func @fold_vector_maskedload_collapse_shape
989//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
990//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
991//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
992//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
993//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
994//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
995//       CHECK:   vector.maskedload %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
996
997// -----
998
999func.func @fold_vector_store_collapse_shape(
1000  %arg0 : memref<4x8xf32>, %arg1 : index, %val : vector<8xf32>) {
1001  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
1002  vector.store %val, %0[%arg1] {nontemporal = true} : memref<32xf32>, vector<8xf32>
1003  return
1004}
1005
1006//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
1007//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
1008// CHECK-LABEL: func @fold_vector_store_collapse_shape
1009//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
1010//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
1011//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
1012//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
1013//       CHECK:   vector.store %{{.*}}, %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true}
1014
1015// -----
1016
1017func.func @fold_vector_maskedstore_collapse_shape(
1018  %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
1019  %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
1020  vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32>
1021  return
1022}
1023
1024//   CHECK-DAG: #[[$MAP:.*]]  = affine_map<()[s0] -> (s0 floordiv 8)>
1025//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
1026// CHECK-LABEL: func @fold_vector_maskedstore_collapse_shape
1027//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
1028//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
1029//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
1030//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
1031//       CHECK:   %[[IDX:.*]] = affine.apply  #[[$MAP]]()[%[[ARG1]]]
1032//       CHECK:   %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
1033//       CHECK:   vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
1034