xref: /llvm-project/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir (revision 5a9bdd85ee4d8527e2cedf44f3ce26ff414f9b6a)
1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2
3/// Tests for `vector.contract` -> `vector.outerproduct` transformations for
4/// Matvec operations:
5///   b += A * x.
6/// (b and x are 1-d vectors, A is a 2-d matrix). ATM three different variants
7/// are tested:
8///   * plain (no mask, fixed-wdith vectors),
9///   * masked (fixed-width vectors,
10///   * scalable (mask + scalable vectors).
11///
12/// TODO: These tests were extracted from 2 different files. If you find the
13/// formatting inconsistent, please update accordingly.
14
15#matvec_accesses_1 = [
16  affine_map<(m, k) -> (m, k)>,
17  affine_map<(m, k) -> (k)>,
18  affine_map<(m, k) -> (m)>
19]
20#matvec_trait_1 = {
21  indexing_maps = #matvec_accesses_1,
22  iterator_types = ["parallel", "reduction"]
23}
24
25#matvecmax_trait = {
26  indexing_maps = #matvec_accesses_1,
27  iterator_types = ["parallel", "reduction"],
28  kind = #vector.kind<maxnumf>
29}
30
31#matvec_accesses_2 = [
32  affine_map<(m, k) -> (k, m)>,
33  affine_map<(m, k) -> (k)>,
34  affine_map<(m, k) -> (m)>
35]
36#matvec_trait_2 = {
37  indexing_maps = #matvec_accesses_2,
38  iterator_types = ["parallel", "reduction"]
39}
40
41#matvec_accesses_3 = [
42  affine_map<(m, k) -> (k)>,
43  affine_map<(m, k) -> (m, k)>,
44  affine_map<(m, k) -> (m)>
45]
46#matvec_trait_3 = {
47  indexing_maps = #matvec_accesses_3,
48  iterator_types = ["parallel", "reduction"]
49}
50
51#matvec_accesses_4 = [
52  affine_map<(m, k) -> (k)>,
53  affine_map<(m, k) -> (k, m)>,
54  affine_map<(m, k) -> (m)>
55]
56#matvec_trait_4 = {
57  indexing_maps = #matvec_accesses_4,
58  iterator_types = ["parallel", "reduction"]
59}
60
61#matvec_accesses_5 = [
62  affine_map<(k, m) -> (m, k)>,
63  affine_map<(k, m) -> (k)>,
64  affine_map<(k, m) -> (m)>
65]
66#matvec_trait_5 = {
67  indexing_maps = #matvec_accesses_5,
68  iterator_types = ["reduction", "parallel"]
69}
70
71#matvec_accesses_6 = [
72  affine_map<(k, m) -> (k, m)>,
73  affine_map<(k, m) -> (k)>,
74  affine_map<(k, m) -> (m)>
75]
76#matvec_trait_6 = {
77  indexing_maps = #matvec_accesses_6,
78  iterator_types = ["reduction", "parallel"]
79}
80
81#matvec_accesses_7 = [
82  affine_map<(k, m) -> (k)>,
83  affine_map<(k, m) -> (m, k)>,
84  affine_map<(k, m) -> (m)>
85]
86#matvec_trait_7 = {
87  indexing_maps = #matvec_accesses_7,
88  iterator_types = ["reduction", "parallel"]
89}
90
91#matvec_accesses_8 = [
92  affine_map<(k, m) -> (k)>,
93  affine_map<(k, m) -> (k, m)>,
94  affine_map<(k, m) -> (m)>
95]
96#matvec_trait_8 = {
97  indexing_maps = #matvec_accesses_8,
98  iterator_types = ["reduction", "parallel"]
99}
100
101// ============================================================================
102//  Matvec 1 (plain + masked + scalable)
103// ============================================================================
104// CHECK-LABEL: func @matvec_mk_k_m
105// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
106// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
107// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
108// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
109// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
110// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
111// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
112// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
113// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
114// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
115func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
116                         %x: vector<2xf32>,
117                         %b: vector<2xf32>) -> vector<2xf32> {
118  %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
119  return %0 : vector<2xf32>
120}
121
122// CHECK-LABEL:   func.func @masked_matvec_mk_k_m(
123// CHECK-SAME:      %{{.*}}: vector<2x3xf32>,
124// CHECK-SAME:      %{{.*}}: vector<3xf32>,
125// CHECK-SAME:      %{{.*}}: vector<2xf32>,
126// CHECK-SAME:      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
127// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
128// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
129// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
130
131// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
132// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
133
134// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
135// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
136func.func @masked_matvec_mk_k_m(%A: vector<2x3xf32>,
137                                %x: vector<3xf32>,
138                                %b: vector<2xf32>,
139                                %m: vector<2x3xi1>) -> vector<2xf32> {
140  %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
141          : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
142  return %0 : vector<2xf32>
143}
144
145// CHECK-LABEL:   func.func @masked_matvec_mk_k_m_scalable_parallel_dim(
146// CHECK-SAME:      %{{.*}}: vector<[2]x3xf32>,
147// CHECK-SAME:      %{{.*}}: vector<3xf32>,
148// CHECK-SAME:      %{{.*}}: vector<[2]xf32>,
149// CHECK-SAME:      %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
150// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
151// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
152// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
153
154// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
155// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
156
157// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
158// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
159func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%A: vector<[2]x3xf32>,
160                                                      %x: vector<3xf32>,
161                                                      %b: vector<[2]xf32>,
162                                                      %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
163  %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
164          : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
165  return %0 : vector<[2]xf32>
166}
167
168// ============================================================================
169//  Matvec 1  - max (plain)
170// ============================================================================
171// CHECK-LABEL: func @matvec_mk_k_m_max
172// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
173// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
174// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
175// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
176// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
177// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
178// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32
179// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
180// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
181// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32
182func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
183                             %x: vector<2xf32>,
184                             %b: vector<2xf32>) -> vector<2xf32> {
185  %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
186  return %0 : vector<2xf32>
187}
188
189// CHECK-LABEL:   func.func @masked_matvec_mk_k_m_max(
190// CHECK-SAME:      %{{.*}}: vector<2x3xf32>,
191// CHECK-SAME:      %{{.*}}: vector<3xf32>,
192// CHECK-SAME:      %{{.*}}: vector<2xf32>,
193// CHECK-SAME:      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
194// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
195// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
196// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
197
198// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
199// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
200
201// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
202// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
203func.func @masked_matvec_mk_k_m_max(%A: vector<2x3xf32>,
204                                    %x: vector<3xf32>,
205                                    %b: vector<2xf32>,
206                                    %m: vector<2x3xi1>) -> vector<2xf32> {
207  %0 = vector.mask %m { vector.contract #matvecmax_trait %A, %x, %b
208          : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
209  return %0 : vector<2xf32>
210}
211
212// CHECK-LABEL:   func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(
213// CHECK-SAME:      %{{.*}}: vector<[2]x3xf32>,
214// CHECK-SAME:      %{{.*}}: vector<3xf32>,
215// CHECK-SAME:      %{{.*}}: vector<[2]xf32>,
216// CHECK-SAME:      %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
217// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
218// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
219// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
220
221// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
222// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
223
224// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
225// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxnumf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
226func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(%A: vector<[2]x3xf32>,
227                                                          %x: vector<3xf32>,
228                                                          %b: vector<[2]xf32>,
229                                                          %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
230  %0 = vector.mask %m { vector.contract #matvecmax_trait %A, %x, %b
231          : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
232  return %0 : vector<[2]xf32>
233}
234
235// ============================================================================
236//  Matvec 2 (plain + masked + scalable)
237// ============================================================================
238// CHECK-LABEL: func @matvec_km_k_m
239// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
240// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
241// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
242// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
243// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
244// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
245// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
246// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
247// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
248func.func @matvec_km_k_m(%A: vector<2x2xf32>,
249                         %x: vector<2xf32>,
250                         %b: vector<2xf32>) -> vector<2xf32> {
251  %0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
252  return %0 : vector<2xf32>
253}
254
255// CHECK-LABEL: @masked_matvec_km_k_m
256// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
257// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
258// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
259// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
260func.func @masked_matvec_km_k_m(%A: vector<2x4xf32>,
261                                %x: vector<2xf32>,
262                                %b: vector<4xf32>,
263                                %mask: vector<4x2xi1>) -> vector<4xf32> {
264  // CHECK:         vector.transpose %[[MASK]]
265  // CHECK-NOT:     vector.transpose %[[A]]
266  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
267  %res = vector.mask %mask {
268    vector.contract #matvec_trait_2 %A, %x, %b
269      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
270  } : vector<4x2xi1> -> vector<4xf32>
271  return %res : vector<4xf32>
272}
273
274// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
275// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
276// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
277// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
278// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
279func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
280                                                      %x: vector<2xf32>,
281                                                      %b: vector<[4]xf32>,
282                                                      %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
283  // CHECK:         vector.transpose %[[MASK]]
284  // CHECK-NOT:     vector.transpose %[[A]]
285  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
286  %res = vector.mask %mask {
287    vector.contract #matvec_trait_2 %A, %x, %b
288      : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
289  } : vector<[4]x2xi1> -> vector<[4]xf32>
290  return %res : vector<[4]xf32>
291}
292
293// ============================================================================
294//  Matvec 3 (plain + masked + scalable)
295// ============================================================================
296// CHECK-LABEL: func @matvec_k_mk_m
297// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
298// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
299// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
300// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
301// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
302// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
303// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
304// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
305// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
306// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
307func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
308                         %x: vector<2xf32>,
309                         %b: vector<2xf32>) -> vector<2xf32> {
310  %0 = vector.contract #matvec_trait_3 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
311  return %0 : vector<2xf32>
312}
313
314// CHECK-LABEL: @masked_matvec_k_mk_m
315// CHECK-SAME:  %[[A:.+]]: vector<4x2xf32>
316// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
317// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
318// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
319func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>,
320                                %x: vector<2xf32>,
321                                %b: vector<4xf32>,
322                                %mask: vector<4x2xi1>) -> vector<4xf32> {
323  // CHECK:         vector.transpose %[[A]]
324  // CHECK:         vector.transpose %[[MASK]]
325  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
326  %res = vector.mask %mask {
327      vector.contract #matvec_trait_3 %x, %A, %b
328        : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
329  } : vector<4x2xi1> -> vector<4xf32>
330  return %res : vector<4xf32>
331}
332
333// CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
334// CHECK-SAME:  %[[A:.+]]: vector<[4]x2xf32>
335// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
336// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
337// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
338func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
339                                                      %x: vector<2xf32>,
340                                                      %b: vector<[4]xf32>,
341                                                      %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
342  // CHECK:         vector.transpose %[[A]]
343  // CHECK:         vector.transpose %[[MASK]]
344  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
345  %res = vector.mask %mask {
346      vector.contract #matvec_trait_3 %x, %A, %b
347        : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
348  } : vector<[4]x2xi1> -> vector<[4]xf32>
349  return %res : vector<[4]xf32>
350}
351
352// ============================================================================
353//  Matvec 4 (plain + masked + scalable)
354// ============================================================================
355// CHECK-LABEL: func @matvec_k_km_m
356// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
357// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
358// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
359// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
360// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
361// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
362// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
363// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
364// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
365func.func @matvec_k_km_m(%A: vector<2x2xf32>,
366                         %x: vector<2xf32>,
367                         %b: vector<2xf32>) -> vector<2xf32> {
368  %0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
369  return %0 : vector<2xf32>
370}
371
372// CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
373// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
374// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
375// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
376// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
377func.func @masked_matvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
378                                                      %x: vector<2xf32>,
379                                                      %b: vector<[4]xf32>,
380                                                      %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
381  // CHECK:         vector.transpose %[[MASK]]
382  // CHECK-NOT:     vector.transpose %[[A]]
383  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
384  %res = vector.mask %mask {
385    vector.contract #matvec_trait_4 %x, %A, %b
386      : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
387  } : vector<[4]x2xi1> -> vector<[4]xf32>
388  return %res : vector<[4]xf32>
389}
390
391// CHECK-LABEL: @masked_matvec_k_km_m
392// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
393// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
394// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
395// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
396func.func @masked_matvec_k_km_m(%A: vector<2x4xf32>,
397                                %x: vector<2xf32>,
398                                %b: vector<4xf32>,
399                                %mask: vector<4x2xi1>) -> vector<4xf32> {
400  // CHECK:         vector.transpose %[[MASK]]
401  // CHECK-NOT:     vector.transpose %[[A]]
402  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
403  %res = vector.mask %mask {
404    vector.contract #matvec_trait_4 %x, %A, %b
405      : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
406  } : vector<4x2xi1> -> vector<4xf32>
407  return %res : vector<4xf32>
408}
409
410// ============================================================================
411//  Matvec 5 (plain + masked + scalable)
412// ============================================================================
413// CHECK-LABEL:   func.func @tmatvec_mk_k_m(
414// CHECK-SAME:      %[[A:.*]]: vector<2x2xf32>,
415// CHECK-SAME:      %[[X:.*]]: vector<2xf32>,
416// CHECK-SAME:      %[[B:.*]]: vector<2xf32>) -> vector<2xf32> {
417// CHECK:           %[[VAL_3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
418// CHECK:           %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : vector<2xf32> from vector<2x2xf32>
419// CHECK:           %[[VAL_5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
420// CHECK:           %[[VAL_6:.*]] = vector.outerproduct %[[VAL_4]], %[[VAL_5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
421// CHECK:           %[[VAL_7:.*]] = vector.extract %[[VAL_3]][1] : vector<2xf32> from vector<2x2xf32>
422// CHECK:           %[[VAL_8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
423// CHECK:           %[[VAL_9:.*]] = vector.outerproduct %[[VAL_7]], %[[VAL_8]], %[[VAL_6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
424func.func @tmatvec_mk_k_m(%A: vector<2x2xf32>,
425                          %x: vector<2xf32>,
426                          %b: vector<2xf32>) -> vector<2xf32> {
427  %0 = vector.contract #matvec_trait_5 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
428  return %0 : vector<2xf32>
429}
430
431// CHECK-LABEL: @masked_tmatvec_mk_k_m
432// CHECK-SAME:  %[[A:.+]]: vector<4x2xf32>
433// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
434// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
435// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
436func.func @masked_tmatvec_mk_k_m(%A: vector<4x2xf32>,
437                                 %x: vector<2xf32>,
438                                 %b: vector<4xf32>,
439                                 %mask: vector<2x4xi1>) -> vector<4xf32> {
440  // CHECK:         vector.transpose %[[A]]
441  // CHECK-NOT:     vector.transpose %[[MASK]]
442  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
443  %res = vector.mask %mask {
444    vector.contract #matvec_trait_5 %A, %x, %b
445      : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
446  } : vector<2x4xi1> -> vector<4xf32>
447  return %res : vector<4xf32>
448}
449
450// CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable_parallel_dim
451// CHECK-SAME:  %[[A:.+]]: vector<[4]x2xf32>
452// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
453// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
454// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
455func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
456                                                       %x: vector<2xf32>,
457                                                       %b: vector<[4]xf32>,
458                                                       %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
459  // CHECK:         vector.transpose %[[A]]
460  // CHECK-NOT:     vector.transpose %[[MASK]]
461  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
462  %res = vector.mask %mask {
463    vector.contract #matvec_trait_5 %A, %x, %b
464      : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
465  } : vector<2x[4]xi1> -> vector<[4]xf32>
466  return %res : vector<[4]xf32>
467}
468
469// ============================================================================
470//  Matvec 6 (plain + masked + scalable)
471// ============================================================================
472// CHECK-LABEL:   func.func @tmatvec_km_k_m(
473// CHECK-SAME:      %[[A:.*]]: vector<2x2xf32>,
474// CHECK-SAME:      %[[X:.*]]: vector<2xf32>,
475// CHECK-SAME:      %[[B:.*]]: vector<2xf32>) -> vector<2xf32> {
476// CHECK:           %[[VAL_3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
477// CHECK:           %[[VAL_4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
478// CHECK:           %[[VAL_5:.*]] = vector.outerproduct %[[VAL_3]], %[[VAL_4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
479// CHECK:           %[[VAL_6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
480// CHECK:           %[[VAL_7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
481// CHECK:           %[[VAL_8:.*]] = vector.outerproduct %[[VAL_6]], %[[VAL_7]], %[[VAL_5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
482func.func @tmatvec_km_k_m(%A: vector<2x2xf32>,
483                          %x: vector<2xf32>,
484                          %b: vector<2xf32>) -> vector<2xf32> {
485  %0 = vector.contract #matvec_trait_6 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
486  return %0 : vector<2xf32>
487}
488
489// CHECK-LABEL: @masked_tmatvec_km_k_m
490// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
491// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
492// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
493// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
494func.func @masked_tmatvec_km_k_m(%A: vector<2x4xf32>,
495                                 %x: vector<2xf32>,
496                                 %b: vector<4xf32>,
497                                 %mask: vector<2x4xi1>) -> vector<4xf32> {
498  // CHECK-NOT:     vector.transpose %[[A]]
499  // CHECK-NOT:     vector.transpose %[[MASK]]
500  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
501  %res = vector.mask %mask {
502    vector.contract #matvec_trait_6 %A, %x, %b
503      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
504  } : vector<2x4xi1> -> vector<4xf32>
505  return %res : vector<4xf32>
506}
507
508// CHECK-LABEL: @masked_tmatvec_km_k_m_scalable_parallel_dim
509// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
510// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
511// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
512// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
513func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
514                                                       %x: vector<2xf32>,
515                                                       %b: vector<[4]xf32>,
516                                                       %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
517  // CHECK-NOT:     vector.transpose %[[A]]
518  // CHECK-NOT:     vector.transpose %[[MASK]]
519  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
520  %res = vector.mask %mask {
521    vector.contract #matvec_trait_6 %A, %x, %b
522      : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
523  } : vector<2x[4]xi1> -> vector<[4]xf32>
524  return %res : vector<[4]xf32>
525}
526
527// ============================================================================
528//  Matvec 7 (plain + masked + scalable)
529// ============================================================================
530// CHECK-LABEL:   func.func @tmatvec_k_mk_m(
531// CHECK-SAME:      %[[A:.*]]: vector<2x2xf32>,
532// CHECK-SAME:      %[[X:.*]]: vector<2xf32>,
533// CHECK-SAME:      %[[B:.*]]: vector<2xf32>) -> vector<2xf32> {
534// CHECK:           %[[VAL_3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
535// CHECK:           %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : vector<2xf32> from vector<2x2xf32>
536// CHECK:           %[[VAL_5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
537// CHECK:           %[[VAL_6:.*]] = vector.outerproduct %[[VAL_4]], %[[VAL_5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
538// CHECK:           %[[VAL_7:.*]] = vector.extract %[[VAL_3]][1] : vector<2xf32> from vector<2x2xf32>
539// CHECK:           %[[VAL_8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
540// CHECK:           %[[VAL_9:.*]] = vector.outerproduct %[[VAL_7]], %[[VAL_8]], %[[VAL_6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
541func.func @tmatvec_k_mk_m(%A: vector<2x2xf32>,
542                          %x: vector<2xf32>,
543                          %b: vector<2xf32>) -> vector<2xf32> {
544  %0 = vector.contract #matvec_trait_7 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
545  return %0 : vector<2xf32>
546}
547
548// CHECK-LABEL: @masked_tmatvec_k_mk_m
549// CHECK-SAME:  %[[A:.+]]: vector<4x2xf32>
550// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
551// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
552// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
553func.func @masked_tmatvec_k_mk_m(%A: vector<4x2xf32>,
554                                 %x: vector<2xf32>,
555                                 %b: vector<4xf32>,
556                                 %mask: vector<2x4xi1>) -> vector<4xf32> {
557  // CHECK:         vector.transpose %[[A]]
558  // CHECK-NOT:     vector.transpose %[[MASK]]
559  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
560  %res = vector.mask %mask {
561    vector.contract #matvec_trait_7 %x, %A, %b
562      : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
563  } : vector<2x4xi1> -> vector<4xf32>
564  return %res : vector<4xf32>
565}
566
567// CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable_parallel_dim
568// CHECK-SAME:  %[[A:.+]]: vector<[4]x2xf32>
569// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
570// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
571// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
572func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
573                                                       %x: vector<2xf32>,
574                                                       %b: vector<[4]xf32>,
575                                                       %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
576  // CHECK:         vector.transpose %[[A]]
577  // CHECK-NOT:     vector.transpose %[[MASK]]
578  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
579  %res = vector.mask %mask {
580    vector.contract #matvec_trait_7 %x, %A, %b
581      : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
582  } : vector<2x[4]xi1> -> vector<[4]xf32>
583  return %res : vector<[4]xf32>
584}
585
586// ============================================================================
587//  Matvec 8 (plain + masked + scalable)
588// ============================================================================
589// CHECK-LABEL: func @tmatvec_m_mk_k
590// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
591// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
592// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
593// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
594// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
595// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
596// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
597// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
598// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
599func.func @tmatvec_m_mk_k(%A: vector<2x2xf32>,
600                          %x: vector<2xf32>,
601                          %b: vector<2xf32>) -> vector<2xf32> {
602  %0 = vector.contract #matvec_trait_8 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
603  return %0 : vector<2xf32>
604}
605
606// CHECK-LABEL: @masked_tmatvec_k_km_m
607// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
608// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
609// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
610// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
611func.func @masked_tmatvec_k_km_m(%A: vector<2x4xf32>,
612                                 %x: vector<2xf32>,
613                                 %b: vector<4xf32>,
614                                 %mask: vector<2x4xi1>) -> vector<4xf32> {
615  // CHECK-NOT:     vector.transpose %[[A]]
616  // CHECK-NOT:     vector.transpose %[[MASK]]
617  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
618  %res = vector.mask %mask {
619    vector.contract #matvec_trait_8 %x, %A, %b
620      : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
621  } : vector<2x4xi1> -> vector<4xf32>
622  return %res : vector<4xf32>
623}
624
625// CHECK-LABEL: @masked_tmatvec_k_km_m_scalable_parallel_dim
626// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
627// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
628// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
629// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
630func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
631                                                       %x: vector<2xf32>,
632                                                       %b: vector<[4]xf32>,
633                                                       %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
634  // CHECK-NOT:     vector.transpose %[[A]]
635  // CHECK-NOT:     vector.transpose %[[MASK]]
636  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
637  %res = vector.mask %mask {
638    vector.contract #matvec_trait_8 %x, %A, %b
639      : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
640  } : vector<2x[4]xi1> -> vector<[4]xf32>
641  return %res : vector<[4]xf32>
642}
643
644// Unrolling scalable reduction dim is not supported - bail out
645// CHECK-LABEL: @masked_extract_contract2_scalable_reduction_dim(
646// CHECK:         vector.contract {{.*}} : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32>
647func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
648                                    %arg1: vector<[3]xf32>,
649                                    %arg2: vector<[2]xf32>,
650                                    %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
651  %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
652          : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
653  return %0 : vector<[2]xf32>
654}
655
656// ============================================================================
657//  TD sequence
658// ============================================================================
659module attributes {transform.with_named_sequence} {
660  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
661    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
662    transform.apply_patterns to %func_op {
663      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
664    } : !transform.op<"func.func">
665    transform.yield
666  }
667}
668