1// RUN: mlir-opt %s --transform-interpreter --verify-diagnostics 2 3module attributes { transform.with_named_sequence } { 4 transform.named_sequence @_reduce_leading_trailing(%entry: !transform.any_op {transform.readonly}) 5 -> (!transform.any_op) { 6 %c1 = transform.param.constant 1 : i64 -> !transform.param<i64> 7 8 transform.match.structured %entry : !transform.any_op { 9 ^bb0(%struct: !transform.any_op): 10 transform.match.structured.dim %struct[all] {parallel} : !transform.any_op 11 transform.match.structured.input %struct[all] {projected_permutation} : !transform.any_op 12 transform.match.structured.init %struct[all] {permutation} : !transform.any_op 13 %ni = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param<i64> 14 transform.match.param.cmpi eq %ni, %c1 : !transform.param<i64> 15 } 16 transform.yield %entry : !transform.any_op 17 } 18 19 transform.named_sequence @fill_reduce_leading_trailing(%entry: !transform.any_op {transform.readonly}) 20 -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, 21 !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) { 22 %c1 = transform.param.constant 1 : i64 -> !transform.param<i64> 23 %c2 = transform.param.constant 2 : i64 -> !transform.param<i64> 24 %c4 = transform.param.constant 4 : i64 -> !transform.param<i64> 25 26 %rk, %dms, %bw, %operand_o, %init_v, %trailing_o = transform.match.structured failures(propagate) %entry 27 : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, 28 !transform.any_op, !transform.any_value, !transform.any_op) { 29 ^bb0(%struct: !transform.any_op): 30 %rank = transform.match.structured.rank %struct : (!transform.any_op) -> !transform.param<i64> 31 transform.match.param.cmpi ge %rank, %c2 : !transform.param<i64> 32 transform.match.param.cmpi le %rank, %c4 : !transform.param<i64> 33 34 transform.match.structured.dim %struct[-1] {reduction} : !transform.any_op 35 transform.match.structured.dim %struct[except(-1)] {parallel} : !transform.any_op 36 %dims = transform.match.structured.dim %struct[all] : (!transform.any_op) -> !transform.param<i64> 37 38 %n_inputs = transform.match.structured.num_inputs %struct : (!transform.any_op) -> !transform.param<i64> 39 %n_outputs = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param<i64> 40 transform.match.param.cmpi eq %n_inputs, %c1 : !transform.param<i64> 41 transform.match.param.cmpi eq %n_outputs, %c1 : !transform.param<i64> 42 43 transform.match.structured.input %struct[0] {projected_permutation} : !transform.any_op 44 transform.match.structured.init %struct[0] {projected_permutation} : !transform.any_op 45 %init = transform.match.structured.init %struct[0] : (!transform.any_op) -> !transform.any_value 46 47 // This danse is necessary to create an empty handle if there is no single 48 // user without failing the entire match 49 %trailing_optional = transform.sequence %struct : (!transform.any_op) -> !transform.any_op failures(suppress) { 50 ^bb0(%struct_inner: !transform.any_op): 51 %result = transform.match.structured failures(propagate) %struct_inner : (!transform.any_op) -> !transform.any_op { 52 ^bb0(%struct_inner_inner: !transform.any_op): 53 %result_inner = transform.match.structured.result %struct_inner_inner[0] {single} : (!transform.any_op) -> !transform.any_op 54 %trailing = transform.include @_reduce_leading_trailing failures(propagate) (%result_inner) : (!transform.any_op) -> !transform.any_op 55 transform.match.structured.yield %trailing : !transform.any_op 56 } 57 transform.yield %result: !transform.any_op 58 } 59 60 // Suppress errors as a way to implement optionality. We cannot suppress them in 61 // the include because it keeps matching after "get_defining_op" fails, which 62 // breaks the single-op precondition of the following ops. We don't want to 63 // propagate that failure though. 64 // 65 // Additionally, we cannot put the sequence inside the call because its first 66 // operand must be an operation handle (the verifier asserts!) and there is 67 // no such handle available there. 68 // 69 // TODO: extend the structured matching to gracefully handle empty handles 70 // or provide the suppress-errors-but-stop failure mode for includes to 71 // implement optionality. 72 %operand_optional = transform.sequence %struct : (!transform.any_op) -> !transform.any_op failures(suppress) { 73 ^bb0(%struct_inner: !transform.any_op): 74 %operand3 = transform.match.structured failures(propagate) %struct_inner : (!transform.any_op) -> !transform.any_op { 75 ^bb1(%struct_inner_inner: !transform.any_op): 76 %operand = transform.match.structured.input %struct_inner_inner[0] : (!transform.any_op) -> !transform.any_op 77 %operand2 = transform.include @_reduce_leading_trailing failures(propagate) (%operand) : (!transform.any_op) -> !transform.any_op 78 transform.match.structured.yield %operand2 : !transform.any_op 79 } 80 transform.yield %operand3 : !transform.any_op 81 } 82 83 %bitwidth = transform.match.structured.elemental_bitwidth %init : (!transform.any_value) -> !transform.param<i64> 84 85 transform.match.structured.body %struct { reduction_position = 0 } : !transform.any_op 86 transform.match.structured.yield %rank, %dims, %bitwidth, %operand_optional, %init, %trailing_optional 87 : !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, 88 !transform.any_op, !transform.any_value, !transform.any_op 89 } 90 91 %init_o = transform.get_defining_op %init_v : (!transform.any_value) -> !transform.any_op 92 transform.match.operation_name %init_o ["linalg.fill"] : !transform.any_op 93 94 transform.yield %operand_o, %init_o, %entry, %trailing_o, %rk, %dms, %bw 95 : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, 96 !transform.param<i64>, !transform.param<i64>, !transform.param<i64> 97 } 98 99 transform.named_sequence @print_reduce_leading_trailing( 100 %leading: !transform.any_op {transform.readonly}, 101 %fill: !transform.any_op {transform.readonly}, 102 %reduction: !transform.any_op {transform.readonly}, 103 %trailing: !transform.any_op {transform.readonly}, 104 %rank: !transform.param<i64> {transform.readonly}, 105 %dims: !transform.param<i64> {transform.readonly}, 106 %bitwidth: !transform.param<i64> {transform.readonly}) { 107 transform.debug.emit_remark_at %leading, "leading" : !transform.any_op 108 transform.debug.emit_remark_at %fill, "fill" : !transform.any_op 109 transform.debug.emit_remark_at %reduction, "reduction" : !transform.any_op 110 transform.debug.emit_remark_at %trailing, "trailing" : !transform.any_op 111 transform.debug.emit_param_as_remark %rank, "rank" at %reduction : !transform.param<i64>, !transform.any_op 112 transform.debug.emit_param_as_remark %dims, "dimensions" at %reduction : !transform.param<i64>, !transform.any_op 113 transform.debug.emit_param_as_remark %bitwidth, "bitwidth" at %reduction : !transform.param<i64>, !transform.any_op 114 transform.yield 115 } 116 117 transform.named_sequence @__transform_main(%root: !transform.any_op {transform.consumed}) { 118 transform.foreach_match in %root 119 @fill_reduce_leading_trailing -> @print_reduce_leading_trailing 120 : (!transform.any_op) -> !transform.any_op 121 transform.yield 122 } 123} 124 125!in_tensor_t = tensor<8x64xf32> 126!out_tensor_t = tensor<8xf32> 127 128func.func @eltwise_reduce(%arg : !in_tensor_t) -> (!out_tensor_t) { 129 %cst = arith.constant -0.000000e+00 : f32 130 131 %0 = tensor.empty() : !out_tensor_t 132 // expected-remark @below {{fill}} 133 %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t 134 %2 = tensor.empty() : !in_tensor_t 135 // expected-remark @below {{leading}} 136 %3 = linalg.generic { 137 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 138 affine_map<(d0, d1) -> (d0, d1)>], 139 iterator_types = ["parallel", "parallel"]} 140 ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) { 141 ^bb0(%arg3: f32, %arg4: f32): 142 %4 = arith.addf %arg3, %arg3 : f32 143 %5 = arith.addf %4, %4 : f32 144 linalg.yield %5 : f32 145 } -> !in_tensor_t 146 147 // expected-remark @below {{reduction}} 148 // expected-remark @below {{rank 2}} 149 // expected-remark @below {{dimensions 8 : i64, 64 : i64}} 150 // expected-remark @below {{bitwidth 32 : i64}} 151 %6 = linalg.generic { 152 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 153 affine_map<(d0, d1) -> (d0)>], 154 iterator_types = ["parallel", "reduction"]} 155 ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) { 156 ^bb0(%arg3: f32, %arg4: f32): 157 %4 = arith.addf %arg3, %arg4 : f32 158 linalg.yield %4 : f32 159 } -> !out_tensor_t 160 161 return %6 : !out_tensor_t 162} 163 164func.func @reduce_eltwise(%arg : !in_tensor_t) -> (!out_tensor_t) { 165 %cst = arith.constant -0.000000e+00 : f32 166 167 %0 = tensor.empty() : !out_tensor_t 168 // expected-remark @below {{fill}} 169 %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t 170 // expected-remark @below {{reduction}} 171 // expected-remark @below {{rank 2}} 172 // expected-remark @below {{dimensions 8 : i64, 64 : i64}} 173 // expected-remark @below {{bitwidth 32 : i64}} 174 %5 = linalg.generic { 175 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 176 affine_map<(d0, d1) -> (d0)>], 177 iterator_types = ["parallel", "reduction"]} 178 ins(%arg : !in_tensor_t) outs(%1 : !out_tensor_t) { 179 ^bb0(%arg3: f32, %arg4: f32): 180 %4 = arith.addf %arg3, %arg4 : f32 181 linalg.yield %4 : f32 182 } -> !out_tensor_t 183 184 %6 = tensor.empty() : !out_tensor_t 185 // expected-remark @below {{trailing}} 186 %7 = linalg.generic { 187 indexing_maps = [affine_map<(d0) -> (d0)>, 188 affine_map<(d0) -> (d0)>], 189 iterator_types = ["parallel"]} 190 ins(%5 : !out_tensor_t) outs(%6 : !out_tensor_t) { 191 ^bb0(%arg3: f32, %arg4: f32): 192 %4 = math.sqrt %arg3 : f32 193 linalg.yield %4 : f32 194 } -> !out_tensor_t 195 return %7 : !out_tensor_t 196} 197 198func.func @eltwise_reduce_eltwise(%arg : !in_tensor_t) -> (!out_tensor_t) { 199 %cst = arith.constant -0.000000e+00 : f32 200 201 %0 = tensor.empty() : !out_tensor_t 202 // expected-remark @below {{fill}} 203 %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t 204 %2 = tensor.empty() : !in_tensor_t 205 // expected-remark @below {{leading}} 206 %3 = linalg.generic { 207 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 208 affine_map<(d0, d1) -> (d0, d1)>], 209 iterator_types = ["parallel", "parallel"]} 210 ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) { 211 ^bb0(%arg3: f32, %arg4: f32): 212 %4 = arith.addf %arg3, %arg3 : f32 213 %5 = arith.addf %4, %4 : f32 214 linalg.yield %5 : f32 215 } -> !in_tensor_t 216 217 // expected-remark @below {{reduction}} 218 // expected-remark @below {{rank 2}} 219 // expected-remark @below {{dimensions 8 : i64, 64 : i64}} 220 // expected-remark @below {{bitwidth 32 : i64}} 221 %6 = linalg.generic { 222 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 223 affine_map<(d0, d1) -> (d0)>], 224 iterator_types = ["parallel", "reduction"]} 225 ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) { 226 ^bb0(%arg3: f32, %arg4: f32): 227 %4 = arith.addf %arg3, %arg4 : f32 228 linalg.yield %4 : f32 229 } -> !out_tensor_t 230 231 %7 = tensor.empty() : !out_tensor_t 232 // expected-remark @below {{trailing}} 233 %8 = linalg.generic { 234 indexing_maps = [affine_map<(d0) -> (d0)>, 235 affine_map<(d0) -> (d0)>], 236 iterator_types = ["parallel"]} 237 ins(%6 : !out_tensor_t) outs(%7 : !out_tensor_t) { 238 ^bb0(%arg3: f32, %arg4: f32): 239 %4 = math.sqrt %arg3 : f32 240 linalg.yield %4 : f32 241 } -> !out_tensor_t 242 243 244 return %8 : !out_tensor_t 245} 246 247func.func @eltwise_reduce_eltwise_swapped(%arg : !in_tensor_t) -> (!out_tensor_t) { 248 %cst = arith.constant -0.000000e+00 : f32 249 250 %2 = tensor.empty() : !in_tensor_t 251 // expected-remark @below {{leading}} 252 %3 = linalg.generic { 253 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 254 affine_map<(d0, d1) -> (d0, d1)>], 255 iterator_types = ["parallel", "parallel"]} 256 ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) { 257 ^bb0(%arg3: f32, %arg4: f32): 258 %4 = arith.addf %arg3, %arg3 : f32 259 %5 = arith.addf %4, %4 : f32 260 linalg.yield %5 : f32 261 } -> !in_tensor_t 262 263 %0 = tensor.empty() : !out_tensor_t 264 // expected-remark @below {{fill}} 265 %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t 266 // expected-remark @below {{reduction}} 267 // expected-remark @below {{rank 2}} 268 // expected-remark @below {{dimensions 8 : i64, 64 : i64}} 269 // expected-remark @below {{bitwidth 32 : i64}} 270 %6 = linalg.generic { 271 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 272 affine_map<(d0, d1) -> (d0)>], 273 iterator_types = ["parallel", "reduction"]} 274 ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) { 275 ^bb0(%arg3: f32, %arg4: f32): 276 %4 = arith.addf %arg3, %arg4 : f32 277 linalg.yield %4 : f32 278 } -> !out_tensor_t 279 280 %7 = tensor.empty() : !out_tensor_t 281 // expected-remark @below {{trailing}} 282 %8 = linalg.generic { 283 indexing_maps = [affine_map<(d0) -> (d0)>, 284 affine_map<(d0) -> (d0)>], 285 iterator_types = ["parallel"]} 286 ins(%6 : !out_tensor_t) outs(%7 : !out_tensor_t) { 287 ^bb0(%arg3: f32, %arg4: f32): 288 %4 = math.sqrt %arg3 : f32 289 linalg.yield %4 : f32 290 } -> !out_tensor_t 291 292 293 return %8 : !out_tensor_t 294} 295 296func.func @reduction_with_extra_op_in_func(%arg0: tensor<8x479xf32>, %arg1: tensor<32x32xf32>) -> (tensor<8xf32>, tensor<32xf32>) { 297 %cst = arith.constant 0.0 : f32 298 %empty = tensor.empty() : tensor<8xf32> 299 // expected-remark @below {{fill}} 300 %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<8xf32>) -> tensor<8xf32> 301 // expected-remark @below {{reduction}} 302 // expected-remark @below {{rank 2}} 303 // expected-remark @below {{dimensions 8 : i64, 479 : i64}} 304 // expected-remark @below {{bitwidth 32 : i64}} 305 %result = linalg.generic { 306 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 307 affine_map<(d0, d1) -> (d0)>], 308 iterator_types = ["parallel", "reduction"]} 309 ins(%arg0 : tensor<8x479xf32>) 310 outs(%fill : tensor<8xf32>) { 311 ^bb0(%in: f32, %out: f32): 312 %6 = arith.addf %in, %out : f32 313 linalg.yield %6 : f32 314 } -> tensor<8xf32> 315 316 %empty2 = tensor.empty() : tensor<32xf32> 317 %fill2 = linalg.fill ins(%cst : f32) outs(%empty2 : tensor<32xf32>) -> tensor<32xf32> 318 return %result, %fill2 : tensor<8xf32>, tensor<32xf32> 319} 320