1// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s 2 3// Check that the im2col patterns are properly connected with the 4// transform dialect. 5 6// Non static shapes are not supported. 7// Check that we emit an error. 8// TODO: Hook up the rewriter errors in transform dialect. 9func.func @conv_non_static(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { 10 // expected-note@below {{when applied to this op}} 11 %0 = linalg.conv_2d_nhwc_hwcf 12 {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } 13 ins(%arg0, %arg1: tensor<?x?x?x?xf32>, tensor<3x3x4x16xf32>) 14 outs(%arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> 15 return %0 : tensor<?x?x?x?xf32> 16} 17 18module attributes {transform.with_named_sequence} { 19 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 20 %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op 21 // expected-error@below {{failed to apply}} 22 %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 23 transform.yield 24 } 25} 26 27// ----- 28 29// Check that we get the proper handles for the img2col tensor producer 30// and the final instruction. 31 32// CHECK: IR printer: tensor_producer 33// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic 34// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>] 35// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32) 36 37// Collapsed indices. 38// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index 39// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index 40// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index 41 42// Compute input channel/convolved indices. 43// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]] 44// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]] 45// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]] 46 47// Extract from the input tensor. 48// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract 49// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32> 50// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32 51 52// CHECK: IR printer: transformed 53// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> 54 55// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 56// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 57// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> 58// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 59// CHECK: @conv_16433136 60// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32> 61// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x4x16xf32> 62// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32> 63// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32> 64// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32> 65// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32> 66// CHECK: %[[COL_TENSOR:.+]] = linalg.generic 67// CHECK-SAME: #[[MAP0]] 68// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32) 69// CHECK: linalg.yield %{{.+}} : f32 70// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic 71// CHECK-SAME: #[[MAP1]] 72// CHECK-SAME: #[[MAP2]] 73// CHECK-SAME: #[[MAP3]] 74// CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<36x16xf32>) 75// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>) 76// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32) 77// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 78// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 79// CHECK: linalg.yield %[[ADD]] : f32 80// CHECK: } -> tensor<1x196x16xf32> 81// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> 82// CHECK: return %[[RESULT]] 83 84func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { 85 %0 = linalg.conv_2d_nhwc_hwcf 86 {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } 87 ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) 88 outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> 89 return %0 : tensor<1x14x14x16xf32> 90} 91 92module attributes {transform.with_named_sequence} { 93 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 94 %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op 95 %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 96 transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op 97 transform.print %transformed {name = "transformed"}: !transform.any_op 98 transform.yield 99 } 100} 101 102// ----- 103 104// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> 105// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 106// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> 107// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 108// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)> 109// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> 110// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)> 111// CHECK: @depthwise_conv_hwc_114x16x3 112// CHECK-SAME: %[[INPUT:.+]]: tensor<1x114x114x16xf32> 113// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x16xf32> 114// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x112x112x16xf32> 115// CHECK: %[[INPUT_T_INIT:.+]] = tensor.empty() : tensor<1x16x114x114xf32> 116// CHECK: %[[INPUT_T:.+]] = linalg.generic 117// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] 118// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] 119// CHECK-SAME: ins(%[[INPUT]] : tensor<1x114x114x16xf32>) outs(%[[INPUT_T_INIT]] : tensor<1x16x114x114xf32>) { 120// CHECK-NEXT: ^bb0(%[[ARG3:.+]]: f32, %[[ARG4:.+]]: f32): 121// CHECK-NEXT: linalg.yield %[[ARG3]] : f32 122// CHECK-NEXT: } -> tensor<1x16x114x114xf32> 123// CHECK: %[[FILTER_T_INIT:.+]] = tensor.empty() : tensor<16x3x3xf32> 124// CHECK: %[[FILTER_T:.+]] = linalg.generic 125// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]] 126// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] 127// CHECK-SAME: ins(%[[FILTER]] : tensor<3x3x16xf32>) outs(%[[FILTER_T_INIT]] : tensor<16x3x3xf32>) { 128// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): 129// CHECK: linalg.yield 130// CHECK: } -> tensor<16x3x3xf32> 131// CHECK: %[[INIT_OUTPUT_TENSOR:.+]] = tensor.empty() : tensor<1x16x112x112xf32> 132// CHECK: %[[OUTPUT_T:.+]] = linalg.generic 133// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] 134// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] 135// CHECK-SAME: ins(%[[OUTPUT]] : tensor<1x112x112x16xf32>) outs(%[[INIT_OUTPUT_TENSOR]] : tensor<1x16x112x112xf32>) { 136// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): 137// CHECK-NEXT: linalg.yield 138// CHECK-NEXT: } -> tensor<1x16x112x112xf32> 139// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x16x112x112x3x3xf32> 140// CHECK: %[[COL_TENSOR:.+]] = linalg.generic 141// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP5]]] 142// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] 143// CHECK-SAME: ins(%[[INPUT_T]] : tensor<1x16x114x114xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x16x112x112x3x3xf32>) { 144// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): 145// CHECK-NEXT: linalg.yield 146// CHECK-NEXT: } -> tensor<1x16x112x112x3x3xf32> 147// CHECK: %[[COL_TENSOR_R:.+]] = tensor.collapse_shape %[[COL_TENSOR]] 148// CHECK-SAME: tensor<1x16x112x112x3x3xf32> into tensor<16x12544x9xf32> 149// CHECK: %[[FILTER_T_R:.+]] = tensor.collapse_shape %[[FILTER_T]] 150// CHECK-SAME: tensor<16x3x3xf32> into tensor<16x9xf32> 151// CHECK: %[[OUTPUT_T_R:.+]] = tensor.collapse_shape %[[OUTPUT_T]] 152// CHECK-SAME: tensor<1x16x112x112xf32> into tensor<16x12544xf32> 153// CHECK: %[[BMV_RESULT:.+]] = linalg.batch_matvec ins(%[[COL_TENSOR_R]], %[[FILTER_T_R]] : tensor<16x12544x9xf32>, tensor<16x9xf32>) outs(%[[OUTPUT_T_R]] : tensor<16x12544xf32>) -> tensor<16x12544xf32> 154// CHECK: %[[RESULT_R:.+]] = tensor.expand_shape %[[BMV_RESULT]] 155// CHECK-SAME: tensor<16x12544xf32> into tensor<1x16x112x112xf32> 156// CHECK: %[[RESULT_INIT:.+]] = tensor.empty() : tensor<1x112x112x16xf32> 157// CHECK: %[[RESULT:.+]] = linalg.generic 158// CHECK-SAME: indexing_maps = [#[[MAP6]], #[[MAP1]]] 159// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] 160// CHECK-SAME: ins(%[[RESULT_R]] : tensor<1x16x112x112xf32>) outs(%[[RESULT_INIT]] : tensor<1x112x112x16xf32>) { 161// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): 162// CHECK-NEXT: linalg.yield 163// CHECK-NEXT: } -> tensor<1x112x112x16xf32> 164// CHECK: return %[[RESULT]] : tensor<1x112x112x16xf32> 165func.func @depthwise_conv_hwc_114x16x3(%input: tensor<1x114x114x16xf32>, %filter: tensor<3x3x16xf32>, %output: tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32> { 166 %0 = linalg.depthwise_conv_2d_nhwc_hwc { 167 dilations = dense<1> : tensor<2xi64>, 168 strides = dense<1> : tensor<2xi64> 169 } ins(%input, %filter : tensor<1x114x114x16xf32>, tensor<3x3x16xf32>) outs(%output : tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32> 170 return %0 : tensor<1x112x112x16xf32> 171} 172 173module attributes {transform.with_named_sequence} { 174 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 175 %0 = transform.structured.match ops{["linalg.depthwise_conv_2d_nhwc_hwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op 176 %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 177 transform.yield 178 } 179} 180 181// ----- 182 183// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 184// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 185// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> 186// CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 187 188// CHECK: func.func @batch_nhwc_conv 189// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x16x16x4xf32>, %[[FILTER:.+]]: tensor<3x3x4x16xf32>, %[[INIT:.+]]: tensor<8x14x14x16xf32>) 190// CHECK-DAG: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32> 191// CHECK-DAG: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32> 192// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x196x36xf32> 193// CHECK: %[[IMG2COL:.+]] = linalg.generic 194// CHECK-SAME: indexing_maps = [#[[MAP]]] 195// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] 196// CHECK-SAME: outs(%[[IT]] : tensor<8x196x36xf32>) 197// CHECK: %[[MATMUL:.+]] = linalg.generic 198// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]], 199// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] 200// CHECK-SAME: ins(%[[IMG2COL]], %[[CS_FILTER]] : tensor<8x196x36xf32>, tensor<36x16xf32>) 201// CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x196x16xf32>) 202// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32): 203// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 204// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 205// CHECK: linalg.yield %[[ADD]] : f32 206// CHECK: } -> tensor<8x196x16xf32> 207// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] output_shape [8, 14, 14, 16] : tensor<8x196x16xf32> into tensor<8x14x14x16xf32> 208// CHECK: return %[[CS_FINAL]] 209func.func @batch_nhwc_conv(%arg0: tensor<8x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> { 210 %0 = linalg.conv_2d_nhwc_hwcf 211 {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } 212 ins(%arg0, %arg1: tensor<8x16x16x4xf32>, tensor<3x3x4x16xf32>) 213 outs(%arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> 214 return %0 : tensor<8x14x14x16xf32> 215} 216 217module attributes {transform.with_named_sequence} { 218 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 219 %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op 220 %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 221 transform.yield 222 } 223} 224 225// ----- 226 227// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 228 229// Im2col maps 230// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 floordiv 9)> 231// CHECK-DAG: #[[MAP7:.+]] = affine_map<()[s0, s1] -> (s0 floordiv 14 + (s1 mod 9) floordiv 3)> 232// CHECK-DAG: #[[MAP8:.+]] = affine_map<()[s0, s1] -> (s0 + s1 - (s0 floordiv 14) * 14 - (s1 floordiv 3) * 3)> 233 234 235// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> 236// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> 237// CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 238 239// CHECK: func.func @batch_nchw_conv 240// CHECK-SAME: (%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.+]]: tensor<16x4x3x3xf32>, %[[INIT:.+]]: tensor<8x16x14x14xf32>) 241// CHECK-DAG: %[[CS_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x4x3x3xf32> into tensor<16x36xf32> 242// CHECK-DAG: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32> 243// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x36x196xf32> 244// CHECK: %[[IMG2COL:.+]] = linalg.generic 245// CHECK-SAME: indexing_maps = [#[[MAP]]] 246// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] 247// CHECK-SAME: outs(%[[IT]] : tensor<8x36x196xf32>) 248// Collapsed indices. 249// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index 250// CHECK: %[[KINDEX:.+]] = linalg.index 1 : index 251// CHECK: %[[NINDEX:.+]] = linalg.index 2 : index 252 253// Compute input channel/convolved indices. 254// CHECK: %[[ICINDEX:.+]] = affine.apply #[[MAP1]]()[%[[KINDEX]]] 255// CHECK: %[[CONVH:.+]] = affine.apply #[[MAP7]]()[%[[NINDEX]], %[[KINDEX]]] 256// CHECK: %[[CONVW:.+]] = affine.apply #[[MAP8]]()[%[[NINDEX]], %[[KINDEX]]] 257 258// Extract from the input tensor. 259// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract 260// CHECK-SAME: %[[INPUT]]{{\[}}%[[BINDEX]], %[[ICINDEX]], %[[CONVH]], %[[CONVW]]] : tensor<8x4x16x16xf32> 261// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32 262// CHECK: %[[MATMUL:.+]] = linalg.generic 263// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]], 264// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] 265// CHECK-SAME: ins(%[[CS_FILTER]], %[[IMG2COL]] : tensor<16x36xf32>, tensor<8x36x196xf32>) 266// CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x16x196xf32>) 267// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32): 268// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 269// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 270// CHECK: linalg.yield %[[ADD]] : f32 271// CHECK: } -> tensor<8x16x196xf32> 272// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] output_shape [8, 16, 14, 14] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32> 273// CHECK: return %[[CS_FINAL]] 274func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> { 275 %0 = linalg.conv_2d_nchw_fchw 276 {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } 277 ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) 278 outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> 279 return %0 : tensor<8x16x14x14xf32> 280} 281 282module attributes {transform.with_named_sequence} { 283 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 284 %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op 285 %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 286 transform.yield 287 } 288} 289 290// ----- 291 292// CHECK: IR printer: tensor_producer 293// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic 294// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>] 295// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32) 296 297// Collapsed indices. 298// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index 299// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index 300// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index 301 302// Compute input channel/convolved indices. 303// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]] 304// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]] 305// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]] 306 307// Extract from the input tensor. 308// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract 309// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32> 310// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32 311 312// CHECK: IR printer: transformed 313// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> 314 315// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 316// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 317// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> 318// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 319// CHECK: @conv_2d_nhwc_fhwc 320// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32> 321// CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32> 322// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32> 323// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32> 324// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32> 325// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32> 326// CHECK: %[[COL_TENSOR:.+]] = linalg.generic 327// CHECK-SAME: #[[MAP0]] 328// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32) 329// CHECK: linalg.yield %{{.+}} : f32 330// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic 331// CHECK-SAME: #[[MAP1]] 332// CHECK-SAME: #[[MAP2]] 333// CHECK-SAME: #[[MAP3]] 334// CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>) 335// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>) 336// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32) 337// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 338// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 339// CHECK: linalg.yield %[[ADD]] : f32 340// CHECK: } -> tensor<1x196x16xf32> 341// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> 342// CHECK: return %[[RESULT]] 343 344func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { 345 %0 = linalg.conv_2d_nhwc_fhwc 346 {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } 347 ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>) 348 outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> 349 return %0 : tensor<1x14x14x16xf32> 350} 351 352module attributes {transform.with_named_sequence} { 353 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 354 %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op 355 %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 356 transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op 357 transform.print %transformed {name = "transformed"}: !transform.any_op 358 transform.yield 359 } 360} 361 362// ----- 363 364// Check for signed extend when the input type is smaller than the accumulator type. 365 366// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 367// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 368// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> 369// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 370// CHECK: @conv_integer_extend 371// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]] 372// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xi8>, tensor<36x16xi8>) 373// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xi32>) 374// CHECK: ^bb0(%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8, %[[ARG2:.+]]: i32) 375// CHECK: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 376// CHECK: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 377// CHECK: %[[MUL:.+]] = arith.muli %[[EXT0]], %[[EXT1]] : i32 378// CHECK: %[[ADD:.+]] = arith.addi %[[MUL]], %[[ARG2]] : i32 379// CHECK: linalg.yield %[[ADD]] : i32 380// CHECK: } -> tensor<1x196x16xi32> 381// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xi32> into tensor<1x14x14x16xi32> 382// CHECK: return %[[RESULT]] 383 384func.func @conv_integer_extend(%arg0: tensor<1x16x16x4xi8>, %arg1: tensor<3x3x4x16xi8>, %arg2: tensor<1x14x14x16xi32>) -> tensor<1x14x14x16xi32> { 385 %0 = linalg.conv_2d_nhwc_hwcf 386 {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } 387 ins(%arg0, %arg1: tensor<1x16x16x4xi8>, tensor<3x3x4x16xi8>) 388 outs(%arg2: tensor<1x14x14x16xi32>) -> tensor<1x14x14x16xi32> 389 return %0 : tensor<1x14x14x16xi32> 390} 391 392module attributes {transform.with_named_sequence} { 393 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 394 %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op 395 %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 396 transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op 397 transform.print %transformed {name = "transformed"}: !transform.any_op 398 transform.yield 399 } 400} 401 402// ----- 403 404// Check for compatible complex case. 405 406// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 407// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 408// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> 409// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 410// CHECK: @conv_complex 411// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]] 412// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xcomplex<f32>>, tensor<36x16xcomplex<f32>>) 413// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xcomplex<f32>>) 414// CHECK: ^bb0(%[[ARG0:.+]]: complex<f32>, %[[ARG1:.+]]: complex<f32>, %[[ARG2:.+]]: complex<f32>) 415// CHECK: %[[MUL:.+]] = complex.mul %[[ARG0]], %[[ARG1]] : complex<f32> 416// CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32> 417// CHECK: linalg.yield %[[ADD]] : complex<f32> 418// CHECK: } -> tensor<1x196x16xcomplex<f32>> 419// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>> 420// CHECK: return %[[RESULT]] 421 422func.func @conv_complex(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xcomplex<f32>>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> { 423 %0 = linalg.conv_2d_nhwc_hwcf 424 {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } 425 ins(%arg0, %arg1: tensor<1x16x16x4xcomplex<f32>>, tensor<3x3x4x16xcomplex<f32>>) 426 outs(%arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> 427 return %0 : tensor<1x14x14x16xcomplex<f32>> 428} 429 430module attributes {transform.with_named_sequence} { 431 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 432 %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op 433 %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 434 transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op 435 transform.print %transformed {name = "transformed"}: !transform.any_op 436 transform.yield 437 } 438} 439 440// ----- 441 442// Check for compatible complex extended case. 443 444// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 445// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 446// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> 447// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 448// CHECK: @conv_complex_extended 449// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]] 450// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xcomplex<f32>>, tensor<36x16xcomplex<f16>>) 451// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xcomplex<f32>>) 452// CHECK: ^bb0(%[[ARG0:.+]]: complex<f32>, %[[ARG1:.+]]: complex<f16>, %[[ARG2:.+]]: complex<f32>) 453// CHECK: %[[REAL:.+]] = complex.re %[[ARG1]] : complex<f16> 454// CHECK: %[[IMAG:.+]] = complex.im %[[ARG1]] : complex<f16> 455// CHECK: %[[REEXT:.+]] = arith.extf %[[REAL]] : f16 to f32 456// CHECK: %[[IMEXT:.+]] = arith.extf %[[IMAG]] : f16 to f32 457// CHECK: %[[COMPLEX:.+]] = complex.create %[[REEXT]], %[[IMEXT]] : complex<f32> 458// CHECK: %[[MUL:.+]] = complex.mul %[[ARG0]], %[[COMPLEX]] : complex<f32> 459// CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32> 460// CHECK: linalg.yield %[[ADD]] : complex<f32> 461// CHECK: } -> tensor<1x196x16xcomplex<f32>> 462// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>> 463// CHECK: return %[[RESULT]] 464 465func.func @conv_complex_extended(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xcomplex<f16>>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> { 466 %0 = linalg.conv_2d_nhwc_hwcf 467 {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } 468 ins(%arg0, %arg1: tensor<1x16x16x4xcomplex<f32>>, tensor<3x3x4x16xcomplex<f16>>) 469 outs(%arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> 470 return %0 : tensor<1x14x14x16xcomplex<f32>> 471} 472 473module attributes {transform.with_named_sequence} { 474 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 475 %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op 476 %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 477 transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op 478 transform.print %transformed {name = "transformed"}: !transform.any_op 479 transform.yield 480 } 481} 482 483// ----- 484 485// Check for compatible complex extended case. 486 487// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 488// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> 489// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> 490// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 491// CHECK: @conv_complex_f16_extended 492// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]] 493// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xcomplex<f32>>, tensor<36x16xf16>) 494// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xcomplex<f32>>) 495// CHECK: ^bb0(%[[ARG0:.+]]: complex<f32>, %[[ARG1:.+]]: f16, %[[ARG2:.+]]: complex<f32>) 496// CHECK: %[[EXT:.+]] = arith.extf %[[ARG1]] : f16 to f32 497// CHECK: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 498// CHECK: %[[COMPLEX:.+]] = complex.create %[[EXT]], %[[ZERO]] 499// CHECK: %[[MUL:.+]] = complex.mul %[[ARG0]], %[[COMPLEX]] : complex<f32> 500// CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32> 501// CHECK: linalg.yield %[[ADD]] : complex<f32> 502// CHECK: } -> tensor<1x196x16xcomplex<f32>> 503// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>> 504// CHECK: return %[[RESULT]] 505 506func.func @conv_complex_f16_extended(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xf16>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> { 507 %0 = linalg.conv_2d_nhwc_hwcf 508 {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } 509 ins(%arg0, %arg1: tensor<1x16x16x4xcomplex<f32>>, tensor<3x3x4x16xf16>) 510 outs(%arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> 511 return %0 : tensor<1x14x14x16xcomplex<f32>> 512} 513 514module attributes {transform.with_named_sequence} { 515 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 516 %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op 517 %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 518 transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op 519 transform.print %transformed {name = "transformed"}: !transform.any_op 520 transform.yield 521 } 522} 523