xref: /llvm-project/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir (revision e4384149b58f7c3d19c5d38bc46038c660b77ca9)
1// RUN: mlir-opt %s --transform-interpreter --split-input-file -canonicalize | FileCheck %s
2
3// This is a simple tile-and-fuse example with a single fusion group.
4
5module {
6  // CHECK: func @foo
7  // CHECK:   scf.forall {{.*}} {
8  // CHECK:     linalg.fill
9  // CHECK:     linalg.matmul
10  // CHECK:     linalg.generic
11  // CHECK:   }
12  func.func @foo(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?xf32>,
13                 %D: tensor<?x?xf32>, %sz0: index, %sz1: index)
14      -> tensor<?x?xf32>
15  {
16    %cst = arith.constant 0.000000e+00 : f32
17    %5 = linalg.fill
18        {__producer__}
19        ins(%cst : f32)
20        outs(%D : tensor<?x?xf32>) -> tensor<?x?xf32>
21    %6 = linalg.matmul
22        {__producer__}
23        ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
24        outs(%5 : tensor<?x?xf32>) -> tensor<?x?xf32>
25    %7 = linalg.generic
26        {__root__,
27         indexing_maps = [affine_map<(d0, d1) -> (d0)>,
28                          affine_map<(d0, d1) -> (d0, d1)>,
29                          affine_map<(d0, d1) -> (d0, d1)>],
30         iterator_types = ["parallel", "parallel"]
31        }
32        ins(%C, %6 : tensor<?xf32>, tensor<?x?xf32>)
33        outs(%D : tensor<?x?xf32>) {
34    ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
35      %16 = arith.maximumf %arg3, %cst : f32
36      %17 = arith.cmpf ogt, %arg2, %cst : f32
37      %18 = arith.select %17, %cst, %16 : f32
38      linalg.yield %18 : f32
39    } -> tensor<?x?xf32>
40    return %7 : tensor<?x?xf32>
41  }
42
43  module attributes {transform.with_named_sequence} {
44    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
45      // Find the root and all producers.
46      %root = transform.structured.match attributes{"__root__"} in %arg1 : (!transform.any_op) -> !transform.any_op
47      %producers = transform.structured.match attributes{"__producer__"} in %arg1 : (!transform.any_op) -> !transform.any_op
48
49      // Tile the root.
50      %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [10, 20]
51           : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
52
53      // Fuse all producers.
54      transform.structured.fuse_into_containing_op %producers into %forall_op
55        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
56        transform.yield
57    }
58  }
59}
60
61// -----
62
63// Inverse the order of the payload ops passed to the tile_using_forall
64// op. Fusion should still work.
65
66module {
67  // CHECK: func @foo
68  // CHECK:   scf.forall {{.*}} {
69  // CHECK:     linalg.fill
70  // CHECK:     linalg.matmul
71  // CHECK:     linalg.generic
72  // CHECK:   }
73  func.func @foo(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?xf32>,
74                 %D: tensor<?x?xf32>, %sz0: index, %sz1: index)
75      -> tensor<?x?xf32>
76  {
77    %cst = arith.constant 0.000000e+00 : f32
78    %5 = linalg.fill
79        {__producer__}
80        ins(%cst : f32)
81        outs(%D : tensor<?x?xf32>) -> tensor<?x?xf32>
82    %6 = linalg.matmul
83        {__producer__}
84        ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
85        outs(%5 : tensor<?x?xf32>) -> tensor<?x?xf32>
86    %7 = linalg.generic
87        {__root__,
88         indexing_maps = [affine_map<(d0, d1) -> (d0)>,
89                          affine_map<(d0, d1) -> (d0, d1)>,
90                          affine_map<(d0, d1) -> (d0, d1)>],
91         iterator_types = ["parallel", "parallel"]
92        }
93        ins(%C, %6 : tensor<?xf32>, tensor<?x?xf32>)
94        outs(%D : tensor<?x?xf32>) {
95    ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
96      %16 = arith.maximumf %arg3, %cst : f32
97      %17 = arith.cmpf ogt, %arg2, %cst : f32
98      %18 = arith.select %17, %cst, %16 : f32
99      linalg.yield %18 : f32
100    } -> tensor<?x?xf32>
101    return %7 : tensor<?x?xf32>
102  }
103
104  module attributes {transform.with_named_sequence} {
105    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
106      // Find the root and all producers.
107      %root = transform.structured.match attributes{"__root__"} in %arg1 : (!transform.any_op) -> !transform.any_op
108      %producers = transform.structured.match attributes{"__producer__"} in %arg1 : (!transform.any_op) -> !transform.any_op
109      %reversed_producers = transform.test_reverse_payload_ops %producers : (!transform.any_op) -> !transform.any_op
110
111      // Tile the root.
112      %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [10, 20]
113           : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
114
115      // Fuse all producers.
116      transform.structured.fuse_into_containing_op %reversed_producers into %forall_op
117        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
118        transform.yield
119    }
120  }
121}
122