xref: /llvm-project/mlir/test/Dialect/Vector/vector-contract-matmul-transforms.mlir (revision 5041fe8439c161e5d9d8f7774f7ca95af46d880e)
1// RUN: mlir-opt %s -test-vector-contraction-prepare-for-mmt-lowering | FileCheck %s
2
3// CHECK-LABEL: func.func @not_matmul
4// CHECK-SAME:    ([[ARG0:%.+]]: vector<4xf32>, [[ARG1:%.+]]: vector<4xf32>, [[ARG2:%.+]]: f32)
5// CHECK-NEXT:    vector.contract
6// CHECK-NEXT:    return
7func.func @not_matmul(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
8  %0 = vector.contract {indexing_maps = [affine_map<(d0) -> (d0)>,
9                                         affine_map<(d0) -> (d0)>,
10                                         affine_map<(d0) -> ()>],
11                        iterator_types = ["reduction"],
12                        kind = #vector.kind<add>} %arg0, %arg1, %arg2 :
13         vector<4xf32>, vector<4xf32> into f32
14  return %0 : f32
15}
16
17// This contraction is already in the canonical form.
18// CHECK-LABEL: func.func @matmul_mk_nk_mn_4x4xi32
19// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
20// CHECK-NEXT:    [[RES:%.+]]   = vector.contract {{.+}} [[ARG0]], [[ARG1]], [[ARG2]]
21// CHECK-NEXT:    return [[RES]]
22func.func @matmul_mk_nk_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
23  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
24                                           affine_map<(d0, d1, d2) -> (d1, d2)>,
25                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
26                          iterator_types = ["parallel", "parallel", "reduction"],
27                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
28  return %res : vector<4x4xi32>
29}
30
31// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x4xi32
32// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
33// CHECK-NEXT:    [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
34// CHECK-NEXT:    [[RES:%.+]]   = vector.contract {{.+}} [[ARG0]], [[TRANS]], [[ARG2]]
35// CHECK-NEXT:    return [[RES]]
36func.func @matmul_mk_kn_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
37  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
38                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
39                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
40                          iterator_types = ["parallel", "parallel", "reduction"],
41                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
42  return %res : vector<4x4xi32>
43}
44
45// CHECK-LABEL: func.func @matmul_mk_kn_mn_8x16xi8_extsi_i32
46// CHECK-SAME:    ([[ARG0:%.+]]: vector<8x4xi8>, [[ARG1:%.+]]: vector<4x16xi8>, [[ARG2:%.+]]: vector<8x16xi32>)
47// CHECK-NEXT:    [[LHS:%.+]]   = arith.extsi [[ARG0]] : vector<8x4xi8> to vector<8x4xi32>
48// CHECK-NEXT:    [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x16xi8> to vector<16x4xi8>
49// CHECK-NEXT:    [[RHS:%.+]]   = arith.extsi [[TRANS]] : vector<16x4xi8> to vector<16x4xi32>
50// CHECK-NEXT:    [[RES:%.+]]   = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
51// CHECK-NEXT:    return [[RES]]
52func.func @matmul_mk_kn_mn_8x16xi8_extsi_i32(%arg0: vector<8x4xi8>, %arg1: vector<4x16xi8>, %arg2: vector<8x16xi32>) -> vector<8x16xi32> {
53  %lhs = arith.extsi %arg0: vector<8x4xi8> to vector<8x4xi32>
54  %rhs = arith.extsi %arg1: vector<4x16xi8> to vector<4x16xi32>
55  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
56                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
57                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
58                          iterator_types = ["parallel", "parallel", "reduction"],
59                          kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<8x4xi32>, vector<4x16xi32> into vector<8x16xi32>
60  return %res : vector<8x16xi32>
61}
62
63// Check that non-square shapes are also handled.
64// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x16xi32
65// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x16xi32>, [[ARG1:%.+]]: vector<16x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
66// CHECK-NEXT:    [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<16x4xi32> to vector<4x16xi32>
67// CHECK-NEXT:    [[RES:%.+]]   = vector.contract {{.+}} [[ARG0]], [[TRANS]], [[ARG2]]
68// CHECK-NEXT:    return [[RES]]
69func.func @matmul_mk_kn_mn_4x16xi32(%arg0: vector<4x16xi32>, %arg1: vector<16x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
70  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
71                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
72                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
73                          iterator_types = ["parallel", "parallel", "reduction"],
74                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x16xi32>, vector<16x4xi32> into vector<4x4xi32>
75  return %res : vector<4x4xi32>
76}
77
78// CHECK-LABEL: func.func @matmul_mk_kn_mn_8x16xi8_extui_i32
79// CHECK-SAME:    ([[ARG0:%.+]]: vector<8x4xi8>, [[ARG1:%.+]]: vector<4x16xi8>, [[ARG2:%.+]]: vector<8x16xi32>)
80// CHECK-NEXT:    [[LHS:%.+]]   = arith.extui [[ARG0]] : vector<8x4xi8> to vector<8x4xi32>
81// CHECK-NEXT:    [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x16xi8> to vector<16x4xi8>
82// CHECK-NEXT:    [[RHS:%.+]]   = arith.extui [[TRANS]] : vector<16x4xi8> to vector<16x4xi32>
83// CHECK-NEXT:    [[RES:%.+]]   = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
84// CHECK-NEXT:    return [[RES]]
85func.func @matmul_mk_kn_mn_8x16xi8_extui_i32(%arg0: vector<8x4xi8>, %arg1: vector<4x16xi8>, %arg2: vector<8x16xi32>) -> vector<8x16xi32> {
86  %lhs = arith.extui %arg0: vector<8x4xi8> to vector<8x4xi32>
87  %rhs = arith.extui %arg1: vector<4x16xi8> to vector<4x16xi32>
88  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
89                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
90                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
91                          iterator_types = ["parallel", "parallel", "reduction"],
92                          kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<8x4xi32>, vector<4x16xi32> into vector<8x16xi32>
93  return %res : vector<8x16xi32>
94}
95
96// CHECK-LABEL: func.func @matmul_km_nk_mn_4x4xi32
97// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
98// CHECK-NEXT:    [[TRANS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
99// CHECK-NEXT:    [[RES:%.+]]   = vector.contract {{.+}} [[TRANS]], [[ARG1]], [[ARG2]]
100// CHECK-NEXT:    return [[RES]]
101func.func @matmul_km_nk_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
102  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
103                                           affine_map<(d0, d1, d2) -> (d1, d2)>,
104                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
105                          iterator_types = ["parallel", "parallel", "reduction"],
106                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
107  return %res : vector<4x4xi32>
108}
109
110// CHECK-LABEL: func.func @matmul_km_kn_mn_4x4xi32
111// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
112// CHECK-DAG:     [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
113// CHECK-DAG:     [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
114// CHECK-NEXT:    [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
115// CHECK-NEXT:    return [[RES]]
116func.func @matmul_km_kn_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
117  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
118                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
119                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
120                          iterator_types = ["parallel", "parallel", "reduction"],
121                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
122  return %res : vector<4x4xi32>
123}
124
125// CHECK-LABEL: func.func @matmul_km_kn_mn_8x16xi8_mixed_ext_i32
126// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x8xi8>, [[ARG1:%.+]]: vector<4x16xi8>, [[ARG2:%.+]]: vector<8x16xi32>)
127// CHECK-DAG:     [[LHST:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x8xi8> to vector<8x4xi8>
128// CHECK-DAG:     [[LHS:%.+]]  = arith.extsi [[LHST]] : vector<8x4xi8> to vector<8x4xi32>
129// CHECK-DAG:     [[RHST:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x16xi8> to vector<16x4xi8>
130// CHECK-DAG:     [[RHS:%.+]]  = arith.extui [[RHST]] : vector<16x4xi8> to vector<16x4xi32>
131// CHECK-NEXT:    [[RES:%.+]]  = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]]
132// CHECK-NEXT:    return [[RES]]
133func.func @matmul_km_kn_mn_8x16xi8_mixed_ext_i32(%arg0: vector<4x8xi8>, %arg1: vector<4x16xi8>, %arg2: vector<8x16xi32>) -> vector<8x16xi32> {
134  %lhs = arith.extsi %arg0 : vector<4x8xi8> to vector<4x8xi32>
135  %rhs = arith.extui %arg1 : vector<4x16xi8> to vector<4x16xi32>
136  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
137                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
138                                           affine_map<(d0, d1, d2) -> (d0, d1)>],
139                          iterator_types = ["parallel", "parallel", "reduction"],
140                          kind = #vector.kind<add>} %lhs, %rhs, %arg2 : vector<4x8xi32>, vector<4x16xi32> into vector<8x16xi32>
141  return %res : vector<8x16xi32>
142}
143
144// CHECK-LABEL: func.func @matmul_mk_nk_nm_4x4xi32
145// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
146// CHECK-NEXT:    [[RES:%.+]]   = vector.contract {{.+}} [[ARG1]], [[ARG0]], [[ARG2]]
147// CHECK-NEXT:    return [[RES]]
148func.func @matmul_mk_nk_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
149  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
150                                           affine_map<(d0, d1, d2) -> (d1, d2)>,
151                                           affine_map<(d0, d1, d2) -> (d1, d0)>],
152                          iterator_types = ["parallel", "parallel", "reduction"],
153                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
154  return %res : vector<4x4xi32>
155}
156
157// CHECK-LABEL: func.func @matmul_km_kn_nm_4x4xi32
158// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
159// CHECK-DAG:     [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
160// CHECK-DAG:     [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
161// CHECK-NEXT:    [[RES:%.+]] = vector.contract {{.+}} [[RHS]], [[LHS]], [[ARG2]]
162// CHECK-NEXT:    return [[RES]]
163func.func @matmul_km_kn_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
164  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
165                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
166                                           affine_map<(d0, d1, d2) -> (d1, d0)>],
167                          iterator_types = ["parallel", "parallel", "reduction"],
168                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
169  return %res : vector<4x4xi32>
170}
171
172// CHECK-LABEL: func.func @matmul_mk_kn_nm_4x4xi32
173// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
174// CHECK-DAG:     [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
175// CHECK-NEXT:    [[RES:%.+]] = vector.contract {{.+}} [[RHS]], [[ARG0]], [[ARG2]]
176// CHECK-NEXT:    return [[RES]]
177func.func @matmul_mk_kn_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
178  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
179                                           affine_map<(d0, d1, d2) -> (d2, d1)>,
180                                           affine_map<(d0, d1, d2) -> (d1, d0)>],
181                          iterator_types = ["parallel", "parallel", "reduction"],
182                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
183  return %res : vector<4x4xi32>
184}
185
186// CHECK-LABEL: func.func @matmul_km_nk_nm_4x4xi32
187// CHECK-SAME:    ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>)
188// CHECK-DAG:     [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32>
189// CHECK-NEXT:    [[RES:%.+]] = vector.contract {{.+}} [[ARG1]], [[LHS]], [[ARG2]]
190// CHECK-NEXT:    return [[RES]]
191func.func @matmul_km_nk_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> {
192  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
193                                           affine_map<(d0, d1, d2) -> (d1, d2)>,
194                                           affine_map<(d0, d1, d2) -> (d1, d0)>],
195                          iterator_types = ["parallel", "parallel", "reduction"],
196                          kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
197  return %res : vector<4x4xi32>
198}
199