xref: /llvm-project/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir (revision f1595ecfdce5387e41826fd72ff930a1a39ae398)
1// RUN: mlir-opt -transform-interpreter -cse -split-input-file %s | FileCheck %s
2
3func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
4  %c0 = arith.constant 0 : index
5  %c1 = arith.constant 1 : index
6  %cst = arith.constant 0.0 : f32
7  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
8  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
9  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
10  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
11  %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
12      outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
13  return %gemm : tensor<?x?xf32>
14}
15
16module attributes {transform.with_named_sequence} {
17  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
18    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
19      : (!transform.any_op) -> !transform.any_op
20    %a, %b, %c = transform.structured.fuse %matmul [10, 20]
21      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
22    transform.yield
23  }
24}
25//      CHECK: func.func @gemm_fill_fusion(
26// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
27// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
28//      CHECK:   %[[INIT:.+]] = tensor.empty
29//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] =
30// CHECK-SAME:       iter_args(%[[ITERARG0:.+]] = %[[INIT]])
31//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] =
32// CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
33//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
34//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
35//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
36//      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
37// CHECK-SAME:           outs(%[[INIT_TILE]] :
38//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
39// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
40// CHECK-SAME:           outs(%[[FILL_TILE]] :
41//      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
42//      CHECK:       scf.yield %[[INSERT]]
43
44// -----
45
46func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
47    %arg2 : tensor<?xf32>) -> tensor<?x?xf32> {
48  %c0 = arith.constant 0 : index
49  %c1 = arith.constant 1 : index
50  %cst = arith.constant 0.0 : f32
51  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
52  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
53  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
54  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
55  %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
56      outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
57  %generic = linalg.generic {
58      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
59      iterator_types = ["parallel", "parallel"]}
60      ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
61    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
62      %add = arith.addf %b0, %b1 : f32
63      linalg.yield %add : f32
64  } -> tensor<?x?xf32>
65  return %generic : tensor<?x?xf32>
66}
67
68module attributes {transform.with_named_sequence} {
69  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
70    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
71      : (!transform.any_op) -> !transform.any_op
72    %a, %b, %c = transform.structured.fuse %generic [10, 20]
73      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
74    transform.yield
75  }
76}
77//      CHECK: func.func @gemm_generic_fusion(
78// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
79// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
80// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf32>)
81//      CHECK:   %[[INIT:.+]] = tensor.empty
82//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] =
83// CHECK-SAME:       iter_args(%[[ITERARG0:.+]] = %[[INIT]])
84//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] =
85// CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
86//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
87//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
88//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
89//      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
90// CHECK-SAME:           outs(%[[INIT_TILE]] :
91//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
92// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
93// CHECK-SAME:           outs(%[[FILL_TILE]] :
94//  CHECK-DAG:       %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]]
95//  CHECK-DAG:       %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
96//      CHECK:       %[[GENERIC_TILE:.+]] = linalg.generic
97// CHECK-SAME:           ins(%[[GEMM_TILE]], %[[BIAS_TILE]] :
98// CHECK-SAME:           outs(%[[OUTS_TILE]] :
99//      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
100//      CHECK:       scf.yield %[[INSERT]]
101
102// -----
103
104func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
105  %c0 = arith.constant 0 : index
106  %c1 = arith.constant 1 : index
107  %cst = arith.constant 0.0 : f32
108  %d0 = tensor.dim %lhs0, %c0 : tensor<?x?xf32>
109  %d1 = tensor.dim %rhs0, %c1 : tensor<?x?xf32>
110  %init0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
111  %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
112  %gemm0 = linalg.matmul
113      ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
114  %d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
115  %init1 = tensor.empty(%d0, %d2) : tensor<?x?xf32>
116  %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
117  %gemm1 = linalg.matmul
118      ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
119  return %gemm1 : tensor<?x?xf32>
120}
121
122module attributes {transform.with_named_sequence} {
123  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
124    %matmuls = transform.structured.match ops{["linalg.matmul"]} in %arg1
125      : (!transform.any_op) -> !transform.any_op
126    %mm1, %mm2 = transform.split_handle %matmuls
127      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
128    %a, %b = transform.structured.fuse %mm2 [10]
129      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
130    transform.yield
131  }
132}
133//      CHECK: func.func @gemm_gemm_fusion(
134// CHECK-SAME:     %[[LHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
135// CHECK-SAME:     %[[RHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
136// CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
137//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
138//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
139//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[LHS0]], %[[C0]]
140//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[RHS0]], %[[C1]]
141//  CHECK-DAG:   %[[INIT0:.+]] = tensor.empty(%[[D0]], %[[D1]])
142//  CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]]
143//      CHECK:   %[[INIT1:.+]] = tensor.empty(%[[D0]], %[[D2]])
144//      CHECK:   scf.for %[[IV:[a-zA-Z0-9]+]] =
145// CHECK-SAME:       iter_args(%[[ITERARG:.+]] = %[[INIT1]])
146//  CHECK-DAG:     %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
147//  CHECK-DAG:     %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0]
148//  CHECK-DAG:     %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]], 0]
149//      CHECK:     %[[FILL0_TILE:.+]] = linalg.fill
150// CHECK-SAME:         outs(%[[INIT0_TILE]] :
151//      CHECK:     %[[GEMM0_TILE:.+]] = linalg.matmul
152// CHECK-SAME:         ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
153// CHECK-SAME:         outs(%[[FILL0_TILE]] :
154//  CHECK-DAG:     %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
155//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG]][%[[IV]], 0]
156//      CHECK:     %[[FILL1_TILE:.+]] = linalg.fill
157// CHECK-SAME:         outs(%[[INIT1_TILE]] :
158//      CHECK:     %[[GEMM1_TILE:.+]] = linalg.matmul
159// CHECK-SAME:         ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] :
160// CHECK-SAME:         outs(%[[FILL1_TILE]] :
161//      CHECK:     %[[INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG]][%[[IV]], 0]
162//      CHECK:     scf.yield %[[INSERT]]
163
164// -----
165
166func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
167  %c0 = arith.constant 0 : index
168  %c1 = arith.constant 1 : index
169  %cst = arith.constant 0.0 : f32
170  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
171  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
172  %init0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
173  %fill = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
174  %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
175      outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
176  %init1 = tensor.empty(%d1, %d0) : tensor<?x?xf32>
177  %transpose = linalg.generic {
178      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
179      iterator_types = ["parallel", "parallel"]}
180      ins(%gemm : tensor<?x?xf32>) outs(%init1 : tensor<?x?xf32>) {
181    ^bb0(%b0 : f32, %b1 : f32):
182      linalg.yield %b0 : f32
183  } -> tensor<?x?xf32>
184  return %transpose : tensor<?x?xf32>
185}
186
187module attributes {transform.with_named_sequence} {
188  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
189    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
190      : (!transform.any_op) -> !transform.any_op
191    %a, %b, %c = transform.structured.fuse %generic [10, 20]
192      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
193    transform.yield
194  }
195}
196//      CHECK: func.func @gemm_transpose_fusion(
197// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
198// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
199//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
200//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
201//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
202//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
203//  CHECK-DAG:   %[[INIT0:.+]] = tensor.empty(%[[D0]], %[[D1]])
204//  CHECK-DAG:   %[[INIT1:.+]] = tensor.empty(%[[D1]], %[[D0]])
205//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] =
206// CHECK-SAME:       iter_args(%[[ITERARG0:.+]] = %[[INIT1]])
207//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] =
208// CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
209//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
210//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
211//  CHECK-DAG:       %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV0]], %[[IV1]]]
212//      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
213// CHECK-SAME:           outs(%[[INIT0_TILE]] :
214//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
215// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
216// CHECK-SAME:           outs(%[[FILL_TILE]] :
217//  CHECK-DAG:       %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]]
218//      CHECK:       %[[GENERIC_TILE:.+]] = linalg.generic
219// CHECK-SAME:           ins(%[[GEMM_TILE]] :
220// CHECK-SAME:           outs(%[[OUTS_TILE]] :
221//      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
222//      CHECK:       scf.yield %[[INSERT]]
223
224// -----
225
226func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
227  %c0 = arith.constant 0 : index
228  %c1 = arith.constant 1 : index
229  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
230  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
231  %cst = arith.constant 0.0 : f32
232  %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
233  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
234  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
235      outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
236  %3 = linalg.generic {
237      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
238      iterator_types = ["parallel", "parallel"]}
239      ins(%2 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
240      ^bb0(%b0 : f32, %b1 : f32):
241        %4 = arith.addf %b0, %b0 : f32
242        linalg.yield %4 : f32
243      } -> tensor<?x?xf32>
244  return %3 : tensor<?x?xf32>
245}
246
247module attributes {transform.with_named_sequence} {
248  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
249    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
250      : (!transform.any_op) -> !transform.any_op
251    %a, %b, %c = transform.structured.fuse %generic [10, 20] interchange[1, 0]
252      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
253    transform.yield
254  }
255}
256//      CHECK: func.func @interchange_matmul_fusion(
257// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
258// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
259//      CHECK:   %[[INIT:.+]] = tensor.empty
260//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] =
261// CHECK-SAME:       iter_args(%[[ITERARG0:.+]] = %[[INIT]])
262//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] =
263// CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
264//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0]
265//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]]
266//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], %[[IV0]]]
267//      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
268// CHECK-SAME:           outs(%[[INIT_TILE]] :
269//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
270// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
271// CHECK-SAME:           outs(%[[FILL_TILE]] :
272//      CHECK:       %[[INIT_TILE_2:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]]
273//      CHECK:       %[[GENERIC_TILE:.+]] = linalg.generic
274// CHECK-SAME:           ins(%[[GEMM_TILE]] :
275// CHECK-SAME:           outs(%[[INIT_TILE_2]] :
276//      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
277//      CHECK:       scf.yield %[[INSERT]]
278
279// -----
280
281func.func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
282                         %arg2: tensor<?x?xf32>) -> tensor<?x?xf32>{
283  %c0 = arith.constant 0 : index
284  %c1 = arith.constant 1 : index
285  %0 = tensor.dim %arg2, %c0 : tensor<?x?xf32>
286  %1 = tensor.dim %arg2, %c1 : tensor<?x?xf32>
287  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
288    outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
289  %3 = tensor.dim %2, %c0 : tensor<?x?xf32>
290  %4 = tensor.dim %2, %c1 : tensor<?x?xf32>
291  %5 = tensor.empty(%3, %4) : tensor<?x?xf32>
292  %6 = linalg.generic
293    {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
294                      affine_map<(d0, d1) -> (d0, d1)>,
295                      affine_map<(d0, d1) -> (d0, d1)>],
296     iterator_types = ["parallel", "parallel"]}
297    ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
298    outs(%5 : tensor<?x?xf32>) {
299    ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
300      %7 = arith.addf %arg3, %arg4 : f32
301      linalg.yield %7 : f32
302    } -> tensor<?x?xf32>
303  return %6 : tensor<?x?xf32>
304}
305
306module attributes {transform.with_named_sequence} {
307  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
308    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
309      : (!transform.any_op) -> !transform.any_op
310    %a, %b, %c = transform.structured.fuse %generic [10, 20]
311      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
312    transform.yield
313  }
314}
315//       CHECK: func @matmul_plus_matmul
316//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
317//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
318//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
319//       CHECK:   %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]]
320//  CHECK-SAME:     iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}})
321//       CHECK:     %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]]
322//  CHECK-SAME:       iter_args(%[[ARG6:.+]] = %[[ARG4]])
323//   CHECK-DAG:       %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
324//   CHECK-DAG:       %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
325//   CHECK-DAG:       %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]]
326//       CHECK:       %[[MATMUL:.+]] = linalg.matmul
327//  CHECK-SAME:         ins(%[[ST_ARG0]], %[[ST_ARG1]] :
328//  CHECK-SAME:         outs(%[[ST_ARG2]] :
329//       CHECK:       %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]]
330//       CHECK:       %[[ST_RESULT:.+]] = linalg.generic
331//  CHECK-SAME:         ins(%[[MATMUL]], %[[MATMUL]] :
332//  CHECK-SAME:         outs(%[[ST_ARG6]] :
333//       CHECK:       %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]]
334//  CHECK-SAME:         into %[[ARG6]][%[[IV0]], %[[IV1]]]
335//       CHECK:       scf.yield %[[UPDATE]]
336//       CHECK:     scf.yield %[[YIELD]]
337//       CHECK:   return %[[RESULT]]
338
339// -----
340
341func.func @matmul_plus_transpose_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
342                         %arg2: tensor<?x?xf32>) -> tensor<?x?xf32>{
343  %c0 = arith.constant 0 : index
344  %c1 = arith.constant 1 : index
345  %0 = tensor.dim %arg2, %c0 : tensor<?x?xf32>
346  %1 = tensor.dim %arg2, %c1 : tensor<?x?xf32>
347  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
348    outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
349  %3 = tensor.dim %2, %c0 : tensor<?x?xf32>
350  %4 = tensor.dim %2, %c1 : tensor<?x?xf32>
351  %5 = tensor.empty(%3, %4) : tensor<?x?xf32>
352  %6 = linalg.generic
353    {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
354                      affine_map<(d0, d1) -> (d1, d0)>,
355                      affine_map<(d0, d1) -> (d0, d1)>],
356     iterator_types = ["parallel", "parallel"]}
357    ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
358    outs(%5 : tensor<?x?xf32>) {
359    ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
360      %7 = arith.addf %arg3, %arg4 : f32
361      linalg.yield %7 : f32
362    } -> tensor<?x?xf32>
363  return %6 : tensor<?x?xf32>
364}
365
366module attributes {transform.with_named_sequence} {
367  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
368    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
369      : (!transform.any_op) -> !transform.any_op
370    %a, %b, %c = transform.structured.fuse %generic [10, 20]
371      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
372    transform.yield
373  }
374}
375//       CHECK: func @matmul_plus_transpose_matmul
376//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
377//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
378//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
379//       CHECK:   %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]]
380//  CHECK-SAME:     iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}})
381//       CHECK:     %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]]
382//  CHECK-SAME:       iter_args(%[[ARG6:.+]] = %[[ARG4]])
383//   CHECK-DAG:       %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
384//   CHECK-DAG:       %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
385//   CHECK-DAG:       %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]]
386//       CHECK:       %[[LHS:.+]] = linalg.matmul
387//  CHECK-SAME:         ins(%[[ST_ARG0]], %[[ST_ARG1]]
388//  CHECK-SAME:           : tensor<?x?xf32>, tensor<?x?xf32>)
389//  CHECK-SAME:         outs(%[[ST_ARG2]] : tensor<?x?xf32>)
390//   CHECK-DAG:       %[[STR_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0]
391//   CHECK-DAG:       %[[STR_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]]
392//   CHECK-DAG:       %[[STR_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV0]]]
393//       CHECK:       %[[RHS:.+]] = linalg.matmul
394//  CHECK-SAME:         ins(%[[STR_ARG0]], %[[STR_ARG1]] :
395//  CHECK-SAME:         outs(%[[STR_ARG2]] :
396//       CHECK:       %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]]
397//       CHECK:       %[[ST_RESULT:.+]] = linalg.generic
398//  CHECK-SAME:         ins(%[[LHS]], %[[RHS]] :
399//  CHECK-SAME:         outs(%[[ST_ARG6]] :
400//       CHECK:       %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]]
401//  CHECK-SAME:         into %[[ARG6]][%[[IV0]], %[[IV1]]]
402//       CHECK:       scf.yield %[[UPDATE]]
403//       CHECK:     scf.yield %[[YIELD]]
404//       CHECK:   return %[[RESULT]]
405
406// -----
407
408func.func @matmul_sequence_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
409    %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>, %arg4: tensor<?x?xf32>,
410    %arg5: tensor<?x?xf32>, %arg6: tensor<?x?xf32>) -> tensor<?x?xf32> {
411  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
412    outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N0] * [N0, N1]
413  %1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
414    outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N1] * [N1, N2]
415  %2 = linalg.matmul ins(%1, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
416    outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N2] * [N2, N3]
417  return %2 : tensor<?x?xf32>
418}
419
420module attributes {transform.with_named_sequence} {
421  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
422    %matmuls = transform.structured.match ops{["linalg.matmul"]} in %arg1
423      : (!transform.any_op) -> !transform.any_op
424    %mm1, %mm2, %mm3 = transform.split_handle %matmuls
425      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
426    %a, %b = transform.structured.fuse %mm3 [10]
427      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
428    transform.yield
429  }
430}
431//       CHECK: #[[MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
432//       CHECK: func @matmul_sequence_fusion(
433//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
434//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
435//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
436//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
437//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
438//  CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
439//  CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
440//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
441//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
442//   CHECK-DAG:   %[[ORIG_GEMM1:.+]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] :
443//   CHECK-DAG:   %[[ORIG_GEMM2:.+]] = linalg.matmul ins(%[[ORIG_GEMM1]], %[[ARG3]] :
444//   CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C0]]
445//   CHECK-DAG:   %[[N2:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C1]]
446//   CHECK-DAG:   %[[N3:.+]] = tensor.dim %[[ARG5]], %[[C1]]
447//       CHECK:   %[[R0:.+]] = scf.for %[[IV:[a-zA-Z0-9_]+]] =
448//  CHECK-SAME:       iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
449//   CHECK-DAG:     %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]]
450//   CHECK-DAG:     %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
451//   CHECK-DAG:     %[[TILE_M:.+]] = affine.min #[[MAP]](%[[IV]])[%[[M]]]
452//   CHECK-DAG:     %[[SLICE_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[TILE_M]], %[[N0]]]
453//   CHECK-DAG:     %[[SLICE_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
454//   CHECK-DAG:     %[[SLICE_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV]], 0] [%[[TILE_M]], %[[N1]]]
455//   CHECK-DAG:     %[[TILE_GEMM1:.+]] = linalg.matmul ins(%[[SLICE_ARG0]], %[[SLICE_ARG1]] :
456//  CHECK-SAME:         outs(%[[SLICE_ARG2]] :
457//   CHECK-DAG:     %[[SLICE_ARG3:.+]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[N1]], %[[N2]]]
458//   CHECK-DAG:     %[[SLICE_ARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV]], 0] [%[[TILE_M]], %[[N2]]]
459//   CHECK-DAG:     %[[TILE_GEMM2:.+]] = linalg.matmul ins(%[[TILE_GEMM1]], %[[SLICE_ARG3]] :
460//  CHECK-SAME:         outs(%[[SLICE_ARG4]] :
461//   CHECK-DAG:     %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]]
462//   CHECK-DAG:     %[[SLICE_ARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]]
463//   CHECK-DAG:     %[[TILE_GEMM3:.+]] = linalg.matmul
464//  CHECK-SAME:         ins(%[[TILE_GEMM2]], %[[SLICE_ARG5]] :
465//  CHECK-SAME:         outs(%[[SLICE_ARG6]] :
466//       CHECK:     %[[UPDATE:.+]] = tensor.insert_slice %[[TILE_GEMM3]] into %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]]
467//       CHECK:     scf.yield %[[UPDATE]]
468
469// -----
470
471func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
472  %cst = arith.constant 0.000000e+00 : f32
473  %cst_0 = arith.constant 0xFF800000 : f32
474  %0 = tensor.empty() : tensor<30xf32>
475  %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
476  %2 = linalg.generic {
477      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
478      iterator_types = ["parallel", "reduction"]}
479      ins(%arg0 : tensor<30x3xf32>) outs(%1 : tensor<30xf32>) {
480    ^bb0(%arg1: f32, %arg2: f32):
481      %8 = arith.maximumf %arg2, %arg1 : f32
482      linalg.yield %8 : f32
483    } -> tensor<30xf32>
484  %3 = tensor.empty() : tensor<30x3xf32>
485  %4 = linalg.fill ins(%cst : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
486  %5:2 = linalg.generic {
487      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
488                       affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
489      iterator_types = ["parallel", "reduction"]}
490      ins(%arg0, %2 : tensor<30x3xf32>, tensor<30xf32>) outs(%4, %3 : tensor<30xf32>, tensor<30x3xf32>) {
491    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
492      %8 = arith.subf %arg1, %arg2 : f32
493      %9 = math.exp %8 : f32
494      %10 = arith.addf %arg3, %9 : f32
495      linalg.yield %10, %9 : f32, f32
496    } -> (tensor<30xf32>, tensor<30x3xf32>)
497  %6 = linalg.generic {
498      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
499                       affine_map<(d0, d1) -> (d0, d1)>],
500      iterator_types = ["parallel", "parallel"]}
501      ins(%5#1, %5#0 : tensor<30x3xf32>, tensor<30xf32>) outs(%3 : tensor<30x3xf32>) {
502    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
503      %8 = arith.divf %arg1, %arg2 : f32
504      linalg.yield %8 : f32
505    } -> tensor<30x3xf32>
506  return %6 : tensor<30x3xf32>
507}
508
509module attributes {transform.with_named_sequence} {
510  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
511    %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
512      : (!transform.any_op) -> !transform.any_op
513    %generic1, %generic2, %generic3 = transform.split_handle %generics
514      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
515    %a, %b = transform.structured.fuse %generic3 [10]
516      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
517    transform.yield
518  }
519}
520//       CHECK: func @reduction_sequence(%[[ARG0:.+]]: tensor<30x3xf32>)
521//   CHECK-DAG:   %[[INIT0:.+]] = tensor.empty() : tensor<30xf32>
522//   CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<30x3xf32>
523//       CHECK:   %[[RESULT:[a-zA-Z0-9]+]] = scf.for %[[IV:[a-zA-Z0-9]+]]
524//  CHECK-SAME:       iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]])
525//   CHECK-DAG:     %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0]
526//   CHECK-DAG:     %[[INIT0_SLICE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]]]
527//       CHECK:     %[[FILL0:.+]] = linalg.fill
528//  CHECK-SAME:         outs(%[[INIT0_SLICE]] :
529//       CHECK:     %[[GENERIC0:.+]] = linalg.generic
530//  CHECK-SAME:         ins(%[[ARG0_SLICE]] :
531//  CHECK-SAME:         outs(%[[FILL0]] :
532//       CHECK:     %[[FILL1:.+]] = linalg.fill
533//  CHECK-SAME:         outs(%[[INIT0_SLICE]] :
534//       CHECK:     %[[INIT1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
535//       CHECK:     %[[GENERIC1:.+]]:2 = linalg.generic
536//  CHECK-SAME:         ins(%[[ARG0_SLICE]], %[[GENERIC0]] :
537//  CHECK-SAME:         outs(%[[FILL1]], %[[INIT1_SLICE]] :
538//       CHECK:     %[[ITERARG0_SLICE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
539//       CHECK:     %[[GENERIC2:.+]] = linalg.generic
540//  CHECK-SAME:         ins(%[[GENERIC1]]#1, %[[GENERIC1]]#0 :
541//  CHECK-SAME:         outs(%[[ITERARG0_SLICE]] :
542//   CHECK-DAG:     %[[INSERTSLICE:.+]] = tensor.insert_slice %[[GENERIC2]] into %[[ITERARG0]][%[[IV]], 0]
543//       CHECK:     scf.yield %[[INSERTSLICE]]
544//       CHECK:   return %[[RESULT]]
545
546// -----
547
548func.func @pad_producer_fusion(%arg0 : tensor<10xf32>) -> tensor<16xf32> {
549  %0 = tensor.empty() : tensor<10xf32>
550  %1 = linalg.generic {
551      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
552      iterator_types = ["parallel"]}
553      ins(%arg0 : tensor<10xf32>) outs(%0 : tensor<10xf32>) {
554    ^bb0(%b0 : f32, %b1 : f32):
555      %2 = arith.addf %b0, %b0: f32
556      linalg.yield %2 : f32
557  } -> tensor<10xf32>
558  %cst = arith.constant 0.0 : f32
559  %2 = tensor.pad %1 low[4] high[2] {
560    ^bb0(%arg1 : index):
561      tensor.yield %cst : f32
562  } : tensor<10xf32> to tensor<16xf32>
563  return %2 : tensor<16xf32>
564}
565module attributes {transform.with_named_sequence} {
566  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
567    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
568      : (!transform.any_op) -> !transform.any_op
569    %pad = transform.structured.match ops{["tensor.pad"]} in %arg1
570      : (!transform.any_op) -> !transform.any_op
571    %a, %b = transform.structured.fuse %pad [8]
572      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
573    transform.yield
574  }
575}
576// CHECK-LABEL: func @pad_producer_fusion
577//  CHECK-SAME:     %[[ARG0:.+]]: tensor<10xf32>
578//       CHECK:   %[[FOR_RESULT:.+]] = scf.for
579//       CHECK:     %[[IF_RESULT:.+]] = scf.if
580//       CHECK:     else
581//       CHECK:       %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
582//       CHECK:       %[[GENERIC:.+]] = linalg.generic
583//  CHECK-SAME:           ins(%[[SLICE]] :
584//       CHECK:       %[[PAD:.+]] = tensor.pad %[[GENERIC]]
585//       CHECK:       %[[CAST:.+]] = tensor.cast %[[PAD]]
586//       CHECK:       scf.yield %[[CAST]]
587//       CHECK:     %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[IF_RESULT]]
588//       CHECK:     scf.yield %[[INSERT_SLICE]]
589//       CHECK:   return %[[FOR_RESULT]]
590
591// -----
592
593func.func @imperfect_unpack_producer_fusion(%source: tensor<1x1x288x8x4xf32>, %dest: tensor<1x2x1152xf32>) -> tensor<1x2x1152xf32> {
594  %0 = tensor.unpack %source
595      outer_dims_perm = [0, 1, 2]
596      inner_dims_pos = [1, 2]
597      inner_tiles = [8, 4] into %dest
598      : tensor<1x1x288x8x4xf32> -> tensor<1x2x1152xf32>
599  %1 = tensor.empty() : tensor<1x2x1152xf32>
600  %cst = arith.constant 1.0 : f32
601  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
602                                        affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
603                       iterator_types = ["parallel", "parallel", "parallel"]}
604                       ins(%0 : tensor<1x2x1152xf32>)
605                       outs(%1 : tensor<1x2x1152xf32>) {
606  ^bb0(%in: f32, %out: f32):
607    %7 = arith.addf %in, %cst : f32
608    linalg.yield %7 : f32
609  } -> tensor<1x2x1152xf32>
610  return %2 : tensor<1x2x1152xf32>
611}
612
613module attributes {transform.with_named_sequence} {
614  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
615    %matmul = transform.structured.match ops{["linalg.generic"]} in %arg1
616      : (!transform.any_op) -> !transform.any_op
617    %a, %b = transform.structured.fuse %matmul [0, 1, 0]
618      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
619    transform.yield
620  }
621}
622
623// CHECK-LABEL: func @imperfect_unpack_producer_fusion
624//  CHECK-SAME:     %[[ARG0:.+]]: tensor<1x1x288x8x4xf32>
625//  CHECK-SAME:     %[[ARG1:.+]]: tensor<1x2x1152xf32>
626//       CHECK:   %[[FOR_RESULT:.+]] = scf.for{{.*}}iter_args(%[[ITER_ARG:.+]] = {{.*}})
627//       CHECK:     %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
628//       CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[SLICE]]
629//   CHECK-DAG:     %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
630//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]]
631//       CHECK:     %[[GENERIC:.+]] = linalg.generic
632//  CHECK-SAME:         ins(%[[UNPACK_SLICE]]
633//  CHECK-SAME:         outs(%[[INIT_SLICE]]
634//       CHECK:     %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER_ARG]]
635//       CHECK:     scf.yield %[[INSERT_SLICE]]
636//       CHECK:   return %[[FOR_RESULT]]
637