xref: /llvm-project/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (revision 5123f2c60a6a357c0384dae6b189fa0f63ba34ef)
1// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
2// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
3
4// TODO: Align naming and format with e.g. vector-transfer-permutation-lowering.mlir
5
6///----------------------------------------------------------------------------------------
7/// vector.transfer_read
8/// [Pattern: FlattenContiguousRowMajorTransferReadPattern]
9///
10/// NOTE: Scalable vectors are not supported
11///----------------------------------------------------------------------------------------
12
13func.func @transfer_read_dims_match_contiguous(
14    %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
15
16  %c0 = arith.constant 0 : index
17  %cst = arith.constant 0 : i8
18  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
19    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
20  return %res : vector<5x4x3x2xi8>
21}
22
23// CHECK-LABEL: func @transfer_read_dims_match_contiguous
24// CHECK-SAME:    %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
25// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{.}}[0, 1, 2, 3]
26// CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
27// CHECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
28// CHECK:         return %[[VEC2D]]
29
30// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous
31//       CHECK-128B:   memref.collapse_shape
32
33func.func @transfer_read_dims_match_contiguous_scalable(
34    %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x[2]xi8> {
35
36  %c0 = arith.constant 0 : index
37  %cst = arith.constant 0 : i8
38  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
39    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x[2]xi8>
40  return %res : vector<5x4x3x[2]xi8>
41}
42
43// CHECK-LABEL: func @transfer_read_dims_match_contiguous_scalable
44// CHECK-NOT: memref.collapse_shape
45
46// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous_scalable
47//   CHECK-128B-NOT:   memref.collapse_shape
48
49// -----
50
51func.func @transfer_read_dims_match_contiguous_empty_stride(
52    %mem : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> {
53
54  %c0 = arith.constant 0 : index
55  %cst = arith.constant 0 : i8
56  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
57    memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
58  return %res : vector<5x4x3x2xi8>
59}
60
61// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
62// CHECK-SAME:    %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
63// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{.}}[0, 1, 2, 3]
64// CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
65// CHECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
66// CHECK:         return %[[VEC2D]]
67
68// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
69//       CHECK-128B:   memref.collapse_shape
70
71// -----
72
73// The shape of the memref and the vector don't match, but the vector is a
74// contiguous subset of the memref, so "flattenable".
75
76func.func @transfer_read_dims_mismatch_contiguous(
77    %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
78
79  %c0 = arith.constant 0 : index
80  %cst = arith.constant 0 : i8
81  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
82    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
83  return %res : vector<1x1x2x2xi8>
84}
85
86// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_contiguous(
87// CHECK-SAME:      %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
88// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i8
89// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
90// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
91// CHECK:           %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
92// CHECK:           %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
93// CHECK:           return %[[VAL_5]] : vector<1x1x2x2xi8>
94
95// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
96//       CHECK-128B:   memref.collapse_shape
97
98// -----
99
100func.func @transfer_read_dims_mismatch_non_zero_indices(
101    %idx_1: index,
102    %idx_2: index,
103    %mem: memref<1x43x4x6xi32>) -> vector<1x2x6xi32>{
104
105  %c0 = arith.constant 0 : index
106  %c0_i32 = arith.constant 0 : i32
107  %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 {
108    in_bounds = [true, true, true]
109  } : memref<1x43x4x6xi32>, vector<1x2x6xi32>
110  return %res : vector<1x2x6xi32>
111}
112
113// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
114
115// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices(
116// CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
117// CHECK-SAME:      %[[MEM:.*]]: memref<1x43x4x6xi32>
118// CHECK:           %[[C_0:.*]] = arith.constant 0 : i32
119// CHECK:           %[[C_0_IDX:.*]] = arith.constant 0 : index
120// CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
121// CHECK:           %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
122// CHECK:           %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
123
124// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
125//   CHECK-128B-NOT:   memref.collapse_shape
126
127// -----
128
129// Overall, the source memref is non-contiguous. However, the slice from which
130// the output vector is to be read _is_ contiguous. Hence the flattening works fine.
131
132func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
133    %mem : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
134    %idx_1 : index,
135    %idx_2 : index) -> vector<2x2xf32> {
136
137  %c0 = arith.constant 0 : index
138  %cst_1 = arith.constant 0.000000e+00 : f32
139  %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %cst_1 {
140    in_bounds = [true, true]
141  } : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
142  return %res : vector<2x2xf32>
143}
144
145// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
146
147// CHECK-LABEL:  func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
148// CHECK:         %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]]
149// CHECK-SAME:      : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
150// CHECK:         %[[APPLY:.*]] = affine.apply #[[$MAP]]()
151
152// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
153//       CHECK-128B:   memref.collapse_shape
154
155// -----
156
157// The leading dynamic shapes don't affect whether this example is flattenable
158// or not. Indeed, those dynamic shapes are not candidates for flattening anyway.
159
160func.func @transfer_read_leading_dynamic_dims(
161    %mem : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
162    %idx_1 : index,
163    %idx_2 : index) -> vector<8x4xi8> {
164
165  %c0_i8 = arith.constant 0 : i8
166  %c0 = arith.constant 0 : index
167  %res = vector.transfer_read %mem[%idx_1, %idx_2, %c0, %c0], %c0_i8 {
168    in_bounds = [true, true]
169  } : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, vector<8x4xi8>
170  return %res : vector<8x4xi8>
171}
172
173// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
174// CHECK-SAME:    %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index
175// CHECK:         %[[C0_I8:.+]] = arith.constant 0 : i8
176// CHECK:         %[[C0:.+]] = arith.constant 0 : index
177// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}}
178// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
179// CHECK:         %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
180// CHECK-SAME:    [%[[IDX_1]], %[[IDX_2]], %[[C0]]], %[[C0_I8]]
181// CHECK-SAME:    {in_bounds = [true]}
182// CHECK-SAME:      : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
183// CHECK:         %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
184// CHECK:         return %[[RES]] : vector<8x4xi8>
185
186// CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims
187//       CHECK-128B:   memref.collapse_shape
188
189// -----
190
191// One of the dims to be flattened is dynamic - not supported ATM.
192
193func.func @negative_transfer_read_dynamic_dim_to_flatten(
194    %idx_1: index,
195    %idx_2: index,
196    %mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
197
198  %c0 = arith.constant 0 : index
199  %c0_i32 = arith.constant 0 : i32
200  %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 {
201    in_bounds = [true, true, true]
202  } : memref<1x?x4x6xi32>, vector<1x2x6xi32>
203  return %res : vector<1x2x6xi32>
204}
205
206// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
207// CHECK-NOT: memref.collapse_shape
208// CHECK-NOT: vector.shape_cast
209
210// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
211//   CHECK-128B-NOT:   memref.collapse_shape
212
213// -----
214
215// The vector to be read represents a _non-contiguous_ slice of the input
216// memref.
217
218func.func @transfer_read_dims_mismatch_non_contiguous_slice(
219    %mem : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
220
221  %c0 = arith.constant 0 : index
222  %cst = arith.constant 0 : i8
223  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
224    memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
225  return %res : vector<2x1x2x2xi8>
226}
227
228// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_slice(
229// CHECK-NOT: memref.collapse_shape
230// CHECK-NOT: vector.shape_cast
231
232// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_slice(
233//   CHECK-128B-NOT:   memref.collapse_shape
234
235// -----
236
237func.func @transfer_read_0d(
238    %mem : memref<i8>) -> vector<i8> {
239
240  %cst = arith.constant 0 : i8
241  %res = vector.transfer_read %mem[], %cst : memref<i8>, vector<i8>
242  return %res : vector<i8>
243}
244
245// CHECK-LABEL: func.func @transfer_read_0d
246// CHECK-NOT: memref.collapse_shape
247// CHECK-NOT: vector.shape_cast
248
249// CHECK-128B-LABEL: func @transfer_read_0d(
250//   CHECK-128B-NOT:   memref.collapse_shape
251//   CHECK-128B-NOT:   vector.shape_cast
252
253// -----
254
255// Strides make the input memref non-contiguous, hence non-flattenable.
256
257func.func @transfer_read_non_contiguous_src(
258    %mem : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
259
260  %c0 = arith.constant 0 : index
261  %cst = arith.constant 0 : i8
262  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
263    memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
264  return %res : vector<5x4x3x2xi8>
265}
266
267// CHECK-LABEL: func.func @transfer_read_non_contiguous_src
268// CHECK-NOT: memref.collapse_shape
269// CHECK-NOT: vector.shape_cast
270
271// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src
272//   CHECK-128B-NOT:   memref.collapse_shape
273//   CHECK-128B-NOT:   vector.shape_cast
274
275// -----
276
277///----------------------------------------------------------------------------------------
278/// vector.transfer_write
279/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
280///
281/// NOTE: Scalable vectors are not supported
282///----------------------------------------------------------------------------------------
283
284func.func @transfer_write_dims_match_contiguous(
285    %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
286    %vec : vector<5x4x3x2xi8>) {
287
288  %c0 = arith.constant 0 : index
289  vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
290    vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
291  return
292}
293
294// CHECK-LABEL: func @transfer_write_dims_match_contiguous(
295// CHECK-SAME:    %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
296// CHECK-SAME:    %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
297// CHECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
298// CHECK-DAG:     %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
299// CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
300
301// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous(
302//       CHECK-128B:   memref.collapse_shape
303
304func.func @transfer_write_dims_match_contiguous_scalable(
305    %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
306    %vec : vector<5x4x3x[2]xi8>) {
307
308  %c0 = arith.constant 0 : index
309  vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
310    vector<5x4x3x[2]xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
311  return
312}
313
314// CHECK-LABEL: func @transfer_write_dims_match_contiguous_scalable(
315// CHECK-NOT:   memref.collapse_shape
316
317// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous_scalable
318//   CHECK-128B-NOT:   memref.collapse_shape
319
320// -----
321
322func.func @transfer_write_dims_match_contiguous_empty_stride(
323    %mem : memref<5x4x3x2xi8>,
324    %vec : vector<5x4x3x2xi8>) {
325
326  %c0 = arith.constant 0 : index
327  vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
328    vector<5x4x3x2xi8>, memref<5x4x3x2xi8>
329  return
330}
331
332// CHECK-LABEL: func @transfer_write_dims_match_contiguous_empty_stride(
333// CHECK-SAME:    %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
334// CHECK-SAME:    %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
335// CHECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8>
336// CHECK-DAG:     %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
337// CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
338
339// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous_empty_stride(
340//       CHECK-128B:   memref.collapse_shape
341
342// -----
343
344func.func @transfer_write_dims_mismatch_contiguous(
345    %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
346    %vec : vector<1x1x2x2xi8>) {
347
348  %c0 = arith.constant 0 : index
349  vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
350    vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
351  return
352}
353
354// CHECK-LABEL:   func.func @transfer_write_dims_mismatch_contiguous
355// CHECK-SAME:      %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
356// CHECK-SAME:      %[[VEC:.*]]: vector<1x1x2x2xi8>) {
357// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
358// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
359// CHECK:           %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
360// CHECK:           vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
361
362// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
363//       CHECK-128B:   memref.collapse_shape
364
365// -----
366
367func.func @transfer_write_dims_mismatch_non_zero_indices(
368    %idx_1: index,
369    %idx_2: index,
370    %mem: memref<1x43x4x6xi32>,
371    %vec: vector<1x2x6xi32>) {
372
373  %c0 = arith.constant 0 : index
374  %c0_i32 = arith.constant 0 : i32
375  vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} :
376    vector<1x2x6xi32>, memref<1x43x4x6xi32>
377  return
378}
379
380// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
381
382// CHECK-LABEL:   func.func @transfer_write_dims_mismatch_non_zero_indices(
383// CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
384// CHECK-SAME:      %[[MEM:.*]]: memref<1x43x4x6xi32>,
385// CHECK-SAME:      %[[VEC:.*]]: vector<1x2x6xi32>) {
386// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
387// CHECK-DAG:       %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
388// CHECK-DAG:       %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
389// CHECK:           %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
390// CHECK:           vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32>
391
392// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices(
393//   CHECK-128B-NOT:   memref.collapse_shape
394
395// -----
396
397// Overall, the destination memref is non-contiguous. However, the slice to
398// which the input vector is to be written _is_ contiguous. Hence the
399// flattening works fine.
400
401func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
402    %vec : vector<2x2xf32>,
403    %mem : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
404    %idx_1 : index,
405    %idx_2 : index) {
406
407  %c0 = arith.constant 0 : index
408  vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
409  return
410}
411
412// CHECK:  #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
413
414// CHECK-LABEL:  func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
415// CHECK-DAG:      %[[APPLY:.*]] = affine.apply #[[$MAP]]()
416// CHECK-DAG:      %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
417
418// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
419//       CHECK-128B:   memref.collapse_shape
420
421// -----
422
423// The leading dynamic shapes don't affect whether this example is flattenable
424// or not. Indeed, those dynamic shapes are not candidates for flattening anyway.
425
426func.func @transfer_write_leading_dynamic_dims(
427    %vec : vector<8x4xi8>,
428    %mem : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
429    %idx_1 : index,
430    %idx_2 : index) {
431
432  %c0 = arith.constant 0 : index
433  vector.transfer_write %vec, %mem[%idx_1, %idx_2, %c0, %c0] {in_bounds = [true, true]} :
434    vector<8x4xi8>, memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>
435  return
436}
437
438// CHECK-LABEL: func @transfer_write_leading_dynamic_dims
439// CHECK-SAME:    %[[VEC:.+]]: vector<8x4xi8>, %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
440// CHECK:         %[[C0:.+]] = arith.constant 0 : index
441// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}}
442// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
443// CHECK:         %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<8x4xi8> to vector<32xi8>
444// CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
445// CHECK-SAME:      [%[[ARG2]], %[[ARG3]], %[[C0]]]
446// CHECK-SAME:      {in_bounds = [true]}
447// CHECK-SAME:      : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
448
449// CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
450//       CHECK-128B:   memref.collapse_shape
451
452// -----
453
454// One of the dims to be flattened is dynamic - not supported ATM.
455
456func.func @negative_transfer_write_dynamic_to_flatten(
457    %idx_1: index,
458    %idx_2: index,
459    %vec : vector<1x2x6xi32>,
460    %mem: memref<1x?x4x6xi32>) {
461
462  %c0 = arith.constant 0 : index
463  %c0_i32 = arith.constant 0 : i32
464  vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} :
465    vector<1x2x6xi32>, memref<1x?x4x6xi32>
466  return
467}
468
469// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
470// CHECK-NOT: memref.collapse_shape
471// CHECK-NOT: vector.shape_cast
472
473// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
474//   CHECK-128B-NOT:   memref.collapse_shape
475
476// -----
477
478// The vector to be written represents a _non-contiguous_ slice of the output
479// memref.
480
481func.func @transfer_write_dims_mismatch_non_contiguous_slice(
482    %mem : memref<5x4x3x2xi8>,
483    %vec : vector<2x1x2x2xi8>) {
484
485  %c0 = arith.constant 0 : index
486  %cst = arith.constant 0 : i8
487  vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] :
488    vector<2x1x2x2xi8>, memref<5x4x3x2xi8>
489  return
490}
491
492// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous_slice(
493// CHECK-NOT: memref.collapse_shape
494// CHECK-NOT: vector.shape_cast
495
496// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous_slice(
497//   CHECK-128B-NOT:   memref.collapse_shape
498
499// -----
500
501func.func @transfer_write_0d(
502    %mem : memref<i8>,
503    %vec : vector<i8>) {
504
505  vector.transfer_write %vec, %mem[] : vector<i8>, memref<i8>
506  return
507}
508
509// CHECK-LABEL: func.func @transfer_write_0d
510// CHECK-NOT: memref.collapse_shape
511// CHECK-NOT: vector.shape_cast
512
513// CHECK-128B-LABEL: func @transfer_write_0d(
514//   CHECK-128B-NOT:   memref.collapse_shape
515//   CHECK-128B-NOT:   vector.shape_cast
516
517// -----
518
519// The strides make the input memref non-contiguous, hence non-flattenable.
520
521func.func @transfer_write_non_contiguous_src(
522    %mem : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>,
523    %vec : vector<5x4x3x2xi8>) {
524
525  %c0 = arith.constant 0 : index
526  vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] :
527   vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>
528  return
529}
530
531// CHECK-LABEL: func.func @transfer_write_non_contiguous_src
532// CHECK-NOT: memref.collapse_shape
533// CHECK-NOT: vector.shape_cast
534
535// CHECK-128B-LABEL: func @transfer_write_non_contiguous_src
536//   CHECK-128B-NOT:   memref.collapse_shape
537//   CHECK-128B-NOT:   vector.shape_cast
538
539// -----
540
541func.func @negative_out_of_bound_transfer_read(
542    %mem : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
543  %c0 = arith.constant 0 : index
544  %cst = arith.constant 0 : i8
545  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst {in_bounds = [false, true, true, true]} :
546    memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
547  return %res : vector<5x4x3x2xi8>
548}
549// CHECK:     func.func @negative_out_of_bound_transfer_read
550// CHECK-NOT:   memref.collapse_shape
551
552// -----
553
554func.func @negative_out_of_bound_transfer_write(
555    %mem : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x3x2xi8>) {
556  %c0 = arith.constant 0 : index
557  vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] {in_bounds = [false, true, true, true]} :
558    vector<1x1x3x2xi8>, memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
559  return
560}
561// CHECK:     func.func @negative_out_of_bound_transfer_write
562// CHECK-NOT:   memref.collapse_shape
563