1// RUN: mlir-opt %s \ 2// RUN: --pass-pipeline="builtin.module(transform-interpreter{ \ 3// RUN: debug-bind-trailing-args=linalg.matmul,linalg.elemwise_binary},\ 4// RUN: canonicalize,cse,symbol-dce)" \ 5// RUN: --split-input-file --verify-diagnostics 6// ****************************** IMPORTANT NOTE ****************************** 7// 8// If you are changing this file, you may also need to change 9// mlir/docs/Tutorials/Transform accordingly. 10// 11// **************************************************************************** 12 13// Original function to optimize. 14func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, 15 %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) 16 -> tensor<512x512xf32> { 17 // Matrix-matrix multiplication. 18 19 // expected-note @below {{nested payload op}} 20 %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) 21 outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> 22 23 // Elementwise addition. 24 25 // expected-note @below {{ancestor payload op}} 26 %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> } 27 ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) 28 outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> 29 30 // Elementwise max with 0 (ReLU). 31 %c0f = arith.constant 0.0 : f32 32 %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } 33 ins(%biased, %c0f : tensor<512x512xf32>, f32) 34 outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> 35 func.return %relued : tensor<512x512xf32> 36} 37 38// Declaration of the "microkernel" function that we will be targeting. 39func.func private @microkernel( 40 %lhs: tensor<4x512xf32>, 41 %rhs: tensor<512x4xf32>, 42 %bias: tensor<4x4xf32>, 43 %init: tensor<4x4xf32>, 44 %output: tensor<4x4xf32>) -> tensor<4x4xf32> 45 46module attributes {transform.with_named_sequence} { 47 transform.named_sequence @__transform_main( 48 %arg0: !transform.any_op, 49 %arg1: !transform.op<"linalg.matmul">, 50 %arg2: !transform.op<"linalg.elemwise_binary">) { 51 // Since the %arg2 handle is associated with both elementwise operations, 52 // we need to split it into two handles so we can target only the second 53 // elementwise operation. 54 %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) 55 -> (!transform.any_op, !transform.any_op) 56 57 // The actual tiling transformation takes tile sizes as attributes. It produces a 58 // handle to the loop generated during tiling. 59 %tiled, %loop = transform.structured.tile_using_forall %max tile_sizes [8, 32] 60 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 61 62 // We can now fuse the other operations into the loop. Here, we fuse 63 // operations one-by-one. This requires the operation that is being fused 64 // to define the value used within the loop, so the order of such fusions 65 // is important. We could also use "transform.merge_handles" to obtain 66 // a single handle to all operations and give it to `fuse_into_containing_op` 67 // that would take care of the ordering in this case. 68 %add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop 69 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 70 %matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2 71 : (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op) 72 73 // Tile again to get the desired size. Note that this time this tiles the 74 // "add" operation and fuses matmul into the loop, but doesn't affect the 75 // "max" operation. This illustrates the precise targeting with the transform 76 // dialect. Otherwise, it is difficult to differentiate "add" and "max", both 77 // of which having the same kind. 78 %tiled_second, %loop_second = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4] 79 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 80 %matmul_fused_2, %loop_second_2 = 81 transform.structured.fuse_into_containing_op %matmul_fused into %loop_second 82 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 83 84 // Since outlining is currently only implemented for region-holding operations 85 // such as loops, use tiling to size 1 to materialize the outer loop that is 86 // going to be outlined. 87 %_0, %loop_third = transform.structured.tile_using_forall %tiled_second tile_sizes [1] 88 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 89 // expected-note @below {{handle to invalidated ops}} 90 %f, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third 91 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 92 93 // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} 94 %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} 95 : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">) 96 97 // expected-error @below {{uses a handle invalidated by a previously executed transform op}} 98 transform.debug.emit_remark_at %f, "fused" : !transform.any_op 99 100 transform.yield 101 } 102} 103