xref: /llvm-project/mlir/test/Dialect/Linalg/vectorize-convolution.mlir (revision 234193bae6cf8b19703c6e543b100517bb99a9f7)
1// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-linalg-to-vector-patterns %s | FileCheck %s
2
3func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf32>, %output: memref<4x2x8xf32>) {
4  linalg.conv_1d_nwc_wcf
5    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
6    ins(%input, %filter : memref<4x6x3xf32>, memref<1x3x8xf32>)
7    outs(%output : memref<4x2x8xf32>)
8  return
9}
10
11// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
12// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
13// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
14
15//      CHECK: func @conv1d_nwc_4x2x8_memref
16// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<1x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>)
17
18//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
19//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
20
21/// Read the whole data in one shot.
22//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
23//  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
24//  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
25
26//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
27// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
28//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
29// CHECK-SAME:     {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
30
31//      CHECK:    %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<3x8xf32> from vector<1x3x8xf32>
32
33//      CHECK:  %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
34// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
35//      CHECK:  %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
36// CHECK-SAME:     {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
37
38/// w == 0, kw == 0
39//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
40// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
41// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
42// CHECK-SAME:       kind = #vector.kind<add>
43// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
44// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
45
46/// w == 1, kw == 0
47//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
48// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
49// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
50// CHECK-SAME:       kind = #vector.kind<add>
51// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
52// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
53
54/// w == 0, kw == 0
55//      CHECK:   %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]]
56// CHECK-SAME:     {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
57/// w == 1, kw == 0
58//      CHECK:   %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_1]], %[[RES_0]]
59// CHECK-SAME:     {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
60
61// Write the result back in one shot.
62//      CHECK:   vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
63
64// -----
65
66// This test is same as above but for i1 type with the only difference being that
67// the combining kind for `vector.contract` is `OR`.
68func.func @conv1d_nwc_4x2x8_memref_i1(%input: memref<4x6x3xi1>, %filter: memref<1x3x8xi1>, %output: memref<4x2x8xi1>) {
69  linalg.conv_1d_nwc_wcf
70    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
71    ins(%input, %filter : memref<4x6x3xi1>, memref<1x3x8xi1>)
72    outs(%output : memref<4x2x8xi1>)
73  return
74}
75// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
76// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
77// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
78
79//      CHECK: func @conv1d_nwc_4x2x8_memref_i1
80/// w == 0, kw == 0
81//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
82// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
83// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
84// CHECK-SAME:       kind = #vector.kind<or>
85// CHECK-SAME:       : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
86
87/// w == 1, kw == 0
88//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
89// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
90// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
91// CHECK-SAME:       kind = #vector.kind<or>
92// CHECK-SAME:       : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
93
94// -----
95
96// The i8i8i32 case is similar to f32 case, so checking one case is enough for
97// test coverage.
98func.func @conv1d_nwc_4x2x8_i8i8i32_memref(%input: memref<4x6x3xi8>, %filter: memref<1x3x8xi8>, %output: memref<4x2x8xi32>) {
99  linalg.conv_1d_nwc_wcf
100    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
101    ins(%input, %filter : memref<4x6x3xi8>, memref<1x3x8xi8>)
102    outs(%output : memref<4x2x8xi32>)
103  return
104}
105
106// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
107// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
108// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
109
110//      CHECK: func @conv1d_nwc_4x2x8_i8i8i32_memref
111// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xi8>, %[[FILTER:.+]]: memref<1x3x8xi8>, %[[OUTPUT:.+]]: memref<4x2x8xi32>)
112
113//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
114//  CHECK-DAG:   %[[C0_I8:.+]] = arith.constant 0 : i8
115//  CHECK-DAG:   %[[C0_I32:.+]] = arith.constant 0 : i32
116
117/// Read the whole data in one shot.
118//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]]
119//  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]]
120//  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I32]]
121
122//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
123// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8>
124//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
125// CHECK-SAME:     {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8>
126
127//      CHECK:    %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<3x8xi8> from vector<1x3x8xi8>
128
129//      CHECK:  %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
130// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xi32> to vector<4x1x8xi32>
131//      CHECK:  %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
132// CHECK-SAME:     {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xi32> to vector<4x1x8xi32>
133
134/// w == 0, kw == 0
135//      CHECK:   %[[EXT_LHS_0:.+]] = arith.extsi %[[V_INPUT_0]] : vector<4x1x3xi8> to vector<4x1x3xi32>
136//      CHECK:   %[[EXT_RHS_0:.+]] = arith.extsi %[[V_FILTER]] : vector<3x8xi8> to vector<3x8xi32>
137//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
138// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
139// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
140// CHECK-SAME:     %[[EXT_LHS_0]], %[[EXT_RHS_0]], %[[V_OUTPUT_0]]
141// CHECK-SAME:     : vector<4x1x3xi32>, vector<3x8xi32> into vector<4x1x8xi32>
142
143/// w == 1, kw == 0
144//      CHECK:   %[[EXT_LHS_1:.+]] = arith.extsi %[[V_INPUT_1]] : vector<4x1x3xi8> to vector<4x1x3xi32>
145//      CHECK:   %[[EXT_RHS_1:.+]] = arith.extsi %[[V_FILTER]] : vector<3x8xi8> to vector<3x8xi32>
146//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
147// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
148// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
149// CHECK-SAME:     %[[EXT_LHS_1]], %[[EXT_RHS_1]], %[[V_OUTPUT_1]]
150// CHECK-SAME:     : vector<4x1x3xi32>, vector<3x8xi32> into vector<4x1x8xi32>
151
152/// w == 0, kw == 0
153//      CHECK:   %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]]
154// CHECK-SAME:     {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xi32> into vector<4x2x8xi32>
155/// w == 1, kw == 0
156//      CHECK:   %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_1]], %[[RES_0]]
157// CHECK-SAME:     {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xi32> into vector<4x2x8xi32>
158
159// Write the result back in one shot.
160//      CHECK:   vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
161
162// -----
163
164func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) {
165  linalg.conv_1d_nwc_wcf
166    {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
167    ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>)
168    outs(%output : memref<4x2x8xf32>)
169  return
170}
171
172// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
173// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
174// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
175
176//      CHECK: func @conv1d_nwc_4x2x8_memref
177// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>)
178
179//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
180//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
181
182/// Read the whole data in one shot.
183//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
184//  CHECK-DAG:   %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
185//  CHECK-DAG:   %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
186
187//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
188// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
189//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
190// CHECK-SAME:     {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
191//      CHECK:   %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
192// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
193//      CHECK:   %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
194// CHECK-SAME:     {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
195
196//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<3x8xf32> from vector<2x3x8xf32>
197//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<3x8xf32> from vector<2x3x8xf32>
198
199//      CHECK:  %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
200// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
201//      CHECK:  %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
202// CHECK-SAME:     {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
203
204/// w == 0, kw == 0
205//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
206// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
207// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
208// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_0]]
209// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
210/// w == 1, kw == 0
211//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
212// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
213// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
214// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER_0]], %[[V_OUTPUT_1]]
215// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
216/// w == 1, kw == 1
217//      CHECK:   %[[CONTRACT_2:.+]] = vector.contract {
218// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
219// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
220// CHECK-SAME:     %[[V_INPUT_2]], %[[V_FILTER_1]], %[[CONTRACT_0]]
221// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
222/// w == 1, kw == 1
223//      CHECK:   %[[CONTRACT_3:.+]] = vector.contract {
224// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
225// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
226// CHECK-SAME:     %[[V_INPUT_3]], %[[V_FILTER_1]], %[[CONTRACT_1]]
227// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
228
229/// w == 0, kw == 0
230//      CHECK:   %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_2]], %[[V_OUTPUT_R]]
231// CHECK-SAME:     {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
232/// w == 1, kw == 0
233//      CHECK:   %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_3]], %[[RES_0]]
234// CHECK-SAME:     {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
235
236// Write the result back in one shot.
237//      CHECK:   vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
238
239// -----
240
241func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) {
242  linalg.conv_1d_nwc_wcf
243    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
244    ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>)
245    outs(%output : memref<4x2x8xf32>)
246  return
247}
248
249// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
250// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
251// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
252
253//      CHECK: func @conv1d_nwc_4x2x8_memref
254// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>)
255
256//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
257//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
258
259/// Read the whole data in one shot.
260//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
261//  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
262//  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
263
264//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
265// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
266//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
267// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
268
269//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<3x8xf32> from vector<2x3x8xf32>
270//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<3x8xf32> from vector<2x3x8xf32>
271
272/// w == 0, kw == 0
273//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
274// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
275// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
276// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]]
277// CHECK-SAME:     : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
278/// w == 0, kw == 1
279//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
280// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
281// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
282// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]]
283// CHECK-SAME:     : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
284
285// Write the result back in one shot.
286//      CHECK:   vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
287
288// -----
289
290func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x3x1xf32>, %output: memref<4x8x2xf32>) {
291  linalg.conv_1d_ncw_fcw
292    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
293    ins(%input, %filter : memref<4x3x6xf32>, memref<8x3x1xf32>)
294    outs(%output : memref<4x8x2xf32>)
295  return
296}
297
298// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
299// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
300// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
301
302//      CHECK: func @conv1d_ncw_4x8x2_memref
303// CHECK-SAME: (%[[INPUT:.+]]: memref<4x3x6xf32>, %[[FILTER:.+]]: memref<8x3x1xf32>, %[[OUTPUT:.+]]: memref<4x8x2xf32>)
304
305//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
306//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
307
308/// Read the whole data in one shot.
309//  CHECK-DAG:   %[[V_NWC_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
310//  CHECK-DAG:  %[[V_NWC_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
311//  CHECK-DAG:  %[[V_NWC_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
312
313/// Transpose result to nwc format.
314//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transpose %[[V_NWC_INPUT_R]], [0, 2, 1]
315//  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transpose %[[V_NWC_FILTER_R]], [2, 1, 0]
316//  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transpose %[[V_NWC_OUTPUT_R]], [0, 2, 1]
317
318//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
319// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
320//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
321// CHECK-SAME:     {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
322
323//      CHECK:    %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<3x8xf32> from vector<1x3x8xf32>
324
325//      CHECK:  %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
326// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
327//      CHECK:  %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
328// CHECK-SAME:     {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
329
330/// w == 0, kw == 0
331//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
332// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
333// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
334// CHECK-SAME:       kind = #vector.kind<add>
335// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
336// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
337
338/// w == 1, kw == 0
339//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
340// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
341// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
342// CHECK-SAME:       kind = #vector.kind<add>
343// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
344// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
345
346/// w == 0, kw == 0
347//      CHECK:   %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]]
348// CHECK-SAME:     {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
349/// w == 1, kw == 0
350//      CHECK:   %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_1]], %[[RES_0]]
351// CHECK-SAME:     {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
352
353/// Transpose result to ncw format.
354//  CHECK:  %[[RES_2:.+]] = vector.transpose %[[RES_1]], [0, 2, 1]
355
356// Write the result back in one shot.
357//      CHECK:   vector.transfer_write %[[RES_2]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
358
359// -----
360
361// This test is same as above but for i1 type with the only difference being that
362// the combining kind for `vector.contract` is `OR`.
363func.func @conv1d_ncw_4x8x2_memref_i1(%input: memref<4x3x6xi1>, %filter: memref<8x3x1xi1>, %output: memref<4x8x2xi1>) {
364  linalg.conv_1d_ncw_fcw
365    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
366    ins(%input, %filter : memref<4x3x6xi1>, memref<8x3x1xi1>)
367    outs(%output : memref<4x8x2xi1>)
368  return
369}
370
371// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
372// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
373// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
374
375//      CHECK: func @conv1d_ncw_4x8x2_memref_i1
376/// w == 0, kw == 0
377//      CHECK:   vector.contract {
378// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
379// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
380// CHECK-SAME:       kind = #vector.kind<or>
381// CHECK-SAME:       : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
382
383/// w == 1, kw == 0
384//      CHECK:   vector.contract {
385// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
386// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
387// CHECK-SAME:       kind = #vector.kind<or>
388// CHECK-SAME:       : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
389
390// -----
391
392func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x3x2xf32>, %output: memref<4x8x2xf32>) {
393  linalg.conv_1d_ncw_fcw
394    {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
395    ins(%input, %filter : memref<4x3x6xf32>, memref<8x3x2xf32>)
396    outs(%output : memref<4x8x2xf32>)
397  return
398}
399
400// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
401// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
402// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
403
404//      CHECK: func @conv1d_ncw_4x8x2_memref
405// CHECK-SAME: (%[[INPUT:.+]]: memref<4x3x6xf32>, %[[FILTER:.+]]: memref<8x3x2xf32>, %[[OUTPUT:.+]]: memref<4x8x2xf32>)
406
407//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
408//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
409
410/// Read the whole data in one shot.
411//  CHECK-DAG:   %[[V_NWC_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
412//  CHECK-DAG:   %[[V_NWC_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
413//  CHECK-DAG:   %[[V_NWC_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
414
415/// Transpose result to nwc format.
416//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transpose %[[V_NWC_INPUT_R]], [0, 2, 1]
417//  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transpose %[[V_NWC_FILTER_R]], [2, 1, 0]
418//  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transpose %[[V_NWC_OUTPUT_R]], [0, 2, 1]
419
420//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
421// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
422//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
423// CHECK-SAME:     {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
424//      CHECK:   %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
425// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
426//      CHECK:   %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
427// CHECK-SAME:     {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
428
429//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<3x8xf32> from vector<2x3x8xf32>
430//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<3x8xf32> from vector<2x3x8xf32>
431
432//      CHECK:  %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
433// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
434//      CHECK:  %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
435// CHECK-SAME:     {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
436
437/// w == 0, kw == 0
438//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
439// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
440// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
441// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_0]]
442// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
443/// w == 1, kw == 0
444//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
445// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
446// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
447// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER_0]], %[[V_OUTPUT_1]]
448// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
449/// w == 1, kw == 1
450//      CHECK:   %[[CONTRACT_2:.+]] = vector.contract {
451// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
452// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
453// CHECK-SAME:     %[[V_INPUT_2]], %[[V_FILTER_1]], %[[CONTRACT_0]]
454// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
455/// w == 1, kw == 1
456//      CHECK:   %[[CONTRACT_3:.+]] = vector.contract {
457// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
458// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
459// CHECK-SAME:     %[[V_INPUT_3]], %[[V_FILTER_1]], %[[CONTRACT_1]]
460// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
461
462/// w == 0, kw == 0
463//      CHECK:   %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_2]], %[[V_OUTPUT_R]]
464// CHECK-SAME:     {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
465/// w == 1, kw == 0
466//      CHECK:   %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_3]], %[[RES_0]]
467// CHECK-SAME:     {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
468
469/// Transpose result to ncw format.
470//  CHECK:  %[[RES_2:.+]] = vector.transpose %[[RES_1]], [0, 2, 1]
471
472// Write the result back in one shot.
473//      CHECK:   vector.transfer_write %[[RES_2]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
474
475// -----
476
477func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x3x2xf32>, %output: memref<4x8x2xf32>) {
478  linalg.conv_1d_ncw_fcw
479    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
480    ins(%input, %filter : memref<4x3x6xf32>, memref<8x3x2xf32>)
481    outs(%output : memref<4x8x2xf32>)
482  return
483}
484
485// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
486// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
487// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
488
489//      CHECK: func @conv1d_ncw_4x8x2_memref
490// CHECK-SAME: (%[[INPUT:.+]]: memref<4x3x6xf32>, %[[FILTER:.+]]: memref<8x3x2xf32>, %[[OUTPUT:.+]]: memref<4x8x2xf32>)
491
492//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
493//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
494
495/// Read the whole data in one shot.
496//  CHECK-DAG:   %[[V_NWC_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
497//  CHECK-DAG:  %[[V_NWC_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
498//  CHECK-DAG:  %[[V_NWC_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
499
500/// Transpose result to nwc format.
501//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transpose %[[V_NWC_INPUT_R]], [0, 2, 1]
502//  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transpose %[[V_NWC_FILTER_R]], [2, 1, 0]
503//  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transpose %[[V_NWC_OUTPUT_R]], [0, 2, 1]
504
505//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
506// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
507//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
508// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
509
510//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<3x8xf32> from vector<2x3x8xf32>
511//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<3x8xf32> from vector<2x3x8xf32>
512
513/// w == 0, kw == 0
514//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
515// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
516// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
517// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]]
518// CHECK-SAME:     : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
519/// w == 0, kw == 1
520//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
521// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
522// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
523// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]]
524// CHECK-SAME:     : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
525
526/// Transpose result to ncw format.
527//  CHECK:  %[[RES:.+]] = vector.transpose %[[CONTRACT_1]], [0, 2, 1]
528
529// Write the result back in one shot.
530//      CHECK:   vector.transfer_write %[[RES]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
531
532
533// -----
534
535func.func @conv1d_8_tensor(%input: tensor<11xf32>, %filter: tensor<4xf32>, %output: tensor<8xf32>) -> tensor<8xf32> {
536  %0 = linalg.conv_1d ins(%input, %filter : tensor<11xf32>, tensor<4xf32>)
537                     outs(%output : tensor<8xf32>) -> tensor<8xf32>
538  return %0 : tensor<8xf32>
539}
540
541//      CHECK: func @conv1d_8_tensor
542// CHECK-SAME: (%[[INPUT:.+]]: tensor<11xf32>, %[[FILTER:.+]]: tensor<4xf32>, %[[OUTPUT:.+]]: tensor<8xf32>)
543
544//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
545//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
546
547/// Read the whole data in one shot.
548//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]]], %[[F0]]
549//  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]]], %[[F0]]
550//  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]]], %[[F0]]
551
552//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
553// CHECK-SAME:     {offsets = [0], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32>
554//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
555// CHECK-SAME:     {offsets = [1], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32>
556//      CHECK:   %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
557// CHECK-SAME:     {offsets = [2], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32>
558//      CHECK:   %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
559// CHECK-SAME:     {offsets = [3], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32>
560
561//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : f32 from vector<4xf32>
562//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : f32 from vector<4xf32>
563//      CHECK:  %[[V_FILTER_2:.+]] = vector.extract %[[V_FILTER_R]][2] : f32 from vector<4xf32>
564//      CHECK:  %[[V_FILTER_3:.+]] = vector.extract %[[V_FILTER_R]][3] : f32 from vector<4xf32>
565
566/// w == 0, kw == 0
567//      CHECK:   %[[RES_0:.+]] = vector.outerproduct
568// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]] {kind = #vector.kind<add>}
569// CHECK-SAME:     : vector<8xf32>, f32
570/// w == 1, kw == 1
571//      CHECK:   %[[RES_1:.+]] = vector.outerproduct
572// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER_1]], %[[RES_0]] {kind = #vector.kind<add>}
573// CHECK-SAME:     : vector<8xf32>, f32
574/// w == 2, kw == 2
575//      CHECK:   %[[RES_2:.+]] = vector.outerproduct
576// CHECK-SAME:     %[[V_INPUT_2]], %[[V_FILTER_2]], %[[RES_1]] {kind = #vector.kind<add>}
577// CHECK-SAME:     : vector<8xf32>, f32
578/// w == 3, kw == 3
579//      CHECK:   %[[RES_3:.+]] = vector.outerproduct
580// CHECK-SAME:     %[[V_INPUT_3]], %[[V_FILTER_3]], %[[RES_2]] {kind = #vector.kind<add>}
581// CHECK-SAME:     : vector<8xf32>, f32
582
583// Write the result back in one shot.
584//      CHECK:   vector.transfer_write %[[RES_3]], %[[OUTPUT]][%[[C0]]]
585
586// -----
587
588func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
589  linalg.depthwise_conv_1d_nwc_wc
590    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
591    ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
592    outs(%output : memref<3x2x4xf32>)
593  return
594}
595
596//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref
597//  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
598
599//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
600//   CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
601
602/// Read the whole data in one shot.
603//      CHECK:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
604//      CHECK:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
605//      CHECK:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
606
607//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
608// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
609//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
610// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
611
612//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xf32> from vector<2x4xf32>
613//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xf32> from vector<2x4xf32>
614
615/// w == 0, kw = 0
616//      CHECK:  %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32>
617//      CHECK:  %[[FMA_0:.*]] = vector.fma %[[V_INPUT_0]], %[[B_FILTER_0]], %[[V_OUTPUT_R]] : vector<3x2x4xf32>
618
619/// w == 0, kw = 1
620//      CHECK:  %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xf32> to vector<3x2x4xf32>
621//      CHECK:  %[[FMA_1:.*]] = vector.fma %[[V_INPUT_1]], %[[B_FILTER_1]], %[[FMA_0]] : vector<3x2x4xf32>
622
623// Write the result back in one shot.
624//      CHECK:   vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
625
626
627// -----
628
629func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %filter: memref<2x4xi8>, %output: memref<3x2x4xi32>) {
630  linalg.depthwise_conv_1d_nwc_wc
631    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
632    ins(%input, %filter : memref<3x5x4xi8>, memref<2x4xi8>)
633    outs(%output : memref<3x2x4xi32>)
634  return
635}
636
637//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref
638//  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xi8>, %[[FILTER:[0-9a-z]+]]: memref<2x4xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xi32>)
639
640//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
641
642/// Read the whole data in one shot.
643//      CHECK:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
644//      CHECK:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
645//      CHECK:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
646
647//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
648// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
649//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
650// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
651
652//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xi8> from vector<2x4xi8>
653//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xi8> from vector<2x4xi8>
654
655/// w == 0, kw =
656//      CHECK:  %[[EXT_INPUT_0:.*]] = arith.extsi %[[V_INPUT_0]] : vector<3x2x4xi8> to vector<3x2x4xi32>
657//      CHECK:  %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x2x4xi8>
658//      CHECK:  %[[EXT_FILTER_0:.*]] = arith.extsi %[[B_FILTER_0]] : vector<3x2x4xi8> to vector<3x2x4xi32>
659//      CHECK:  %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[EXT_FILTER_0]] : vector<3x2x4xi32>
660//      CHECK:  %[[ADD_0:.*]] = arith.addi %[[MUL_0]], %[[V_OUTPUT_R]] : vector<3x2x4xi32>
661
662/// w == 0, kw = 1
663//      CHECK:  %[[EXT_INPUT_1:.*]] = arith.extsi %[[V_INPUT_1]] : vector<3x2x4xi8> to vector<3x2x4xi32>
664//      CHECK:  %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x2x4xi8>
665//      CHECK:  %[[EXT_FILTER_1:.*]] = arith.extsi %[[B_FILTER_1]] : vector<3x2x4xi8> to vector<3x2x4xi32>
666//      CHECK:  %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[EXT_FILTER_1]] : vector<3x2x4xi32>
667//      CHECK:  %[[ADD_1:.*]] = arith.addi %[[MUL_1]], %[[ADD_0]] : vector<3x2x4xi32>
668
669// Write the result back in one shot.
670//      CHECK:   vector.transfer_write %[[ADD_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
671
672// -----
673
674func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter: memref<1x3x2xf16>, %output: memref<1x2x2xf32>) {
675  linalg.conv_1d_nwc_wcf
676  {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
677   ins(%input, %filter : memref<1x2x3xf16>, memref<1x3x2xf16>)
678   outs(%output : memref<1x2x2xf32>)
679  return
680}
681
682//       CHECK: func @conv_1d_nwc_wcf_mixed_type_memref
683//  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<1x2x3xf16>, %[[FILTER:[0-9a-z]+]]: memref<1x3x2xf16>, %[[OUTPUT:[0-9a-z]+]]: memref<1x2x2xf32>)
684
685//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
686//   CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
687
688/// Read the whole data in one shot.
689//      CHECK:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
690//      CHECK:   %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]]
691//      CHECK:   %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
692//      CHECK:   %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<3x2xf16> from vector<1x3x2xf16>
693//      CHECK:   %[[CONT:.*]] = vector.contract
694//        {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32>
695//      CHECK:   vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
696
697// -----
698
699func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter: memref<1x3x2xi8>, %output: memref<1x2x2xf32>) {
700  linalg.conv_1d_nwc_wcf
701  {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
702   ins(%input, %filter : memref<1x2x3xi8>, memref<1x3x2xi8>)
703   outs(%output : memref<1x2x2xf32>)
704  return
705}
706
707
708// CHECK-LABEL: func @conv_1d_nwc_wcf_mixed_int_fp_memref
709// CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<1x2x3xi8>, %[[FILTER:[0-9a-z]+]]: memref<1x3x2xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<1x2x2xf32>)
710// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
711// CHECK-DAG: %[[I0:.+]] = arith.constant 0 : index
712// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i8
713// CHECK: %[[READ0:.+]] = vector.transfer_read %arg0[%[[I0]], %[[I0]], %[[I0]]], %[[C0]]
714// CHECK: %[[READ1:.+]] = vector.transfer_read %arg1[%[[I0]], %[[I0]], %[[I0]]], %[[C0]]
715// CHECK: %[[READ2:.+]] = vector.transfer_read %arg2[%[[I0]], %[[I0]], %[[I0]]], %[[CST]]
716// CHECK: %[[EXT:.+]] = vector.extract %[[READ1]][0] : vector<3x2xi8> from vector<1x3x2xi8>
717// CHECK: %[[CAST0:.+]] = arith.sitofp %[[READ0]] : vector<1x2x3xi8> to vector<1x2x3xf32>
718// CHECK: %[[CAST1:.+]] = arith.sitofp %[[EXT]] : vector<3x2xi8> to vector<3x2xf32>
719// CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[CAST0]], %[[CAST1]], %[[READ2]]
720// CHECK: vector.transfer_write %[[CONTRACT]], %arg2[%[[I0]], %[[I0]], %[[I0]]]
721
722// -----
723
724func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {
725  linalg.pooling_nwc_sum
726    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
727    ins(%input, %filter : memref<4x4x3xf32>, memref<1xf32>)
728    outs(%output : memref<4x2x3xf32>)
729  return
730}
731
732// CHECK-LABEL: func.func @pooling_nwc_sum_memref_1_2_1_3
733// CHECK-SAME: (%[[INPUT:.+]]: memref<4x4x3xf32>, %[[FILTER:.+]]: memref<1xf32>, %[[OUTPUT:.+]]: memref<4x2x3xf32>)
734// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
735// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32
736// CHECK: %[[V0:.+]] = vector.transfer_read %[[INPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x4x3xf32>, vector<4x4x3xf32>
737// CHECK: %[[V1:.+]] = vector.transfer_read %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32>
738// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
739// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
740// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
741// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
742// CHECK: %[[V6:.+]] = arith.addf %[[V2]], %[[V4]] : vector<4x1x3xf32>
743// CHECK: %[[V7:.+]] = arith.addf %[[V3]], %[[V5]] : vector<4x1x3xf32>
744// CHECK: %[[V8:.+]] = vector.insert_strided_slice %[[V6]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
745// CHECK: %[[V9:.+]] = vector.insert_strided_slice %[[V7]], %[[V8]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
746// CHECK: vector.transfer_write %[[V9]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>
747
748// -----
749
750func.func @pooling_nwc_max_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {
751  linalg.pooling_nwc_max
752    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
753    ins(%input, %filter : memref<4x4x3xf32>, memref<1xf32>)
754    outs(%output : memref<4x2x3xf32>)
755  return
756}
757
758// CHECK-LABEL: func.func @pooling_nwc_max_memref_1_2_1_3
759// CHECK-SAME: (%[[INPUT:.+]]: memref<4x4x3xf32>, %[[FILTER:.+]]: memref<1xf32>, %[[OUTPUT:.+]]: memref<4x2x3xf32>)
760// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
761// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32
762// CHECK: %[[V0:.+]] = vector.transfer_read %[[INPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x4x3xf32>, vector<4x4x3xf32>
763// CHECK: %[[V1:.+]] = vector.transfer_read %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32>
764// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
765// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
766// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
767// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
768// CHECK: %[[V6:.+]] = arith.maximumf %[[V2]], %[[V4]] : vector<4x1x3xf32>
769// CHECK: %[[V7:.+]] = arith.maximumf %[[V3]], %[[V5]] : vector<4x1x3xf32>
770// CHECK: %[[V8:.+]] = vector.insert_strided_slice %[[V6]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
771// CHECK: %[[V9:.+]] = vector.insert_strided_slice %[[V7]], %[[V8]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
772// CHECK: vector.transfer_write %[[V9]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>
773
774// -----
775
776// The i8i8i32 case is similar to f32 case, so checking one case is enough for
777// test coverage.
778func.func @pooling_nwc_sum_i8i8i32_memref_1_2_1_3(%input: memref<4x4x3xi8>, %filter: memref<1xi8>, %output: memref<4x2x3xi32>) {
779  linalg.pooling_nwc_sum
780    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
781    ins(%input, %filter : memref<4x4x3xi8>, memref<1xi8>)
782    outs(%output : memref<4x2x3xi32>)
783  return
784}
785
786// CHECK-LABEL: func.func @pooling_nwc_sum_i8i8i32_memref_1_2_1_3
787// CHECK-SAME: (%[[INPUT:.+]]: memref<4x4x3xi8>, %[[FILTER:.+]]: memref<1xi8>, %[[OUTPUT:.+]]: memref<4x2x3xi32>)
788// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
789// CHECK-DAG: %[[Vc0_i8:.+]] = arith.constant 0 : i8
790// CHECK-DAG: %[[Vc0_i32:.+]] = arith.constant 0 : i32
791// CHECK: %[[V0:.+]] = vector.transfer_read %[[INPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i8]] {in_bounds = [true, true, true]} : memref<4x4x3xi8>, vector<4x4x3xi8>
792// CHECK: %[[V1:.+]] = vector.transfer_read %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i32]] {in_bounds = [true, true, true]} : memref<4x2x3xi32>, vector<4x2x3xi32>
793// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8>
794// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8>
795// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32>
796// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32>
797// CHECK: %[[V6:.+]] = arith.extsi %[[V2]] : vector<4x1x3xi8> to vector<4x1x3xi32>
798// CHECK: %[[V7:.+]] = arith.addi %[[V6]], %[[V4]] : vector<4x1x3xi32>
799// CHECK: %[[V8:.+]] = arith.extsi %[[V3]] : vector<4x1x3xi8> to vector<4x1x3xi32>
800// CHECK: %[[V9:.+]] = arith.addi %[[V8]], %[[V5]] : vector<4x1x3xi32>
801// CHECK: %[[V10:.+]] = vector.insert_strided_slice %[[V7]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32>
802// CHECK: %[[V11:.+]] = vector.insert_strided_slice %[[V9]], %[[V10]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32>
803// CHECK: vector.transfer_write %[[V11]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xi32>, memref<4x2x3xi32>
804// CHECK: return
805
806// -----
807
808// The i8i8i32 case is similar to f32 case, so checking one case is enough for
809// test coverage.
810func.func @pooling_nwc_max_i8i8i32_memref_1_2_1_3(%input: memref<4x4x3xi8>, %filter: memref<1xi8>, %output: memref<4x2x3xi32>) {
811  linalg.pooling_nwc_max
812    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
813    ins(%input, %filter : memref<4x4x3xi8>, memref<1xi8>)
814    outs(%output : memref<4x2x3xi32>)
815  return
816}
817
818// CHECK-LABEL: func.func @pooling_nwc_max_i8i8i32_memref_1_2_1_3
819// CHECK-SAME: (%[[INPUT:.+]]: memref<4x4x3xi8>, %[[FILTER:.+]]: memref<1xi8>, %[[OUTPUT:.+]]: memref<4x2x3xi32>)
820// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
821// CHECK-DAG: %[[Vc0_i8:.+]] = arith.constant 0 : i8
822// CHECK-DAG: %[[Vc0_i32:.+]] = arith.constant 0 : i32
823// CHECK: %[[V0:.+]] = vector.transfer_read %[[INPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i8]] {in_bounds = [true, true, true]} : memref<4x4x3xi8>, vector<4x4x3xi8>
824// CHECK: %[[V1:.+]] = vector.transfer_read %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i32]] {in_bounds = [true, true, true]} : memref<4x2x3xi32>, vector<4x2x3xi32>
825// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8>
826// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8>
827// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32>
828// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32>
829// CHECK: %[[V6:.+]] = arith.extsi %[[V2]] : vector<4x1x3xi8> to vector<4x1x3xi32>
830// CHECK: %[[V7:.+]] = arith.maxsi %[[V6]], %[[V4]] : vector<4x1x3xi32>
831// CHECK: %[[V8:.+]] = arith.extsi %[[V3]] : vector<4x1x3xi8> to vector<4x1x3xi32>
832// CHECK: %[[V9:.+]] = arith.maxsi %[[V8]], %[[V5]] : vector<4x1x3xi32>
833// CHECK: %[[V10:.+]] = vector.insert_strided_slice %[[V7]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32>
834// CHECK: %[[V11:.+]] = vector.insert_strided_slice %[[V9]], %[[V10]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32>
835// CHECK: vector.transfer_write %[[V11]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xi32>, memref<4x2x3xi32>
836// CHECK: return
837
838// -----
839
840func.func @pooling_nwc_sum_memref_2_2_2_3(%input: memref<4x6x3xf32>, %filter: memref<2xf32>, %output: memref<4x2x3xf32>) {
841  linalg.pooling_nwc_sum
842    {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
843    ins(%input, %filter : memref<4x6x3xf32>, memref<2xf32>)
844    outs(%output : memref<4x2x3xf32>)
845  return
846}
847
848// CHECK-LABEL: func.func @pooling_nwc_sum_memref_2_2_2_3
849// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2xf32>, %[[OUTPUT:.+]]: memref<4x2x3xf32>)
850// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
851// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32
852// CHECK: %[[V0:.+]] = vector.transfer_read %[[INPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x6x3xf32>, vector<4x6x3xf32>
853// CHECK: %[[V1:.+]] = vector.transfer_read %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32>
854// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
855// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
856// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
857// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
858// CHECK: %[[V6:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
859// CHECK: %[[V7:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
860// CHECK: %[[V8:.+]] = arith.addf %[[V2]], %[[V6]] : vector<4x1x3xf32>
861// CHECK: %[[V9:.+]] = arith.addf %[[V3]], %[[V7]] : vector<4x1x3xf32>
862// CHECK: %[[V10:.+]] = arith.addf %[[V4]], %[[V8]] : vector<4x1x3xf32>
863// CHECK: %[[V11:.+]] = arith.addf %[[V5]], %[[V9]] : vector<4x1x3xf32>
864// CHECK: %[[V12:.+]] = vector.insert_strided_slice %[[V10]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
865// CHECK: %[[V13:.+]] = vector.insert_strided_slice %[[V11]], %[[V12]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
866// CHECK: vector.transfer_write %[[V13:.+]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>
867
868
869// -----
870
871func.func @pooling_ncw_sum_memref_1_2_1_3(%input: memref<4x3x4xf32>, %filter: memref<1xf32>, %output: memref<4x3x2xf32>) {
872  linalg.pooling_ncw_sum
873    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
874    ins(%input, %filter : memref<4x3x4xf32>, memref<1xf32>)
875    outs(%output : memref<4x3x2xf32>)
876  return
877}
878
879// CHECK-LABEL: func.func @pooling_ncw_sum_memref_1_2_1_3
880// CHECK-SAME: (%[[INPUT:.+]]: memref<4x3x4xf32>, %[[FILTER:.+]]: memref<1xf32>, %[[OUTPUT:.+]]: memref<4x3x2xf32>)
881// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
882// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32
883// CHECK: %[[V0:.+]] = vector.transfer_read %[[INPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x4xf32>, vector<4x3x4xf32>
884// CHECK: %[[V1:.+]] = vector.transfer_read %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x2xf32>, vector<4x3x2xf32>
885// CHECK: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x3x4xf32> to vector<4x4x3xf32>
886// CHECK: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32>
887// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
888// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
889// CHECK: %[[V6:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
890// CHECK: %[[V7:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
891// CHECK: %[[V8:.+]] = arith.addf %[[V4]], %[[V6]] : vector<4x1x3xf32>
892// CHECK: %[[V9:.+]] = arith.addf %[[V5]], %[[V7]] : vector<4x1x3xf32>
893// CHECK: %[[V10:.+]] = vector.insert_strided_slice %[[V8]], %[[V3]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
894// CHECK: %[[V11:.+]] = vector.insert_strided_slice %[[V9]], %[[V10]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
895// CHECK: %[[V12:.+]] = vector.transpose %[[V11]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
896// CHECK: vector.transfer_write %[[V12:.+]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x3x2xf32>, memref<4x3x2xf32>
897
898
899// -----
900
901func.func @pooling_nwc_sum_mixed_type_memref_1_2_1_1(%input: memref<1x2x3xf16>, %filter: memref<1xf16>, %output: memref<1x2x3xf32>) {
902  linalg.pooling_nwc_sum
903  {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
904   ins(%input, %filter : memref<1x2x3xf16>, memref<1xf16>)
905   outs(%output : memref<1x2x3xf32>)
906  return
907}
908
909// CHECK-LABEL: func.func @pooling_nwc_sum_mixed_type_memref_1_2_1_1
910// CHECK-SAME: (%[[INPUT:.+]]: memref<1x2x3xf16>, %[[FILTER:.+]]: memref<1xf16>, %[[OUTPUT:.+]]: memref<1x2x3xf32>)
911// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
912// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f16
913// CHECK-DAG: %[[Vcst_0:.+]] = arith.constant 0.000000e+00 : f32
914// CHECK: %[[V0:.+]] = vector.transfer_read %[[INPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<1x2x3xf16>, vector<1x2x3xf16>
915// CHECK: %[[V1:.+]] = vector.transfer_read %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst_0]] {in_bounds = [true, true, true]} : memref<1x2x3xf32>, vector<1x2x3xf32>
916// CHECK: %[[V2:.+]] = arith.extf %[[V0]] : vector<1x2x3xf16> to vector<1x2x3xf32>
917// CHECK: %[[V3:.+]] = arith.addf %[[V2]], %[[V1]] : vector<1x2x3xf32>
918// CHECK: vector.transfer_write %[[V3:.+]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<1x2x3xf32>, memref<1x2x3xf32>
919
920// -----
921
922func.func @pooling_nwc_sum_memref_2_2_2_1(%input: memref<4x4x3xf32>, %filter: memref<2xf32>, %output: memref<4x2x3xf32>) {
923  linalg.pooling_nwc_sum
924    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
925    ins(%input, %filter : memref<4x4x3xf32>, memref<2xf32>)
926    outs(%output : memref<4x2x3xf32>)
927  return
928}
929
930// CHECK-LABEL: func.func @pooling_nwc_sum_memref_2_2_2_1
931// CHECK-SAME: (%[[INPUT:.+]]: memref<4x4x3xf32>, %[[FILTER:.+]]: memref<2xf32>, %[[OUTPUT:.+]]: memref<4x2x3xf32>)
932// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
933// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32
934// CHECK: %[[V0:.+]] = vector.transfer_read %[[INPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x4x3xf32>, vector<4x4x3xf32>
935// CHECK: %[[V1:.+]] = vector.transfer_read %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32>
936// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
937// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
938// CHECK: %[[V4:.+]] = arith.addf %[[V2]], %[[V1]] : vector<4x2x3xf32>
939// CHECK: %[[V5:.+]] = arith.addf %[[V3]], %[[V4]] : vector<4x2x3xf32>
940// CHECK: vector.transfer_write %[[V5:.+]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>
941
942
943// -----
944
945func.func @pooling_ncw_sum_memref_2_2_2_3(%input: memref<4x3x6xf32>, %filter: memref<2xf32>, %output: memref<4x3x2xf32>) {
946  linalg.pooling_ncw_sum
947    {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
948    ins(%input, %filter : memref<4x3x6xf32>, memref<2xf32>)
949    outs(%output : memref<4x3x2xf32>)
950  return
951}
952
953// CHECK-LABEL: func.func @pooling_ncw_sum_memref_2_2_2_3
954// CHECK-SAME: (%[[INPUT:.+]]: memref<4x3x6xf32>, %[[FILTER:.+]]: memref<2xf32>, %[[OUTPUT:.+]]: memref<4x3x2xf32>)
955// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
956// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32
957// CHECK: %[[V0:.+]] = vector.transfer_read %[[INPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x6xf32>, vector<4x3x6xf32>
958// CHECK: %[[V1:.+]] = vector.transfer_read %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x2xf32>, vector<4x3x2xf32>
959// CHECK: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x3x6xf32> to vector<4x6x3xf32>
960// CHECK: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32>
961// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
962// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
963// CHECK: %[[V6:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
964// CHECK: %[[V7:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
965// CHECK: %[[V8:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
966// CHECK: %[[V9:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
967// CHECK: %[[V10:.+]] = arith.addf %[[V4]], %[[V8]] : vector<4x1x3xf32>
968// CHECK: %[[V11:.+]] = arith.addf %[[V5]], %[[V9]] : vector<4x1x3xf32>
969// CHECK: %[[V12:.+]] = arith.addf %[[V6]], %[[V10]] : vector<4x1x3xf32>
970// CHECK: %[[V13:.+]] = arith.addf %[[V7]], %[[V11]] : vector<4x1x3xf32>
971// CHECK: %[[V14:.+]] = vector.insert_strided_slice %[[V12]], %[[V3]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
972// CHECK: %[[V15:.+]] = vector.insert_strided_slice %[[V13]], %[[V14]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
973// CHECK: %[[V16:.+]] = vector.transpose %[[V15]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
974// CHECK: vector.transfer_write %[[V16:.+]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x3x2xf32>, memref<4x3x2xf32>
975
976// -----
977
978func.func @pooling_ncw_sum_memref_2_3_2_1(%input: memref<4x2x5xf32>, %filter: memref<2xf32>, %output: memref<4x2x3xf32>) {
979  linalg.pooling_ncw_sum
980    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
981    ins(%input, %filter : memref<4x2x5xf32>, memref<2xf32>)
982    outs(%output : memref<4x2x3xf32>)
983  return
984}
985
986// CHECK-LABEL: func.func @pooling_ncw_sum_memref_2_3_2_1
987// CHECK-SAME: (%[[INPUT:.+]]: memref<4x2x5xf32>, %[[FILTER:.+]]: memref<2xf32>, %[[OUTPUT:.+]]: memref<4x2x3xf32>)
988// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
989// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32
990// CHECK: %[[V0:.+]] = vector.transfer_read %[[INPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x5xf32>, vector<4x2x5xf32>
991// CHECK: %[[V1:.+]] = vector.transfer_read %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32>
992// CHECK: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x2x5xf32> to vector<4x5x2xf32>
993// CHECK: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
994// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 3, 2], strides = [1, 1, 1]} : vector<4x5x2xf32> to vector<4x3x2xf32>
995// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 2, 0], sizes = [4, 3, 2], strides = [1, 1, 1]} : vector<4x5x2xf32> to vector<4x3x2xf32>
996// CHECK: %[[V6:.+]] = arith.addf %[[V4]], %[[V3]] : vector<4x3x2xf32>
997// CHECK: %[[V7:.+]] = arith.addf %[[V5]], %[[V6]] : vector<4x3x2xf32>
998// CHECK: %[[V8:.+]] = vector.transpose %[[V7]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32>
999// CHECK: vector.transfer_write %[[V8:.+]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>
1000