1// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics 2 3// Matmul+ReLU. 4func.func @fc_relu_operands_00( 5 %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, 6 %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) 7 -> tensor<512x512xf32> { 8 // Matrix-matrix multiplication. 9 // expected-remark @below {{matmul # 0}} 10 %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) 11 outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> 12 13 // Elementwise addition. 14 // expected-remark @below {{add # 0}} 15 %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> } 16 ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) 17 outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> 18 19 // Elementwise max with 0 (ReLU). 20 %c0f = arith.constant 0.0 : f32 21 // expected-remark @below {{max # 0}} 22 %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } 23 ins(%biased, %c0f : tensor<512x512xf32>, f32) 24 outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> 25 func.return %relued : tensor<512x512xf32> 26} 27 28// Matmul+ReLU with swapped operands. 29func.func @fc_relu_operands_01( 30 %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, 31 %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) 32 -> tensor<512x512xf32> { 33 // Matrix-matrix multiplication. 34 // expected-remark @below {{matmul # 1}} 35 %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) 36 outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> 37 38 // Elementwise addition. 39 // expected-remark @below {{add # 1}} 40 %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> } 41 ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) 42 outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> 43 44 // Elementwise max with 0 (ReLU). 45 %c0f = arith.constant 0.0 : f32 46 // expected-remark @below {{max # 1}} 47 %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } 48 ins(%c0f, %biased : f32, tensor<512x512xf32>) 49 outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> 50 func.return %relued : tensor<512x512xf32> 51} 52 53// The module containing named sequences must have an attribute allowing them 54// to enable verification. 55module @transforms attributes { transform.with_named_sequence } { 56 // Entry point. This takes as the only argument the root operation (typically 57 // pass root) given to the transform interpreter. 58 transform.named_sequence @__transform_main( 59 %root: !transform.any_op {transform.consumed}) { 60 61 // Traverses the payload IR associated with the operand handle, invoking 62 // @match_matmul_elemwise on each of the operations. If the named sequence 63 // succeeds, i.e., if none of the nested match (transform) operations 64 // produced a silenceable failure, invokes @print_matmul_elemwise and 65 // forwards the values yielded as arguments of the new invocation. If the 66 // named sequence fails with a silenceable failure, silences it (the message 67 // is forwarded to the debug stream). Definite failures are propagated 68 // immediately and unconditionally, as usual. 69 transform.foreach_match in %root 70 @match_matmul_elemwise -> @print_matmul_elemwise 71 : (!transform.any_op) -> !transform.any_op 72 73 transform.yield 74 } 75 76 // This is an action sequence. 77 transform.named_sequence @print_matmul_elemwise( 78 %matmul: !transform.any_op {transform.readonly}, 79 %add: !transform.any_op {transform.readonly}, 80 %max: !transform.any_op {transform.readonly}, 81 %pos: !transform.param<i32> {transform.readonly}) { 82 transform.debug.emit_param_as_remark %pos, "matmul #" at %matmul 83 : !transform.param<i32>, !transform.any_op 84 transform.debug.emit_param_as_remark %pos, "add #" at %add 85 : !transform.param<i32>, !transform.any_op 86 transform.debug.emit_param_as_remark %pos, "max #" at %max 87 : !transform.param<i32>, !transform.any_op 88 transform.yield 89 } 90 91 // This is also a matcher sequence. It is similarly given an operation to 92 // match and nested operations must succeed in order for a match to be deemed 93 // successful. It starts matching from the last operation in the use-def chain 94 // and goes back because each operand (use) has exactly one definition. 95 transform.named_sequence @match_matmul_elemwise( 96 %last: !transform.any_op {transform.readonly}) 97 -> (!transform.any_op, !transform.any_op, !transform.any_op, 98 !transform.param<i32>) { 99 // The last operation must be an elementwise binary. 100 transform.match.operation_name %last ["linalg.elemwise_binary"] 101 : !transform.any_op 102 103 // One of its operands must be defined by another operation, to which we 104 // will get a handle here. This is achieved thanks to a newly defined 105 // operation that tries to match operands one by one using the match 106 // operations nested in its region. 107 %pos, %middle = transform.match.my.has_operand_satisfying %last 108 : (!transform.any_op) -> (!transform.param<i32>, !transform.any_op) { 109 ^bb0(%operand: !transform.any_value): 110 // The operand must be defined by an operation. 111 %def = transform.get_defining_op %operand 112 : (!transform.any_value) -> !transform.any_op 113 // The defining operation must itself be an elementwise binary. 114 transform.match.operation_name %def ["linalg.elemwise_binary"] 115 : !transform.any_op 116 transform.yield %def : !transform.any_op 117 } 118 119 // And the first operand of that operation must be defined by yet another 120 // operation. 121 %matmul = transform.get_producer_of_operand %middle[0] 122 : (!transform.any_op) -> !transform.any_op 123 // And that operation is a matmul. 124 transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op 125 // We will yield the handles to the matmul and the two elementwise 126 // operations separately. 127 transform.yield %matmul, %middle, %last, %pos 128 : !transform.any_op, !transform.any_op, !transform.any_op, 129 !transform.param<i32> 130 } 131} 132