14cb2ef4fSOleksandr "Alex" Zinenko// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics 24cb2ef4fSOleksandr "Alex" Zinenko 34cb2ef4fSOleksandr "Alex" Zinenko// Matmul as a named operation. 44cb2ef4fSOleksandr "Alex" Zinenkofunc.func @named( 54cb2ef4fSOleksandr "Alex" Zinenko %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, 64cb2ef4fSOleksandr "Alex" Zinenko %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) 74cb2ef4fSOleksandr "Alex" Zinenko -> tensor<512x512xf32> { 84cb2ef4fSOleksandr "Alex" Zinenko // expected-remark @below {{matmul}} 94cb2ef4fSOleksandr "Alex" Zinenko %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) 104cb2ef4fSOleksandr "Alex" Zinenko outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> 114cb2ef4fSOleksandr "Alex" Zinenko func.return %matmul : tensor<512x512xf32> 124cb2ef4fSOleksandr "Alex" Zinenko} 134cb2ef4fSOleksandr "Alex" Zinenko 144cb2ef4fSOleksandr "Alex" Zinenko// Matmul as a generic operation. 154cb2ef4fSOleksandr "Alex" Zinenkofunc.func @generic( 164cb2ef4fSOleksandr "Alex" Zinenko %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, 174cb2ef4fSOleksandr "Alex" Zinenko %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) 184cb2ef4fSOleksandr "Alex" Zinenko -> tensor<512x512xf32> { 194cb2ef4fSOleksandr "Alex" Zinenko // expected-remark @below {{matmul}} 204cb2ef4fSOleksandr "Alex" Zinenko %matmul = linalg.generic { 214cb2ef4fSOleksandr "Alex" Zinenko iterator_types = ["parallel", "parallel", "reduction"], 224cb2ef4fSOleksandr "Alex" Zinenko indexing_maps = [ 234cb2ef4fSOleksandr "Alex" Zinenko affine_map<(d0, d1, d2) -> (d0, d2)>, 244cb2ef4fSOleksandr "Alex" Zinenko affine_map<(d0, d1, d2) -> (d2, d1)>, 254cb2ef4fSOleksandr "Alex" Zinenko affine_map<(d0, d1, d2) -> (d0, d1)>] 264cb2ef4fSOleksandr "Alex" Zinenko } ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) 274cb2ef4fSOleksandr "Alex" Zinenko outs(%output: tensor<512x512xf32>) { 284cb2ef4fSOleksandr "Alex" Zinenko ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): 294cb2ef4fSOleksandr "Alex" Zinenko %0 = arith.mulf %arg0, %arg1 : f32 304cb2ef4fSOleksandr "Alex" Zinenko %1 = arith.addf %0, %arg2 : f32 314cb2ef4fSOleksandr "Alex" Zinenko linalg.yield %1 : f32 324cb2ef4fSOleksandr "Alex" Zinenko } -> tensor<512x512xf32> 334cb2ef4fSOleksandr "Alex" Zinenko return %matmul : tensor<512x512xf32> 344cb2ef4fSOleksandr "Alex" Zinenko} 354cb2ef4fSOleksandr "Alex" Zinenko 364cb2ef4fSOleksandr "Alex" Zinenko// The module containing named sequences must have an attribute allowing them 374cb2ef4fSOleksandr "Alex" Zinenko// to enable verification. 384cb2ef4fSOleksandr "Alex" Zinenkomodule @transforms attributes { transform.with_named_sequence } { 394cb2ef4fSOleksandr "Alex" Zinenko // Entry point. This takes as the only argument the root operation (typically 404cb2ef4fSOleksandr "Alex" Zinenko // pass root) given to the transform interpreter. 414cb2ef4fSOleksandr "Alex" Zinenko transform.named_sequence @__transform_main( 424cb2ef4fSOleksandr "Alex" Zinenko %root: !transform.any_op {transform.consumed}) { 434cb2ef4fSOleksandr "Alex" Zinenko 444cb2ef4fSOleksandr "Alex" Zinenko // Traverses the payload IR associated with the operand handle, invoking 454cb2ef4fSOleksandr "Alex" Zinenko // @match_matmul_elemwise on each of the operations. If the named sequence 464cb2ef4fSOleksandr "Alex" Zinenko // succeeds, i.e., if none of the nested match (transform) operations 474cb2ef4fSOleksandr "Alex" Zinenko // produced a silenceable failure, invokes @print_matmul_elemwise and 484cb2ef4fSOleksandr "Alex" Zinenko // forwards the values yielded as arguments of the new invocation. If the 494cb2ef4fSOleksandr "Alex" Zinenko // named sequence fails with a silenceable failure, silences it (the message 504cb2ef4fSOleksandr "Alex" Zinenko // is forwarded to the debug stream). Definite failures are propagated 514cb2ef4fSOleksandr "Alex" Zinenko // immediately and unconditionally, as usual. 524cb2ef4fSOleksandr "Alex" Zinenko transform.foreach_match in %root 534cb2ef4fSOleksandr "Alex" Zinenko @match_generic_matmul -> @print_generic_matmul 544cb2ef4fSOleksandr "Alex" Zinenko : (!transform.any_op) -> !transform.any_op 554cb2ef4fSOleksandr "Alex" Zinenko 564cb2ef4fSOleksandr "Alex" Zinenko transform.yield 574cb2ef4fSOleksandr "Alex" Zinenko } 584cb2ef4fSOleksandr "Alex" Zinenko 594cb2ef4fSOleksandr "Alex" Zinenko // This is an action sequence. 604cb2ef4fSOleksandr "Alex" Zinenko transform.named_sequence @print_generic_matmul( 614cb2ef4fSOleksandr "Alex" Zinenko %matmul: !transform.any_op {transform.readonly}) { 62*2798b72aSOleksandr "Alex" Zinenko transform.debug.emit_remark_at %matmul, "matmul" : !transform.any_op 634cb2ef4fSOleksandr "Alex" Zinenko transform.yield 644cb2ef4fSOleksandr "Alex" Zinenko } 654cb2ef4fSOleksandr "Alex" Zinenko 664cb2ef4fSOleksandr "Alex" Zinenko transform.named_sequence @match_generic_matmul( 674cb2ef4fSOleksandr "Alex" Zinenko %candidate: !transform.any_op {transform.readonly}) -> !transform.any_op { 684cb2ef4fSOleksandr "Alex" Zinenko // Match a structured linear algebra operation. 694cb2ef4fSOleksandr "Alex" Zinenko transform.match.structured %candidate : !transform.any_op { 704cb2ef4fSOleksandr "Alex" Zinenko ^bb0(%c: !transform.any_op): 714cb2ef4fSOleksandr "Alex" Zinenko // With a rank equal to 3. 724cb2ef4fSOleksandr "Alex" Zinenko %rank = transform.match.structured.rank %c 734cb2ef4fSOleksandr "Alex" Zinenko : (!transform.any_op) -> !transform.param<i64> 744cb2ef4fSOleksandr "Alex" Zinenko %c3 = transform.param.constant 3 : i64 -> !transform.param<i64> 754cb2ef4fSOleksandr "Alex" Zinenko transform.match.param.cmpi eq %rank, %c3 : !transform.param<i64> 764cb2ef4fSOleksandr "Alex" Zinenko 774cb2ef4fSOleksandr "Alex" Zinenko // With 2 inputs. 784cb2ef4fSOleksandr "Alex" Zinenko %n_ins = transform.match.structured.num_inputs %c 794cb2ef4fSOleksandr "Alex" Zinenko : (!transform.any_op) -> !transform.param<i64> 804cb2ef4fSOleksandr "Alex" Zinenko %c2 = transform.param.constant 2 : i64 -> !transform.param<i64> 814cb2ef4fSOleksandr "Alex" Zinenko transform.match.param.cmpi eq %n_ins, %c2 : !transform.param<i64> 824cb2ef4fSOleksandr "Alex" Zinenko 834cb2ef4fSOleksandr "Alex" Zinenko // With 1 output (note that structured ops in destination passing style 844cb2ef4fSOleksandr "Alex" Zinenko // has as many inits as outputs). 854cb2ef4fSOleksandr "Alex" Zinenko %n_inits = transform.match.structured.num_inits %c 864cb2ef4fSOleksandr "Alex" Zinenko : (!transform.any_op) -> !transform.param<i64> 874cb2ef4fSOleksandr "Alex" Zinenko %c1 = transform.param.constant 1 : i64 -> !transform.param<i64> 884cb2ef4fSOleksandr "Alex" Zinenko transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64> 894cb2ef4fSOleksandr "Alex" Zinenko 904cb2ef4fSOleksandr "Alex" Zinenko // All inputs and inits are accessed with a projected permutation. 914cb2ef4fSOleksandr "Alex" Zinenko transform.match.structured.input %c[all] {projected_permutation} 924cb2ef4fSOleksandr "Alex" Zinenko : !transform.any_op 934cb2ef4fSOleksandr "Alex" Zinenko transform.match.structured.init %c[0] {projected_permutation} 944cb2ef4fSOleksandr "Alex" Zinenko : !transform.any_op 954cb2ef4fSOleksandr "Alex" Zinenko 964cb2ef4fSOleksandr "Alex" Zinenko // The body is a mulf/addf contraction with appropriate dimensions. 974cb2ef4fSOleksandr "Alex" Zinenko transform.match.structured.body %c 984cb2ef4fSOleksandr "Alex" Zinenko { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op 994cb2ef4fSOleksandr "Alex" Zinenko %batch, %lhs, %rhs, %reduction = 1004cb2ef4fSOleksandr "Alex" Zinenko transform.match.structured.classify_contraction_dims %c 1014cb2ef4fSOleksandr "Alex" Zinenko : (!transform.any_op) 1024cb2ef4fSOleksandr "Alex" Zinenko -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, 1034cb2ef4fSOleksandr "Alex" Zinenko !transform.param<i64>) 1044cb2ef4fSOleksandr "Alex" Zinenko 1054cb2ef4fSOleksandr "Alex" Zinenko // There is one of lhs, rhs and reduction dimensions and zero batch 1064cb2ef4fSOleksandr "Alex" Zinenko // dimensions. 1074cb2ef4fSOleksandr "Alex" Zinenko %n_batch = transform.num_associations %batch 1084cb2ef4fSOleksandr "Alex" Zinenko : (!transform.param<i64>) -> !transform.param<i64> 1094cb2ef4fSOleksandr "Alex" Zinenko %n_lhs = transform.num_associations %lhs 1104cb2ef4fSOleksandr "Alex" Zinenko : (!transform.param<i64>) -> !transform.param<i64> 1114cb2ef4fSOleksandr "Alex" Zinenko %n_rhs = transform.num_associations %rhs 1124cb2ef4fSOleksandr "Alex" Zinenko : (!transform.param<i64>) -> !transform.param<i64> 1134cb2ef4fSOleksandr "Alex" Zinenko %n_reduction = transform.num_associations %reduction 1144cb2ef4fSOleksandr "Alex" Zinenko : (!transform.param<i64>) -> !transform.param<i64> 1154cb2ef4fSOleksandr "Alex" Zinenko %c0 = transform.param.constant 0 : i64 -> !transform.param<i64> 1164cb2ef4fSOleksandr "Alex" Zinenko transform.match.param.cmpi eq %n_batch, %c0 : !transform.param<i64> 1174cb2ef4fSOleksandr "Alex" Zinenko transform.match.param.cmpi eq %n_lhs, %c1 : !transform.param<i64> 1184cb2ef4fSOleksandr "Alex" Zinenko transform.match.param.cmpi eq %n_rhs, %c1 : !transform.param<i64> 1194cb2ef4fSOleksandr "Alex" Zinenko transform.match.param.cmpi eq %n_reduction, %c1 : !transform.param<i64> 1204cb2ef4fSOleksandr "Alex" Zinenko } 1214cb2ef4fSOleksandr "Alex" Zinenko transform.yield %candidate : !transform.any_op 1224cb2ef4fSOleksandr "Alex" Zinenko } 1234cb2ef4fSOleksandr "Alex" Zinenko} 124