xref: /llvm-project/mlir/test/Dialect/Linalg/hoisting.mlir (revision 2bff9d9ffe3a4813961c1cf3af2e9ac5a20190bd)
1// RUN: mlir-opt  -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
2
3// CHECK-LABEL: func @hoist_vector_transfer_pairs(
4//  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
5//  CHECK-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
6//  CHECK-SAME:   %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
7//  CHECK-SAME:   %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
8//  CHECK-SAME:   %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
9//  CHECK-SAME:   %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
10//  CHECK-SAME:   %[[VAL:[a-zA-Z0-9]*]]: index,
11//  CHECK-SAME:   %[[LB:[a-zA-Z0-9]*]]: index,
12//  CHECK-SAME:   %[[UB:[a-zA-Z0-9]*]]: index,
13//  CHECK-SAME:   %[[STEP:[a-zA-Z0-9]*]]: index,
14//  CHECK-SAME:   %[[CMP:[a-zA-Z0-9]*]]: i1
15func.func @hoist_vector_transfer_pairs(
16    %memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>,
17    %memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>,
18    %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
19  %c0 = arith.constant 0 : index
20  %cst = arith.constant 0.0 : f32
21
22// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
23// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
24// CHECK:   vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
25// CHECK:   scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
26// CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
27// CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
28// CHECK:     "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
29// CHECK:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<5xf32>
30// CHECK:     "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
31// CHECK:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
32// CHECK:     "some_use"(%[[MEMREF2]], %{{.*}}) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
33// CHECK:     "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
34// CHECK:     "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
35// CHECK:     vector.transfer_write %{{.*}} : vector<3xf32>, memref<?x?xf32>
36// CHECK:     vector.transfer_write %{{.*}} : vector<4xf32>, memref<?x?xf32>
37// CHECK:     vector.transfer_write %{{.*}} : vector<5xf32>, memref<?x?xf32>
38// CHECK:     "some_crippling_use"(%[[MEMREF3]]) : (memref<?x?xf32>) -> ()
39// CHECK:     scf.yield {{.*}} : vector<1xf32>, vector<2xf32>
40// CHECK:   }
41// CHECK:   vector.transfer_write %{{.*}} : vector<2xf32>, memref<?x?xf32>
42// CHECK:   "unrelated_use"(%[[MEMREF0]]) : (memref<?x?xf32>) -> ()
43// CHECK:   scf.yield {{.*}} : vector<1xf32>
44// CHECK: }
45// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
46// CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref<?x?xf32>) -> ()
47  scf.for %i = %lb to %ub step %step {
48    scf.for %j = %lb to %ub step %step {
49      %r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32>
50      %r1 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
51      %r2 = vector.transfer_read %memref2[%c0, %c0], %cst: memref<?x?xf32>, vector<3xf32>
52      %r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref<?x?xf32>, vector<4xf32>
53      "some_crippling_use"(%memref4) : (memref<?x?xf32>) -> ()
54      %r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref<?x?xf32>, vector<5xf32>
55      %r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref<?x?xf32>, vector<6xf32>
56      "some_crippling_use"(%memref5) : (memref<?x?xf32>) -> ()
57      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
58      %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
59      %u2 = "some_use"(%memref2, %r2) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
60      %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
61      %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
62      %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
63      vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
64      vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
65      vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32>
66      vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref<?x?xf32>
67      vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref<?x?xf32>
68      vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref<?x?xf32>
69      "some_crippling_use"(%memref3) : (memref<?x?xf32>) -> ()
70    }
71    "unrelated_use"(%memref0) : (memref<?x?xf32>) -> ()
72  }
73  "unrelated_use"(%memref1) : (memref<?x?xf32>) -> ()
74  return
75}
76
77module attributes {transform.with_named_sequence} {
78  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
79    %0 = transform.structured.match ops{["func.func"]} in %arg1
80      : (!transform.any_op) -> !transform.any_op
81    transform.structured.hoist_redundant_vector_transfers %0
82      : (!transform.any_op) -> !transform.any_op
83    transform.yield
84  }
85}
86
87// -----
88
89// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint(
90//  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
91//  CHECK-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
92//  CHECK-SAME:   %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
93//  CHECK-SAME:   %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
94//  CHECK-SAME:   %[[VAL:[a-zA-Z0-9]*]]: index,
95//  CHECK-SAME:   %[[LB:[a-zA-Z0-9]*]]: index,
96//  CHECK-SAME:   %[[UB:[a-zA-Z0-9]*]]: index,
97//  CHECK-SAME:   %[[STEP:[a-zA-Z0-9]*]]: index,
98//  CHECK-SAME:   %[[RANDOM:[a-zA-Z0-9]*]]: index,
99//  CHECK-SAME:   %[[CMP:[a-zA-Z0-9]*]]: i1
100func.func @hoist_vector_transfer_pairs_disjoint(
101    %memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>,
102    %memref2: memref<?x?xf32>, %memref3: memref<?x?xf32>, %val: index, %lb : index, %ub : index,
103    %step: index, %random_index : index, %cmp: i1) {
104  %c0 = arith.constant 0 : index
105  %c1 = arith.constant 1 : index
106  %c3 = arith.constant 3 : index
107  %cst = arith.constant 0.0 : f32
108
109// CHECK: vector.transfer_read %[[MEMREF2]]{{.*}} : memref<?x?xf32>, vector<3xf32>
110// CHECK: vector.transfer_read %[[MEMREF2]]{{.*}} : memref<?x?xf32>, vector<3xf32>
111// CHECK: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
112// CHECK: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
113// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
114//  CHECK-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
115// CHECK:   scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
116//  CHECK-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
117// CHECK:     vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
118// CHECK:     vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
119// CHECK:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
120// CHECK:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
121// CHECK:     "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
122// CHECK:     "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
123// CHECK:     "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
124// CHECK:     "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
125// CHECK:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
126// CHECK:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
127// CHECK:     vector.transfer_write %{{.*}}, %[[MEMREF1]]{{.*}} : vector<2xf32>, memref<?x?xf32>
128// CHECK:     vector.transfer_write %{{.*}}, %[[MEMREF1]]{{.*}} : vector<2xf32>, memref<?x?xf32>
129// CHECK:     scf.yield {{.*}} : vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
130// CHECK:   }
131// CHECK:   scf.yield {{.*}} : vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
132// CHECK: }
133// CHECK: vector.transfer_write %{{.*}}, %[[MEMREF3]]{{.*}} : vector<4xf32>, memref<?x?xf32>
134// CHECK: vector.transfer_write %{{.*}}, %[[MEMREF3]]{{.*}} : vector<4xf32>, memref<?x?xf32>
135// CHECK: vector.transfer_write %{{.*}}, %[[MEMREF2]]{{.*}} : vector<3xf32>, memref<?x?xf32>
136// CHECK: vector.transfer_write %{{.*}}, %[[MEMREF2]]{{.*}} : vector<3xf32>, memref<?x?xf32>
137  scf.for %i = %lb to %ub step %step {
138    scf.for %j = %lb to %ub step %step {
139      %r00 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<2xf32>
140      %r01 = vector.transfer_read %memref1[%c0, %c1], %cst: memref<?x?xf32>, vector<2xf32>
141      %r20 = vector.transfer_read %memref2[%c0, %c0], %cst: memref<?x?xf32>, vector<3xf32>
142      %r21 = vector.transfer_read %memref2[%c0, %c3], %cst: memref<?x?xf32>, vector<3xf32>
143      %r30 = vector.transfer_read %memref3[%c0, %random_index], %cst: memref<?x?xf32>, vector<4xf32>
144      %r31 = vector.transfer_read %memref3[%c1, %random_index], %cst: memref<?x?xf32>, vector<4xf32>
145      %r10 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
146      %r11 = vector.transfer_read %memref0[%random_index, %random_index], %cst: memref<?x?xf32>, vector<2xf32>
147      %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
148      %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
149      %u20 = "some_use"(%r20) : (vector<3xf32>) -> vector<3xf32>
150      %u21 = "some_use"(%r21) : (vector<3xf32>) -> vector<3xf32>
151      %u30 = "some_use"(%r30) : (vector<4xf32>) -> vector<4xf32>
152      %u31 = "some_use"(%r31) : (vector<4xf32>) -> vector<4xf32>
153      %u10 = "some_use"(%r10) : (vector<2xf32>) -> vector<2xf32>
154      %u11 = "some_use"(%r11) : (vector<2xf32>) -> vector<2xf32>
155      vector.transfer_write %u00, %memref1[%c0, %c0] : vector<2xf32>, memref<?x?xf32>
156      vector.transfer_write %u01, %memref1[%c0, %c1] : vector<2xf32>, memref<?x?xf32>
157      vector.transfer_write %u20, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32>
158      vector.transfer_write %u21, %memref2[%c0, %c3] : vector<3xf32>, memref<?x?xf32>
159      vector.transfer_write %u30, %memref3[%c0, %random_index] : vector<4xf32>, memref<?x?xf32>
160      vector.transfer_write %u31, %memref3[%c1, %random_index] : vector<4xf32>, memref<?x?xf32>
161      vector.transfer_write %u10, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
162      vector.transfer_write %u11, %memref0[%random_index, %random_index] : vector<2xf32>, memref<?x?xf32>
163    }
164  }
165  return
166}
167
168module attributes {transform.with_named_sequence} {
169  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
170    %0 = transform.structured.match ops{["func.func"]} in %arg1
171      : (!transform.any_op) -> !transform.any_op
172    transform.structured.hoist_redundant_vector_transfers %0
173      : (!transform.any_op) -> !transform.any_op
174    transform.yield
175  }
176}
177
178// -----
179
180// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops(
181//  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>,
182//  CHECK-SAME:   %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>,
183//  CHECK-SAME:   %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) {
184//       CHECK:   %[[C0:.*]] = arith.constant 0 : i32
185//       CHECK:   affine.for %[[I:.*]] = 0 to 64 {
186//       CHECK:     affine.for %[[J:.*]] = 0 to 64 step 16 {
187//       CHECK:       %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32>
188//       CHECK:       %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) {
189//       CHECK:         %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32>
190//       CHECK:         %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32>
191//       CHECK:         %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32>
192//       CHECK:         %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32>
193//       CHECK:         affine.yield %[[T1]] : vector<16xi32>
194//       CHECK:       }
195//       CHECK:       vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32>
196//       CHECK:     }
197//       CHECK:   }
198func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) {
199  %c0_i32 = arith.constant 0 : i32
200  affine.for %arg3 = 0 to 64 {
201    affine.for %arg4 = 0 to 64 step 16 {
202      affine.for %arg5 = 0 to 64 {
203        %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32>
204        %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
205        %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
206        %3 = arith.muli %0, %1 : vector<16xi32>
207        %4 = arith.addi %2, %3 : vector<16xi32>
208        vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32>
209      }
210    }
211  }
212  return
213}
214
215module attributes {transform.with_named_sequence} {
216  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
217    %0 = transform.structured.match ops{["func.func"]} in %arg1
218      : (!transform.any_op) -> !transform.any_op
219    transform.structured.hoist_redundant_vector_transfers %0
220      : (!transform.any_op) -> !transform.any_op
221    transform.yield
222  }
223}
224
225// -----
226
227// CHECK-LABEL:  func.func @hoist_vector_transfer_read(
228// CHECK-DAG:      %[[C0:.+]] = arith.constant 0 : index
229// CHECK-DAG:      %[[C128:.+]] = arith.constant 128 : index
230// CHECK-DAG:      %[[C1024:.+]] = arith.constant 1024 : index
231// CHECK-DAG:      %[[CST:.+]] = arith.constant 0.000000e+00 : f32
232// CHECK:          %[[ALLOC:.+]] = memref.alloc() : memref<32x64xf32>
233// CHECK:          %[[ALLOC_0:.+]] = memref.alloc() : memref<32x128xf32>
234// CHECK:          %[[CAST:.+]] = memref.cast %[[ALLOC_0]] : memref<32x128xf32> to memref<32x128xf32, strided<[128, 1],
235// CHECK-SAME:       offset: ?>>
236// CHECK:          %[[D0:.+]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} :
237// CHECK-SAME:       memref<32x64xf32>, vector<32x64xf32>
238// CHECK:          scf.for %[[ARG0:.+]] = %[[C0]] to %[[C1024]] step %[[C128]] {
239// CHECK:            %[[D1:.+]] = vector.transfer_read %[[ALLOC_0]][%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true]}
240// CHECK-SAME:         : memref<32x128xf32>, vector<32x128xf32>
241// CHECK:            "some_use"(%[[D0]], %[[D1]], %[[CAST]]) : (vector<32x64xf32>, vector<32x128xf32>, memref<32x128xf32,
242// CHECK-SAME:         strided<[128, 1], offset: ?>>) -> ()
243// CHECK:          }
244// CHECK:          memref.dealloc %[[ALLOC]] : memref<32x64xf32>
245// CHECK:          return
246func.func @hoist_vector_transfer_read() {
247  %c0 = arith.constant 0 : index
248  %c128 = arith.constant 128 : index
249  %c1024 = arith.constant 1024 : index
250  %cst_2 = arith.constant 0.000000e+00 : f32
251  %memref0 = memref.alloc() : memref<32x64xf32>
252  %memref2 = memref.alloc() : memref<32x128xf32>
253  %subview2 = memref.subview %memref2[%c0, %c0] [32, 128] [1, 1]: memref<32x128xf32> to memref<32x128xf32, strided<[128, 1], offset: ?>>
254  scf.for %arg0 = %c0 to %c1024 step %c128 {
255    %2 = vector.transfer_read %memref2[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<32x128xf32>, vector<32x128xf32>
256    %3 = vector.transfer_read %memref0[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<32x64xf32>, vector<32x64xf32>
257    "some_use"(%3, %2, %subview2) : (vector<32x64xf32>, vector<32x128xf32>, memref<32x128xf32, strided<[128, 1], offset: ?>>) -> ()
258  }
259  memref.dealloc %memref0 : memref<32x64xf32>
260  return
261}
262
263module attributes {transform.with_named_sequence} {
264  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
265    %0 = transform.structured.match ops{["func.func"]} in %arg1
266      : (!transform.any_op) -> !transform.any_op
267    transform.structured.hoist_redundant_vector_transfers %0
268      : (!transform.any_op) -> !transform.any_op
269    transform.yield
270  }
271}
272
273// -----
274
275// The transfers in this test case cannot be hoisted and replaced by a vector
276// iter_arg because they do not match.
277
278// CHECK-LABEL:  func.func @non_matching_transfers(
279//       CHECK:    scf.for {{.*}} {
280//       CHECK:      vector.transfer_read
281//       CHECK:      vector.transfer_write
282//       CHECK:    }
283func.func @non_matching_transfers(%m: memref<6x1x7x32xf32>) {
284  %c0 = arith.constant 0 : index
285  %c1024 = arith.constant 1024 : index
286  %c128 = arith.constant 128 : index
287  %cst = arith.constant dense<5.5> : vector<6x7x32xf32>
288  %cst_0 = arith.constant 0.0 : f32
289  scf.for %iv = %c0 to %c1024 step %c128 {
290    %read = vector.transfer_read %m[%c0, %c0, %c0, %c0], %cst_0 {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>} : memref<6x1x7x32xf32>, vector<6x7x32xf32>
291    %added = arith.addf %read, %cst : vector<6x7x32xf32>
292    %bc = vector.broadcast %added : vector<6x7x32xf32> to vector<1x6x7x32xf32>
293    %tr = vector.transpose %bc, [1, 0, 2, 3] : vector<1x6x7x32xf32> to vector<6x1x7x32xf32>
294    vector.transfer_write %tr, %m[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<6x1x7x32xf32>, memref<6x1x7x32xf32>
295  }
296  return
297}
298
299module attributes {transform.with_named_sequence} {
300  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
301    %0 = transform.structured.match ops{["func.func"]} in %arg1
302      : (!transform.any_op) -> !transform.any_op
303    transform.structured.hoist_redundant_vector_transfers %0
304      : (!transform.any_op) -> !transform.any_op
305    transform.yield
306  }
307}
308
309// -----
310
311// CHECK-LABEL:  func.func @no_hoisting_unknown_bound_loop
312func.func @no_hoisting_unknown_bound_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
313  %c0_i32 = arith.constant 0 : i32
314  %c0 = arith.constant 0 : index
315  %c1 = arith.constant 1 : index
316
317  // %lb and %ub are unbounded, so do not hoist.
318  // CHECK:       scf.for {{.*}} {
319  // CHECK-NEXT:    vector.transfer_read
320  // CHECK-NEXT:    "test.some_use"
321  scf.for %arg2 = %lb to %ub step %c1 {
322    %read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
323    "test.some_use"(%read) : (vector<4xi32>) ->()
324  }
325  return
326}
327
328module attributes {transform.with_named_sequence} {
329  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
330    %0 = transform.structured.match ops{["func.func"]} in %arg1
331      : (!transform.any_op) -> !transform.any_op
332    transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
333      : (!transform.any_op) -> !transform.any_op
334    transform.yield
335  }
336}
337
338// -----
339
340// CHECK-LABEL:  func.func @no_hoisting_possibly_zero_trip_loop
341func.func @no_hoisting_possibly_zero_trip_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
342  %c0_i32 = arith.constant 0 : i32
343  %c0 = arith.constant 0 : index
344  %c1 = arith.constant 1 : index
345
346  // %lb_0 is in range [%lb, 8], and %ub_0 is in range [4, %ub].
347  // Since %lb_0 could be greater than %ub_0, do not hoist.
348  %lb_0 = affine.min affine_map<(d0) -> (d0, 8)>(%lb)
349  %ub_0 = affine.max affine_map<(d0) -> (d0, 4)>(%ub)
350
351  // CHECK:       scf.for {{.*}} {
352  // CHECK-NEXT:    vector.transfer_read
353  // CHECK-NEXT:    "test.some_use"
354  scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
355    %read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
356    "test.some_use"(%read) : (vector<4xi32>) ->()
357  }
358  return
359}
360
361module attributes {transform.with_named_sequence} {
362  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
363    %0 = transform.structured.match ops{["func.func"]} in %arg1
364      : (!transform.any_op) -> !transform.any_op
365    transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
366      : (!transform.any_op) -> !transform.any_op
367    transform.yield
368  }
369}
370
371// -----
372
373// CHECK-LABEL:  func.func @no_hoisting_possibly_zero_trip_loop_eq_lb_and_ub
374func.func @no_hoisting_possibly_zero_trip_loop_eq_lb_and_ub(%memref0: memref<20xi32>, %lb: index, %ub: index) {
375  %c0_i32 = arith.constant 0 : i32
376  %c0 = arith.constant 0 : index
377  %c1 = arith.constant 1 : index
378
379  // %lb_0 is in range [%lb, 8], and %ub_0 is in range [8, %ub].
380  // Since %lb_0 could be equal to %ub_0, do not hoist.
381  %lb_0 = affine.min affine_map<(d0) -> (d0, 8)>(%lb)
382  %ub_0 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)
383
384  // CHECK:       scf.for {{.*}} {
385  // CHECK-NEXT:    vector.transfer_read
386  // CHECK-NEXT:    "test.some_use"
387  scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
388    %read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
389    "test.some_use"(%read) : (vector<4xi32>) ->()
390  }
391  return
392}
393
394module attributes {transform.with_named_sequence} {
395  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
396    %0 = transform.structured.match ops{["func.func"]} in %arg1
397      : (!transform.any_op) -> !transform.any_op
398    transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
399      : (!transform.any_op) -> !transform.any_op
400    transform.yield
401  }
402}
403
404// -----
405
406// CHECK-LABEL:  func.func @hoisting_non_zero_trip_loop
407func.func @hoisting_non_zero_trip_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
408  %c0_i32 = arith.constant 0 : i32
409  %c0 = arith.constant 0 : index
410  %c1 = arith.constant 1 : index
411
412  // %lb_0 is in range [%lb, 4], and %ub_0 is in range [8, %ub].
413  // Since %lb_0 is guaranteed to be less than %ub_0, hoisting is possible.
414  %lb_0 = affine.min affine_map<(d0) -> (d0, 4)>(%lb)
415  %ub_0 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)
416
417  // CHECK:       vector.transfer_read
418  // CHECK:       scf.for {{.*}} {
419  // CHECK-NEXT:    "test.some_use"
420  scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
421    %read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
422    "test.some_use"(%read) : (vector<4xi32>) ->()
423  }
424  return
425}
426
427module attributes {transform.with_named_sequence} {
428  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
429    %0 = transform.structured.match ops{["func.func"]} in %arg1
430      : (!transform.any_op) -> !transform.any_op
431    transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
432      : (!transform.any_op) -> !transform.any_op
433    transform.yield
434  }
435}
436
437// -----
438
439// Regression test - `vector.transfer_read` below should not be hoisted.
440// Indeed, %collapse_shape (written to by `vector.transfer_write`) and %alloca
441// (read by `vector.transfer_read`) alias.
442
443// CHECK-LABEL:  func.func @no_hoisting_collapse_shape
444//       CHECK:    scf.for {{.*}} {
445//       CHECK:      vector.transfer_write {{.*}} : vector<4xi32>, memref<4xi32>
446//       CHECK-NEXT:      vector.transfer_read {{.*}} : memref<1x4x1xi32>, vector<1x4x1xi32>
447//       CHECK-NEXT:      vector.transfer_write {{.*}} : vector<1x4x1xi32>, memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>>
448//       CHECK-NEXT:    }
449
450func.func @no_hoisting_collapse_shape(%in_0: memref<1x20x1xi32>, %1: memref<9x1xi32>, %vec: vector<4xi32>) {
451  %c0_i32 = arith.constant 0 : i32
452  %c0 = arith.constant 0 : index
453  %c4 = arith.constant 4 : index
454  %c20 = arith.constant 20 : index
455  %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x4x1xi32>
456  scf.for %arg0 = %c0 to %c20 step %c4 {
457    %subview = memref.subview %in_0[0, %arg0, 0] [1, 4, 1] [1, 1, 1] : memref<1x20x1xi32> to memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>>
458    %collapse_shape = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x4x1xi32> into memref<4xi32>
459    vector.transfer_write %vec, %collapse_shape[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32>
460    %read = vector.transfer_read %alloca[%c0, %c0, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x4x1xi32>, vector<1x4x1xi32>
461    vector.transfer_write %read, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x1xi32>, memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>>
462  }
463  return
464}
465
466module attributes {transform.with_named_sequence} {
467  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
468    %0 = transform.structured.match ops{["func.func"]} in %arg1
469      : (!transform.any_op) -> !transform.any_op
470    transform.structured.hoist_redundant_vector_transfers %0
471      : (!transform.any_op) -> !transform.any_op
472    transform.yield
473  }
474}
475
476// -----
477
478// Regression test - `vector.transfer_read` below should not be hoisted.
479// Indeed, %collapse_shape (read by `vector.transfer_read`) and %alloca
480// (written to by `vector.transfer_write`) alias.
481
482// CHECK-LABEL:  func.func @no_hoisting_collapse_shape_2
483//       CHECK:    scf.for {{.*}} {
484//       CHECK:      vector.transfer_write
485//       CHECK:      vector.transfer_read
486
487func.func @no_hoisting_collapse_shape_2(%vec: vector<1x12x1xi32>) {
488  %c0_i32 = arith.constant 0 : i32
489  %c0 = arith.constant 0 : index
490  %c4 = arith.constant 4 : index
491  %c20 = arith.constant 20 : index
492  %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x12x1xi32>
493  scf.for %arg0 = %c0 to %c20 step %c4 {
494    %collapse_shape = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x12x1xi32> into memref<12xi32>
495    vector.transfer_write %vec, %alloca[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x12x1xi32>, memref<1x12x1xi32>
496    %read = vector.transfer_read %collapse_shape[%c0], %c0_i32 {in_bounds = [true]} : memref<12xi32>, vector<12xi32>
497    "test.some_use"(%read) : (vector<12xi32>) ->()
498  }
499  return
500}
501
502module attributes {transform.with_named_sequence} {
503  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
504    %0 = transform.structured.match ops{["func.func"]} in %arg1
505      : (!transform.any_op) -> !transform.any_op
506    transform.structured.hoist_redundant_vector_transfers %0
507      : (!transform.any_op) -> !transform.any_op
508    transform.yield
509  }
510}
511
512// -----
513
514// Regression test - hoisting the following `vector.transfer_{read|write}` pair
515// would not be safe:
516//    %lhs = vector.transfer_read %collapsed_1[%c0]
517//    vector.transfer_write %op, %collapsed_1[%c0]
518// That's because the following `vector.transfer_read` reads from the same
519// memory (i.e. `%collapsed_1` and `%collapsed_2` alias):
520//    %acc = vector.transfer_read %collapsed_2[%c0]
521
522// CHECK-LABEL:  func.func @no_hoisting_write_to_memref
523//       CHECK:    scf.for {{.*}} {
524//       CHECK:      vector.transfer_read {{.*}} :  memref<2xi32>, vector<1xi32>
525//       CHECK-NEXT:      vector.transfer_read {{.*}} :  memref<2xi32>, vector<1xi32>
526//       CHECK-NEXT:      vector.outerproduct {{.*}} : vector<1xi32>, i32
527//       CHECK-NEXT:      vector.transfer_write {{.*}} : vector<1xi32>, memref<2xi32>
528//       CHECK-NEXT:    }
529
530func.func @no_hoisting_write_to_memref(%rhs: i32, %arg1: vector<1xi32>) {
531  %c0_i32 = arith.constant 0 : i32
532  %c0 = arith.constant 0 : index
533  %c1 = arith.constant 1 : index
534  %c4 = arith.constant 4 : index
535  %c20 = arith.constant 20 : index
536  %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x1x2xi32>
537  %cast = memref.cast %alloca : memref<1x1x2xi32> to memref<1x1x2xi32>
538  %collapsed_1 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32>
539  scf.for %_ = %c0 to %c20 step %c4 {
540    %collapsed_2 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32>
541    %lhs = vector.transfer_read %collapsed_1[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32>
542    %acc = vector.transfer_read %collapsed_2[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32>
543    %op = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<1xi32>, i32
544    vector.transfer_write %op, %collapsed_1[%c0] {in_bounds = [true]} : vector<1xi32>, memref<2xi32>
545  }
546  return
547}
548
549module attributes {transform.with_named_sequence} {
550  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
551    %0 = transform.structured.match ops{["func.func"]} in %arg1
552      : (!transform.any_op) -> !transform.any_op
553    transform.structured.hoist_redundant_vector_transfers %0
554      : (!transform.any_op) -> !transform.any_op
555    transform.yield
556  }
557}
558
559// -----
560
561// Test that we can hoist out 1-D read-write pairs whose indices are dynamic values.
562
563// CHECK: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 1)>
564// CHECK: #[[$MAP4:.+]] = affine_map<()[s0] -> (s0 + 4)>
565
566//   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
567//    CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[I0:.+]]: index)
568
569//         CHECK:   %[[PLUS1:.+]] = affine.apply #[[$MAP1]]()[%[[I0]]]
570//         CHECK:   %[[PLUS4:.+]] = affine.apply #[[$MAP4]]()[%[[I0]]]
571//         CHECK:   %2 = vector.transfer_read %[[BUFFER]][%[[I0]], %[[I0]]]
572//         CHECK:   %3 = vector.transfer_read %[[BUFFER]][%[[PLUS1]], %[[I0]]]
573//         CHECK:   %4 = vector.transfer_read %[[BUFFER]][%[[PLUS1]], %[[PLUS4]]]
574// CHECK-COUNT-2:   scf.for %{{.+}} = {{.+}} -> (vector<4xf32>, vector<4xf32>, vector<4xf32>)
575// CHECK-COUNT-3:     "some_use"
576// CHECK-COUNT-2:   scf.yield {{.+}} : vector<4xf32>, vector<4xf32>, vector<4xf32>
577//         CHECK:   vector.transfer_write %{{.+}}, %[[BUFFER]][%[[PLUS1]], %[[PLUS4]]]
578//         CHECK:   vector.transfer_write %{{.+}}, %[[BUFFER]][%[[PLUS1]], %[[I0]]]
579//         CHECK:   vector.transfer_write %{{.+}}, %[[BUFFER]][%[[I0]], %[[I0]]]
580
581func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
582    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
583  %cst = arith.constant 0.0 : f32
584  %i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0)
585  %i2 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0)
586
587  scf.for %i = %lb to %ub step %step {
588    scf.for %j = %lb to %ub step %step {
589      %r0 = vector.transfer_read %buffer[%i0, %i0], %cst: memref<?x?xf32>, vector<4xf32>
590      // Disjoint leading dim
591      %r1 = vector.transfer_read %buffer[%i1, %i0], %cst: memref<?x?xf32>, vector<4xf32>
592      // Non-overlap trailing dim
593      %r2 = vector.transfer_read %buffer[%i1, %i2], %cst: memref<?x?xf32>, vector<4xf32>
594      %u0 = "some_use"(%r0) : (vector<4xf32>) -> vector<4xf32>
595      %u1 = "some_use"(%r1) : (vector<4xf32>) -> vector<4xf32>
596      %u2 = "some_use"(%r2) : (vector<4xf32>) -> vector<4xf32>
597      vector.transfer_write %u0, %buffer[%i0, %i0] : vector<4xf32>, memref<?x?xf32>
598      vector.transfer_write %u1, %buffer[%i1, %i0] : vector<4xf32>, memref<?x?xf32>
599      vector.transfer_write %u2, %buffer[%i1, %i2] : vector<4xf32>, memref<?x?xf32>
600    }
601  }
602  return
603}
604
605module attributes {transform.with_named_sequence} {
606  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
607    %0 = transform.structured.match ops{["func.func"]} in %arg1
608      : (!transform.any_op) -> !transform.any_op
609    transform.structured.hoist_redundant_vector_transfers %0
610      : (!transform.any_op) -> !transform.any_op
611    transform.yield
612  }
613}
614
615// -----
616
617// Test that we cannot hoist out read-write pairs whose indices are overlapping.
618
619//   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_overlapping_dynamic
620// CHECK-COUNT-2:   scf.for
621// CHECK-COUNT-2:     vector.transfer_read
622// CHECK-COUNT-2:     vector.transfer_write
623
624func.func @hoist_vector_transfer_pairs_overlapping_dynamic(
625    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
626  %cst = arith.constant 0.0 : f32
627  %i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0)
628
629  scf.for %i = %lb to %ub step %step {
630    scf.for %j = %lb to %ub step %step {
631      %r0 = vector.transfer_read %buffer[%i0, %i0], %cst: memref<?x?xf32>, vector<4xf32>
632      // Overlapping range with the above
633      %r1 = vector.transfer_read %buffer[%i0, %i1], %cst: memref<?x?xf32>, vector<4xf32>
634      %u0 = "some_use"(%r0) : (vector<4xf32>) -> vector<4xf32>
635      %u1 = "some_use"(%r1) : (vector<4xf32>) -> vector<4xf32>
636      vector.transfer_write %u0, %buffer[%i0, %i0] : vector<4xf32>, memref<?x?xf32>
637      vector.transfer_write %u1, %buffer[%i0, %i1] : vector<4xf32>, memref<?x?xf32>
638    }
639  }
640  return
641}
642
643module attributes {transform.with_named_sequence} {
644  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
645    %0 = transform.structured.match ops{["func.func"]} in %arg1
646      : (!transform.any_op) -> !transform.any_op
647    transform.structured.hoist_redundant_vector_transfers %0
648      : (!transform.any_op) -> !transform.any_op
649    transform.yield
650  }
651}
652
653// -----
654
655// Test that we can hoist out 2-D read-write pairs whose indices are dynamic values.
656
657//   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
658// CHECK-COUNT-3:   vector.transfer_read
659// CHECK-COUNT-2:   %{{.+}}:3 = scf.for {{.+}} -> (vector<16x8xf32>, vector<16x8xf32>, vector<16x8xf32>)
660// CHECK-COUNT-2:   scf.yield {{.+}} : vector<16x8xf32>, vector<16x8xf32>, vector<16x8xf32>
661// CHECK-COUNT-3:   vector.transfer_write
662//         CHECK:   return
663
664func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
665    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index, %i1 : index) {
666  %cst = arith.constant 0.0 : f32
667  %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
668  %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
669  %i4 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 16)>(%i1)
670
671  scf.for %i = %lb to %ub step %step {
672    scf.for %j = %lb to %ub step %step {
673      %r0 = vector.transfer_read %buffer[%i0, %i2], %cst: memref<?x?xf32>, vector<16x8xf32>
674      %r1 = vector.transfer_read %buffer[%i0, %i3], %cst: memref<?x?xf32>, vector<16x8xf32>
675      %r2 = vector.transfer_read %buffer[%i0, %i4], %cst: memref<?x?xf32>, vector<16x8xf32>
676      %u0 = "some_use"(%r0) : (vector<16x8xf32>) -> vector<16x8xf32>
677      %u1 = "some_use"(%r1) : (vector<16x8xf32>) -> vector<16x8xf32>
678      %u2 = "some_use"(%r2) : (vector<16x8xf32>) -> vector<16x8xf32>
679      vector.transfer_write %u2, %buffer[%i0, %i4] : vector<16x8xf32>, memref<?x?xf32>
680      vector.transfer_write %u1, %buffer[%i0, %i3] : vector<16x8xf32>, memref<?x?xf32>
681      vector.transfer_write %u0, %buffer[%i0, %i2] : vector<16x8xf32>, memref<?x?xf32>
682    }
683  }
684  return
685}
686
687module attributes {transform.with_named_sequence} {
688  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
689    %0 = transform.structured.match ops{["func.func"]} in %arg1
690      : (!transform.any_op) -> !transform.any_op
691    transform.structured.hoist_redundant_vector_transfers %0
692      : (!transform.any_op) -> !transform.any_op
693    transform.yield
694  }
695}
696
697// -----
698
699// Test hoisting of vector.extract/vector.broadcast pairs
700
701// CHECK-LABEL:  func.func @hoist_vector_broadcasts
702//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
703//       CHECK:        %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
704//       CHECK-NEXT:   %[[LOOP:.+]] = scf.for {{.*}} {
705//       CHECK-NEXT:     %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
706//       CHECK-NEXT:     scf.yield %[[USE]] : vector<4xf32>
707//       CHECK-NEXT:   }
708//       CHECK-NEXT:   %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
709//       CHECK-NEXT:   return %[[BCAST]] : vector<3x4xf32>
710
711func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
712  %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
713    %extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
714    %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
715    %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
716    scf.yield %broadcast : vector<3x4xf32>
717  }
718  return %bcast_vec : vector<3x4xf32>
719}
720
721module attributes {transform.with_named_sequence} {
722  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
723    %0 = transform.structured.match ops{["func.func"]} in %arg1
724      : (!transform.any_op) -> !transform.any_op
725    transform.structured.hoist_redundant_vector_broadcasts %0
726      : (!transform.any_op) -> !transform.any_op
727    transform.yield
728  }
729}
730
731// -----
732
733// Test hoisting of vector.extract/vector.broadcast pairs with dynamic position
734
735// CHECK-LABEL:  func.func @hoist_vector_broadcasts
736//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
737//       CHECK:        %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
738//       CHECK-NEXT:   %[[LOOP:.+]] = scf.for {{.*}} {
739//       CHECK-NEXT:     %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
740//       CHECK-NEXT:     scf.yield %[[USE]] : vector<4xf32>
741//       CHECK-NEXT:   }
742//       CHECK-NEXT:   %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
743//       CHECK-NEXT:   return %[[BCAST]] : vector<3x4xf32>
744
745func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
746  %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
747    %extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32>
748    %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
749    %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
750    scf.yield %broadcast : vector<3x4xf32>
751  }
752  return %bcast_vec : vector<3x4xf32>
753}
754
755module attributes {transform.with_named_sequence} {
756  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
757    %0 = transform.structured.match ops{["func.func"]} in %arg1
758      : (!transform.any_op) -> !transform.any_op
759    transform.structured.hoist_redundant_vector_broadcasts %0
760      : (!transform.any_op) -> !transform.any_op
761    transform.yield
762  }
763}
764
765// -----
766
767// Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args
768
769// CHECK-LABEL:  func.func @hoist_vector_broadcasts_multiple
770//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
771//       CHECK-SAME:  %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
772//       CHECK-DAG:     %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
773//       CHECK-DAG:     %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
774//       CHECK-NEXT:    %[[LOOP:.+]]:2 = scf.for {{.*}} {
775//       CHECK-DAG:       %[[USE1:.+]] = "some_use1"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
776//       CHECK-DAG:       %[[USE2:.+]] = "some_use2"({{.*}}) : (vector<5xf32>) -> vector<5xf32>
777//       CHECK-NEXT:      scf.yield %[[USE1]], %[[USE2]]  : vector<4xf32>, vector<5xf32>
778//       CHECK-NEXT:    }
779//       CHECK-DAG:     %[[BCAST1:.+]] = vector.broadcast %[[LOOP]]#0 : vector<4xf32> to vector<3x4xf32>
780//       CHECK-DAG:     %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
781//       CHECK-NEXT:    return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>
782
783func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) ->  (vector<3x4xf32>, vector<3x5xf32>) {
784  %bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) {
785    %extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
786    %extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32>
787    %use1 = "some_use1"(%extract1) : (vector<4xf32>) -> vector<4xf32>
788    %use2 = "some_use2"(%extract2) : (vector<5xf32>) -> vector<5xf32>
789    %broadcast1 = vector.broadcast %use1 : vector<4xf32> to vector<3x4xf32>
790    %broadcast2 = vector.broadcast %use2 : vector<5xf32> to vector<3x5xf32>
791    scf.yield %broadcast1, %broadcast2 : vector<3x4xf32>,vector<3x5xf32>
792  }
793  return %bcast_vec#0, %bcast_vec#1 :  vector<3x4xf32>, vector<3x5xf32>
794}
795
796module attributes {transform.with_named_sequence} {
797  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
798    %0 = transform.structured.match ops{["func.func"]} in %arg1
799      : (!transform.any_op) -> !transform.any_op
800    transform.structured.hoist_redundant_vector_broadcasts %0
801      : (!transform.any_op) -> !transform.any_op
802    transform.yield
803  }
804}
805