1 //===- StructuredOpsUtilsTest.cpp - StructuredOpsUtils unit tests ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "gmock/gmock.h"
13 #include "gtest/gtest.h"
14
15 using namespace mlir;
16 using testing::Not;
17 using testing::Truly;
18
19 namespace {
20
TEST(isRowMajorMatmul,Simple)21 TEST(isRowMajorMatmul, Simple) {
22 MLIRContext context;
23
24 AffineExpr m, n, k;
25 bindDims(&context, m, n, k);
26 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
27 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
28 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
29 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
30
31 EXPECT_THAT(maps, Truly(isRowMajorMatmul));
32 }
33
TEST(isRowMajorMatmul,BindingShifted)34 TEST(isRowMajorMatmul, BindingShifted) {
35 MLIRContext context;
36
37 AffineExpr m, n, k;
38 bindDims(&context, k, m, n); // bind in different order
39 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
40 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
41 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
42 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
43
44 EXPECT_THAT(maps, Truly(isRowMajorMatmul));
45 }
46
TEST(isRowMajorMatmul,BindingSwapped)47 TEST(isRowMajorMatmul, BindingSwapped) {
48 MLIRContext context;
49
50 AffineExpr m, n, k;
51 bindDims(&context, k, n, m); // bind in different order
52 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
53 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
54 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
55 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
56
57 EXPECT_THAT(maps, Truly(isRowMajorMatmul));
58 }
59
TEST(isRowMajorMatmul,ColumnMajor)60 TEST(isRowMajorMatmul, ColumnMajor) {
61 MLIRContext context;
62
63 AffineExpr m, n, k;
64 bindDims(&context, m, n, k);
65 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
66 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
67 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
68 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
69
70 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
71 }
72
TEST(isRowMajorMatmul,FirstInputSwapped)73 TEST(isRowMajorMatmul, FirstInputSwapped) {
74 MLIRContext context;
75
76 AffineExpr m, n, k;
77 bindDims(&context, m, n, k);
78 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context));
79 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
80 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
81 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
82
83 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
84 }
85
TEST(isRowMajorMatmul,TooFewMaps)86 TEST(isRowMajorMatmul, TooFewMaps) {
87 MLIRContext context;
88
89 AffineExpr m, n, k;
90 bindDims(&context, m, n, k);
91 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
92 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
93 auto maps = ArrayAttr::get(&context, {mapA, mapB});
94
95 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
96 }
97
TEST(isRowMajorMatmul,TooManyMaps)98 TEST(isRowMajorMatmul, TooManyMaps) {
99 MLIRContext context;
100
101 AffineExpr m, n, k;
102 bindDims(&context, m, n, k);
103 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
104 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
105 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
106 auto mapD = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context));
107
108 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC, mapD});
109
110 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
111 }
112
TEST(isRowMajorMatmul,TooFewOutputs)113 TEST(isRowMajorMatmul, TooFewOutputs) {
114 MLIRContext context;
115
116 AffineExpr m, n, k;
117 bindDims(&context, m, n, k);
118 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m}, &context));
119 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
120 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
121 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
122
123 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
124 }
125
TEST(isColumnMajorMatmul,Simple)126 TEST(isColumnMajorMatmul, Simple) {
127 MLIRContext context;
128
129 AffineExpr m, n, k;
130 bindDims(&context, m, n, k);
131 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
132 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
133 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
134 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
135
136 EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
137 }
138
TEST(isColumnMajorMatmul,BindingShifted)139 TEST(isColumnMajorMatmul, BindingShifted) {
140 MLIRContext context;
141
142 AffineExpr m, n, k;
143 bindDims(&context, k, m, n); // bind in different order
144 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
145 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
146 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
147 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
148
149 EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
150 }
151
TEST(isColumnMajorMatmul,BindingSwapped)152 TEST(isColumnMajorMatmul, BindingSwapped) {
153 MLIRContext context;
154
155 AffineExpr m, n, k;
156 bindDims(&context, k, n, m); // bind in different order
157 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
158 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
159 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
160 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
161
162 EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
163 }
164
TEST(isColumnMajorMatmul,RowMajor)165 TEST(isColumnMajorMatmul, RowMajor) {
166 MLIRContext context;
167
168 AffineExpr m, n, k;
169 bindDims(&context, m, n, k);
170 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
171 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
172 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
173 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
174
175 EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul)));
176 }
177
TEST(isColumnMajorMatmul,FirstInputSwapped)178 TEST(isColumnMajorMatmul, FirstInputSwapped) {
179 MLIRContext context;
180
181 AffineExpr m, n, k;
182 bindDims(&context, m, n, k);
183 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, &context));
184 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
185 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
186 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
187
188 EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul)));
189 }
190
TEST(isRowMajorBatchMatmul,Simple)191 TEST(isRowMajorBatchMatmul, Simple) {
192 MLIRContext context;
193
194 AffineExpr batch, m, n, k;
195 bindDims(&context, batch, m, n, k);
196 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
197 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
198 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
199 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
200
201 EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
202 }
203
TEST(isRowMajorBatchMatmul,BindingShifted)204 TEST(isRowMajorBatchMatmul, BindingShifted) {
205 MLIRContext context;
206
207 AffineExpr batch, m, n, k;
208 bindDims(&context, k, batch, m, n); // bind in different order
209 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
210 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
211 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
212 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
213
214 EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
215 }
216
TEST(isRowMajorBatchMatmul,BindingSwapped)217 TEST(isRowMajorBatchMatmul, BindingSwapped) {
218 MLIRContext context;
219
220 AffineExpr batch, m, n, k;
221 bindDims(&context, batch, k, n, m); // bind in different order
222 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
223 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
224 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
225 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
226
227 EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
228 }
229
TEST(isRowMajorBatchMatmul,FirstInputSwapped)230 TEST(isRowMajorBatchMatmul, FirstInputSwapped) {
231 MLIRContext context;
232
233 AffineExpr batch, m, n, k;
234 bindDims(&context, batch, m, n, k);
235 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, m}, &context));
236 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
237 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
238 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
239
240 EXPECT_THAT(maps, Not(Truly(isRowMajorBatchMatmul)));
241 }
242
TEST(isVecmat,Simple)243 TEST(isVecmat, Simple) {
244 MLIRContext context;
245
246 AffineExpr k, n;
247 bindDims(&context, k, n);
248 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
249 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
250 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
251 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
252
253 EXPECT_THAT(maps, Truly(isVecmat));
254 }
255
TEST(isVecmat,BindingSwapped)256 TEST(isVecmat, BindingSwapped) {
257 MLIRContext context;
258
259 AffineExpr k, n;
260 bindDims(&context, n, k); // bind in different order
261 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
262 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
263 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
264 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
265
266 EXPECT_THAT(maps, Truly(isVecmat));
267 }
268
TEST(isVecmat,WrongDimOrderMatrix)269 TEST(isVecmat, WrongDimOrderMatrix) {
270 MLIRContext context;
271
272 AffineExpr k, n;
273 bindDims(&context, k, n);
274 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
275 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
276 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
277 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
278
279 EXPECT_THAT(maps, Not(Truly(isVecmat)));
280 }
281
TEST(isMatvec,Simple)282 TEST(isMatvec, Simple) {
283 MLIRContext context;
284
285 AffineExpr k, n;
286 bindDims(&context, k, n);
287 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
288 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
289 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
290 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
291
292 EXPECT_THAT(maps, Truly(isMatvec));
293 }
294
TEST(isMatvec,BindingSwapped)295 TEST(isMatvec, BindingSwapped) {
296 MLIRContext context;
297
298 AffineExpr k, n;
299 bindDims(&context, n, k); // bind in different order
300 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
301 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
302 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
303 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
304
305 EXPECT_THAT(maps, Truly(isMatvec));
306 }
307
TEST(isMatvec,WrongDimOrderMatrix)308 TEST(isMatvec, WrongDimOrderMatrix) {
309 MLIRContext context;
310
311 AffineExpr k, n;
312 bindDims(&context, k, n);
313 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
314 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
315 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
316 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
317
318 EXPECT_THAT(maps, Not(Truly(isMatvec)));
319 }
320
TEST(isBatchMatvec,Simple)321 TEST(isBatchMatvec, Simple) {
322 MLIRContext context;
323
324 AffineExpr batch, k, n;
325 bindDims(&context, batch, k, n);
326 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
327 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
328 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
329 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
330
331 EXPECT_THAT(maps, Truly(isBatchMatvec));
332 }
333
TEST(isBatchMatvec,BindingSwapped)334 TEST(isBatchMatvec, BindingSwapped) {
335 MLIRContext context;
336
337 AffineExpr batch, k, n;
338 bindDims(&context, batch, n, k); // bind in different order
339 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
340 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
341 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
342 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
343
344 EXPECT_THAT(maps, Truly(isBatchMatvec));
345 }
346
TEST(isBatchMatvec,Matmul)347 TEST(isBatchMatvec, Matmul) {
348 MLIRContext context;
349
350 AffineExpr m, n, k;
351 bindDims(&context, m, n, k);
352 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
353 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
354 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
355 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
356
357 EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
358 }
359
TEST(isBatchMatvec,WrongDimOrderMatrix)360 TEST(isBatchMatvec, WrongDimOrderMatrix) {
361 MLIRContext context;
362
363 AffineExpr batch, k, n;
364 bindDims(&context, batch, k, n);
365 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
366 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
367 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
368 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
369
370 EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
371 }
372
TEST(isBatchVecmat,Simple)373 TEST(isBatchVecmat, Simple) {
374 MLIRContext context;
375
376 AffineExpr batch, k, n;
377 bindDims(&context, batch, k, n);
378 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
379 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
380 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
381 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
382
383 EXPECT_THAT(maps, Truly(isBatchVecmat));
384 }
385
TEST(isBatchVecmat,BindingSwapped)386 TEST(isBatchVecmat, BindingSwapped) {
387 MLIRContext context;
388
389 AffineExpr batch, k, n;
390 bindDims(&context, batch, n, k); // bind in different order
391 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
392 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
393 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
394 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
395
396 EXPECT_THAT(maps, Truly(isBatchVecmat));
397 }
398
TEST(isBatchVecmat,Matmul)399 TEST(isBatchVecmat, Matmul) {
400 MLIRContext context;
401
402 AffineExpr m, n, k;
403 bindDims(&context, m, n, k);
404 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
405 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
406 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
407 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
408
409 EXPECT_THAT(maps, Not(Truly(isBatchVecmat)));
410 }
411
TEST(isBatchVecmat,WrongDimOrderMatrix)412 TEST(isBatchVecmat, WrongDimOrderMatrix) {
413 MLIRContext context;
414
415 AffineExpr batch, k, n;
416 bindDims(&context, batch, k, n);
417 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
418 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
419 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
420 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
421
422 EXPECT_THAT(maps, Not(Truly(isBatchVecmat)));
423 }
424
425 } // namespace
426