1// RUN: transform-opt-ch2 %s \ 2// RUN: --pass-pipeline="builtin.module(transform-interpreter{ \ 3// RUN: debug-bind-trailing-args=linalg.matmul,linalg.elemwise_binary},\ 4// RUN: canonicalize,cse,symbol-dce)" |\ 5// RUN: FileCheck %s 6 7// ****************************** IMPORTANT NOTE ****************************** 8// 9// If you are changing this file, you may also need to change 10// mlir/docs/Tutorials/Transform accordingly. 11// 12// **************************************************************************** 13 14// Original function to optimize. 15func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, 16 %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) 17 -> tensor<512x512xf32> { 18 // Matrix-matrix multiplication. 19 %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) 20 outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> 21 22 // Elementwise addition. 23 %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> } 24 ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) 25 outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> 26 27 // Elementwise max with 0 (ReLU). 28 %c0f = arith.constant 0.0 : f32 29 %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } 30 ins(%biased, %c0f : tensor<512x512xf32>, f32) 31 outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> 32 func.return %relued : tensor<512x512xf32> 33} 34 35// CHECK-LABEL: func @fc_relu 36// CHECK: scf.forall 37// CHECK: scf.forall 38// CHECK: %[[SLICE4:.+]] = tensor.extract_slice 39// CHECK: %[[SLICE5:.+]] = tensor.extract_slice 40// CHECK: %[[SLICE6:.+]] = tensor.extract_slice 41// CHECK: %[[SLICE7:.+]] = tensor.extract_slice 42// CHECK: %[[SLICE8:.+]] = tensor.extract_slice 43// CHECK: func.call @microkernel(%[[SLICE4]], %[[SLICE5]], %[[SLICE6]], %[[SLICE7]], %[[SLICE8]]) 44// CHECK-NOT: linalg.matmul 45// CHECK-NOT: linalg.elemwise_binary 46// CHECK: scf.forall.in_parallel 47// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn<max_signed>} 48// CHECK: scf.forall.in_parallel 49 50// Declaration of the "microkernel" function that we will be targeting. 51func.func private @microkernel( 52 %lhs: tensor<4x512xf32>, 53 %rhs: tensor<512x4xf32>, 54 %bias: tensor<4x4xf32>, 55 %init: tensor<4x4xf32>, 56 %output: tensor<4x4xf32>) -> tensor<4x4xf32> 57 58module attributes {transform.with_named_sequence} { 59 transform.named_sequence @__transform_main( 60 %arg0: !transform.any_op, 61 %arg1: !transform.op<"linalg.matmul">, 62 %arg2: !transform.op<"linalg.elemwise_binary">) { 63 // Since the %arg2 handle is associated with both elementwise operations, 64 // we need to split it into two handles so we can target only the second 65 // elementwise operation. 66 %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) 67 -> (!transform.any_op, !transform.any_op) 68 69 // The actual tiling transformation takes tile sizes as attributes. It produces a 70 // handle to the loop generated during tiling. 71 %tiled, %loop = transform.structured.tile_using_forall %max tile_sizes [8, 32] 72 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 73 74 // We can now fuse the other operations into the loop. Here, we fuse 75 // operations one-by-one. This requires the operation that is being fused 76 // to define the value used within the loop, so the order of such fusions 77 // is important. We could also use "transform.merge_handles" to obtain 78 // a single handle to all operations and give it to `fuse_into_containing_op` 79 // that would take care of the ordering in this case. 80 %add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop 81 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 82 %matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2 83 : (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op) 84 85 // Tile again to get the desired size. Note that this time this tiles the 86 // "add" operation and fuses matmul into the loop, but doesn't affect the 87 // "max" operation. This illustrates the precise targeting with the transform 88 // dialect. Otherwise, it is difficult to differentiate "add" and "max", both 89 // of which having the same kind. 90 %tiled_second, %loop_second = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4] 91 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 92 %matmul_fused_2, %loop_second_2 = 93 transform.structured.fuse_into_containing_op %matmul_fused into %loop_second 94 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 95 96 // Since outlining is currently only implemented for region-holding operations 97 // such as loops, use tiling to size 1 to materialize the outer loop that is 98 // going to be outlined. 99 %_0, %loop_third = transform.structured.tile_using_forall %tiled_second tile_sizes [1] 100 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 101 %_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third 102 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 103 %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} 104 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 105 106 // Rewrite the call target. 107 transform.my.change_call_target %call, "microkernel" : !transform.any_op 108 109 transform.yield 110 } 111} 112