xref: /llvm-project/mlir/test/Dialect/Linalg/transpose-conv2d.mlir (revision fcdb848596c33cf05c8b6e99296a171482719493)
1// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
2
3// CHECK-LABEL: @conv_2d_nhwc_fhwc_f64
4// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf64>, %[[FILTER:.+]]: tensor<8x2x2x6xf64>, %[[INIT:.+]]: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64> {
5// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf64>
6// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf64>) outs(%[[NEWF]] : tensor<2x2x6x8xf64>) permutation = [1, 2, 3, 0]
7// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf64>, tensor<2x2x6x8xf64>) outs(%[[INIT]] : tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64>
8// CHECK:    return %[[CONV]] : tensor<1x2x2x8xf64>
9func.func @conv_2d_nhwc_fhwc_f64(%input: tensor<1x4x4x6xf64>, %filter: tensor<8x2x2x6xf64>, %init: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64> {
10  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
11                                              strides = dense<2> : tensor<2xi64>}
12     ins (%input, %filter: tensor<1x4x4x6xf64>, tensor<8x2x2x6xf64>)
13    outs (%init: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64>
14  return %0 : tensor<1x2x2x8xf64>
15}
16
17// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32
18// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
19// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
20// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
21// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
22// CHECK:    return %[[CONV]] : tensor<1x2x2x8xf32>
23func.func @conv_2d_nhwc_fhwc_f32(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
24  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
25                                              strides = dense<2> : tensor<2xi64>}
26     ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
27    outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
28  return %0 : tensor<1x2x2x8xf32>
29}
30
31// CHECK-LABEL: @conv_2d_nhwc_fhwc_f16
32// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf16>, %[[FILTER:.+]]: tensor<8x2x2x6xf16>, %[[INIT:.+]]: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16> {
33// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf16>
34// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf16>) outs(%[[NEWF]] : tensor<2x2x6x8xf16>) permutation = [1, 2, 3, 0]
35// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf16>, tensor<2x2x6x8xf16>) outs(%[[INIT]] : tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16>
36// CHECK:    return %[[CONV]] : tensor<1x2x2x8xf16>
37func.func @conv_2d_nhwc_fhwc_f16(%input: tensor<1x4x4x6xf16>, %filter: tensor<8x2x2x6xf16>, %init: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16> {
38  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
39                                              strides = dense<2> : tensor<2xi64>}
40     ins (%input, %filter: tensor<1x4x4x6xf16>, tensor<8x2x2x6xf16>)
41    outs (%init: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16>
42  return %0 : tensor<1x2x2x8xf16>
43}
44
45// CHECK-LABEL: @conv_2d_nhwc_fhwc_b16
46// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xbf16>, %[[FILTER:.+]]: tensor<8x2x2x6xbf16>, %[[INIT:.+]]: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16> {
47// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xbf16>
48// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xbf16>) outs(%[[NEWF]] : tensor<2x2x6x8xbf16>) permutation = [1, 2, 3, 0]
49// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xbf16>, tensor<2x2x6x8xbf16>) outs(%[[INIT]] : tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16>
50// CHECK:    return %[[CONV]] : tensor<1x2x2x8xbf16>
51func.func @conv_2d_nhwc_fhwc_b16(%input: tensor<1x4x4x6xbf16>, %filter: tensor<8x2x2x6xbf16>, %init: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16> {
52  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
53                                              strides = dense<2> : tensor<2xi64>}
54     ins (%input, %filter: tensor<1x4x4x6xbf16>, tensor<8x2x2x6xbf16>)
55    outs (%init: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16>
56  return %0 : tensor<1x2x2x8xbf16>
57}
58
59// CHECK-LABEL: @conv_2d_nhwc_fhwc
60// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi64>, %[[FILTER:.+]]: tensor<8x2x2x6xi64>, %[[INIT:.+]]: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64> {
61// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi64>
62// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi64>) outs(%[[NEWF]] : tensor<2x2x6x8xi64>) permutation = [1, 2, 3, 0]
63// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi64>, tensor<2x2x6x8xi64>) outs(%[[INIT]] : tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64>
64// CHECK:    return %[[CONV]] : tensor<1x2x2x8xi64>
65func.func @conv_2d_nhwc_fhwc_i64(%input: tensor<1x4x4x6xi64>, %filter: tensor<8x2x2x6xi64>, %init: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64> {
66  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
67                                              strides = dense<2> : tensor<2xi64>}
68     ins (%input, %filter: tensor<1x4x4x6xi64>, tensor<8x2x2x6xi64>)
69    outs (%init: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64>
70  return %0 : tensor<1x2x2x8xi64>
71}
72
73// CHECK-LABEL: @conv_2d_nhwc_fhwc_i32
74// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi32>, %[[FILTER:.+]]: tensor<8x2x2x6xi32>, %[[INIT:.+]]: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32> {
75// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi32>
76// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi32>) outs(%[[NEWF]] : tensor<2x2x6x8xi32>) permutation = [1, 2, 3, 0]
77// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi32>, tensor<2x2x6x8xi32>) outs(%[[INIT]] : tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32>
78// CHECK:    return %[[CONV]] : tensor<1x2x2x8xi32>
79func.func @conv_2d_nhwc_fhwc_i32(%input: tensor<1x4x4x6xi32>, %filter: tensor<8x2x2x6xi32>, %init: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32> {
80  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
81                                              strides = dense<2> : tensor<2xi64>}
82     ins (%input, %filter: tensor<1x4x4x6xi32>, tensor<8x2x2x6xi32>)
83    outs (%init: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32>
84  return %0 : tensor<1x2x2x8xi32>
85}
86
87// CHECK-LABEL: @conv_2d_nhwc_fhwc_i16
88// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi16>, %[[FILTER:.+]]: tensor<8x2x2x6xi16>, %[[INIT:.+]]: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16> {
89// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi16>
90// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi16>) outs(%[[NEWF]] : tensor<2x2x6x8xi16>) permutation = [1, 2, 3, 0]
91// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi16>, tensor<2x2x6x8xi16>) outs(%[[INIT]] : tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16>
92// CHECK:    return %[[CONV]] : tensor<1x2x2x8xi16>
93func.func @conv_2d_nhwc_fhwc_i16(%input: tensor<1x4x4x6xi16>, %filter: tensor<8x2x2x6xi16>, %init: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16> {
94  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
95                                              strides = dense<2> : tensor<2xi64>}
96     ins (%input, %filter: tensor<1x4x4x6xi16>, tensor<8x2x2x6xi16>)
97    outs (%init: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16>
98  return %0 : tensor<1x2x2x8xi16>
99}
100
101// CHECK-LABEL: @conv_2d_nhwc_fhwc_i8
102// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi8>, %[[FILTER:.+]]: tensor<8x2x2x6xi8>, %[[INIT:.+]]: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8> {
103// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi8>
104// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi8>) outs(%[[NEWF]] : tensor<2x2x6x8xi8>) permutation = [1, 2, 3, 0]
105// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi8>, tensor<2x2x6x8xi8>) outs(%[[INIT]] : tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8>
106// CHECK:    return %[[CONV]] : tensor<1x2x2x8xi8>
107func.func @conv_2d_nhwc_fhwc_i8(%input: tensor<1x4x4x6xi8>, %filter: tensor<8x2x2x6xi8>, %init: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8> {
108  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
109                                              strides = dense<2> : tensor<2xi64>}
110     ins (%input, %filter: tensor<1x4x4x6xi8>, tensor<8x2x2x6xi8>)
111    outs (%init: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8>
112  return %0 : tensor<1x2x2x8xi8>
113}
114
115// CHECK-LABEL: @conv_2d_nhwc_fhwc_q
116// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>, %[[A:.+]]: i32, %[[B:.+]]: i32) -> tensor<1x2x2x8xf32> {
117// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
118// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
119// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]], %[[A]], %[[B]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>, i32, i32) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
120// CHECK:    return %[[CONV]] : tensor<1x2x2x8xf32>
121  func.func @conv_2d_nhwc_fhwc_q(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>, %a: i32, %b: i32) -> tensor<1x2x2x8xf32> {
122  %0 = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<1> : tensor<2xi64>,
123                                              strides = dense<2> : tensor<2xi64>}
124     ins (%input, %filter, %a, %b: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>, i32, i32)
125    outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
126  return %0 : tensor<1x2x2x8xf32>
127}
128
129// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32_unit_stride
130// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32> {
131// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
132// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
133// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32>
134// CHECK:    return %[[CONV]] : tensor<1x3x3x8xf32>
135func.func @conv_2d_nhwc_fhwc_f32_unit_stride(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32> {
136  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
137                                              strides = dense<1> : tensor<2xi64>}
138     ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
139    outs (%init: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32>
140  return %0 : tensor<1x3x3x8xf32>
141}
142
143// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32_2_dialation
144// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
145// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
146// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
147// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
148// CHECK:    return %[[CONV]] : tensor<1x2x2x8xf32>
149func.func @conv_2d_nhwc_fhwc_f32_2_dialation(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
150  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<2> : tensor<2xi64>,
151                                              strides = dense<1> : tensor<2xi64>}
152     ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
153    outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
154  return %0 : tensor<1x2x2x8xf32>
155}
156
157// CHECK-LABEL: @conv_2d_nhwc_fhwc_memref
158// CHECK-SAME: (%[[INPUT:.+]]: memref<1x4x4x6xf32>, %[[FILTER:.+]]: memref<8x2x2x6xf32>, %[[INIT:.+]]: memref<1x2x2x8xf32>) -> memref<1x2x2x8xf32> {
159// CHECK-DAG:    %[[NEWF:.+]] = memref.alloc() : memref<2x2x6x8xf32>
160// CHECK:    linalg.transpose ins(%[[FILTER]] : memref<8x2x2x6xf32>) outs(%[[NEWF]] : memref<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
161// CHECK:    linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[NEWF]] : memref<1x4x4x6xf32>, memref<2x2x6x8xf32>) outs(%[[INIT]] : memref<1x2x2x8xf32>)
162// CHECK:    return %[[INIT]] : memref<1x2x2x8xf32>
163func.func @conv_2d_nhwc_fhwc_memref(%input: memref<1x4x4x6xf32>, %filter: memref<8x2x2x6xf32>, %init: memref<1x2x2x8xf32>) -> memref<1x2x2x8xf32> {
164  linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
165                            strides = dense<2> : tensor<2xi64>}
166     ins (%input, %filter: memref<1x4x4x6xf32>, memref<8x2x2x6xf32>)
167    outs (%init: memref<1x2x2x8xf32>)
168  return %init : memref<1x2x2x8xf32>
169}
170
171module attributes {transform.with_named_sequence} {
172  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
173    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc", "linalg.conv_2d_nhwc_fhwc_q"]} in %arg1 : (!transform.any_op) -> !transform.any_op
174    %1 = transform.structured.transpose_conv2d %0 : (!transform.any_op) -> (!transform.any_op)
175    transform.yield
176  }
177}
178