xref: /llvm-project/mlir/test/Examples/transform/Ch4/sequence.mlir (revision 2798b72ae7e5caad793169b77cbac47fe2362d0f)
1// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics
2//
3// RUN: transform-opt-ch4 %s \
4// RUN:              --transform-interpreter='entry-point=__transform_main_v2' \
5// RUN:              --verify-diagnostics
6
7// ****************************** IMPORTANT NOTE ******************************
8//
9// If you are changing this file, you may also need to change
10// mlir/docs/Tutorials/Transform accordingly.
11//
12// ****************************************************************************
13
14// Original function to optimize.
15func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
16                   %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
17                   -> tensor<512x512xf32> {
18  // Matrix-matrix multiplication.
19  // expected-remark @below {{matmul}}
20  %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
21                          outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
22
23  // Elementwise addition.
24  // expected-remark @below {{elementwise binary}}
25  %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
26    ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
27    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
28
29  // Elementwise max with 0 (ReLU).
30  %c0f = arith.constant 0.0 : f32
31  // expected-remark @below {{elementwise binary}}
32  %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
33    ins(%biased, %c0f : tensor<512x512xf32>, f32)
34    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
35  func.return %relued : tensor<512x512xf32>
36}
37
38// The module containing named sequences must have an attribute allowing them
39// to enable verification.
40module @transforms attributes { transform.with_named_sequence } {
41  // Entry point. This takes as the only argument the root operation (typically
42  // pass root) given to the transform interpreter.
43  transform.named_sequence @__transform_main(
44      %root: !transform.any_op {transform.readonly}) {
45    // Collect operations that match the criteria specified in the named
46    // sequence. If the named sequence fails with a silenceable failure,
47    // silences it (the message is forwarded to the debug stream). If the named
48    // sequence succeeds, appends its results to the results of this operation.
49    %elemwise = transform.collect_matching @match_elemwise in %root
50      : (!transform.any_op) -> !transform.any_op
51    %matmul = transform.collect_matching @match_matmul in %root
52      : (!transform.any_op) -> !transform.any_op
53
54    transform.include @print_elemwise failures(propagate)  (%elemwise)
55      : (!transform.any_op) -> ()
56    transform.include @print_matmul failures(propagate)  (%matmul)
57      : (!transform.any_op) -> ()
58
59    transform.yield
60  }
61
62  // Alternative entry point.
63  transform.named_sequence @__transform_main_v2(
64      %root: !transform.any_op {transform.readonly}) {
65    // Collect groups of operations that match the criteria specified in the
66    // named sequence.
67    %matmul, %el1, %el2 = transform.collect_matching @match_matmul_elemwise in %root
68      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
69    %elemwise = transform.merge_handles %el1, %el2 : !transform.any_op
70
71    transform.include @print_elemwise failures(propagate)  (%elemwise)
72      : (!transform.any_op) -> ()
73    transform.include @print_matmul failures(propagate)  (%matmul)
74      : (!transform.any_op) -> ()
75
76    transform.yield
77  }
78
79  // This is a matcher sequence. It is given an operation to match and the
80  // match is considered successful unless any nested operation produces a
81  // failure. The values yielded by this operation will be forwarded to the
82  // rewriter sequence on success.
83  transform.named_sequence @match_elemwise(
84      %entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
85    transform.match.operation_name %entry ["linalg.elemwise_binary"]
86      : !transform.any_op
87    transform.yield %entry : !transform.any_op
88  }
89  transform.named_sequence @match_matmul(
90      %entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
91    transform.match.operation_name %entry ["linalg.matmul"] : !transform.any_op
92    transform.yield %entry : !transform.any_op
93  }
94
95  // This is an action sequence.
96  transform.named_sequence @print_elemwise(
97      %elemwise_binary: !transform.any_op {transform.readonly}) {
98    transform.debug.emit_remark_at
99      %elemwise_binary, "elementwise binary" : !transform.any_op
100    transform.yield
101  }
102  transform.named_sequence @print_matmul(
103      %matmul: !transform.any_op {transform.readonly}) {
104    transform.debug.emit_remark_at %matmul, "matmul" : !transform.any_op
105    transform.yield
106  }
107
108  // This is also a matcher sequence. It is similarly given an operation to
109  // match and nested operations must succeed in order for a match to be deemed
110  // successful. It starts matching from the last operation in the use-def chain
111  // and goes back because each operand (use) has exactly one definition.
112  transform.named_sequence @match_matmul_elemwise(
113      %last: !transform.any_op {transform.readonly})
114      -> (!transform.any_op, !transform.any_op, !transform.any_op) {
115    // The last operation must be an elementwise binary.
116    transform.match.operation_name %last ["linalg.elemwise_binary"]
117      : !transform.any_op
118    // Its first operand must be defined by another operation, to which we
119    // will get a handle here. We are guaranteed that the first operand exists
120    // because we know the operation is binary, but even in absence of such a
121    // guarantee, this operation would have produced a silenceable failure when
122    // `%last` does not have enough operands.
123    %middle = transform.get_producer_of_operand %last[0]
124      : (!transform.any_op) -> !transform.any_op
125    // The defining operation must itself be an elementwise binary.
126    transform.match.operation_name %middle ["linalg.elemwise_binary"]
127      : !transform.any_op
128    // And the first operand of that operation must be defined by yet another
129    // operation.
130    %matmul = transform.get_producer_of_operand %middle[0]
131      : (!transform.any_op) -> !transform.any_op
132    // And that operation is a matmul.
133    transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op
134    // We will yield the handles to the matmul and the two elementwise
135    // operations separately.
136    transform.yield %matmul, %middle, %last
137      : !transform.any_op, !transform.any_op, !transform.any_op
138  }
139}
140