1// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics 2// 3// RUN: transform-opt-ch4 %s \ 4// RUN: --transform-interpreter='entry-point=__transform_main_v2' \ 5// RUN: --verify-diagnostics 6 7// ****************************** IMPORTANT NOTE ****************************** 8// 9// If you are changing this file, you may also need to change 10// mlir/docs/Tutorials/Transform accordingly. 11// 12// **************************************************************************** 13 14// Original function to optimize. 15func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, 16 %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) 17 -> tensor<512x512xf32> { 18 // Matrix-matrix multiplication. 19 // expected-remark @below {{matmul}} 20 %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) 21 outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> 22 23 // Elementwise addition. 24 // expected-remark @below {{elementwise binary}} 25 %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> } 26 ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) 27 outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> 28 29 // Elementwise max with 0 (ReLU). 30 %c0f = arith.constant 0.0 : f32 31 // expected-remark @below {{elementwise binary}} 32 %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } 33 ins(%biased, %c0f : tensor<512x512xf32>, f32) 34 outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> 35 func.return %relued : tensor<512x512xf32> 36} 37 38// The module containing named sequences must have an attribute allowing them 39// to enable verification. 40module @transforms attributes { transform.with_named_sequence } { 41 // Entry point. This takes as the only argument the root operation (typically 42 // pass root) given to the transform interpreter. 43 transform.named_sequence @__transform_main( 44 %root: !transform.any_op {transform.readonly}) { 45 // Collect operations that match the criteria specified in the named 46 // sequence. If the named sequence fails with a silenceable failure, 47 // silences it (the message is forwarded to the debug stream). If the named 48 // sequence succeeds, appends its results to the results of this operation. 49 %elemwise = transform.collect_matching @match_elemwise in %root 50 : (!transform.any_op) -> !transform.any_op 51 %matmul = transform.collect_matching @match_matmul in %root 52 : (!transform.any_op) -> !transform.any_op 53 54 transform.include @print_elemwise failures(propagate) (%elemwise) 55 : (!transform.any_op) -> () 56 transform.include @print_matmul failures(propagate) (%matmul) 57 : (!transform.any_op) -> () 58 59 transform.yield 60 } 61 62 // Alternative entry point. 63 transform.named_sequence @__transform_main_v2( 64 %root: !transform.any_op {transform.readonly}) { 65 // Collect groups of operations that match the criteria specified in the 66 // named sequence. 67 %matmul, %el1, %el2 = transform.collect_matching @match_matmul_elemwise in %root 68 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 69 %elemwise = transform.merge_handles %el1, %el2 : !transform.any_op 70 71 transform.include @print_elemwise failures(propagate) (%elemwise) 72 : (!transform.any_op) -> () 73 transform.include @print_matmul failures(propagate) (%matmul) 74 : (!transform.any_op) -> () 75 76 transform.yield 77 } 78 79 // This is a matcher sequence. It is given an operation to match and the 80 // match is considered successful unless any nested operation produces a 81 // failure. The values yielded by this operation will be forwarded to the 82 // rewriter sequence on success. 83 transform.named_sequence @match_elemwise( 84 %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { 85 transform.match.operation_name %entry ["linalg.elemwise_binary"] 86 : !transform.any_op 87 transform.yield %entry : !transform.any_op 88 } 89 transform.named_sequence @match_matmul( 90 %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { 91 transform.match.operation_name %entry ["linalg.matmul"] : !transform.any_op 92 transform.yield %entry : !transform.any_op 93 } 94 95 // This is an action sequence. 96 transform.named_sequence @print_elemwise( 97 %elemwise_binary: !transform.any_op {transform.readonly}) { 98 transform.debug.emit_remark_at 99 %elemwise_binary, "elementwise binary" : !transform.any_op 100 transform.yield 101 } 102 transform.named_sequence @print_matmul( 103 %matmul: !transform.any_op {transform.readonly}) { 104 transform.debug.emit_remark_at %matmul, "matmul" : !transform.any_op 105 transform.yield 106 } 107 108 // This is also a matcher sequence. It is similarly given an operation to 109 // match and nested operations must succeed in order for a match to be deemed 110 // successful. It starts matching from the last operation in the use-def chain 111 // and goes back because each operand (use) has exactly one definition. 112 transform.named_sequence @match_matmul_elemwise( 113 %last: !transform.any_op {transform.readonly}) 114 -> (!transform.any_op, !transform.any_op, !transform.any_op) { 115 // The last operation must be an elementwise binary. 116 transform.match.operation_name %last ["linalg.elemwise_binary"] 117 : !transform.any_op 118 // Its first operand must be defined by another operation, to which we 119 // will get a handle here. We are guaranteed that the first operand exists 120 // because we know the operation is binary, but even in absence of such a 121 // guarantee, this operation would have produced a silenceable failure when 122 // `%last` does not have enough operands. 123 %middle = transform.get_producer_of_operand %last[0] 124 : (!transform.any_op) -> !transform.any_op 125 // The defining operation must itself be an elementwise binary. 126 transform.match.operation_name %middle ["linalg.elemwise_binary"] 127 : !transform.any_op 128 // And the first operand of that operation must be defined by yet another 129 // operation. 130 %matmul = transform.get_producer_of_operand %middle[0] 131 : (!transform.any_op) -> !transform.any_op 132 // And that operation is a matmul. 133 transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op 134 // We will yield the handles to the matmul and the two elementwise 135 // operations separately. 136 transform.yield %matmul, %middle, %last 137 : !transform.any_op, !transform.any_op, !transform.any_op 138 } 139} 140