xref: /llvm-project/mlir/test/Examples/transform/Ch4/multiple.mlir (revision 2798b72ae7e5caad793169b77cbac47fe2362d0f)
1// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics
2
3// Matmul+ReLU.
4func.func @fc_relu_operands_00(
5    %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
6    %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
7    -> tensor<512x512xf32> {
8  // Matrix-matrix multiplication.
9  // expected-remark @below {{matmul # 0}}
10  %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
11                          outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
12
13  // Elementwise addition.
14  // expected-remark @below {{add # 0}}
15  %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
16    ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
17    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
18
19  // Elementwise max with 0 (ReLU).
20  %c0f = arith.constant 0.0 : f32
21  // expected-remark @below {{max # 0}}
22  %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
23    ins(%biased, %c0f : tensor<512x512xf32>, f32)
24    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
25  func.return %relued : tensor<512x512xf32>
26}
27
28// Matmul+ReLU with swapped operands.
29func.func @fc_relu_operands_01(
30    %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
31    %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
32    -> tensor<512x512xf32> {
33  // Matrix-matrix multiplication.
34  // expected-remark @below {{matmul # 1}}
35  %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
36                          outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
37
38  // Elementwise addition.
39  // expected-remark @below {{add # 1}}
40  %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
41    ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
42    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
43
44  // Elementwise max with 0 (ReLU).
45  %c0f = arith.constant 0.0 : f32
46  // expected-remark @below {{max # 1}}
47  %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
48    ins(%c0f, %biased : f32, tensor<512x512xf32>)
49    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
50  func.return %relued : tensor<512x512xf32>
51}
52
53// The module containing named sequences must have an attribute allowing them
54// to enable verification.
55module @transforms attributes { transform.with_named_sequence } {
56  // Entry point. This takes as the only argument the root operation (typically
57  // pass root) given to the transform interpreter.
58  transform.named_sequence @__transform_main(
59      %root: !transform.any_op {transform.consumed}) {
60
61    // Traverses the payload IR associated with the operand handle, invoking
62    // @match_matmul_elemwise on each of the operations. If the named sequence
63    // succeeds, i.e., if none of the nested match (transform) operations
64    // produced a silenceable failure, invokes @print_matmul_elemwise and
65    // forwards the values yielded as arguments of the new invocation. If the
66    // named sequence fails with a silenceable failure, silences it (the message
67    // is forwarded to the debug stream). Definite failures are propagated
68    // immediately and unconditionally, as usual.
69    transform.foreach_match in %root
70      @match_matmul_elemwise -> @print_matmul_elemwise
71      : (!transform.any_op) -> !transform.any_op
72
73    transform.yield
74  }
75
76  // This is an action sequence.
77  transform.named_sequence @print_matmul_elemwise(
78      %matmul: !transform.any_op {transform.readonly},
79      %add: !transform.any_op {transform.readonly},
80      %max: !transform.any_op {transform.readonly},
81      %pos: !transform.param<i32> {transform.readonly}) {
82    transform.debug.emit_param_as_remark %pos, "matmul #" at %matmul
83      : !transform.param<i32>, !transform.any_op
84    transform.debug.emit_param_as_remark %pos, "add #" at %add
85      : !transform.param<i32>, !transform.any_op
86    transform.debug.emit_param_as_remark %pos, "max #" at %max
87      : !transform.param<i32>, !transform.any_op
88    transform.yield
89  }
90
91  // This is also a matcher sequence. It is similarly given an operation to
92  // match and nested operations must succeed in order for a match to be deemed
93  // successful. It starts matching from the last operation in the use-def chain
94  // and goes back because each operand (use) has exactly one definition.
95  transform.named_sequence @match_matmul_elemwise(
96      %last: !transform.any_op {transform.readonly})
97      -> (!transform.any_op, !transform.any_op, !transform.any_op,
98          !transform.param<i32>) {
99    // The last operation must be an elementwise binary.
100    transform.match.operation_name %last ["linalg.elemwise_binary"]
101      : !transform.any_op
102
103    // One of its operands must be defined by another operation, to which we
104    // will get a handle here. This is achieved thanks to a newly defined
105    // operation that tries to match operands one by one using the match
106    // operations nested in its region.
107    %pos, %middle = transform.match.my.has_operand_satisfying %last
108        : (!transform.any_op) -> (!transform.param<i32>, !transform.any_op) {
109    ^bb0(%operand: !transform.any_value):
110      // The operand must be defined by an operation.
111      %def = transform.get_defining_op %operand
112        : (!transform.any_value) -> !transform.any_op
113      // The defining operation must itself be an elementwise binary.
114      transform.match.operation_name %def ["linalg.elemwise_binary"]
115        : !transform.any_op
116      transform.yield %def : !transform.any_op
117    }
118
119    // And the first operand of that operation must be defined by yet another
120    // operation.
121    %matmul = transform.get_producer_of_operand %middle[0]
122      : (!transform.any_op) -> !transform.any_op
123    // And that operation is a matmul.
124    transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op
125    // We will yield the handles to the matmul and the two elementwise
126    // operations separately.
127    transform.yield %matmul, %middle, %last, %pos
128      : !transform.any_op, !transform.any_op, !transform.any_op,
129        !transform.param<i32>
130  }
131}
132