1// RUN: mlir-opt %s -transform-interpreter | FileCheck %s 2 3// CHECK-LABEL: @conv_2d_nhwc_fhwc_f64 4// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf64>, %[[FILTER:.+]]: tensor<8x2x2x6xf64>, %[[INIT:.+]]: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64> { 5// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf64> 6// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf64>) outs(%[[NEWF]] : tensor<2x2x6x8xf64>) permutation = [1, 2, 3, 0] 7// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf64>, tensor<2x2x6x8xf64>) outs(%[[INIT]] : tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64> 8// CHECK: return %[[CONV]] : tensor<1x2x2x8xf64> 9func.func @conv_2d_nhwc_fhwc_f64(%input: tensor<1x4x4x6xf64>, %filter: tensor<8x2x2x6xf64>, %init: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64> { 10 %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, 11 strides = dense<2> : tensor<2xi64>} 12 ins (%input, %filter: tensor<1x4x4x6xf64>, tensor<8x2x2x6xf64>) 13 outs (%init: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64> 14 return %0 : tensor<1x2x2x8xf64> 15} 16 17// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32 18// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> { 19// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32> 20// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0] 21// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> 22// CHECK: return %[[CONV]] : tensor<1x2x2x8xf32> 23func.func @conv_2d_nhwc_fhwc_f32(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> { 24 %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, 25 strides = dense<2> : tensor<2xi64>} 26 ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>) 27 outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> 28 return %0 : tensor<1x2x2x8xf32> 29} 30 31// CHECK-LABEL: @conv_2d_nhwc_fhwc_f16 32// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf16>, %[[FILTER:.+]]: tensor<8x2x2x6xf16>, %[[INIT:.+]]: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16> { 33// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf16> 34// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf16>) outs(%[[NEWF]] : tensor<2x2x6x8xf16>) permutation = [1, 2, 3, 0] 35// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf16>, tensor<2x2x6x8xf16>) outs(%[[INIT]] : tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16> 36// CHECK: return %[[CONV]] : tensor<1x2x2x8xf16> 37func.func @conv_2d_nhwc_fhwc_f16(%input: tensor<1x4x4x6xf16>, %filter: tensor<8x2x2x6xf16>, %init: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16> { 38 %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, 39 strides = dense<2> : tensor<2xi64>} 40 ins (%input, %filter: tensor<1x4x4x6xf16>, tensor<8x2x2x6xf16>) 41 outs (%init: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16> 42 return %0 : tensor<1x2x2x8xf16> 43} 44 45// CHECK-LABEL: @conv_2d_nhwc_fhwc_b16 46// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xbf16>, %[[FILTER:.+]]: tensor<8x2x2x6xbf16>, %[[INIT:.+]]: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16> { 47// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xbf16> 48// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xbf16>) outs(%[[NEWF]] : tensor<2x2x6x8xbf16>) permutation = [1, 2, 3, 0] 49// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xbf16>, tensor<2x2x6x8xbf16>) outs(%[[INIT]] : tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16> 50// CHECK: return %[[CONV]] : tensor<1x2x2x8xbf16> 51func.func @conv_2d_nhwc_fhwc_b16(%input: tensor<1x4x4x6xbf16>, %filter: tensor<8x2x2x6xbf16>, %init: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16> { 52 %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, 53 strides = dense<2> : tensor<2xi64>} 54 ins (%input, %filter: tensor<1x4x4x6xbf16>, tensor<8x2x2x6xbf16>) 55 outs (%init: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16> 56 return %0 : tensor<1x2x2x8xbf16> 57} 58 59// CHECK-LABEL: @conv_2d_nhwc_fhwc 60// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi64>, %[[FILTER:.+]]: tensor<8x2x2x6xi64>, %[[INIT:.+]]: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64> { 61// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi64> 62// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi64>) outs(%[[NEWF]] : tensor<2x2x6x8xi64>) permutation = [1, 2, 3, 0] 63// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi64>, tensor<2x2x6x8xi64>) outs(%[[INIT]] : tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64> 64// CHECK: return %[[CONV]] : tensor<1x2x2x8xi64> 65func.func @conv_2d_nhwc_fhwc_i64(%input: tensor<1x4x4x6xi64>, %filter: tensor<8x2x2x6xi64>, %init: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64> { 66 %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, 67 strides = dense<2> : tensor<2xi64>} 68 ins (%input, %filter: tensor<1x4x4x6xi64>, tensor<8x2x2x6xi64>) 69 outs (%init: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64> 70 return %0 : tensor<1x2x2x8xi64> 71} 72 73// CHECK-LABEL: @conv_2d_nhwc_fhwc_i32 74// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi32>, %[[FILTER:.+]]: tensor<8x2x2x6xi32>, %[[INIT:.+]]: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32> { 75// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi32> 76// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi32>) outs(%[[NEWF]] : tensor<2x2x6x8xi32>) permutation = [1, 2, 3, 0] 77// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi32>, tensor<2x2x6x8xi32>) outs(%[[INIT]] : tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32> 78// CHECK: return %[[CONV]] : tensor<1x2x2x8xi32> 79func.func @conv_2d_nhwc_fhwc_i32(%input: tensor<1x4x4x6xi32>, %filter: tensor<8x2x2x6xi32>, %init: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32> { 80 %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, 81 strides = dense<2> : tensor<2xi64>} 82 ins (%input, %filter: tensor<1x4x4x6xi32>, tensor<8x2x2x6xi32>) 83 outs (%init: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32> 84 return %0 : tensor<1x2x2x8xi32> 85} 86 87// CHECK-LABEL: @conv_2d_nhwc_fhwc_i16 88// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi16>, %[[FILTER:.+]]: tensor<8x2x2x6xi16>, %[[INIT:.+]]: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16> { 89// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi16> 90// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi16>) outs(%[[NEWF]] : tensor<2x2x6x8xi16>) permutation = [1, 2, 3, 0] 91// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi16>, tensor<2x2x6x8xi16>) outs(%[[INIT]] : tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16> 92// CHECK: return %[[CONV]] : tensor<1x2x2x8xi16> 93func.func @conv_2d_nhwc_fhwc_i16(%input: tensor<1x4x4x6xi16>, %filter: tensor<8x2x2x6xi16>, %init: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16> { 94 %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, 95 strides = dense<2> : tensor<2xi64>} 96 ins (%input, %filter: tensor<1x4x4x6xi16>, tensor<8x2x2x6xi16>) 97 outs (%init: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16> 98 return %0 : tensor<1x2x2x8xi16> 99} 100 101// CHECK-LABEL: @conv_2d_nhwc_fhwc_i8 102// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi8>, %[[FILTER:.+]]: tensor<8x2x2x6xi8>, %[[INIT:.+]]: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8> { 103// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi8> 104// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi8>) outs(%[[NEWF]] : tensor<2x2x6x8xi8>) permutation = [1, 2, 3, 0] 105// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi8>, tensor<2x2x6x8xi8>) outs(%[[INIT]] : tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8> 106// CHECK: return %[[CONV]] : tensor<1x2x2x8xi8> 107func.func @conv_2d_nhwc_fhwc_i8(%input: tensor<1x4x4x6xi8>, %filter: tensor<8x2x2x6xi8>, %init: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8> { 108 %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, 109 strides = dense<2> : tensor<2xi64>} 110 ins (%input, %filter: tensor<1x4x4x6xi8>, tensor<8x2x2x6xi8>) 111 outs (%init: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8> 112 return %0 : tensor<1x2x2x8xi8> 113} 114 115// CHECK-LABEL: @conv_2d_nhwc_fhwc_q 116// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>, %[[A:.+]]: i32, %[[B:.+]]: i32) -> tensor<1x2x2x8xf32> { 117// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32> 118// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0] 119// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]], %[[A]], %[[B]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>, i32, i32) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> 120// CHECK: return %[[CONV]] : tensor<1x2x2x8xf32> 121 func.func @conv_2d_nhwc_fhwc_q(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>, %a: i32, %b: i32) -> tensor<1x2x2x8xf32> { 122 %0 = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<1> : tensor<2xi64>, 123 strides = dense<2> : tensor<2xi64>} 124 ins (%input, %filter, %a, %b: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>, i32, i32) 125 outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> 126 return %0 : tensor<1x2x2x8xf32> 127} 128 129// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32_unit_stride 130// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32> { 131// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32> 132// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0] 133// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32> 134// CHECK: return %[[CONV]] : tensor<1x3x3x8xf32> 135func.func @conv_2d_nhwc_fhwc_f32_unit_stride(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32> { 136 %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, 137 strides = dense<1> : tensor<2xi64>} 138 ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>) 139 outs (%init: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32> 140 return %0 : tensor<1x3x3x8xf32> 141} 142 143// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32_2_dialation 144// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> { 145// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32> 146// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0] 147// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> 148// CHECK: return %[[CONV]] : tensor<1x2x2x8xf32> 149func.func @conv_2d_nhwc_fhwc_f32_2_dialation(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> { 150 %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<2> : tensor<2xi64>, 151 strides = dense<1> : tensor<2xi64>} 152 ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>) 153 outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> 154 return %0 : tensor<1x2x2x8xf32> 155} 156 157// CHECK-LABEL: @conv_2d_nhwc_fhwc_memref 158// CHECK-SAME: (%[[INPUT:.+]]: memref<1x4x4x6xf32>, %[[FILTER:.+]]: memref<8x2x2x6xf32>, %[[INIT:.+]]: memref<1x2x2x8xf32>) -> memref<1x2x2x8xf32> { 159// CHECK-DAG: %[[NEWF:.+]] = memref.alloc() : memref<2x2x6x8xf32> 160// CHECK: linalg.transpose ins(%[[FILTER]] : memref<8x2x2x6xf32>) outs(%[[NEWF]] : memref<2x2x6x8xf32>) permutation = [1, 2, 3, 0] 161// CHECK: linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[NEWF]] : memref<1x4x4x6xf32>, memref<2x2x6x8xf32>) outs(%[[INIT]] : memref<1x2x2x8xf32>) 162// CHECK: return %[[INIT]] : memref<1x2x2x8xf32> 163func.func @conv_2d_nhwc_fhwc_memref(%input: memref<1x4x4x6xf32>, %filter: memref<8x2x2x6xf32>, %init: memref<1x2x2x8xf32>) -> memref<1x2x2x8xf32> { 164 linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, 165 strides = dense<2> : tensor<2xi64>} 166 ins (%input, %filter: memref<1x4x4x6xf32>, memref<8x2x2x6xf32>) 167 outs (%init: memref<1x2x2x8xf32>) 168 return %init : memref<1x2x2x8xf32> 169} 170 171module attributes {transform.with_named_sequence} { 172 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 173 %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc", "linalg.conv_2d_nhwc_fhwc_q"]} in %arg1 : (!transform.any_op) -> !transform.any_op 174 %1 = transform.structured.transpose_conv2d %0 : (!transform.any_op) -> (!transform.any_op) 175 transform.yield 176 } 177} 178