xref: /llvm-project/mlir/test/Dialect/Linalg/transform-op-tile.mlir (revision 6740d701bde4ad9b95d7d811852fa0a2542e6b28)
1// RUN: mlir-opt --transform-interpreter --mlir-print-local-scope --split-input-file --verify-diagnostics --cse %s | FileCheck %s
2
3module attributes {transform.with_named_sequence} {
4  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
5    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
6    %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [4, 4, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
7    transform.yield
8  }
9}
10
11// CHECK-LABEL: func @tile_linalg_matmul(
12// CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<128x128xf32>
13// CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<128x128xf32>
14// CHECK-SAME:    %[[TC:[0-9a-z]+]]: tensor<128x128xf32>
15// CHECK-SAME:  -> tensor<128x128xf32> {
16func.func @tile_linalg_matmul(
17  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
18    -> tensor<128x128xf32> {
19//      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<128x128xf32>) {
20//      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<128x128xf32>) {
21//      CHECK:     %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<128x128xf32>) {
22//      CHECK:       %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
23//      CHECK:       %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
24//      CHECK:       %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
25//      CHECK:       %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<4x4xf32>, tensor<4x4xf32>)
26// CHECK-SAME:                                   outs(%[[sTC]] : tensor<4x4xf32>)  -> tensor<4x4xf32>
27//      CHECK:       %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}]  : tensor<4x4xf32> into tensor<128x128xf32>
28//      CHECK:       scf.yield %[[TD]] : tensor<128x128xf32>
29//      CHECK:     scf.yield %[[TD2]] : tensor<128x128xf32>
30//      CHECK:   scf.yield %[[TD1]] : tensor<128x128xf32>
31  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
32                     outs(%arg2: tensor<128x128xf32>)
33    -> tensor<128x128xf32>
34
35//      CHECK: return %[[TD0]] : tensor<128x128xf32>
36  return %0 : tensor<128x128xf32>
37}
38
39// -----
40
41module attributes {transform.with_named_sequence} {
42  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
43    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
44    %1 = transform.structured.match ops{["func.call"]} in %arg1 : (!transform.any_op) -> !transform.any_op
45    %2, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [%1, %1, 4] : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
46    transform.yield
47  }
48}
49
50func.func private @get_dynamic_tile_size() -> index
51
52// CHECK-LABEL: func @tile_linalg_matmul_dynamic(
53// CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<128x128xf32>
54// CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<128x128xf32>
55// CHECK-SAME:    %[[TC:[0-9a-z]+]]: tensor<128x128xf32>
56// CHECK-SAME:  -> tensor<128x128xf32> {
57func.func @tile_linalg_matmul_dynamic(
58  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
59    -> tensor<128x128xf32> {
60//      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<128x128xf32>) {
61//      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<128x128xf32>) {
62//      CHECK:     %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<128x128xf32>) {
63//      CHECK:       %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<128x128xf32> to tensor<?x4xf32>
64//      CHECK:       %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<128x128xf32> to tensor<4x?xf32>
65//      CHECK:       %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<128x128xf32> to tensor<?x?xf32>
66//      CHECK:       %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<?x4xf32>, tensor<4x?xf32>)
67// CHECK-SAME:                                   outs(%[[sTC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
68//      CHECK:       %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}]  : tensor<?x?xf32> into tensor<128x128xf32>
69//      CHECK:       scf.yield %[[TD]] : tensor<128x128xf32>
70//      CHECK:     scf.yield %[[TD2]] : tensor<128x128xf32>
71//      CHECK:   scf.yield %[[TD1]] : tensor<128x128xf32>
72  %sz = func.call @get_dynamic_tile_size() : () -> index
73  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
74                     outs(%arg2: tensor<128x128xf32>)
75    -> tensor<128x128xf32>
76
77//      CHECK: return %[[TD0]] : tensor<128x128xf32>
78  return %0 : tensor<128x128xf32>
79}
80
81// -----
82
83module attributes {transform.with_named_sequence} {
84  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
85    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
86    // expected-note @below {{for this parameter}}
87    %1 = transform.test_produce_param (0 : i64) : !transform.param<i64>
88    // expected-error @below {{expected as many parameter values (0) as target ops (2)}}
89    transform.structured.tile_using_for %0 tile_sizes [%1, %1, %1]
90      : (!transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
91      -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
92      transform.yield
93  }
94}
95
96func.func @tile_linalg_matmul(
97  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
98    -> (tensor<128x128xf32>, tensor<128x128xf32>) {
99  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
100                     outs(%arg2: tensor<128x128xf32>)
101    -> tensor<128x128xf32>
102  %1 = linalg.matmul  ins(%0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
103                     outs(%arg2: tensor<128x128xf32>)
104    -> tensor<128x128xf32>
105  return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32>
106}
107
108// -----
109
110module attributes {transform.with_named_sequence} {
111  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
112    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
113    // expected-note @below {{for this handle}}
114    %1 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
115    // expected-error @below {{expected as many dynamic size-producing operations (0) as target ops (2)}}
116    transform.structured.tile_using_for %0 tile_sizes [%1, %1, 1]
117      : (!transform.any_op, !transform.any_op, !transform.any_op)
118      -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
119      transform.yield
120  }
121}
122
123func.func @tile_linalg_matmul(
124  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
125    -> (tensor<128x128xf32>, tensor<128x128xf32>) {
126  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
127                     outs(%arg2: tensor<128x128xf32>)
128    -> tensor<128x128xf32>
129  %1 = linalg.matmul  ins(%0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
130                     outs(%arg2: tensor<128x128xf32>)
131    -> tensor<128x128xf32>
132  return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32>
133}
134
135// -----
136
137// CHECK-LABEL: tile_tensor_pad
138func.func @tile_tensor_pad(
139  %arg0 : tensor<?x?xf32>, %cst : f32, %low: index, %high: index)
140    -> tensor<20x40xf32>
141{
142  // CHECK: scf.forall
143  // CHECK:   scf.if
144  // CHECK:     tensor.generate
145  // CHECK:   else
146  // CHECK:     tensor.pad {{.*}} nofold
147  %0 = tensor.pad %arg0 nofold low[%low, %low] high[%high, %high] {
148        ^bb0(%arg9: index, %arg10: index):
149          tensor.yield %cst : f32
150  } : tensor<?x?xf32> to tensor<20x40xf32>
151  return %0 : tensor<20x40xf32>
152}
153
154module attributes {transform.with_named_sequence} {
155  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
156    %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op
157    transform.structured.tile_using_forall %0 tile_sizes[1, 1]
158           : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
159           transform.yield
160  }
161}
162
163// -----
164
165#map = affine_map<(d0) -> (d0)>
166
167module {
168  func.func @scalable_tile(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: f32) -> tensor<?xf32> {
169    %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) outs(%arg2 : tensor<?xf32>) {
170    ^bb0(%in_1: f32, %in_2: f32, %out: f32):
171      %1 = arith.addf %in_1, %in_2 : f32
172      %2 = arith.mulf %arg3, %1 : f32
173      linalg.yield %2 : f32
174    } -> tensor<?xf32>
175    return %0 : tensor<?xf32>
176  }
177}
178
179// CHECK-LABEL:   func.func @scalable_tile(
180// CHECK-SAME:      %[[ARG_0:.*]]: tensor<?xf32>, %[[ARG_1:.*]]: tensor<?xf32>, %[[ARG_2:.*]]: tensor<?xf32>,
181// CHECK:           %[[C0:.*]] = arith.constant 0 : index
182// CHECK:           %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[C0]] : tensor<?xf32>
183// CHECK:           %[[VEC_SIZE:.*]] = arith.constant 4 : index
184// CHECK:           %[[VS:.*]] = vector.vscale
185// CHECK:           %[[STEP:.*]] = arith.muli %[[VEC_SIZE]], %[[VS]] : index
186// CHECK:           scf.for %[[IV:.*]] = %[[C0]] to %[[DIM]] step %[[STEP]] iter_args(%[[VAL:.*]] = %[[ARG_2]]) -> (tensor<?xf32>) {
187// CHECK:             %[[SIZE:.*]] = affine.min affine_map<(d0)[s0, s1] -> (-d0 + s0, s1)>(%[[IV]])[%[[DIM]], %[[STEP]]]
188// CHECK:             %[[SLICE_ARG0:.*]] = tensor.extract_slice %[[ARG_0]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
189// CHECK:             %[[SLICE_ARG1:.*]] = tensor.extract_slice %[[ARG_1]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
190// CHECK:             %[[SLICE_ARG2:.*]] = tensor.extract_slice %[[VAL]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
191// CHECK:             linalg.generic {indexing_maps = {{.*}}, iterator_types = ["parallel"]} ins(%[[SLICE_ARG0]], %[[SLICE_ARG1]] : tensor<?xf32>, tensor<?xf32>) outs(%[[SLICE_ARG2]] : tensor<?xf32>) {
192
193  module attributes {transform.with_named_sequence} {
194  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
195      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
196      %1, %loop = transform.structured.tile_using_for %0 tile_sizes [[4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
197      transform.yield
198  }
199  }
200
201// -----
202
203// CHECK-LABEL:   func.func @scalable_and_fixed_length_tile
204//   CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
205//   CHECK-DAG:     %[[VS:.*]] = vector.vscale
206//   CHECK-DAG:     %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
207//   CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
208//   CHECK-DAG:     %[[C128:.*]] = arith.constant 128 : index
209//       CHECK:     scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[C4]]
210//       CHECK:       scf.for %[[VAL_16:.*]] = %[[C0]] to %[[C128]] step %[[C4]]
211//       CHECK:         scf.for %{{.*}} = %[[C0]] to %[[C128]] step %[[STEP_2]]
212
213func.func @scalable_and_fixed_length_tile(
214  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
215    -> tensor<128x128xf32> {
216  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
217                     outs(%arg2: tensor<128x128xf32>)
218    -> tensor<128x128xf32>
219
220  return %0 : tensor<128x128xf32>
221}
222
223module attributes {transform.with_named_sequence} {
224  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
225    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
226    %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
227    transform.yield
228  }
229}
230
231// -----
232
233func.func @too_many_tiles(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
234                          %arg2: tensor<128x128xf32>) ->  tensor<128x128xf32> {
235  // expected-note @below {{target op}}
236  %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
237                     outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32>
238  return %0 : tensor<128x128xf32>
239}
240
241module attributes {transform.with_named_sequence} {
242  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
243    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
244    // expected-error @below {{too many tiles provided, expected at most 3 found 4}}
245    %1, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
246    transform.yield
247  }
248}
249
250// -----
251
252module attributes {transform.with_named_sequence} {
253  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
254    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
255    // expected-error @below {{op expected number of loops to tile (3) to match number of `loops` results (1)}}
256    %1, %loops = transform.structured.tile_using_for %0 tile_sizes [4, 4, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
257    transform.yield
258  }
259}
260
261func.func @tile_linalg_matmul(
262  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
263    -> tensor<128x128xf32> {
264  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
265                     outs(%arg2: tensor<128x128xf32>)
266    -> tensor<128x128xf32>
267  return %0 : tensor<128x128xf32>
268}
269