xref: /llvm-project/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir (revision d28a4f1fc02dc34a87fa22af0a053e8f1e7f6cea)
1// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
2
3// Check that the im2col patterns are properly connected with the
4// transform dialect.
5
6// Non static shapes are not supported.
7// Check that we emit an error.
8// TODO: Hook up the rewriter errors in transform dialect.
9func.func @conv_non_static(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
10    // expected-note@below {{when applied to this op}}
11    %0 = linalg.conv_2d_nhwc_hwcf
12      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
13       ins(%arg0, %arg1: tensor<?x?x?x?xf32>, tensor<3x3x4x16xf32>)
14      outs(%arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
15    return %0 : tensor<?x?x?x?xf32>
16}
17
18module attributes {transform.with_named_sequence} {
19  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
20    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
21    // expected-error@below {{failed to apply}}
22    %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
23    transform.yield
24  }
25}
26
27// -----
28
29// Check that we get the proper handles for the img2col tensor producer
30// and the final instruction.
31
32// CHECK: IR printer: tensor_producer
33// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
34// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
35// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
36
37// Collapsed indices.
38// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
39// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
40// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
41
42// Compute input channel/convolved indices.
43// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]]
44// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]]
45// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]]
46
47// Extract from the input tensor.
48// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
49// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
50// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
51
52// CHECK: IR printer: transformed
53// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
54
55// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
56// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
57// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
58// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
59//      CHECK: @conv_16433136
60//      CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
61//      CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x4x16xf32>
62//      CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
63//  CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
64//  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
65//      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
66//      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
67//           CHECK-SAME: #[[MAP0]]
68//                CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
69//                CHECK: linalg.yield %{{.+}} : f32
70//      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
71//           CHECK-SAME: #[[MAP1]]
72//           CHECK-SAME: #[[MAP2]]
73//           CHECK-SAME: #[[MAP3]]
74//           CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<36x16xf32>)
75//           CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
76//                CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
77//                CHECK:     %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
78//                CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
79//                CHECK:     linalg.yield %[[ADD]] : f32
80//                CHECK: } -> tensor<1x196x16xf32>
81//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
82//      CHECK: return %[[RESULT]]
83
84func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
85    %0 = linalg.conv_2d_nhwc_hwcf
86      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
87       ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
88      outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
89    return %0 : tensor<1x14x14x16xf32>
90}
91
92module attributes {transform.with_named_sequence} {
93  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
94    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
95    %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
96    transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
97    transform.print %transformed {name = "transformed"}: !transform.any_op
98    transform.yield
99  }
100}
101
102// -----
103
104// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
105// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
106// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
107// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
108// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
109// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
110// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>
111// CHECK: @depthwise_conv_hwc_114x16x3
112// CHECK-SAME: %[[INPUT:.+]]: tensor<1x114x114x16xf32>
113// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x16xf32>
114// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x112x112x16xf32>
115//      CHECK: %[[INPUT_T_INIT:.+]] = tensor.empty() : tensor<1x16x114x114xf32>
116//      CHECK: %[[INPUT_T:.+]] = linalg.generic
117// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
118// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
119// CHECK-SAME: ins(%[[INPUT]] : tensor<1x114x114x16xf32>) outs(%[[INPUT_T_INIT]] : tensor<1x16x114x114xf32>) {
120// CHECK-NEXT: ^bb0(%[[ARG3:.+]]: f32, %[[ARG4:.+]]: f32):
121// CHECK-NEXT:     linalg.yield %[[ARG3]] : f32
122// CHECK-NEXT:  } -> tensor<1x16x114x114xf32>
123//      CHECK: %[[FILTER_T_INIT:.+]] = tensor.empty() : tensor<16x3x3xf32>
124//      CHECK: %[[FILTER_T:.+]] = linalg.generic
125// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]
126// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
127// CHECK-SAME: ins(%[[FILTER]] : tensor<3x3x16xf32>) outs(%[[FILTER_T_INIT]] : tensor<16x3x3xf32>) {
128// CHECK-NEXT:      ^bb0(%{{.*}}: f32, %{{.*}}: f32):
129//      CHECK:      linalg.yield
130//      CHECK:    } -> tensor<16x3x3xf32>
131//      CHECK: %[[INIT_OUTPUT_TENSOR:.+]] = tensor.empty() : tensor<1x16x112x112xf32>
132//      CHECK: %[[OUTPUT_T:.+]] = linalg.generic
133// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
134// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
135// CHECK-SAME: ins(%[[OUTPUT]] : tensor<1x112x112x16xf32>) outs(%[[INIT_OUTPUT_TENSOR]] : tensor<1x16x112x112xf32>) {
136// CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32):
137// CHECK-NEXT:     linalg.yield
138// CHECK-NEXT:  } -> tensor<1x16x112x112xf32>
139//      CHECK:  %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x16x112x112x3x3xf32>
140//      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
141// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP5]]]
142// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
143// CHECK-SAME:   ins(%[[INPUT_T]] : tensor<1x16x114x114xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x16x112x112x3x3xf32>) {
144// CHECK-NEXT:      ^bb0(%{{.*}}: f32, %{{.*}}: f32):
145// CHECK-NEXT:         linalg.yield
146// CHECK-NEXT:    } -> tensor<1x16x112x112x3x3xf32>
147//      CHECK: %[[COL_TENSOR_R:.+]] = tensor.collapse_shape %[[COL_TENSOR]]
148// CHECK-SAME:    tensor<1x16x112x112x3x3xf32> into tensor<16x12544x9xf32>
149//      CHECK: %[[FILTER_T_R:.+]] = tensor.collapse_shape %[[FILTER_T]]
150// CHECK-SAME:    tensor<16x3x3xf32> into tensor<16x9xf32>
151//      CHECK: %[[OUTPUT_T_R:.+]] = tensor.collapse_shape %[[OUTPUT_T]]
152// CHECK-SAME:    tensor<1x16x112x112xf32> into tensor<16x12544xf32>
153//      CHECK: %[[BMV_RESULT:.+]] = linalg.batch_matvec ins(%[[COL_TENSOR_R]], %[[FILTER_T_R]] : tensor<16x12544x9xf32>, tensor<16x9xf32>) outs(%[[OUTPUT_T_R]] : tensor<16x12544xf32>) -> tensor<16x12544xf32>
154//      CHECK: %[[RESULT_R:.+]] = tensor.expand_shape %[[BMV_RESULT]]
155// CHECK-SAME:    tensor<16x12544xf32> into tensor<1x16x112x112xf32>
156//      CHECK: %[[RESULT_INIT:.+]] = tensor.empty() : tensor<1x112x112x16xf32>
157//      CHECK: %[[RESULT:.+]] = linalg.generic
158// CHECK-SAME: indexing_maps = [#[[MAP6]], #[[MAP1]]]
159// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
160// CHECK-SAME: ins(%[[RESULT_R]] : tensor<1x16x112x112xf32>) outs(%[[RESULT_INIT]] : tensor<1x112x112x16xf32>) {
161// CHECK-NEXT:      ^bb0(%{{.*}}: f32, %{{.*}}: f32):
162// CHECK-NEXT:      linalg.yield
163// CHECK-NEXT:    } -> tensor<1x112x112x16xf32>
164//      CHECK: return %[[RESULT]] : tensor<1x112x112x16xf32>
165func.func @depthwise_conv_hwc_114x16x3(%input: tensor<1x114x114x16xf32>, %filter: tensor<3x3x16xf32>, %output: tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32> {
166    %0 = linalg.depthwise_conv_2d_nhwc_hwc {
167      dilations = dense<1> : tensor<2xi64>,
168      strides = dense<1> : tensor<2xi64>
169    } ins(%input, %filter : tensor<1x114x114x16xf32>, tensor<3x3x16xf32>) outs(%output : tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32>
170    return %0 : tensor<1x112x112x16xf32>
171}
172
173module attributes {transform.with_named_sequence} {
174  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
175    %0 = transform.structured.match ops{["linalg.depthwise_conv_2d_nhwc_hwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
176    %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
177    transform.yield
178  }
179}
180
181// -----
182
183//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
184//  CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
185//  CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
186//  CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
187
188//      CHECK: func.func @batch_nhwc_conv
189// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x16x16x4xf32>, %[[FILTER:.+]]: tensor<3x3x4x16xf32>, %[[INIT:.+]]: tensor<8x14x14x16xf32>)
190//  CHECK-DAG:   %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
191//  CHECK-DAG:   %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32>
192//      CHECK:   %[[IT:.+]] = tensor.empty() : tensor<8x196x36xf32>
193//      CHECK:   %[[IMG2COL:.+]] = linalg.generic
194// CHECK-SAME:      indexing_maps = [#[[MAP]]]
195// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]
196// CHECK-SAME:   outs(%[[IT]] : tensor<8x196x36xf32>)
197//      CHECK:   %[[MATMUL:.+]] = linalg.generic
198// CHECK-SAME:      indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
199// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "reduction"]
200// CHECK-SAME:   ins(%[[IMG2COL]], %[[CS_FILTER]] : tensor<8x196x36xf32>, tensor<36x16xf32>)
201// CHECK-SAME:   outs(%[[CS_RESULT]] : tensor<8x196x16xf32>)
202//      CHECK:   ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
203//      CHECK:     %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
204//      CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
205//      CHECK:     linalg.yield %[[ADD]] : f32
206//      CHECK:   } -> tensor<8x196x16xf32>
207//      CHECK:   %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] output_shape [8, 14, 14, 16] : tensor<8x196x16xf32> into tensor<8x14x14x16xf32>
208//      CHECK:   return %[[CS_FINAL]]
209func.func @batch_nhwc_conv(%arg0: tensor<8x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> {
210    %0 = linalg.conv_2d_nhwc_hwcf
211      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
212       ins(%arg0, %arg1: tensor<8x16x16x4xf32>, tensor<3x3x4x16xf32>)
213      outs(%arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32>
214    return %0 : tensor<8x14x14x16xf32>
215}
216
217module attributes {transform.with_named_sequence} {
218  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
219    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
220    %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
221    transform.yield
222  }
223}
224
225// -----
226
227//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
228
229//  Im2col maps
230//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 floordiv 9)>
231//  CHECK-DAG: #[[MAP7:.+]] = affine_map<()[s0, s1] -> (s0 floordiv 14 + (s1 mod 9) floordiv 3)>
232//  CHECK-DAG: #[[MAP8:.+]] = affine_map<()[s0, s1] -> (s0 + s1 - (s0 floordiv 14) * 14 - (s1 floordiv 3) * 3)>
233
234
235//  CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
236//  CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
237//  CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
238
239//      CHECK: func.func @batch_nchw_conv
240// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.+]]: tensor<16x4x3x3xf32>, %[[INIT:.+]]: tensor<8x16x14x14xf32>)
241//  CHECK-DAG:   %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x4x3x3xf32> into tensor<16x36xf32>
242//  CHECK-DAG:   %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32>
243//      CHECK:   %[[IT:.+]] = tensor.empty() : tensor<8x36x196xf32>
244//      CHECK:   %[[IMG2COL:.+]] = linalg.generic
245// CHECK-SAME:      indexing_maps = [#[[MAP]]]
246// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]
247// CHECK-SAME:   outs(%[[IT]] : tensor<8x36x196xf32>)
248//      Collapsed indices.
249//      CHECK:       %[[BINDEX:.+]] = linalg.index 0 : index
250//      CHECK:       %[[KINDEX:.+]] = linalg.index 1 : index
251//      CHECK:       %[[NINDEX:.+]] = linalg.index 2 : index
252
253//      Compute input channel/convolved indices.
254//      CHECK:       %[[ICINDEX:.+]] = affine.apply #[[MAP1]]()[%[[KINDEX]]]
255//      CHECK:       %[[CONVH:.+]] = affine.apply #[[MAP7]]()[%[[NINDEX]], %[[KINDEX]]]
256//      CHECK:       %[[CONVW:.+]] = affine.apply #[[MAP8]]()[%[[NINDEX]], %[[KINDEX]]]
257
258//      Extract from the input tensor.
259//      CHECK:       %[[EXTRACTED_INPUT:.+]] = tensor.extract
260//      CHECK-SAME:  %[[INPUT]]{{\[}}%[[BINDEX]], %[[ICINDEX]], %[[CONVH]], %[[CONVW]]] : tensor<8x4x16x16xf32>
261//      CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
262//      CHECK:   %[[MATMUL:.+]] = linalg.generic
263// CHECK-SAME:      indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
264// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "reduction"]
265// CHECK-SAME:   ins(%[[CS_FILTER]], %[[IMG2COL]] : tensor<16x36xf32>, tensor<8x36x196xf32>)
266// CHECK-SAME:   outs(%[[CS_RESULT]] : tensor<8x16x196xf32>)
267//      CHECK:   ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
268//      CHECK:     %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
269//      CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
270//      CHECK:     linalg.yield %[[ADD]] : f32
271//      CHECK:   } -> tensor<8x16x196xf32>
272//      CHECK:   %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] output_shape [8, 16, 14, 14] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32>
273//      CHECK:   return %[[CS_FINAL]]
274func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> {
275    %0 = linalg.conv_2d_nchw_fchw
276      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
277       ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>)
278      outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
279    return %0 : tensor<8x16x14x14xf32>
280}
281
282module attributes {transform.with_named_sequence} {
283  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
284    %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
285    %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
286    transform.yield
287  }
288}
289
290// -----
291
292// CHECK: IR printer: tensor_producer
293// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
294// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
295// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
296
297// Collapsed indices.
298// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
299// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
300// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
301
302// Compute input channel/convolved indices.
303// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]]
304// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]]
305// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]]
306
307// Extract from the input tensor.
308// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
309// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
310// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
311
312// CHECK: IR printer: transformed
313// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
314
315// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
316// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
317// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
318// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
319//      CHECK: @conv_2d_nhwc_fhwc
320//      CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
321//      CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32>
322//      CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
323//  CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32>
324//  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
325//      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
326//      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
327//           CHECK-SAME: #[[MAP0]]
328//                CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
329//                CHECK: linalg.yield %{{.+}} : f32
330//      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
331//           CHECK-SAME: #[[MAP1]]
332//           CHECK-SAME: #[[MAP2]]
333//           CHECK-SAME: #[[MAP3]]
334//           CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>)
335//           CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
336//                CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
337//                CHECK:     %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
338//                CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
339//                CHECK:     linalg.yield %[[ADD]] : f32
340//                CHECK: } -> tensor<1x196x16xf32>
341//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
342//      CHECK: return %[[RESULT]]
343
344func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
345    %0 = linalg.conv_2d_nhwc_fhwc
346      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
347       ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
348      outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
349    return %0 : tensor<1x14x14x16xf32>
350}
351
352module attributes {transform.with_named_sequence} {
353  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
354    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
355    %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
356    transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
357    transform.print %transformed {name = "transformed"}: !transform.any_op
358    transform.yield
359  }
360}
361
362// -----
363
364// Check for signed extend when the input type is smaller than the accumulator type.
365
366// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
367// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
368// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
369// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
370//      CHECK: @conv_integer_extend
371//      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]]
372//           CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xi8>, tensor<36x16xi8>)
373//           CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xi32>)
374//                CHECK: ^bb0(%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8, %[[ARG2:.+]]: i32)
375//                CHECK:     %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
376//                CHECK:     %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
377//                CHECK:     %[[MUL:.+]] = arith.muli %[[EXT0]], %[[EXT1]] : i32
378//                CHECK:     %[[ADD:.+]] = arith.addi %[[MUL]], %[[ARG2]] : i32
379//                CHECK:     linalg.yield %[[ADD]] : i32
380//                CHECK: } -> tensor<1x196x16xi32>
381//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xi32> into tensor<1x14x14x16xi32>
382//      CHECK: return %[[RESULT]]
383
384func.func @conv_integer_extend(%arg0: tensor<1x16x16x4xi8>, %arg1: tensor<3x3x4x16xi8>, %arg2: tensor<1x14x14x16xi32>) -> tensor<1x14x14x16xi32> {
385    %0 = linalg.conv_2d_nhwc_hwcf
386      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
387       ins(%arg0, %arg1: tensor<1x16x16x4xi8>, tensor<3x3x4x16xi8>)
388      outs(%arg2: tensor<1x14x14x16xi32>) -> tensor<1x14x14x16xi32>
389    return %0 : tensor<1x14x14x16xi32>
390}
391
392module attributes {transform.with_named_sequence} {
393  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
394    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
395    %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
396    transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
397    transform.print %transformed {name = "transformed"}: !transform.any_op
398    transform.yield
399  }
400}
401
402// -----
403
404// Check for compatible complex case.
405
406// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
407// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
408// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
409// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
410//      CHECK: @conv_complex
411//      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]]
412//           CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xcomplex<f32>>, tensor<36x16xcomplex<f32>>)
413//           CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xcomplex<f32>>)
414//                CHECK: ^bb0(%[[ARG0:.+]]: complex<f32>, %[[ARG1:.+]]: complex<f32>, %[[ARG2:.+]]: complex<f32>)
415//                CHECK:     %[[MUL:.+]] = complex.mul %[[ARG0]], %[[ARG1]] : complex<f32>
416//                CHECK:     %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32>
417//                CHECK:     linalg.yield %[[ADD]] : complex<f32>
418//                CHECK: } -> tensor<1x196x16xcomplex<f32>>
419//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
420//      CHECK: return %[[RESULT]]
421
422func.func @conv_complex(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xcomplex<f32>>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> {
423    %0 = linalg.conv_2d_nhwc_hwcf
424      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
425       ins(%arg0, %arg1: tensor<1x16x16x4xcomplex<f32>>, tensor<3x3x4x16xcomplex<f32>>)
426      outs(%arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>>
427    return %0 : tensor<1x14x14x16xcomplex<f32>>
428}
429
430module attributes {transform.with_named_sequence} {
431  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
432    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
433    %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
434    transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
435    transform.print %transformed {name = "transformed"}: !transform.any_op
436    transform.yield
437  }
438}
439
440// -----
441
442// Check for compatible complex extended case.
443
444// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
445// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
446// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
447// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
448//      CHECK: @conv_complex_extended
449//      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]]
450//           CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xcomplex<f32>>, tensor<36x16xcomplex<f16>>)
451//           CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xcomplex<f32>>)
452//                CHECK: ^bb0(%[[ARG0:.+]]: complex<f32>, %[[ARG1:.+]]: complex<f16>, %[[ARG2:.+]]: complex<f32>)
453//                CHECK:     %[[REAL:.+]] = complex.re %[[ARG1]] : complex<f16>
454//                CHECK:     %[[IMAG:.+]] = complex.im %[[ARG1]] : complex<f16>
455//                CHECK:     %[[REEXT:.+]] = arith.extf %[[REAL]] : f16 to f32
456//                CHECK:     %[[IMEXT:.+]] = arith.extf %[[IMAG]] : f16 to f32
457//                CHECK:     %[[COMPLEX:.+]] = complex.create %[[REEXT]], %[[IMEXT]] : complex<f32>
458//                CHECK:     %[[MUL:.+]] = complex.mul %[[ARG0]], %[[COMPLEX]] : complex<f32>
459//                CHECK:     %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32>
460//                CHECK:     linalg.yield %[[ADD]] : complex<f32>
461//                CHECK: } -> tensor<1x196x16xcomplex<f32>>
462//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
463//      CHECK: return %[[RESULT]]
464
465func.func @conv_complex_extended(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xcomplex<f16>>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> {
466    %0 = linalg.conv_2d_nhwc_hwcf
467      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
468       ins(%arg0, %arg1: tensor<1x16x16x4xcomplex<f32>>, tensor<3x3x4x16xcomplex<f16>>)
469      outs(%arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>>
470    return %0 : tensor<1x14x14x16xcomplex<f32>>
471}
472
473module attributes {transform.with_named_sequence} {
474  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
475    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
476    %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
477    transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
478    transform.print %transformed {name = "transformed"}: !transform.any_op
479    transform.yield
480  }
481}
482
483// -----
484
485// Check for compatible complex extended case.
486
487// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
488// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
489// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
490// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
491//      CHECK: @conv_complex_f16_extended
492//      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]]
493//           CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xcomplex<f32>>, tensor<36x16xf16>)
494//           CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xcomplex<f32>>)
495//                CHECK: ^bb0(%[[ARG0:.+]]: complex<f32>, %[[ARG1:.+]]: f16, %[[ARG2:.+]]: complex<f32>)
496//                CHECK:     %[[EXT:.+]] = arith.extf %[[ARG1]] : f16 to f32
497//                CHECK:     %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
498//                CHECK:     %[[COMPLEX:.+]] = complex.create %[[EXT]], %[[ZERO]]
499//                CHECK:     %[[MUL:.+]] = complex.mul %[[ARG0]], %[[COMPLEX]] : complex<f32>
500//                CHECK:     %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32>
501//                CHECK:     linalg.yield %[[ADD]] : complex<f32>
502//                CHECK: } -> tensor<1x196x16xcomplex<f32>>
503//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
504//      CHECK: return %[[RESULT]]
505
506func.func @conv_complex_f16_extended(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xf16>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> {
507    %0 = linalg.conv_2d_nhwc_hwcf
508      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
509       ins(%arg0, %arg1: tensor<1x16x16x4xcomplex<f32>>, tensor<3x3x4x16xf16>)
510      outs(%arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>>
511    return %0 : tensor<1x14x14x16xcomplex<f32>>
512}
513
514module attributes {transform.with_named_sequence} {
515  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
516    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
517    %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
518    transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
519    transform.print %transformed {name = "transformed"}: !transform.any_op
520    transform.yield
521  }
522}
523