1// RUN: mlir-opt --transform-interpreter --split-input-file %s | FileCheck %s 2 3// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 4// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> 5 6// CHECK-LABEL: @conv_2d_nhwc_hwcf 7// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>, 8// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32> 9// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32> 10func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> { 11 // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] 12 // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] 13 // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] 14 // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_nwc_wcf 15 // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] 16 %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, 17 strides = dense<1> : tensor<2xi64>} 18 ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>) 19 outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> 20 // CHECK: return %[[RES]] 21 return %0 : tensor<?x1x?x?xf32> 22} 23 24// CHECK-LABEL: @conv_2d_nchw_fchw 25// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>, 26// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x1x?xf32>, 27// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>) 28func.func @conv_2d_nchw_fchw(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x?x1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> { 29 // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] 30 // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] 31 // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] 32 // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_ncw_fcw 33 // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] 34 %0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, 35 strides = dense<1> : tensor<2xi64>} 36 ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>) 37 outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> 38 // CHECK: return %[[RES]] 39 return %0 : tensor<?x?x1x?xf32> 40} 41 42// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc 43// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32> 44// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32> 45func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> { 46 // CHECK: %[[RES:.+]] = tensor.empty 47 %init = tensor.empty() : tensor<1x1x56x96xf32> 48 // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] 49 // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] 50 // CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]] 51 // CHECK: %[[OPRES:.+]] = linalg.depthwise_conv_1d_nwc_wc 52 // CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] 53 // CHECK-SAME: outs(%[[SLICERES]] 54 // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[OPRES]] into %[[RES]] 55 %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} 56 ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>) 57 outs(%init: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32> 58 // CHECK: %[[INSERTED]] 59 return %0: tensor<1x1x56x96xf32> 60} 61 62// CHECK-LABEL: @conv_2d 63// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<1x?xf32>, 64// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>, 65// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<1x?xf32>) 66func.func @conv_2d(%input: tensor<1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<1x?xf32>) -> tensor<1x?xf32> { 67 // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] 68 // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] 69 // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] 70 // CHECK: %[[SLICERES:.+]] = linalg.conv_1d 71 // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] 72 %0 = linalg.conv_2d 73 ins (%input, %filter: tensor<1x?xf32>, tensor<1x?xf32>) 74 outs (%init: tensor<1x?xf32>) -> tensor<1x?xf32> 75 // CHECK: return %[[RES]] 76 return %0 : tensor<1x?xf32> 77} 78 79// CHECK-LABEL: @pooling_nhwc_sum 80// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>, 81// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> 82// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32> 83func.func @pooling_nhwc_sum(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> { 84 // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] 85 // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] 86 // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] 87 // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_sum 88 // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] 89 %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, 90 strides = dense<1> : tensor<2xi64>} 91 ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>) 92 outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> 93 // CHECK: return %[[RES]] 94 return %0 : tensor<?x1x?x?xf32> 95} 96 97// CHECK-LABEL: @pooling_nchw_sum 98// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>, 99// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>, 100// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>) 101func.func @pooling_nchw_sum(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> { 102 // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] 103 // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] 104 // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] 105 // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_sum 106 // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] 107 %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>, 108 strides = dense<1> : tensor<2xi64>} 109 ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>) 110 outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> 111 // CHECK: return %[[RES]] 112 return %0 : tensor<?x?x1x?xf32> 113} 114 115// CHECK-LABEL: @pooling_nhwc_max 116// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>, 117// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> 118// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32> 119func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> { 120 // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] 121 // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] 122 // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] 123 // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max 124 // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] 125 %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, 126 strides = dense<1> : tensor<2xi64>} 127 ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>) 128 outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> 129 // CHECK: return %[[RES]] 130 return %0 : tensor<?x1x?x?xf32> 131} 132 133// CHECK-LABEL: @pooling_nhwc_max_unsigned 134// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>, 135// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> 136// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32> 137func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> { 138 // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] 139 // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] 140 // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] 141 // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max_unsigned 142 // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] 143 %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, 144 strides = dense<1> : tensor<2xi64>} 145 ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>) 146 outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> 147 // CHECK: return %[[RES]] 148 return %0 : tensor<?x1x?x?xf32> 149} 150 151// CHECK-LABEL: @pooling_nhwc_min 152// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>, 153// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> 154// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32> 155func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> { 156 // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] 157 // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] 158 // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] 159 // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min 160 // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] 161 %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, 162 strides = dense<1> : tensor<2xi64>} 163 ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>) 164 outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> 165 // CHECK: return %[[RES]] 166 return %0 : tensor<?x1x?x?xf32> 167} 168 169// CHECK-LABEL: @pooling_nhwc_min_unsigned 170// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>, 171// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> 172// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32> 173func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> { 174 // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] 175 // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] 176 // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] 177 // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min_unsigned 178 // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] 179 %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, 180 strides = dense<1> : tensor<2xi64>} 181 ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>) 182 outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> 183 // CHECK: return %[[RES]] 184 return %0 : tensor<?x1x?x?xf32> 185} 186 187// CHECK-LABEL: @pooling_nchw_max 188// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>, 189// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>, 190// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>) 191func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> { 192 // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] 193 // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] 194 // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] 195 // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_max 196 // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] 197 %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, 198 strides = dense<1> : tensor<2xi64>} 199 ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>) 200 outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> 201 // CHECK: return %[[RES]] 202 return %0 : tensor<?x?x1x?xf32> 203} 204 205func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { 206 %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> 207 return %1 : tensor<2x16x32xf32> 208} 209 210// CHECK-LABEL: func.func @softmax( 211// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { 212// CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32> 213// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFFC00000 : f32 214// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32> 215// CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", 216// CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) { 217// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32): 218// CHECK: %[[D8:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32 219// CHECK: linalg.yield %[[D8]] : f32 220// CHECK: } -> tensor<2x16xf32> 221// CHECK: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types = 222// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>) 223// CHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) { 224// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32): 225// CHECK: %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32 226// CHECK: %[[D9:.+]] = math.exp %[[D8]] : f32 227// CHECK: linalg.yield %[[D9]] : f32 228// CHECK: } -> tensor<2x16x32xf32> 229// CHECK: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32 230// CHECK: %[[D5:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32> 231// CHECK: %[[D6:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", 232// CHECK-SAME: "parallel", "reduction"]} ins(%[[D4]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) { 233// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32): 234// CHECK: %[[D8]] = arith.addf %[[IN]], %[[OUT]] : f32 235// CHECK: linalg.yield %[[D8]] : f32 236// CHECK: } -> tensor<2x16xf32> 237// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types = 238// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>) 239// CHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) { 240// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32): 241// CHECK: %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32 242// CHECK: linalg.yield %[[D8]] : f32 243// CHECK: } -> tensor<2x16x32xf32> 244// CHECK: return %[[D7]] : tensor<2x16x32xf32> 245 246module attributes {transform.with_named_sequence} { 247 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 248 %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op 249 %1 = transform.structured.decompose %0 : (!transform.any_op) -> !transform.any_op 250 251 %2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op 252 %3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> !transform.any_op 253 transform.yield 254 } 255} 256