xref: /llvm-project/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir (revision f11bda78c8fc551cf3e22cd5caa4005c329b904f)
1// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
2
3func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x3xi8>,
4                                                   %filter: tensor<1x3xi8>,
5                                                   %output: tensor<1x8x3xi8>) -> (tensor<1x8x3xi8>) {
6  %res = linalg.depthwise_conv_1d_nwc_wc
7    {dilations = dense<1> : vector<1xi64>,
8    strides = dense<1> : vector<1xi64>}
9    ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
10    outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8>
11  return %res : tensor<1x8x3xi8>
12}
13
14module attributes {transform.with_named_sequence} {
15  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
16    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
17    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
18    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
19    transform.yield
20  }
21}
22// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor
23// CHECK-SAME:      %[[INPUT:.*]]: tensor<1x8x3xi8>,
24// CHECK-SAME:      %[[FILTER:.*]]: tensor<1x3xi8>,
25// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x3xi8>) -> tensor<1x8x3xi8> {
26
27// CHECK-DAG:       %[[C0_IDX:.*]] = arith.constant 0 : index
28
29/// Read the whole data in one shot.
30// CHECK:           %[[V_INPUT_R:.*]] = vector.transfer_read %[[INPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
31// CHECK:           %[[V_FILTER_R:.*]] = vector.transfer_read %[[FILTER]][%[[C0_IDX]], %[[C0_IDX]]]
32// CHECK:           %[[V_OUTPUT_R:.*]] = vector.transfer_read %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
33
34// CHECK:           %[[V_FILTER_0:.*]] = vector.extract %[[V_FILTER_R]][0] : vector<3xi8> from vector<1x3xi8>
35
36/// w == 0, kw = 0
37// CHECK:           %[[SC_INPUT:.*]] = vector.shape_cast %[[V_INPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8>
38// CHECK:           %[[SC_OUTPUT:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8>
39// CHECK:           %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
40// CHECK-SAME:        [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
41// CHECK:           %[[B_FILTER:.*]] = vector.broadcast %[[SH_FILTER_0]] : vector<24xi8> to vector<1x24xi8>
42// CHECK:           %[[MULI:.*]] = arith.muli %[[SC_INPUT]], %[[B_FILTER]] : vector<1x24xi8>
43// CHECK:           %[[ADDI:.*]] = arith.addi %[[MULI]], %[[SC_OUTPUT]] : vector<1x24xi8>
44
45// Write the result back in one shot.
46// CHECK:           %[[SC_ADDI:.*]] = vector.shape_cast %[[ADDI]] : vector<1x24xi8> to vector<1x8x3xi8>
47// CHECK:           vector.transfer_write %[[SC_ADDI]], %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
48
49//------
50
51func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x4xf32>,
52                                                                %filter: memref<2x4xf32>,
53                                                                %output: memref<3x2x4xf32>) {
54  linalg.depthwise_conv_1d_nwc_wc
55    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
56    ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
57    outs(%output : memref<3x2x4xf32>)
58  return
59}
60
61//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2
62//  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
63
64//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
65//   CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
66
67/// Read the whole data in one shot.
68//      CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
69//      CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
70//      CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
71
72//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
73// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
74//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
75// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
76
77//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xf32> from vector<2x4xf32>
78//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xf32> from vector<2x4xf32>
79
80
81/// w == 0, kw = 0
82// CHECK:           %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xf32> to vector<3x8xf32>
83// CHECK:           %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xf32> to vector<3x8xf32>
84// CHECK:           %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
85// CHECK-SAME:        [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
86// CHECK:           %[[B_FILTER_0:.*]] = vector.broadcast %[[SH_FILTER_0]] : vector<8xf32> to vector<3x8xf32>
87// CHECK:           %[[FMA_0:.*]] = vector.fma %[[SC_V_INPUT_0]], %[[B_FILTER_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xf32>
88
89/// w == 0, kw = 1
90// CHECK:           %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xf32> to vector<3x8xf32>
91// CHECK:           %[[SH_FILTER_1:.*]] = vector.shuffle %[[V_FILTER_1]], %[[V_FILTER_1]]
92// CHECK-SAME:        [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
93// CHECK:           %[[B_FILTER_1:.*]] = vector.broadcast %[[SH_FILTER_1]] : vector<8xf32> to vector<3x8xf32>
94// CHECK:           %[[FMA_1:.*]] = vector.fma %[[SC_V_INPUT_1]], %[[B_FILTER_1]], %[[FMA_0]] : vector<3x8xf32>
95
96// Write the result back in one shot.
97//      CHECK:   %[[SC_FMA_1:.*]] = vector.shape_cast %[[FMA_1]] : vector<3x8xf32> to vector<3x2x4xf32>
98//      CHECK:   vector.transfer_write %[[SC_FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
99
100
101module attributes {transform.with_named_sequence} {
102  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
103    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
104    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
105    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
106    transform.yield
107  }
108}
109
110// -----
111
112func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2(%input: memref<3x5x4xi8>,
113                                                              %filter: memref<2x4xi8>,
114                                                              %output: memref<3x2x4xi32>) {
115  linalg.depthwise_conv_1d_nwc_wc
116    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
117    ins(%input, %filter : memref<3x5x4xi8>, memref<2x4xi8>)
118    outs(%output : memref<3x2x4xi32>)
119  return
120}
121
122//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2
123//  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xi8>, %[[FILTER:[0-9a-z]+]]: memref<2x4xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xi32>)
124
125//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
126
127/// Read the whole data in one shot.
128//      CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
129//      CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
130//      CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
131
132//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
133// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
134//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
135// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
136
137//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xi8> from vector<2x4xi8>
138//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xi8> from vector<2x4xi8>
139
140/// w == 0, kw = 0
141//      CHECK:  %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xi8> to vector<3x8xi8>
142//      CHECK:  %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xi32> to vector<3x8xi32>
143//      CHECK:  %[[EXT_INPUT_0:.*]] = arith.extsi %[[SC_V_INPUT_0]] : vector<3x8xi8> to vector<3x8xi32>
144//      CHECK:  %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
145//      CHECK-SAME:  [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
146//      CHECK:  %[[EXT_FILTER_0:.*]] = arith.extsi %[[SH_FILTER_0]] : vector<8xi8> to vector<8xi32>
147//      CHECK:  %[[B_FILTER_0:.*]] = vector.broadcast %[[EXT_FILTER_0]] : vector<8xi32> to vector<3x8xi32>
148//      CHECK:  %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[B_FILTER_0]] : vector<3x8xi32>
149//      CHECK:  %[[ADD_0:.*]] = arith.addi %[[MUL_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xi32>
150
151/// w == 0, kw = 1
152//      CHECK:  %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xi8> to vector<3x8xi8>
153//      CHECK:  %[[EXT_INPUT_1:.*]] = arith.extsi %[[SC_V_INPUT_1]] : vector<3x8xi8> to vector<3x8xi32>
154//      CHECK:  %[[SH_FILTER_1:.*]] = vector.shuffle %[[V_FILTER_1]], %[[V_FILTER_1]]
155//      CHECK-SAME:  [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
156//      CHECK:  %[[EXT_FILTER_1:.*]] = arith.extsi %[[SH_FILTER_1]] : vector<8xi8> to vector<8xi32>
157//      CHECK:  %[[B_FILTER_1:.*]] = vector.broadcast %[[EXT_FILTER_1]] : vector<8xi32> to vector<3x8xi32>
158//      CHECK:  %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[B_FILTER_1]] : vector<3x8xi32>
159//      CHECK:  %[[ADD_1:.*]] = arith.addi %[[MUL_1]], %[[ADD_0]] : vector<3x8xi32>
160
161// Write the result back in one shot.
162//      CHECK:   %[[SC_ADD_1:.*]] = vector.shape_cast %[[ADD_1]] : vector<3x8xi32> to vector<3x2x4xi32>
163//      CHECK:   vector.transfer_write %[[SC_ADD_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
164
165module attributes {transform.with_named_sequence} {
166  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
167    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
168    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
169    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
170    transform.yield
171  }
172}
173
174// -----
175
176func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2(%input: tensor<3x9x4xi8>,
177                                                            %filter: tensor<3x4xi8>,
178                                                            %output: tensor<3x3x4xi8>) -> tensor<3x3x4xi8> {
179  %res = linalg.depthwise_conv_1d_nwc_wc
180    {dilations = dense<1> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
181    ins(%input, %filter : tensor<3x9x4xi8>, tensor<3x4xi8>)
182    outs(%output : tensor<3x3x4xi8>) -> tensor<3x3x4xi8>
183  return %res : tensor<3x3x4xi8>
184}
185// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2
186// CHECK-SAME:      %[[INPUT:.*]]: tensor<3x9x4xi8>,
187// CHECK-SAME:      %[[FILTER:.*]]: tensor<3x4xi8>,
188// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<3x3x4xi8>) -> tensor<3x3x4xi8> {
189
190// CHECK-DAG:           %[[C0_IDX:.*]] = arith.constant 0 : index
191// CHECK-DAG:           %[[C0_I8:.*]] = arith.constant 0 : i8
192
193/// Read the whole data in one shot.
194// CHECK:           %[[V_INPUT_R:.*]] = vector.transfer_read %[[INPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]], %[[C0_I8]]
195// CHECK:           %[[V_FILTER_R:.*]] = vector.transfer_read %[[FILTER]][%[[C0_IDX]], %[[C0_IDX]]], %[[C0_I8]]
196// CHECK:           %[[V_OUTPUT_R:.*]] = vector.transfer_read %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]], %[[C0_I8]]
197
198// CHECK:           %[[V_INPUT_0:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
199// CHECK-SAME:        {offsets = [0, 0, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
200// CHECK:           %[[V_INPUT_1:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
201// CHECK-SAME:        {offsets = [0, 2, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
202// CHECK:           %[[V_INPUT_2:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
203// CHECK-SAME:        {offsets = [0, 4, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
204// CHECK:           %[[V_INPUT_3:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
205// CHECK-SAME:        {offsets = [0, 1, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
206// CHECK:           %[[V_INPUT_4:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
207// CHECK-SAME:        {offsets = [0, 3, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
208// CHECK:           %[[V_INPUT_5:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
209// CHECK-SAME:        {offsets = [0, 5, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
210// CHECK:           %[[V_INPUT_6:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
211// CHECK-SAME:        {offsets = [0, 2, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
212// CHECK:           %[[V_INPUT_7:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
213// CHECK-SAME:        {offsets = [0, 4, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
214// CHECK:           %[[V_INPUT_8:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
215// CHECK-SAME:        {offsets = [0, 6, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
216
217// CHECK:           %[[V_FILTER_0:.*]] = vector.extract %[[V_FILTER_R]][0] : vector<4xi8> from vector<3x4xi8>
218// CHECK:           %[[V_FILTER_1:.*]] = vector.extract %[[V_FILTER_R]][1] : vector<4xi8> from vector<3x4xi8>
219// CHECK:           %[[V_FILTER_2:.*]] = vector.extract %[[V_FILTER_R]][2] : vector<4xi8> from vector<3x4xi8>
220
221// CHECK:           %[[V_OUTPUT_0:.*]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
222// CHECK-SAME:        {offsets = [0, 0, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x3x4xi8> to vector<3x1x4xi8>
223// CHECK:           %[[V_OUTPUT_1:.*]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
224// CHECK-SAME:       {offsets = [0, 1, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x3x4xi8> to vector<3x1x4xi8>
225// CHECK:           %[[V_OUTPUT_2:.*]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
226// CHECK-SAME:        {offsets = [0, 2, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x3x4xi8> to vector<3x1x4xi8>
227
228/// w == 0, kw == 0
229// CHECK:           %[[VAL_23:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8>
230// CHECK:           %[[VAL_24:.*]] = vector.shape_cast %[[V_OUTPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8>
231// CHECK:           %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
232// CHECK:           %[[VAL_27:.*]] = arith.muli %[[VAL_23]], %[[B_FILTER_0]] : vector<3x4xi8>
233// CHECK:           %[[VAL_28:.*]] = arith.addi %[[VAL_27]], %[[VAL_24]] : vector<3x4xi8>
234
235/// w == 1, kw == 0
236// CHECK:           %[[VAL_29:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8>
237// CHECK:           %[[VAL_30:.*]] = vector.shape_cast %[[V_OUTPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8>
238// CHECK:           %[[B_FILTER_0_1:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
239// CHECK:           %[[VAL_33:.*]] = arith.muli %[[VAL_29]], %[[B_FILTER_0_1]] : vector<3x4xi8>
240// CHECK:           %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_30]] : vector<3x4xi8>
241
242/// w == 2, kw == 0
243// CHECK:           %[[VAL_35:.*]] = vector.shape_cast %[[V_INPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8>
244// CHECK:           %[[VAL_36:.*]] = vector.shape_cast %[[V_OUTPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8>
245// CHECK:           %[[B_FILTER_0_2:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
246// CHECK:           %[[VAL_39:.*]] = arith.muli %[[VAL_35]], %[[B_FILTER_0_2]] : vector<3x4xi8>
247// CHECK:           %[[VAL_40:.*]] = arith.addi %[[VAL_39]], %[[VAL_36]] : vector<3x4xi8>
248
249/// w == 3, kw == 1
250// CHECK:           %[[VAL_41:.*]] = vector.shape_cast %[[V_INPUT_3]] : vector<3x1x4xi8> to vector<3x4xi8>
251// CHECK:           %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
252// CHECK:           %[[VAL_44:.*]] = arith.muli %[[VAL_41]], %[[B_FILTER_1]] : vector<3x4xi8>
253// CHECK:           %[[VAL_45:.*]] = arith.addi %[[VAL_44]], %[[VAL_28]] : vector<3x4xi8>
254
255/// w == 4, kw == 1
256// CHECK:           %[[VAL_46:.*]] = vector.shape_cast %[[V_INPUT_4]] : vector<3x1x4xi8> to vector<3x4xi8>
257// CHECK:           %[[B_FILTER_1_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
258// CHECK:           %[[VAL_49:.*]] = arith.muli %[[VAL_46]], %[[B_FILTER_1_1]] : vector<3x4xi8>
259// CHECK:           %[[VAL_50:.*]] = arith.addi %[[VAL_49]], %[[VAL_34]] : vector<3x4xi8>
260
261/// w == 5, kw == 1
262// CHECK:           %[[VAL_51:.*]] = vector.shape_cast %[[V_INPUT_5]] : vector<3x1x4xi8> to vector<3x4xi8>
263// CHECK:           %[[B_FILTER_1_2:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
264// CHECK:           %[[VAL_54:.*]] = arith.muli %[[VAL_51]], %[[B_FILTER_1_2]] : vector<3x4xi8>
265// CHECK:           %[[VAL_55:.*]] = arith.addi %[[VAL_54]], %[[VAL_40]] : vector<3x4xi8>
266
267/// w == 6, kw == 2
268// CHECK:           %[[VAL_56:.*]] = vector.shape_cast %[[V_INPUT_6]] : vector<3x1x4xi8> to vector<3x4xi8>
269// CHECK:           %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
270// CHECK:           %[[VAL_59:.*]] = arith.muli %[[VAL_56]], %[[B_FILTER_2]] : vector<3x4xi8>
271// CHECK:           %[[VAL_60:.*]] = arith.addi %[[VAL_59]], %[[VAL_45]] : vector<3x4xi8>
272
273/// w == 7, kw == 2
274// CHECK:           %[[VAL_61:.*]] = vector.shape_cast %[[VAL_60]] : vector<3x4xi8> to vector<3x1x4xi8>
275// CHECK:           %[[VAL_62:.*]] = vector.shape_cast %[[V_INPUT_7]] : vector<3x1x4xi8> to vector<3x4xi8>
276// CHECK:           %[[B_FILTER_2_1:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
277// CHECK:           %[[VAL_65:.*]] = arith.muli %[[VAL_62]], %[[B_FILTER_2_1]] : vector<3x4xi8>
278// CHECK:           %[[VAL_66:.*]] = arith.addi %[[VAL_65]], %[[VAL_50]] : vector<3x4xi8>
279
280/// w == 8, kw == 2
281// CHECK:           %[[VAL_67:.*]] = vector.shape_cast %[[VAL_66]] : vector<3x4xi8> to vector<3x1x4xi8>
282// CHECK:           %[[VAL_68:.*]] = vector.shape_cast %[[V_INPUT_8]] : vector<3x1x4xi8> to vector<3x4xi8>
283// CHECK:           %[[B_FILTER_2_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
284// CHECK:           %[[VAL_71:.*]] = arith.muli %[[VAL_68]], %[[B_FILTER_2_2]] : vector<3x4xi8>
285// CHECK:           %[[VAL_72:.*]] = arith.addi %[[VAL_71]], %[[VAL_55]] : vector<3x4xi8>
286
287// Write the result back.
288// CHECK:           %[[VAL_73:.*]] = vector.shape_cast %[[VAL_72]] : vector<3x4xi8> to vector<3x1x4xi8>
289// CHECK:           %[[VAL_74:.*]] = vector.insert_strided_slice %[[VAL_61]], %[[V_OUTPUT_R]]
290// CHECK-SAME:        {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x1x4xi8> into vector<3x3x4xi8>
291// CHECK:           %[[VAL_75:.*]] = vector.insert_strided_slice %[[VAL_67]], %[[VAL_74]]
292// CHECK-SAME:        {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<3x1x4xi8> into vector<3x3x4xi8>
293// CHECK:           %[[VAL_76:.*]] = vector.insert_strided_slice %[[VAL_73]], %[[VAL_75]]
294// CHECK-SAME:        {offsets = [0, 2, 0], strides = [1, 1, 1]} : vector<3x1x4xi8> into vector<3x3x4xi8>
295// CHECK:           %[[VAL_77:.*]] = vector.transfer_write %[[VAL_76]], %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
296
297module attributes {transform.with_named_sequence} {
298  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
299    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
300    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
301    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
302    transform.yield
303  }
304}
305
306