xref: /llvm-project/mlir/test/Dialect/Linalg/named-ops.mlir (revision 0d4efa27252cbbea4b5672d4d8ffc15a3ba51d83)
1// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
2
3// CHECK-LABEL: func @depthwise_conv_1d_nwc_wcm
4func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> {
5  %zero = arith.constant 0.000000e+00 : f32
6  %init = tensor.empty() : tensor<1x10x8x8xf32>
7  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x10x8x8xf32>) -> tensor<1x10x8x8xf32>
8  // CHECK: depthwise_conv_1d_nwc_wcm
9  %0 = linalg.depthwise_conv_1d_nwc_wcm {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
10    ins(%input, %filter : tensor<1x12x8xf32>, tensor<3x8x8xf32>)
11    outs(%fill : tensor<1x10x8x8xf32>) -> tensor<1x10x8x8xf32>
12  return %0 : tensor<1x10x8x8xf32>
13}
14
15// -----
16
17// CHECK-LABEL: func @depthwise_conv_1d_nwc_wc
18func.func @depthwise_conv_1d_nwc_wc(%input: tensor<1x12x8xf32>, %filter: tensor<3x8xf32>) -> tensor<1x10x8xf32> {
19  %zero = arith.constant 0.000000e+00 : f32
20  %init = tensor.empty() : tensor<1x10x8xf32>
21  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x10x8xf32>) -> tensor<1x10x8xf32>
22  // CHECK: depthwise_conv_1d_nwc_wc
23  %0 = linalg.depthwise_conv_1d_nwc_wc {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
24    ins(%input, %filter : tensor<1x12x8xf32>, tensor<3x8xf32>)
25    outs(%fill : tensor<1x10x8xf32>) -> tensor<1x10x8xf32>
26  return %0 : tensor<1x10x8xf32>
27}
28
29// -----
30
31// CHECK-LABEL: func @depthwise_conv_1d_ncw_cw
32func.func @depthwise_conv_1d_ncw_cw(%input: tensor<1x8x12xf32>, %filter: tensor<8x3xf32>) -> tensor<1x8x10xf32> {
33  %zero = arith.constant 0.000000e+00 : f32
34  %init = tensor.empty() : tensor<1x8x10xf32>
35  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x8x10xf32>) -> tensor<1x8x10xf32>
36  // CHECK: depthwise_conv_1d_ncw_cw
37  %0 = linalg.depthwise_conv_1d_ncw_cw {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
38    ins(%input, %filter : tensor<1x8x12xf32>, tensor<8x3xf32>)
39    outs(%fill : tensor<1x8x10xf32>) -> tensor<1x8x10xf32>
40  return %0 : tensor<1x8x10xf32>
41}
42
43// -----
44
45// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwcm_tensor
46func.func @depthwise_conv_2d_nhwc_hwcm_tensor(%input: tensor<2x4x5x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x3x4x2x3xf32> {
47  %zero = arith.constant 0.000000e+00 : f32
48  %init = tensor.empty() : tensor<2x3x4x2x3xf32>
49  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
50  // CHECK:      %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwcm
51  // CHECK-SAME:   {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
52  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
53  // CHECK-SAME:   outs(%{{.+}} : tensor<2x3x4x2x3xf32>)
54  %0 = linalg.depthwise_conv_2d_nhwc_hwcm
55     { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
56     ins(%input, %filter : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
57    outs(%fill : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
58  return %0 : tensor<2x3x4x2x3xf32>
59}
60
61// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwcm_memref
62func.func @depthwise_conv_2d_nhwc_hwcm_memref(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
63  // CHECK:      linalg.depthwise_conv_2d_nhwc_hwcm
64  // CHECK-SAME:   {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
65  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
66  // CHECK-SAME:   outs(%{{.+}} : memref<2x3x4x2x3xf32>)
67  linalg.depthwise_conv_2d_nhwc_hwcm
68     { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
69     ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
70    outs(%output : memref<2x3x4x2x3xf32>)
71  return
72}
73
74// CHECK-LABEL: func @depthwise_conv_1d_nw_tensor
75func.func @depthwise_conv_1d_nw_tensor(%input: tensor<1x113x96xf32>, %filter: tensor<3x96xf32>) -> tensor<1x56x96xf32> {
76  %init = tensor.empty() : tensor<1x56x96xf32>
77  // CHECK:      %{{.+}} = linalg.depthwise_conv_1d_nw
78  // CHECK-SAME:   {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>}
79  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<1x113x96xf32>, tensor<3x96xf32>)
80  // CHECK-SAME:   outs(%{{.+}} : tensor<1x56x96xf32>) -> tensor<1x56x96xf32>
81  %0 = linalg.depthwise_conv_1d_nwc_wc {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>}
82         ins(%input, %filter: tensor<1x113x96xf32>, tensor<3x96xf32>)
83         outs(%init: tensor<1x56x96xf32>) -> tensor<1x56x96xf32>
84  return %0: tensor<1x56x96xf32>
85}
86
87// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwc_tensor
88func.func @depthwise_conv_2d_nhwc_hwc_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>) -> tensor<1x56x56x96xf32> {
89  %init = tensor.empty() : tensor<1x56x56x96xf32>
90  // CHECK:      %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwc
91  // CHECK-SAME:   {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
92  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
93  // CHECK-SAME:   outs(%{{.+}} : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
94  %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
95         ins(%input, %filter: tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
96         outs(%init: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
97  return %0: tensor<1x56x56x96xf32>
98}
99
100// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwc_memref
101func.func @depthwise_conv_2d_nhwc_hwc_memref(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
102  // CHECK:      linalg.depthwise_conv_2d_nhwc_hwc
103  // CHECK-SAME:   {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
104  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<1x113x113x96xf32>, memref<3x3x96xf32>)
105  // CHECK-SAME:   outs(%{{.+}} : memref<1x56x56x96xf32>)
106  linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
107    ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
108    outs(%output: memref<1x56x56x96xf32>)
109  return
110}
111
112// CHECK-LABEL: func @depthwise_conv_2d_nchw_chw_tensor
113func.func @depthwise_conv_2d_nchw_chw_tensor(%input: tensor<1x96x113x113xf32>, %filter: tensor<96x3x3xf32>) -> tensor<1x96x56x56xf32> {
114  %init = tensor.empty() : tensor<1x96x56x56xf32>
115  // CHECK:      %{{.+}} = linalg.depthwise_conv_2d_nchw_chw
116  // CHECK-SAME:   {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
117  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<1x96x113x113xf32>, tensor<96x3x3xf32>)
118  // CHECK-SAME:   outs(%{{.+}} : tensor<1x96x56x56xf32>) -> tensor<1x96x56x56xf32>
119  %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
120         ins(%input, %filter: tensor<1x96x113x113xf32>, tensor<96x3x3xf32>)
121         outs(%init: tensor<1x96x56x56xf32>) -> tensor<1x96x56x56xf32>
122  return %0: tensor<1x96x56x56xf32>
123}
124
125// CHECK-LABEL: func @depthwise_conv_2d_nchw_chw_memref
126func.func @depthwise_conv_2d_nchw_chw_memref(%input: memref<1x96x113x113xf32>, %filter: memref<96x3x3xf32>, %output: memref<1x96x56x56xf32>) {
127  // CHECK:      linalg.depthwise_conv_2d_nchw_chw
128  // CHECK-SAME:   {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
129  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<1x96x113x113xf32>, memref<96x3x3xf32>)
130  // CHECK-SAME:   outs(%{{.+}} : memref<1x96x56x56xf32>)
131  linalg.depthwise_conv_2d_nchw_chw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
132    ins(%input, %filter: memref<1x96x113x113xf32>, memref<96x3x3xf32>)
133    outs(%output: memref<1x96x56x56xf32>)
134  return
135}
136
137func.func @depthwise_conv_2d_nhwc_hwcm_tensor_dilated(%input: tensor<2x8x9x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x6x7x2x3xf32> {
138  %zero = arith.constant 0.000000e+00 : f32
139  %init = tensor.empty() : tensor<2x6x7x2x3xf32>
140  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x6x7x2x3xf32>) -> tensor<2x6x7x2x3xf32>
141  // CHECK:      %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwcm
142  // CHECK-SAME:   {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
143  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<2x8x9x2xf32>, tensor<2x2x2x3xf32>)
144  // CHECK-SAME:   outs(%{{.+}} : tensor<2x6x7x2x3xf32>)
145  %0 = linalg.depthwise_conv_2d_nhwc_hwcm
146     { dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
147     ins(%input, %filter : tensor<2x8x9x2xf32>, tensor<2x2x2x3xf32>)
148    outs(%fill : tensor<2x6x7x2x3xf32>) -> tensor<2x6x7x2x3xf32>
149  return %0 : tensor<2x6x7x2x3xf32>
150}
151
152// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwcm_memref_dilated
153func.func @depthwise_conv_2d_nhwc_hwcm_memref_dilated(%input: memref<2x8x9x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x6x7x2x3xf32>) {
154  // CHECK:      linalg.depthwise_conv_2d_nhwc_hwcm
155  // CHECK-SAME:   {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
156  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<2x8x9x2xf32>, memref<2x2x2x3xf32>)
157  // CHECK-SAME:   outs(%{{.+}} : memref<2x6x7x2x3xf32>)
158  linalg.depthwise_conv_2d_nhwc_hwcm
159     { dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
160     ins(%input, %filter : memref<2x8x9x2xf32>, memref<2x2x2x3xf32>)
161    outs(%output : memref<2x6x7x2x3xf32>)
162  return
163}
164
165// -----
166
167// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_default_attributes
168func.func @depthwise_conv_2d_input_nhwc_filter_default_attributes(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
169  // CHECK:      linalg.depthwise_conv_2d_nhwc_hwc
170  // CHECK-NOT:  strides =
171  // CHECK-NOT:  dilations =
172  linalg.depthwise_conv_2d_nhwc_hwc
173    ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
174    outs(%output: memref<1x56x56x96xf32>)
175  return
176}
177
178// -----
179
180func.func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type_properties(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
181  // expected-error @+1 {{invalid properties {dilations = dense<1> : vector<2xi64>, operandSegmentSizes = array<i32: 2, 1>, strides = dense<2.000000e+00> : vector<2xf32>} for op linalg.depthwise_conv_2d_nhwc_hwc: Invalid attribute `strides` in property conversion: dense<2.000000e+00> : vector<2xf32>}}
182  linalg.depthwise_conv_2d_nhwc_hwc <{dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}>
183    ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
184    outs(%output: memref<1x56x56x96xf32>)
185  return
186}
187
188// -----
189
190func.func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
191  // expected-error @+1 {{op attribute 'strides' failed to satisfy constraint: 64-bit signless int elements attribute of shape [2]}}
192  linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}
193    ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
194    outs(%output: memref<1x56x56x96xf32>)
195  return
196}
197
198// -----
199
200func.func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
201  // expected-error @+1 {{op attribute 'strides' failed to satisfy constraint: 64-bit signless int elements attribute of shape [2]}}
202  linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> }
203    ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
204    outs(%output: memref<1x56x56x96xf32>)
205  return
206}
207
208// -----
209
210// CHECK-LABEL: func @depthwise_conv_3d_ndhwc_dhwcm
211func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<2x6x13x12x6xf32>, %filter: tensor<2x1x3x6x6xf32>) -> tensor<2x3x13x4x6x6xf32> {
212  %zero = arith.constant 0.000000e+00 : f32
213  %init = tensor.empty() : tensor<2x3x13x4x6x6xf32>
214  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x3x13x4x6x6xf32>) -> tensor<2x3x13x4x6x6xf32>
215  // CHECK: depthwise_conv_3d_ndhwc_dhwcm
216  %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>, strides = dense<[2, 1, 3]> : tensor<3xi64>}
217    ins(%input, %filter : tensor<2x6x13x12x6xf32>, tensor<2x1x3x6x6xf32>)
218    outs(%fill : tensor<2x3x13x4x6x6xf32>) -> tensor<2x3x13x4x6x6xf32>
219  return %0 : tensor<2x3x13x4x6x6xf32>
220}
221
222// -----
223
224// CHECK-LABEL: func @depthwise_conv_3d_ndhwc_dhwc
225func.func @depthwise_conv_3d_ndhwc_dhwc(%input: tensor<2x6x13x12x6xf32>, %filter: tensor<2x1x3x6xf32>) -> tensor<2x3x13x4x6xf32> {
226  %zero = arith.constant 0.000000e+00 : f32
227  %init = tensor.empty() : tensor<2x3x13x4x6xf32>
228  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x3x13x4x6xf32>) -> tensor<2x3x13x4x6xf32>
229  // CHECK: depthwise_conv_3d_ndhwc_dhwc
230  %0 = linalg.depthwise_conv_3d_ndhwc_dhwc {dilations = dense<1> : tensor<3xi64>, strides = dense<[2, 1, 3]> : tensor<3xi64>}
231    ins(%input, %filter : tensor<2x6x13x12x6xf32>, tensor<2x1x3x6xf32>)
232    outs(%fill : tensor<2x3x13x4x6xf32>) -> tensor<2x3x13x4x6xf32>
233  return %0 : tensor<2x3x13x4x6xf32>
234}
235
236// -----
237
238// CHECK-LABEL: func @depthwise_conv_3d_ncdhw_cdhw
239func.func @depthwise_conv_3d_ncdhw_cdhw(%input: tensor<2x6x6x13x12xf32>, %filter: tensor<6x2x1x3xf32>) -> tensor<2x6x3x13x4xf32> {
240  %zero = arith.constant 0.000000e+00 : f32
241  %init = tensor.empty() : tensor<2x6x3x13x4xf32>
242  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x6x3x13x4xf32>) -> tensor<2x6x3x13x4xf32>
243  // CHECK: depthwise_conv_3d_ncdhw_cdhw
244  %0 = linalg.depthwise_conv_3d_ncdhw_cdhw {dilations = dense<1> : tensor<3xi64>, strides = dense<[2, 1, 3]> : tensor<3xi64>}
245    ins(%input, %filter : tensor<2x6x6x13x12xf32>, tensor<6x2x1x3xf32>)
246    outs(%fill : tensor<2x6x3x13x4xf32>) -> tensor<2x6x3x13x4xf32>
247  return %0 : tensor<2x6x3x13x4xf32>
248}
249
250// -----
251
252// CHECK-LABEL: func @conv_1d_nwc_wcf
253func.func @conv_1d_nwc_wcf(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
254  // CHECK:      %{{.+}} = linalg.conv_1d_nwc_wcf
255  // CHECK-SAME:   dilations = dense<1> : tensor<1xi64>
256  // CHECK-SAME:   strides = dense<1> : tensor<1xi64>
257  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
258  // CHECK-SAME:   outs(%{{.+}} : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
259  %0 = linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>,
260                                            strides = dense<1> : tensor<1xi64>}
261     ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
262    outs (%init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
263  return %0 : tensor<?x?x?xf32>
264}
265
266// -----
267
268// CHECK-LABEL: func @conv_1d_nwc_wcf
269func.func @conv_1d_nwc_wcf(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
270  // CHECK:      linalg.conv_1d_nwc_wcf
271  // CHECK-SAME:   dilations = dense<1> : tensor<1xi64>
272  // CHECK-SAME:   strides = dense<1> : tensor<1xi64>
273  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
274  // CHECK-SAME:   outs(%{{.+}} : memref<?x?x?xf32>)
275  linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>,
276                                       strides = dense<1> : tensor<1xi64>}
277     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
278    outs (%output: memref<?x?x?xf32>)
279  return
280}
281
282// -----
283
284// CHECK-LABEL: func @conv_1d_ncw_fcw
285func.func @conv_1d_ncw_fcw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
286  // CHECK:      %{{.+}} = linalg.conv_1d_ncw_fcw
287  // CHECK-SAME:   dilations = dense<1> : tensor<1xi64>
288  // CHECK-SAME:   strides = dense<1> : tensor<1xi64>
289  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
290  // CHECK-SAME:   outs(%{{.+}} : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
291  %0 = linalg.conv_1d_ncw_fcw {dilations = dense<1> : tensor<1xi64>,
292                                            strides = dense<1> : tensor<1xi64>}
293     ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
294    outs (%init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
295  return %0 : tensor<?x?x?xf32>
296}
297
298// -----
299
300// CHECK-LABEL: func @conv_1d_ncw_fcw
301func.func @conv_1d_ncw_fcw(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
302  // CHECK:      linalg.conv_1d_ncw_fcw
303  // CHECK-SAME:   dilations = dense<1> : tensor<1xi64>
304  // CHECK-SAME:   strides = dense<1> : tensor<1xi64>
305  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
306  // CHECK-SAME:   outs(%{{.+}} : memref<?x?x?xf32>)
307  linalg.conv_1d_ncw_fcw {dilations = dense<1> : tensor<1xi64>,
308                                       strides = dense<1> : tensor<1xi64>}
309     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
310    outs (%output: memref<?x?x?xf32>)
311  return
312}
313
314// -----
315
316// CHECK-LABEL: func @conv_2d_nhwc_hwcf
317func.func @conv_2d_nhwc_hwcf(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
318  // CHECK:      %{{.+}} = linalg.conv_2d_nhwc_hwcf
319  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
320  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
321  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
322  // CHECK-SAME:   outs(%{{.+}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
323  %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
324                                              strides = dense<1> : tensor<2xi64>}
325     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
326    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
327  return %0 : tensor<?x?x?x?xf32>
328}
329
330// -----
331
332// CHECK-LABEL: func @conv_2d_ngchw_fgchw
333func.func @conv_2d_ngchw_fgchw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
334  // CHECK:      %{{.+}} = linalg.conv_2d_ngchw_fgchw
335  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
336  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
337  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
338  // CHECK-SAME:   outs(%{{.+}} : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
339  %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
340                                              strides = dense<1> : tensor<2xi64>}
341     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
342    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
343  return %0 : tensor<?x?x?x?x?xf32>
344}
345
346// -----
347
348// CHECK-LABEL: func @conv_2d_nhwc_fhwc
349func.func @conv_2d_nhwc_fhwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
350  // CHECK:      %{{.+}} = linalg.conv_2d_nhwc_fhwc
351  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
352  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
353  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
354  // CHECK-SAME:   outs(%{{.+}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
355  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
356                                 strides = dense<1> : tensor<2xi64>}
357     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
358    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
359  return %0 : tensor<?x?x?x?xf32>
360}
361
362// -----
363
364// CHECK-LABEL: func @conv_2d_nhwc_fhwc_static
365func.func @conv_2d_nhwc_fhwc_static(%input: tensor<?x128x128x32xf32>, %filter: tensor<64x3x3x32xf32>, %init: tensor<?x126x126x64xf32>) -> tensor<?x126x126x64xf32> {
366  // CHECK:      %{{.+}} = linalg.conv_2d_nhwc_fhwc
367  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
368  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
369  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x128x128x32xf32>, tensor<64x3x3x32xf32>)
370  // CHECK-SAME:   outs(%{{.+}} : tensor<?x126x126x64xf32>) -> tensor<?x126x126x64xf32>
371  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
372                                 strides = dense<1> : tensor<2xi64>}
373     ins (%input, %filter: tensor<?x128x128x32xf32>, tensor<64x3x3x32xf32>)
374    outs (%init: tensor<?x126x126x64xf32>) -> tensor<?x126x126x64xf32>
375  return %0 : tensor<?x126x126x64xf32>
376}
377
378// -----
379
380// CHECK-LABEL: func @conv_2d_nhwc_hwcf
381func.func @conv_2d_nhwc_hwcf(%input: memref<?x?x?x?xf32>, %filter: memref<?x?x?x?xf32>, %output: memref<?x?x?x?xf32>) {
382  // CHECK:      linalg.conv_2d_nhwc_hwcf
383  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
384  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
385  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
386  // CHECK-SAME:   outs(%{{.+}} : memref<?x?x?x?xf32>)
387  linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
388                                         strides = dense<1> : tensor<2xi64>}
389     ins (%input, %filter: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
390    outs (%output: memref<?x?x?x?xf32>)
391  return
392}
393
394// -----
395
396// CHECK-LABEL: func @conv_2d_ngchw_fgchw
397func.func @conv_2d_ngchw_fgchw(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
398  // CHECK:      linalg.conv_2d_ngchw_fgchw
399  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
400  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
401  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
402  // CHECK-SAME:   outs(%{{.+}} : memref<?x?x?x?x?xf32>)
403  linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
404                                         strides = dense<1> : tensor<2xi64>}
405     ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
406    outs (%output: memref<?x?x?x?x?xf32>)
407  return
408}
409
410// -----
411
412// CHECK-LABEL: func @conv_2d_nhwgc_gfhwc
413func.func @conv_2d_nhwgc_gfhwc(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
414  // CHECK:      linalg.conv_2d_nhwgc_gfhwc
415  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
416  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
417  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
418  // CHECK-SAME:   outs(%{{.+}} : memref<?x?x?x?x?xf32>)
419  linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>,
420                                         strides = dense<1> : tensor<2xi64>}
421     ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
422    outs (%output: memref<?x?x?x?x?xf32>)
423  return
424}
425
426// -----
427
428// CHECK-LABEL: func @conv_2d_nhwgc_gfhwc_tensor
429func.func @conv_2d_nhwgc_gfhwc_tensor(%input: tensor<1x28x28x2x3xf32>, %filter: tensor<2x8x3x3x3xf32>, %output: tensor<1x26x26x2x8xf32>) -> tensor<1x26x26x2x8xf32>  {
430  // CHECK:      linalg.conv_2d_nhwgc_gfhwc
431  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
432  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
433  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<1x28x28x2x3xf32>, tensor<2x8x3x3x3xf32>)
434  // CHECK-SAME:   outs(%{{.+}} : tensor<1x26x26x2x8xf32>) -> tensor<1x26x26x2x8xf32>
435  %0 = linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>,
436                                         strides = dense<1> : tensor<2xi64>}
437     ins (%input, %filter: tensor<1x28x28x2x3xf32>, tensor<2x8x3x3x3xf32>)
438    outs (%output: tensor<1x26x26x2x8xf32>) -> tensor<1x26x26x2x8xf32>
439  return  %0 : tensor<1x26x26x2x8xf32>
440}
441
442// -----
443
444// CHECK-LABEL: func @conv_2d_ngchw_fgchw_dimensions
445func.func @conv_2d_ngchw_fgchw_dimensions(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<2x5x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
446  // CHECK:      linalg.conv_2d_ngchw_fgchw
447  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
448  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
449  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xf32>, tensor<2x5x3x3x3xf32>)
450  // CHECK-SAME:   outs(%{{.+}} : tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
451  %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
452                                         strides = dense<1> : tensor<2xi64>}
453     ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<2x5x3x3x3xf32>)
454    outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
455  return %0 : tensor<1x5x2x30x30xf32>
456}
457
458// -----
459
460// CHECK-LABEL: func @conv_2d_ngchw_gfchw
461func.func @conv_2d_ngchw_gfchw(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<5x2x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
462  // CHECK:      linalg.conv_2d_ngchw_gfchw
463  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
464  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
465  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
466  // CHECK-SAME:   outs(%{{.+}} : tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
467  %0 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>,
468                                         strides = dense<1> : tensor<2xi64>}
469     ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
470    outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
471  return %0 : tensor<1x5x2x30x30xf32>
472}
473
474// -----
475
476// CHECK-LABEL: func @conv_2d_ngchw_gfchw_q
477func.func @conv_2d_ngchw_gfchw_q(%input: tensor<1x5x3x32x32xi8>, %filter: tensor<5x2x3x3x3xi8>, %inputzp: i32, %filterzp: i32, %init: tensor<1x5x2x30x30xi32>) -> tensor<1x5x2x30x30xi32> {
478  // CHECK:      linalg.conv_2d_ngchw_gfchw_q
479  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
480  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
481  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xi8>, tensor<5x2x3x3x3xi8>, i32, i32)
482  // CHECK-SAME:   outs(%{{.+}} : tensor<1x5x2x30x30xi32>) -> tensor<1x5x2x30x30xi32>
483  %0 = linalg.conv_2d_ngchw_gfchw_q {dilations = dense<1> : tensor<2xi64>,
484                                         strides = dense<1> : tensor<2xi64>}
485     ins (%input, %filter, %inputzp, %filterzp: tensor<1x5x3x32x32xi8>, tensor<5x2x3x3x3xi8>, i32, i32)
486    outs (%init: tensor<1x5x2x30x30xi32>) -> tensor<1x5x2x30x30xi32>
487  return %0 : tensor<1x5x2x30x30xi32>
488}
489// -----
490
491// CHECK-LABEL: func @conv_3d_ndhwc_dhwcf
492func.func @conv_3d_ndhwc_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
493  // CHECK:      %{{.+}} = linalg.conv_3d_ndhwc_dhwcf
494  // CHECK-SAME:   dilations = dense<1> : tensor<3xi64>
495  // CHECK-SAME:   strides = dense<1> : tensor<3xi64>
496  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
497  // CHECK-SAME:   outs(%{{.+}} : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
498  %0 = linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>,
499                                                strides = dense<1> : tensor<3xi64>}
500     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
501    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
502  return %0 : tensor<?x?x?x?x?xf32>
503}
504
505// -----
506
507// CHECK-LABEL: func @conv_3d_ndhwc_dhwcf
508func.func @conv_3d_ndhwc_dhwcf(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
509  // CHECK:      linalg.conv_3d_ndhwc_dhwcf
510  // CHECK-SAME:   dilations = dense<1> : tensor<3xi64>
511  // CHECK-SAME:   strides = dense<1> : tensor<3xi64>
512  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
513  // CHECK-SAME:   outs(%{{.+}} : memref<?x?x?x?x?xf32>)
514  linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>,
515                                           strides = dense<1> : tensor<3xi64>}
516     ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
517    outs (%output: memref<?x?x?x?x?xf32>)
518  return
519}
520
521// -----
522
523// CHECK-LABEL: func @conv_3d_ncdhw_fcdhw
524func.func @conv_3d_ncdhw_fcdhw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
525  // CHECK:      %{{.+}} = linalg.conv_3d_ncdhw_fcdhw
526  // CHECK-SAME:   dilations = dense<1> : tensor<3xi64>
527  // CHECK-SAME:   strides = dense<1> : tensor<3xi64>
528  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
529  // CHECK-SAME:   outs(%{{.+}} : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
530  %0 = linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : tensor<3xi64>,
531                                                strides = dense<1> : tensor<3xi64>}
532     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
533    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
534  return %0 : tensor<?x?x?x?x?xf32>
535}
536
537// -----
538
539// CHECK-LABEL: func @conv_3d_ncdhw_fcdhw
540func.func @conv_3d_ncdhw_fcdhw(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
541  // CHECK:      linalg.conv_3d_ncdhw_fcdhw
542  // CHECK-SAME:   dilations = dense<1> : tensor<3xi64>
543  // CHECK-SAME:   strides = dense<1> : tensor<3xi64>
544  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
545  // CHECK-SAME:   outs(%{{.+}} : memref<?x?x?x?x?xf32>)
546  linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : tensor<3xi64>,
547                                           strides = dense<1> : tensor<3xi64>}
548     ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
549    outs (%output: memref<?x?x?x?x?xf32>)
550  return
551}
552
553// -----
554
555// CHECK-LABEL: func @pooling_nhwc_sum_tensor
556// CHECK:         %{{.+}} = linalg.pooling_nhwc_sum
557// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
558// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
559// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xf32>, tensor<3x3xf32>)
560// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
561func.func @pooling_nhwc_sum_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> {
562  %fake = tensor.empty() : tensor<3x3xf32>
563  %init = tensor.empty() : tensor<1x2x2x1xf32>
564  %cst = arith.constant 0.000000e+00 : f32
565  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
566  %res = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
567    ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>)
568    outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
569  return %res : tensor<1x2x2x1xf32>
570}
571
572// -----
573
574// CHECK-LABEL: func @pooling_nwc_sum_tensor
575// CHECK:         %{{.+}} = linalg.pooling_nwc_sum
576// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
577// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
578// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x1xf32>, tensor<3xf32>)
579// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x1xf32>) -> tensor<1x2x1xf32>
580func.func @pooling_nwc_sum_tensor(%input: tensor<1x4x1xf32>) -> tensor<1x2x1xf32> {
581  %fake = tensor.empty() : tensor<3xf32>
582  %init = tensor.empty() : tensor<1x2x1xf32>
583  %cst = arith.constant 0.000000e+00 : f32
584  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x1xf32>) -> tensor<1x2x1xf32>
585  %res = linalg.pooling_nwc_sum {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
586    ins(%input, %fake: tensor<1x4x1xf32>, tensor<3xf32>)
587    outs(%fill: tensor<1x2x1xf32>) -> tensor<1x2x1xf32>
588  return %res : tensor<1x2x1xf32>
589}
590
591// -----
592
593// CHECK-LABEL: func @pooling_nhwc_sum
594// CHECK:         linalg.pooling_nhwc_sum
595// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
596// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
597// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xf32>, memref<3x3xf32>)
598// CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x1xf32>)
599func.func @pooling_nhwc_sum(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
600  linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
601    ins(%input, %fake: memref<1x4x4x1xf32>, memref<3x3xf32>)
602    outs(%output: memref<1x2x2x1xf32>)
603  return
604}
605
606// -----
607
608// CHECK-LABEL: func @pooling_nwc_sum
609// CHECK:         linalg.pooling_nwc_sum
610// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
611// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
612// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x1xf32>, memref<3xf32>)
613// CHECK-SAME:      outs(%{{.+}} : memref<1x2x1xf32>)
614func.func @pooling_nwc_sum(%input: memref<1x4x1xf32>, %fake: memref<3xf32>, %output: memref<1x2x1xf32>) {
615  linalg.pooling_nwc_sum {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
616    ins(%input, %fake: memref<1x4x1xf32>, memref<3xf32>)
617    outs(%output: memref<1x2x1xf32>)
618  return
619}
620
621// -----
622
623// CHECK-LABEL: func @pooling_nchw_sum_tensor
624// CHECK:         %{{.+}} = linalg.pooling_nchw_sum
625// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
626// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
627// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x1x4x4xf32>, tensor<3x3xf32>)
628// CHECK-SAME:      outs(%{{.+}} : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
629func.func @pooling_nchw_sum_tensor(%input: tensor<1x1x4x4xf32>) -> tensor<1x1x2x2xf32> {
630  %fake = tensor.empty() : tensor<3x3xf32>
631  %init = tensor.empty() : tensor<1x1x2x2xf32>
632  %cst = arith.constant 0.000000e+00 : f32
633  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
634  %res = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
635    ins(%input, %fake: tensor<1x1x4x4xf32>, tensor<3x3xf32>)
636    outs(%fill: tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
637  return %res : tensor<1x1x2x2xf32>
638}
639
640// -----
641
642// CHECK-LABEL: func @pooling_ncw_sum_tensor
643// CHECK:         %{{.+}} = linalg.pooling_ncw_sum
644// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
645// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
646// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x1x4xf32>, tensor<3xf32>)
647// CHECK-SAME:      outs(%{{.+}} : tensor<1x1x2xf32>) -> tensor<1x1x2xf32>
648func.func @pooling_ncw_sum_tensor(%input: tensor<1x1x4xf32>) -> tensor<1x1x2xf32> {
649  %fake = tensor.empty() : tensor<3xf32>
650  %init = tensor.empty() : tensor<1x1x2xf32>
651  %cst = arith.constant 0.000000e+00 : f32
652  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x1x2xf32>) -> tensor<1x1x2xf32>
653  %res = linalg.pooling_ncw_sum {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
654    ins(%input, %fake: tensor<1x1x4xf32>, tensor<3xf32>)
655    outs(%fill: tensor<1x1x2xf32>) -> tensor<1x1x2xf32>
656  return %res : tensor<1x1x2xf32>
657}
658
659// -----
660
661// CHECK-LABEL: func @pooling_nchw_sum
662// CHECK:         linalg.pooling_nchw_sum
663// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
664// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
665// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x1x4x4xf32>, memref<3x3xf32>)
666// CHECK-SAME:      outs(%{{.+}} : memref<1x1x2x2xf32>)
667func.func @pooling_nchw_sum(%input: memref<1x1x4x4xf32>, %fake: memref<3x3xf32>, %output: memref<1x1x2x2xf32>) {
668  linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
669    ins(%input, %fake: memref<1x1x4x4xf32>, memref<3x3xf32>)
670    outs(%output: memref<1x1x2x2xf32>)
671  return
672}
673
674// -----
675
676// CHECK-LABEL: func @pooling_ncw_sum
677// CHECK:         linalg.pooling_ncw_sum
678// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
679// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
680// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x1x4xf32>, memref<3xf32>)
681// CHECK-SAME:      outs(%{{.+}} : memref<1x1x2xf32>)
682func.func @pooling_ncw_sum(%input: memref<1x1x4xf32>, %fake: memref<3xf32>, %output: memref<1x1x2xf32>) {
683  linalg.pooling_ncw_sum {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
684    ins(%input, %fake: memref<1x1x4xf32>, memref<3xf32>)
685    outs(%output: memref<1x1x2xf32>)
686  return
687}
688
689// -----
690
691// CHECK-LABEL: func @pooling_nhwc_max_tensor
692// CHECK:         %{{.+}} = linalg.pooling_nhwc_max
693// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
694// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
695// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xf32>, tensor<3x3xf32>)
696// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
697func.func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> {
698  %fake = tensor.empty() : tensor<3x3xf32>
699  %init = tensor.empty() : tensor<1x2x2x1xf32>
700  %cst = arith.constant 0.000000e+00 : f32
701  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
702  %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
703    ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>)
704    outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
705  return %res : tensor<1x2x2x1xf32>
706}
707
708// -----
709// CHECK-LABEL: func @pooling_nwc_max_tensor
710// CHECK:         %{{.+}} = linalg.pooling_nwc_max
711// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
712// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
713// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x1xf32>, tensor<3xf32>)
714// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x1xf32>) -> tensor<1x2x1xf32>
715func.func @pooling_nwc_max_tensor(%input: tensor<1x4x1xf32>) -> tensor<1x2x1xf32> {
716  %fake = tensor.empty() : tensor<3xf32>
717  %init = tensor.empty() : tensor<1x2x1xf32>
718  %cst = arith.constant 0.000000e+00 : f32
719  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x1xf32>) -> tensor<1x2x1xf32>
720  %res = linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
721    ins(%input, %fake: tensor<1x4x1xf32>, tensor<3xf32>)
722    outs(%fill: tensor<1x2x1xf32>) -> tensor<1x2x1xf32>
723  return %res : tensor<1x2x1xf32>
724}
725
726// -----
727// CHECK-LABEL: func @pooling_nchw_max_tensor
728// CHECK:         %{{.+}} = linalg.pooling_nchw_max
729// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
730// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
731// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x1x4x4xf32>, tensor<3x3xf32>)
732// CHECK-SAME:      outs(%{{.+}} : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
733
734func.func @pooling_nchw_max_tensor(%input: tensor<1x1x4x4xf32>) -> tensor<1x1x2x2xf32> {
735  %fake = tensor.empty() : tensor<3x3xf32>
736  %init = tensor.empty() : tensor<1x1x2x2xf32>
737  %cst = arith.constant 0.000000e+00 : f32
738  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
739  %res = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
740    ins(%input, %fake: tensor<1x1x4x4xf32>, tensor<3x3xf32>)
741    outs(%fill: tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
742  return %res : tensor<1x1x2x2xf32>
743}
744
745// -----
746// CHECK-LABEL: func @pooling_ncw_max_tensor
747// CHECK:         %{{.+}} = linalg.pooling_ncw_max
748// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
749// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
750// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x1x4xf32>, tensor<3xf32>)
751// CHECK-SAME:      outs(%{{.+}} : tensor<1x1x2xf32>) -> tensor<1x1x2xf32>
752
753func.func @pooling_ncw_max_tensor(%input: tensor<1x1x4xf32>) -> tensor<1x1x2xf32> {
754  %fake = tensor.empty() : tensor<3xf32>
755  %init = tensor.empty() : tensor<1x1x2xf32>
756  %cst = arith.constant 0.000000e+00 : f32
757  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x1x2xf32>) -> tensor<1x1x2xf32>
758  %res = linalg.pooling_ncw_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
759    ins(%input, %fake: tensor<1x1x4xf32>, tensor<3xf32>)
760    outs(%fill: tensor<1x1x2xf32>) -> tensor<1x1x2xf32>
761  return %res : tensor<1x1x2xf32>
762}
763
764// -----
765
766// CHECK-LABEL: func @pooling_nhwc_max
767// CHECK:         linalg.pooling_nhwc_max
768// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
769// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
770// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xf32>, memref<3x3xf32>)
771// CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x1xf32>)
772func.func @pooling_nhwc_max(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
773  linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
774    ins(%input, %fake: memref<1x4x4x1xf32>, memref<3x3xf32>)
775    outs(%output: memref<1x2x2x1xf32>)
776  return
777}
778
779// -----
780
781// CHECK-LABEL: func @pooling_nwc_max
782// CHECK:         linalg.pooling_nwc_max
783// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
784// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
785// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x1xf32>, memref<3xf32>)
786// CHECK-SAME:      outs(%{{.+}} : memref<1x2x1xf32>)
787func.func @pooling_nwc_max(%input: memref<1x4x1xf32>, %fake: memref<3xf32>, %output: memref<1x2x1xf32>) {
788  linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
789    ins(%input, %fake: memref<1x4x1xf32>, memref<3xf32>)
790    outs(%output: memref<1x2x1xf32>)
791  return
792}
793
794// -----
795
796// CHECK-LABEL: func @pooling_nhwc_i8_max_tensor
797// CHECK:         %{{.+}} = linalg.pooling_nhwc_max
798// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
799// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
800// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi8>, tensor<3x3xi8>)
801// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8>
802func.func @pooling_nhwc_i8_max_tensor(%input: tensor<1x4x4x1xi8>) -> tensor<1x2x2x1xi8> {
803  %fake = tensor.empty() : tensor<3x3xi8>
804  %init = tensor.empty() : tensor<1x2x2x1xi8>
805  %cst = arith.constant 0 : i8
806  %fill = linalg.fill ins(%cst : i8) outs(%init : tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8>
807  %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
808    ins(%input, %fake: tensor<1x4x4x1xi8>, tensor<3x3xi8>)
809    outs(%fill: tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8>
810  return %res : tensor<1x2x2x1xi8>
811}
812
813// -----
814
815// CHECK-LABEL: func @pooling_nwc_i8_max_tensor
816// CHECK:         %{{.+}} = linalg.pooling_nwc_max
817// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
818// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
819// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x1xi8>, tensor<3xi8>)
820// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x1xi8>) -> tensor<1x2x1xi8>
821func.func @pooling_nwc_i8_max_tensor(%input: tensor<1x4x1xi8>) -> tensor<1x2x1xi8> {
822  %fake = tensor.empty() : tensor<3xi8>
823  %init = tensor.empty() : tensor<1x2x1xi8>
824  %cst = arith.constant 0 : i8
825  %fill = linalg.fill ins(%cst : i8) outs(%init : tensor<1x2x1xi8>) -> tensor<1x2x1xi8>
826  %res = linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
827    ins(%input, %fake: tensor<1x4x1xi8>, tensor<3xi8>)
828    outs(%fill: tensor<1x2x1xi8>) -> tensor<1x2x1xi8>
829  return %res : tensor<1x2x1xi8>
830}
831
832// -----
833
834// CHECK-LABEL: func @pooling_nhwc_i8_max
835// CHECK:         linalg.pooling_nhwc_max
836// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
837// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
838// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi8>, memref<3x3xi8>)
839// CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x1xi8>)
840func.func @pooling_nhwc_i8_max(%input: memref<1x4x4x1xi8>, %fake: memref<3x3xi8>, %output: memref<1x2x2x1xi8>) {
841  linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
842    ins(%input, %fake: memref<1x4x4x1xi8>, memref<3x3xi8>)
843    outs(%output: memref<1x2x2x1xi8>)
844  return
845}
846
847// -----
848
849// CHECK-LABEL: func @pooling_nwc_i8_max
850// CHECK:         linalg.pooling_nwc_max
851// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
852// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
853// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x1xi8>, memref<3xi8>)
854// CHECK-SAME:      outs(%{{.+}} : memref<1x2x1xi8>)
855func.func @pooling_nwc_i8_max(%input: memref<1x4x1xi8>, %fake: memref<3xi8>, %output: memref<1x2x1xi8>) {
856  linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
857    ins(%input, %fake: memref<1x4x1xi8>, memref<3xi8>)
858    outs(%output: memref<1x2x1xi8>)
859  return
860}
861
862// -----
863
864// CHECK-LABEL: func @pooling_nhwc_i16_max_tensor
865// CHECK:         %{{.+}} = linalg.pooling_nhwc_max
866// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
867// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
868// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi16>, tensor<3x3xi16>)
869// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16>
870func.func @pooling_nhwc_i16_max_tensor(%input: tensor<1x4x4x1xi16>) -> tensor<1x2x2x1xi16> {
871  %fake = tensor.empty() : tensor<3x3xi16>
872  %init = tensor.empty() : tensor<1x2x2x1xi16>
873  %cst = arith.constant 0 : i16
874  %fill = linalg.fill ins(%cst : i16) outs(%init : tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16>
875  %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
876    ins(%input, %fake: tensor<1x4x4x1xi16>, tensor<3x3xi16>)
877    outs(%fill: tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16>
878  return %res : tensor<1x2x2x1xi16>
879}
880
881// -----
882
883// CHECK-LABEL: func @pooling_nwc_i16_max_tensor
884// CHECK:         %{{.+}} = linalg.pooling_nwc_max
885// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
886// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
887// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x1xi16>, tensor<3xi16>)
888// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x1xi16>) -> tensor<1x2x1xi16>
889func.func @pooling_nwc_i16_max_tensor(%input: tensor<1x4x1xi16>) -> tensor<1x2x1xi16> {
890  %fake = tensor.empty() : tensor<3xi16>
891  %init = tensor.empty() : tensor<1x2x1xi16>
892  %cst = arith.constant 0 : i16
893  %fill = linalg.fill ins(%cst : i16) outs(%init : tensor<1x2x1xi16>) -> tensor<1x2x1xi16>
894  %res = linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
895    ins(%input, %fake: tensor<1x4x1xi16>, tensor<3xi16>)
896    outs(%fill: tensor<1x2x1xi16>) -> tensor<1x2x1xi16>
897  return %res : tensor<1x2x1xi16>
898}
899
900// -----
901
902// CHECK-LABEL: func @pooling_nhwc_i16_max
903// CHECK:         linalg.pooling_nhwc_max
904// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
905// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
906// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi16>, memref<3x3xi16>)
907// CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x1xi16>)
908func.func @pooling_nhwc_i16_max(%input: memref<1x4x4x1xi16>, %fake: memref<3x3xi16>, %output: memref<1x2x2x1xi16>) {
909  linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
910    ins(%input, %fake: memref<1x4x4x1xi16>, memref<3x3xi16>)
911    outs(%output: memref<1x2x2x1xi16>)
912  return
913}
914
915// -----
916
917// CHECK-LABEL: func @pooling_nwc_i16_max
918// CHECK:         linalg.pooling_nwc_max
919// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
920// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
921// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x1xi16>, memref<3xi16>)
922// CHECK-SAME:      outs(%{{.+}} : memref<1x2x1xi16>)
923func.func @pooling_nwc_i16_max(%input: memref<1x4x1xi16>, %fake: memref<3xi16>, %output: memref<1x2x1xi16>) {
924  linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
925    ins(%input, %fake: memref<1x4x1xi16>, memref<3xi16>)
926    outs(%output: memref<1x2x1xi16>)
927  return
928}
929
930// -----
931
932// CHECK-LABEL: func @pooling_nhwc_i32_max_tensor
933// CHECK:         %{{.+}} = linalg.pooling_nhwc_max
934// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
935// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
936// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
937// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
938func.func @pooling_nhwc_i32_max_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
939  %fake = tensor.empty() : tensor<3x3xi32>
940  %init = tensor.empty() : tensor<1x2x2x1xi32>
941  %cst = arith.constant 0 : i32
942  %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
943  %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
944    ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
945    outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
946  return %res : tensor<1x2x2x1xi32>
947}
948
949// -----
950
951// CHECK-LABEL: func @pooling_nwc_i32_max_tensor
952// CHECK:         %{{.+}} = linalg.pooling_nwc_max
953// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
954// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
955// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x1xi32>, tensor<3xi32>)
956// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x1xi32>) -> tensor<1x2x1xi32>
957func.func @pooling_nwc_i32_max_tensor(%input: tensor<1x4x1xi32>) -> tensor<1x2x1xi32> {
958  %fake = tensor.empty() : tensor<3xi32>
959  %init = tensor.empty() : tensor<1x2x1xi32>
960  %cst = arith.constant 0 : i32
961  %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x1xi32>) -> tensor<1x2x1xi32>
962  %res = linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
963    ins(%input, %fake: tensor<1x4x1xi32>, tensor<3xi32>)
964    outs(%fill: tensor<1x2x1xi32>) -> tensor<1x2x1xi32>
965  return %res : tensor<1x2x1xi32>
966}
967
968// -----
969
970// CHECK-LABEL: func @pooling_nhwc_i32_max
971// CHECK:         linalg.pooling_nhwc_max
972// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
973// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
974// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi32>, memref<3x3xi32>)
975// CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x1xi32>)
976func.func @pooling_nhwc_i32_max(%input: memref<1x4x4x1xi32>, %fake: memref<3x3xi32>, %output: memref<1x2x2x1xi32>) {
977  linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
978    ins(%input, %fake: memref<1x4x4x1xi32>, memref<3x3xi32>)
979    outs(%output: memref<1x2x2x1xi32>)
980  return
981}
982
983// -----
984
985// CHECK-LABEL: func @pooling_nwc_i32_max
986// CHECK:         linalg.pooling_nwc_max
987// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
988// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
989// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x1xi32>, memref<3xi32>)
990// CHECK-SAME:      outs(%{{.+}} : memref<1x2x1xi32>)
991func.func @pooling_nwc_i32_max(%input: memref<1x4x1xi32>, %fake: memref<3xi32>, %output: memref<1x2x1xi32>) {
992  linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
993    ins(%input, %fake: memref<1x4x1xi32>, memref<3xi32>)
994    outs(%output: memref<1x2x1xi32>)
995  return
996}
997
998
999// -----
1000
1001// CHECK-LABEL: func @pooling_nhwc_min_tensor
1002// CHECK:         %{{.+}} = linalg.pooling_nhwc_min
1003// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
1004// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
1005// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xf32>, tensor<3x3xf32>)
1006// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
1007func.func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> {
1008  %fake = tensor.empty() : tensor<3x3xf32>
1009  %init = tensor.empty() : tensor<1x2x2x1xf32>
1010  %cst = arith.constant 0.000000e+00 : f32
1011  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
1012  %res = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
1013    ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>)
1014    outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
1015  return %res : tensor<1x2x2x1xf32>
1016}
1017
1018// -----
1019
1020// CHECK-LABEL: func @pooling_nwc_min_tensor
1021// CHECK:         %{{.+}} = linalg.pooling_nwc_min
1022// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
1023// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
1024// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x1xf32>, tensor<3xf32>)
1025// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x1xf32>) -> tensor<1x2x1xf32>
1026func.func @pooling_nwc_min_tensor(%input: tensor<1x4x1xf32>) -> tensor<1x2x1xf32> {
1027  %fake = tensor.empty() : tensor<3xf32>
1028  %init = tensor.empty() : tensor<1x2x1xf32>
1029  %cst = arith.constant 0.000000e+00 : f32
1030  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x1xf32>) -> tensor<1x2x1xf32>
1031  %res = linalg.pooling_nwc_min {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
1032    ins(%input, %fake: tensor<1x4x1xf32>, tensor<3xf32>)
1033    outs(%fill: tensor<1x2x1xf32>) -> tensor<1x2x1xf32>
1034  return %res : tensor<1x2x1xf32>
1035}
1036
1037// -----
1038
1039// CHECK-LABEL: func @pooling_nhwc_min
1040// CHECK:         linalg.pooling_nhwc_min
1041// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
1042// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
1043// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xf32>, memref<3x3xf32>)
1044// CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x1xf32>)
1045func.func @pooling_nhwc_min(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
1046  linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
1047    ins(%input, %fake: memref<1x4x4x1xf32>, memref<3x3xf32>)
1048    outs(%output: memref<1x2x2x1xf32>)
1049  return
1050}
1051
1052// -----
1053
1054// CHECK-LABEL: func @pooling_nwc_min
1055// CHECK:         linalg.pooling_nwc_min
1056// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
1057// CHECK-SAME:      strides = dense<1> : tensor<1xi64>
1058// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x1xf32>, memref<3xf32>)
1059// CHECK-SAME:      outs(%{{.+}} : memref<1x2x1xf32>)
1060func.func @pooling_nwc_min(%input: memref<1x4x1xf32>, %fake: memref<3xf32>, %output: memref<1x2x1xf32>) {
1061  linalg.pooling_nwc_min {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
1062    ins(%input, %fake: memref<1x4x1xf32>, memref<3xf32>)
1063    outs(%output: memref<1x2x1xf32>)
1064  return
1065}
1066
1067// -----
1068
1069// CHECK-LABEL: func @pooling_ndhwc_sum_tensor
1070// CHECK:         %{{.+}} = linalg.pooling_ndhwc_sum
1071// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>
1072// CHECK-SAME:      strides = dense<1> : tensor<3xi64>
1073// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>)
1074// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>
1075func.func @pooling_ndhwc_sum_tensor(%input: tensor<1x4x4x4x1xf32>) -> tensor<1x2x2x2x1xf32> {
1076  %fake = tensor.empty() : tensor<3x3x3xf32>
1077  %init = tensor.empty() : tensor<1x2x2x2x1xf32>
1078  %cst = arith.constant 0.000000e+00 : f32
1079  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>
1080  %res = linalg.pooling_ndhwc_sum {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
1081    ins(%input, %fake: tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>)
1082    outs(%fill: tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>
1083  return %res : tensor<1x2x2x2x1xf32>
1084}
1085
1086// -----
1087
1088// CHECK-LABEL: func @pooling_ndhwc_sum
1089// CHECK:         linalg.pooling_ndhwc_sum
1090// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>
1091// CHECK-SAME:      strides = dense<1> : tensor<3xi64>
1092// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x4x1xf32>, memref<3x3x3xf32>)
1093// CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x2x1xf32>)
1094func.func @pooling_ndhwc_sum(%input: memref<1x4x4x4x1xf32>, %fake: memref<3x3x3xf32>, %output: memref<1x2x2x2x1xf32>) {
1095  linalg.pooling_ndhwc_sum {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
1096    ins(%input, %fake: memref<1x4x4x4x1xf32>, memref<3x3x3xf32>)
1097    outs(%output: memref<1x2x2x2x1xf32>)
1098  return
1099}
1100
1101// -----
1102
1103// CHECK-LABEL: func @pooling_ndhwc_max_tensor
1104// CHECK:         %{{.+}} = linalg.pooling_ndhwc_max
1105// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>
1106// CHECK-SAME:      strides = dense<1> : tensor<3xi64>
1107// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>)
1108// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>
1109func.func @pooling_ndhwc_max_tensor(%input: tensor<1x4x4x4x1xf32>) -> tensor<1x2x2x2x1xf32> {
1110  %fake = tensor.empty() : tensor<3x3x3xf32>
1111  %init = tensor.empty() : tensor<1x2x2x2x1xf32>
1112  %cst = arith.constant 0.000000e+00 : f32
1113  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>
1114  %res = linalg.pooling_ndhwc_max {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
1115    ins(%input, %fake: tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>)
1116    outs(%fill: tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>
1117  return %res : tensor<1x2x2x2x1xf32>
1118}
1119
1120// -----
1121
1122// CHECK-LABEL: func @pooling_ndhwc_max
1123// CHECK:         linalg.pooling_ndhwc_max
1124// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>
1125// CHECK-SAME:      strides = dense<1> : tensor<3xi64>
1126// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x4x1xf32>, memref<3x3x3xf32>)
1127// CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x2x1xf32>)
1128func.func @pooling_ndhwc_max(%input: memref<1x4x4x4x1xf32>, %fake: memref<3x3x3xf32>, %output: memref<1x2x2x2x1xf32>) {
1129  linalg.pooling_ndhwc_max {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
1130    ins(%input, %fake: memref<1x4x4x4x1xf32>, memref<3x3x3xf32>)
1131    outs(%output: memref<1x2x2x2x1xf32>)
1132  return
1133}
1134
1135// -----
1136
1137// CHECK-LABEL: func @pooling_ndhwc_min_tensor
1138// CHECK:         %{{.+}} = linalg.pooling_ndhwc_min
1139// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>
1140// CHECK-SAME:      strides = dense<1> : tensor<3xi64>
1141// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>)
1142// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>
1143func.func @pooling_ndhwc_min_tensor(%input: tensor<1x4x4x4x1xf32>) -> tensor<1x2x2x2x1xf32> {
1144  %fake = tensor.empty() : tensor<3x3x3xf32>
1145  %init = tensor.empty() : tensor<1x2x2x2x1xf32>
1146  %cst = arith.constant 0.000000e+00 : f32
1147  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>
1148  %res = linalg.pooling_ndhwc_min {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
1149    ins(%input, %fake: tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>)
1150    outs(%fill: tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>
1151  return %res : tensor<1x2x2x2x1xf32>
1152}
1153
1154// -----
1155
1156// CHECK-LABEL: func @pooling_ndhwc_min
1157// CHECK:         linalg.pooling_ndhwc_min
1158// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>
1159// CHECK-SAME:      strides = dense<1> : tensor<3xi64>
1160// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x4x1xf32>, memref<3x3x3xf32>)
1161// CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x2x1xf32>)
1162func.func @pooling_ndhwc_min(%input: memref<1x4x4x4x1xf32>, %fake: memref<3x3x3xf32>, %output: memref<1x2x2x2x1xf32>) {
1163  linalg.pooling_ndhwc_min {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
1164    ins(%input, %fake: memref<1x4x4x4x1xf32>, memref<3x3x3xf32>)
1165    outs(%output: memref<1x2x2x2x1xf32>)
1166  return
1167}
1168
1169// -----
1170
1171#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 2, d2 * 2 + d5, d6)>
1172#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
1173#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
1174func.func @conv_interface_wrong_input_indexing_map(
1175    %arg0 : tensor<?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
1176  // expected-error @+1 {{unexpected input index map for convolutions}}
1177  %0 = "linalg.conv_2d_nhwc_hwcf"(%arg0, %arg1, %arg2) ({
1178    ^bb0(%arg3: f32, %arg4: f32, %arg5 : f32):
1179      %1 = "arith.mulf"(%arg3, %arg4) : (f32, f32) -> f32
1180      %2 = "arith.addf"(%arg5, %1) : (f32, f32) -> f32
1181      "linalg.yield"(%2) : (f32) -> ()
1182    }) {dilations = dense<1> : tensor<2xi64>, linalg.memoized_indexing_maps = [#map0, #map1, #map2], operandSegmentSizes = array<i32: 2, 1>, strides = dense<2> : tensor<2xi64>} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
1183  return %0 : tensor<?x?x?x?xf32>
1184}
1185
1186// -----
1187
1188#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
1189#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3, d5 + 1)>
1190#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
1191func.func @conv_interface_wrong_num_operands(
1192    %arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
1193  // expected-error @+1 {{expected output/filter indexing maps to be projected permutations}}
1194  %0 = "linalg.conv_2d_nhwc_hwcf"(%arg0, %arg1, %arg2) ({
1195    ^bb0(%arg3: f32, %arg4: f32, %arg5 : f32):
1196      %1 = "arith.mulf"(%arg3, %arg4) : (f32, f32) -> f32
1197      %2 = "arith.addf"(%arg5, %1) : (f32, f32) -> f32
1198      "linalg.yield"(%2) : (f32) -> ()
1199    }) {dilations = dense<1> : tensor<2xi64>, linalg.memoized_indexing_maps = [#map0, #map1, #map2], operandSegmentSizes = array<i32: 2, 1>, strides = dense<1> : tensor<2xi64>} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
1200  return %0 : tensor<?x?x?x?xf32>
1201}
1202
1203// -----
1204
1205func.func @batch_reduce_matmul(%arg0: tensor<8x128x256xf32>, %arg1: tensor<8x256x512xf32>, %arg2: tensor<128x512xf32>) -> tensor<128x512xf32> {
1206  // CHECK: %{{.+}} = linalg.batch_reduce_matmul
1207  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<8x128x256xf32>, tensor<8x256x512xf32>)
1208  // CHECK-SAME: outs(%{{.+}} : tensor<128x512xf32>) -> tensor<128x512xf32>
1209  %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<8x128x256xf32>, tensor<8x256x512xf32>) outs(%arg2: tensor<128x512xf32>) -> tensor<128x512xf32>
1210  return %0: tensor<128x512xf32>
1211}
1212
1213// -----
1214
1215func.func @batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
1216  // CHECK: linalg.batch_reduce_matmul
1217  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
1218  // CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
1219  linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
1220  return
1221}
1222
1223// -----
1224
1225// CHECK-LABEL: func @matmul_transpose_a
1226//       CHECK:   linalg.matmul_transpose_a
1227//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
1228//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
1229func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
1230  linalg.matmul_transpose_a ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
1231  return
1232}
1233
1234// -----
1235
1236// CHECK-LABEL: func @matmul_transpose_a_explicit
1237//       CHECK:   linalg.matmul
1238//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
1239//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
1240func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
1241  linalg.matmul indexing_maps = [
1242                       affine_map<(d0, d1, d2) -> (d2, d0)>,
1243                       affine_map<(d0, d1, d2) -> (d2, d1)>,
1244                       affine_map<(d0, d1, d2) -> (d0, d1)>
1245                      ]
1246                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
1247                      outs(%arg2: memref<3x7xf32>)
1248  return
1249}
1250
1251// -----
1252
1253func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
1254  linalg.matmul indexing_maps = [
1255                       affine_map<(d0, d1, d2) -> (d0, d2)>,
1256                       affine_map<(d0, d1, d2) -> (d1, d2)>,
1257                       affine_map<(d0, d1, d2) -> (d0, d1)>
1258                      ]
1259                      ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
1260                      outs(%arg2: memref<3x7xf32>)
1261  return
1262}
1263
1264// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
1265// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
1266// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1267
1268// CHECK-LABEL:   func.func @matmul_transpose_b_explicit(
1269// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<3x5xf32>,
1270// CHECK-SAME:                                           %[[VAL_1:.*]]: memref<7x5xf32>,
1271// CHECK-SAME:                                           %[[VAL_2:.*]]: memref<3x7xf32>) {
1272// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1273// CHECK:           return
1274// CHECK:         }
1275
1276// -----
1277
1278func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
1279  linalg.matmul indexing_maps = [
1280                       affine_map<(d0, d1, d2) -> (d2, d0)>,
1281                       affine_map<(d0, d1, d2) -> (d1, d2)>,
1282                       affine_map<(d0, d1, d2) -> (d0, d1)>
1283                      ]
1284                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
1285                      outs(%arg2: memref<3x7xf32>)
1286  return
1287}
1288
1289// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
1290// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
1291// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1292
1293// CHECK-LABEL:   func.func @matmul_transpose_a_b_explicit(
1294// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
1295// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
1296// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
1297// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1298// CHECK:           return
1299// CHECK:         }
1300
1301// -----
1302
1303func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
1304  linalg.matmul indexing_maps = [
1305                       affine_map<(d0, d1, d2) -> (d2)>,
1306                       affine_map<(d0, d1, d2) -> (d2, d1)>,
1307                       affine_map<(d0, d1, d2) -> (d0, d1)>
1308                     ]
1309                     ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
1310  return
1311}
1312
1313// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1314// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
1315// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1316// CHECK-LABEL: func @matmul_bcast_a
1317//       CHECK:   linalg.matmul
1318//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
1319//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
1320
1321// -----
1322
1323func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
1324  linalg.matmul indexing_maps = [
1325                       affine_map<(d0, d1, d2) -> (d2)>,
1326                       affine_map<(d0, d1, d2) -> (d2, d1)>,
1327                       affine_map<(d0, d1, d2) -> (d0, d1)>
1328                     ]
1329                     ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
1330  return
1331}
1332
1333// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1334// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
1335// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1336// CHECK-LABEL: func @matmul_bcast_a_dim1
1337//       CHECK:   linalg.matmul
1338//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
1339//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
1340
1341// -----
1342
1343func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
1344  linalg.matmul indexing_maps = [
1345                       affine_map<(d0, d1, d2) -> (d0, d2)>,
1346                       affine_map<(d0, d1, d2) -> (d2)>,
1347                       affine_map<(d0, d1, d2) -> (d0, d1)>
1348                     ]
1349                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
1350  return
1351}
1352
1353// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
1354// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1355// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1356// CHECK-LABEL: func @matmul_bcast_b
1357//       CHECK:   linalg.matmul
1358//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
1359//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
1360
1361// -----
1362
1363func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
1364  linalg.matmul indexing_maps = [
1365                       affine_map<(d0, d1, d2) -> (d2)>,
1366                       affine_map<(d0, d1, d2) -> (d2)>,
1367                       affine_map<(d0, d1, d2) -> (d0, d1)>
1368                     ]
1369                     ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
1370  return
1371}
1372
1373// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1374// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1375
1376// CHECK-LABEL:   func.func @matmul_bcast_a_b(
1377// CHECK-SAME:                                %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>,
1378// CHECK-SAME:                                %[[VAL_2:.*]]: memref<3x7xf32>) {
1379// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
1380// CHECK:           return
1381// CHECK:         }
1382
1383// -----
1384
1385func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
1386  linalg.matmul indexing_maps = [
1387                       affine_map<(d0, d1, d2) -> (d0, d2)>,
1388                       affine_map<(d0, d1, d2) -> (d2)>,
1389                       affine_map<(d0, d1, d2) -> (d0, d1)>
1390                     ]
1391                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
1392  return
1393}
1394
1395// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
1396// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1397// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1398// CHECK-LABEL: func @matmul_bcast_b_dim1
1399//       CHECK:   linalg.matmul
1400//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
1401//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
1402
1403// -----
1404
1405func.func @dynamic_matmul_bcast_a(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
1406  linalg.matmul indexing_maps = [
1407                       affine_map<(d0, d1, d2) -> (d2)>,
1408                       affine_map<(d0, d1, d2) -> (d2, d1)>,
1409                       affine_map<(d0, d1, d2) -> (d0, d1)>
1410                     ]
1411                     ins(%arg0, %arg1 : memref<?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?xf32>)
1412  return
1413}
1414
1415// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1416// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
1417// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1418
1419// CHECK-LABEL:   func.func @dynamic_matmul_bcast_a(
1420// CHECK-SAME:                                      %[[VAL_0:.*]]: memref<?xf32>,
1421// CHECK-SAME:                                      %[[VAL_1:.*]]: memref<?x?xf32>,
1422// CHECK-SAME:                                      %[[VAL_2:.*]]: memref<?x?xf32>) {
1423// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1424// CHECK:           return
1425// CHECK:         }
1426
1427// -----
1428
1429func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
1430  linalg.matmul indexing_maps = [
1431                       affine_map<(d0, d1, d2) -> (d2)>,
1432                       affine_map<(d0, d1, d2) -> (d1, d2)>,
1433                       affine_map<(d0, d1, d2) -> (d0, d1)>
1434                     ]
1435                     ins(%arg0, %arg1 : memref<5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
1436  return
1437}
1438
1439// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1440// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
1441// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1442
1443// CHECK-LABEL:   func.func @matmul_bcast_a_transpose_b(
1444// CHECK-SAME:                                  %[[VAL_0:.*]]: memref<5xf32>,
1445// CHECK-SAME:                                  %[[VAL_1:.*]]: memref<7x5xf32>,
1446// CHECK-SAME:                                  %[[VAL_2:.*]]: memref<3x7xf32>) {
1447// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1448// CHECK:           return
1449// CHECK:         }
1450
1451// -----
1452
1453func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
1454  linalg.matmul indexing_maps = [
1455                       affine_map<(d0, d1, d2) -> (d2, d0)>,
1456                       affine_map<(d0, d1, d2) -> (d2)>,
1457                       affine_map<(d0, d1, d2) -> (d0, d1)>
1458                     ]
1459                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
1460  return
1461}
1462
1463// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
1464// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1465// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1466
1467// CHECK-LABEL:   func.func @matmul_bcast_b_transpose_a(
1468// CHECK-SAME:                                          %[[VAL_0:.*]]: memref<5x3xf32>,
1469// CHECK-SAME:                                          %[[VAL_1:.*]]: memref<5xf32>,
1470// CHECK-SAME:                                          %[[VAL_2:.*]]: memref<3x7xf32>) {
1471// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1472// CHECK:           return
1473// CHECK:         }
1474
1475// -----
1476
1477// CHECK-LABEL: func @matmul_transpose_b
1478//       CHECK:   linalg.matmul_transpose_b
1479//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
1480//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
1481func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
1482  linalg.matmul_transpose_b ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
1483  return
1484}
1485
1486// -----
1487
1488// CHECK-LABEL: func @batchmatmul_transpose_a
1489//       CHECK:   linalg.batch_matmul_transpose_a
1490//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
1491//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
1492func.func @batchmatmul_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
1493  linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
1494  return
1495}
1496
1497// -----
1498
1499// CHECK-LABEL: func @batchmatmul_transpose_b
1500//       CHECK:   linalg.batch_matmul_transpose_b
1501//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>)
1502//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
1503func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
1504  linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>)
1505  return
1506}
1507
1508// -----
1509
1510// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
1511// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
1512// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
1513// CHECK-LABEL: func @contract
1514//       CHECK:   linalg.contract
1515//  CHECK-SAME:     indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1516//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x5x7xf32>)
1517//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
1518func.func @contract(
1519    %A: memref<2x3x5xf32>, %B: memref<2x5x7xf32>, %C: memref<2x3x7xf32>) {
1520  linalg.contract
1521      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
1522                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
1523                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
1524      ins(%A, %B : memref<2x3x5xf32>, memref<2x5x7xf32>)
1525      outs(%C: memref<2x3x7xf32>)
1526  return
1527}
1528
1529// -----
1530
1531// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1532// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
1533// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1534// CHECK-LABEL: func @contract_matmul_bcast_a
1535func.func @contract_matmul_bcast_a(%A: memref<5xf32>, %B: memref<5x7xf32>, %C: memref<3x7xf32>) {
1536// CHECK:  linalg.contract
1537// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
1538// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
1539// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
1540  linalg.contract
1541      indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>,
1542                       affine_map<(d0, d1, d2) -> (d2, d1)>,
1543                       affine_map<(d0, d1, d2) -> (d0, d1)>]
1544      ins(%A, %B : memref<5xf32>, memref<5x7xf32>)
1545      outs(%C: memref<3x7xf32>)
1546  return
1547}
1548
1549// -----
1550
1551// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
1552// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1553// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1554// CHECK-LABEL: func @contract_matmul_bcast_b
1555func.func @contract_matmul_bcast_b(%A: memref<3x5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
1556// CHECK:  linalg.contract
1557// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
1558// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
1559// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
1560  linalg.contract
1561      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
1562                       affine_map<(d0, d1, d2) -> (d2)>,
1563                       affine_map<(d0, d1, d2) -> (d0, d1)>]
1564      ins(%A, %B : memref<3x5xf32>, memref<5xf32>)
1565      outs(%C: memref<3x7xf32>)
1566  return
1567}
1568
1569// -----
1570
1571// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1572// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1573// CHECK-LABEL: func.func @contract_matmul_bcast_a_b
1574func.func @contract_matmul_bcast_a_b(
1575    %A: memref<5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
1576// CHECK:  linalg.contract
1577// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_A]], #[[$ACCESS_B]]]
1578// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5xf32>)
1579// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
1580  linalg.contract
1581      indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>,
1582                       affine_map<(d0, d1, d2) -> (d2)>,
1583                       affine_map<(d0, d1, d2) -> (d0, d1)>]
1584      ins(%A, %B : memref<5xf32>, memref<5xf32>)
1585      outs(%C: memref<3x7xf32>)
1586  return
1587}
1588
1589// -----
1590
1591// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1592// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
1593// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1594// CHECK-LABEL: func.func @contract_matmul_bcast_a_transpose_b
1595func.func @contract_matmul_bcast_a_transpose_b(
1596    %A: memref<5xf32>, %B: memref<7x5xf32>, %C: memref<3x7xf32>) {
1597// CHECK:  linalg.contract
1598// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
1599// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<7x5xf32>)
1600// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
1601  linalg.contract
1602      indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>,
1603                       affine_map<(d0, d1, d2) -> (d1, d2)>,
1604                       affine_map<(d0, d1, d2) -> (d0, d1)>]
1605      ins(%A, %B : memref<5xf32>, memref<7x5xf32>)
1606      outs(%C: memref<3x7xf32>)
1607  return
1608}
1609
1610// -----
1611
1612// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
1613// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1614// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1615// CHECK-LABEL:   func.func @contract_matmul_bcast_b_transpose_a
1616func.func @contract_matmul_bcast_b_transpose_a(%A: memref<5x3xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
1617// CHECK:      linalg.contract
1618// CHECK-SAME:     indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
1619// CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5xf32>)
1620// CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
1621  linalg.contract
1622      indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
1623                       affine_map<(d0, d1, d2) -> (d2)>,
1624                       affine_map<(d0, d1, d2) -> (d0, d1)>]
1625      ins(%A, %B : memref<5x3xf32>, memref<5xf32>)
1626      outs(%C: memref<3x7xf32>)
1627  return
1628}
1629
1630// -----
1631
1632// CHECK-LABEL: func @mmt4d
1633func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> {
1634  // CHECK: %{{.+}} = linalg.mmt4d
1635  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>)
1636  // CHECK-SAME: outs(%{{.+}} : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
1637  %0 = linalg.mmt4d ins(%A, %B : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
1638  return %0: tensor<10x80x8x4xf32>
1639}
1640
1641// -----
1642
1643// CHECK-LABEL: func @batch_mmt4d
1644func.func @batch_mmt4d(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
1645  // CHECK: %{{.+}} = linalg.batch_mmt4d
1646  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<128x10x32x8x1xf32>, tensor<128x80x32x4x1xf32>)
1647  // CHECK-SAME: outs(%{{.+}} : tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32>
1648  %0 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<128x10x32x8x1xf32>, tensor<128x80x32x4x1xf32>) outs(%arg2 : tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32>
1649  return %0: tensor<128x10x80x8x4xf32>
1650}
1651
1652// -----
1653
1654// CHECK-LABEL: func @add_dynamic
1655func.func @add_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
1656  // CHECK: linalg.add
1657  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
1658  // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
1659  linalg.add ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
1660  return
1661}
1662
1663// -----
1664
1665// CHECK-LABEL: func @add_static
1666func.func @add_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
1667  // CHECK: linalg.add
1668  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
1669  // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
1670  linalg.add ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
1671  return
1672}
1673
1674// -----
1675
1676// CHECK-LABEL: func @add_tensor
1677func.func @add_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
1678  %0 = tensor.empty() : tensor<4x8x16xf32>
1679  // CHECK: linalg.add
1680  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
1681  // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
1682  %1 = linalg.add ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
1683  return %1 : tensor<4x8x16xf32>
1684}
1685
1686// -----
1687
1688// CHECK-LABEL: func @sub_dynamic
1689func.func @sub_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
1690  // CHECK: linalg.sub
1691  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
1692  // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
1693  linalg.sub ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
1694  return
1695}
1696
1697// -----
1698
1699// CHECK-LABEL: func @sub_static
1700func.func @sub_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
1701  // CHECK: linalg.sub
1702  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
1703  // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
1704  linalg.sub ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
1705  return
1706}
1707
1708// -----
1709
1710// CHECK-LABEL: func @sub_tensor
1711func.func @sub_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
1712  %0 = tensor.empty() : tensor<4x8x16xf32>
1713  // CHECK: linalg.sub
1714  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
1715  // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
1716  %1 = linalg.sub ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
1717  return %1 : tensor<4x8x16xf32>
1718}
1719
1720// -----
1721
1722// CHECK-LABEL: func @mul_dynamic
1723func.func @mul_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
1724  // CHECK: linalg.mul
1725  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
1726  // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
1727  linalg.mul ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
1728  return
1729}
1730
1731// -----
1732
1733// CHECK-LABEL: func @mul_static
1734func.func @mul_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
1735  // CHECK: linalg.mul
1736  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
1737  // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
1738  linalg.mul ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
1739  return
1740}
1741
1742// -----
1743
1744// CHECK-LABEL: func @mul_tensor
1745func.func @mul_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
1746  %0 = tensor.empty() : tensor<4x8x16xf32>
1747  // CHECK: linalg.mul
1748  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
1749  // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
1750  %1 = linalg.mul ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
1751  return %1 : tensor<4x8x16xf32>
1752}
1753
1754// -----
1755
1756// CHECK-LABEL: func @div_dynamic
1757func.func @div_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
1758  // CHECK: linalg.div
1759  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
1760  // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
1761  linalg.div ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
1762  return
1763}
1764
1765// -----
1766
1767// CHECK-LABEL: func @div_static
1768func.func @div_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
1769  // CHECK: linalg.div
1770  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
1771  // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
1772  linalg.div ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
1773  return
1774}
1775
1776// -----
1777
1778// CHECK-LABEL: func @div_tensor
1779func.func @div_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
1780  %0 = tensor.empty() : tensor<4x8x16xf32>
1781  // CHECK: linalg.div
1782  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
1783  // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
1784  %1 = linalg.div ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
1785  return %1 : tensor<4x8x16xf32>
1786}
1787
1788// -----
1789
1790// CHECK-LABEL: func @div_unsigned_dynamic
1791func.func @div_unsigned_dynamic(%arg0: memref<?x?x?xi32>, %arg1: memref<?x?x?xi32>, %arg2: memref<?x?x?xi32>) {
1792  // CHECK: linalg.div_unsigned
1793  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xi32>, memref<?x?x?xi32>)
1794  // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xi32>)
1795  linalg.div_unsigned ins(%arg0, %arg1 : memref<?x?x?xi32>, memref<?x?x?xi32>) outs(%arg2: memref<?x?x?xi32>)
1796  return
1797}
1798
1799// -----
1800
1801// CHECK-LABEL: func @div_unsigned_static
1802func.func @div_unsigned_static(%arg0: memref<4x8x16xi32>, %arg1: memref<4x8x16xi32>, %arg2: memref<4x8x16xi32>) {
1803  // CHECK: linalg.div_unsigned
1804  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xi32>, memref<4x8x16xi32>)
1805  // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xi32>)
1806  linalg.div_unsigned ins(%arg0, %arg1 : memref<4x8x16xi32>, memref<4x8x16xi32>) outs(%arg2: memref<4x8x16xi32>)
1807  return
1808}
1809
1810// -----
1811
1812// CHECK-LABEL: func @div_unsigned_tensor
1813func.func @div_unsigned_tensor(%arg0: tensor<4x8x16xi32>, %arg1: tensor<4x8x16xi32>) -> tensor<4x8x16xi32> {
1814  %0 = tensor.empty() : tensor<4x8x16xi32>
1815  // CHECK: linalg.div_unsigned
1816  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xi32>, tensor<4x8x16xi32>)
1817  // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xi32>)
1818  %1 = linalg.div_unsigned ins(%arg0, %arg1 : tensor<4x8x16xi32>, tensor<4x8x16xi32>) outs(%0: tensor<4x8x16xi32>) -> tensor<4x8x16xi32>
1819  return %1 : tensor<4x8x16xi32>
1820}
1821
1822// -----
1823
1824// CHECK-LABEL: func @exp_dynamic
1825func.func @exp_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
1826  // CHECK: linalg.exp
1827  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
1828  linalg.exp ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
1829  return
1830}
1831
1832// -----
1833
1834// CHECK-LABEL: func @exp_static
1835func.func @exp_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
1836  // CHECK: linalg.exp
1837  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
1838  linalg.exp ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
1839  return
1840}
1841
1842// -----
1843
1844// CHECK-LABEL: func @exp_tensor
1845func.func @exp_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
1846  %0 = tensor.empty() : tensor<4x8x16xf32>
1847  // CHECK: linalg.exp
1848  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
1849  %1 = linalg.exp ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
1850  return %1 : tensor<4x8x16xf32>
1851}
1852
1853// -----
1854
1855// CHECK-LABEL: func @log_dynamic
1856func.func @log_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
1857  // CHECK: linalg.log
1858  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
1859  linalg.log ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
1860  return
1861}
1862
1863// -----
1864
1865// CHECK-LABEL: func @log_static
1866func.func @log_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
1867  // CHECK: linalg.log
1868  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
1869  linalg.log ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
1870  return
1871}
1872
1873// -----
1874
1875// CHECK-LABEL: func @log_tensor
1876func.func @log_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
1877  %0 = tensor.empty() : tensor<4x8x16xf32>
1878  // CHECK: linalg.log
1879  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
1880  %1 = linalg.log ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
1881  return %1 : tensor<4x8x16xf32>
1882}
1883
1884// -----
1885
1886// CHECK-LABEL: func @abs_dynamic
1887func.func @abs_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
1888  // CHECK: linalg.abs
1889  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
1890  linalg.abs ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
1891  return
1892}
1893
1894// -----
1895
1896// CHECK-LABEL: func @abs_static
1897func.func @abs_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
1898  // CHECK: linalg.abs
1899  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
1900  linalg.abs ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
1901  return
1902}
1903
1904// -----
1905
1906// CHECK-LABEL: func @abs_tensor
1907func.func @abs_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
1908  %0 = tensor.empty() : tensor<4x8x16xf32>
1909  // CHECK: linalg.abs
1910  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
1911  %1 = linalg.abs ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
1912  return %1 : tensor<4x8x16xf32>
1913}
1914
1915// -----
1916
1917// CHECK-LABEL: func @ceil_dynamic
1918func.func @ceil_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
1919  // CHECK: linalg.ceil
1920  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
1921  linalg.ceil ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
1922  return
1923}
1924
1925// -----
1926
1927// CHECK-LABEL: func @ceil_static
1928func.func @ceil_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
1929  // CHECK: linalg.ceil
1930  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
1931  linalg.ceil ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
1932  return
1933}
1934
1935// -----
1936
1937// CHECK-LABEL: func @ceil_tensor
1938func.func @ceil_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
1939  %0 = tensor.empty() : tensor<4x8x16xf32>
1940  // CHECK: linalg.ceil
1941  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
1942  %1 = linalg.ceil ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
1943  return %1 : tensor<4x8x16xf32>
1944}
1945
1946// -----
1947
1948// CHECK-LABEL: func @floor_dynamic
1949func.func @floor_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
1950  // CHECK: linalg.floor
1951  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
1952  linalg.floor ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
1953  return
1954}
1955
1956// -----
1957
1958// CHECK-LABEL: func @floor_static
1959func.func @floor_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
1960  // CHECK: linalg.floor
1961  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
1962  linalg.floor ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
1963  return
1964}
1965
1966// -----
1967
1968// CHECK-LABEL: func @floor_tensor
1969func.func @floor_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
1970  %0 = tensor.empty() : tensor<4x8x16xf32>
1971  // CHECK: linalg.floor
1972  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
1973  %1 = linalg.floor ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
1974  return %1 : tensor<4x8x16xf32>
1975}
1976
1977// -----
1978
1979// CHECK-LABEL: func @negf_dynamic
1980func.func @negf_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
1981  // CHECK: linalg.negf
1982  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
1983  linalg.negf ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
1984  return
1985}
1986
1987// -----
1988
1989// CHECK-LABEL: func @negf_static
1990func.func @negf_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
1991  // CHECK: linalg.negf
1992  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
1993  linalg.negf ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
1994  return
1995}
1996
1997// -----
1998
1999// CHECK-LABEL: func @negf_tensor
2000func.func @negf_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
2001  %0 = tensor.empty() : tensor<4x8x16xf32>
2002  // CHECK: linalg.negf
2003  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
2004  %1 = linalg.negf ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
2005  return %1 : tensor<4x8x16xf32>
2006}
2007
2008// -----
2009
2010// CHECK-LABEL: func @reciprocal_dynamic
2011func.func @reciprocal_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
2012  // CHECK: linalg.reciprocal
2013  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
2014  linalg.reciprocal ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
2015  return
2016}
2017
2018// -----
2019
2020// CHECK-LABEL: func @reciprocal_static
2021func.func @reciprocal_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
2022  // CHECK: linalg.reciprocal
2023  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
2024  linalg.reciprocal ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
2025  return
2026}
2027
2028// -----
2029
2030// CHECK-LABEL: func @reciprocal_tensor
2031func.func @reciprocal_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
2032  %0 = tensor.empty() : tensor<4x8x16xf32>
2033  // CHECK: linalg.reciprocal
2034  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
2035  %1 = linalg.reciprocal ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
2036  return %1 : tensor<4x8x16xf32>
2037}
2038
2039// -----
2040
2041// CHECK-LABEL: func @round_dynamic
2042func.func @round_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
2043  // CHECK: linalg.round
2044  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
2045  linalg.round ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
2046  return
2047}
2048
2049// -----
2050
2051// CHECK-LABEL: func @round_static
2052func.func @round_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
2053  // CHECK: linalg.round
2054  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
2055  linalg.round ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
2056  return
2057}
2058
2059// -----
2060
2061// CHECK-LABEL: func @round_tensor
2062func.func @round_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
2063  %0 = tensor.empty() : tensor<4x8x16xf32>
2064  // CHECK: linalg.round
2065  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
2066  %1 = linalg.round ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
2067  return %1 : tensor<4x8x16xf32>
2068}
2069
2070// -----
2071
2072// CHECK-LABEL: func @sqrt_dynamic
2073func.func @sqrt_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
2074  // CHECK: linalg.sqrt
2075  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
2076  linalg.sqrt ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
2077  return
2078}
2079
2080// -----
2081
2082// CHECK-LABEL: func @sqrt_static
2083func.func @sqrt_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
2084  // CHECK: linalg.sqrt
2085  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
2086  linalg.sqrt ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
2087  return
2088}
2089
2090// -----
2091
2092// CHECK-LABEL: func @sqrt_tensor
2093func.func @sqrt_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
2094  %0 = tensor.empty() : tensor<4x8x16xf32>
2095  // CHECK: linalg.sqrt
2096  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
2097  %1 = linalg.sqrt ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
2098  return %1 : tensor<4x8x16xf32>
2099}
2100
2101// -----
2102
2103// CHECK-LABEL: func @rsqrt_dynamic
2104func.func @rsqrt_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
2105  // CHECK: linalg.rsqrt
2106  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
2107  linalg.rsqrt ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
2108  return
2109}
2110
2111// -----
2112
2113// CHECK-LABEL: func @rsqrt_static
2114func.func @rsqrt_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
2115  // CHECK: linalg.rsqrt
2116  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
2117  linalg.rsqrt ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
2118  return
2119}
2120
2121// -----
2122
2123// CHECK-LABEL: func @rsqrt_tensor
2124func.func @rsqrt_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
2125  %0 = tensor.empty() : tensor<4x8x16xf32>
2126  // CHECK: linalg.rsqrt
2127  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
2128  %1 = linalg.rsqrt ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
2129  return %1 : tensor<4x8x16xf32>
2130}
2131
2132// -----
2133
2134// CHECK-LABEL: func @square_dynamic
2135func.func @square_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
2136  // CHECK: linalg.square
2137  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
2138  linalg.square ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
2139  return
2140}
2141
2142// -----
2143
2144// CHECK-LABEL: func @square_static
2145func.func @square_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
2146  // CHECK: linalg.square
2147  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
2148  linalg.square ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
2149  return
2150}
2151
2152// -----
2153
2154// CHECK-LABEL: func @square_tensor
2155func.func @square_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
2156  %0 = tensor.empty() : tensor<4x8x16xf32>
2157  // CHECK: linalg.square
2158  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
2159  %1 = linalg.square ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
2160  return %1 : tensor<4x8x16xf32>
2161}
2162
2163// -----
2164
2165// CHECK-LABEL: func @tanh_dynamic
2166func.func @tanh_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
2167  // CHECK: linalg.tanh
2168  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
2169  linalg.tanh ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
2170  return
2171}
2172
2173// -----
2174
2175// CHECK-LABEL: func @tanh_static
2176func.func @tanh_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
2177  // CHECK: linalg.tanh
2178  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
2179  linalg.tanh ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
2180  return
2181}
2182
2183// -----
2184
2185// CHECK-LABEL: func @tanh_tensor
2186func.func @tanh_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
2187  %0 = tensor.empty() : tensor<4x8x16xf32>
2188  // CHECK: linalg.tanh
2189  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
2190  %1 = linalg.tanh ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
2191  return %1 : tensor<4x8x16xf32>
2192}
2193
2194// -----
2195
2196// CHECK-LABEL: func @erf_dynamic
2197func.func @erf_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
2198  // CHECK: linalg.erf
2199  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
2200  linalg.erf ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
2201  return
2202}
2203
2204// -----
2205
2206// CHECK-LABEL: func @erf_static
2207func.func @erf_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
2208  // CHECK: linalg.erf
2209  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
2210  linalg.erf ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
2211  return
2212}
2213
2214// -----
2215
2216// CHECK-LABEL: func @erf_tensor
2217func.func @erf_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
2218  %0 = tensor.empty() : tensor<4x8x16xf32>
2219  // CHECK: linalg.erf
2220  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
2221  %1 = linalg.erf ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
2222  return %1 : tensor<4x8x16xf32>
2223}
2224
2225// -----
2226
2227// CHECK-LABEL: func @max_dynamic
2228func.func @max_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
2229  // CHECK: linalg.max
2230  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
2231  // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
2232  linalg.max ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
2233  return
2234}
2235
2236// -----
2237
2238// CHECK-LABEL: func @max_static
2239func.func @max_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
2240  // CHECK: linalg.max
2241  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
2242  // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
2243  linalg.max ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
2244  return
2245}
2246
2247// -----
2248
2249// CHECK-LABEL: func @max_tensor
2250func.func @max_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
2251  %0 = tensor.empty() : tensor<4x8x16xf32>
2252  // CHECK: linalg.max
2253  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
2254  // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
2255  %1 = linalg.max ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
2256  return %1 : tensor<4x8x16xf32>
2257}
2258
2259// -----
2260
2261// CHECK-LABEL: func @min_dynamic
2262func.func @min_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
2263  // CHECK: linalg.min
2264  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
2265  // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
2266  linalg.min ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
2267  return
2268}
2269
2270// -----
2271
2272// CHECK-LABEL: func @min_static
2273func.func @min_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
2274  // CHECK: linalg.min
2275  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
2276  // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
2277  linalg.min ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
2278  return
2279}
2280
2281// -----
2282
2283// CHECK-LABEL: func @min_tensor
2284func.func @min_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
2285  %0 = tensor.empty() : tensor<4x8x16xf32>
2286  // CHECK: linalg.min
2287  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
2288  // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
2289  %1 = linalg.min ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
2290  return %1 : tensor<4x8x16xf32>
2291}
2292
2293// -----
2294
2295// CHECK-LABEL: func @powf_dynamic
2296func.func @powf_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
2297  // CHECK: linalg.powf
2298  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
2299  // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
2300  linalg.powf ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
2301  return
2302}
2303
2304// -----
2305
2306// CHECK-LABEL: func @powf_static
2307func.func @powf_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
2308  // CHECK: linalg.powf
2309  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
2310  // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
2311  linalg.powf ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
2312  return
2313}
2314
2315// -----
2316
2317// CHECK-LABEL: func @powf_tensor
2318func.func @powf_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
2319  %0 = tensor.empty() : tensor<4x8x16xf32>
2320  // CHECK: linalg.powf
2321  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
2322  // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
2323  %1 = linalg.powf ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
2324  return %1 : tensor<4x8x16xf32>
2325}
2326
2327// -----
2328
2329// CHECK-LABEL: func @fill_tensor
2330func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
2331  %e0 = tensor.empty() : tensor<f32>
2332  %0 = linalg.fill ins(%f : f32) outs(%e0 : tensor<f32>) -> tensor<f32>
2333  %e1 = tensor.empty() : tensor<vector<2x4xf32>>
2334  %1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
2335  return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
2336}
2337
2338// -----
2339
2340// CHECK-LABEL: func @select_dynamic
2341func.func @select_dynamic(%arg0: memref<?x?x?xi1>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>, %arg3: memref<?x?x?xf32>) {
2342  // CHECK: linalg.select
2343  // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : memref<?x?x?xi1>, memref<?x?x?xf32>, memref<?x?x?xf32>)
2344  // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
2345  linalg.select ins(%arg0, %arg1, %arg2 : memref<?x?x?xi1>, memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg3: memref<?x?x?xf32>)
2346  return
2347}
2348
2349// -----
2350
2351// CHECK-LABEL: func @select_static
2352func.func @select_static(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
2353  // CHECK: linalg.select
2354  // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>)
2355  // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
2356  linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
2357  return
2358}
2359
2360// -----
2361
2362// CHECK-LABEL: func @select_tensor
2363func.func @select_tensor(%arg0: tensor<4x8x16xi1>, %arg1: tensor<4x8x16xf32>, %arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
2364  %0 = tensor.empty() : tensor<4x8x16xf32>
2365  // CHECK: linalg.select
2366  // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>)
2367  // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
2368  %1 = linalg.select ins(%arg0, %arg1, %arg2 : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
2369  return %1 : tensor<4x8x16xf32>
2370}
2371