1// RUN: mlir-opt %s -linalg-block-pack-matmul=block-factors=32,16,64 -canonicalize -split-input-file | FileCheck %s 2 3func.func @block_matmul( 4 %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> { 5 %0 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>) 6 outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32> 7 return %0 : tensor<128x128xf32> 8} 9 10// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> 11// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)> 12// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> 13 14// CHECK-LABEL: func @block_matmul( 15// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32> 16// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<4x2x32x64xf32> 17// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]] 18// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64] 19// CHECK-SAME: into %[[PACK_DST_0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32> 20// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<8x2x16x64xf32> 21// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]] 22// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64] 23// CHECK-SAME: into %[[PACK_DST_1]] : tensor<128x128xf32> -> tensor<8x2x16x64xf32> 24// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<4x8x32x16xf32> 25// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]] 26// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 27// CHECK-SAME: into %[[PACK_DST_2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32> 28// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic 29// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] 30// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] 31// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>) 32// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 33// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 34// CHECK-SAME: into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32> 35// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32> 36 37// ----- 38 39func.func @block_matmul_dynamic( 40 %A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> { 41 %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) 42 outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32> 43 return %0 : tensor<?x?xf32> 44} 45 46// CHECK-DAG: #[[$MAP_M:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)> 47// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)> 48// CHECK-DAG: #[[$MAP_N:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)> 49// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> 50// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)> 51// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> 52 53// CHECK-LABEL: func @block_matmul_dynamic( 54// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>, %[[B:[0-9a-z]+]]: tensor<?x?xf32>, %[[C:[0-9a-z]+]]: tensor<?x?xf32> 55// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 56// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 57// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 58// CHECK-DAG: %[[A_M:.+]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?xf32> 59// CHECK-DAG: %[[A_K:.+]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32> 60// CHECK-DAG: %[[A_OUTER_TILE_M:.+]] = affine.apply #[[$MAP_M]]()[%[[A_M]]] 61// CHECK-DAG: %[[A_OUTER_TILE_K:.+]] = affine.apply #[[$MAP_K]]()[%[[A_K]]] 62// CHECK: %[[PACK_DST_0:.+]] = tensor.empty(%[[A_OUTER_TILE_M]], %[[A_OUTER_TILE_K]]) : tensor<?x?x32x64xf32> 63// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]] 64// CHECK-SAME: padding_value(%[[ZERO]] : f32) 65// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64] 66// CHECK-SAME: into %[[PACK_DST_0]] : tensor<?x?xf32> -> tensor<?x?x32x64xf32> 67// CHECK-DAG: %[[B_K:.+]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xf32> 68// CHECK-DAG: %[[B_N:.+]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32> 69// CHECK-DAG: %[[B_OUTER_TILE_K:.+]] = affine.apply #[[$MAP_K]]()[%[[B_K]]] 70// CHECK-DAG: %[[B_OUTER_TILE_N:.+]] = affine.apply #[[$MAP_N]]()[%[[B_N]]] 71// CHECK: %[[PACK_DST_1:.+]] = tensor.empty(%[[B_OUTER_TILE_N]], %[[B_OUTER_TILE_K]]) : tensor<?x?x16x64xf32> 72// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]] 73// CHECK-SAME: padding_value(%[[ZERO]] : f32) 74// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64] 75// CHECK-SAME: into %[[PACK_DST_1]] : tensor<?x?xf32> -> tensor<?x?x16x64xf32> 76// CHECK-DAG: %[[C_M:.+]] = tensor.dim %[[C]], %[[C0]] : tensor<?x?xf32> 77// CHECK-DAG: %[[C_N:.+]] = tensor.dim %[[C]], %[[C1]] : tensor<?x?xf32> 78// CHECK-DAG: %[[C_OUTER_TILE_M:.+]] = affine.apply #[[$MAP_M]]()[%[[C_M]]] 79// CHECK-DAG: %[[C_OUTER_TILE_N:.+]] = affine.apply #[[$MAP_N]]()[%[[C_N]]] 80// CHECK: %[[PACK_DST_2:.+]] = tensor.empty(%[[C_OUTER_TILE_M]], %[[C_OUTER_TILE_N]]) : tensor<?x?x32x16xf32> 81// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]] 82// CHECK-SAME: padding_value(%[[ZERO]] : f32) 83// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 84// CHECK-SAME: into %[[PACK_DST_2]] : tensor<?x?xf32> -> tensor<?x?x32x16xf32> 85// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic 86// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] 87// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] 88// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<?x?x32x64xf32>, tensor<?x?x16x64xf32>) outs(%[[C_PACKED]] : tensor<?x?x32x16xf32>) 89// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 90// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 91// CHECK-SAME: into %[[C]] : tensor<?x?x32x16xf32> -> tensor<?x?xf32> 92// CHECK: return %[[RES_UNPACKED]] : tensor<?x?xf32> 93 94// ----- 95 96func.func @block_matmul_with_constant( 97 %A: tensor<128x128xf32>, %B: tensor<128x128xf32>) -> tensor<128x128xf32> { 98 %cst_acc = arith.constant dense<0.0> : tensor<128x128xf32> 99 %0 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>) 100 outs(%cst_acc : tensor<128x128xf32>) -> tensor<128x128xf32> 101 return %0 : tensor<128x128xf32> 102} 103 104// CHECK-LABEL: func @block_matmul_with_constant( 105// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32> 106// CHECK-DAG: %[[CST_ACC_PACKED:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32> 107// CHECK-DAG: %[[RES_DST:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32> 108// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic 109// CHECK-SAME: ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[CST_ACC_PACKED]] : tensor<4x8x32x16xf32>) 110// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 111// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 112// CHECK-SAME: into %[[RES_DST]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32> 113// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32> 114 115// ----- 116 117func.func @block_matmul_with_producer( 118 %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> { 119 %cst = arith.constant 0.0 : f32 120 %acc = linalg.fill ins(%cst : f32) outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32> 121 %1 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>) 122 outs(%acc : tensor<128x128xf32>) -> tensor<128x128xf32> 123 return %1 : tensor<128x128xf32> 124} 125 126// CHECK-LABEL: func @block_matmul_with_producer( 127// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32> 128// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32 129// CHECK: %[[FILL_DST_PACKED:.+]] = tensor.empty() : tensor<4x8x32x16xf32> 130// CHECK: %[[ACC_PACKED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[FILL_DST_PACKED]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32> 131// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic 132// CHECK-SAME: ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[ACC_PACKED]] : tensor<4x8x32x16xf32>) 133// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 134// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 135// CHECK-SAME: into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32> 136// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32> 137 138// ----- 139 140func.func @block_matmul_with_consumer( 141 %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>, %D: tensor<128x128xf32>) -> tensor<128x128xf32> { 142 %0 = tensor.empty() : tensor<128x128xf32> 143 %1 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>) 144 outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32> 145 %2 = linalg.add ins(%1, %D : tensor<128x128xf32>, tensor<128x128xf32>) 146 outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32> 147 return %2 : tensor<128x128xf32> 148} 149 150// CHECK-LABEL: func @block_matmul_with_consumer( 151// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>, %[[D:[0-9a-z]+]]: tensor<128x128xf32> 152// CHECK-DAG: %[[RES_DST:.+]] = tensor.empty() : tensor<128x128xf32> 153// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic 154// CHECK-SAME: outs({{.*}} : tensor<4x8x32x16xf32>) 155// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 156// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 157// CHECK-SAME: into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32> 158// CHECK: %[[ADD_RES:.+]] = linalg.add 159// CHECK-SAME: ins(%[[RES_UNPACKED]], %[[D]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[RES_DST]] : tensor<128x128xf32>) 160// CHECK: return %[[ADD_RES]] : tensor<128x128xf32> 161 162// ----- 163 164func.func @block_batch_matmul( 165 %A: tensor<512x64x128xf32>, %B: tensor<512x128x64xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> { 166 %0 = linalg.batch_matmul ins(%A, %B : tensor<512x64x128xf32>, tensor<512x128x64xf32>) 167 outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32> 168 return %0 : tensor<512x64x64xf32> 169} 170 171// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)> 172// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)> 173// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)> 174 175// CHECK-LABEL: func @block_batch_matmul( 176// CHECK-SAME: %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32> 177// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32> 178// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]] 179// CHECK-SAME: outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 64] 180// CHECK-SAME: into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32> 181// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32> 182// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]] 183// CHECK-SAME: outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 64] 184// CHECK-SAME: into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x16x64xf32> 185// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32> 186// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]] 187// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16] 188// CHECK-SAME: into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32> 189// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic 190// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] 191// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] 192// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>) 193// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 194// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16] 195// CHECK-SAME: into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32> 196// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32> 197 198// ----- 199 200func.func @block_matmul_transpose_a( 201 %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> { 202 %0 = linalg.matmul_transpose_a ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>) 203 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32> 204 return %0 : tensor<64x64xf32> 205} 206 207// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> 208// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)> 209// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> 210 211// CHECK-LABEL: func @block_matmul_transpose_a( 212// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<128x64xf32>, %[[B:[0-9a-z]+]]: tensor<128x64xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32> 213// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32> 214// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]] 215// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 64] 216// CHECK-SAME: into %[[PACK_DST_0]] : tensor<128x64xf32> -> tensor<2x2x32x64xf32> 217// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32> 218// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]] 219// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64] 220// CHECK-SAME: into %[[PACK_DST_1]] : tensor<128x64xf32> -> tensor<4x2x16x64xf32> 221// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32> 222// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]] 223// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 224// CHECK-SAME: into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32> 225// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic 226// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] 227// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] 228// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>) 229// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 230// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 231// CHECK-SAME: into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32> 232// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32> 233 234// ----- 235 236func.func @block_batch_matmul_transpose_a( 237 %A: tensor<512x128x64xf32>, %B: tensor<512x128x64xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> { 238 %0 = linalg.batch_matmul_transpose_a ins(%A, %B : tensor<512x128x64xf32>, tensor<512x128x64xf32>) 239 outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32> 240 return %0 : tensor<512x64x64xf32> 241} 242 243// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)> 244// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)> 245// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)> 246 247// CHECK-LABEL: func @block_batch_matmul_transpose_a( 248// CHECK-SAME: %[[A:.+]]: tensor<512x128x64xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32> 249// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32> 250// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]] 251// CHECK-SAME: outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [32, 64] 252// CHECK-SAME: into %[[PACK_DST_0]] : tensor<512x128x64xf32> -> tensor<512x2x2x32x64xf32> 253// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32> 254// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]] 255// CHECK-SAME: outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 64] 256// CHECK-SAME: into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x16x64xf32> 257// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32> 258// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]] 259// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16] 260// CHECK-SAME: into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32> 261// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic 262// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] 263// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] 264// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>) 265// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 266// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16] 267// CHECK-SAME: into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32> 268// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32> 269 270// ----- 271 272func.func @block_matmul_transpose_b( 273 %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> { 274 %0 = linalg.matmul_transpose_b ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>) 275 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32> 276 return %0 : tensor<64x64xf32> 277} 278 279// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> 280// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)> 281// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> 282 283// CHECK-LABEL: func @block_matmul_transpose_b( 284// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<64x128xf32>, %[[B:[0-9a-z]+]]: tensor<64x128xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32> 285// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32> 286// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]] 287// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64] 288// CHECK-SAME: into %[[PACK_DST_0]] : tensor<64x128xf32> -> tensor<2x2x32x64xf32> 289// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32> 290// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]] 291// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 64] 292// CHECK-SAME: into %[[PACK_DST_1]] : tensor<64x128xf32> -> tensor<4x2x16x64xf32> 293// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32> 294// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]] 295// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 296// CHECK-SAME: into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32> 297// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic 298// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] 299// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] 300// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>) 301// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 302// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 303// CHECK-SAME: into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32> 304// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32> 305 306// ----- 307 308func.func @block_batch_matmul_transpose_b( 309 %A: tensor<512x64x128xf32>, %B: tensor<512x64x128xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> { 310 %0 = linalg.batch_matmul_transpose_b ins(%A, %B : tensor<512x64x128xf32>, tensor<512x64x128xf32>) 311 outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32> 312 return %0 : tensor<512x64x64xf32> 313} 314 315// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)> 316// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)> 317// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)> 318 319// CHECK-LABEL: func @block_batch_matmul_transpose_b( 320// CHECK-SAME: %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x64x128xf32>, %[[C:.+]]: tensor<512x64x64xf32> 321// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32> 322// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]] 323// CHECK-SAME: outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 64] 324// CHECK-SAME: into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32> 325// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32> 326// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]] 327// CHECK-SAME: outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 64] 328// CHECK-SAME: into %[[PACK_DST_1]] : tensor<512x64x128xf32> -> tensor<512x4x2x16x64xf32> 329// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32> 330// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]] 331// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16] 332// CHECK-SAME: into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32> 333// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic 334// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] 335// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] 336// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>) 337// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 338// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16] 339// CHECK-SAME: into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32> 340// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32> 341 342// ----- 343 344#map = affine_map<(d0, d1, d2) -> (d0, d2)> 345#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> 346#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> 347 348func.func @block_generic_matmul( 349 %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> { 350 %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} 351 ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>) 352 outs(%C : tensor<128x128xf32>) { 353 ^bb0(%in: f32, %in_0: f32, %out: f32): 354 %1 = arith.mulf %in, %in_0 : f32 355 %2 = arith.addf %out, %1 : f32 356 linalg.yield %2 : f32 357 } -> tensor<128x128xf32> 358 return %0 : tensor<128x128xf32> 359} 360 361// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> 362// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)> 363// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> 364 365// CHECK-LABEL: func @block_generic_matmul( 366// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32> 367// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<4x2x32x64xf32> 368// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]] 369// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64] 370// CHECK-SAME: into %[[PACK_DST_0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32> 371// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<8x2x16x64xf32> 372// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]] 373// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64] 374// CHECK-SAME: into %[[PACK_DST_1]] : tensor<128x128xf32> -> tensor<8x2x16x64xf32> 375// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<4x8x32x16xf32> 376// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]] 377// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 378// CHECK-SAME: into %[[PACK_DST_2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32> 379// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic 380// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] 381// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] 382// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>) 383// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 384// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 385// CHECK-SAME: into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32> 386// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32> 387 388// ----- 389 390#map = affine_map<(d0, d1, d2) -> (d2, d0)> 391#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> 392#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> 393 394func.func @block_generic_matmul_transpose_a( 395 %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> { 396 %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} 397 ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>) 398 outs(%C : tensor<64x64xf32>) { 399 ^bb0(%in: f32, %in_0: f32, %out: f32): 400 %1 = arith.mulf %in, %in_0 : f32 401 %2 = arith.addf %out, %1 : f32 402 linalg.yield %2 : f32 403 } -> tensor<64x64xf32> 404 return %0 : tensor<64x64xf32> 405} 406 407// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> 408// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)> 409// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> 410 411// CHECK-LABEL: func @block_generic_matmul_transpose_a( 412// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<128x64xf32>, %[[B:[0-9a-z]+]]: tensor<128x64xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32> 413// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32> 414// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]] 415// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 64] 416// CHECK-SAME: into %[[PACK_DST_0]] : tensor<128x64xf32> -> tensor<2x2x32x64xf32> 417// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32> 418// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]] 419// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64] 420// CHECK-SAME: into %[[PACK_DST_1]] : tensor<128x64xf32> -> tensor<4x2x16x64xf32> 421// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32> 422// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]] 423// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 424// CHECK-SAME: into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32> 425// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic 426// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] 427// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] 428// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>) 429// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 430// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 431// CHECK-SAME: into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32> 432// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32> 433 434// ----- 435 436#map = affine_map<(d0, d1, d2) -> (d0, d2)> 437#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> 438#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> 439 440func.func @block_generic_matmul_transpose_b( 441 %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> { 442 %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} 443 ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>) 444 outs(%C : tensor<64x64xf32>) { 445 ^bb0(%in: f32, %in_0: f32, %out: f32): 446 %1 = arith.mulf %in, %in_0 : f32 447 %2 = arith.addf %out, %1 : f32 448 linalg.yield %2 : f32 449 } -> tensor<64x64xf32> 450 return %0 : tensor<64x64xf32> 451} 452 453// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> 454// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)> 455// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> 456 457// CHECK-LABEL: func @block_generic_matmul_transpose_b( 458// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<64x128xf32>, %[[B:[0-9a-z]+]]: tensor<64x128xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32> 459// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32> 460// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]] 461// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64] 462// CHECK-SAME: into %[[PACK_DST_0]] : tensor<64x128xf32> -> tensor<2x2x32x64xf32> 463// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32> 464// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]] 465// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 64] 466// CHECK-SAME: into %[[PACK_DST_1]] : tensor<64x128xf32> -> tensor<4x2x16x64xf32> 467// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32> 468// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]] 469// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 470// CHECK-SAME: into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32> 471// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic 472// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] 473// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] 474// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>) 475// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]] 476// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] 477// CHECK-SAME: into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32> 478// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32> 479 480// ----- 481 482#map = affine_map<(d0, d1) -> (d0, d1)> 483 484func.func @non_contraction_generic( 485 %A: tensor<64x128xf32>) -> tensor<64x128xf32> { 486 %c0 = arith.constant 0.000000e+00 : f32 487 %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} 488 outs(%A : tensor<64x128xf32>) { 489 ^bb0(%out: f32): 490 %1 = arith.maximumf %out, %c0 : f32 491 linalg.yield %1 : f32 492 } -> tensor<64x128xf32> 493 return %0 : tensor<64x128xf32> 494} 495 496// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> 497 498// CHECK-LABEL: func @non_contraction_generic( 499// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<64x128xf32> 500// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32 501// CHECK-NOT: tensor.pack 502// CHECK: %[[GENERIC:.+]] = linalg.generic 503// CHECK-SAME: indexing_maps = [#[[$MAP]]] 504// CHECK-SAME: iterator_types = ["parallel", "parallel"] 505// CHECK-SAME: outs(%[[A]] : tensor<64x128xf32>) 506// CHECK-NOT: tensor.unpack 507// CHECK: return %[[GENERIC]] : tensor<64x128xf32> 508