xref: /llvm-project/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp (revision b9ff67099ad6da931976e66f1510c5af2558a86e)
1 //===- StructuredOpsUtilsTest.cpp - StructuredOpsUtils unit tests ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "gmock/gmock.h"
13 #include "gtest/gtest.h"
14 
15 using namespace mlir;
16 using testing::Not;
17 using testing::Truly;
18 
19 namespace {
20 
21 TEST(isRowMajorMatmul, Simple) {
22   MLIRContext context;
23 
24   AffineExpr m, n, k;
25   bindDims(&context, m, n, k);
26   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
27   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
28   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
29   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
30 
31   EXPECT_THAT(maps, Truly(isRowMajorMatmul));
32 }
33 
34 TEST(isRowMajorMatmul, BindingShifted) {
35   MLIRContext context;
36 
37   AffineExpr m, n, k;
38   bindDims(&context, k, m, n); // bind in different order
39   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
40   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
41   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
42   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
43 
44   EXPECT_THAT(maps, Truly(isRowMajorMatmul));
45 }
46 
47 TEST(isRowMajorMatmul, BindingSwapped) {
48   MLIRContext context;
49 
50   AffineExpr m, n, k;
51   bindDims(&context, k, n, m); // bind in different order
52   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
53   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
54   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
55   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
56 
57   EXPECT_THAT(maps, Truly(isRowMajorMatmul));
58 }
59 
60 TEST(isRowMajorMatmul, ColumnMajor) {
61   MLIRContext context;
62 
63   AffineExpr m, n, k;
64   bindDims(&context, m, n, k);
65   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
66   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
67   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
68   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
69 
70   EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
71 }
72 
73 TEST(isRowMajorMatmul, FirstInputSwapped) {
74   MLIRContext context;
75 
76   AffineExpr m, n, k;
77   bindDims(&context, m, n, k);
78   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context));
79   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
80   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
81   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
82 
83   EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
84 }
85 
86 TEST(isRowMajorMatmul, TooFewMaps) {
87   MLIRContext context;
88 
89   AffineExpr m, n, k;
90   bindDims(&context, m, n, k);
91   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
92   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
93   auto maps = ArrayAttr::get(&context, {mapA, mapB});
94 
95   EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
96 }
97 
98 TEST(isRowMajorMatmul, TooManyMaps) {
99   MLIRContext context;
100 
101   AffineExpr m, n, k;
102   bindDims(&context, m, n, k);
103   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
104   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
105   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
106   auto mapD = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context));
107 
108   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC, mapD});
109 
110   EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
111 }
112 
113 TEST(isRowMajorMatmul, TooFewDims) {
114   MLIRContext context;
115 
116   AffineExpr m, n, k;
117   bindDims(&context, m, n, k);
118   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
119   auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
120   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
121   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
122 
123   EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
124 }
125 
126 TEST(isRowMajorMatmul, TooFewOutputs) {
127   MLIRContext context;
128 
129   AffineExpr m, n, k;
130   bindDims(&context, m, n, k);
131   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m}, &context));
132   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
133   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
134   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
135 
136   EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
137 }
138 
139 TEST(isColumnMajorMatmul, Simple) {
140   MLIRContext context;
141 
142   AffineExpr m, n, k;
143   bindDims(&context, m, n, k);
144   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
145   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
146   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
147   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
148 
149   EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
150 }
151 
152 TEST(isColumnMajorMatmul, BindingShifted) {
153   MLIRContext context;
154 
155   AffineExpr m, n, k;
156   bindDims(&context, k, m, n); // bind in different order
157   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
158   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
159   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
160   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
161 
162   EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
163 }
164 
165 TEST(isColumnMajorMatmul, BindingSwapped) {
166   MLIRContext context;
167 
168   AffineExpr m, n, k;
169   bindDims(&context, k, n, m); // bind in different order
170   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
171   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
172   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
173   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
174 
175   EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
176 }
177 
178 TEST(isColumnMajorMatmul, RowMajor) {
179   MLIRContext context;
180 
181   AffineExpr m, n, k;
182   bindDims(&context, m, n, k);
183   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
184   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
185   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
186   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
187 
188   EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul)));
189 }
190 
191 TEST(isColumnMajorMatmul, FirstInputSwapped) {
192   MLIRContext context;
193 
194   AffineExpr m, n, k;
195   bindDims(&context, m, n, k);
196   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, &context));
197   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
198   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
199   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
200 
201   EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul)));
202 }
203 
204 TEST(isRowMajorBatchMatmul, Simple) {
205   MLIRContext context;
206 
207   AffineExpr batch, m, n, k;
208   bindDims(&context, batch, m, n, k);
209   auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
210   auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
211   auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
212   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
213 
214   EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
215 }
216 
217 TEST(isRowMajorBatchMatmul, BindingShifted) {
218   MLIRContext context;
219 
220   AffineExpr batch, m, n, k;
221   bindDims(&context, k, batch, m, n); // bind in different order
222   auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
223   auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
224   auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
225   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
226 
227   EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
228 }
229 
230 TEST(isRowMajorBatchMatmul, BindingSwapped) {
231   MLIRContext context;
232 
233   AffineExpr batch, m, n, k;
234   bindDims(&context, batch, k, n, m); // bind in different order
235   auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
236   auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
237   auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
238   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
239 
240   EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
241 }
242 
243 TEST(isRowMajorBatchMatmul, FirstInputSwapped) {
244   MLIRContext context;
245 
246   AffineExpr batch, m, n, k;
247   bindDims(&context, batch, m, n, k);
248   auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, m}, &context));
249   auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
250   auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
251   auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
252 
253   EXPECT_THAT(maps, Not(Truly(isRowMajorBatchMatmul)));
254 }
255 
256 } // namespace
257