xref: /llvm-project/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir (revision 17afa5befb4cbe86c22c25ae1603433c8bd21551)
1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2
3/// Tests for `vector.contract` -> `vector.outerproduct` transformations for
4/// matmul operations:
5///   C += A * B.
6/// (A, B and C are 2-d matrices). ATM three different variants / are tested:
7///   * plain (no mask, fixed-wdith vectors),
8///   * masked (fixed-width vectors,
9///   * scalable (mask + scalable vectors).
10/// In order for the "vector.contract -> vector.outerproduct" patterns to work,
11/// only the non-reduction dimension can be scalable (*). For matmul operations
12/// that is set to be the N dimension (i.e. rows of the output matrix), which
13/// matches how matrix multiplication are normally implemented for e.g.
14/// Arm SVE. However, making the M dimension scalable (i.e. columns of the
15/// output matrix) should work as well.
16///
17/// (*) The conversion tested in this file unrolls along the reduction
18/// dimension, which is not supported for scalable vectors.
19
20#matmat_accesses_0 = [
21  affine_map<(m, n, k) -> (m, k)>,
22  affine_map<(m, n, k) -> (k, n)>,
23  affine_map<(m, n, k) -> (m, n)>
24]
25#matmat_trait_0 = {
26  indexing_maps = #matmat_accesses_0,
27  iterator_types = ["parallel", "parallel", "reduction"]
28}
29
30#matmat_accesses_1 = [
31  affine_map<(m, n, k) -> (m, k)>,
32  affine_map<(m, n, k) -> (n, k)>,
33  affine_map<(m, n, k) -> (m, n)>
34]
35#matmat_trait_1 = {
36  indexing_maps = #matmat_accesses_1,
37  iterator_types = ["parallel", "parallel", "reduction"]
38}
39
40#matmat_accesses_2 = [
41  affine_map<(m, n, k) -> (k, m)>,
42  affine_map<(m, n, k) -> (k, n)>,
43  affine_map<(m, n, k) -> (m, n)>
44]
45#matmat_trait_2 = {
46  indexing_maps = #matmat_accesses_2,
47  iterator_types = ["parallel", "parallel", "reduction"]
48}
49
50#matmat_accesses_3 = [
51  affine_map<(m, n, k) -> (k, m)>,
52  affine_map<(m, n, k) -> (n, k)>,
53  affine_map<(m, n, k) -> (m, n)>
54]
55#matmat_trait_3 = {
56  indexing_maps = #matmat_accesses_3,
57  iterator_types = ["parallel", "parallel", "reduction"]
58}
59
60#matmat_accesses_4 = [
61  affine_map<(m, n, k) -> (m, k)>,
62  affine_map<(m, n, k) -> (k, n)>,
63  affine_map<(m, n, k) -> (n, m)>
64]
65#matmat_trait_4 = {
66  indexing_maps = #matmat_accesses_4,
67  iterator_types = ["parallel", "parallel", "reduction"]
68}
69
70// ============================================================================
71//  Matmul 0 (plain + masked + mixed types)
72// ============================================================================
73// CHECK-LABEL: func @matmul
74// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
75// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
76// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
77//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
78// CHECK-SAME:  : vector<2x4xf32> to vector<4x2xf32>
79//
80//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
81//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<4x3xf32>
82//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
83// CHECK-SAME:  : vector<2xf32>, vector<3xf32>
84//
85//      CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
86//      CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<4x3xf32>
87//      CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
88// CHECK-SAME:  : vector<2xf32>, vector<3xf32>
89//
90//      CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
91//      CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<3xf32> from vector<4x3xf32>
92//      CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
93// CHECK-SAME:  : vector<2xf32>, vector<3xf32>
94//
95//      CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
96//      CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<3xf32> from vector<4x3xf32>
97//      CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
98// CHECK-SAME:  : vector<2xf32>, vector<3xf32>
99//
100//      CHECK: return %[[c3]] : vector<2x3xf32>
101func.func @matmul(%A: vector<2x4xf32>,
102                  %B: vector<4x3xf32>,
103                  %C: vector<2x3xf32>) -> vector<2x3xf32> {
104  %0 = vector.contract #matmat_trait_0 %A, %B, %C
105    : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
106  return %0 : vector<2x3xf32>
107}
108
109// CHECK-LABEL: func @matmul_scalable
110// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
111// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x[3]xf32>,
112// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
113//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
114// CHECK-SAME:  : vector<2x4xf32> to vector<4x2xf32>
115//
116//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
117//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<4x[3]xf32>
118//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
119// CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
120//
121//      CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
122//      CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<4x[3]xf32>
123//      CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
124// CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
125//
126//      CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
127//      CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<4x[3]xf32>
128//      CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
129// CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
130//
131//      CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
132//      CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<4x[3]xf32>
133//      CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
134// CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
135//
136//      CHECK: return %[[c3]] : vector<2x[3]xf32>
137func.func @matmul_scalable(%A: vector<2x4xf32>,
138                           %B: vector<4x[3]xf32>,
139                           %C: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
140  %0 = vector.contract #matmat_trait_0 %A, %B, %C
141    : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
142  return %0 : vector<2x[3]xf32>
143}
144
145// CHECK-LABEL: func.func @masked_matmul(
146// CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
147// CHECK-SAME:    %{{.*}}: vector<5x7xf32>,
148// CHECK-SAME:    %{{.*}}: vector<3x7xf32>,
149// CHECK-SAME:    %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
150// CHECK:         %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
151// CHECK:         %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
152// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
153// CHECK:         %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
154// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
155// CHECK:         %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
156// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
157// CHECK:         %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
158// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
159// CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
160// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
161
162func.func @masked_matmul(%A: vector<3x5xf32>,
163                         %B: vector<5x7xf32>,
164                         %C: vector<3x7xf32>,
165                         %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
166  %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
167  : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
168  return %0 : vector<3x7xf32>
169}
170
171// CHECK-LABEL: func.func @masked_matmul_scalable(
172// CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
173// CHECK-SAME:    %{{.*}}: vector<5x[7]xf32>,
174// CHECK-SAME:    %{{.*}}: vector<3x[7]xf32>,
175// CHECK-SAME:    %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
176// CHECK:         %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
177// CHECK:         %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
178// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
179// CHECK:         %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
180// CHECK:         %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
181// CHECK:         %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
182// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
183// CHECK:         %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
184// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
185// CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
186// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
187
188func.func @masked_matmul_scalable(%A: vector<3x5xf32>,
189                                  %B: vector<5x[7]xf32>,
190                                  %C: vector<3x[7]xf32>,
191                                  %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
192  %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
193  : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
194  return %0 : vector<3x[7]xf32>
195}
196
197// CHECK-LABEL: func @matmul_mixed
198// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
199// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
200// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
201//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
202//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
203//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf16> from vector<1x3xf16>
204//      CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
205//      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
206//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
207//      CHECK: return %[[c0]] : vector<2x3xf32>
208func.func @matmul_mixed(%A: vector<2x1xf16>,
209                        %B: vector<1x3xf16>,
210                        %C: vector<2x3xf32>) -> vector<2x3xf32>
211{
212  %0 = vector.contract #matmat_trait_0 %A, %B, %C
213    : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
214  return %0 : vector<2x3xf32>
215}
216
217// CHECK-LABEL: func @matmul_mixed_scalable
218// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
219// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf16>,
220// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
221//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
222//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
223//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<1x[3]xf16>
224//      CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
225//      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
226//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
227//      CHECK: return %[[c0]] : vector<2x[3]xf32>
228func.func @matmul_mixed_scalable(%A: vector<2x1xf16>,
229                                 %B: vector<1x[3]xf16>,
230                                 %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
231{
232  %0 = vector.contract #matmat_trait_0 %A, %B, %C
233    : vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
234  return %0 : vector<2x[3]xf32>
235}
236
237// ============================================================================
238//  Matmul 1 (plain + scalable)
239// ============================================================================
240// CHECK-LABEL: func @matmul_1
241// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
242// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
243// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
244//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
245//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
246//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
247//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
248//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
249//      CHECK: return %[[c0]] : vector<2x3xf32>
250func.func @matmul_1(%A: vector<2x1xf32>,
251                    %B: vector<3x1xf32>,
252                    %C: vector<2x3xf32>) -> vector<2x3xf32>
253{
254  %0 = vector.contract #matmat_trait_1 %A, %B, %C
255    : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
256  return %0 : vector<2x3xf32>
257}
258
259// CHECK-LABEL: func @matmul_1_scalable
260// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
261// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
262// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
263//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
264//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
265//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
266//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
267//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
268//      CHECK: return %[[c0]] : vector<2x[3]xf32>
269func.func @matmul_1_scalable(%A: vector<2x1xf32>,
270                             %B: vector<[3]x1xf32>,
271                             %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
272{
273  %0 = vector.contract #matmat_trait_1 %A, %B, %C
274    : vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
275  return %0 : vector<2x[3]xf32>
276}
277
278// ============================================================================
279//  Matmul 2 (plain + scalable)
280// ============================================================================
281// CHECK-LABEL: func @matmul_2
282// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
283// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
284// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
285//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
286//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
287//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
288//      CHECK: return %[[c0]] : vector<2x3xf32>
289func.func @matmul_2(%A: vector<1x2xf32>,
290                    %B: vector<1x3xf32>,
291                    %C: vector<2x3xf32>) -> vector<2x3xf32>
292{
293  %0 = vector.contract #matmat_trait_2 %A, %B, %C
294    : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
295  return %0 : vector<2x3xf32>
296}
297
298// CHECK-LABEL: func @matmul_2_scalable
299// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
300// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
301// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
302//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
303//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
304//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
305//      CHECK: return %[[c0]] : vector<2x[3]xf32>
306func.func @matmul_2_scalable(%A: vector<1x2xf32>,
307                             %B: vector<1x[3]xf32>,
308                             %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
309{
310  %0 = vector.contract #matmat_trait_2 %A, %B, %C
311    : vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
312  return %0 : vector<2x[3]xf32>
313}
314
315// ============================================================================
316//  Matmul 3 (plain + scalable)
317// ============================================================================
318// CHECK-LABEL: func @matmul_3
319// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
320// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
321// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
322//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
323//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
324//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
325//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
326//      CHECK: return %[[c0]] : vector<2x3xf32>
327func.func @matmul_3(%A: vector<1x2xf32>,
328                    %B: vector<3x1xf32>,
329                    %C: vector<2x3xf32>) -> vector<2x3xf32>
330{
331  %0 = vector.contract #matmat_trait_3 %A, %B, %C
332    : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
333  return %0 : vector<2x3xf32>
334}
335
336// CHECK-LABEL: func @matmul_3_scalable
337// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
338// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
339// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
340//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
341//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
342//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
343//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
344//      CHECK: return %[[c0]] : vector<2x[3]xf32>
345func.func @matmul_3_scalable(%A: vector<1x2xf32>,
346                             %B: vector<[3]x1xf32>,
347                             %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
348{
349  %0 = vector.contract #matmat_trait_3 %A, %B, %C
350    : vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
351  return %0 : vector<2x[3]xf32>
352}
353
354// ============================================================================
355//  Matmul 4 (plain + scalable)
356// ============================================================================
357// CHECK-LABEL: func @matmul_4
358// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
359// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
360// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
361//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
362//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
363//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
364//      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
365//      CHECK: return %[[c0]] : vector<3x2xf32>
366func.func @matmul_4(%A: vector<2x1xf32>,
367                    %B: vector<1x3xf32>,
368                    %C: vector<3x2xf32>) -> vector<3x2xf32>
369{
370  %0 = vector.contract #matmat_trait_4 %A, %B, %C
371    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
372  return %0 : vector<3x2xf32>
373}
374
375// CHECK-LABEL: func @matmul_4_scalable
376// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<[2]x1xf32>,
377// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
378// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x[2]xf32>
379//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
380//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
381//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
382//      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
383//      CHECK: return %[[c0]] : vector<3x[2]xf32>
384func.func @matmul_4_scalable(%A: vector<[2]x1xf32>,
385                             %B: vector<1x3xf32>,
386                             %C: vector<3x[2]xf32>) -> vector<3x[2]xf32>
387{
388  %0 = vector.contract #matmat_trait_4 %A, %B, %C
389    : vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
390  return %0 : vector<3x[2]xf32>
391}
392
393// ============================================================================
394//  TD sequence
395// ============================================================================
396module attributes {transform.with_named_sequence} {
397  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
398    %f = transform.structured.match ops{["func.func"]} in %module_op
399      : (!transform.any_op) -> !transform.any_op
400
401    transform.apply_patterns to %f {
402      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
403    } : !transform.any_op
404    transform.yield
405  }
406}
407