1// RUN: mlir-opt %s \ 2// RUN: --pass-pipeline="builtin.module(transform-interpreter{ \ 3// RUN: debug-bind-trailing-args=linalg.matmul,linalg.elemwise_binary},\ 4// RUN: canonicalize,cse,symbol-dce)" |\ 5// RUN: FileCheck %s 6 7// ****************************** IMPORTANT NOTE ****************************** 8// 9// If you are changing this file, you may also need to change 10// mlir/docs/Tutorials/Transform accordingly. 11// 12// **************************************************************************** 13 14// Original function to optimize. 15func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, 16 %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) 17 -> tensor<512x512xf32> { 18 // Matrix-matrix multiplication. 19 %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) 20 outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> 21 22 // Elementwise addition. 23 %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> } 24 ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) 25 outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> 26 27 // Elementwise max with 0 (ReLU). 28 %c0f = arith.constant 0.0 : f32 29 %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } 30 ins(%biased, %c0f : tensor<512x512xf32>, f32) 31 outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> 32 func.return %relued : tensor<512x512xf32> 33} 34 35// CHECK: func @outlined 36// CHECK: linalg.matmul 37// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn<add>} 38 39// CHECK-LABEL: func @fc_relu 40// CHECK: scf.forall 41// CHECK: scf.forall 42// CHECK: %[[SLICE4:.+]] = tensor.extract_slice 43// CHECK: %[[SLICE5:.+]] = tensor.extract_slice 44// CHECK: %[[SLICE6:.+]] = tensor.extract_slice 45// CHECK: %[[SLICE7:.+]] = tensor.extract_slice 46// CHECK: %[[SLICE8:.+]] = tensor.extract_slice 47// CHECK: func.call @outlined(%[[SLICE4]], %[[SLICE5]], %[[SLICE6]], %[[SLICE7]], %[[SLICE8]]) 48// CHECK-NOT: linalg.matmul 49// CHECK-NOT: linalg.elemwise_binary 50// CHECK: scf.forall.in_parallel 51// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn<max_signed>} 52// CHECK: scf.forall.in_parallel 53 54// Declaration of the "microkernel" function that we will be targeting. 55func.func private @microkernel( 56 %lhs: tensor<4x512xf32>, 57 %rhs: tensor<512x4xf32>, 58 %bias: tensor<4x4xf32>, 59 %init: tensor<4x4xf32>, 60 %output: tensor<4x4xf32>) -> tensor<4x4xf32> 61 62module attributes {transform.with_named_sequence} { 63 transform.named_sequence @__transform_main( 64 %arg0: !transform.any_op, 65 %arg1: !transform.op<"linalg.matmul">, 66 %arg2: !transform.op<"linalg.elemwise_binary">) { 67 // Since the %arg2 handle is associated with both elementwise operations, 68 // we need to split it into two handles so we can target only the second 69 // elementwise operation. 70 %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) 71 -> (!transform.any_op, !transform.any_op) 72 73 // The actual tiling transformation takes tile sizes as attributes. It produces a 74 // handle to the loop generated during tiling. 75 %tiled, %loop = transform.structured.tile_using_forall %max tile_sizes [8, 32] 76 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 77 78 // We can now fuse the other operations into the loop. Here, we fuse 79 // operations one-by-one. This requires the operation that is being fused 80 // to define the value used within the loop, so the order of such fusions 81 // is important. We could also use "transform.merge_handles" to obtain 82 // a single handle to all operations and give it to `fuse_into_containing_op` 83 // that would take care of the ordering in this case. 84 %add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop 85 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 86 %matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2 87 : (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op) 88 89 // Tile again to get the desired size. Note that this time this tiles the 90 // "add" operation and fuses matmul into the loop, but doesn't affect the 91 // "max" operation. This illustrates the precise targeting with the transform 92 // dialect. Otherwise, it is difficult to differentiate "add" and "max", both 93 // of which having the same kind. 94 %tiled_second, %loop_second = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4] 95 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 96 %matmul_fused_2, %loop_second_2 = 97 transform.structured.fuse_into_containing_op %matmul_fused into %loop_second 98 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 99 100 // Since outlining is currently only implemented for region-holding operations 101 // such as loops, use tiling to size 1 to materialize the outer loop that is 102 // going to be outlined. 103 %_0, %loop_third = transform.structured.tile_using_forall %tiled_second tile_sizes [1] 104 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 105 %_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third 106 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 107 %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} 108 : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">) 109 110 transform.yield 111 } 112} 113