xref: /llvm-project/mlir/test/Examples/transform/Ch4/features.mlir (revision 2798b72ae7e5caad793169b77cbac47fe2362d0f)
14cb2ef4fSOleksandr "Alex" Zinenko// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics
24cb2ef4fSOleksandr "Alex" Zinenko
34cb2ef4fSOleksandr "Alex" Zinenko// Matmul as a named operation.
44cb2ef4fSOleksandr "Alex" Zinenkofunc.func @named(
54cb2ef4fSOleksandr "Alex" Zinenko    %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
64cb2ef4fSOleksandr "Alex" Zinenko    %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
74cb2ef4fSOleksandr "Alex" Zinenko    -> tensor<512x512xf32> {
84cb2ef4fSOleksandr "Alex" Zinenko  // expected-remark @below {{matmul}}
94cb2ef4fSOleksandr "Alex" Zinenko  %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
104cb2ef4fSOleksandr "Alex" Zinenko                          outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
114cb2ef4fSOleksandr "Alex" Zinenko  func.return %matmul : tensor<512x512xf32>
124cb2ef4fSOleksandr "Alex" Zinenko}
134cb2ef4fSOleksandr "Alex" Zinenko
144cb2ef4fSOleksandr "Alex" Zinenko// Matmul as a generic operation.
154cb2ef4fSOleksandr "Alex" Zinenkofunc.func @generic(
164cb2ef4fSOleksandr "Alex" Zinenko    %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
174cb2ef4fSOleksandr "Alex" Zinenko    %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
184cb2ef4fSOleksandr "Alex" Zinenko    -> tensor<512x512xf32> {
194cb2ef4fSOleksandr "Alex" Zinenko  // expected-remark @below {{matmul}}
204cb2ef4fSOleksandr "Alex" Zinenko  %matmul = linalg.generic {
214cb2ef4fSOleksandr "Alex" Zinenko    iterator_types = ["parallel", "parallel", "reduction"],
224cb2ef4fSOleksandr "Alex" Zinenko    indexing_maps = [
234cb2ef4fSOleksandr "Alex" Zinenko      affine_map<(d0, d1, d2) -> (d0, d2)>,
244cb2ef4fSOleksandr "Alex" Zinenko      affine_map<(d0, d1, d2) -> (d2, d1)>,
254cb2ef4fSOleksandr "Alex" Zinenko      affine_map<(d0, d1, d2) -> (d0, d1)>]
264cb2ef4fSOleksandr "Alex" Zinenko  } ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
274cb2ef4fSOleksandr "Alex" Zinenko    outs(%output: tensor<512x512xf32>) {
284cb2ef4fSOleksandr "Alex" Zinenko  ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
294cb2ef4fSOleksandr "Alex" Zinenko    %0 = arith.mulf %arg0, %arg1 : f32
304cb2ef4fSOleksandr "Alex" Zinenko    %1 = arith.addf %0, %arg2 : f32
314cb2ef4fSOleksandr "Alex" Zinenko    linalg.yield %1 : f32
324cb2ef4fSOleksandr "Alex" Zinenko  } -> tensor<512x512xf32>
334cb2ef4fSOleksandr "Alex" Zinenko  return %matmul : tensor<512x512xf32>
344cb2ef4fSOleksandr "Alex" Zinenko}
354cb2ef4fSOleksandr "Alex" Zinenko
364cb2ef4fSOleksandr "Alex" Zinenko// The module containing named sequences must have an attribute allowing them
374cb2ef4fSOleksandr "Alex" Zinenko// to enable verification.
384cb2ef4fSOleksandr "Alex" Zinenkomodule @transforms attributes { transform.with_named_sequence } {
394cb2ef4fSOleksandr "Alex" Zinenko  // Entry point. This takes as the only argument the root operation (typically
404cb2ef4fSOleksandr "Alex" Zinenko  // pass root) given to the transform interpreter.
414cb2ef4fSOleksandr "Alex" Zinenko  transform.named_sequence @__transform_main(
424cb2ef4fSOleksandr "Alex" Zinenko      %root: !transform.any_op {transform.consumed}) {
434cb2ef4fSOleksandr "Alex" Zinenko
444cb2ef4fSOleksandr "Alex" Zinenko    // Traverses the payload IR associated with the operand handle, invoking
454cb2ef4fSOleksandr "Alex" Zinenko    // @match_matmul_elemwise on each of the operations. If the named sequence
464cb2ef4fSOleksandr "Alex" Zinenko    // succeeds, i.e., if none of the nested match (transform) operations
474cb2ef4fSOleksandr "Alex" Zinenko    // produced a silenceable failure, invokes @print_matmul_elemwise and
484cb2ef4fSOleksandr "Alex" Zinenko    // forwards the values yielded as arguments of the new invocation. If the
494cb2ef4fSOleksandr "Alex" Zinenko    // named sequence fails with a silenceable failure, silences it (the message
504cb2ef4fSOleksandr "Alex" Zinenko    // is forwarded to the debug stream). Definite failures are propagated
514cb2ef4fSOleksandr "Alex" Zinenko    // immediately and unconditionally, as usual.
524cb2ef4fSOleksandr "Alex" Zinenko    transform.foreach_match in %root
534cb2ef4fSOleksandr "Alex" Zinenko      @match_generic_matmul -> @print_generic_matmul
544cb2ef4fSOleksandr "Alex" Zinenko      : (!transform.any_op) -> !transform.any_op
554cb2ef4fSOleksandr "Alex" Zinenko
564cb2ef4fSOleksandr "Alex" Zinenko    transform.yield
574cb2ef4fSOleksandr "Alex" Zinenko  }
584cb2ef4fSOleksandr "Alex" Zinenko
594cb2ef4fSOleksandr "Alex" Zinenko  // This is an action sequence.
604cb2ef4fSOleksandr "Alex" Zinenko  transform.named_sequence @print_generic_matmul(
614cb2ef4fSOleksandr "Alex" Zinenko      %matmul: !transform.any_op {transform.readonly}) {
62*2798b72aSOleksandr "Alex" Zinenko    transform.debug.emit_remark_at %matmul, "matmul" : !transform.any_op
634cb2ef4fSOleksandr "Alex" Zinenko    transform.yield
644cb2ef4fSOleksandr "Alex" Zinenko  }
654cb2ef4fSOleksandr "Alex" Zinenko
664cb2ef4fSOleksandr "Alex" Zinenko  transform.named_sequence @match_generic_matmul(
674cb2ef4fSOleksandr "Alex" Zinenko      %candidate: !transform.any_op {transform.readonly}) -> !transform.any_op {
684cb2ef4fSOleksandr "Alex" Zinenko    // Match a structured linear algebra operation.
694cb2ef4fSOleksandr "Alex" Zinenko    transform.match.structured %candidate : !transform.any_op {
704cb2ef4fSOleksandr "Alex" Zinenko    ^bb0(%c: !transform.any_op):
714cb2ef4fSOleksandr "Alex" Zinenko      // With a rank equal to 3.
724cb2ef4fSOleksandr "Alex" Zinenko      %rank = transform.match.structured.rank %c
734cb2ef4fSOleksandr "Alex" Zinenko        : (!transform.any_op) -> !transform.param<i64>
744cb2ef4fSOleksandr "Alex" Zinenko      %c3 = transform.param.constant 3 : i64 -> !transform.param<i64>
754cb2ef4fSOleksandr "Alex" Zinenko      transform.match.param.cmpi eq %rank, %c3 : !transform.param<i64>
764cb2ef4fSOleksandr "Alex" Zinenko
774cb2ef4fSOleksandr "Alex" Zinenko      // With 2 inputs.
784cb2ef4fSOleksandr "Alex" Zinenko      %n_ins = transform.match.structured.num_inputs %c
794cb2ef4fSOleksandr "Alex" Zinenko        : (!transform.any_op) -> !transform.param<i64>
804cb2ef4fSOleksandr "Alex" Zinenko      %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
814cb2ef4fSOleksandr "Alex" Zinenko      transform.match.param.cmpi eq %n_ins, %c2 : !transform.param<i64>
824cb2ef4fSOleksandr "Alex" Zinenko
834cb2ef4fSOleksandr "Alex" Zinenko      // With 1 output (note that structured ops in destination passing style
844cb2ef4fSOleksandr "Alex" Zinenko      // has as many inits as outputs).
854cb2ef4fSOleksandr "Alex" Zinenko      %n_inits = transform.match.structured.num_inits %c
864cb2ef4fSOleksandr "Alex" Zinenko        : (!transform.any_op) -> !transform.param<i64>
874cb2ef4fSOleksandr "Alex" Zinenko      %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
884cb2ef4fSOleksandr "Alex" Zinenko      transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64>
894cb2ef4fSOleksandr "Alex" Zinenko
904cb2ef4fSOleksandr "Alex" Zinenko      // All inputs and inits are accessed with a projected permutation.
914cb2ef4fSOleksandr "Alex" Zinenko      transform.match.structured.input %c[all] {projected_permutation}
924cb2ef4fSOleksandr "Alex" Zinenko        : !transform.any_op
934cb2ef4fSOleksandr "Alex" Zinenko      transform.match.structured.init %c[0] {projected_permutation}
944cb2ef4fSOleksandr "Alex" Zinenko        : !transform.any_op
954cb2ef4fSOleksandr "Alex" Zinenko
964cb2ef4fSOleksandr "Alex" Zinenko      // The body is a mulf/addf contraction with appropriate dimensions.
974cb2ef4fSOleksandr "Alex" Zinenko      transform.match.structured.body %c
984cb2ef4fSOleksandr "Alex" Zinenko        { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op
994cb2ef4fSOleksandr "Alex" Zinenko      %batch, %lhs, %rhs, %reduction =
1004cb2ef4fSOleksandr "Alex" Zinenko      transform.match.structured.classify_contraction_dims %c
1014cb2ef4fSOleksandr "Alex" Zinenko        : (!transform.any_op)
1024cb2ef4fSOleksandr "Alex" Zinenko        -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
1034cb2ef4fSOleksandr "Alex" Zinenko            !transform.param<i64>)
1044cb2ef4fSOleksandr "Alex" Zinenko
1054cb2ef4fSOleksandr "Alex" Zinenko      // There is one of lhs, rhs and reduction dimensions and zero batch
1064cb2ef4fSOleksandr "Alex" Zinenko      // dimensions.
1074cb2ef4fSOleksandr "Alex" Zinenko      %n_batch = transform.num_associations %batch
1084cb2ef4fSOleksandr "Alex" Zinenko        : (!transform.param<i64>) -> !transform.param<i64>
1094cb2ef4fSOleksandr "Alex" Zinenko      %n_lhs = transform.num_associations %lhs
1104cb2ef4fSOleksandr "Alex" Zinenko        : (!transform.param<i64>) -> !transform.param<i64>
1114cb2ef4fSOleksandr "Alex" Zinenko      %n_rhs = transform.num_associations %rhs
1124cb2ef4fSOleksandr "Alex" Zinenko        : (!transform.param<i64>) -> !transform.param<i64>
1134cb2ef4fSOleksandr "Alex" Zinenko      %n_reduction = transform.num_associations %reduction
1144cb2ef4fSOleksandr "Alex" Zinenko        : (!transform.param<i64>) -> !transform.param<i64>
1154cb2ef4fSOleksandr "Alex" Zinenko      %c0 = transform.param.constant 0 : i64 -> !transform.param<i64>
1164cb2ef4fSOleksandr "Alex" Zinenko      transform.match.param.cmpi eq %n_batch, %c0 : !transform.param<i64>
1174cb2ef4fSOleksandr "Alex" Zinenko      transform.match.param.cmpi eq %n_lhs, %c1 : !transform.param<i64>
1184cb2ef4fSOleksandr "Alex" Zinenko      transform.match.param.cmpi eq %n_rhs, %c1 : !transform.param<i64>
1194cb2ef4fSOleksandr "Alex" Zinenko      transform.match.param.cmpi eq %n_reduction, %c1 : !transform.param<i64>
1204cb2ef4fSOleksandr "Alex" Zinenko    }
1214cb2ef4fSOleksandr "Alex" Zinenko    transform.yield %candidate : !transform.any_op
1224cb2ef4fSOleksandr "Alex" Zinenko  }
1234cb2ef4fSOleksandr "Alex" Zinenko}
124