1// RUN: mlir-opt %s --transform-interpreter --split-input-file -canonicalize | FileCheck %s 2 3// This is a simple tile-and-fuse example with a single fusion group. 4 5module { 6 // CHECK: func @foo 7 // CHECK: scf.forall {{.*}} { 8 // CHECK: linalg.fill 9 // CHECK: linalg.matmul 10 // CHECK: linalg.generic 11 // CHECK: } 12 func.func @foo(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?xf32>, 13 %D: tensor<?x?xf32>, %sz0: index, %sz1: index) 14 -> tensor<?x?xf32> 15 { 16 %cst = arith.constant 0.000000e+00 : f32 17 %5 = linalg.fill 18 {__producer__} 19 ins(%cst : f32) 20 outs(%D : tensor<?x?xf32>) -> tensor<?x?xf32> 21 %6 = linalg.matmul 22 {__producer__} 23 ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) 24 outs(%5 : tensor<?x?xf32>) -> tensor<?x?xf32> 25 %7 = linalg.generic 26 {__root__, 27 indexing_maps = [affine_map<(d0, d1) -> (d0)>, 28 affine_map<(d0, d1) -> (d0, d1)>, 29 affine_map<(d0, d1) -> (d0, d1)>], 30 iterator_types = ["parallel", "parallel"] 31 } 32 ins(%C, %6 : tensor<?xf32>, tensor<?x?xf32>) 33 outs(%D : tensor<?x?xf32>) { 34 ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): 35 %16 = arith.maximumf %arg3, %cst : f32 36 %17 = arith.cmpf ogt, %arg2, %cst : f32 37 %18 = arith.select %17, %cst, %16 : f32 38 linalg.yield %18 : f32 39 } -> tensor<?x?xf32> 40 return %7 : tensor<?x?xf32> 41 } 42 43 module attributes {transform.with_named_sequence} { 44 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 45 // Find the root and all producers. 46 %root = transform.structured.match attributes{"__root__"} in %arg1 : (!transform.any_op) -> !transform.any_op 47 %producers = transform.structured.match attributes{"__producer__"} in %arg1 : (!transform.any_op) -> !transform.any_op 48 49 // Tile the root. 50 %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [10, 20] 51 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 52 53 // Fuse all producers. 54 transform.structured.fuse_into_containing_op %producers into %forall_op 55 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 56 transform.yield 57 } 58 } 59} 60 61// ----- 62 63// Inverse the order of the payload ops passed to the tile_using_forall 64// op. Fusion should still work. 65 66module { 67 // CHECK: func @foo 68 // CHECK: scf.forall {{.*}} { 69 // CHECK: linalg.fill 70 // CHECK: linalg.matmul 71 // CHECK: linalg.generic 72 // CHECK: } 73 func.func @foo(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?xf32>, 74 %D: tensor<?x?xf32>, %sz0: index, %sz1: index) 75 -> tensor<?x?xf32> 76 { 77 %cst = arith.constant 0.000000e+00 : f32 78 %5 = linalg.fill 79 {__producer__} 80 ins(%cst : f32) 81 outs(%D : tensor<?x?xf32>) -> tensor<?x?xf32> 82 %6 = linalg.matmul 83 {__producer__} 84 ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) 85 outs(%5 : tensor<?x?xf32>) -> tensor<?x?xf32> 86 %7 = linalg.generic 87 {__root__, 88 indexing_maps = [affine_map<(d0, d1) -> (d0)>, 89 affine_map<(d0, d1) -> (d0, d1)>, 90 affine_map<(d0, d1) -> (d0, d1)>], 91 iterator_types = ["parallel", "parallel"] 92 } 93 ins(%C, %6 : tensor<?xf32>, tensor<?x?xf32>) 94 outs(%D : tensor<?x?xf32>) { 95 ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): 96 %16 = arith.maximumf %arg3, %cst : f32 97 %17 = arith.cmpf ogt, %arg2, %cst : f32 98 %18 = arith.select %17, %cst, %16 : f32 99 linalg.yield %18 : f32 100 } -> tensor<?x?xf32> 101 return %7 : tensor<?x?xf32> 102 } 103 104 module attributes {transform.with_named_sequence} { 105 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 106 // Find the root and all producers. 107 %root = transform.structured.match attributes{"__root__"} in %arg1 : (!transform.any_op) -> !transform.any_op 108 %producers = transform.structured.match attributes{"__producer__"} in %arg1 : (!transform.any_op) -> !transform.any_op 109 %reversed_producers = transform.test_reverse_payload_ops %producers : (!transform.any_op) -> !transform.any_op 110 111 // Tile the root. 112 %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [10, 20] 113 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 114 115 // Fuse all producers. 116 transform.structured.fuse_into_containing_op %reversed_producers into %forall_op 117 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 118 transform.yield 119 } 120 } 121} 122