xref: /llvm-project/mlir/test/Dialect/Linalg/transpose-matmul.mlir (revision 79225349748bb556fd027cc0bfeb73b1e9a632f4)
1// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/transpose-matmul-a.mlir' -transform-interpreter -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-A
2// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/transpose-matmul-b.mlir' -transform-interpreter -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-B
3
4// CHECK-LABEL:   func.func @matmul_static(
5// CHECK-SAME:                             %[[A:.*]]: tensor<16x8xf32>,
6// CHECK-SAME:                             %[[B:.*]]: tensor<8x16xf32>) -> tensor<16x16xf32> {
7// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
8// CHECK:           %[[C_INIT:.*]] = tensor.empty() : tensor<16x16xf32>
9// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32>
10// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
11// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
12// TRANSPOSE-A:     %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
13// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32>
14// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0]
15// TRANSPOSE-B:     %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
16// CHECK:           return %[[C]] : tensor<16x16xf32>
17// CHECK:         }
18func.func @matmul_static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
19  %cst = arith.constant 0.0 : f32
20  %init = tensor.empty() : tensor<16x16xf32>
21  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<16x16xf32>) -> tensor<16x16xf32>
22  %0 = linalg.matmul ins(%A, %B : tensor<16x8xf32>, tensor<8x16xf32>) outs(%C : tensor<16x16xf32>) -> tensor<16x16xf32>
23  return %0 : tensor<16x16xf32>
24}
25
26//-----
27
28// CHECK-LABEL:   func.func @matmul_dynamic(
29// CHECK-SAME:                              %[[A:.*]]: tensor<?x?xf32>,
30// CHECK-SAME:                              %[[B:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
31// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
32// CHECK:           %[[C0:.*]] = arith.constant 0 : index
33// CHECK:           %[[C1:.*]] = arith.constant 1 : index
34// CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?xf32>
35// CHECK:           %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
36// CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[B_DIM1]]) : tensor<?x?xf32>
37// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
38// TRANSPOSE-A:     %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
39// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
40// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
41// TRANSPOSE-A:     %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
42// TRANSPOSE-B:     %[[B_DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xf32>
43// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty(%[[B_DIM1]], %[[B_DIM0]]) : tensor<?x?xf32>
44// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<?x?xf32>) outs(%[[B_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
45// TRANSPOSE-B:     %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
46// CHECK:           return %[[C]] : tensor<?x?xf32>
47// CHECK:         }
48func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
49  %cst = arith.constant 0.0 : f32
50  %c0 = arith.constant 0 : index
51  %c1 = arith.constant 1 : index
52  %d0 = tensor.dim %A, %c0 : tensor<?x?xf32>
53  %d1 = tensor.dim %B, %c1 : tensor<?x?xf32>
54  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
55  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
56  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
57  return %0 : tensor<?x?xf32>
58}
59
60//-----
61
62// CHECK-LABEL:   func.func @matmul_mixed(
63// CHECK-SAME:                            %[[A:.*]]: tensor<?x8xf32>,
64// CHECK-SAME:                            %[[B:.*]]: tensor<8x16xf32>) -> tensor<?x16xf32> {
65// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
66// CHECK:           %[[C0:.*]] = arith.constant 0 : index
67// CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x8xf32>
68// CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<?x16xf32>
69// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x16xf32>) -> tensor<?x16xf32>
70// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
71// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
72// TRANSPOSE-A:     %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
73// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32>
74// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0]
75// TRANSPOSE-B:     %[[B0:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
76// CHECK:           return %[[B0]] : tensor<?x16xf32>
77// CHECK:         }
78func.func @matmul_mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {
79  %cst = arith.constant 0.0 : f32
80  %c0 = arith.constant 0 : index
81  %c1 = arith.constant 1 : index
82  %d0 = tensor.dim %A, %c0 : tensor<?x8xf32>
83  %init = tensor.empty(%d0) : tensor<?x16xf32>
84  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x16xf32>) -> tensor<?x16xf32>
85  %0 = linalg.matmul ins(%A, %B : tensor<?x8xf32>, tensor<8x16xf32>) outs(%C : tensor<?x16xf32>) -> tensor<?x16xf32>
86  return %0 : tensor<?x16xf32>
87}
88
89//-----
90
91// CHECK-LABEL:   func.func @batch_matmul_static(
92// CHECK-SAME:                                   %[[A:.*]]: tensor<2x16x8xf32>,
93// CHECK-SAME:                                   %[[B:.*]]: tensor<2x8x16xf32>) -> tensor<2x16x16xf32> {
94// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
95// CHECK:           %[[C_INIT:.*]] = tensor.empty() : tensor<2x16x16xf32>
96// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
97// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x8x16xf32>
98// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<2x16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<2x8x16xf32>) permutation = [0, 2, 1]
99// TRANSPOSE-A:     %[[C:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x16xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
100// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x16x8xf32>
101// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<2x8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<2x16x8xf32>) permutation = [0, 2, 1]
102// TRANSPOSE-B:     %[[C:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<2x16x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
103// CHECK:           return %[[C]] : tensor<2x16x16xf32>
104// CHECK:         }
105func.func @batch_matmul_static(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>) -> (tensor<2x16x16xf32>) {
106  %cst = arith.constant 0.0 : f32
107  %init = tensor.empty() : tensor<2x16x16xf32>
108  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
109  %0 = linalg.batch_matmul ins(%A, %B : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%C : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
110  return %0 : tensor<2x16x16xf32>
111}
112
113//-----
114
115// CHECK-LABEL:   func.func @batch_matmul_dynamic(
116// CHECK-SAME:                                    %[[A:.*]]: tensor<?x?x?xf32>,
117// CHECK-SAME:                                    %[[B:.*]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
118// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
119// CHECK:           %[[C0:.*]] = arith.constant 0 : index
120// CHECK:           %[[C1:.*]] = arith.constant 1 : index
121// CHECK:           %[[C2:.*]] = arith.constant 2 : index
122// CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?x?xf32>
123// CHECK:           %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?x?xf32>
124// CHECK:           %[[B_DIM2:.*]] = tensor.dim %[[B]], %[[C2]] : tensor<?x?x?xf32>
125// CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[A_DIM1]], %[[B_DIM2]]) : tensor<?x?x?xf32>
126// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
127// TRANSPOSE-A:     %[[A_DIM2:.*]] = tensor.dim %[[A]], %[[C2]] : tensor<?x?x?xf32>
128// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[A_DIM2]], %[[A_DIM1]]) : tensor<?x?x?xf32>
129// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?x?xf32>) permutation = [0, 2, 1]
130// TRANSPOSE-A:     %[[C:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
131// TRANSPOSE-B:     %[[B_DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?x?xf32>
132// TRANSPOSE-B:     %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?x?xf32>
133// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty(%[[B_DIM0]], %[[B_DIM2]], %[[B_DIM1]]) : tensor<?x?x?xf32>
134// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<?x?x?xf32>) outs(%[[B_TRANSP_INIT]] : tensor<?x?x?xf32>) permutation = [0, 2, 1]
135// TRANSPOSE-B:     %[[C:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
136// CHECK:           return %[[C]] : tensor<?x?x?xf32>
137// CHECK:         }
138func.func @batch_matmul_dynamic(%A: tensor<?x?x?xf32>, %B: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) {
139  %cst = arith.constant 0.0 : f32
140  %c0 = arith.constant 0 : index
141  %c1 = arith.constant 1 : index
142  %c2 = arith.constant 2 : index
143  %d0 = tensor.dim %A, %c0 : tensor<?x?x?xf32>
144  %d1 = tensor.dim %A, %c1 : tensor<?x?x?xf32>
145  %d2 = tensor.dim %B, %c2 : tensor<?x?x?xf32>
146  %init = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xf32>
147  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
148  %0 = linalg.batch_matmul ins(%A, %B : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%C : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
149  return %0 : tensor<?x?x?xf32>
150}
151
152//-----
153
154// CHECK-LABEL:   func.func @batch_matmul_mixed(
155// CHECK-SAME:                                  %[[A:.*]]: tensor<2x?x8xf32>,
156// CHECK-SAME:                                  %[[B:.*]]: tensor<2x8x16xf32>) -> tensor<2x?x16xf32> {
157// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
158// CHECK:           %[[C1:.*]] = arith.constant 1 : index
159// CHECK:           %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<2x?x8xf32>
160// CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM1]]) : tensor<2x?x16xf32>
161// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
162// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]]) : tensor<2x8x?xf32>
163// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<2x?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<2x8x?xf32>) permutation = [0, 2, 1]
164// TRANSPOSE-A:     %[[B0:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x?xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
165// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x16x8xf32>
166// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<2x8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<2x16x8xf32>) permutation = [0, 2, 1]
167// TRANSPOSE-B:     %[[B0:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<2x?x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
168// CHECK:           return %[[B0]] : tensor<2x?x16xf32>
169// CHECK:         }
170func.func @batch_matmul_mixed(%A: tensor<2x?x8xf32>, %B: tensor<2x8x16xf32>) -> (tensor<2x?x16xf32>) {
171  %cst = arith.constant 0.0 : f32
172  %c0 = arith.constant 0 : index
173  %c1 = arith.constant 1 : index
174  %d1 = tensor.dim %A, %c1 : tensor<2x?x8xf32>
175  %init = tensor.empty(%d1) : tensor<2x?x16xf32>
176  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
177  %0 = linalg.batch_matmul ins(%A, %B : tensor<2x?x8xf32>, tensor<2x8x16xf32>) outs(%C : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
178  return %0 : tensor<2x?x16xf32>
179}
180