xref: /llvm-project/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp (revision 8a80e331506e3e3db390ed0b482c7cbe216f7afc)
1 //===- StructuredOpsUtilsTest.cpp - StructuredOpsUtils unit tests ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "gmock/gmock.h"
13 #include "gtest/gtest.h"
14 
15 using namespace mlir;
16 using testing::Not;
17 using testing::Truly;
18 
19 namespace {
20 
TEST(isRowMajorMatmul,Simple)21 TEST(isRowMajorMatmul, Simple) {
22   MLIRContext context;
23 
24   AffineExpr m, n, k;
25   bindDims(&context, m, n, k);
26   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
27   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
28   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
29   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
30 
31   EXPECT_THAT(maps, Truly(isRowMajorMatmul));
32 }
33 
TEST(isRowMajorMatmul,BindingShifted)34 TEST(isRowMajorMatmul, BindingShifted) {
35   MLIRContext context;
36 
37   AffineExpr m, n, k;
38   bindDims(&context, k, m, n); // bind in different order
39   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
40   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
41   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
42   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
43 
44   EXPECT_THAT(maps, Truly(isRowMajorMatmul));
45 }
46 
TEST(isRowMajorMatmul,BindingSwapped)47 TEST(isRowMajorMatmul, BindingSwapped) {
48   MLIRContext context;
49 
50   AffineExpr m, n, k;
51   bindDims(&context, k, n, m); // bind in different order
52   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
53   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
54   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
55   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
56 
57   EXPECT_THAT(maps, Truly(isRowMajorMatmul));
58 }
59 
TEST(isRowMajorMatmul,ColumnMajor)60 TEST(isRowMajorMatmul, ColumnMajor) {
61   MLIRContext context;
62 
63   AffineExpr m, n, k;
64   bindDims(&context, m, n, k);
65   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
66   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
67   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
68   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
69 
70   EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
71 }
72 
TEST(isRowMajorMatmul,FirstInputSwapped)73 TEST(isRowMajorMatmul, FirstInputSwapped) {
74   MLIRContext context;
75 
76   AffineExpr m, n, k;
77   bindDims(&context, m, n, k);
78   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context));
79   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
80   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
81   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
82 
83   EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
84 }
85 
TEST(isRowMajorMatmul,TooFewMaps)86 TEST(isRowMajorMatmul, TooFewMaps) {
87   MLIRContext context;
88 
89   AffineExpr m, n, k;
90   bindDims(&context, m, n, k);
91   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
92   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
93   auto maps = ArrayAttr::get(&context, {mapA, mapB});
94 
95   EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
96 }
97 
TEST(isRowMajorMatmul,TooManyMaps)98 TEST(isRowMajorMatmul, TooManyMaps) {
99   MLIRContext context;
100 
101   AffineExpr m, n, k;
102   bindDims(&context, m, n, k);
103   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
104   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
105   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
106   auto mapD = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context));
107 
108   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC, mapD});
109 
110   EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
111 }
112 
TEST(isRowMajorMatmul,TooFewOutputs)113 TEST(isRowMajorMatmul, TooFewOutputs) {
114   MLIRContext context;
115 
116   AffineExpr m, n, k;
117   bindDims(&context, m, n, k);
118   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m}, &context));
119   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
120   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
121   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
122 
123   EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
124 }
125 
TEST(isColumnMajorMatmul,Simple)126 TEST(isColumnMajorMatmul, Simple) {
127   MLIRContext context;
128 
129   AffineExpr m, n, k;
130   bindDims(&context, m, n, k);
131   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
132   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
133   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
134   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
135 
136   EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
137 }
138 
TEST(isColumnMajorMatmul,BindingShifted)139 TEST(isColumnMajorMatmul, BindingShifted) {
140   MLIRContext context;
141 
142   AffineExpr m, n, k;
143   bindDims(&context, k, m, n); // bind in different order
144   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
145   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
146   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
147   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
148 
149   EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
150 }
151 
TEST(isColumnMajorMatmul,BindingSwapped)152 TEST(isColumnMajorMatmul, BindingSwapped) {
153   MLIRContext context;
154 
155   AffineExpr m, n, k;
156   bindDims(&context, k, n, m); // bind in different order
157   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
158   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
159   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
160   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
161 
162   EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
163 }
164 
TEST(isColumnMajorMatmul,RowMajor)165 TEST(isColumnMajorMatmul, RowMajor) {
166   MLIRContext context;
167 
168   AffineExpr m, n, k;
169   bindDims(&context, m, n, k);
170   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
171   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
172   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
173   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
174 
175   EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul)));
176 }
177 
TEST(isColumnMajorMatmul,FirstInputSwapped)178 TEST(isColumnMajorMatmul, FirstInputSwapped) {
179   MLIRContext context;
180 
181   AffineExpr m, n, k;
182   bindDims(&context, m, n, k);
183   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, &context));
184   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
185   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
186   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
187 
188   EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul)));
189 }
190 
TEST(isRowMajorBatchMatmul,Simple)191 TEST(isRowMajorBatchMatmul, Simple) {
192   MLIRContext context;
193 
194   AffineExpr batch, m, n, k;
195   bindDims(&context, batch, m, n, k);
196   auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
197   auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
198   auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
199   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
200 
201   EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
202 }
203 
TEST(isRowMajorBatchMatmul,BindingShifted)204 TEST(isRowMajorBatchMatmul, BindingShifted) {
205   MLIRContext context;
206 
207   AffineExpr batch, m, n, k;
208   bindDims(&context, k, batch, m, n); // bind in different order
209   auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
210   auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
211   auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
212   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
213 
214   EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
215 }
216 
TEST(isRowMajorBatchMatmul,BindingSwapped)217 TEST(isRowMajorBatchMatmul, BindingSwapped) {
218   MLIRContext context;
219 
220   AffineExpr batch, m, n, k;
221   bindDims(&context, batch, k, n, m); // bind in different order
222   auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
223   auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
224   auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
225   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
226 
227   EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
228 }
229 
TEST(isRowMajorBatchMatmul,FirstInputSwapped)230 TEST(isRowMajorBatchMatmul, FirstInputSwapped) {
231   MLIRContext context;
232 
233   AffineExpr batch, m, n, k;
234   bindDims(&context, batch, m, n, k);
235   auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, m}, &context));
236   auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
237   auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
238   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
239 
240   EXPECT_THAT(maps, Not(Truly(isRowMajorBatchMatmul)));
241 }
242 
TEST(isVecmat,Simple)243 TEST(isVecmat, Simple) {
244   MLIRContext context;
245 
246   AffineExpr k, n;
247   bindDims(&context, k, n);
248   auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
249   auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
250   auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
251   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
252 
253   EXPECT_THAT(maps, Truly(isVecmat));
254 }
255 
TEST(isVecmat,BindingSwapped)256 TEST(isVecmat, BindingSwapped) {
257   MLIRContext context;
258 
259   AffineExpr k, n;
260   bindDims(&context, n, k); // bind in different order
261   auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
262   auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
263   auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
264   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
265 
266   EXPECT_THAT(maps, Truly(isVecmat));
267 }
268 
TEST(isVecmat,WrongDimOrderMatrix)269 TEST(isVecmat, WrongDimOrderMatrix) {
270   MLIRContext context;
271 
272   AffineExpr k, n;
273   bindDims(&context, k, n);
274   auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
275   auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
276   auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
277   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
278 
279   EXPECT_THAT(maps, Not(Truly(isVecmat)));
280 }
281 
TEST(isMatvec,Simple)282 TEST(isMatvec, Simple) {
283   MLIRContext context;
284 
285   AffineExpr k, n;
286   bindDims(&context, k, n);
287   auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
288   auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
289   auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
290   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
291 
292   EXPECT_THAT(maps, Truly(isMatvec));
293 }
294 
TEST(isMatvec,BindingSwapped)295 TEST(isMatvec, BindingSwapped) {
296   MLIRContext context;
297 
298   AffineExpr k, n;
299   bindDims(&context, n, k); // bind in different order
300   auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
301   auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
302   auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
303   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
304 
305   EXPECT_THAT(maps, Truly(isMatvec));
306 }
307 
TEST(isMatvec,WrongDimOrderMatrix)308 TEST(isMatvec, WrongDimOrderMatrix) {
309   MLIRContext context;
310 
311   AffineExpr k, n;
312   bindDims(&context, k, n);
313   auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
314   auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
315   auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
316   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
317 
318   EXPECT_THAT(maps, Not(Truly(isMatvec)));
319 }
320 
TEST(isBatchMatvec,Simple)321 TEST(isBatchMatvec, Simple) {
322   MLIRContext context;
323 
324   AffineExpr batch, k, n;
325   bindDims(&context, batch, k, n);
326   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
327   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
328   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
329   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
330 
331   EXPECT_THAT(maps, Truly(isBatchMatvec));
332 }
333 
TEST(isBatchMatvec,BindingSwapped)334 TEST(isBatchMatvec, BindingSwapped) {
335   MLIRContext context;
336 
337   AffineExpr batch, k, n;
338   bindDims(&context, batch, n, k); // bind in different order
339   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
340   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
341   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
342   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
343 
344   EXPECT_THAT(maps, Truly(isBatchMatvec));
345 }
346 
TEST(isBatchMatvec,Matmul)347 TEST(isBatchMatvec, Matmul) {
348   MLIRContext context;
349 
350   AffineExpr m, n, k;
351   bindDims(&context, m, n, k);
352   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
353   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
354   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
355   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
356 
357   EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
358 }
359 
TEST(isBatchMatvec,WrongDimOrderMatrix)360 TEST(isBatchMatvec, WrongDimOrderMatrix) {
361   MLIRContext context;
362 
363   AffineExpr batch, k, n;
364   bindDims(&context, batch, k, n);
365   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
366   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
367   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
368   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
369 
370   EXPECT_THAT(maps, Not(Truly(isBatchMatvec)));
371 }
372 
TEST(isBatchVecmat,Simple)373 TEST(isBatchVecmat, Simple) {
374   MLIRContext context;
375 
376   AffineExpr batch, k, n;
377   bindDims(&context, batch, k, n);
378   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
379   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
380   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
381   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
382 
383   EXPECT_THAT(maps, Truly(isBatchVecmat));
384 }
385 
TEST(isBatchVecmat,BindingSwapped)386 TEST(isBatchVecmat, BindingSwapped) {
387   MLIRContext context;
388 
389   AffineExpr batch, k, n;
390   bindDims(&context, batch, n, k); // bind in different order
391   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
392   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context));
393   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
394   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
395 
396   EXPECT_THAT(maps, Truly(isBatchVecmat));
397 }
398 
TEST(isBatchVecmat,Matmul)399 TEST(isBatchVecmat, Matmul) {
400   MLIRContext context;
401 
402   AffineExpr m, n, k;
403   bindDims(&context, m, n, k);
404   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
405   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
406   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
407   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
408 
409   EXPECT_THAT(maps, Not(Truly(isBatchVecmat)));
410 }
411 
TEST(isBatchVecmat,WrongDimOrderMatrix)412 TEST(isBatchVecmat, WrongDimOrderMatrix) {
413   MLIRContext context;
414 
415   AffineExpr batch, k, n;
416   bindDims(&context, batch, k, n);
417   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
418   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
419   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
420   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
421 
422   EXPECT_THAT(maps, Not(Truly(isBatchVecmat)));
423 }
424 
425 } // namespace
426