1// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics 2 3// Matmul as a named operation. 4func.func @named( 5 %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, 6 %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) 7 -> tensor<512x512xf32> { 8 // expected-remark @below {{matmul}} 9 %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) 10 outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> 11 func.return %matmul : tensor<512x512xf32> 12} 13 14// Matmul as a generic operation. 15func.func @generic( 16 %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, 17 %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) 18 -> tensor<512x512xf32> { 19 // expected-remark @below {{matmul}} 20 %matmul = linalg.generic { 21 iterator_types = ["parallel", "parallel", "reduction"], 22 indexing_maps = [ 23 affine_map<(d0, d1, d2) -> (d0, d2)>, 24 affine_map<(d0, d1, d2) -> (d2, d1)>, 25 affine_map<(d0, d1, d2) -> (d0, d1)>] 26 } ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) 27 outs(%output: tensor<512x512xf32>) { 28 ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): 29 %0 = arith.mulf %arg0, %arg1 : f32 30 %1 = arith.addf %0, %arg2 : f32 31 linalg.yield %1 : f32 32 } -> tensor<512x512xf32> 33 return %matmul : tensor<512x512xf32> 34} 35 36// The module containing named sequences must have an attribute allowing them 37// to enable verification. 38module @transforms attributes { transform.with_named_sequence } { 39 // Entry point. This takes as the only argument the root operation (typically 40 // pass root) given to the transform interpreter. 41 transform.named_sequence @__transform_main( 42 %root: !transform.any_op {transform.consumed}) { 43 44 // Traverses the payload IR associated with the operand handle, invoking 45 // @match_matmul_elemwise on each of the operations. If the named sequence 46 // succeeds, i.e., if none of the nested match (transform) operations 47 // produced a silenceable failure, invokes @print_matmul_elemwise and 48 // forwards the values yielded as arguments of the new invocation. If the 49 // named sequence fails with a silenceable failure, silences it (the message 50 // is forwarded to the debug stream). Definite failures are propagated 51 // immediately and unconditionally, as usual. 52 transform.foreach_match in %root 53 @match_generic_matmul -> @print_generic_matmul 54 : (!transform.any_op) -> !transform.any_op 55 56 transform.yield 57 } 58 59 // This is an action sequence. 60 transform.named_sequence @print_generic_matmul( 61 %matmul: !transform.any_op {transform.readonly}) { 62 transform.debug.emit_remark_at %matmul, "matmul" : !transform.any_op 63 transform.yield 64 } 65 66 transform.named_sequence @match_generic_matmul( 67 %candidate: !transform.any_op {transform.readonly}) -> !transform.any_op { 68 // Match a structured linear algebra operation. 69 transform.match.structured %candidate : !transform.any_op { 70 ^bb0(%c: !transform.any_op): 71 // With a rank equal to 3. 72 %rank = transform.match.structured.rank %c 73 : (!transform.any_op) -> !transform.param<i64> 74 %c3 = transform.param.constant 3 : i64 -> !transform.param<i64> 75 transform.match.param.cmpi eq %rank, %c3 : !transform.param<i64> 76 77 // With 2 inputs. 78 %n_ins = transform.match.structured.num_inputs %c 79 : (!transform.any_op) -> !transform.param<i64> 80 %c2 = transform.param.constant 2 : i64 -> !transform.param<i64> 81 transform.match.param.cmpi eq %n_ins, %c2 : !transform.param<i64> 82 83 // With 1 output (note that structured ops in destination passing style 84 // has as many inits as outputs). 85 %n_inits = transform.match.structured.num_inits %c 86 : (!transform.any_op) -> !transform.param<i64> 87 %c1 = transform.param.constant 1 : i64 -> !transform.param<i64> 88 transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64> 89 90 // All inputs and inits are accessed with a projected permutation. 91 transform.match.structured.input %c[all] {projected_permutation} 92 : !transform.any_op 93 transform.match.structured.init %c[0] {projected_permutation} 94 : !transform.any_op 95 96 // The body is a mulf/addf contraction with appropriate dimensions. 97 transform.match.structured.body %c 98 { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op 99 %batch, %lhs, %rhs, %reduction = 100 transform.match.structured.classify_contraction_dims %c 101 : (!transform.any_op) 102 -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, 103 !transform.param<i64>) 104 105 // There is one of lhs, rhs and reduction dimensions and zero batch 106 // dimensions. 107 %n_batch = transform.num_associations %batch 108 : (!transform.param<i64>) -> !transform.param<i64> 109 %n_lhs = transform.num_associations %lhs 110 : (!transform.param<i64>) -> !transform.param<i64> 111 %n_rhs = transform.num_associations %rhs 112 : (!transform.param<i64>) -> !transform.param<i64> 113 %n_reduction = transform.num_associations %reduction 114 : (!transform.param<i64>) -> !transform.param<i64> 115 %c0 = transform.param.constant 0 : i64 -> !transform.param<i64> 116 transform.match.param.cmpi eq %n_batch, %c0 : !transform.param<i64> 117 transform.match.param.cmpi eq %n_lhs, %c1 : !transform.param<i64> 118 transform.match.param.cmpi eq %n_rhs, %c1 : !transform.param<i64> 119 transform.match.param.cmpi eq %n_reduction, %c1 : !transform.param<i64> 120 } 121 transform.yield %candidate : !transform.any_op 122 } 123} 124