xref: /llvm-project/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (revision 956c0707d9098499a2682297b71f46b0a562eed9)
1// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s
2
3// CHECK-LABEL: @transpose_conv2d
4func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x18x19x5xf32> {
5  // CHECK: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
6  // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
7  // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2
8  // CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, stride = array<i64: 1, 1>
9  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32>
10  return %0 : tensor<2x18x19x5xf32>
11}
12
13// -----
14
15// CHECK-LABEL: @transpose_conv2d_quantized
16
17func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) {
18  // CHECK: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
19  // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
20  // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
21  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32>
22  return %0 : tensor<2x18x19x5xi32>
23}
24
25// -----
26
27// CHECK-LABEL: @transpose_conv2d_quantized_padded
28func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x21x26x5xi32>) {
29  // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %0 {axis = 2 : i32}
30  // CHECK-DAG: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
31  // CHECK: tosa.conv2d %arg0, %1, %arg2
32  // CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>,
33  // CHECK-SAME: quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
34  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {
35    acc_type = i32,
36    out_pad = array<i64: 1, 2, 3, 4>,
37    quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>,
38    out_shape = array<i64: -1, -1, -1, -1>,
39    stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x21x26x5xi32>
40  return %0 : tensor<2x21x26x5xi32>
41}
42
43// -----
44
45// CHECK-LABEL: @transpose_conv2d_strided
46
47func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
48  // Manipulate the weight matrix to handle striding.
49  // CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
50  // CHECK-DAG: %[[TRANSV:.+]]  = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
51  // CHECK-DAG: %[[PADW:.+]]  = tosa.pad %arg1, %[[PADV]]
52  // CHECK-DAG: %[[RESW1:.+]]  = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
53  // CHECK-DAG: %[[TRANS:.+]]  = tosa.transpose %[[RESW1]], %[[TRANSV]]
54  // CHECK-DAG: %[[RESW2:.+]]  = tosa.reshape %[[TRANS]] {new_shape = array<i64: 30, 2, 2, 3>}
55  // CHECK-DAG: %[[REV1:.+]]  = tosa.reverse %[[RESW2]] {axis = 1 : i32}
56  // CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
57
58  // Pad out the input matrix to handle the transpose conv.
59  // CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
60  // CHECK-DAG: %[[TRANS2:.+]]  = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
61  // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]]
62
63  // Manipulate the final shape.
64  // CHECK-DAG: %[[BIAS:.+]]  = "tosa.const"() <{value = dense<0.000000e+00> : tensor<30xf32>}
65  // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
66  // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 2, 18, 16, 2, 3, 5>}
67  // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]]
68  // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]]
69  // CHECK-DAG: %[[START:.*]] = tosa.const_shape  {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
70  // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape  {value = dense<[2, 35, 47, 5]> : tensor<4xindex>} : () -> !tosa.shape<4>
71  // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
72  // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2
73  // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]]
74  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
75  %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32>
76  return %1 : tensor<2x?x?x5xf32>
77}
78
79// -----
80
81// CHECK-LABEL: @transpose_conv2d_strided_quantized
82
83func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) {
84  // Manipulate the weight matrix to handle striding.
85  // CHECK-DAG: %[[PADV:.+]]  = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
86  // CHECK-DAG: %[[TRANSV:.+]]  = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
87  // CHECK-DAG: %[[PADW:.+]]  = tosa.pad %arg1, %[[PADV]] {quantization_info = #tosa.pad_quant<input_zp = 42>}
88  // CHECK-DAG: %[[RESW1:.+]]  = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
89  // CHECK-DAG: %[[TRANS:.+]]  = tosa.transpose %[[RESW1]], %[[TRANSV]]
90  // CHECK-DAG: %[[RESW2:.+]]  = tosa.reshape %[[TRANS]] {new_shape = array<i64: 30, 2, 2, 3>}
91  // CHECK-DAG: %[[REV1:.+]]  = tosa.reverse %[[RESW2]] {axis = 1 : i32}
92  // CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
93
94  // Pad out the input matrix to handle the transpose conv.
95  // CHECK-DAG: %[[PAD:.+]]  = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
96  // CHECK-DAG: %[[TRANS2:.+]]  = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
97  // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {quantization_info = #tosa.pad_quant<input_zp = -22>}
98
99  // Manipulate the final shape.
100  // CHECK-DAG: %[[BIAS:.+]]  = "tosa.const"() <{value = dense<0> : tensor<30xi32>}
101  // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
102  // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 2, 18, 16, 2, 3, 5>}
103  // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]]
104  // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]]
105  // CHECK-DAG: %[[START:.*]] = tosa.const_shape  {value = dense<0> : tensor<4xindex>}
106  // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape  {value = dense<[2, 35, 47, 5]> : tensor<4xindex>}
107  // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
108  // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2
109  // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]]
110  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
111  return %0 : tensor<2x35x47x5xi32>
112}
113
114// -----
115
116// CHECK-LABEL: @transpose_conv2d_strided_overpad
117func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) {
118  // CHECK-DAG: %[[WEIGHT_PAD:.+]] = tosa.const_shape  {value = dense<[0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
119  // CHECK-DAG: %[[WEIGHT_PERMS:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
120  // CHECK-DAG: %[[INPUT_PAD:.+]] = tosa.const_shape  {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
121  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
122  // CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
123  // CHECK-DAG: %[[RESULT_PAD:.+]] = tosa.const_shape  {value = dense<[0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
124  // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = 93>}
125  // CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array<i64: 1, 2, 1, 1, 2, 1>}
126  // CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]]
127  // CHECK: %[[RESHAPE_WEIGHT_1:.+]] = tosa.reshape %[[TRANSPOSE_WEIGHT]] {new_shape = array<i64: 2, 2, 1, 1>}
128  // CHECK: %[[REVERSE:.+]] = tosa.reverse %[[RESHAPE_WEIGHT_1]] {axis = 1 : i32}
129  // CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = -103>}
130  // CHECK: %[[CONV:.+]] = tosa.conv2d %[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]]
131  // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -103, weight_zp = 93>, stride = [1, 1]}
132  // CHECK: %[[RESHAPE_RESULT_0:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 1, 17, 1, 1, 2, 1>}
133  // CHECK: %[[TRANSPOSE_RESULT:.+]] = tosa.transpose %[[RESHAPE_RESULT_0]], %[[RESULT_PERMS]]
134  // CHECK: %[[RESHAPE_RESULT_1:.+]] = tosa.reshape %[[TRANSPOSE_RESULT]] {new_shape = array<i64: 1, 17, 2, 1>}
135  // CHECK: %[[PAD_RESULT:.+]] = tosa.pad %[[RESHAPE_RESULT_1]], %[[RESULT_PAD]]
136  // CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
137  // CHECK: %[[ADD:.+]] = tosa.add %[[PAD_RESULT]], %[[RESHAPE_ARG2]]
138  %2 =  tosa.transpose_conv2d %arg0, %arg1, %arg2 {
139    acc_type = i32,
140    out_pad = array<i64: 2, 0, 0, 1>,
141    out_shape = array<i64: 1, -1, -1, 1>,
142    stride = array<i64: 1, 2>,
143    quantization_info = #tosa.conv_quant<input_zp = -103, weight_zp = 93>} :
144    (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>) -> tensor<1x19x2x1xi32>
145  "func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()
146}
147