168ae0d78SAlex Zinenko// RUN: mlir-opt %s \ 2*b33b91a2SOleksandr "Alex" Zinenko// RUN: --pass-pipeline="builtin.module(transform-interpreter{ \ 3*b33b91a2SOleksandr "Alex" Zinenko// RUN: debug-bind-trailing-args=linalg.matmul,linalg.elemwise_binary},\ 4*b33b91a2SOleksandr "Alex" Zinenko// RUN: canonicalize,cse,symbol-dce)" |\ 568ae0d78SAlex Zinenko// RUN: FileCheck %s 668ae0d78SAlex Zinenko 768ae0d78SAlex Zinenko// ****************************** IMPORTANT NOTE ****************************** 868ae0d78SAlex Zinenko// 968ae0d78SAlex Zinenko// If you are changing this file, you may also need to change 1068ae0d78SAlex Zinenko// mlir/docs/Tutorials/Transform accordingly. 1168ae0d78SAlex Zinenko// 1268ae0d78SAlex Zinenko// **************************************************************************** 1368ae0d78SAlex Zinenko 1468ae0d78SAlex Zinenko// Original function to optimize. 1568ae0d78SAlex Zinenkofunc.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, 1668ae0d78SAlex Zinenko %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) 1768ae0d78SAlex Zinenko -> tensor<512x512xf32> { 1868ae0d78SAlex Zinenko // Matrix-matrix multiplication. 1968ae0d78SAlex Zinenko %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) 2068ae0d78SAlex Zinenko outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> 2168ae0d78SAlex Zinenko 2268ae0d78SAlex Zinenko // Elementwise addition. 2368ae0d78SAlex Zinenko %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> } 2468ae0d78SAlex Zinenko ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) 2568ae0d78SAlex Zinenko outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> 2668ae0d78SAlex Zinenko 2768ae0d78SAlex Zinenko // Elementwise max with 0 (ReLU). 2868ae0d78SAlex Zinenko %c0f = arith.constant 0.0 : f32 2968ae0d78SAlex Zinenko %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } 3068ae0d78SAlex Zinenko ins(%biased, %c0f : tensor<512x512xf32>, f32) 3168ae0d78SAlex Zinenko outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> 3268ae0d78SAlex Zinenko func.return %relued : tensor<512x512xf32> 3368ae0d78SAlex Zinenko} 3468ae0d78SAlex Zinenko 3568ae0d78SAlex Zinenko// CHECK: func @outlined 3668ae0d78SAlex Zinenko// CHECK: linalg.matmul 3768ae0d78SAlex Zinenko// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn<add>} 3868ae0d78SAlex Zinenko 3968ae0d78SAlex Zinenko// CHECK-LABEL: func @fc_relu 4068ae0d78SAlex Zinenko// CHECK: scf.forall 4168ae0d78SAlex Zinenko// CHECK: scf.forall 4268ae0d78SAlex Zinenko// CHECK: %[[SLICE4:.+]] = tensor.extract_slice 4368ae0d78SAlex Zinenko// CHECK: %[[SLICE5:.+]] = tensor.extract_slice 4468ae0d78SAlex Zinenko// CHECK: %[[SLICE6:.+]] = tensor.extract_slice 4568ae0d78SAlex Zinenko// CHECK: %[[SLICE7:.+]] = tensor.extract_slice 4668ae0d78SAlex Zinenko// CHECK: %[[SLICE8:.+]] = tensor.extract_slice 4768ae0d78SAlex Zinenko// CHECK: func.call @outlined(%[[SLICE4]], %[[SLICE5]], %[[SLICE6]], %[[SLICE7]], %[[SLICE8]]) 4868ae0d78SAlex Zinenko// CHECK-NOT: linalg.matmul 4968ae0d78SAlex Zinenko// CHECK-NOT: linalg.elemwise_binary 5068ae0d78SAlex Zinenko// CHECK: scf.forall.in_parallel 5168ae0d78SAlex Zinenko// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn<max_signed>} 5268ae0d78SAlex Zinenko// CHECK: scf.forall.in_parallel 5368ae0d78SAlex Zinenko 5468ae0d78SAlex Zinenko// Declaration of the "microkernel" function that we will be targeting. 5568ae0d78SAlex Zinenkofunc.func private @microkernel( 5668ae0d78SAlex Zinenko %lhs: tensor<4x512xf32>, 5768ae0d78SAlex Zinenko %rhs: tensor<512x4xf32>, 5868ae0d78SAlex Zinenko %bias: tensor<4x4xf32>, 5968ae0d78SAlex Zinenko %init: tensor<4x4xf32>, 6068ae0d78SAlex Zinenko %output: tensor<4x4xf32>) -> tensor<4x4xf32> 6168ae0d78SAlex Zinenko 62*b33b91a2SOleksandr "Alex" Zinenkomodule attributes {transform.with_named_sequence} { 63*b33b91a2SOleksandr "Alex" Zinenko transform.named_sequence @__transform_main( 64*b33b91a2SOleksandr "Alex" Zinenko %arg0: !transform.any_op, 6568ae0d78SAlex Zinenko %arg1: !transform.op<"linalg.matmul">, 66*b33b91a2SOleksandr "Alex" Zinenko %arg2: !transform.op<"linalg.elemwise_binary">) { 6768ae0d78SAlex Zinenko // Since the %arg2 handle is associated with both elementwise operations, 6868ae0d78SAlex Zinenko // we need to split it into two handles so we can target only the second 6968ae0d78SAlex Zinenko // elementwise operation. 7068ae0d78SAlex Zinenko %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) 7168ae0d78SAlex Zinenko -> (!transform.any_op, !transform.any_op) 7268ae0d78SAlex Zinenko 7368ae0d78SAlex Zinenko // The actual tiling transformation takes tile sizes as attributes. It produces a 7468ae0d78SAlex Zinenko // handle to the loop generated during tiling. 7596ff0255SOleksandr "Alex" Zinenko %tiled, %loop = transform.structured.tile_using_forall %max tile_sizes [8, 32] 7668ae0d78SAlex Zinenko : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 7768ae0d78SAlex Zinenko 7868ae0d78SAlex Zinenko // We can now fuse the other operations into the loop. Here, we fuse 7968ae0d78SAlex Zinenko // operations one-by-one. This requires the operation that is being fused 8068ae0d78SAlex Zinenko // to define the value used within the loop, so the order of such fusions 8168ae0d78SAlex Zinenko // is important. We could also use "transform.merge_handles" to obtain 8268ae0d78SAlex Zinenko // a single handle to all operations and give it to `fuse_into_containing_op` 8368ae0d78SAlex Zinenko // that would take care of the ordering in this case. 8468ae0d78SAlex Zinenko %add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop 8568ae0d78SAlex Zinenko : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 8668ae0d78SAlex Zinenko %matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2 8768ae0d78SAlex Zinenko : (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op) 8868ae0d78SAlex Zinenko 8968ae0d78SAlex Zinenko // Tile again to get the desired size. Note that this time this tiles the 9068ae0d78SAlex Zinenko // "add" operation and fuses matmul into the loop, but doesn't affect the 9168ae0d78SAlex Zinenko // "max" operation. This illustrates the precise targeting with the transform 9268ae0d78SAlex Zinenko // dialect. Otherwise, it is difficult to differentiate "add" and "max", both 9368ae0d78SAlex Zinenko // of which having the same kind. 9496ff0255SOleksandr "Alex" Zinenko %tiled_second, %loop_second = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4] 9568ae0d78SAlex Zinenko : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 9668ae0d78SAlex Zinenko %matmul_fused_2, %loop_second_2 = 9768ae0d78SAlex Zinenko transform.structured.fuse_into_containing_op %matmul_fused into %loop_second 9868ae0d78SAlex Zinenko : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 9968ae0d78SAlex Zinenko 10068ae0d78SAlex Zinenko // Since outlining is currently only implemented for region-holding operations 10168ae0d78SAlex Zinenko // such as loops, use tiling to size 1 to materialize the outer loop that is 10268ae0d78SAlex Zinenko // going to be outlined. 10396ff0255SOleksandr "Alex" Zinenko %_0, %loop_third = transform.structured.tile_using_forall %tiled_second tile_sizes [1] 10468ae0d78SAlex Zinenko : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 10568ae0d78SAlex Zinenko %_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third 10668ae0d78SAlex Zinenko : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 10768ae0d78SAlex Zinenko %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} 10868ae0d78SAlex Zinenko : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">) 10968ae0d78SAlex Zinenko 11068ae0d78SAlex Zinenko transform.yield 11168ae0d78SAlex Zinenko } 112*b33b91a2SOleksandr "Alex" Zinenko} 113