xref: /llvm-project/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir (revision 5a9bdd85ee4d8527e2cedf44f3ce26ff414f9b6a)
1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2
3// CHECK-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)>
4// CHECK-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)>
5// CHECK-DAG: #[[$bounds_map_4:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)>
6// CHECK-DAG: #[[$bounds_map_8:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)>
7
8// CHECK-LABEL: split_vector_transfer_read_2d(
9//  CHECK-SAME: %[[A:[a-zA-Z0-9_]*]]: memref
10//  CHECK-SAME: %[[i:[a-zA-Z0-9_]*]]: index
11//  CHECK-SAME: %[[j:[a-zA-Z0-9_]*]]: index
12func.func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -> vector<4x8xf32> {
13  %c0 = arith.constant 0 : index
14  %f0 = arith.constant 0.0 : f32
15
16  //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
17  //  CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
18  //  CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
19  // alloca for boundary full tile
20  //      CHECK: %[[alloc:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32>
21  // %i + 4 <= dim(%A, 0)
22  //      CHECK: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
23  //      CHECK: %[[d0:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x8xf32>
24  //      CHECK: %[[cmp0:.*]] = arith.cmpi sle, %[[idx0]], %[[d0]] : index
25  // %j + 8 <= dim(%A, 1)
26  //      CHECK: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
27  //      CHECK: %[[cmp1:.*]] = arith.cmpi sle, %[[idx1]], %[[c8]] : index
28  // are both conds true
29  //      CHECK: %[[cond:.*]] = arith.andi %[[cmp0]], %[[cmp1]] : i1
30  //      CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32>, index, index) {
31  //               inBounds, just yield %A
32  //      CHECK:   scf.yield %[[A]], %[[i]], %[[j]] : memref<?x8xf32>, index, index
33  //      CHECK: } else {
34  //               slow path, fill tmp alloc and yield a memref_casted version of it
35  //      CHECK:   linalg.fill ins(%cst : f32) outs(%[[alloc]] : memref<4x8xf32>)
36  //      CHECK:   %[[d0:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x8xf32>
37  //      CHECK:   %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[d0]], %[[i]], %[[c4]])
38  //      CHECK:   %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
39  //      CHECK:   %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1]
40  // CHECK-SAME:     memref<?x8xf32> to memref<?x?xf32, strided<[8, 1], offset: ?>>
41  //      CHECK:   %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1]
42  //      CHECK:   memref.copy %[[sv]], %[[alloc_view]] : memref<?x?xf32, strided<[8, 1], offset: ?>> to memref<?x?xf32, strided{{.*}}>
43  //      CHECK:   %[[yielded:.*]] = memref.cast %[[alloc]] :
44  // CHECK-SAME:     memref<4x8xf32> to memref<?x8xf32>
45  //      CHECK:   scf.yield %[[yielded]], %[[c0]], %[[c0]] :
46  // CHECK-SAME:     memref<?x8xf32>, index, index
47  //      CHECK: }
48  //      CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %cst
49  // CHECK-SAME:   {in_bounds = [true, true]} : memref<?x8xf32>, vector<4x8xf32>
50  %1 = vector.transfer_read %A[%i, %j], %f0 : memref<?x8xf32>, vector<4x8xf32>
51
52  // CHECK: return %[[res]] : vector<4x8xf32>
53  return %1: vector<4x8xf32>
54}
55
56// CHECK-LABEL: split_vector_transfer_read_strided_2d(
57//  CHECK-SAME: %[[A:[a-zA-Z0-9_]*]]: memref
58//  CHECK-SAME: %[[i:[a-zA-Z0-9_]*]]: index
59//  CHECK-SAME: %[[j:[a-zA-Z0-9_]*]]: index
60func.func @split_vector_transfer_read_strided_2d(
61    %A: memref<7x8xf32, strided<[?, 1], offset: ?>>,
62    %i: index, %j: index) -> vector<4x8xf32> {
63  %c0 = arith.constant 0 : index
64  %f0 = arith.constant 0.0 : f32
65
66
67  //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
68  //  CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
69  //  CHECK-DAG: %[[c7:.*]] = arith.constant 7 : index
70  //  CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
71  // alloca for boundary full tile
72  //      CHECK: %[[alloc:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32>
73  // %i + 4 <= dim(%A, 0)
74  //      CHECK: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
75  //      CHECK: %[[cmp0:.*]] = arith.cmpi sle, %[[idx0]], %[[c7]] : index
76  // %j + 8 <= dim(%A, 1)
77  //      CHECK: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
78  //      CHECK: %[[cmp1:.*]] = arith.cmpi sle, %[[idx1]], %[[c8]] : index
79  // are both conds true
80  //      CHECK: %[[cond:.*]] = arith.andi %[[cmp0]], %[[cmp1]] : i1
81  //      CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index) {
82  //               inBounds but not cast-compatible: yield a memref_casted form of %A
83  //      CHECK:   %[[casted:.*]] = memref.cast %arg0 :
84  // CHECK-SAME:     memref<7x8xf32, strided<[?, 1], offset: ?>> to memref<?x8xf32, strided<[?, 1], offset: ?>>
85  //      CHECK:   scf.yield %[[casted]], %[[i]], %[[j]] :
86  // CHECK-SAME:     memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index
87  //      CHECK: } else {
88  //               slow path, fill tmp alloc and yield a memref_casted version of it
89  //      CHECK:   linalg.fill ins(%cst : f32) outs(%[[alloc]] : memref<4x8xf32>)
90  //      CHECK:   %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[c7]], %[[i]], %[[c4]])
91  //      CHECK:   %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
92  //      CHECK:   %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1]
93  // CHECK-SAME:     memref<7x8xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
94  //      CHECK:   %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1]
95  //      CHECK:   memref.copy %[[sv]], %[[alloc_view]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided{{.*}}>
96  //      CHECK:   %[[yielded:.*]] = memref.cast %[[alloc]] :
97  // CHECK-SAME:     memref<4x8xf32> to memref<?x8xf32, strided<[?, 1], offset: ?>>
98  //      CHECK:   scf.yield %[[yielded]], %[[c0]], %[[c0]] :
99  // CHECK-SAME:     memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index
100  //      CHECK: }
101  //      CHECK: %[[res:.*]] = vector.transfer_read {{.*}} {in_bounds = [true, true]} :
102  // CHECK-SAME:   memref<?x8xf32, strided<[?, 1], offset: ?>>, vector<4x8xf32>
103  %1 = vector.transfer_read %A[%i, %j], %f0 :
104    memref<7x8xf32, strided<[?, 1], offset: ?>>, vector<4x8xf32>
105
106  return %1 : vector<4x8xf32>
107}
108
109module attributes {transform.with_named_sequence} {
110  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
111    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
112    transform.apply_patterns to %func_op {
113      transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
114    } : !transform.op<"func.func">
115    transform.yield
116  }
117}
118
119// -----
120
121func.func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref<?x8xf32>, %i: index, %j: index) {
122  vector.transfer_write %V, %A[%i, %j] :
123    vector<4x8xf32>, memref<?x8xf32>
124  return
125}
126
127// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 + 4)>
128// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 8)>
129// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)>
130// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)>
131
132// CHECK-LABEL:     func @split_vector_transfer_write_2d(
133// CHECK-SAME:                                         %[[VEC:.*]]: vector<4x8xf32>,
134// CHECK-SAME:                                         %[[DEST:.*]]: memref<?x8xf32>,
135// CHECK-SAME:                                         %[[I:.*]]: index,
136// CHECK-SAME:                                         %[[J:.*]]: index) {
137// CHECK-DAG:       %[[CT:.*]] = arith.constant true
138// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
139// CHECK-DAG:       %[[C4:.*]] = arith.constant 4 : index
140// CHECK-DAG:       %[[C8:.*]] = arith.constant 8 : index
141// CHECK:           %[[TEMP:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32>
142// CHECK:           %[[IDX0:.*]] = affine.apply #[[$MAP0]]()[%[[I]]]
143// CHECK:           %[[DIM0:.*]] = memref.dim %[[DEST]], %[[C0]] : memref<?x8xf32>
144// CHECK:           %[[DIM0_IN:.*]] = arith.cmpi sle, %[[IDX0]], %[[DIM0]] : index
145// CHECK:           %[[DIM1:.*]] = affine.apply #[[$MAP1]]()[%[[J]]]
146// CHECK:           %[[DIM1_IN:.*]] = arith.cmpi sle, %[[DIM1]], %[[C8]] : index
147// CHECK:           %[[IN_BOUNDS:.*]] = arith.andi %[[DIM0_IN]], %[[DIM1_IN]] : i1
148// CHECK:           %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]]
149// CHECK-SAME:          -> (memref<?x8xf32>, index, index) {
150// CHECK:             scf.yield %[[DEST]], %[[I]], %[[J]] : memref<?x8xf32>, index, index
151// CHECK:           } else {
152// CHECK:             %[[VAL_16:.*]] = memref.cast %[[TEMP]] : memref<4x8xf32> to memref<?x8xf32>
153// CHECK:             scf.yield %[[VAL_16]], %[[C0]], %[[C0]] : memref<?x8xf32>, index, index
154// CHECK:           }
155// CHECK:           vector.transfer_write %[[VEC]],
156// CHECK-SAME:          %[[IN_BOUND_DEST:.*]]#0[%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2]
157// CHECK-SAME:          {in_bounds = [true, true]} : vector<4x8xf32>, memref<?x8xf32>
158// CHECK:           %[[OUT_BOUNDS:.*]] = arith.xori %[[IN_BOUNDS]], %[[CT]] : i1
159// CHECK:           scf.if %[[OUT_BOUNDS]] {
160// CHECK:             %[[VAL_19:.*]] = memref.dim %[[DEST]], %[[C0]] : memref<?x8xf32>
161// CHECK-DAG:         %[[VAL_20:.*]] = affine.min #[[$MAP2]](%[[VAL_19]], %[[I]], %[[C4]])
162// CHECK-DAG:         %[[VAL_21:.*]] = affine.min #[[$MAP3]](%[[C8]], %[[J]], %[[C8]])
163// CHECK:             %[[VAL_22:.*]] = memref.subview %[[TEMP]]
164// CHECK-SAME:            [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]]
165// CHECK-SAME:            [1, 1] : memref<4x8xf32> to memref<?x?xf32, strided<[8, 1], offset: ?>>
166// CHECK:             %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1]
167// CHECK:             memref.copy %[[VAL_22]], %[[DEST_VIEW]]
168// CHECK-SAME:            : memref<?x?xf32, strided<[8, 1], offset: ?>> to memref<?x?xf32, strided{{.*}}>
169// CHECK:           }
170// CHECK:           return
171// CHECK:         }
172
173module attributes {transform.with_named_sequence} {
174  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
175    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
176    transform.apply_patterns to %func_op {
177      transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
178    } : !transform.op<"func.func">
179    transform.yield
180  }
181}
182
183// -----
184
185func.func @split_vector_transfer_write_strided_2d(
186    %V: vector<4x8xf32>, %A: memref<7x8xf32, strided<[?, 1], offset: ?>>,
187    %i: index, %j: index) {
188  vector.transfer_write %V, %A[%i, %j] :
189    vector<4x8xf32>, memref<7x8xf32, strided<[?, 1], offset: ?>>
190  return
191}
192
193// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 4)>
194// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 8)>
195// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)>
196// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)>
197// CHECK-LABEL:   func @split_vector_transfer_write_strided_2d(
198// CHECK-SAME:                                                 %[[VEC:.*]]: vector<4x8xf32>,
199// CHECK-SAME:                                                 %[[DEST:.*]]: memref<7x8xf32, strided<[?, 1], offset: ?>>,
200// CHECK-SAME:                                                 %[[I:.*]]: index,
201// CHECK-SAME:                                                 %[[J:.*]]: index) {
202// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
203// CHECK-DAG:       %[[CT:.*]] = arith.constant true
204// CHECK-DAG:       %[[C7:.*]] = arith.constant 7 : index
205// CHECK-DAG:       %[[C4:.*]] = arith.constant 4 : index
206// CHECK-DAG:       %[[C8:.*]] = arith.constant 8 : index
207// CHECK:           %[[TEMP:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32>
208// CHECK:           %[[DIM0:.*]] = affine.apply #[[$MAP1]]()[%[[I]]]
209// CHECK:           %[[DIM0_IN:.*]] = arith.cmpi sle, %[[DIM0]], %[[C7]] : index
210// CHECK:           %[[DIM1:.*]] = affine.apply #[[$MAP2]]()[%[[J]]]
211// CHECK:           %[[DIM1_IN:.*]] = arith.cmpi sle, %[[DIM1]], %[[C8]] : index
212// CHECK:           %[[IN_BOUNDS:.*]] = arith.andi %[[DIM0_IN]], %[[DIM1_IN]] : i1
213// CHECK:           %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]]
214// CHECK-SAME:          -> (memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index) {
215// CHECK:             %[[VAL_16:.*]] = memref.cast %[[DEST]]
216// CHECK-SAME:            : memref<7x8xf32, strided<[?, 1], offset: ?>> to memref<?x8xf32, strided<[?, 1], offset: ?>>
217// CHECK:             scf.yield %[[VAL_16]], %[[I]], %[[J]]
218// CHECK-SAME:            : memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index
219// CHECK:           } else {
220// CHECK:             %[[VAL_17:.*]] = memref.cast %[[TEMP]]
221// CHECK-SAME:            : memref<4x8xf32> to memref<?x8xf32, strided<[?, 1], offset: ?>>
222// CHECK:             scf.yield %[[VAL_17]], %[[C0]], %[[C0]]
223// CHECK-SAME:            : memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index
224// CHECK:           }
225// CHECK:           vector.transfer_write %[[VEC]],
226// CHECK-SAME:          %[[IN_BOUND_DEST:.*]]#0
227// CHECK-SAME:          [%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2]
228// CHECK-SAME:          {in_bounds = [true, true]}
229// CHECK-SAME:          : vector<4x8xf32>, memref<?x8xf32, strided<[?, 1], offset: ?>>
230// CHECK:           %[[OUT_BOUNDS:.*]] = arith.xori %[[IN_BOUNDS]], %[[CT]] : i1
231// CHECK:           scf.if %[[OUT_BOUNDS]] {
232// CHECK-DAG:         %[[VAL_20:.*]] = affine.min #[[$MAP3]](%[[C7]], %[[I]], %[[C4]])
233// CHECK-DAG:         %[[VAL_21:.*]] = affine.min #[[$MAP4]](%[[C8]], %[[J]], %[[C8]])
234// CHECK:             %[[VAL_22:.*]] = memref.subview %[[TEMP]]
235// CHECK-SAME:            [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]]
236// CHECK-SAME:            [1, 1] : memref<4x8xf32> to memref<?x?xf32, strided<[8, 1], offset: ?>>
237// CHECK:             %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1]
238// CHECK:             memref.copy %[[VAL_22]], %[[DEST_VIEW]]
239// CHECK-SAME:            : memref<?x?xf32, strided<[8, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
240// CHECK:           }
241// CHECK:           return
242// CHECK:         }
243
244module attributes {transform.with_named_sequence} {
245  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
246    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
247    transform.apply_patterns to %func_op {
248      transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
249    } : !transform.op<"func.func">
250    transform.yield
251  }
252}
253