1// RUN: mlir-opt --transform-interpreter --split-input-file %s -verify-diagnostics | FileCheck %s 2 3#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> 4#map1 = affine_map<(d0)[s0] -> (d0 * s0)> 5#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> 6 7module { 8 // CHECK-LABEL: func.func @fuse_tileable_op 9 // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index 10 // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32> 11 // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<?xf32> 12 func.func @fuse_tileable_op(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> { 13 %cst = arith.constant 4.200000e+01 : f32 14 %c0 = arith.constant 0 : index 15 %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32> 16 %d0 = tensor.dim %arg1, %c0 : tensor<?xf32> 17 %1 = affine.apply #map0()[%d0, %arg0] 18 19 // CHECK: scf.forall {{.*}} { 20 %2 = scf.forall (%arg3) in (%1) shared_outs(%o = %arg2) -> (tensor<?xf32>) { 21 %3 = affine.apply #map1(%arg3)[%arg0] 22 %4 = affine.min #map2(%arg3)[%d0, %arg0] 23 %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32> 24 25 // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}] 26 // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]] 27 %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32> 28 29 // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]] 30 %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32> 31 scf.forall.in_parallel { 32 tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32> 33 } 34 } 35 // CHECK: } 36 func.return %2 : tensor<?xf32> 37 } 38 39 // Check no failure when nothing happens. 40 func.func @dummy1() { return } 41 func.func @dummy2() { return } 42 func.func @dummy3() { return } 43 44 module attributes {transform.with_named_sequence} { 45 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 46 %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.fill"> 47 %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> 48 49 // linalg.fill is tileable. The op is tiled and fused. 50 transform.structured.fuse_into_containing_op %0 into %1 51 : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) 52 transform.yield 53 } 54 } 55} 56 57// ----- 58 59#map0 = affine_map<()[s0] -> (64 ceildiv s0)> 60#map1 = affine_map<(d0)[s0] -> (d0 * s0)> 61#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)> 62 63module { 64 // CHECK-LABEL: func.func @fuse_untileable_op 65 // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index 66 // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32> 67 // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32> 68 func.func @fuse_untileable_op(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { 69 %0 = tensor.empty(%arg0) : tensor<?xf32> 70 %1 = affine.apply #map0()[%arg0] 71 72 // CHECK: scf.forall {{.*}} { 73 %2 = scf.forall (%arg3) in (%1) shared_outs(%o = %arg2) -> (tensor<64xf32>) { 74 // CHECK: %[[INIT_TENSOR:.*]] = tensor.empty 75 %3 = affine.apply #map1(%arg3)[%arg0] 76 %4 = affine.min #map2(%arg3)[%arg0] 77 %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<64xf32> to tensor<?xf32> 78 79 // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[INIT_TENSOR]] 80 %7 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32> 81 scf.forall.in_parallel { 82 tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<64xf32> 83 } 84 } 85 // CHECK: } 86 87 func.return %2 : tensor<64xf32> 88 } 89 90 module attributes {transform.with_named_sequence} { 91 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 92 %0 = transform.structured.match ops{["tensor.empty"]} in %arg1 : (!transform.any_op) -> !transform.op<"tensor.empty"> 93 %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> 94 95 // tensor.empty is not tileable. The op is cloned and fused. 96 transform.structured.fuse_into_containing_op %0 into %1 97 : (!transform.op<"tensor.empty">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) 98 transform.yield 99 } 100 } 101} 102 103// ----- 104 105module { 106 func.func @foo(%0: tensor<f32>) -> tensor<f32> { 107 return %0: tensor<f32> 108 } 109 110 // CHECK-LABEL: func.func @fuse_tileable_op_rank_reducing 111 // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index 112 // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32> 113 // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<?xf32> 114 func.func @fuse_tileable_op_rank_reducing(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> { 115 %cst = arith.constant 4.200000e+01 : f32 116 %c0 = arith.constant 0 : index 117 %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32> 118 %d0 = tensor.dim %arg1, %c0 : tensor<?xf32> 119 120 // CHECK: scf.forall {{.*}} -> (tensor<?xf32>) { 121 %2 = scf.forall (%arg3) in (%d0) shared_outs(%o = %0) -> (tensor<?xf32>) { 122 %5 = tensor.extract_slice %o[%arg3] [1] [1] : tensor<?xf32> to tensor<f32> 123 124 // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [1] [1] : tensor<?xf32> to tensor<1xf32> 125 // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<1xf32>) -> tensor<1xf32> 126 // CHECK: tensor.extract_slice %{{.*}}[0] [1] [1] : tensor<1xf32> to tensor<f32> 127 // CHECK: func.call @foo(%{{.*}}) : (tensor<f32>) -> tensor<f32> 128 %7 = func.call @foo(%5) : (tensor<f32>) -> tensor<f32> 129 130 scf.forall.in_parallel { 131 // CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [1] [1] : tensor<f32> into tensor<?xf32> 132 tensor.parallel_insert_slice %7 into %o[%arg3] [1] [1] : tensor<f32> into tensor<?xf32> 133 } 134 } 135 // CHECK: } 136 func.return %2 : tensor<?xf32> 137 } 138 139 module attributes {transform.with_named_sequence} { 140 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 141 %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.fill"> 142 %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> 143 144 // linalg.fill is tileable. The op is tiled and fused. 145 transform.structured.fuse_into_containing_op %0 into %1 146 : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) 147 transform.yield 148 } 149 } 150} 151 152// ----- 153 154#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> 155#map1 = affine_map<(d0)[s0] -> (d0 * s0)> 156#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> 157 158module { 159 // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg 160 // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index 161 // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32> 162 // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<?xf32> 163 func.func @fuse_tileable_op_through_bbarg(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> { 164 %cst = arith.constant 4.200000e+01 : f32 165 %c0 = arith.constant 0 : index 166 %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32> 167 %d0 = tensor.dim %arg1, %c0 : tensor<?xf32> 168 %1 = affine.apply #map0()[%d0, %arg0] 169 170 // CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[OUT]]) -> (tensor<?xf32>) { 171 %2 = scf.forall (%arg3) in (%1) shared_outs(%o = %0) -> (tensor<?xf32>) { 172 %3 = affine.apply #map1(%arg3)[%arg0] 173 %4 = affine.min #map2(%arg3)[%d0, %arg0] 174 %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32> 175 176 // CHECK: %[[T0:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}] 177 // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]] 178 %6 = tensor.extract_slice %arg1[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32> 179 180 // CHECK: %[[T2:.*]] = linalg.elemwise_unary {{.*}} outs(%[[T1]] 181 %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32> 182 scf.forall.in_parallel { 183 tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32> 184 } 185 } 186 // CHECK: } 187 func.return %2 : tensor<?xf32> 188 } 189 190 module attributes {transform.with_named_sequence} { 191 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 192 %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op 193 %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op 194 195 // linalg.fill is tileable. The op is tiled and fused. 196 transform.structured.fuse_into_containing_op %0 into %1 197 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 198 transform.yield 199 } 200 } 201} 202 203// ----- 204 205#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> 206#map1 = affine_map<(d0)[s0] -> (d0 * s0)> 207#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> 208 209module { 210 // CHECK-LABEL: func.func @fuse_tileable_multi_output_op 211 // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index 212 // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32> 213 // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32> 214 // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32> 215 // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32> 216 func.func @fuse_tileable_multi_output_op(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>) -> tensor<?xf32> { 217 %cst = arith.constant 4.200000e+01 : f32 218 %c0 = arith.constant 0 : index 219 220 %0:2 = linalg.generic { 221 indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], 222 iterator_types = ["parallel"] 223 } ins(%in : tensor<?xf32>) outs(%out_1, %out_3 : tensor<?xf32>, tensor<?xf32>) { 224 ^bb0(%a: f32, %b: f32, %c: f32): 225 %d = arith.addf %a, %b : f32 226 %e = arith.addf %d, %c : f32 227 linalg.yield %d, %e : f32, f32 228 } -> (tensor<?xf32>, tensor<?xf32>) 229 %d0 = tensor.dim %out_1, %c0 : tensor<?xf32> 230 231 %1 = affine.apply #map0()[%d0, %idx] 232 233 // CHECK: scf.forall {{.*}} { 234 %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) { 235 %3 = affine.apply #map1(%i)[%idx] 236 %4 = affine.min #map2(%i)[%d0, %idx] 237 %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32> 238 239 // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}] 240 // CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}} ins(%[[T0]] 241 %6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32> 242 243 // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]#0 244 %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32> 245 scf.forall.in_parallel { 246 tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32> 247 } 248 } 249 // CHECK: } 250 func.return %2 : tensor<?xf32> 251 } 252 253 module attributes {transform.with_named_sequence} { 254 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 255 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> 256 %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> 257 258 // linalg.generic is tileable. The op is tiled and fused. 259 transform.structured.fuse_into_containing_op %0 into %1 260 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) 261 transform.yield 262 } 263 } 264} 265 266// ----- 267 268module { 269 // CHECK-LABEL: func.func @fuse_repeated 270 func.func @fuse_repeated(%fill: tensor<2xf32>, %output: tensor<2xf32>) -> tensor<2xf32> { 271 %c0 = arith.constant 0.0 : f32 272 %0 = linalg.fill ins(%c0 : f32) outs(%fill : tensor<2xf32>) -> tensor<2xf32> 273 274 // CHECK: scf.forall 275 %1 = scf.forall (%i) in (2) shared_outs(%arg1 = %output) -> (tensor<2xf32>) { 276 %2 = tensor.extract_slice %0[%i][1][1] : tensor<2xf32> to tensor<1xf32> 277 %3 = tensor.extract_slice %arg1[%i][1][1] : tensor<2xf32> to tensor<1xf32> 278 // CHECK: %[[FUSED:.+]] = linalg.fill 279 // CHECK: elemwise_unary ins(%[[FUSED]] 280 %4 = linalg.elemwise_unary ins(%2 : tensor<1xf32>) outs(%3 : tensor<1xf32>) -> tensor<1xf32> 281 scf.forall.in_parallel { 282 tensor.parallel_insert_slice %4 into %arg1[%i][1][1] : tensor<1xf32> into tensor<2xf32> 283 } 284 } 285 286 return %1 : tensor<2xf32> 287 } 288 289 module attributes {transform.with_named_sequence} { 290 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 291 %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op 292 %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op 293 294 // Create a new handle that points to `linalg.fill` twice. 295 %2 = transform.merge_handles %0, %0 : !transform.any_op 296 297 // It shouldn't be a problem to fuse this handle. 298 transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 299 transform.yield 300 } 301 } 302} 303 304// ----- 305 306#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> 307#map1 = affine_map<(d0)[s0] -> (d0 * s0)> 308#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> 309 310module { 311 // CHECK-LABEL: func.func @fuse_tileable_multi_output_op_multi_use 312 // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index 313 // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32> 314 // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32> 315 // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32> 316 // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32> 317 func.func @fuse_tileable_multi_output_op_multi_use(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>) 318 -> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) { 319 %cst = arith.constant 4.200000e+01 : f32 320 %c0 = arith.constant 0 : index 321 322 // CHECK: %[[G0:.*]]:2 = linalg.generic 323 %0:2 = linalg.generic { 324 indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], 325 iterator_types = ["parallel"] 326 } ins(%in : tensor<?xf32>) outs(%out_1, %out_3 : tensor<?xf32>, tensor<?xf32>) { 327 ^bb0(%a: f32, %b: f32, %c: f32): 328 %d = arith.addf %a, %b : f32 329 %e = arith.addf %d, %c : f32 330 linalg.yield %d, %e : f32, f32 331 } -> (tensor<?xf32>, tensor<?xf32>) 332 %d0 = tensor.dim %out_1, %c0 : tensor<?xf32> 333 334 %1 = affine.apply #map0()[%d0, %idx] 335 336 // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]]) 337 // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) { 338 // expected-remark @below{{new containing op}} 339 %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) { 340 // CHECK: %[[I0:.*]] = affine.apply {{.*}} 341 %3 = affine.apply #map1(%i)[%idx] 342 // CHECK: %[[I1:.*]] = affine.min {{.*}} 343 %4 = affine.min #map2(%i)[%d0, %idx] 344 %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32> 345 346 // CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}} 347 %6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32> 348 349 %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32> 350 scf.forall.in_parallel { 351 // CHECK: tensor.parallel_insert_slice %[[T1]]#0 into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32> 352 tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32> 353 } 354 } 355 // CHECK: return %[[R0]]#0, %[[R0]]#1, %[[G0]]#1 356 func.return %2, %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32> 357 // CHECK: } 358 } 359 360 module attributes {transform.with_named_sequence} { 361 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 362 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> 363 %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> 364 365 // linalg.generic is tileable. The op is tiled and fused. 366 %fused, %containing = transform.structured.fuse_into_containing_op %0 into %1 367 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) 368 transform.debug.emit_remark_at %containing, "new containing op" : !transform.any_op 369 transform.yield 370 } 371 } 372} 373 374// ----- 375 376#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> 377#map1 = affine_map<(d0)[s0] -> (d0 * s0)> 378#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> 379 380module { 381 // CHECK-LABEL: func.func @fuse_tileable_mixed_dominating_uses 382 // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index 383 // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32> 384 // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32> 385 // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32> 386 // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32> 387 func.func @fuse_tileable_mixed_dominating_uses(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>) 388 -> (tensor<?xf32>, tensor<?xf32>) { 389 %cst = arith.constant 4.200000e+01 : f32 390 %c0 = arith.constant 0 : index 391 392 // CHECK: %[[G0:.*]] = linalg.generic 393 %0 = linalg.generic { 394 indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], 395 iterator_types = ["parallel"] 396 } ins(%in : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) { 397 ^bb0(%a: f32, %b: f32): 398 %d = arith.addf %a, %b : f32 399 linalg.yield %d : f32 400 } -> tensor<?xf32> 401 // CHECK: %[[D0:.*]] = tensor.dim %[[G0]] 402 %d0 = tensor.dim %0, %c0 : tensor<?xf32> 403 404 %1 = affine.apply #map0()[%d0, %idx] 405 406 // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]]) 407 // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) { 408 %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) { 409 // CHECK: %[[I0:.*]] = affine.apply {{.*}} 410 %3 = affine.apply #map1(%i)[%idx] 411 // CHECK: %[[I1:.*]] = affine.min {{.*}} 412 %4 = affine.min #map2(%i)[%d0, %idx] 413 %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32> 414 415 // CHECK: %[[T1:.*]] = linalg.generic {{.*}} 416 %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32> 417 418 %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32> 419 scf.forall.in_parallel { 420 // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32> 421 tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32> 422 } 423 } 424 // CHECK: return %[[R0]]#0, %[[R0]]#1 425 func.return %2, %0 : tensor<?xf32>, tensor<?xf32> 426 // CHECK: } 427 } 428 429 module attributes {transform.with_named_sequence} { 430 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 431 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> 432 %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> 433 434 // linalg.generic is tileable. The op is tiled and fused. 435 transform.structured.fuse_into_containing_op %0 into %1 436 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) 437 transform.yield 438 } 439 } 440} 441 442// ----- 443 444#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> 445#map1 = affine_map<(d0)[s0] -> (d0 * s0)> 446#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> 447#map3 = affine_map<(d0, d1) -> (d0, d1)> 448#map4 = affine_map<(d0, d1) -> (d0)> 449 450module { 451 // CHECK-LABEL: func.func @fuse_tileable_reductions 452 // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index 453 // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?x?xf32> 454 // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32> 455 // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32> 456 // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32> 457 func.func @fuse_tileable_reductions(%idx: index, %in: tensor<?x?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>) 458 -> (tensor<?xf32>, tensor<?xf32>) { 459 %cst = arith.constant 4.200000e+01 : f32 460 %c0 = arith.constant 0 : index 461 462 %0 = linalg.generic { 463 indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"] 464 } ins(%in : tensor<?x?xf32>) outs(%out_1 : tensor<?xf32>) { 465 ^bb0(%a: f32, %b: f32): 466 %d = arith.maximumf %a, %b : f32 467 linalg.yield %d : f32 468 } -> tensor<?xf32> 469 %d0 = tensor.dim %out_1, %c0 : tensor<?xf32> 470 471 %1 = affine.apply #map0()[%d0, %idx] 472 473 // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]]) 474 // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) { 475 %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) { 476 // CHECK: %[[I0:.*]] = affine.apply {{.*}} 477 %3 = affine.apply #map1(%i)[%idx] 478 // CHECK: %[[I1:.*]] = affine.min {{.*}} 479 %4 = affine.min #map2(%i)[%d0, %idx] 480 %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32> 481 482 // CHECK: %[[T1:.*]] = linalg.generic {{.*}} 483 %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32> 484 485 %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32> 486 scf.forall.in_parallel { 487 // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32> 488 tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32> 489 } 490 } 491 // CHECK: return %[[R0]]#0, %[[R0]]#1 492 func.return %2, %0 : tensor<?xf32>, tensor<?xf32> 493 // CHECK: } 494 } 495 496 module attributes {transform.with_named_sequence} { 497 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 498 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> 499 %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> 500 501 // linalg.generic is tileable. The op is tiled and fused. 502 transform.structured.fuse_into_containing_op %0 into %1 503 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) 504 transform.yield 505 } 506 } 507} 508 509// ----- 510 511#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> 512#map1 = affine_map<(d0)[s0] -> (d0 * s0)> 513#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> 514#map3 = affine_map<(d0) -> (d0)> 515 516module { 517 // CHECK-LABEL: func.func @fuse_tileable_using_new_handle 518 // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index 519 // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32> 520 // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32> 521 // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32> 522 // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32> 523 func.func @fuse_tileable_using_new_handle(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>) 524 -> (tensor<?xf32>, tensor<?xf32>) { 525 %cst = arith.constant 4.200000e+01 : f32 526 %c0 = arith.constant 0 : index 527 528 %0 = linalg.generic { 529 indexing_maps = [#map3, #map3], iterator_types = ["parallel"] 530 } ins(%in : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) { 531 ^bb0(%a: f32, %b: f32): 532 %d = arith.addf %a, %b : f32 533 linalg.yield %d : f32 534 } -> tensor<?xf32> 535 536 %1 = linalg.generic { 537 indexing_maps = [#map3, #map3], iterator_types = ["parallel"] 538 } ins(%0 : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) { 539 ^bb0(%a: f32, %b: f32): 540 %d = arith.mulf %a, %b : f32 541 linalg.yield %d : f32 542 } -> tensor<?xf32> 543 %d0 = tensor.dim %out_1, %c0 : tensor<?xf32> 544 545 %2 = affine.apply #map0()[%d0, %idx] 546 547 // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]]) 548 // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) { 549 %3 = scf.forall (%i) in (%2) shared_outs(%o = %out_2) -> (tensor<?xf32>) { 550 // CHECK: %[[I0:.*]] = affine.apply {{.*}} 551 %4 = affine.apply #map1(%i)[%idx] 552 // CHECK: %[[I1:.*]] = affine.min {{.*}} 553 %5 = affine.min #map2(%i)[%d0, %idx] 554 %6 = tensor.extract_slice %o[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32> 555 556 // CHECK: %[[T1:.*]] = linalg.generic {{.*}} 557 // CHECK: %[[T2:.*]] = linalg.generic {{.*}} 558 %7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32> 559 560 %8 = linalg.elemwise_unary ins(%7 : tensor<?xf32>) outs(%6 : tensor<?xf32>) -> tensor<?xf32> 561 scf.forall.in_parallel { 562 // CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32> 563 tensor.parallel_insert_slice %8 into %o[%2] [%5] [1] : tensor<?xf32> into tensor<?xf32> 564 } 565 } 566 // CHECK: return %[[R0]]#0, %[[R0]]#1 567 func.return %3, %1 : tensor<?xf32>, tensor<?xf32> 568 // CHECK: } 569 } 570 571 module attributes {transform.with_named_sequence} { 572 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 573 %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> 574 %add, %reduce = transform.split_handle %0 : (!transform.op<"linalg.generic">) -> (!transform.op<"linalg.generic">, !transform.op<"linalg.generic">) 575 %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> 576 577 %fused_ops, %new_forall = transform.structured.fuse_into_containing_op %reduce into %1 578 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">) 579 %fused_ops_2, %new_forall_2 = transform.structured.fuse_into_containing_op %add into %new_forall 580 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">) 581 transform.yield 582 } 583 } 584} 585 586// ----- 587 588// This is a regression test. Make sure that the transform succeeds and valid 589// IR is generated. 590 591module { 592 // CHECK-LABEL: func.func @softmax_dispatch_0_generic_16x128x128_f32 593 func.func @softmax_dispatch_0_generic_16x128x128_f32() -> tensor<16x128x128xf32> { 594 %c0 = arith.constant 0 : index 595 %cst = arith.constant dense<5.000000e+00> : tensor<16x128x128xf32> 596 %cst_1 = arith.constant 5.000000e+00 : f32 597 %1 = tensor.empty() : tensor<16x128xf32> 598 %2 = tensor.empty() : tensor<16x128x128xf32> 599 %3 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<16x128xf32>) -> tensor<16x128xf32> 600 %4 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<16x128xf32>) -> tensor<16x128xf32> 601 %5 = linalg.generic {producer, indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%cst : tensor<16x128x128xf32>) outs(%4 : tensor<16x128xf32>) { 602 ^bb0(%in: f32, %out: f32): 603 %8 = arith.maximumf %in, %out : f32 604 linalg.yield %8 : f32 605 } -> tensor<16x128xf32> 606 %c16 = arith.constant 16 : index 607 %c32 = arith.constant 32 : index 608 %7 = scf.forall (%arg0, %arg1) in (16, 32) shared_outs(%arg2 = %2) -> (tensor<16x128x128xf32>) { 609 %11 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg1) 610 %extracted_slice = tensor.extract_slice %5[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32> 611 %extracted_slice_3 = tensor.extract_slice %2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32> 612 %extracted_slice_4 = tensor.extract_slice %3[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32> 613 %15:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice : tensor<1x4xf32>) outs(%extracted_slice_3, %extracted_slice_4 : tensor<1x4x128xf32>, tensor<1x4xf32>) { 614 ^bb0(%in: f32, %out: f32, %out_9: f32): 615 %22 = arith.subf %cst_1, %in : f32 616 %23 = math.exp %22 : f32 617 %24 = arith.addf %23, %out_9 : f32 618 linalg.yield %23, %24 : f32, f32 619 } -> (tensor<1x4x128xf32>, tensor<1x4xf32>) 620 %extracted_slice_5 = tensor.extract_slice %5[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32> 621 %extracted_slice_6 = tensor.extract_slice %2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32> 622 %extracted_slice_7 = tensor.extract_slice %3[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32> 623 %19:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice_5 : tensor<1x4xf32>) outs(%extracted_slice_6, %extracted_slice_7 : tensor<1x4x128xf32>, tensor<1x4xf32>) { 624 ^bb0(%in: f32, %out: f32, %out_9: f32): 625 %22 = arith.subf %cst_1, %in : f32 626 %23 = math.exp %22 : f32 627 %24 = arith.addf %23, %out_9 : f32 628 linalg.yield %23, %24 : f32, f32 629 } -> (tensor<1x4x128xf32>, tensor<1x4xf32>) 630 %extracted_slice_8 = tensor.extract_slice %arg2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32> 631 %20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15#0, %19#1 : tensor<1x4x128xf32>, tensor<1x4xf32>) outs(%extracted_slice_8 : tensor<1x4x128xf32>) { 632 ^bb0(%in: f32, %in_9: f32, %out: f32): 633 %22 = arith.divf %in, %in_9 : f32 634 linalg.yield %22 : f32 635 } -> tensor<1x4x128xf32> 636 scf.forall.in_parallel { 637 tensor.parallel_insert_slice %20 into %arg2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<1x4x128xf32> into tensor<16x128x128xf32> 638 } 639 } 640 return %7 : tensor<16x128x128xf32> 641 } 642 643 module attributes {transform.with_named_sequence} { 644 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 645 %0 = transform.structured.match attributes{producer} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> 646 %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> 647 transform.structured.fuse_into_containing_op %0 into %1 648 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) 649 transform.yield 650 } 651 } 652} 653 654 655//////////////////////////////////////////////////////////////////////////////// 656// Tests below are expected to fail. 657//////////////////////////////////////////////////////////////////////////////// 658 659// ----- 660 661// NO-CHECK-LABEL-ON-EXPECTED-ERROR 662func.func @copy_1d_1024xf16(%arg0: tensor<123x456xf32>, %arg1: tensor<456x789xf32>, %arg2 : tensor<123x789xf32>) -> tensor<123x789xf32> { 663 %0 = arith.constant 0.000000e+00 : f32 664 %1 = linalg.fill ins(%0 : f32) outs(%arg2 : tensor<123x789xf32>) -> tensor<123x789xf32> 665 // expected-note @below {{containing op}} 666 %2 = linalg.matmul ins(%arg0, %arg1 : tensor<123x456xf32>, tensor<456x789xf32>) outs(%1 : tensor<123x789xf32>) -> tensor<123x789xf32> 667 return %2 : tensor<123x789xf32> 668} 669 670module attributes {transform.with_named_sequence} { 671 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 672 %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 673 : (!transform.any_op) -> !transform.any_op 674 %1 = transform.structured.match ops{["linalg.matmul"]} in %arg1 675 : (!transform.any_op) -> !transform.any_op 676 %tiled_op, %forall_op = transform.structured.tile_using_forall %1 677 num_threads [] tile_sizes [50, 16] 678 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 679 // Note that we pass in %tiled_op, which isn't a container op. 680 // expected-error @+2 {{could not find next producer to fuse into container}} 681 %fused_op, %new_containing_op = 682 transform.structured.fuse_into_containing_op %0 into %tiled_op 683 : (!transform.any_op, !transform.any_op) 684 -> (!transform.any_op, !transform.any_op) 685 transform.yield 686 } 687} 688