1// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 allow-padding=1" \ 2// RUN: -canonicalize | FileCheck %s 3 4// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 allow-padding=0" \ 5// RUN: -canonicalize | FileCheck %s --check-prefix=NOPAD 6 7// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 allow-padding=1 mnk-padded-multiples=256,512,384" \ 8// RUN: -canonicalize | FileCheck %s --check-prefix=PAD-MULT 9 10func.func @block_matmul_padding( 11 %A: tensor<123x125xf32>, %B: tensor<125x124xf32>, %C: tensor<123x124xf32>) -> tensor<123x124xf32> { 12 %0 = linalg.matmul ins(%A, %B : tensor<123x125xf32>, tensor<125x124xf32>) 13 outs(%C : tensor<123x124xf32>) -> tensor<123x124xf32> 14 return %0 : tensor<123x124xf32> 15} 16 17// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> 18// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)> 19// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> 20// CHECK-LABEL: func @block_matmul_padding( 21// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<123x125xf32>, %[[B:[0-9a-z]+]]: tensor<125x124xf32>, %[[C:[0-9a-z]+]]: tensor<123x124xf32> 22// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 23// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<4x2x32x64xf32> 24// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]] 25// CHECK-SAME: padding_value(%[[ZERO]] : f32) 26// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64] 27// CHECK-SAME: into %[[PACK_DST_0]] : tensor<123x125xf32> -> tensor<4x2x32x64xf32> 28// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<8x2x16x64xf32> 29// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]] 30// CHECK-SAME: padding_value(%[[ZERO]] : f32) 31// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64] 32// CHECK-SAME: into %[[PACK_DST_1]] : tensor<125x124xf32> -> tensor<8x2x16x64xf32> 33// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<4x8x32x16xf32> 34// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]] 35// CHECK-SAME: padding_value(%[[ZERO]] : f32) 36// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 37// CHECK-SAME: into %[[PACK_DST_2]] : tensor<123x124xf32> -> tensor<4x8x32x16xf32> 38// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic 39// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] 40// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] 41// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>) 42// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 43// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 44// CHECK-SAME: into %[[C]] : tensor<4x8x32x16xf32> -> tensor<123x124xf32> 45// CHECK: return %[[RES_UNPACKED]] : tensor<123x124xf32> 46 47// NOPAD-LABEL: func @block_matmul_padding( 48// NOPAD-SAME: %[[A:[0-9a-z]+]]: tensor<123x125xf32>, %[[B:[0-9a-z]+]]: tensor<125x124xf32>, %[[C:[0-9a-z]+]]: tensor<123x124xf32> 49// NOPAD-NOT: tensor.pack 50// NOPAD: linalg.matmul ins(%[[A]], %[[B]] : tensor<123x125xf32>, tensor<125x124xf32>) 51// NOPAD-SAME: outs(%[[C]] : tensor<123x124xf32>) -> tensor<123x124xf32> 52// NOPAD-NOT: tensor.unpack 53 54// PAD-MULT-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> 55// PAD-MULT-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)> 56// PAD-MULT-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> 57// PAD-MULT-LABEL: func @block_matmul_padding( 58// PAD-MULT-SAME: %[[A:[0-9a-z]+]]: tensor<123x125xf32>, %[[B:[0-9a-z]+]]: tensor<125x124xf32>, %[[C:[0-9a-z]+]]: tensor<123x124xf32> 59// PAD-MULT-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 60// PAD-MULT: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<1x1x256x384xf32> 61// PAD-MULT: %[[A_PACKED:.+]] = tensor.pack %[[A]] 62// PAD-MULT-SAME: padding_value(%[[ZERO]] : f32) 63// PAD-MULT-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [256, 384] 64// PAD-MULT-SAME: into %[[PACK_DST_0]] : tensor<123x125xf32> -> tensor<1x1x256x384xf32> 65// PAD-MULT: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<1x1x512x384xf32> 66// PAD-MULT: %[[B_PACKED:.+]] = tensor.pack %[[B]] 67// PAD-MULT-SAME: padding_value(%[[ZERO]] : f32) 68// PAD-MULT-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [512, 384] 69// PAD-MULT-SAME: into %[[PACK_DST_1]] : tensor<125x124xf32> -> tensor<1x1x512x384xf32> 70// PAD-MULT: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<1x1x256x512xf32> 71// PAD-MULT: %[[C_PACKED:.+]] = tensor.pack %[[C]] 72// PAD-MULT-SAME: padding_value(%[[ZERO]] : f32) 73// PAD-MULT-SAME: inner_dims_pos = [0, 1] inner_tiles = [256, 512] 74// PAD-MULT-SAME: into %[[PACK_DST_2]] : tensor<123x124xf32> -> tensor<1x1x256x512xf32> 75// PAD-MULT: %[[GEMM_RES_PACKED:.+]] = linalg.generic 76// PAD-MULT-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] 77// PAD-MULT-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] 78// PAD-MULT-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<1x1x256x384xf32>, tensor<1x1x512x384xf32>) outs(%[[C_PACKED]] : tensor<1x1x256x512xf32>) 79// PAD-MULT: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 80// PAD-MULT-SAME: inner_dims_pos = [0, 1] inner_tiles = [256, 512] 81// PAD-MULT-SAME: into %[[C]] : tensor<1x1x256x512xf32> -> tensor<123x124xf32> 82// PAD-MULT: return %[[RES_UNPACKED]] : tensor<123x124xf32> 83