xref: /llvm-project/mlir/test/Dialect/Linalg/transform-op-decompose.mlir (revision 36c3466aef6c8bfde0ddc736b8403e2c45f5e1c6)
1// RUN: mlir-opt --transform-interpreter --split-input-file %s | FileCheck %s
2
3// CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
4// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
5
6// CHECK-LABEL: @conv_2d_nhwc_hwcf
7// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
8// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
9// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
10func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
11  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
12  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
13  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
14  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_nwc_wcf
15  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
16  %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
17                                 strides = dense<1> : tensor<2xi64>}
18     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>)
19    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
20  // CHECK: return %[[RES]]
21  return %0 : tensor<?x1x?x?xf32>
22}
23
24// CHECK-LABEL: @conv_2d_nchw_fchw
25// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
26// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
27// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
28func.func @conv_2d_nchw_fchw(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x?x1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
29  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
30  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
31  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
32  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_ncw_fcw
33  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
34  %0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>,
35                                 strides = dense<1> : tensor<2xi64>}
36     ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>)
37    outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
38  // CHECK: return %[[RES]]
39  return %0 : tensor<?x?x1x?xf32>
40}
41
42// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
43// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
44// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>
45func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> {
46  // CHECK: %[[RES:.+]] = tensor.empty
47  %init = tensor.empty() : tensor<1x1x56x96xf32>
48  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
49  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
50  // CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]]
51  // CHECK: %[[OPRES:.+]] = linalg.depthwise_conv_1d_nwc_wc
52  // CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]]
53  // CHECK-SAME: outs(%[[SLICERES]]
54  // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[OPRES]] into %[[RES]]
55  %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
56         ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>)
57         outs(%init: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32>
58  // CHECK: %[[INSERTED]]
59  return %0: tensor<1x1x56x96xf32>
60}
61
62// CHECK-LABEL: @conv_2d
63// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<1x?xf32>,
64// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
65// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<1x?xf32>)
66func.func @conv_2d(%input: tensor<1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<1x?xf32>) -> tensor<1x?xf32> {
67  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
68  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
69  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
70  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d
71  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
72  %0 = linalg.conv_2d
73     ins (%input, %filter: tensor<1x?xf32>, tensor<1x?xf32>)
74    outs (%init: tensor<1x?xf32>) -> tensor<1x?xf32>
75  // CHECK: return %[[RES]]
76  return %0 : tensor<1x?xf32>
77}
78
79// CHECK-LABEL: @pooling_nhwc_sum
80// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
81// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
82// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
83func.func @pooling_nhwc_sum(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
84  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
85  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
86  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
87  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_sum
88  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
89  %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>,
90                                strides = dense<1> : tensor<2xi64>}
91     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
92    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
93  // CHECK: return %[[RES]]
94  return %0 : tensor<?x1x?x?xf32>
95}
96
97// CHECK-LABEL: @pooling_nchw_sum
98// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
99// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
100// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
101func.func @pooling_nchw_sum(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
102  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
103  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
104  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
105  // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_sum
106  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
107  %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>,
108                                strides = dense<1> : tensor<2xi64>}
109     ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
110    outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
111  // CHECK: return %[[RES]]
112  return %0 : tensor<?x?x1x?xf32>
113}
114
115// CHECK-LABEL: @pooling_nhwc_max
116// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
117// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
118// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
119func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
120  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
121  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
122  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
123  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max
124  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
125  %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>,
126                                strides = dense<1> : tensor<2xi64>}
127     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
128    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
129  // CHECK: return %[[RES]]
130  return %0 : tensor<?x1x?x?xf32>
131}
132
133// CHECK-LABEL: @pooling_nhwc_max_unsigned
134// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
135// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
136// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
137func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
138  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
139  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
140  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
141  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max_unsigned
142  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
143  %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
144                                strides = dense<1> : tensor<2xi64>}
145     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
146    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
147  // CHECK: return %[[RES]]
148  return %0 : tensor<?x1x?x?xf32>
149}
150
151// CHECK-LABEL: @pooling_nhwc_min
152// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
153// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
154// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
155func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
156  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
157  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
158  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
159  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min
160  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
161  %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>,
162                                strides = dense<1> : tensor<2xi64>}
163     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
164    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
165  // CHECK: return %[[RES]]
166  return %0 : tensor<?x1x?x?xf32>
167}
168
169// CHECK-LABEL: @pooling_nhwc_min_unsigned
170// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
171// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
172// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
173func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
174  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
175  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
176  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
177  // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min_unsigned
178  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
179  %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
180                                strides = dense<1> : tensor<2xi64>}
181     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
182    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
183  // CHECK: return %[[RES]]
184  return %0 : tensor<?x1x?x?xf32>
185}
186
187// CHECK-LABEL: @pooling_nchw_max
188// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
189// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
190// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
191func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
192  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
193  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
194  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
195  // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_max
196  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
197  %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>,
198                                strides = dense<1> : tensor<2xi64>}
199     ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
200    outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
201  // CHECK: return %[[RES]]
202  return %0 : tensor<?x?x1x?xf32>
203}
204
205func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
206  %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
207  return %1 : tensor<2x16x32xf32>
208}
209
210// CHECK-LABEL:      func.func @softmax(
211// CHECK-SAME:           %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
212// CHECK-DAG:        %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
213// CHECK-DAG:        %[[CST:.+]] = arith.constant 0xFFC00000 : f32
214// CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
215// CHECK:        %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
216// CHECK-SAME:     "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
217// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
218// CHECK:          %[[D8:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32
219// CHECK:          linalg.yield %[[D8]] : f32
220// CHECK:        } -> tensor<2x16xf32>
221// CHECK:        %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
222// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
223// CHECK-SAME:     outs(%[[DST]] : tensor<2x16x32xf32>) {
224// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
225// CHECK:          %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32
226// CHECK:          %[[D9:.+]] = math.exp %[[D8]] : f32
227// CHECK:          linalg.yield %[[D9]] : f32
228// CHECK:        } -> tensor<2x16x32xf32>
229// CHECK:        %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
230// CHECK:        %[[D5:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
231// CHECK:        %[[D6:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
232// CHECK-SAME:     "parallel", "reduction"]} ins(%[[D4]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) {
233// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
234// CHECK:          %[[D8]] = arith.addf %[[IN]], %[[OUT]] : f32
235// CHECK:          linalg.yield %[[D8]] : f32
236// CHECK:        } -> tensor<2x16xf32>
237// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
238// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
239// CHECK-SAME:     outs(%[[DST]] : tensor<2x16x32xf32>) {
240// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
241// CHECK:          %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32
242// CHECK:          linalg.yield %[[D8]] : f32
243// CHECK:        } -> tensor<2x16x32xf32>
244// CHECK:        return %[[D7]] : tensor<2x16x32xf32>
245
246module attributes {transform.with_named_sequence} {
247  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
248    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
249    %1 = transform.structured.decompose %0 : (!transform.any_op) -> !transform.any_op
250
251    %2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
252    %3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> !transform.any_op
253    transform.yield
254  }
255}
256