xref: /llvm-project/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir (revision e4384149b58f7c3d19c5d38bc46038c660b77ca9)
1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2
3#dotp_accesses = [
4  affine_map<(i) -> (i)>,
5  affine_map<(i) -> (i)>,
6  affine_map<(i) -> ()>
7]
8#dotp_trait = {
9  indexing_maps = #dotp_accesses,
10  iterator_types = ["reduction"]
11}
12
13// CHECK-LABEL: func @extract_contract1
14// CHECK-SAME: %[[A:.*0]]: vector<4xf32>,
15// CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
16// CHECK-SAME: %[[C:.*2]]: f32
17// CHECK:      %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32>
18// CHECK:      %[[R:.*]] = vector.reduction <add>, %[[F]], %[[C]] : vector<4xf32> into f32
19// CHECK:      return %[[R]] : f32
20
21func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
22  %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
23    : vector<4xf32>, vector<4xf32> into f32
24  return %0 : f32
25}
26
27// CHECK-LABEL: func @masked_extract_contract1
28//  CHECK-SAME:   %[[A:.*0]]: vector<4xf32>, %[[B:.*1]]: vector<4xf32>, %[[C:.*2]]: f32
29//  CHECK-SAME:   %[[M:.*]]: vector<4xi1>
30//       CHECK:   %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32>
31//       CHECK:   %[[R:.*]] = vector.mask %[[M]] { vector.reduction <add>, %0, %arg2 : vector<4xf32> into f32 } : vector<4xi1> -> f32
32//       CHECK:   return %[[R]] : f32
33
34func.func @masked_extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32, %mask: vector<4xi1>) -> f32 {
35  %0 = vector.mask %mask { vector.contract #dotp_trait %arg0, %arg1, %arg2 : vector<4xf32>, vector<4xf32> into f32 } : vector<4xi1> -> f32
36  return %0 : f32
37}
38
39// CHECK-LABEL: func @extract_contract1_int
40// CHECK-SAME: %[[A:.*0]]: vector<4xi32>,
41// CHECK-SAME: %[[B:.*1]]: vector<4xi32>,
42// CHECK-SAME: %[[C:.*2]]: i32
43// CHECK:      %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32>
44// CHECK:      %[[R:.*]] = vector.reduction <add>, %[[F]], %[[C]] : vector<4xi32> into i32
45// CHECK:      return %[[R]] : i32
46
47func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 {
48  %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
49    : vector<4xi32>, vector<4xi32> into i32
50  return %0 : i32
51}
52
53#matvec_accesses = [
54  affine_map<(i, j) -> (i, j)>,
55  affine_map<(i, j) -> (j)>,
56  affine_map<(i, j) -> (i)>
57]
58#matvec_trait = {
59  indexing_maps = #matvec_accesses,
60  iterator_types = ["parallel", "reduction"]
61}
62
63// CHECK-LABEL: func @extract_contract2
64// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
65// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
66// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
67// CHECK:      %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
68// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32>
69// CHECK:      %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] : vector<3xf32>
70// CHECK:      %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xf32> into f32
71// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
72// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32>
73// CHECK:      %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
74// CHECK:      %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xf32> into f32
75// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
76// CHECK:      %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32>
77// CHECK:      return %[[T10]] : vector<2xf32>
78
79func.func @extract_contract2(%arg0: vector<2x3xf32>,
80                        %arg1: vector<3xf32>,
81                        %arg2: vector<2xf32>) -> vector<2xf32> {
82  %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
83    : vector<2x3xf32>, vector<3xf32> into vector<2xf32>
84  return %0 : vector<2xf32>
85}
86
87// CHECK-LABEL: func @extract_contract2_int
88// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>,
89// CHECK-SAME: %[[B:.*1]]: vector<3xi32>,
90// CHECK-SAME: %[[C:.*2]]: vector<2xi32>
91// CHECK:      %[[R:.*]] = arith.constant dense<0> : vector<2xi32>
92// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xi32> from vector<2x3xi32>
93// CHECK:      %[[T2:.*]] = arith.muli %[[T0]], %[[B]] : vector<3xi32>
94// CHECK:      %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xi32> into i32
95// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32>
96// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<3xi32> from vector<2x3xi32>
97// CHECK:      %[[T7:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
98// CHECK:      %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xi32> into i32
99// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32>
100// CHECK:      %[[T10:.*]] = arith.addi %[[T9]], %[[C]] : vector<2xi32>
101// CHECK:      return %[[T10]] : vector<2xi32>
102func.func @extract_contract2_int(%arg0: vector<2x3xi32>,
103                        %arg1: vector<3xi32>,
104                        %arg2: vector<2xi32>) -> vector<2xi32> {
105  %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
106    : vector<2x3xi32>, vector<3xi32> into vector<2xi32>
107  return %0 : vector<2xi32>
108}
109
110#vecmat_accesses = [
111  affine_map<(i, j) -> (j)>,
112  affine_map<(i, j) -> (i, j)>,
113  affine_map<(i, j) -> (i)>
114]
115#vecmat_trait = {
116  indexing_maps = #vecmat_accesses,
117  iterator_types = ["parallel", "reduction"]
118}
119
120// CHECK-LABEL: func @extract_contract3
121// CHECK-SAME: %[[A:.*0]]: vector<3xf32>,
122// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
123// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
124// CHECK:      %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
125// CHECK:      %[[T0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<2x3xf32>
126// CHECK:      %[[T2:.*]] = arith.mulf %[[T0]], %[[A]] : vector<3xf32>
127// CHECK:      %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xf32> into f32
128// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
129// CHECK:      %[[T5:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<2x3xf32>
130// CHECK:      %[[T7:.*]] = arith.mulf %[[T5]], %[[A]] : vector<3xf32>
131// CHECK:      %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xf32> into f32
132// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
133// CHECK:      %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32>
134// CHECK:      return %[[T10]] : vector<2xf32>
135
136func.func @extract_contract3(%arg0: vector<3xf32>,
137                        %arg1: vector<2x3xf32>,
138                        %arg2: vector<2xf32>) -> vector<2xf32> {
139  %0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2
140    : vector<3xf32>, vector<2x3xf32> into vector<2xf32>
141  return %0 : vector<2xf32>
142}
143
144#matmat_accesses = [
145  affine_map<(i, j, k) -> (i, k)>,
146  affine_map<(i, j, k) -> (k, j)>,
147  affine_map<(i, j, k) -> (i, j)>
148]
149#matmat_trait = {
150  indexing_maps = #matmat_accesses,
151  iterator_types = ["parallel", "parallel", "reduction"]
152}
153
154// CHECK-LABEL: func @extract_contract4
155// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
156// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
157// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
158// CHECK:    %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
159// CHECK:    %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32>
160// CHECK:    %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
161// CHECK:    %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
162// CHECK:    %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32>
163// CHECK:    %[[T10:.*]] = vector.reduction <add>, %[[T9]] : vector<2xf32> into f32
164// CHECK:    %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32>
165//
166// CHECK:    %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
167// CHECK:    %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32>
168// CHECK:    %[[T20:.*]] = vector.reduction <add>, %[[T19]] : vector<2xf32> into f32
169// CHECK:    %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32>
170//
171// CHECK:    %[[T23:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
172// CHECK:    %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
173// CHECK:    %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32>
174// CHECK:    %[[T33:.*]] = vector.reduction <add>, %[[T32]] : vector<2xf32> into f32
175// CHECK:    %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32>
176//
177// CHECK:    %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
178// CHECK:    %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32>
179// CHECK:    %[[T42:.*]] = vector.reduction <add>, %[[T41]] : vector<2xf32> into f32
180// CHECK:    %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32>
181//
182// CHECK:    %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32>
183// CHECK:    return %[[T52]] : vector<2x2xf32>
184
185func.func @extract_contract4(%arg0: vector<2x2xf32>,
186                        %arg1: vector<2x2xf32>,
187                        %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
188  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
189    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
190  return %0 : vector<2x2xf32>
191}
192
193
194#contraction2d_accesses = [
195  affine_map<(i, j) -> (i, j)>,
196  affine_map<(i, j) -> (i, j)>,
197  affine_map<(i, j) -> ()>
198]
199#contraction2d_trait = {
200  indexing_maps = #contraction2d_accesses,
201  iterator_types = ["reduction", "reduction"]
202}
203
204// CHECK-LABEL: func @full_contract1
205// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
206// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
207// CHECK-SAME: %[[C:.*2]]: f32
208// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32>
209// CHECK:      %[[T1:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<2x3xf32>
210// CHECK:      %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32>
211// CHECK:      %[[T3:.*]] = vector.reduction <add>, %[[T2]], %[[C]] : vector<3xf32> into f32
212// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32>
213// CHECK:      %[[T6:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<2x3xf32>
214// CHECK:      %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32>
215// CHECK:      %[[T8:.*]] = vector.reduction <add>, %[[T7]], %[[T3]] : vector<3xf32> into f32
216// CHECK:      return %[[T8]] : f32
217
218func.func @full_contract1(%arg0: vector<2x3xf32>,
219                     %arg1: vector<2x3xf32>,
220                     %arg2: f32) -> f32 {
221  %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2
222    : vector<2x3xf32>, vector<2x3xf32> into f32
223  return %0 : f32
224}
225
226#contraction2d_trans_accesses = [
227  affine_map<(i, j) -> (i, j)>,
228  affine_map<(i, j) -> (j, i)>,
229  affine_map<(i, j) -> ()>
230]
231#contraction2d_trans_trait = {
232  indexing_maps = #contraction2d_trans_accesses,
233  iterator_types = ["reduction", "reduction"]
234}
235
236// CHECK-LABEL: func @full_contract2
237// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
238// CHECK-SAME: %[[B:.*1]]: vector<3x2xf32>,
239// CHECK-SAME: %[[C:.*2]]: f32
240// CHECK:      %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3xf32>
241// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32>
242// CHECK:      %[[T1:.*]] = vector.extract %[[B]][0, 0] : f32 from vector<3x2xf32>
243// CHECK:      %[[T3:.*]] = vector.insert %[[T1]], %[[Z]] [0] : f32 into vector<3xf32>
244// CHECK:      %[[T4:.*]] = vector.extract %[[B]][1, 0] : f32 from vector<3x2xf32>
245// CHECK:      %[[T6:.*]] = vector.insert %[[T4]], %[[T3]] [1] : f32 into vector<3xf32>
246// CHECK:      %[[T7:.*]] = vector.extract %[[B]][2, 0] : f32 from vector<3x2xf32>
247// CHECK:      %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32>
248// CHECK:      %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32>
249// CHECK:      %[[T11:.*]] = vector.reduction <add>, %[[T10]], %[[C]] : vector<3xf32> into f32
250//
251// CHECK:      %[[T12:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32>
252// CHECK:      %[[T13:.*]] = vector.extract %[[B]][0, 1] : f32 from vector<3x2xf32>
253// CHECK:      %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32>
254// CHECK:      %[[T16:.*]] = vector.extract %[[B]][1, 1] : f32 from vector<3x2xf32>
255// CHECK:      %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32>
256// CHECK:      %[[T19:.*]] = vector.extract %[[B]][2, 1] : f32 from vector<3x2xf32>
257// CHECK:      %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32>
258// CHECK:      %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32>
259// CHECK:      %[[T23:.*]] = vector.reduction <add>, %[[T22]], %[[T11]] : vector<3xf32> into f32
260// CHECK:      return %[[T23]] : f32
261
262func.func @full_contract2(%arg0: vector<2x3xf32>,
263                     %arg1: vector<3x2xf32>,
264                     %arg2: f32) -> f32 {
265  %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2
266    : vector<2x3xf32>, vector<3x2xf32> into f32
267  return %0 : f32
268}
269
270// CHECK-LABEL: @contract_one_sided_unit_reduction_dim
271// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>)
272// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32>
273// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<2xi32> from vector<1x2xi32>
274// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2xi32> from vector<2x2xi32>
275// CHECK:     %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32>
276// CHECK:     %[[R0:.+]] = vector.reduction <add>, %[[M0]] : vector<2xi32> into i32
277// CHECK:     %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32>
278// CHECK:     %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2xi32> from vector<2x2xi32>
279// CHECK:     %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32>
280// CHECK:     %[[R1:.+]] = vector.reduction <add>, %[[M1]] : vector<2xi32> into i32
281// CHECK:     %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32>
282// CHECK:     %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32>
283// CHECK:     return %[[S]] : vector<2xi32>
284
285func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> {
286  %res = vector.contract {
287    indexing_maps = [
288      affine_map<(d0, d1, d2) -> (d0, d2)>,
289      affine_map<(d0, d1, d2) -> (d1, d2)>,
290      affine_map<(d0, d1, d2) -> (d1)>
291    ],
292    iterator_types = ["reduction", "parallel", "reduction"],
293    kind = #vector.kind<add>
294  } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32>
295  return %res : vector<2xi32>
296}
297
298module attributes {transform.with_named_sequence} {
299  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
300    %f = transform.structured.match ops{["func.func"]} in %module_op
301      : (!transform.any_op) -> !transform.any_op
302
303    transform.apply_patterns to %f {
304      transform.apply_patterns.vector.lower_contraction lowering_strategy = "dot"
305    } : !transform.any_op
306    transform.yield
307  }
308}
309