xref: /llvm-project/mlir/test/Dialect/Linalg/block-pack-matmul.mlir (revision d776346afe790e5d51ca6c6e2238a6ba91d130a1)
1// RUN: mlir-opt %s -linalg-block-pack-matmul=block-factors=32,16,64 -canonicalize -split-input-file | FileCheck %s
2
3func.func @block_matmul(
4    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
5  %0 = linalg.matmul  ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
6                      outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
7  return %0 : tensor<128x128xf32>
8}
9
10// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
11// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
12// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
13
14// CHECK-LABEL: func @block_matmul(
15// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>
16// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
17// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
18// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
19// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
20// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<8x2x16x64xf32>
21// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
22// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
23// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x128xf32> -> tensor<8x2x16x64xf32>
24// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
25// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
26// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
27// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
28// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
29// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
30// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
31// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>)
32// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
33// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
34// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
35// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
36
37// -----
38
39func.func @block_matmul_dynamic(
40    %A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
41  %0 = linalg.matmul  ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
42                      outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
43  return %0 : tensor<?x?xf32>
44}
45
46// CHECK-DAG: #[[$MAP_M:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
47// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
48// CHECK-DAG: #[[$MAP_N:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
49// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
50// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
51// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
52
53// CHECK-LABEL: func @block_matmul_dynamic(
54// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<?x?xf32>, %[[B:[0-9a-z]+]]: tensor<?x?xf32>, %[[C:[0-9a-z]+]]: tensor<?x?xf32>
55// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
56// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
57// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
58// CHECK-DAG: %[[A_M:.+]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?xf32>
59// CHECK-DAG: %[[A_K:.+]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
60// CHECK-DAG: %[[A_OUTER_TILE_M:.+]] = affine.apply #[[$MAP_M]]()[%[[A_M]]]
61// CHECK-DAG: %[[A_OUTER_TILE_K:.+]] = affine.apply #[[$MAP_K]]()[%[[A_K]]]
62// CHECK: %[[PACK_DST_0:.+]] = tensor.empty(%[[A_OUTER_TILE_M]], %[[A_OUTER_TILE_K]]) : tensor<?x?x32x64xf32>
63// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
64// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
65// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
66// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<?x?xf32> -> tensor<?x?x32x64xf32>
67// CHECK-DAG: %[[B_K:.+]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xf32>
68// CHECK-DAG: %[[B_N:.+]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
69// CHECK-DAG: %[[B_OUTER_TILE_K:.+]] = affine.apply #[[$MAP_K]]()[%[[B_K]]]
70// CHECK-DAG: %[[B_OUTER_TILE_N:.+]] = affine.apply #[[$MAP_N]]()[%[[B_N]]]
71// CHECK: %[[PACK_DST_1:.+]] = tensor.empty(%[[B_OUTER_TILE_N]], %[[B_OUTER_TILE_K]]) : tensor<?x?x16x64xf32>
72// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
73// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
74// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
75// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<?x?xf32> -> tensor<?x?x16x64xf32>
76// CHECK-DAG: %[[C_M:.+]] = tensor.dim %[[C]], %[[C0]] : tensor<?x?xf32>
77// CHECK-DAG: %[[C_N:.+]] = tensor.dim %[[C]], %[[C1]] : tensor<?x?xf32>
78// CHECK-DAG: %[[C_OUTER_TILE_M:.+]] = affine.apply #[[$MAP_M]]()[%[[C_M]]]
79// CHECK-DAG: %[[C_OUTER_TILE_N:.+]] = affine.apply #[[$MAP_N]]()[%[[C_N]]]
80// CHECK: %[[PACK_DST_2:.+]] = tensor.empty(%[[C_OUTER_TILE_M]], %[[C_OUTER_TILE_N]]) : tensor<?x?x32x16xf32>
81// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
82// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
83// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
84// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<?x?xf32> -> tensor<?x?x32x16xf32>
85// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
86// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
87// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
88// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<?x?x32x64xf32>, tensor<?x?x16x64xf32>) outs(%[[C_PACKED]] : tensor<?x?x32x16xf32>)
89// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
90// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
91// CHECK-SAME:  into %[[C]] : tensor<?x?x32x16xf32> -> tensor<?x?xf32>
92// CHECK: return %[[RES_UNPACKED]] : tensor<?x?xf32>
93
94// -----
95
96func.func @block_matmul_with_constant(
97    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>) -> tensor<128x128xf32> {
98  %cst_acc = arith.constant dense<0.0> : tensor<128x128xf32>
99  %0 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
100                      outs(%cst_acc : tensor<128x128xf32>) -> tensor<128x128xf32>
101  return %0 : tensor<128x128xf32>
102}
103
104// CHECK-LABEL: func @block_matmul_with_constant(
105// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>
106// CHECK-DAG: %[[CST_ACC_PACKED:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
107// CHECK-DAG: %[[RES_DST:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
108// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
109// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[CST_ACC_PACKED]] : tensor<4x8x32x16xf32>)
110// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
111// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
112// CHECK-SAME:  into %[[RES_DST]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
113// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
114
115// -----
116
117func.func @block_matmul_with_producer(
118    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
119  %cst = arith.constant 0.0 : f32
120  %acc = linalg.fill ins(%cst : f32) outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
121  %1 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
122                      outs(%acc : tensor<128x128xf32>) -> tensor<128x128xf32>
123  return %1 : tensor<128x128xf32>
124}
125
126// CHECK-LABEL: func @block_matmul_with_producer(
127// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>
128// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
129// CHECK: %[[FILL_DST_PACKED:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
130// CHECK: %[[ACC_PACKED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[FILL_DST_PACKED]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
131// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
132// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[ACC_PACKED]] : tensor<4x8x32x16xf32>)
133// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
134// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
135// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
136// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
137
138// -----
139
140func.func @block_matmul_with_consumer(
141    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>, %D: tensor<128x128xf32>) -> tensor<128x128xf32> {
142  %0 = tensor.empty() : tensor<128x128xf32>
143  %1 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
144                     outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
145  %2 = linalg.add ins(%1, %D : tensor<128x128xf32>, tensor<128x128xf32>)
146                  outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
147  return %2 : tensor<128x128xf32>
148}
149
150// CHECK-LABEL: func @block_matmul_with_consumer(
151// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>, %[[D:[0-9a-z]+]]: tensor<128x128xf32>
152// CHECK-DAG: %[[RES_DST:.+]] = tensor.empty() : tensor<128x128xf32>
153// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
154// CHECK-SAME:  outs({{.*}} : tensor<4x8x32x16xf32>)
155// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
156// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
157// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
158// CHECK: %[[ADD_RES:.+]] = linalg.add
159// CHECK-SAME:  ins(%[[RES_UNPACKED]], %[[D]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[RES_DST]] : tensor<128x128xf32>)
160// CHECK: return %[[ADD_RES]] : tensor<128x128xf32>
161
162// -----
163
164func.func @block_batch_matmul(
165    %A: tensor<512x64x128xf32>, %B: tensor<512x128x64xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
166  %0 = linalg.batch_matmul ins(%A, %B : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
167                           outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
168  return %0 : tensor<512x64x64xf32>
169}
170
171// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
172// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
173// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
174
175// CHECK-LABEL: func @block_batch_matmul(
176// CHECK-SAME:   %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32>
177// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
178// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
179// CHECK-SAME:  outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 64]
180// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
181// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
182// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
183// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 64]
184// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x16x64xf32>
185// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
186// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
187// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
188// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
189// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
190// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
191// CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
192// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
193// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
194// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
195// CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
196// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
197
198// -----
199
200func.func @block_matmul_transpose_a(
201    %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
202  %0 = linalg.matmul_transpose_a ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
203                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
204  return %0 : tensor<64x64xf32>
205}
206
207// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
208// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
209// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
210
211// CHECK-LABEL: func @block_matmul_transpose_a(
212// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x64xf32>, %[[B:[0-9a-z]+]]: tensor<128x64xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
213// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
214// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
215// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 64]
216// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x64xf32> -> tensor<2x2x32x64xf32>
217// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32>
218// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
219// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
220// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x64xf32> -> tensor<4x2x16x64xf32>
221// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
222// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
223// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
224// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
225// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
226// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
227// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
228// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
229// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
230// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
231// CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
232// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
233
234// -----
235
236func.func @block_batch_matmul_transpose_a(
237    %A: tensor<512x128x64xf32>, %B: tensor<512x128x64xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
238  %0 = linalg.batch_matmul_transpose_a ins(%A, %B : tensor<512x128x64xf32>, tensor<512x128x64xf32>)
239                                       outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
240  return %0 : tensor<512x64x64xf32>
241}
242
243// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
244// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
245// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
246
247// CHECK-LABEL: func @block_batch_matmul_transpose_a(
248// CHECK-SAME:   %[[A:.+]]: tensor<512x128x64xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32>
249// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
250// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
251// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [32, 64]
252// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x128x64xf32> -> tensor<512x2x2x32x64xf32>
253// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
254// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
255// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 64]
256// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x16x64xf32>
257// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
258// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
259// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
260// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
261// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
262// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
263// CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
264// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
265// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
266// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
267// CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
268// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
269
270// -----
271
272func.func @block_matmul_transpose_b(
273    %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
274  %0 = linalg.matmul_transpose_b ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
275                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
276  return %0 : tensor<64x64xf32>
277}
278
279// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
280// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
281// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
282
283// CHECK-LABEL: func @block_matmul_transpose_b(
284// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<64x128xf32>, %[[B:[0-9a-z]+]]: tensor<64x128xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
285// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
286// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
287// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
288// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<64x128xf32> -> tensor<2x2x32x64xf32>
289// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32>
290// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
291// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 64]
292// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<64x128xf32> -> tensor<4x2x16x64xf32>
293// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
294// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
295// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
296// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
297// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
298// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
299// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
300// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
301// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
302// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
303// CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
304// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
305
306// -----
307
308func.func @block_batch_matmul_transpose_b(
309    %A: tensor<512x64x128xf32>, %B: tensor<512x64x128xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
310  %0 = linalg.batch_matmul_transpose_b ins(%A, %B : tensor<512x64x128xf32>, tensor<512x64x128xf32>)
311                                       outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
312  return %0 : tensor<512x64x64xf32>
313}
314
315// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
316// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
317// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
318
319// CHECK-LABEL: func @block_batch_matmul_transpose_b(
320// CHECK-SAME:   %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x64x128xf32>, %[[C:.+]]: tensor<512x64x64xf32>
321// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
322// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
323// CHECK-SAME:  outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 64]
324// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
325// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
326// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
327// CHECK-SAME:  outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 64]
328// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x64x128xf32> -> tensor<512x4x2x16x64xf32>
329// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
330// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
331// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
332// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
333// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
334// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
335// CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
336// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
337// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
338// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
339// CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
340// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
341
342// -----
343
344#map = affine_map<(d0, d1, d2) -> (d0, d2)>
345#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
346#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
347
348func.func @block_generic_matmul(
349    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
350  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
351    ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
352    outs(%C : tensor<128x128xf32>) {
353  ^bb0(%in: f32, %in_0: f32, %out: f32):
354    %1 = arith.mulf %in, %in_0 : f32
355    %2 = arith.addf %out, %1 : f32
356    linalg.yield %2 : f32
357  } -> tensor<128x128xf32>
358  return %0 : tensor<128x128xf32>
359}
360
361// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
362// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
363// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
364
365// CHECK-LABEL: func @block_generic_matmul(
366// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>
367// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
368// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
369// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
370// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
371// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<8x2x16x64xf32>
372// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
373// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
374// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x128xf32> -> tensor<8x2x16x64xf32>
375// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
376// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
377// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
378// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
379// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
380// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
381// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
382// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>)
383// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
384// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
385// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
386// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
387
388// -----
389
390#map = affine_map<(d0, d1, d2) -> (d2, d0)>
391#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
392#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
393
394func.func @block_generic_matmul_transpose_a(
395    %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
396  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
397    ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
398    outs(%C : tensor<64x64xf32>) {
399  ^bb0(%in: f32, %in_0: f32, %out: f32):
400    %1 = arith.mulf %in, %in_0 : f32
401    %2 = arith.addf %out, %1 : f32
402    linalg.yield %2 : f32
403  } -> tensor<64x64xf32>
404  return %0 : tensor<64x64xf32>
405}
406
407// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
408// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
409// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
410
411// CHECK-LABEL: func @block_generic_matmul_transpose_a(
412// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x64xf32>, %[[B:[0-9a-z]+]]: tensor<128x64xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
413// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
414// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
415// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 64]
416// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x64xf32> -> tensor<2x2x32x64xf32>
417// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32>
418// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
419// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
420// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x64xf32> -> tensor<4x2x16x64xf32>
421// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
422// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
423// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
424// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
425// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
426// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
427// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
428// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
429// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
430// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
431// CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
432// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
433
434// -----
435
436#map = affine_map<(d0, d1, d2) -> (d0, d2)>
437#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
438#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
439
440func.func @block_generic_matmul_transpose_b(
441    %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
442  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
443    ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
444    outs(%C : tensor<64x64xf32>) {
445  ^bb0(%in: f32, %in_0: f32, %out: f32):
446    %1 = arith.mulf %in, %in_0 : f32
447    %2 = arith.addf %out, %1 : f32
448    linalg.yield %2 : f32
449  } -> tensor<64x64xf32>
450  return %0 : tensor<64x64xf32>
451}
452
453// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
454// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
455// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
456
457// CHECK-LABEL: func @block_generic_matmul_transpose_b(
458// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<64x128xf32>, %[[B:[0-9a-z]+]]: tensor<64x128xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
459// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
460// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
461// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
462// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<64x128xf32> -> tensor<2x2x32x64xf32>
463// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32>
464// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
465// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 64]
466// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<64x128xf32> -> tensor<4x2x16x64xf32>
467// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
468// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
469// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
470// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
471// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
472// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
473// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
474// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
475// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
476// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
477// CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
478// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
479
480// -----
481
482#map = affine_map<(d0, d1) -> (d0, d1)>
483
484func.func @non_contraction_generic(
485    %A: tensor<64x128xf32>) -> tensor<64x128xf32> {
486  %c0 = arith.constant 0.000000e+00 : f32
487  %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]}
488    outs(%A : tensor<64x128xf32>) {
489  ^bb0(%out: f32):
490    %1 = arith.maximumf %out, %c0 : f32
491    linalg.yield %1 : f32
492  } -> tensor<64x128xf32>
493  return %0 : tensor<64x128xf32>
494}
495
496// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
497
498// CHECK-LABEL: func @non_contraction_generic(
499// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<64x128xf32>
500// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
501// CHECK-NOT: tensor.pack
502// CHECK: %[[GENERIC:.+]] = linalg.generic
503// CHECK-SAME:  indexing_maps = [#[[$MAP]]]
504// CHECK-SAME:  iterator_types = ["parallel", "parallel"]
505// CHECK-SAME:  outs(%[[A]] : tensor<64x128xf32>)
506// CHECK-NOT: tensor.unpack
507// CHECK: return %[[GENERIC]] : tensor<64x128xf32>
508