xref: /llvm-project/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir (revision 4c3db2588e8b38f75744def6e2dd17c556950e46)
1// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 \
2// RUN: lhs-transpose-outer-blocks=false lhs-transpose-inner-blocks=false \
3// RUN: rhs-transpose-outer-blocks=true rhs-transpose-inner-blocks=true" \
4// RUN: -canonicalize | FileCheck %s --check-prefix=MMT4D
5
6// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 \
7// RUN: lhs-transpose-outer-blocks=false lhs-transpose-inner-blocks=false \
8// RUN: rhs-transpose-outer-blocks=false rhs-transpose-inner-blocks=false" \
9// RUN: -canonicalize | FileCheck %s --check-prefix=MM4D
10
11// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 \
12// RUN: lhs-transpose-outer-blocks=true lhs-transpose-inner-blocks=true \
13// RUN: rhs-transpose-outer-blocks=false rhs-transpose-inner-blocks=false" \
14// RUN: -canonicalize | FileCheck %s --check-prefix=MTM4D
15
16func.func @block_matmul(
17    %A: tensor<64x128xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
18  %0 = linalg.matmul  ins(%A, %B : tensor<64x128xf32>, tensor<128x64xf32>)
19                      outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
20  return %0 : tensor<64x64xf32>
21}
22
23func.func @block_matmul_transpose_a(
24    %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
25  %0 = linalg.matmul_transpose_a ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
26                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
27  return %0 : tensor<64x64xf32>
28}
29
30func.func @block_matmul_transpose_b(
31    %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
32  %0 = linalg.matmul_transpose_b ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
33                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
34  return %0 : tensor<64x64xf32>
35}
36
37// MMT4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
38// MMT4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
39// MMT4D-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
40// MMT4D-LABEL: func @block_matmul
41// MMT4D-COUNT-3: tensor.pack
42// MMT4D: linalg.generic
43// MMT4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
44// MMT4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
45// MMT4D-COUNT-1: tensor.unpack
46// MMT4D-LABEL: func @block_matmul_transpose_a
47// MMT4D-COUNT-3: tensor.pack
48// MMT4D: linalg.generic
49// MMT4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
50// MMT4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
51// MMT4D-COUNT-1: tensor.unpack
52// MMT4D-LABEL: func @block_matmul_transpose_b
53// MMT4D-COUNT-3: tensor.pack
54// MMT4D: linalg.generic
55// MMT4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
56// MMT4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
57// MMT4D-COUNT-1: tensor.unpack
58
59// MM4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
60// MM4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
61// MM4D-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
62// MM4D-LABEL: func @block_matmul
63// MM4D-COUNT-3: tensor.pack
64// MM4D: linalg.generic
65// MM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
66// MM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
67// MM4D-COUNT-1: tensor.unpack
68// MM4D-LABEL: func @block_matmul_transpose_a
69// MM4D-COUNT-3: tensor.pack
70// MM4D: linalg.generic
71// MM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
72// MM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
73// MM4D-COUNT-1: tensor.unpack
74// MM4D-LABEL: func @block_matmul_transpose_b
75// MM4D-COUNT-3: tensor.pack
76// MM4D: linalg.generic
77// MM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
78// MM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
79// MM4D-COUNT-1: tensor.unpack
80
81// MTM4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d5, d3)>
82// MTM4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
83// MTM4D-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
84// MTM4D-LABEL: func @block_matmul
85// MTM4D-COUNT-3: tensor.pack
86// MTM4D: linalg.generic
87// MTM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
88// MTM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
89// MTM4D-COUNT-1: tensor.unpack
90// MTM4D-LABEL: func @block_matmul_transpose_a
91// MTM4D-COUNT-3: tensor.pack
92// MTM4D: linalg.generic
93// MTM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
94// MTM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
95// MTM4D-COUNT-1: tensor.unpack
96// MTM4D-LABEL: func @block_matmul_transpose_b
97// MTM4D-COUNT-3: tensor.pack
98// MTM4D: linalg.generic
99// MTM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
100// MTM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
101// MTM4D-COUNT-1: tensor.unpack
102