xref: /llvm-project/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir (revision 618f231a6d3ef41d231e2a4d1e2eca4c0d709802)
1// RUN: mlir-opt --transform-interpreter --scf-for-loop-canonicalization --canonicalize --split-input-file %s | FileCheck %s
2// RUN: mlir-opt --transform-interpreter --split-input-file %s | FileCheck %s --check-prefix=NOCANON
3
4// This implements a 2D multisize tiling with target sizes [3, 10].
5module attributes {transform.with_named_sequence} {
6  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
7    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
8    %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op
9    %split = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
10    %2:2 = transform.split_handle %split : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
11    %3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
12    %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
13    %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
14    transform.foreach %5 : !transform.any_op {
15    ^bb0(%inner_linalg: !transform.any_op):
16      %low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
17      %split2 = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
18      %inner_linalg_low, %inner_linalg_high = transform.split_handle %split2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
19      transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
20      transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
21    }
22    transform.yield
23  }
24}
25
26func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
27
28// Without canonicalization, tile sizes are computed dynamically as affine maps.
29// NOCANON-LABEL: @two_d
30// NOCANON-COUNT-8: affine.apply
31// NOCANON:         scf.for
32
33// CHECK-LABEL: @two_d
34// CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32>
35func.func @two_d(%arg0: tensor<10x34xf32>,
36                 %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> {
37  %0 = linalg.generic {
38    indexing_maps = [affine_map<(i, j) -> (i, j)>,
39                     affine_map<(i, j) -> (i, j)>],
40    iterator_types = ["parallel", "parallel"]
41  }
42  ins(%arg0: tensor<10x34xf32>)
43  outs(%arg1: tensor<10x34xf32>) {
44  ^bb0(%0: f32, %1: f32):
45    %i = linalg.index 0 : index
46    %j = linalg.index 1 : index
47    %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32
48    linalg.yield %call_res : f32
49  } -> tensor<10x34xf32>
50
51  // 2D multi-size tiling should produce for quadrants with sizes
52  //   (2, 8), (2, 9), (3, 8), (3, 9)
53  // respectively, and in this order.
54  // Check the full code for the first quadrant, the data flow for the second
55  // quadrant and only the overall code structure for the remaining quadrants.
56  // The canonicalizer is able to recover static shapes of for linalg.generic
57  // instances, use those to differentiate the quadrants.
58
59  // CHECK:      %[[SLICE_1_IN:.+]] = tensor.extract_slice %[[IN]][0, 0] [4, 34] [1, 1]
60  // CHECK:      %[[SLICE_1:.+]] = tensor.extract_slice %[[OUT]][0, 0] [4, 34] [1, 1]
61  // CHECK:      scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_1:.+]] = %[[SLICE_1]])
62  // CHECK:        %[[OUTSLICE_1_IN:.+]] = tensor.extract_slice %[[SLICE_1_IN]][%[[I1]], 0] [2, 34] [1, 1]
63  // CHECK:        %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [2, 34] [1, 1]
64
65  // CHECK:        %[[SLICE_2_IN:.+]] = tensor.extract_slice %[[OUTSLICE_1_IN]][0, 0] [2, 16] [1, 1]
66  // CHECK:        %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
67  // CHECK:        %[[LOOPRES:.+]] = scf.for %[[I2:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_2:.+]] = %[[SLICE_2]])
68  // CHECK:          %[[INSLICE_2:.+]] = tensor.extract_slice %[[SLICE_2_IN]][0, %[[I2]]] [2, 8] [1, 1]
69  // CHECK:          %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [2, 8] [1, 1]
70  // CHECK:          %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<2x8xf32>) outs(%[[OUTSLICE_2]] : tensor<2x8xf32>)
71  // CHECK:          %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]]
72  // CHECK:          scf.yield %[[RESPARTIAL]]
73
74  // CHECK:        %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
75  // CHECK:        %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1]
76  // CHECK:        scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]])
77  // CHECK-COUNT-2:  tensor.extract_slice
78  // CHECK:          linalg.generic {{.*}} ins(%{{.*}} : tensor<2x9xf32>)
79  // CHECK:          tensor.insert_slice
80  // CHECK:          scf.yield
81  // CHECK:        %[[INSERTED_2:.+]] = tensor.insert_slice %{{.*}} into %[[INSERTED]]
82  // CHECK:        %[[INSERTED_3:.+]] = tensor.insert_slice %[[INSERTED_2]] into %[[ITERARG_1]]
83  // CHECK:        scf.yield %[[INSERTED_3]]
84
85  // CHECK:        tensor.insert_slice
86  // CHECK:        tensor.extract_slice
87  // CHECK:        scf.for
88  // CHECK-COUNT-2:  tensor.extract_slice
89  // CHECK:          scf.for
90  // CHECK-COUNT-2:    tensor.extract_slice
91  // CHECK:            linalg.generic {{.*}} ins(%{{.*}} : tensor<3x8xf32>)
92  // CHECK:            tensor.insert_slice
93  // CHECK:            scf.yield
94  // CHECK:          tensor.insert_slice
95  // CHECK:          tensor.extract_slice
96  // CHECK:          scf.for
97  // CHECK-COUNT-2:    tensor.extract_slice
98  // CHECK:            linalg.generic {{.*}} ins(%{{.*}} : tensor<3x9xf32>)
99  // CHECK:            tensor.insert_slice
100  // CHECK:            scf.yield
101  // CHECK-COUNT-2:  tensor.insert_slice
102  // CHECK:          scf.yield
103  // CHECK:        %[[RESULT:.+]] = tensor.insert_slice
104  // CHECK:        return %[[RESULT]]
105
106  return %0 : tensor<10x34xf32>
107}
108
109// -----
110
111module attributes {transform.with_named_sequence} {
112  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
113    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
114    %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.param<i64>
115    %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.param<i64>
116    %split = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.param<i64>
117    %2:2 = transform.split_handle %split : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
118    %3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
119    %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
120    %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
121    %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
122    transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> {
123    ^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>):
124      %split2 = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
125      %inner_linalg_low, %inner_linalg_high = transform.split_handle %split2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
126      transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
127      transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
128    }
129    transform.yield
130  }
131}
132
133func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
134
135// Even without canonicalization, tile sizes can be computed statically thanks
136// to parameters.
137// NOCANON-LABEL: @two_d
138// NOCANON-NOT:   affine.apply
139// NOCANON:       scf.for
140
141// CHECK-LABEL: @two_d_param
142// CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32>
143func.func @two_d_param(%arg0: tensor<10x34xf32>,
144                       %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> {
145  %0 = linalg.generic {
146    indexing_maps = [affine_map<(i, j) -> (i, j)>,
147                     affine_map<(i, j) -> (i, j)>],
148    iterator_types = ["parallel", "parallel"]
149  }
150  ins(%arg0: tensor<10x34xf32>)
151  outs(%arg1: tensor<10x34xf32>) {
152  ^bb0(%0: f32, %1: f32):
153    %i = linalg.index 0 : index
154    %j = linalg.index 1 : index
155    %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32
156    linalg.yield %call_res : f32
157  } -> tensor<10x34xf32>
158
159  // CHECK:      %[[SLICE_1_IN:.+]] = tensor.extract_slice %[[IN]][0, 0] [4, 34] [1, 1]
160  // CHECK:      %[[SLICE_1:.+]] = tensor.extract_slice %[[OUT]][0, 0] [4, 34] [1, 1]
161  // CHECK:      scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_1:.+]] = %[[SLICE_1]])
162  // CHECK:        %[[OUTSLICE_1_IN:.+]] = tensor.extract_slice %[[SLICE_1_IN]][%[[I1]], 0] [2, 34] [1, 1]
163  // CHECK:        %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [2, 34] [1, 1]
164
165  // CHECK:        %[[SLICE_2_IN:.+]] = tensor.extract_slice %[[OUTSLICE_1_IN]][0, 0] [2, 16] [1, 1]
166  // CHECK:        %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
167  // CHECK:        %[[LOOPRES:.+]] = scf.for %[[I2:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_2:.+]] = %[[SLICE_2]])
168  // CHECK:          %[[INSLICE_2:.+]] = tensor.extract_slice %[[SLICE_2_IN]][0, %[[I2]]] [2, 8] [1, 1]
169  // CHECK:          %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [2, 8] [1, 1]
170  // CHECK:          %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<2x8xf32>) outs(%[[OUTSLICE_2]] : tensor<2x8xf32>)
171  // CHECK:          %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]]
172  // CHECK:          scf.yield %[[RESPARTIAL]]
173
174  // CHECK:        %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
175  // CHECK:        %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1]
176  // CHECK:        scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]])
177  // CHECK-COUNT-2:  tensor.extract_slice
178  // CHECK:          linalg.generic {{.*}} ins(%{{.*}} : tensor<2x9xf32>)
179  // CHECK:          tensor.insert_slice
180  // CHECK:          scf.yield
181  // CHECK:        %[[INSERTED_2:.+]] = tensor.insert_slice %{{.*}} into %[[INSERTED]]
182  // CHECK:        %[[INSERTED_3:.+]] = tensor.insert_slice %[[INSERTED_2]] into %[[ITERARG_1]]
183  // CHECK:        scf.yield %[[INSERTED_3]]
184
185  // CHECK:        tensor.insert_slice
186  // CHECK:        tensor.extract_slice
187  // CHECK:        scf.for
188  // CHECK-COUNT-2:  tensor.extract_slice
189  // CHECK:          scf.for
190  // CHECK-COUNT-2:    tensor.extract_slice
191  // CHECK:            linalg.generic {{.*}} ins(%{{.*}} : tensor<3x8xf32>)
192  // CHECK:            tensor.insert_slice
193  // CHECK:            scf.yield
194  // CHECK:          tensor.insert_slice
195  // CHECK:          tensor.extract_slice
196  // CHECK:          scf.for
197  // CHECK-COUNT-2:    tensor.extract_slice
198  // CHECK:            linalg.generic {{.*}} ins(%{{.*}} : tensor<3x9xf32>)
199  // CHECK:            tensor.insert_slice
200  // CHECK:            scf.yield
201  // CHECK-COUNT-2:  tensor.insert_slice
202  // CHECK:          scf.yield
203  // CHECK:        %[[RESULT:.+]] = tensor.insert_slice
204  // CHECK:        return %[[RESULT]]
205
206  return %0 : tensor<10x34xf32>
207}
208